Module fl_server_api.views¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from .group import Group
from .inference import Inference
from .model import Model
from .training import Training
from .user import User
__all__ = ["Group", "Inference", "Model", "Training", "User"]
Sub-modules¶
- fl_server_api.views.base
- fl_server_api.views.dummy
- fl_server_api.views.group
- fl_server_api.views.inference
- fl_server_api.views.model
- fl_server_api.views.training
- fl_server_api.views.user
Classes¶
Group¶
Group Model ViewSet.
View Source
class Group(ViewSet):
"""
Group Model ViewSet.
"""
serializer_class = GroupSerializer
"""The serializer for the ViewSet."""
def _get_group(self, user: UserModel, group_id: int) -> GroupModel:
"""
Get a group by id if the user is a member of the group.
Args:
user (UserModel): The user making the request.
group_id (int): The id of the group.
Raises:
PermissionDenied: If the user is not a member of the group.
Returns:
GroupModel: The group instance.
"""
group = get_entity(GroupModel, pk=group_id)
if not user.groups.contains(group):
raise PermissionDenied("You are not allowed to access this group.")
return group
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def list(self, request: HttpRequest) -> HttpResponse:
"""
Get all groups.
Args:
request (HttpRequest): request object
Raises:
PermissionDenied: If user is not a superuser.
Returns:
HttpResponse: list of groups as json response
"""
if not request.user.is_superuser:
raise PermissionDenied("You are not allowed to access all groups.")
groups = GroupModel.objects.all()
serializer = GroupSerializer(groups, many=True)
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[_default_group_example]
)
def retrieve(self, request: HttpRequest, id: int) -> HttpResponse:
"""
Get group information by id.
Args:
request (HttpRequest): request object
id (int): group id
Returns:
HttpResponse: group as json response
"""
group = self._get_group(request.user, id)
serializer = GroupSerializer(group)
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_201_CREATED: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[OpenApiExample(
name="Create group",
description="Create a new group.",
value={"name": "My new amazing group"},
)]
)
def create(self, request: HttpRequest) -> HttpResponse:
"""
Create a new group.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: new created group as json response
"""
serializer = GroupSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)
def _update(self, request: HttpRequest, id: int, *, partial: bool) -> HttpResponse:
"""
Update group information.
Args:
request (HttpRequest): request object
id (int): group id
partial (bool): allow partial update
Returns:
HttpResponse: updated group as json response
"""
group = self._get_group(request.user, id)
serializer = GroupSerializer(group, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
serializer.save()
if getattr(group, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
group._prefetched_objects_cache = {}
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
_default_group_example,
OpenApiExample(
name="Update group",
description="Update group fields.",
value={"name": "My new amazing group is the best!"},
)
]
)
def update(self, request: HttpRequest, id: int) -> HttpResponse:
"""
Update group information.
Args:
request (HttpRequest): request object
id (int): group id
Returns:
HttpResponse: updated group as json response
"""
return self._update(request, id, partial=False)
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
_default_group_example,
OpenApiExample(
name="Update group partially",
description="Update only some group fields.",
value={"name": "My new amazing group is the best!"},
)
]
)
def partial_update(self, request: HttpRequest, id: int) -> HttpResponse:
"""
Update group information partially.
Args:
request (HttpRequest): request object
id (int): group id
Returns:
HttpResponse: updated group as json response
"""
return self._update(request, id, partial=True)
@extend_schema(
responses={
status.HTTP_204_NO_CONTENT: None,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[_default_group_example]
)
def destroy(self, request: HttpRequest, id: int) -> HttpResponse:
"""
Remove group by id.
Args:
request (HttpRequest): request object
id (int): group id
Returns:
HttpResponse: 204 NO CONTENT
"""
group = self._get_group(request.user, id)
group.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
Ancestors (in MRO)¶
- fl_server_api.views.base.ViewSet
- rest_framework.viewsets.ViewSet
- rest_framework.viewsets.ViewSetMixin
- rest_framework.views.APIView
- django.views.generic.base.View
Class variables¶
The serializer for the ViewSet.
Static methods¶
as_view¶
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement .as_view
,
and slightly modify the view function that is created and returned.
View Source
@classonlymethod
def as_view(cls, actions=None, **initkwargs):
"""
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement `.as_view`,
and slightly modify the view function that is created and returned.
"""
# The name and description initkwargs may be explicitly overridden for
# certain route configurations. eg, names of extra actions.
cls.name = None
cls.description = None
# The suffix initkwarg is reserved for displaying the viewset type.
# This initkwarg should have no effect if the name is provided.
# eg. 'List' or 'Instance'.
cls.suffix = None
# The detail initkwarg is reserved for introspecting the viewset type.
cls.detail = None
# Setting a basename allows a view to reverse its action urls. This
# value is provided by the router through the initkwargs.
cls.basename = None
# actions must not be empty
if not actions:
raise TypeError("The `actions` argument must be provided when "
"calling `.as_view()` on a ViewSet. For example "
"`.as_view({'get': 'list'})`")
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
raise TypeError("You tried to pass in the %s method name as a "
"keyword argument to %s(). Don't do that."
% (key, cls.__name__))
if not hasattr(cls, key):
raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key))
# name and suffix are mutually exclusive
if 'name' in initkwargs and 'suffix' in initkwargs:
raise TypeError("%s() received both `name` and `suffix`, which are "
"mutually exclusive arguments." % (cls.__name__))
def view(request, *args, **kwargs):
self = cls(**initkwargs)
if 'get' in actions and 'head' not in actions:
actions['head'] = actions['get']
# We also store the mapping of request methods to actions,
# so that we can later set the action attribute.
# eg. `self.action = 'list'` on an incoming GET request.
self.action_map = actions
# Bind methods to actions
# This is the bit that's different to a standard view
for method, action in actions.items():
handler = getattr(self, action)
setattr(self, method, handler)
self.request = request
self.args = args
self.kwargs = kwargs
# And continue as usual
return self.dispatch(request, *args, **kwargs)
# take name and docstring from class
update_wrapper(view, cls, updated=())
# and possible attributes set by decorators
# like csrf_exempt from dispatch
update_wrapper(view, cls.dispatch, assigned=())
# We need to set these on the view function, so that breadcrumb
# generation can pick out these bits of information from a
# resolved URL.
view.cls = cls
view.initkwargs = initkwargs
view.actions = actions
return csrf_exempt(view)
get_extra_actions¶
Get the methods that are marked as an extra ViewSet @action
.
View Source
Instance variables¶
Wrap Django's private _allowed_methods
interface in a public property.
Methods¶
check_object_permissions¶
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
View Source
def check_object_permissions(self, request, obj):
"""
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_object_permission(request, self, obj):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_permissions¶
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
View Source
def check_permissions(self, request):
"""
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_throttles¶
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
View Source
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
create¶
Create a new group.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
Returns:
Type | Description |
---|---|
HttpResponse | new created group as json response |
View Source
@extend_schema(
responses={
status.HTTP_201_CREATED: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[OpenApiExample(
name="Create group",
description="Create a new group.",
value={"name": "My new amazing group"},
)]
)
def create(self, request: HttpRequest) -> HttpResponse:
"""
Create a new group.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: new created group as json response
"""
serializer = GroupSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)
destroy¶
def destroy(
self,
request: django.http.request.HttpRequest,
id: int
) -> django.http.response.HttpResponse
Remove group by id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | int | group id | None |
Returns:
Type | Description |
---|---|
HttpResponse | 204 NO CONTENT |
View Source
@extend_schema(
responses={
status.HTTP_204_NO_CONTENT: None,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[_default_group_example]
)
def destroy(self, request: HttpRequest, id: int) -> HttpResponse:
"""
Remove group by id.
Args:
request (HttpRequest): request object
id (int): group id
Returns:
HttpResponse: 204 NO CONTENT
"""
group = self._get_group(request.user, id)
group.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
determine_version¶
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
View Source
def determine_version(self, request, *args, **kwargs):
"""
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
"""
if self.versioning_class is None:
return (None, None)
scheme = self.versioning_class()
return (scheme.determine_version(request, *args, **kwargs), scheme)
dispatch¶
.dispatch()
is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
View Source
def dispatch(self, request, *args, **kwargs):
"""
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
finalize_response¶
Returns the final response object.
View Source
def finalize_response(self, request, response, *args, **kwargs):
"""
Returns the final response object.
"""
# Make the error obvious if a proper response is not returned
assert isinstance(response, HttpResponseBase), (
'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
'to be returned from the view, but received a `%s`'
% type(response)
)
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
neg = self.perform_content_negotiation(request, force=True)
request.accepted_renderer, request.accepted_media_type = neg
response.accepted_renderer = request.accepted_renderer
response.accepted_media_type = request.accepted_media_type
response.renderer_context = self.get_renderer_context()
# Add new vary headers to the response instead of overwriting.
vary_headers = self.headers.pop('Vary', None)
if vary_headers is not None:
patch_vary_headers(response, cc_delim_re.split(vary_headers))
for key, value in self.headers.items():
response[key] = value
return response
get_authenticate_header¶
If a request is unauthenticated, determine the WWW-Authenticate
header to use for 401 responses, if any.
View Source
get_authenticators¶
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them. Otherwise, it falls back to the default authenticators.
Returns:
Type | Description |
---|---|
list | The authenticators for the ViewSet. |
View Source
def get_authenticators(self):
"""
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them.
Otherwise, it falls back to the default authenticators.
Returns:
list: The authenticators for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "authentication_classes"):
return method.authentication_classes
return super().get_authenticators()
get_content_negotiator¶
Instantiate and return the content negotiation class to use.
View Source
get_exception_handler¶
Returns the exception handler that this view uses.
View Source
get_exception_handler_context¶
Returns a dict that is passed through to EXCEPTION_HANDLER,
as the context
argument.
View Source
get_extra_action_url_map¶
Build a map of {names: urls} for the extra actions.
This method will noop if detail
was not provided as a view initkwarg.
View Source
def get_extra_action_url_map(self):
"""
Build a map of {names: urls} for the extra actions.
This method will noop if `detail` was not provided as a view initkwarg.
"""
action_urls = OrderedDict()
# exit early if `detail` has not been provided
if self.detail is None:
return action_urls
# filter for the relevant extra actions
actions = [
action for action in self.get_extra_actions()
if action.detail == self.detail
]
for action in actions:
try:
url_name = '%s-%s' % (self.basename, action.url_name)
namespace = self.request.resolver_match.namespace
if namespace:
url_name = '%s:%s' % (namespace, url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url
except NoReverseMatch:
pass # URL requires additional arguments, ignore
return action_urls
get_format_suffix¶
Determine if the request includes a '.json' style format suffix
View Source
get_parser_context¶
Returns a dict that is passed through to Parser.parse(),
as the parser_context
keyword argument.
View Source
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
# Note: Additionally `request` and `encoding` will also be added
# to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {})
}
get_parsers¶
Instantiates and returns the list of parsers that this view can use.
View Source
get_permissions¶
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them. Otherwise, it falls back to the default permissions.
Returns:
Type | Description |
---|---|
list | The permissions for the ViewSet. |
View Source
def get_permissions(self):
"""
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them.
Otherwise, it falls back to the default permissions.
Returns:
list: The permissions for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "permission_classes"):
return method.permission_classes
return super().get_permissions()
get_renderer_context¶
Returns a dict that is passed through to Renderer.render(),
as the renderer_context
keyword argument.
View Source
def get_renderer_context(self):
"""
Returns a dict that is passed through to Renderer.render(),
as the `renderer_context` keyword argument.
"""
# Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {}),
'request': getattr(self, 'request', None)
}
get_renderers¶
Instantiates and returns the list of renderers that this view can use.
View Source
get_throttles¶
Instantiates and returns the list of throttles that this view uses.
View Source
get_view_description¶
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
View Source
get_view_name¶
Return the view name, as used in OPTIONS responses and in the
browsable API.
View Source
handle_exception¶
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
View Source
def handle_exception(self, exc):
"""
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
exception_handler = self.get_exception_handler()
context = self.get_exception_handler_context()
response = exception_handler(exc, context)
if response is None:
self.raise_uncaught_exception(exc)
response.exception = True
return response
http_method_not_allowed¶
If request.method
does not correspond to a handler method,
determine what kind of exception to raise.
View Source
initial¶
Runs anything that needs to occur prior to calling the method handler.
View Source
def initial(self, request, *args, **kwargs):
"""
Runs anything that needs to occur prior to calling the method handler.
"""
self.format_kwarg = self.get_format_suffix(**kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
initialize_request¶
Set the .action
attribute on the view, depending on the request method.
View Source
def initialize_request(self, request, *args, **kwargs):
"""
Set the `.action` attribute on the view, depending on the request method.
"""
request = super().initialize_request(request, *args, **kwargs)
method = request.method.lower()
if method == 'options':
# This is a special case as we always provide handling for the
# options method in the base `View` class.
# Unlike the other explicitly defined actions, 'metadata' is implicit.
self.action = 'metadata'
else:
self.action = self.action_map.get(method)
return request
list¶
Get all groups.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
Returns:
Type | Description |
---|---|
HttpResponse | list of groups as json response |
Raises:
Type | Description |
---|---|
PermissionDenied | If user is not a superuser. |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def list(self, request: HttpRequest) -> HttpResponse:
"""
Get all groups.
Args:
request (HttpRequest): request object
Raises:
PermissionDenied: If user is not a superuser.
Returns:
HttpResponse: list of groups as json response
"""
if not request.user.is_superuser:
raise PermissionDenied("You are not allowed to access all groups.")
groups = GroupModel.objects.all()
serializer = GroupSerializer(groups, many=True)
return Response(serializer.data)
options¶
Handler method for HTTP 'OPTIONS' request.
View Source
def options(self, request, *args, **kwargs):
"""
Handler method for HTTP 'OPTIONS' request.
"""
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)
partial_update¶
def partial_update(
self,
request: django.http.request.HttpRequest,
id: int
) -> django.http.response.HttpResponse
Update group information partially.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | int | group id | None |
Returns:
Type | Description |
---|---|
HttpResponse | updated group as json response |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
_default_group_example,
OpenApiExample(
name="Update group partially",
description="Update only some group fields.",
value={"name": "My new amazing group is the best!"},
)
]
)
def partial_update(self, request: HttpRequest, id: int) -> HttpResponse:
"""
Update group information partially.
Args:
request (HttpRequest): request object
id (int): group id
Returns:
HttpResponse: updated group as json response
"""
return self._update(request, id, partial=True)
perform_authentication¶
Perform authentication on the incoming request.
Note that if you override this and simply 'pass', then authentication
will instead be performed lazily, the first time either
request.user
or request.auth
is accessed.
View Source
perform_content_negotiation¶
Determine which renderer and media type to use render the response.
View Source
def perform_content_negotiation(self, request, force=False):
"""
Determine which renderer and media type to use render the response.
"""
renderers = self.get_renderers()
conneg = self.get_content_negotiator()
try:
return conneg.select_renderer(request, renderers, self.format_kwarg)
except Exception:
if force:
return (renderers[0], renderers[0].media_type)
raise
permission_denied¶
If request is not permitted, determine what kind of exception to raise.
View Source
def permission_denied(self, request, message=None, code=None):
"""
If request is not permitted, determine what kind of exception to raise.
"""
if request.authenticators and not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied(detail=message, code=code)
raise_uncaught_exception¶
View Source
retrieve¶
def retrieve(
self,
request: django.http.request.HttpRequest,
id: int
) -> django.http.response.HttpResponse
Get group information by id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | int | group id | None |
Returns:
Type | Description |
---|---|
HttpResponse | group as json response |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[_default_group_example]
)
def retrieve(self, request: HttpRequest, id: int) -> HttpResponse:
"""
Get group information by id.
Args:
request (HttpRequest): request object
id (int): group id
Returns:
HttpResponse: group as json response
"""
group = self._get_group(request.user, id)
serializer = GroupSerializer(group)
return Response(serializer.data)
reverse_action¶
Reverse the action for the given url_name
.
View Source
def reverse_action(self, url_name, *args, **kwargs):
"""
Reverse the action for the given `url_name`.
"""
url_name = '%s-%s' % (self.basename, url_name)
namespace = None
if self.request and self.request.resolver_match:
namespace = self.request.resolver_match.namespace
if namespace:
url_name = namespace + ':' + url_name
kwargs.setdefault('request', self.request)
return reverse(url_name, *args, **kwargs)
setup¶
Initialize attributes shared by all view methods.
View Source
throttled¶
If request is throttled, determine what kind of exception to raise.
View Source
update¶
def update(
self,
request: django.http.request.HttpRequest,
id: int
) -> django.http.response.HttpResponse
Update group information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | int | group id | None |
Returns:
Type | Description |
---|---|
HttpResponse | updated group as json response |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
_default_group_example,
OpenApiExample(
name="Update group",
description="Update group fields.",
value={"name": "My new amazing group is the best!"},
)
]
)
def update(self, request: HttpRequest, id: int) -> HttpResponse:
"""
Update group information.
Args:
request (HttpRequest): request object
id (int): group id
Returns:
HttpResponse: updated group as json response
"""
return self._update(request, id, partial=False)
Inference¶
Inference ViewSet for performing inference on a model.
View Source
class Inference(ViewSet):
"""
Inference ViewSet for performing inference on a model.
"""
serializer_class = inline_serializer("InferenceSerializer", fields={
"inference": ListField(child=ListField(child=FloatField())),
"uncertainty": DictField(child=FloatField())
})
"""The serializer for the ViewSet."""
@extend_schema(
request=inline_serializer(
"InferenceJsonSerializer",
fields={
"model_id": CharField(),
"model_input": ListField(child=ListField(child=FloatField())),
"return_format": ChoiceField(["binary", "json"])
}
),
responses={
status.HTTP_200_OK: serializer_class,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
},
examples=[
OpenApiExample("JSON Example", value={
"model_id": "mymodel",
"model_input": [
[1.0, 2.3, -0.4, 3],
[0.01, 9.7, 5.6, 7]
],
"return_format": "json"
}, request_only=True),
]
)
def inference(self, request: HttpRequest) -> HttpResponse:
"""
Performs inference on the provided model and input data.
This method takes in an HTTP request containing the necessary metadata and input data,
performs any required preprocessing on the input data, runs the inference using the specified model,
and returns a response in the format specified by the `return_format` parameter including
possible uncertainty measurements if defined.
Args:
request (HttpRequest): The current HTTP request.
Returns:
HttpResponse: A HttpResponse containing the result of the inference as well as its uncertainty.
"""
request_body, is_json = self._get_handle_content_type(request)
model, preprocessing, input_shape, return_format = self._get_inference_metadata(
request_body,
"json" if is_json else "binary"
)
model_input = self._get_model_input(request, request_body)
if preprocessing:
model_input = preprocessing(model_input)
else:
# if no preprocessing is defined, at least try to convert/interpret the model_input as
# PyTorch tensor, before raising an exception
model_input = self._try_cast_model_input_to_tensor(model_input)
self._validate_model_input_after_preprocessing(model_input, input_shape, bool(preprocessing))
uncertainty_cls, inference, uncertainty = self._do_inference(model, model_input)
return self._make_response(uncertainty_cls, inference, uncertainty, return_format)
def _get_handle_content_type(self, request: HttpRequest) -> Tuple[dict, bool]:
"""
Handles HTTP request body based on their content type.
This function checks if the request content type is either `application/json`
or `multipart/form-data`. If it matches, it returns the corresponding data and
a boolean indicating whether it's JSON (True) or multipart/form-data (False).
Args:
request (HttpRequest): The request.
Returns:
tuple: A tuple containing the parsed data and a boolean indicating the content type.
* If content type is `application/json`, returns the JSON payload as a Python object (dict)
and True to indicate it's JSON.
* If content type is `multipart/form-data`, returns the request POST data and False.
Raises:
UnsupportedMediaType: If an unknown content type is specified, raising an error with
details on supported types (`application/json` and `multipart/form-data`).
"""
match request.content_type.lower():
case s if s.startswith("multipart/form-data"):
return request.POST, False
case s if s.startswith("application/json"):
return json.loads(request.body), True
# if the content type is specified, but not supported, return 415
self._logger.error(f"Unknown Content-Type '{request.content_type}'")
raise UnsupportedMediaType(
"Only Content-Type 'application/json' and 'multipart/form-data' is supported."
)
def _get_inference_metadata(
self,
request_body: dict,
return_format_default: Literal["binary", "json"]
) -> Tuple[Model, Optional[torch.nn.Module], Optional[List[Optional[int]]], str]:
"""
Retrieves inference metadata based on the content of the provided request body.
This method checks if a `model_id` is present in the request body and retrieves
the corresponding model entity. It then determines the return format based on the
request body or default to one of the two supported formats (`binary` or `json`).
Args:
request_body (dict): The data sent with the request, containing at least `model_id`.
return_format_default (Literal["binary", "json"]): The default return format to use if not specified in
the request body.
Returns:
Tuple[Model, Optional[torch.nn.Module], Optional[List[Optional[int]]], str]: A tuple containing:
* The retrieved model entity.
* The global model's preprocessing torch module (if applicable).
* The input shape of the global model (if applicable).
* The return format (`binary` or `json`).
Raises:
ValidationError: If no valid `model_id` is provided in the request body, or if an unknown return format
is specified.
"""
if "model_id" not in request_body:
self._logger.error("No 'model_id' provided in request.")
raise ValidationError("No 'model_id' provided in request.")
model_id = request_body["model_id"]
model = get_entity(Model, pk=model_id)
return_format = request_body.get("return_format", return_format_default)
if return_format not in ["binary", "json"]:
self._logger.error(f"Unknown return format '{return_format}'. Supported are binary and json.")
raise ValidationError(f"Unknown return format '{return_format}'. Supported are binary and json.")
global_model: Optional[GlobalModel] = None
if isinstance(model, GlobalModel):
global_model = model
elif isinstance(model, LocalModel):
global_model = model.base_model
else:
self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel. Skip preprocessing.")
preprocessing: Optional[torch.nn.Module] = None
input_shape: Optional[List[Optional[int]]] = None
if global_model:
if global_model.preprocessing is not None:
preprocessing = global_model.get_preprocessing_torch_model()
if global_model.input_shape is not None:
input_shape = global_model.input_shape
return model, preprocessing, input_shape, return_format
def _get_model_input(self, request: HttpRequest, request_body: dict) -> Any:
"""
Retrieves and decodes the model input from either an uploaded file or the request body.
Args:
request (HttpRequest): The current HTTP request.
request_body (dict): The parsed request body as a dictionary.
Returns:
Any: The decoded model input data.
Raises:
ValidationError: If no `model_input` is found in the uploaded file or the request body.
"""
uploaded_file = request.FILES.get("model_input", None)
if uploaded_file and uploaded_file.file:
model_input = uploaded_file.file.read()
else:
model_input = request_body.get("model_input", None)
if not model_input:
raise ValidationError("No uploaded file 'model_input' found.")
return self._try_decode_model_input(model_input)
def _try_decode_model_input(self, model_input: Any) -> Any:
"""
Attempts to decode the input `model_input` from various formats and returns it in a usable form.
This function first tries to deserialize the input as a PyTorch tensor. If that fails, it attempts to
decode the input as a base64-encoded string. If neither attempt is successful, the original input is returned.
Args:
model_input (Any): The input to be decoded, which can be in any format.
Returns:
Any: The decoded input, which may still be in an unknown format if decoding attempts fail.
"""
# 1. try to deserialize model_input as PyTorch tensor
try:
with disable_logger(self._logger):
model_input = to_torch_tensor(model_input)
except Exception:
pass
# 2. try to decode model_input as base64
try:
is_base64, tmp_model_input = self._is_base64(model_input)
if is_base64:
model_input = tmp_model_input
except Exception:
pass
# result
return model_input
def _try_cast_model_input_to_tensor(self, model_input: Any) -> Any:
"""
Attempt to cast the given model input to a PyTorch tensor.
This function tries to interpret the input in several formats:
1. PIL Image (and later convert it to a PyTorch tensor, see 3.)
2. PyTorch tensor via `torch.as_tensor`
3. PyTorch tensor via torchvision `ToTensor` (supports e.g. PIL images)
If none of these attempts are successful, the original input is returned.
Args:
model_input: The input data to be cast to a PyTorch tensor.
Can be any type that can be converted to a tensor.
Returns:
A PyTorch tensor representation of the input data, or the original
input if it cannot be converted.
"""
def _try_to_pil_image(model_input: Any) -> Any:
stream = BytesIO(model_input)
return Image.open(stream)
if isinstance(model_input, torch.Tensor):
return model_input
# In the following order, try to:
# 1. interpret model_input as PIL image (and later to PyTorch tensor, see step 3),
# 2. interpret model_input as PyTorch tensor,
# 3. interpret model_input as PyTorch tensor via torchvision ToTensor (supports e.g. PIL images).
for fn in [_try_to_pil_image, torch.as_tensor, to_tensor]:
try:
model_input = fn(model_input) # type: ignore
except Exception:
pass
return model_input
def _is_base64(self, sb: str | bytes) -> Tuple[bool, bytes]:
"""
Check if a string or bytes object is a valid Base64 encoded string.
This function checks if the input can be decoded and re-encoded without any changes.
If decoding and encoding returns the same result as the original input, it's likely
that the input was indeed a valid Base64 encoded string.
Note: This code is based on the reference implementation from the linked Stack Overflow answer.
Args:
sb (str | bytes): The input string or bytes object to check.
Returns:
Tuple[bool, bytes]: A tuple containing a boolean indicating whether the input is
a valid Base64 encoded string and the decoded bytes if it is.
References:
https://stackoverflow.com/a/45928164
"""
try:
if isinstance(sb, str):
# If there's any unicode here, an exception will be thrown and the function will return false
sb_bytes = bytes(sb, "ascii")
elif isinstance(sb, bytes):
sb_bytes = sb
else:
raise ValueError("Argument must be string or bytes")
decoded = base64.b64decode(sb_bytes)
return base64.b64encode(decoded) == sb_bytes, decoded
except Exception:
return False, b""
def _validate_model_input_after_preprocessing(
self,
model_input: Any,
model_input_shape: Optional[List[Optional[int]]],
preprocessing: bool
) -> None:
"""
Validates the model input after preprocessing.
Ensures that the provided `model_input` is a valid PyTorch tensor and its shape matches
the expected`model_input_shape`.
Args:
model_input (Any): The model input to be validated.
model_input_shape (Optional[List[Optional[int]]]): The expected shape of the model input.
Can contain None values if not all dimensions are fixed (e.g. first dimension as batch size).
preprocessing (bool): Whether a preprocessing model was defined or not. (Only for a better error message.)
Raises:
ValidationError: If the `model_input` is not a valid PyTorch tensor or
its shape does not match the expected `model_input_shape`.
"""
if not isinstance(model_input, torch.Tensor):
msg = "Model input could not be casted or interpreted as a PyTorch tensor object"
if preprocessing:
msg += " and is still not a PyTorch tensor after preprecessing."
else:
msg += " and no preprecessing is defined."
raise ValidationError(msg)
if model_input_shape and not all(
dim_input == dim_model
for (dim_input, dim_model) in zip(model_input.shape, model_input_shape)
if dim_model is not None
):
raise ValidationError("Input shape does not match model input shape.")
def _make_response(
self,
uncertainty_cls: Type[UncertaintyBase],
inference: torch.Tensor,
uncertainty: Any,
return_type: str
) -> HttpResponse:
"""
Build the response object with the result data.
This method checks the return type and makes a response with the appropriate content type.
If return_type is "binary", a binary-encoded response will be generated using pickle.
Otherwise, a JSON response will be generated by serializing the uncertainty object using its to_json method.
Args:
uncertainty_cls (Type[UncertaintyBase]): The uncertainty class.
inference (torch.Tensor): The inference.
uncertainty (Any): The uncertainty.
return_type (str): The return type.
Returns:
HttpResponse: The inference result response.
"""
if return_type == "binary":
response_bytes = pickle.dumps(dict(inference=inference, uncertainty=uncertainty))
return HttpResponse(response_bytes, content_type="application/octet-stream")
return HttpResponse(uncertainty_cls.to_json(inference, uncertainty), content_type="application/json")
def _do_inference(
self, model: Model, input_tensor: torch.Tensor
) -> Tuple[Type[UncertaintyBase], torch.Tensor, Dict[str, Any]]:
"""
Perform inference on a given input tensor using the provided model.
This methods retrieves the uncertainty class, performs the prediction.
The output of this method consists of:
* The uncertainty class used for inference
* The result of the model's prediction on the input tensor
* Any associated uncertainty for the prediction
Args:
model (Model): The model to perform inference with.
input_tensor (torch.Tensor): Input tensor to pass through the model.
Returns:
Tuple[Type[UncertaintyBase], torch.Tensor, Dict[str, Any]]:
A tuple containing the uncertainty class, prediction result, and any associated uncertainty.
Raises:
APIException: If an error occurs during inference
"""
try:
uncertainty_cls = get_uncertainty_class(model)
inference, uncertainty = uncertainty_cls.prediction(input_tensor, model)
return uncertainty_cls, inference, uncertainty
except TorchDeserializationException as e:
raise APIException(e) from e
except Exception as e:
self._logger.error(e)
raise APIException("Internal Server Error occurred during inference!") from e
Ancestors (in MRO)¶
- fl_server_api.views.base.ViewSet
- rest_framework.viewsets.ViewSet
- rest_framework.viewsets.ViewSetMixin
- rest_framework.views.APIView
- django.views.generic.base.View
Class variables¶
The serializer for the ViewSet.
Static methods¶
as_view¶
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement .as_view
,
and slightly modify the view function that is created and returned.
View Source
@classonlymethod
def as_view(cls, actions=None, **initkwargs):
"""
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement `.as_view`,
and slightly modify the view function that is created and returned.
"""
# The name and description initkwargs may be explicitly overridden for
# certain route configurations. eg, names of extra actions.
cls.name = None
cls.description = None
# The suffix initkwarg is reserved for displaying the viewset type.
# This initkwarg should have no effect if the name is provided.
# eg. 'List' or 'Instance'.
cls.suffix = None
# The detail initkwarg is reserved for introspecting the viewset type.
cls.detail = None
# Setting a basename allows a view to reverse its action urls. This
# value is provided by the router through the initkwargs.
cls.basename = None
# actions must not be empty
if not actions:
raise TypeError("The `actions` argument must be provided when "
"calling `.as_view()` on a ViewSet. For example "
"`.as_view({'get': 'list'})`")
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
raise TypeError("You tried to pass in the %s method name as a "
"keyword argument to %s(). Don't do that."
% (key, cls.__name__))
if not hasattr(cls, key):
raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key))
# name and suffix are mutually exclusive
if 'name' in initkwargs and 'suffix' in initkwargs:
raise TypeError("%s() received both `name` and `suffix`, which are "
"mutually exclusive arguments." % (cls.__name__))
def view(request, *args, **kwargs):
self = cls(**initkwargs)
if 'get' in actions and 'head' not in actions:
actions['head'] = actions['get']
# We also store the mapping of request methods to actions,
# so that we can later set the action attribute.
# eg. `self.action = 'list'` on an incoming GET request.
self.action_map = actions
# Bind methods to actions
# This is the bit that's different to a standard view
for method, action in actions.items():
handler = getattr(self, action)
setattr(self, method, handler)
self.request = request
self.args = args
self.kwargs = kwargs
# And continue as usual
return self.dispatch(request, *args, **kwargs)
# take name and docstring from class
update_wrapper(view, cls, updated=())
# and possible attributes set by decorators
# like csrf_exempt from dispatch
update_wrapper(view, cls.dispatch, assigned=())
# We need to set these on the view function, so that breadcrumb
# generation can pick out these bits of information from a
# resolved URL.
view.cls = cls
view.initkwargs = initkwargs
view.actions = actions
return csrf_exempt(view)
get_extra_actions¶
Get the methods that are marked as an extra ViewSet @action
.
View Source
Instance variables¶
Wrap Django's private _allowed_methods
interface in a public property.
Methods¶
check_object_permissions¶
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
View Source
def check_object_permissions(self, request, obj):
"""
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_object_permission(request, self, obj):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_permissions¶
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
View Source
def check_permissions(self, request):
"""
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_throttles¶
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
View Source
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
determine_version¶
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
View Source
def determine_version(self, request, *args, **kwargs):
"""
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
"""
if self.versioning_class is None:
return (None, None)
scheme = self.versioning_class()
return (scheme.determine_version(request, *args, **kwargs), scheme)
dispatch¶
.dispatch()
is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
View Source
def dispatch(self, request, *args, **kwargs):
"""
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
finalize_response¶
Returns the final response object.
View Source
def finalize_response(self, request, response, *args, **kwargs):
"""
Returns the final response object.
"""
# Make the error obvious if a proper response is not returned
assert isinstance(response, HttpResponseBase), (
'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
'to be returned from the view, but received a `%s`'
% type(response)
)
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
neg = self.perform_content_negotiation(request, force=True)
request.accepted_renderer, request.accepted_media_type = neg
response.accepted_renderer = request.accepted_renderer
response.accepted_media_type = request.accepted_media_type
response.renderer_context = self.get_renderer_context()
# Add new vary headers to the response instead of overwriting.
vary_headers = self.headers.pop('Vary', None)
if vary_headers is not None:
patch_vary_headers(response, cc_delim_re.split(vary_headers))
for key, value in self.headers.items():
response[key] = value
return response
get_authenticate_header¶
If a request is unauthenticated, determine the WWW-Authenticate
header to use for 401 responses, if any.
View Source
get_authenticators¶
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them. Otherwise, it falls back to the default authenticators.
Returns:
Type | Description |
---|---|
list | The authenticators for the ViewSet. |
View Source
def get_authenticators(self):
"""
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them.
Otherwise, it falls back to the default authenticators.
Returns:
list: The authenticators for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "authentication_classes"):
return method.authentication_classes
return super().get_authenticators()
get_content_negotiator¶
Instantiate and return the content negotiation class to use.
View Source
get_exception_handler¶
Returns the exception handler that this view uses.
View Source
get_exception_handler_context¶
Returns a dict that is passed through to EXCEPTION_HANDLER,
as the context
argument.
View Source
get_extra_action_url_map¶
Build a map of {names: urls} for the extra actions.
This method will noop if detail
was not provided as a view initkwarg.
View Source
def get_extra_action_url_map(self):
"""
Build a map of {names: urls} for the extra actions.
This method will noop if `detail` was not provided as a view initkwarg.
"""
action_urls = OrderedDict()
# exit early if `detail` has not been provided
if self.detail is None:
return action_urls
# filter for the relevant extra actions
actions = [
action for action in self.get_extra_actions()
if action.detail == self.detail
]
for action in actions:
try:
url_name = '%s-%s' % (self.basename, action.url_name)
namespace = self.request.resolver_match.namespace
if namespace:
url_name = '%s:%s' % (namespace, url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url
except NoReverseMatch:
pass # URL requires additional arguments, ignore
return action_urls
get_format_suffix¶
Determine if the request includes a '.json' style format suffix
View Source
get_parser_context¶
Returns a dict that is passed through to Parser.parse(),
as the parser_context
keyword argument.
View Source
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
# Note: Additionally `request` and `encoding` will also be added
# to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {})
}
get_parsers¶
Instantiates and returns the list of parsers that this view can use.
View Source
get_permissions¶
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them. Otherwise, it falls back to the default permissions.
Returns:
Type | Description |
---|---|
list | The permissions for the ViewSet. |
View Source
def get_permissions(self):
"""
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them.
Otherwise, it falls back to the default permissions.
Returns:
list: The permissions for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "permission_classes"):
return method.permission_classes
return super().get_permissions()
get_renderer_context¶
Returns a dict that is passed through to Renderer.render(),
as the renderer_context
keyword argument.
View Source
def get_renderer_context(self):
"""
Returns a dict that is passed through to Renderer.render(),
as the `renderer_context` keyword argument.
"""
# Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {}),
'request': getattr(self, 'request', None)
}
get_renderers¶
Instantiates and returns the list of renderers that this view can use.
View Source
get_throttles¶
Instantiates and returns the list of throttles that this view uses.
View Source
get_view_description¶
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
View Source
get_view_name¶
Return the view name, as used in OPTIONS responses and in the
browsable API.
View Source
handle_exception¶
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
View Source
def handle_exception(self, exc):
"""
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
exception_handler = self.get_exception_handler()
context = self.get_exception_handler_context()
response = exception_handler(exc, context)
if response is None:
self.raise_uncaught_exception(exc)
response.exception = True
return response
http_method_not_allowed¶
If request.method
does not correspond to a handler method,
determine what kind of exception to raise.
View Source
inference¶
def inference(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Performs inference on the provided model and input data.
This method takes in an HTTP request containing the necessary metadata and input data,
performs any required preprocessing on the input data, runs the inference using the specified model,
and returns a response in the format specified by the return_format
parameter including
possible uncertainty measurements if defined.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The current HTTP request. | None |
Returns:
Type | Description |
---|---|
HttpResponse | A HttpResponse containing the result of the inference as well as its uncertainty. |
View Source
@extend_schema(
request=inline_serializer(
"InferenceJsonSerializer",
fields={
"model_id": CharField(),
"model_input": ListField(child=ListField(child=FloatField())),
"return_format": ChoiceField(["binary", "json"])
}
),
responses={
status.HTTP_200_OK: serializer_class,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
},
examples=[
OpenApiExample("JSON Example", value={
"model_id": "mymodel",
"model_input": [
[1.0, 2.3, -0.4, 3],
[0.01, 9.7, 5.6, 7]
],
"return_format": "json"
}, request_only=True),
]
)
def inference(self, request: HttpRequest) -> HttpResponse:
"""
Performs inference on the provided model and input data.
This method takes in an HTTP request containing the necessary metadata and input data,
performs any required preprocessing on the input data, runs the inference using the specified model,
and returns a response in the format specified by the `return_format` parameter including
possible uncertainty measurements if defined.
Args:
request (HttpRequest): The current HTTP request.
Returns:
HttpResponse: A HttpResponse containing the result of the inference as well as its uncertainty.
"""
request_body, is_json = self._get_handle_content_type(request)
model, preprocessing, input_shape, return_format = self._get_inference_metadata(
request_body,
"json" if is_json else "binary"
)
model_input = self._get_model_input(request, request_body)
if preprocessing:
model_input = preprocessing(model_input)
else:
# if no preprocessing is defined, at least try to convert/interpret the model_input as
# PyTorch tensor, before raising an exception
model_input = self._try_cast_model_input_to_tensor(model_input)
self._validate_model_input_after_preprocessing(model_input, input_shape, bool(preprocessing))
uncertainty_cls, inference, uncertainty = self._do_inference(model, model_input)
return self._make_response(uncertainty_cls, inference, uncertainty, return_format)
initial¶
Runs anything that needs to occur prior to calling the method handler.
View Source
def initial(self, request, *args, **kwargs):
"""
Runs anything that needs to occur prior to calling the method handler.
"""
self.format_kwarg = self.get_format_suffix(**kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
initialize_request¶
Set the .action
attribute on the view, depending on the request method.
View Source
def initialize_request(self, request, *args, **kwargs):
"""
Set the `.action` attribute on the view, depending on the request method.
"""
request = super().initialize_request(request, *args, **kwargs)
method = request.method.lower()
if method == 'options':
# This is a special case as we always provide handling for the
# options method in the base `View` class.
# Unlike the other explicitly defined actions, 'metadata' is implicit.
self.action = 'metadata'
else:
self.action = self.action_map.get(method)
return request
options¶
Handler method for HTTP 'OPTIONS' request.
View Source
def options(self, request, *args, **kwargs):
"""
Handler method for HTTP 'OPTIONS' request.
"""
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)
perform_authentication¶
Perform authentication on the incoming request.
Note that if you override this and simply 'pass', then authentication
will instead be performed lazily, the first time either
request.user
or request.auth
is accessed.
View Source
perform_content_negotiation¶
Determine which renderer and media type to use render the response.
View Source
def perform_content_negotiation(self, request, force=False):
"""
Determine which renderer and media type to use render the response.
"""
renderers = self.get_renderers()
conneg = self.get_content_negotiator()
try:
return conneg.select_renderer(request, renderers, self.format_kwarg)
except Exception:
if force:
return (renderers[0], renderers[0].media_type)
raise
permission_denied¶
If request is not permitted, determine what kind of exception to raise.
View Source
def permission_denied(self, request, message=None, code=None):
"""
If request is not permitted, determine what kind of exception to raise.
"""
if request.authenticators and not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied(detail=message, code=code)
raise_uncaught_exception¶
View Source
reverse_action¶
Reverse the action for the given url_name
.
View Source
def reverse_action(self, url_name, *args, **kwargs):
"""
Reverse the action for the given `url_name`.
"""
url_name = '%s-%s' % (self.basename, url_name)
namespace = None
if self.request and self.request.resolver_match:
namespace = self.request.resolver_match.namespace
if namespace:
url_name = namespace + ':' + url_name
kwargs.setdefault('request', self.request)
return reverse(url_name, *args, **kwargs)
setup¶
Initialize attributes shared by all view methods.
View Source
throttled¶
If request is throttled, determine what kind of exception to raise.
View Source
Model¶
Model ViewSet.
View Source
class Model(ViewSet):
"""
Model ViewSet.
"""
serializer_class = GlobalModelSerializer
"""The serializer for the ViewSet."""
def _get_user_related_global_models(self, user: Union[AbstractBaseUser, AnonymousUser]) -> List[ModelDB]:
"""
Get global models related to a user.
This method retrieves all global models where the user is the actor or a participant of.
Args:
user (Union[AbstractBaseUser, AnonymousUser]): The user.
Returns:
List[ModelDB]: The global models related to the user.
"""
user_ids = Training.objects.filter(
Q(actor=user) | Q(participants=user)
).distinct().values_list("model__id", flat=True)
return ModelDB.objects.filter(Q(owner=user) | Q(id__in=user_ids)).distinct()
def _get_local_models_for_global_model(self, global_model: GlobalModelDB) -> List[LocalModelDB]:
"""
Get all local models that are based on the global model.
Args:
global_model (GlobalModelDB): The global model.
Returns:
List[LocalModelDB]: The local models for the global model.
"""
return LocalModelDB.objects.filter(base_model=global_model).all()
def _get_local_models_for_global_models(self, global_models: List[GlobalModelDB]) -> List[LocalModelDB]:
"""
Get all local models that are based on any of the global models.
Args:
global_models (List[GlobalModelDB]): The global models.
Returns:
List[LocalModelDB]: The local models for the global models.
"""
return LocalModelDB.objects.filter(base_model__in=global_models).all()
def _filter_by_training(self, models: List[ModelDB], training_id: str) -> List[ModelDB]:
"""
Filter a list of models by checking if they are associated with the training.
Args:
models (List[ModelDB]): The models to filter.
training_id (str): The ID of the training.
Returns:
List[ModelDB]: The models associated with the training.
"""
def associated_with_training(m: ModelDB) -> bool:
training = m.get_training()
if training is None:
return False
return training.pk == UUID(training_id)
return list(filter(associated_with_training, models))
@extend_schema(
responses={
status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_models(self, request: HttpRequest) -> HttpResponse:
"""
Get a list of all global models associated with the requesting user.
A global model is deemed associated with a user if the user is either the owner of the model,
or if the user is an actor or a participant in the model's training process.
Args:
request (HttpRequest): The incoming request object.
Returns:
HttpResponse: Model list as JSON response.
"""
models = self._get_user_related_global_models(request.user)
serializer = ModelSerializer(models, many=True)
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
"""
Get a list of all models associated with a specific training process and the requesting user.
A model is deemed associated with a user if the user is either the owner of the model,
or if the user is an actor or a participant in the model's training process.
Args:
request (HttpRequest): The incoming request object.
training_id (str): The unique identifier of the training process.
Returns:
HttpResponse: Model list as JSON response.
"""
global_models = self._get_user_related_global_models(request.user)
global_models = self._filter_by_training(global_models, training_id)
local_models = self._get_local_models_for_global_models(global_models)
serializer = ModelSerializer([*global_models, *local_models], many=True)
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
"""
Get a list of the latest models for a specific training process associated with the requesting user.
A model is considered associated with a user if the user is either the owner of the model,
or if the user is an actor or a participant in the model's training process.
The latest model refers to the model from the most recent round (highest round number) of
a participant's training process.
Args:
request (HttpRequest): The incoming request object.
training_id (str): The unique identifier of the training process.
Returns:
HttpResponse: Model list as JSON response.
"""
models: List[ModelDB] = []
# add latest global model
global_models = self._get_user_related_global_models(request.user)
global_models = self._filter_by_training(global_models, training_id)
models.append(max(global_models, key=lambda m: m.round))
# add latest local models
local_models = self._get_local_models_for_global_models(global_models)
local_models = sorted(local_models, key=lambda m: str(m.owner.pk)) # required for groupby
for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
models.append(max(group, key=lambda m: m.round))
serializer = ModelSerializer(models, many=True)
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
"""
Get model meta data.
Args:
request (HttpRequest): The incoming request object.
id (str): The unique identifier of the model.
Returns:
HttpResponse: Model meta data as JSON response.
"""
model = get_entity(ModelDB, pk=id)
serializer = ModelSerializer(model, context={"with-stats": True})
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
"""
Download the whole model as PyTorch serialized file.
Args:
request (HttpRequest): The incoming request object.
id (str): The unique identifier of the model.
Returns:
HttpResponseBase: model as file response
"""
model = get_entity(ModelDB, pk=id)
if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
if model.swag_second_moment is None:
raise APIException(f"Model {model.id} is in inconsistent state!")
raise NotImplementedError(
"SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
)
# NOTE: FileResponse does strange stuff with bytes
# and in case of sqlite the weights will be bytes and not a memoryview
response = HttpResponse(model.weights, content_type="application/octet-stream")
response["Content-Disposition"] = f'filename="model-{id}.pt"'
return response
@extend_schema(
responses={
status.HTTP_200_OK: OpenApiResponse(
response=bytes,
description="Proprecessing model is returned as bytes"
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
status.HTTP_404_NOT_FOUND: error_response_404,
},
)
def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
"""
Download the whole preprocessing model as PyTorch serialized file.
Args:
request (HttpRequest): The incoming request object.
id (str): The unique identifier of the model.
Returns:
HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
"""
model = get_entity(ModelDB, pk=id)
global_model: torch.nn.Module
if isinstance(model, GlobalModelDB):
global_model = model
elif isinstance(model, LocalModelDB):
global_model = model.base_model
else:
self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
raise ValidationError(f"Unknown model type. Model id: {id}")
if global_model.preprocessing is None:
raise NotFound(f"Model '{id}' has no preprocessing model defined.")
# NOTE: FileResponse does strange stuff with bytes
# and in case of sqlite the weights will be bytes and not a memoryview
response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
return response
@extend_schema(responses={
status.HTTP_200_OK: inline_serializer(
"DeleteModelSuccessSerializer",
fields={
"detail": CharField(default="Model removed!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
})
def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Remove an existing model.
Args:
request (HttpRequest): The incoming request object.
id (str): The unique identifier of the model.
Returns:
HttpResponse: 200 Response if model was removed, else corresponding error code
"""
model = get_entity(ModelDB, pk=id)
if model.owner != request.user:
training = model.get_training()
if training is None or training.actor != request.user:
raise PermissionDenied(
"You are neither the owner of the model nor the actor of the corresponding training."
)
model.delete()
return JsonResponse({"detail": "Model removed!"})
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"name": {"type": "string"},
"description": {"type": "string"},
"model_file": {"type": "string", "format": "binary"},
"model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
},
},
},
responses={
status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
"detail": CharField(default="Model Upload Accepted"),
"model_id": UUIDField(),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def create_model(self, request: HttpRequest) -> HttpResponse:
"""
Upload a global model file.
The model file should be a PyTorch serialized model.
Providing the model via `torch.save` as well as in TorchScript format is supported.
Args:
request (HttpRequest): The incoming request object.
Returns:
HttpResponse: upload success message as json response
"""
model = load_and_create_model_request(request)
return JsonResponse({
"detail": "Model Upload Accepted",
"model_id": str(model.id),
}, status=status.HTTP_201_CREATED)
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"model_preprocessing_file": {"type": "string", "format": "binary"},
},
},
},
responses={
status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
"detail": CharField(default="Proprocessing Model Upload Accepted"),
"model_id": UUIDField(),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Upload a preprocessing model file for a global model.
The preprocessing model file should be a PyTorch serialized model.
Providing the model via `torch.save` as well as in TorchScript format is supported.
```python
transforms = torch.nn.Sequential(
torchvision.transforms.CenterCrop(10),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
```
Make sure to only use transformations that inherit from `torch.nn.Module`.
It is advised to use the `torchvision.transforms.v2` module for common transformations.
Please note that this function is still in the beta phase.
Args:
request (HttpRequest): request object
id (str): global model UUID
Raises:
PermissionDenied: Unauthorized to upload preprocessing model for the specified model
ValidationError: Preprocessing model is not a valid torch model
Returns:
HttpResponse: upload success message as json response
"""
model = get_entity(GlobalModelDB, pk=id)
if request.user.id != model.owner.id:
raise PermissionDenied(f"You are not the owner of model {model.id}!")
model.preprocessing = get_file(request, "model_preprocessing_file")
verify_model_object(model.preprocessing, "preprocessing")
model.save()
return JsonResponse({
"detail": "Proprocessing Model Upload Accepted",
}, status=status.HTTP_202_ACCEPTED)
@extend_schema(
responses={
status.HTTP_200_OK: MetricSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Reports all metrics for the selected model.
Args:
request (HttpRequest): request object
id (str): model UUID
Returns:
HttpResponse: Metrics as JSON Array
"""
model = get_entity(ModelDB, pk=id)
metrics = MetricDB.objects.filter(model=model).all()
return Response(MetricSerializer(metrics, many=True).data)
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"metric_names": {"type": "list"},
"metric_values": {"type": "list"},
},
},
},
responses={
status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
"detail": CharField(default="Model Metrics Upload Accepted"),
"model_id": UUIDField(),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
OpenApiExample("Example", value={
"metric_names": ["accuracy", "training loss"],
"metric_values": [0.6, 0.04]
}, media_type="multipart/form-data")
]
)
def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Upload model metrics.
Args:
request (HttpRequest): request object
id (str): model uuid
Returns:
HttpResponse: upload success message as json response
"""
model = get_entity(ModelDB, pk=id)
formdata = dict(request.POST)
with locked_atomic_transaction(MetricDB):
self._metric_upload(formdata, model, request.user)
if isinstance(model, GlobalModelDB):
n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
training = model.get_training()
if training:
if n_metrics == training.participants.count():
dispatch_trainer_task(training, ModelTestFinished, False)
else:
self._logger.warning(f"Global model {id} is not connected to any training.")
return JsonResponse({
"detail": "Model Metrics Upload Accepted",
"model_id": str(model.id),
}, status=status.HTTP_201_CREATED)
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"owner": {"type": "string"},
"round": {"type": "int"},
"sample_size": {"type": "int"},
"metric_names": {"type": "list[string]"},
"metric_values": {"type": "list[float]"},
"model_file": {"type": "string", "format": "binary"},
},
},
},
responses={
status.HTTP_200_OK: ModelSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Upload a partial trained model file from client.
Args:
request (HttpRequest): request object
id (str): model uuid of the model, which was used for training
Returns:
HttpResponse: upload success message as json response
"""
try:
formdata = dict(request.POST)
(round_num,) = formdata["round"]
(sample_size,) = formdata["sample_size"]
round_num, sample_size = int(round_num), int(sample_size)
client = request.user
model_file = get_file(request, "model_file")
global_model = get_entity(GlobalModelDB, pk=id)
# ensure that a training process coresponding to the model exists, else the process will error out
training = Training.objects.get(model=global_model)
self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)
verify_model_object(model_file)
local_model = LocalModelDB.objects.create(
base_model=global_model, weights=model_file,
round=round_num, owner=client, sample_size=sample_size
)
self._metric_upload(formdata, local_model, client, metrics_required=False)
updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
if updates.count() == training.participants.count():
dispatch_trainer_task(training, TrainingRoundFinished, True)
return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
except Training.DoesNotExist:
raise NotFound(f"Model with ID {id} does not have a training process running")
except (MultiValueDictKeyError, KeyError) as e:
raise ParseError(e)
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"round": {"type": "int"},
"sample_size": {"type": "int"},
"first_moment_file": {"type": "string", "format": "binary"},
"second_moment_file": {"type": "string", "format": "binary"}
},
},
},
responses={
status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
"detail": CharField(default="SWAg Statistics Accepted"),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Upload SWAG statistics.
Args:
request (HttpRequest): request object
id (str): global model uuid
Raises:
APIException: internal server error
NotFound: model not found
ParseError: request data not valid
Returns:
HttpResponse: upload success message as json response
"""
try:
client = request.user
formdata = dict(request.POST)
(round_num,) = formdata["round"]
(sample_size,) = formdata["sample_size"]
round_num, sample_size = int(round_num), int(sample_size)
fst_moment = get_file(request, "first_moment_file")
snd_moment = get_file(request, "second_moment_file")
model = get_entity(GlobalModelDB, pk=id)
# ensure that a training process coresponding to the model exists, else the process will error out
training = Training.objects.get(model=model)
self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)
self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)
swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")
if swag_stats_first.count() != swag_stats_second.count():
training.state = TrainingState.ERROR
raise APIException("SWAG stats in inconsistent state!")
if swag_stats_first.count() == training.participants.count():
dispatch_trainer_task(training, SWAGRoundFinished, True)
return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
except Training.DoesNotExist:
raise NotFound(f"Model with ID {id} does not have a training process running")
except (MultiValueDictKeyError, KeyError) as e:
raise ParseError(e)
except Exception as e:
raise APIException(e)
@staticmethod
def _save_swag_stats(
fst_moment: bytes, snd_moment: bytes, model: GlobalModelDB, client: UserDB, sample_size: int
):
"""
Save the first and second moments, and the sample size of the SWAG to the database.
This function creates and saves three metrics for each round of the model:
- the first moment,
- the second moment, and
- the sample size.
These metrics are associated with the model, the round, and the client that reported them.
Args:
fst_moment (bytes): The first moment of the SWAG.
snd_moment (bytes): The second moment of the SWAG.
model (GlobalModelDB): The global model for which the metrics are being reported.
client (UserDB): The client reporting the metrics.
sample_size (int): The sample size of the SWAG.
"""
MetricDB.objects.create(
model=model,
key="SWAG First Moment Local",
value_binary=fst_moment,
step=model.round,
reporter=client
).save()
MetricDB.objects.create(
model=model,
key="SWAG Second Moment Local",
value_binary=snd_moment,
step=model.round,
reporter=client
).save()
MetricDB.objects.create(
model=model,
key="SWAG Sample Size Local",
value_float=sample_size,
step=model.round,
reporter=client
).save()
@transaction.atomic()
def _metric_upload(self, formdata: dict, model: ModelDB, client: UserDB, metrics_required: bool = True):
"""
Uploads metrics associated with a model.
For each pair of metric name and value, it attempts to convert the value to a float.
If this fails, it treats the value as a binary string.
It then creates a new metric object with the model, the metric name, the float or binary value,
the model's round number, and the client, and saves this object to the database.
Args:
formdata (dict): The form data containing the metric names and values.
model (ModelDB): The model with which the metrics are associated.
client (UserDB): The client reporting the metrics.
metrics_required (bool): A flag indicating whether metrics are required. Defaults to True.
Raises:
ParseError: If `metric_names` or `metric_values` are not in formdata,
or if they do not have the same length and metrics are required.
"""
if "metric_names" not in formdata or "metric_values" not in formdata:
if metrics_required or ("metric_names" in formdata) != ("metric_values" in formdata):
raise ParseError("Metric names or values are missing")
return
if len(formdata["metric_names"]) != len(formdata["metric_values"]):
if metrics_required:
raise ParseError("Metric names and values must have the same length")
return
for key, value in zip(formdata["metric_names"], formdata["metric_values"]):
try:
metric_float = float(value)
metric_binary = None
except Exception:
metric_float = None
metric_binary = bytes(value, encoding="utf-8")
MetricDB.objects.create(
model=model,
key=key,
value_float=metric_float,
value_binary=metric_binary,
step=model.round,
reporter=client
).save()
def _verify_valid_update(self, client: UserDB, train: Training, round_num: int, expected_state: tuple[str, Any]):
"""
Verifies if a client can update a training process.
This function checks if
- the client is a participant of the training process,
- the training process is in the expected state, and if
- the round number matches the current round of the model associated with the training process.
Args:
client (UserDB): The client attempting to update the training process.
train (Training): The training process to be updated.
round_num (int): The round number reported by the client.
expected_state (tuple[str, Any]): The expected state of the training process.
Raises:
PermissionDenied: If the client is not a participant of the training process.
ValidationError: If the training process is not in the expected state or if the round number does not match
the current round of the model.
"""
if client.id not in [p.id for p in train.participants.all()]:
raise PermissionDenied(f"You are not a participant of training {train.id}!")
if train.state != expected_state:
raise ValidationError(f"Training with ID {train.id} is in state {train.state}")
if int(round_num) != train.model.round:
raise ValidationError(f"Training with ID {train.id} is not currently in round {round_num}")
Ancestors (in MRO)¶
- fl_server_api.views.base.ViewSet
- rest_framework.viewsets.ViewSet
- rest_framework.viewsets.ViewSetMixin
- rest_framework.views.APIView
- django.views.generic.base.View
Class variables¶
The serializer for the ViewSet.
Static methods¶
as_view¶
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement .as_view
,
and slightly modify the view function that is created and returned.
View Source
@classonlymethod
def as_view(cls, actions=None, **initkwargs):
"""
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement `.as_view`,
and slightly modify the view function that is created and returned.
"""
# The name and description initkwargs may be explicitly overridden for
# certain route configurations. eg, names of extra actions.
cls.name = None
cls.description = None
# The suffix initkwarg is reserved for displaying the viewset type.
# This initkwarg should have no effect if the name is provided.
# eg. 'List' or 'Instance'.
cls.suffix = None
# The detail initkwarg is reserved for introspecting the viewset type.
cls.detail = None
# Setting a basename allows a view to reverse its action urls. This
# value is provided by the router through the initkwargs.
cls.basename = None
# actions must not be empty
if not actions:
raise TypeError("The `actions` argument must be provided when "
"calling `.as_view()` on a ViewSet. For example "
"`.as_view({'get': 'list'})`")
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
raise TypeError("You tried to pass in the %s method name as a "
"keyword argument to %s(). Don't do that."
% (key, cls.__name__))
if not hasattr(cls, key):
raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key))
# name and suffix are mutually exclusive
if 'name' in initkwargs and 'suffix' in initkwargs:
raise TypeError("%s() received both `name` and `suffix`, which are "
"mutually exclusive arguments." % (cls.__name__))
def view(request, *args, **kwargs):
self = cls(**initkwargs)
if 'get' in actions and 'head' not in actions:
actions['head'] = actions['get']
# We also store the mapping of request methods to actions,
# so that we can later set the action attribute.
# eg. `self.action = 'list'` on an incoming GET request.
self.action_map = actions
# Bind methods to actions
# This is the bit that's different to a standard view
for method, action in actions.items():
handler = getattr(self, action)
setattr(self, method, handler)
self.request = request
self.args = args
self.kwargs = kwargs
# And continue as usual
return self.dispatch(request, *args, **kwargs)
# take name and docstring from class
update_wrapper(view, cls, updated=())
# and possible attributes set by decorators
# like csrf_exempt from dispatch
update_wrapper(view, cls.dispatch, assigned=())
# We need to set these on the view function, so that breadcrumb
# generation can pick out these bits of information from a
# resolved URL.
view.cls = cls
view.initkwargs = initkwargs
view.actions = actions
return csrf_exempt(view)
get_extra_actions¶
Get the methods that are marked as an extra ViewSet @action
.
View Source
Instance variables¶
Wrap Django's private _allowed_methods
interface in a public property.
Methods¶
check_object_permissions¶
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
View Source
def check_object_permissions(self, request, obj):
"""
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_object_permission(request, self, obj):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_permissions¶
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
View Source
def check_permissions(self, request):
"""
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_throttles¶
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
View Source
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
create_local_model¶
def create_local_model(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Upload a partial trained model file from client.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | model uuid of the model, which was used for training | None |
Returns:
Type | Description |
---|---|
HttpResponse | upload success message as json response |
View Source
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"owner": {"type": "string"},
"round": {"type": "int"},
"sample_size": {"type": "int"},
"metric_names": {"type": "list[string]"},
"metric_values": {"type": "list[float]"},
"model_file": {"type": "string", "format": "binary"},
},
},
},
responses={
status.HTTP_200_OK: ModelSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Upload a partial trained model file from client.
Args:
request (HttpRequest): request object
id (str): model uuid of the model, which was used for training
Returns:
HttpResponse: upload success message as json response
"""
try:
formdata = dict(request.POST)
(round_num,) = formdata["round"]
(sample_size,) = formdata["sample_size"]
round_num, sample_size = int(round_num), int(sample_size)
client = request.user
model_file = get_file(request, "model_file")
global_model = get_entity(GlobalModelDB, pk=id)
# ensure that a training process coresponding to the model exists, else the process will error out
training = Training.objects.get(model=global_model)
self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)
verify_model_object(model_file)
local_model = LocalModelDB.objects.create(
base_model=global_model, weights=model_file,
round=round_num, owner=client, sample_size=sample_size
)
self._metric_upload(formdata, local_model, client, metrics_required=False)
updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
if updates.count() == training.participants.count():
dispatch_trainer_task(training, TrainingRoundFinished, True)
return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
except Training.DoesNotExist:
raise NotFound(f"Model with ID {id} does not have a training process running")
except (MultiValueDictKeyError, KeyError) as e:
raise ParseError(e)
create_model¶
def create_model(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Upload a global model file.
The model file should be a PyTorch serialized model.
Providing the model via torch.save
as well as in TorchScript format is supported.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The incoming request object. | None |
Returns:
Type | Description |
---|---|
HttpResponse | upload success message as json response |
View Source
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"name": {"type": "string"},
"description": {"type": "string"},
"model_file": {"type": "string", "format": "binary"},
"model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
},
},
},
responses={
status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
"detail": CharField(default="Model Upload Accepted"),
"model_id": UUIDField(),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def create_model(self, request: HttpRequest) -> HttpResponse:
"""
Upload a global model file.
The model file should be a PyTorch serialized model.
Providing the model via `torch.save` as well as in TorchScript format is supported.
Args:
request (HttpRequest): The incoming request object.
Returns:
HttpResponse: upload success message as json response
"""
model = load_and_create_model_request(request)
return JsonResponse({
"detail": "Model Upload Accepted",
"model_id": str(model.id),
}, status=status.HTTP_201_CREATED)
create_model_metrics¶
def create_model_metrics(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Upload model metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | model uuid | None |
Returns:
Type | Description |
---|---|
HttpResponse | upload success message as json response |
View Source
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"metric_names": {"type": "list"},
"metric_values": {"type": "list"},
},
},
},
responses={
status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
"detail": CharField(default="Model Metrics Upload Accepted"),
"model_id": UUIDField(),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
OpenApiExample("Example", value={
"metric_names": ["accuracy", "training loss"],
"metric_values": [0.6, 0.04]
}, media_type="multipart/form-data")
]
)
def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Upload model metrics.
Args:
request (HttpRequest): request object
id (str): model uuid
Returns:
HttpResponse: upload success message as json response
"""
model = get_entity(ModelDB, pk=id)
formdata = dict(request.POST)
with locked_atomic_transaction(MetricDB):
self._metric_upload(formdata, model, request.user)
if isinstance(model, GlobalModelDB):
n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
training = model.get_training()
if training:
if n_metrics == training.participants.count():
dispatch_trainer_task(training, ModelTestFinished, False)
else:
self._logger.warning(f"Global model {id} is not connected to any training.")
return JsonResponse({
"detail": "Model Metrics Upload Accepted",
"model_id": str(model.id),
}, status=status.HTTP_201_CREATED)
create_swag_stats¶
def create_swag_stats(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Upload SWAG statistics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | global model uuid | None |
Returns:
Type | Description |
---|---|
HttpResponse | upload success message as json response |
Raises:
Type | Description |
---|---|
APIException | internal server error |
NotFound | model not found |
ParseError | request data not valid |
View Source
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"round": {"type": "int"},
"sample_size": {"type": "int"},
"first_moment_file": {"type": "string", "format": "binary"},
"second_moment_file": {"type": "string", "format": "binary"}
},
},
},
responses={
status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
"detail": CharField(default="SWAg Statistics Accepted"),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Upload SWAG statistics.
Args:
request (HttpRequest): request object
id (str): global model uuid
Raises:
APIException: internal server error
NotFound: model not found
ParseError: request data not valid
Returns:
HttpResponse: upload success message as json response
"""
try:
client = request.user
formdata = dict(request.POST)
(round_num,) = formdata["round"]
(sample_size,) = formdata["sample_size"]
round_num, sample_size = int(round_num), int(sample_size)
fst_moment = get_file(request, "first_moment_file")
snd_moment = get_file(request, "second_moment_file")
model = get_entity(GlobalModelDB, pk=id)
# ensure that a training process coresponding to the model exists, else the process will error out
training = Training.objects.get(model=model)
self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)
self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)
swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")
if swag_stats_first.count() != swag_stats_second.count():
training.state = TrainingState.ERROR
raise APIException("SWAG stats in inconsistent state!")
if swag_stats_first.count() == training.participants.count():
dispatch_trainer_task(training, SWAGRoundFinished, True)
return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
except Training.DoesNotExist:
raise NotFound(f"Model with ID {id} does not have a training process running")
except (MultiValueDictKeyError, KeyError) as e:
raise ParseError(e)
except Exception as e:
raise APIException(e)
determine_version¶
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
View Source
def determine_version(self, request, *args, **kwargs):
"""
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
"""
if self.versioning_class is None:
return (None, None)
scheme = self.versioning_class()
return (scheme.determine_version(request, *args, **kwargs), scheme)
dispatch¶
.dispatch()
is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
View Source
def dispatch(self, request, *args, **kwargs):
"""
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
finalize_response¶
Returns the final response object.
View Source
def finalize_response(self, request, response, *args, **kwargs):
"""
Returns the final response object.
"""
# Make the error obvious if a proper response is not returned
assert isinstance(response, HttpResponseBase), (
'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
'to be returned from the view, but received a `%s`'
% type(response)
)
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
neg = self.perform_content_negotiation(request, force=True)
request.accepted_renderer, request.accepted_media_type = neg
response.accepted_renderer = request.accepted_renderer
response.accepted_media_type = request.accepted_media_type
response.renderer_context = self.get_renderer_context()
# Add new vary headers to the response instead of overwriting.
vary_headers = self.headers.pop('Vary', None)
if vary_headers is not None:
patch_vary_headers(response, cc_delim_re.split(vary_headers))
for key, value in self.headers.items():
response[key] = value
return response
get_authenticate_header¶
If a request is unauthenticated, determine the WWW-Authenticate
header to use for 401 responses, if any.
View Source
get_authenticators¶
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them. Otherwise, it falls back to the default authenticators.
Returns:
Type | Description |
---|---|
list | The authenticators for the ViewSet. |
View Source
def get_authenticators(self):
"""
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them.
Otherwise, it falls back to the default authenticators.
Returns:
list: The authenticators for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "authentication_classes"):
return method.authentication_classes
return super().get_authenticators()
get_content_negotiator¶
Instantiate and return the content negotiation class to use.
View Source
get_exception_handler¶
Returns the exception handler that this view uses.
View Source
get_exception_handler_context¶
Returns a dict that is passed through to EXCEPTION_HANDLER,
as the context
argument.
View Source
get_extra_action_url_map¶
Build a map of {names: urls} for the extra actions.
This method will noop if detail
was not provided as a view initkwarg.
View Source
def get_extra_action_url_map(self):
"""
Build a map of {names: urls} for the extra actions.
This method will noop if `detail` was not provided as a view initkwarg.
"""
action_urls = OrderedDict()
# exit early if `detail` has not been provided
if self.detail is None:
return action_urls
# filter for the relevant extra actions
actions = [
action for action in self.get_extra_actions()
if action.detail == self.detail
]
for action in actions:
try:
url_name = '%s-%s' % (self.basename, action.url_name)
namespace = self.request.resolver_match.namespace
if namespace:
url_name = '%s:%s' % (namespace, url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url
except NoReverseMatch:
pass # URL requires additional arguments, ignore
return action_urls
get_format_suffix¶
Determine if the request includes a '.json' style format suffix
View Source
get_metadata¶
def get_metadata(
self,
_request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Get model meta data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The incoming request object. | None |
id | str | The unique identifier of the model. | None |
Returns:
Type | Description |
---|---|
HttpResponse | Model meta data as JSON response. |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
"""
Get model meta data.
Args:
request (HttpRequest): The incoming request object.
id (str): The unique identifier of the model.
Returns:
HttpResponse: Model meta data as JSON response.
"""
model = get_entity(ModelDB, pk=id)
serializer = ModelSerializer(model, context={"with-stats": True})
return Response(serializer.data)
get_model¶
def get_model(
self,
_request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponseBase
Download the whole model as PyTorch serialized file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The incoming request object. | None |
id | str | The unique identifier of the model. | None |
Returns:
Type | Description |
---|---|
HttpResponseBase | model as file response |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
"""
Download the whole model as PyTorch serialized file.
Args:
request (HttpRequest): The incoming request object.
id (str): The unique identifier of the model.
Returns:
HttpResponseBase: model as file response
"""
model = get_entity(ModelDB, pk=id)
if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
if model.swag_second_moment is None:
raise APIException(f"Model {model.id} is in inconsistent state!")
raise NotImplementedError(
"SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
)
# NOTE: FileResponse does strange stuff with bytes
# and in case of sqlite the weights will be bytes and not a memoryview
response = HttpResponse(model.weights, content_type="application/octet-stream")
response["Content-Disposition"] = f'filename="model-{id}.pt"'
return response
get_model_metrics¶
def get_model_metrics(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Reports all metrics for the selected model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | model UUID | None |
Returns:
Type | Description |
---|---|
HttpResponse | Metrics as JSON Array |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: MetricSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Reports all metrics for the selected model.
Args:
request (HttpRequest): request object
id (str): model UUID
Returns:
HttpResponse: Metrics as JSON Array
"""
model = get_entity(ModelDB, pk=id)
metrics = MetricDB.objects.filter(model=model).all()
return Response(MetricSerializer(metrics, many=True).data)
get_model_proprecessing¶
def get_model_proprecessing(
self,
_request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponseBase
Download the whole preprocessing model as PyTorch serialized file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The incoming request object. | None |
id | str | The unique identifier of the model. | None |
Returns:
Type | Description |
---|---|
HttpResponseBase | proprecessing model as file response or 404 if proprecessing model not found |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: OpenApiResponse(
response=bytes,
description="Proprecessing model is returned as bytes"
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
status.HTTP_404_NOT_FOUND: error_response_404,
},
)
def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
"""
Download the whole preprocessing model as PyTorch serialized file.
Args:
request (HttpRequest): The incoming request object.
id (str): The unique identifier of the model.
Returns:
HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
"""
model = get_entity(ModelDB, pk=id)
global_model: torch.nn.Module
if isinstance(model, GlobalModelDB):
global_model = model
elif isinstance(model, LocalModelDB):
global_model = model.base_model
else:
self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
raise ValidationError(f"Unknown model type. Model id: {id}")
if global_model.preprocessing is None:
raise NotFound(f"Model '{id}' has no preprocessing model defined.")
# NOTE: FileResponse does strange stuff with bytes
# and in case of sqlite the weights will be bytes and not a memoryview
response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
return response
get_models¶
def get_models(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Get a list of all global models associated with the requesting user.
A global model is deemed associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The incoming request object. | None |
Returns:
Type | Description |
---|---|
HttpResponse | Model list as JSON response. |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_models(self, request: HttpRequest) -> HttpResponse:
"""
Get a list of all global models associated with the requesting user.
A global model is deemed associated with a user if the user is either the owner of the model,
or if the user is an actor or a participant in the model's training process.
Args:
request (HttpRequest): The incoming request object.
Returns:
HttpResponse: Model list as JSON response.
"""
models = self._get_user_related_global_models(request.user)
serializer = ModelSerializer(models, many=True)
return Response(serializer.data)
get_parser_context¶
Returns a dict that is passed through to Parser.parse(),
as the parser_context
keyword argument.
View Source
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
# Note: Additionally `request` and `encoding` will also be added
# to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {})
}
get_parsers¶
Instantiates and returns the list of parsers that this view can use.
View Source
get_permissions¶
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them. Otherwise, it falls back to the default permissions.
Returns:
Type | Description |
---|---|
list | The permissions for the ViewSet. |
View Source
def get_permissions(self):
"""
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them.
Otherwise, it falls back to the default permissions.
Returns:
list: The permissions for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "permission_classes"):
return method.permission_classes
return super().get_permissions()
get_renderer_context¶
Returns a dict that is passed through to Renderer.render(),
as the renderer_context
keyword argument.
View Source
def get_renderer_context(self):
"""
Returns a dict that is passed through to Renderer.render(),
as the `renderer_context` keyword argument.
"""
# Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {}),
'request': getattr(self, 'request', None)
}
get_renderers¶
Instantiates and returns the list of renderers that this view can use.
View Source
get_throttles¶
Instantiates and returns the list of throttles that this view uses.
View Source
get_training_models¶
def get_training_models(
self,
request: django.http.request.HttpRequest,
training_id: str
) -> django.http.response.HttpResponse
Get a list of all models associated with a specific training process and the requesting user.
A model is deemed associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The incoming request object. | None |
training_id | str | The unique identifier of the training process. | None |
Returns:
Type | Description |
---|---|
HttpResponse | Model list as JSON response. |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
"""
Get a list of all models associated with a specific training process and the requesting user.
A model is deemed associated with a user if the user is either the owner of the model,
or if the user is an actor or a participant in the model's training process.
Args:
request (HttpRequest): The incoming request object.
training_id (str): The unique identifier of the training process.
Returns:
HttpResponse: Model list as JSON response.
"""
global_models = self._get_user_related_global_models(request.user)
global_models = self._filter_by_training(global_models, training_id)
local_models = self._get_local_models_for_global_models(global_models)
serializer = ModelSerializer([*global_models, *local_models], many=True)
return Response(serializer.data)
get_training_models_latest¶
def get_training_models_latest(
self,
request: django.http.request.HttpRequest,
training_id: str
) -> django.http.response.HttpResponse
Get a list of the latest models for a specific training process associated with the requesting user.
A model is considered associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process. The latest model refers to the model from the most recent round (highest round number) of a participant's training process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The incoming request object. | None |
training_id | str | The unique identifier of the training process. | None |
Returns:
Type | Description |
---|---|
HttpResponse | Model list as JSON response. |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
"""
Get a list of the latest models for a specific training process associated with the requesting user.
A model is considered associated with a user if the user is either the owner of the model,
or if the user is an actor or a participant in the model's training process.
The latest model refers to the model from the most recent round (highest round number) of
a participant's training process.
Args:
request (HttpRequest): The incoming request object.
training_id (str): The unique identifier of the training process.
Returns:
HttpResponse: Model list as JSON response.
"""
models: List[ModelDB] = []
# add latest global model
global_models = self._get_user_related_global_models(request.user)
global_models = self._filter_by_training(global_models, training_id)
models.append(max(global_models, key=lambda m: m.round))
# add latest local models
local_models = self._get_local_models_for_global_models(global_models)
local_models = sorted(local_models, key=lambda m: str(m.owner.pk)) # required for groupby
for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
models.append(max(group, key=lambda m: m.round))
serializer = ModelSerializer(models, many=True)
return Response(serializer.data)
get_view_description¶
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
View Source
get_view_name¶
Return the view name, as used in OPTIONS responses and in the
browsable API.
View Source
handle_exception¶
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
View Source
def handle_exception(self, exc):
"""
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
exception_handler = self.get_exception_handler()
context = self.get_exception_handler_context()
response = exception_handler(exc, context)
if response is None:
self.raise_uncaught_exception(exc)
response.exception = True
return response
http_method_not_allowed¶
If request.method
does not correspond to a handler method,
determine what kind of exception to raise.
View Source
initial¶
Runs anything that needs to occur prior to calling the method handler.
View Source
def initial(self, request, *args, **kwargs):
"""
Runs anything that needs to occur prior to calling the method handler.
"""
self.format_kwarg = self.get_format_suffix(**kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
initialize_request¶
Set the .action
attribute on the view, depending on the request method.
View Source
def initialize_request(self, request, *args, **kwargs):
"""
Set the `.action` attribute on the view, depending on the request method.
"""
request = super().initialize_request(request, *args, **kwargs)
method = request.method.lower()
if method == 'options':
# This is a special case as we always provide handling for the
# options method in the base `View` class.
# Unlike the other explicitly defined actions, 'metadata' is implicit.
self.action = 'metadata'
else:
self.action = self.action_map.get(method)
return request
options¶
Handler method for HTTP 'OPTIONS' request.
View Source
def options(self, request, *args, **kwargs):
"""
Handler method for HTTP 'OPTIONS' request.
"""
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)
perform_authentication¶
Perform authentication on the incoming request.
Note that if you override this and simply 'pass', then authentication
will instead be performed lazily, the first time either
request.user
or request.auth
is accessed.
View Source
perform_content_negotiation¶
Determine which renderer and media type to use render the response.
View Source
def perform_content_negotiation(self, request, force=False):
"""
Determine which renderer and media type to use render the response.
"""
renderers = self.get_renderers()
conneg = self.get_content_negotiator()
try:
return conneg.select_renderer(request, renderers, self.format_kwarg)
except Exception:
if force:
return (renderers[0], renderers[0].media_type)
raise
permission_denied¶
If request is not permitted, determine what kind of exception to raise.
View Source
def permission_denied(self, request, message=None, code=None):
"""
If request is not permitted, determine what kind of exception to raise.
"""
if request.authenticators and not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied(detail=message, code=code)
raise_uncaught_exception¶
View Source
remove_model¶
def remove_model(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Remove an existing model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The incoming request object. | None |
id | str | The unique identifier of the model. | None |
Returns:
Type | Description |
---|---|
HttpResponse | 200 Response if model was removed, else corresponding error code |
View Source
@extend_schema(responses={
status.HTTP_200_OK: inline_serializer(
"DeleteModelSuccessSerializer",
fields={
"detail": CharField(default="Model removed!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
})
def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Remove an existing model.
Args:
request (HttpRequest): The incoming request object.
id (str): The unique identifier of the model.
Returns:
HttpResponse: 200 Response if model was removed, else corresponding error code
"""
model = get_entity(ModelDB, pk=id)
if model.owner != request.user:
training = model.get_training()
if training is None or training.actor != request.user:
raise PermissionDenied(
"You are neither the owner of the model nor the actor of the corresponding training."
)
model.delete()
return JsonResponse({"detail": "Model removed!"})
reverse_action¶
Reverse the action for the given url_name
.
View Source
def reverse_action(self, url_name, *args, **kwargs):
"""
Reverse the action for the given `url_name`.
"""
url_name = '%s-%s' % (self.basename, url_name)
namespace = None
if self.request and self.request.resolver_match:
namespace = self.request.resolver_match.namespace
if namespace:
url_name = namespace + ':' + url_name
kwargs.setdefault('request', self.request)
return reverse(url_name, *args, **kwargs)
setup¶
Initialize attributes shared by all view methods.
View Source
throttled¶
If request is throttled, determine what kind of exception to raise.
View Source
upload_model_preprocessing¶
def upload_model_preprocessing(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Upload a preprocessing model file for a global model.
The preprocessing model file should be a PyTorch serialized model.
Providing the model via torch.save
as well as in TorchScript format is supported.
transforms = torch.nn.Sequential(
torchvision.transforms.CenterCrop(10),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
Make sure to only use transformations that inherit from torch.nn.Module
.
It is advised to use the torchvision.transforms.v2
module for common transformations.
Please note that this function is still in the beta phase.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | global model UUID | None |
Returns:
Type | Description |
---|---|
HttpResponse | upload success message as json response |
Raises:
Type | Description |
---|---|
PermissionDenied | Unauthorized to upload preprocessing model for the specified model |
ValidationError | Preprocessing model is not a valid torch model |
View Source
@extend_schema(
request={
"multipart/form-data": {
"type": "object",
"properties": {
"model_preprocessing_file": {"type": "string", "format": "binary"},
},
},
},
responses={
status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
"detail": CharField(default="Proprocessing Model Upload Accepted"),
"model_id": UUIDField(),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
)
def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Upload a preprocessing model file for a global model.
The preprocessing model file should be a PyTorch serialized model.
Providing the model via `torch.save` as well as in TorchScript format is supported.
```python
transforms = torch.nn.Sequential(
torchvision.transforms.CenterCrop(10),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
```
Make sure to only use transformations that inherit from `torch.nn.Module`.
It is advised to use the `torchvision.transforms.v2` module for common transformations.
Please note that this function is still in the beta phase.
Args:
request (HttpRequest): request object
id (str): global model UUID
Raises:
PermissionDenied: Unauthorized to upload preprocessing model for the specified model
ValidationError: Preprocessing model is not a valid torch model
Returns:
HttpResponse: upload success message as json response
"""
model = get_entity(GlobalModelDB, pk=id)
if request.user.id != model.owner.id:
raise PermissionDenied(f"You are not the owner of model {model.id}!")
model.preprocessing = get_file(request, "model_preprocessing_file")
verify_model_object(model.preprocessing, "preprocessing")
model.save()
return JsonResponse({
"detail": "Proprocessing Model Upload Accepted",
}, status=status.HTTP_202_ACCEPTED)
Training¶
Training model ViewSet.
This ViewSet is used to create and manage trainings.
View Source
class Training(ViewSet):
"""
Training model ViewSet.
This ViewSet is used to create and manage trainings.
"""
serializer_class = TrainingSerializer
"""The serializer for the ViewSet."""
def _check_user_permission_for_training(self, user: UserDB, training_id: UUID | str) -> TrainingDB:
"""
Check if a user has permission for a training.
This method checks if the user is the actor of the training or a participant in the training.
Args:
user (UserDB): The user.
training_id (UUID | str): The ID of the training.
Returns:
TrainingDB: The training.
"""
if isinstance(training_id, str):
training_id = UUID(training_id)
training = get_entity(TrainingDB, pk=training_id)
if training.actor != user and user not in training.participants.all():
raise PermissionDenied()
return training
def _get_clients_from_body(self, body_raw: bytes) -> list[UserDB]:
"""
Get clients or participants from a request body.
This method retrieves and loads all client data associated with the provided list of UUIDs contained
within the request's clients field in the request body.
Args:
body_raw (bytes): The raw request body.
Returns:
list[UserDB]: The clients.
"""
body: ClientAdministrationBody = self._load_marshmallow_request(ClientAdministrationBodySchema(), body_raw)
return self._get_clients_from_uuid_list(body.clients)
def _get_clients_from_uuid_list(self, uuids: list[UUID]) -> list[UserDB]:
"""
Get clients from a list of UUIDs.
This method gets the clients with the IDs in the list of UUIDs from the database.
Args:
uuids (list[UUID]): The list of UUIDs.
Returns:
list[UserDB]: The clients.
"""
if uuids is None or len(uuids) == 0:
return []
# Note: filter "in" does not raise UserDB.DoesNotExist exceptions
clients = UserDB.objects.filter(id__in=uuids)
if len(clients) != len(uuids):
raise ParseError("Not all provided users were found!")
return clients
def _load_marshmallow_request(self, schema: Schema, json_data: str | bytes | bytearray):
"""
Load JSON data using from a request using a Marshmallow schema.
Args:
schema (Schema): The Marshmallow schema to use for loading the request.
json_data (str | bytes | bytearray): The JSON data to load.
Raises:
ParseError: If a MarshmallowValidationError occurs.
Returns:
dict: The loaded data.
"""
try:
return schema.load(json.loads(json_data)) # should `schema.loads` be used instead?
except MarshmallowValidationError as e:
raise ParseError(e.messages) from e
@extend_schema(responses={
status.HTTP_200_OK: TrainingSerializer(many=True),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
})
def get_trainings(self, request: HttpRequest) -> HttpResponse:
"""
Get information about all owned trainings.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: list of training data as json response
"""
trainings = TrainingDB.objects.filter(actor=request.user)
serializer = TrainingSerializer(trainings, many=True)
return Response(serializer.data)
@extend_schema(responses={
status.HTTP_200_OK: TrainingSerializerWithRounds,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
})
def get_training(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Get information about the selected training.
Args:
request (HttpRequest): request object
id (str): training uuid
Returns:
HttpResponse: training data as json response
"""
train = self._check_user_permission_for_training(request.user, id)
serializer = TrainingSerializerWithRounds(train)
return Response(serializer.data)
@extend_schema(
request=inline_serializer("EmptyBodySerializer", fields={}),
responses={
status.HTTP_200_OK: inline_serializer(
"StartTrainingSuccessSerializer",
fields={
"detail": CharField(default="Training started!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def start_training(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Start a training process.
This method checks if there are any participants registered for the training process.
If there are participants, it checks if the training process is in the INITIAL state and starts the training
session.
Args:
request (HttpRequest): The request object, which includes information about the user making the request.
id (str): The UUID of the training process to start.
Raises:
ParseError: If there are no participants registered for the training process or if the training process
is not in the INITIAL state.
Returns:
HttpResponse: A JSON response indicating that the training process has started.
"""
training = self._check_user_permission_for_training(request.user, id)
if training.participants.count() == 0:
raise ParseError("At least one participant must be registered!")
if training.state != TrainingState.INITIAL:
raise ParseError(f"Training {training.id} is not in state INITIAL!")
ModelTrainer(training).start()
return JsonResponse({"detail": "Training started!"}, status=status.HTTP_202_ACCEPTED)
@extend_schema(
request=inline_serializer(
"RegisterClientsSerializer",
fields={
"clients": ListField(child=UUIDField())
}
),
responses={
status.HTTP_200_OK: inline_serializer(
"RegisteredClientsSuccessSerializer",
fields={
"detail": CharField(default="Users registered as participants!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def register_clients(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Register one or more clients for a training process.
This method is designed to be called by a POST request with a JSON body of the form
`{"clients": [<list of UUIDs>]}`.
It adds these clients as participants of the training process.
Note: This method should be called once before the training process is started.
Args:
request (HttpRequest): The request object.
id (str): The UUID of the training process.
Returns:
HttpResponse: 202 Response if clients were registered, else corresponding error code.
"""
train = self._check_user_permission_for_training(request.user, id)
clients = self._get_clients_from_body(request.body)
train.participants.add(*clients)
return JsonResponse({"detail": "Users registered as participants!"}, status=status.HTTP_202_ACCEPTED)
@extend_schema(responses={
status.HTTP_200_OK: inline_serializer(
"DeleteTrainingSuccessSerializer",
fields={
"detail": CharField(default="Training removed!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
})
def remove_training(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Remove an existing training process.
Args:
request (HttpRequest): request object
id (str): training uuid
Returns:
HttpResponse: 200 Response if training was removed, else corresponding error code
"""
training = get_entity(TrainingDB, pk=id)
if training.actor != request.user:
raise PermissionDenied("You are not the owner the training.")
training.delete()
return JsonResponse({"detail": "Training removed!"})
@extend_schema(
request=inline_serializer(
"RemoveClientsSerializer",
fields={
"clients": ListField(child=UUIDField())
}
),
responses={
status.HTTP_200_OK: inline_serializer(
"RemovedClientsSuccessSerializer",
fields={
"detail": CharField(default="Users removed from training participants!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def remove_clients(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Remove one or more clients from a training process.
This method is designed to modify an already existing training process.
Args:
request (HttpRequest): The request object.
id (str): The UUID of the training process.
Returns:
HttpResponse: 200 Response if clients were removed, else corresponding error code.
"""
train = self._check_user_permission_for_training(request.user, id)
clients = self._get_clients_from_body(request.body)
train.participants.remove(*clients)
return JsonResponse({"detail": "Users removed from training participants!"})
@extend_schema(
request=inline_serializer(
name="TrainingCreationSerializer",
fields={
"model_id": CharField(),
"target_num_updates": IntegerField(),
"metric_names": ListField(child=CharField()),
"aggregation_method": CharField(),
"clients": ListField(child=UUIDField())
}
),
responses={
status.HTTP_200_OK: inline_serializer("TrainingCreatedSerializer", fields={
"detail": CharField(default="Training created successfully!"),
"training_id": UUIDField(),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def create_training(self, request: HttpRequest) -> HttpResponse:
"""
Create a new training process.
This method is designed to be called by a POST request according to the `CreateTrainingRequestSchema`.
The request should include a model file (the initial model) as an attached FILE.
Args:
request (HttpRequest): The request object.
Returns:
HttpResponse: 201 if training could be registered.
"""
parsed_request: CreateTrainingRequest = self._load_marshmallow_request(
CreateTrainingRequestSchema(),
request.body.decode("utf-8")
)
model = get_entity(ModelDB, pk=parsed_request.model_id)
if model.owner != request.user:
raise PermissionDenied()
if TrainingDB.objects.filter(model=model).exists():
# the selected model is already referenced by another training, so we need to copy it
model = clone_model(model)
clients = self._get_clients_from_uuid_list(parsed_request.clients)
train = TrainingDB.objects.create(
model=model,
actor=request.user,
target_num_updates=parsed_request.target_num_updates,
state=TrainingState.INITIAL,
uncertainty_method=parsed_request.uncertainty_method.value,
aggregation_method=parsed_request.aggregation_method.value,
options=parsed_request.options
)
train.participants.add(*clients)
return JsonResponse({
"detail": "Training created successfully!",
"training_id": train.id
}, status=status.HTTP_201_CREATED)
Ancestors (in MRO)¶
- fl_server_api.views.base.ViewSet
- rest_framework.viewsets.ViewSet
- rest_framework.viewsets.ViewSetMixin
- rest_framework.views.APIView
- django.views.generic.base.View
Class variables¶
The serializer for the ViewSet.
Static methods¶
as_view¶
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement .as_view
,
and slightly modify the view function that is created and returned.
View Source
@classonlymethod
def as_view(cls, actions=None, **initkwargs):
"""
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement `.as_view`,
and slightly modify the view function that is created and returned.
"""
# The name and description initkwargs may be explicitly overridden for
# certain route configurations. eg, names of extra actions.
cls.name = None
cls.description = None
# The suffix initkwarg is reserved for displaying the viewset type.
# This initkwarg should have no effect if the name is provided.
# eg. 'List' or 'Instance'.
cls.suffix = None
# The detail initkwarg is reserved for introspecting the viewset type.
cls.detail = None
# Setting a basename allows a view to reverse its action urls. This
# value is provided by the router through the initkwargs.
cls.basename = None
# actions must not be empty
if not actions:
raise TypeError("The `actions` argument must be provided when "
"calling `.as_view()` on a ViewSet. For example "
"`.as_view({'get': 'list'})`")
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
raise TypeError("You tried to pass in the %s method name as a "
"keyword argument to %s(). Don't do that."
% (key, cls.__name__))
if not hasattr(cls, key):
raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key))
# name and suffix are mutually exclusive
if 'name' in initkwargs and 'suffix' in initkwargs:
raise TypeError("%s() received both `name` and `suffix`, which are "
"mutually exclusive arguments." % (cls.__name__))
def view(request, *args, **kwargs):
self = cls(**initkwargs)
if 'get' in actions and 'head' not in actions:
actions['head'] = actions['get']
# We also store the mapping of request methods to actions,
# so that we can later set the action attribute.
# eg. `self.action = 'list'` on an incoming GET request.
self.action_map = actions
# Bind methods to actions
# This is the bit that's different to a standard view
for method, action in actions.items():
handler = getattr(self, action)
setattr(self, method, handler)
self.request = request
self.args = args
self.kwargs = kwargs
# And continue as usual
return self.dispatch(request, *args, **kwargs)
# take name and docstring from class
update_wrapper(view, cls, updated=())
# and possible attributes set by decorators
# like csrf_exempt from dispatch
update_wrapper(view, cls.dispatch, assigned=())
# We need to set these on the view function, so that breadcrumb
# generation can pick out these bits of information from a
# resolved URL.
view.cls = cls
view.initkwargs = initkwargs
view.actions = actions
return csrf_exempt(view)
get_extra_actions¶
Get the methods that are marked as an extra ViewSet @action
.
View Source
Instance variables¶
Wrap Django's private _allowed_methods
interface in a public property.
Methods¶
check_object_permissions¶
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
View Source
def check_object_permissions(self, request, obj):
"""
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_object_permission(request, self, obj):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_permissions¶
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
View Source
def check_permissions(self, request):
"""
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_throttles¶
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
View Source
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
create_training¶
def create_training(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Create a new training process.
This method is designed to be called by a POST request according to the CreateTrainingRequestSchema
.
The request should include a model file (the initial model) as an attached FILE.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The request object. | None |
Returns:
Type | Description |
---|---|
HttpResponse | 201 if training could be registered. |
View Source
@extend_schema(
request=inline_serializer(
name="TrainingCreationSerializer",
fields={
"model_id": CharField(),
"target_num_updates": IntegerField(),
"metric_names": ListField(child=CharField()),
"aggregation_method": CharField(),
"clients": ListField(child=UUIDField())
}
),
responses={
status.HTTP_200_OK: inline_serializer("TrainingCreatedSerializer", fields={
"detail": CharField(default="Training created successfully!"),
"training_id": UUIDField(),
}),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def create_training(self, request: HttpRequest) -> HttpResponse:
"""
Create a new training process.
This method is designed to be called by a POST request according to the `CreateTrainingRequestSchema`.
The request should include a model file (the initial model) as an attached FILE.
Args:
request (HttpRequest): The request object.
Returns:
HttpResponse: 201 if training could be registered.
"""
parsed_request: CreateTrainingRequest = self._load_marshmallow_request(
CreateTrainingRequestSchema(),
request.body.decode("utf-8")
)
model = get_entity(ModelDB, pk=parsed_request.model_id)
if model.owner != request.user:
raise PermissionDenied()
if TrainingDB.objects.filter(model=model).exists():
# the selected model is already referenced by another training, so we need to copy it
model = clone_model(model)
clients = self._get_clients_from_uuid_list(parsed_request.clients)
train = TrainingDB.objects.create(
model=model,
actor=request.user,
target_num_updates=parsed_request.target_num_updates,
state=TrainingState.INITIAL,
uncertainty_method=parsed_request.uncertainty_method.value,
aggregation_method=parsed_request.aggregation_method.value,
options=parsed_request.options
)
train.participants.add(*clients)
return JsonResponse({
"detail": "Training created successfully!",
"training_id": train.id
}, status=status.HTTP_201_CREATED)
determine_version¶
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
View Source
def determine_version(self, request, *args, **kwargs):
"""
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
"""
if self.versioning_class is None:
return (None, None)
scheme = self.versioning_class()
return (scheme.determine_version(request, *args, **kwargs), scheme)
dispatch¶
.dispatch()
is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
View Source
def dispatch(self, request, *args, **kwargs):
"""
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
finalize_response¶
Returns the final response object.
View Source
def finalize_response(self, request, response, *args, **kwargs):
"""
Returns the final response object.
"""
# Make the error obvious if a proper response is not returned
assert isinstance(response, HttpResponseBase), (
'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
'to be returned from the view, but received a `%s`'
% type(response)
)
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
neg = self.perform_content_negotiation(request, force=True)
request.accepted_renderer, request.accepted_media_type = neg
response.accepted_renderer = request.accepted_renderer
response.accepted_media_type = request.accepted_media_type
response.renderer_context = self.get_renderer_context()
# Add new vary headers to the response instead of overwriting.
vary_headers = self.headers.pop('Vary', None)
if vary_headers is not None:
patch_vary_headers(response, cc_delim_re.split(vary_headers))
for key, value in self.headers.items():
response[key] = value
return response
get_authenticate_header¶
If a request is unauthenticated, determine the WWW-Authenticate
header to use for 401 responses, if any.
View Source
get_authenticators¶
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them. Otherwise, it falls back to the default authenticators.
Returns:
Type | Description |
---|---|
list | The authenticators for the ViewSet. |
View Source
def get_authenticators(self):
"""
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them.
Otherwise, it falls back to the default authenticators.
Returns:
list: The authenticators for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "authentication_classes"):
return method.authentication_classes
return super().get_authenticators()
get_content_negotiator¶
Instantiate and return the content negotiation class to use.
View Source
get_exception_handler¶
Returns the exception handler that this view uses.
View Source
get_exception_handler_context¶
Returns a dict that is passed through to EXCEPTION_HANDLER,
as the context
argument.
View Source
get_extra_action_url_map¶
Build a map of {names: urls} for the extra actions.
This method will noop if detail
was not provided as a view initkwarg.
View Source
def get_extra_action_url_map(self):
"""
Build a map of {names: urls} for the extra actions.
This method will noop if `detail` was not provided as a view initkwarg.
"""
action_urls = OrderedDict()
# exit early if `detail` has not been provided
if self.detail is None:
return action_urls
# filter for the relevant extra actions
actions = [
action for action in self.get_extra_actions()
if action.detail == self.detail
]
for action in actions:
try:
url_name = '%s-%s' % (self.basename, action.url_name)
namespace = self.request.resolver_match.namespace
if namespace:
url_name = '%s:%s' % (namespace, url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url
except NoReverseMatch:
pass # URL requires additional arguments, ignore
return action_urls
get_format_suffix¶
Determine if the request includes a '.json' style format suffix
View Source
get_parser_context¶
Returns a dict that is passed through to Parser.parse(),
as the parser_context
keyword argument.
View Source
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
# Note: Additionally `request` and `encoding` will also be added
# to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {})
}
get_parsers¶
Instantiates and returns the list of parsers that this view can use.
View Source
get_permissions¶
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them. Otherwise, it falls back to the default permissions.
Returns:
Type | Description |
---|---|
list | The permissions for the ViewSet. |
View Source
def get_permissions(self):
"""
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them.
Otherwise, it falls back to the default permissions.
Returns:
list: The permissions for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "permission_classes"):
return method.permission_classes
return super().get_permissions()
get_renderer_context¶
Returns a dict that is passed through to Renderer.render(),
as the renderer_context
keyword argument.
View Source
def get_renderer_context(self):
"""
Returns a dict that is passed through to Renderer.render(),
as the `renderer_context` keyword argument.
"""
# Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {}),
'request': getattr(self, 'request', None)
}
get_renderers¶
Instantiates and returns the list of renderers that this view can use.
View Source
get_throttles¶
Instantiates and returns the list of throttles that this view uses.
View Source
get_training¶
def get_training(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Get information about the selected training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | training uuid | None |
Returns:
Type | Description |
---|---|
HttpResponse | training data as json response |
View Source
@extend_schema(responses={
status.HTTP_200_OK: TrainingSerializerWithRounds,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
})
def get_training(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Get information about the selected training.
Args:
request (HttpRequest): request object
id (str): training uuid
Returns:
HttpResponse: training data as json response
"""
train = self._check_user_permission_for_training(request.user, id)
serializer = TrainingSerializerWithRounds(train)
return Response(serializer.data)
get_trainings¶
def get_trainings(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Get information about all owned trainings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
Returns:
Type | Description |
---|---|
HttpResponse | list of training data as json response |
View Source
@extend_schema(responses={
status.HTTP_200_OK: TrainingSerializer(many=True),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
})
def get_trainings(self, request: HttpRequest) -> HttpResponse:
"""
Get information about all owned trainings.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: list of training data as json response
"""
trainings = TrainingDB.objects.filter(actor=request.user)
serializer = TrainingSerializer(trainings, many=True)
return Response(serializer.data)
get_view_description¶
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
View Source
get_view_name¶
Return the view name, as used in OPTIONS responses and in the
browsable API.
View Source
handle_exception¶
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
View Source
def handle_exception(self, exc):
"""
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
exception_handler = self.get_exception_handler()
context = self.get_exception_handler_context()
response = exception_handler(exc, context)
if response is None:
self.raise_uncaught_exception(exc)
response.exception = True
return response
http_method_not_allowed¶
If request.method
does not correspond to a handler method,
determine what kind of exception to raise.
View Source
initial¶
Runs anything that needs to occur prior to calling the method handler.
View Source
def initial(self, request, *args, **kwargs):
"""
Runs anything that needs to occur prior to calling the method handler.
"""
self.format_kwarg = self.get_format_suffix(**kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
initialize_request¶
Set the .action
attribute on the view, depending on the request method.
View Source
def initialize_request(self, request, *args, **kwargs):
"""
Set the `.action` attribute on the view, depending on the request method.
"""
request = super().initialize_request(request, *args, **kwargs)
method = request.method.lower()
if method == 'options':
# This is a special case as we always provide handling for the
# options method in the base `View` class.
# Unlike the other explicitly defined actions, 'metadata' is implicit.
self.action = 'metadata'
else:
self.action = self.action_map.get(method)
return request
options¶
Handler method for HTTP 'OPTIONS' request.
View Source
def options(self, request, *args, **kwargs):
"""
Handler method for HTTP 'OPTIONS' request.
"""
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)
perform_authentication¶
Perform authentication on the incoming request.
Note that if you override this and simply 'pass', then authentication
will instead be performed lazily, the first time either
request.user
or request.auth
is accessed.
View Source
perform_content_negotiation¶
Determine which renderer and media type to use render the response.
View Source
def perform_content_negotiation(self, request, force=False):
"""
Determine which renderer and media type to use render the response.
"""
renderers = self.get_renderers()
conneg = self.get_content_negotiator()
try:
return conneg.select_renderer(request, renderers, self.format_kwarg)
except Exception:
if force:
return (renderers[0], renderers[0].media_type)
raise
permission_denied¶
If request is not permitted, determine what kind of exception to raise.
View Source
def permission_denied(self, request, message=None, code=None):
"""
If request is not permitted, determine what kind of exception to raise.
"""
if request.authenticators and not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied(detail=message, code=code)
raise_uncaught_exception¶
View Source
register_clients¶
def register_clients(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Register one or more clients for a training process.
This method is designed to be called by a POST request with a JSON body of the form
{"clients": [<list of UUIDs>]}
.
It adds these clients as participants of the training process.
Note: This method should be called once before the training process is started.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The request object. | None |
id | str | The UUID of the training process. | None |
Returns:
Type | Description |
---|---|
HttpResponse | 202 Response if clients were registered, else corresponding error code. |
View Source
@extend_schema(
request=inline_serializer(
"RegisterClientsSerializer",
fields={
"clients": ListField(child=UUIDField())
}
),
responses={
status.HTTP_200_OK: inline_serializer(
"RegisteredClientsSuccessSerializer",
fields={
"detail": CharField(default="Users registered as participants!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def register_clients(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Register one or more clients for a training process.
This method is designed to be called by a POST request with a JSON body of the form
`{"clients": [<list of UUIDs>]}`.
It adds these clients as participants of the training process.
Note: This method should be called once before the training process is started.
Args:
request (HttpRequest): The request object.
id (str): The UUID of the training process.
Returns:
HttpResponse: 202 Response if clients were registered, else corresponding error code.
"""
train = self._check_user_permission_for_training(request.user, id)
clients = self._get_clients_from_body(request.body)
train.participants.add(*clients)
return JsonResponse({"detail": "Users registered as participants!"}, status=status.HTTP_202_ACCEPTED)
remove_clients¶
def remove_clients(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Remove one or more clients from a training process.
This method is designed to modify an already existing training process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The request object. | None |
id | str | The UUID of the training process. | None |
Returns:
Type | Description |
---|---|
HttpResponse | 200 Response if clients were removed, else corresponding error code. |
View Source
@extend_schema(
request=inline_serializer(
"RemoveClientsSerializer",
fields={
"clients": ListField(child=UUIDField())
}
),
responses={
status.HTTP_200_OK: inline_serializer(
"RemovedClientsSuccessSerializer",
fields={
"detail": CharField(default="Users removed from training participants!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def remove_clients(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Remove one or more clients from a training process.
This method is designed to modify an already existing training process.
Args:
request (HttpRequest): The request object.
id (str): The UUID of the training process.
Returns:
HttpResponse: 200 Response if clients were removed, else corresponding error code.
"""
train = self._check_user_permission_for_training(request.user, id)
clients = self._get_clients_from_body(request.body)
train.participants.remove(*clients)
return JsonResponse({"detail": "Users removed from training participants!"})
remove_training¶
def remove_training(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Remove an existing training process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | training uuid | None |
Returns:
Type | Description |
---|---|
HttpResponse | 200 Response if training was removed, else corresponding error code |
View Source
@extend_schema(responses={
status.HTTP_200_OK: inline_serializer(
"DeleteTrainingSuccessSerializer",
fields={
"detail": CharField(default="Training removed!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
})
def remove_training(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Remove an existing training process.
Args:
request (HttpRequest): request object
id (str): training uuid
Returns:
HttpResponse: 200 Response if training was removed, else corresponding error code
"""
training = get_entity(TrainingDB, pk=id)
if training.actor != request.user:
raise PermissionDenied("You are not the owner the training.")
training.delete()
return JsonResponse({"detail": "Training removed!"})
reverse_action¶
Reverse the action for the given url_name
.
View Source
def reverse_action(self, url_name, *args, **kwargs):
"""
Reverse the action for the given `url_name`.
"""
url_name = '%s-%s' % (self.basename, url_name)
namespace = None
if self.request and self.request.resolver_match:
namespace = self.request.resolver_match.namespace
if namespace:
url_name = namespace + ':' + url_name
kwargs.setdefault('request', self.request)
return reverse(url_name, *args, **kwargs)
setup¶
Initialize attributes shared by all view methods.
View Source
start_training¶
def start_training(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Start a training process.
This method checks if there are any participants registered for the training process. If there are participants, it checks if the training process is in the INITIAL state and starts the training session.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | The request object, which includes information about the user making the request. | None |
id | str | The UUID of the training process to start. | None |
Returns:
Type | Description |
---|---|
HttpResponse | A JSON response indicating that the training process has started. |
Raises:
Type | Description |
---|---|
ParseError | If there are no participants registered for the training process or if the training process is not in the INITIAL state. |
View Source
@extend_schema(
request=inline_serializer("EmptyBodySerializer", fields={}),
responses={
status.HTTP_200_OK: inline_serializer(
"StartTrainingSuccessSerializer",
fields={
"detail": CharField(default="Training started!")
}
),
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def start_training(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Start a training process.
This method checks if there are any participants registered for the training process.
If there are participants, it checks if the training process is in the INITIAL state and starts the training
session.
Args:
request (HttpRequest): The request object, which includes information about the user making the request.
id (str): The UUID of the training process to start.
Raises:
ParseError: If there are no participants registered for the training process or if the training process
is not in the INITIAL state.
Returns:
HttpResponse: A JSON response indicating that the training process has started.
"""
training = self._check_user_permission_for_training(request.user, id)
if training.participants.count() == 0:
raise ParseError("At least one participant must be registered!")
if training.state != TrainingState.INITIAL:
raise ParseError(f"Training {training.id} is not in state INITIAL!")
ModelTrainer(training).start()
return JsonResponse({"detail": "Training started!"}, status=status.HTTP_202_ACCEPTED)
throttled¶
If request is throttled, determine what kind of exception to raise.
View Source
User¶
User model ViewSet.
View Source
class User(ViewSet):
"""
User model ViewSet.
"""
serializer_class = UserSerializer
"""The serializer for the ViewSet."""
@extend_schema(
responses={
status.HTTP_200_OK: UserSerializer(many=True),
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def get_users(self, request: HttpRequest) -> HttpResponse:
"""
Get all registered users as list.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: user list as json response
"""
serializer = UserSerializer(UserModel.objects.all(), many=True)
return Response(serializer.data)
def get_myself(self, request: HttpRequest) -> HttpResponse:
"""
Get current user.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: user data as json response
"""
serializer = UserSerializer(request.user)
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_200_OK: UserSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
OpenApiExample(
name="Get user by id",
description="Retrieve user data by user ID.",
value="88a9f11a-846b-43b5-bd15-367fc332ba59",
parameter_only=("id", OpenApiParameter.PATH)
)
]
)
def get_user(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Get user information.
Args:
request (HttpRequest): request object
id (str): user uuid
Returns:
HttpResponse: user as json response
"""
serializer = UserSerializer(get_entity(UserModel, pk=id), context={"request_user_id": request.user.id})
return Response(serializer.data)
@decorators.authentication_classes([])
@decorators.permission_classes([])
@extend_schema(
responses={
status.HTTP_200_OK: UserSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
},
examples=[
OpenApiExample("Jane Doe", value={
"message_endpoint": "http://example.com/",
"actor": True,
"client": True,
"username": "jane",
"first_name": "Jane",
"last_name": "Doe",
"email": "jane.doe@example.com",
"password": "my-super-secret-password"
})
]
)
def create_user(self, request: HttpRequest) -> HttpResponse:
"""
Create a new user.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: new created user as json response
"""
user = UserSerializer().create(request.data)
serializer = UserSerializer(user, context={"request_user_id": user.id})
return Response(serializer.data, status=status.HTTP_201_CREATED)
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
OpenApiExample(
name="User uuid",
value="88a9f11a-846b-43b5-bd15-367fc332ba59",
parameter_only=("id", OpenApiParameter.PATH)
)
]
)
def get_user_groups(self, request: HttpRequest) -> HttpResponse:
"""
Get user groups.
Args:
request (HttpRequest): request object
id (str): user uuid
Returns:
HttpResponse: user groups as json response
"""
serializer = GroupSerializer(request.user.groups, many=True)
return Response(serializer.data)
@extend_schema(
responses={
status.HTTP_200_OK: TrainingSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def get_user_trainings(self, request: HttpRequest) -> HttpResponse:
"""
Get user trainings.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: user trainings as json response
"""
trainings = TrainingModel.objects.filter(Q(actor=request.user) | Q(participants=request.user)).distinct()
serializer = TrainingSerializer(trainings, many=True)
return Response(serializer.data)
Ancestors (in MRO)¶
- fl_server_api.views.base.ViewSet
- rest_framework.viewsets.ViewSet
- rest_framework.viewsets.ViewSetMixin
- rest_framework.views.APIView
- django.views.generic.base.View
Class variables¶
The serializer for the ViewSet.
Static methods¶
as_view¶
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement .as_view
,
and slightly modify the view function that is created and returned.
View Source
@classonlymethod
def as_view(cls, actions=None, **initkwargs):
"""
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement `.as_view`,
and slightly modify the view function that is created and returned.
"""
# The name and description initkwargs may be explicitly overridden for
# certain route configurations. eg, names of extra actions.
cls.name = None
cls.description = None
# The suffix initkwarg is reserved for displaying the viewset type.
# This initkwarg should have no effect if the name is provided.
# eg. 'List' or 'Instance'.
cls.suffix = None
# The detail initkwarg is reserved for introspecting the viewset type.
cls.detail = None
# Setting a basename allows a view to reverse its action urls. This
# value is provided by the router through the initkwargs.
cls.basename = None
# actions must not be empty
if not actions:
raise TypeError("The `actions` argument must be provided when "
"calling `.as_view()` on a ViewSet. For example "
"`.as_view({'get': 'list'})`")
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
raise TypeError("You tried to pass in the %s method name as a "
"keyword argument to %s(). Don't do that."
% (key, cls.__name__))
if not hasattr(cls, key):
raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key))
# name and suffix are mutually exclusive
if 'name' in initkwargs and 'suffix' in initkwargs:
raise TypeError("%s() received both `name` and `suffix`, which are "
"mutually exclusive arguments." % (cls.__name__))
def view(request, *args, **kwargs):
self = cls(**initkwargs)
if 'get' in actions and 'head' not in actions:
actions['head'] = actions['get']
# We also store the mapping of request methods to actions,
# so that we can later set the action attribute.
# eg. `self.action = 'list'` on an incoming GET request.
self.action_map = actions
# Bind methods to actions
# This is the bit that's different to a standard view
for method, action in actions.items():
handler = getattr(self, action)
setattr(self, method, handler)
self.request = request
self.args = args
self.kwargs = kwargs
# And continue as usual
return self.dispatch(request, *args, **kwargs)
# take name and docstring from class
update_wrapper(view, cls, updated=())
# and possible attributes set by decorators
# like csrf_exempt from dispatch
update_wrapper(view, cls.dispatch, assigned=())
# We need to set these on the view function, so that breadcrumb
# generation can pick out these bits of information from a
# resolved URL.
view.cls = cls
view.initkwargs = initkwargs
view.actions = actions
return csrf_exempt(view)
get_extra_actions¶
Get the methods that are marked as an extra ViewSet @action
.
View Source
Instance variables¶
Wrap Django's private _allowed_methods
interface in a public property.
Methods¶
check_object_permissions¶
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
View Source
def check_object_permissions(self, request, obj):
"""
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_object_permission(request, self, obj):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_permissions¶
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
View Source
def check_permissions(self, request):
"""
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
check_throttles¶
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
View Source
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
create_user¶
def create_user(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Create a new user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
Returns:
Type | Description |
---|---|
HttpResponse | new created user as json response |
View Source
@decorators.authentication_classes([])
@decorators.permission_classes([])
@extend_schema(
responses={
status.HTTP_200_OK: UserSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
},
examples=[
OpenApiExample("Jane Doe", value={
"message_endpoint": "http://example.com/",
"actor": True,
"client": True,
"username": "jane",
"first_name": "Jane",
"last_name": "Doe",
"email": "jane.doe@example.com",
"password": "my-super-secret-password"
})
]
)
def create_user(self, request: HttpRequest) -> HttpResponse:
"""
Create a new user.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: new created user as json response
"""
user = UserSerializer().create(request.data)
serializer = UserSerializer(user, context={"request_user_id": user.id})
return Response(serializer.data, status=status.HTTP_201_CREATED)
determine_version¶
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
View Source
def determine_version(self, request, *args, **kwargs):
"""
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
"""
if self.versioning_class is None:
return (None, None)
scheme = self.versioning_class()
return (scheme.determine_version(request, *args, **kwargs), scheme)
dispatch¶
.dispatch()
is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
View Source
def dispatch(self, request, *args, **kwargs):
"""
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
finalize_response¶
Returns the final response object.
View Source
def finalize_response(self, request, response, *args, **kwargs):
"""
Returns the final response object.
"""
# Make the error obvious if a proper response is not returned
assert isinstance(response, HttpResponseBase), (
'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
'to be returned from the view, but received a `%s`'
% type(response)
)
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
neg = self.perform_content_negotiation(request, force=True)
request.accepted_renderer, request.accepted_media_type = neg
response.accepted_renderer = request.accepted_renderer
response.accepted_media_type = request.accepted_media_type
response.renderer_context = self.get_renderer_context()
# Add new vary headers to the response instead of overwriting.
vary_headers = self.headers.pop('Vary', None)
if vary_headers is not None:
patch_vary_headers(response, cc_delim_re.split(vary_headers))
for key, value in self.headers.items():
response[key] = value
return response
get_authenticate_header¶
If a request is unauthenticated, determine the WWW-Authenticate
header to use for 401 responses, if any.
View Source
get_authenticators¶
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them. Otherwise, it falls back to the default authenticators.
Returns:
Type | Description |
---|---|
list | The authenticators for the ViewSet. |
View Source
def get_authenticators(self):
"""
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them.
Otherwise, it falls back to the default authenticators.
Returns:
list: The authenticators for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "authentication_classes"):
return method.authentication_classes
return super().get_authenticators()
get_content_negotiator¶
Instantiate and return the content negotiation class to use.
View Source
get_exception_handler¶
Returns the exception handler that this view uses.
View Source
get_exception_handler_context¶
Returns a dict that is passed through to EXCEPTION_HANDLER,
as the context
argument.
View Source
get_extra_action_url_map¶
Build a map of {names: urls} for the extra actions.
This method will noop if detail
was not provided as a view initkwarg.
View Source
def get_extra_action_url_map(self):
"""
Build a map of {names: urls} for the extra actions.
This method will noop if `detail` was not provided as a view initkwarg.
"""
action_urls = OrderedDict()
# exit early if `detail` has not been provided
if self.detail is None:
return action_urls
# filter for the relevant extra actions
actions = [
action for action in self.get_extra_actions()
if action.detail == self.detail
]
for action in actions:
try:
url_name = '%s-%s' % (self.basename, action.url_name)
namespace = self.request.resolver_match.namespace
if namespace:
url_name = '%s:%s' % (namespace, url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url
except NoReverseMatch:
pass # URL requires additional arguments, ignore
return action_urls
get_format_suffix¶
Determine if the request includes a '.json' style format suffix
View Source
get_myself¶
def get_myself(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Get current user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
Returns:
Type | Description |
---|---|
HttpResponse | user data as json response |
View Source
get_parser_context¶
Returns a dict that is passed through to Parser.parse(),
as the parser_context
keyword argument.
View Source
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
# Note: Additionally `request` and `encoding` will also be added
# to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {})
}
get_parsers¶
Instantiates and returns the list of parsers that this view can use.
View Source
get_permissions¶
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them. Otherwise, it falls back to the default permissions.
Returns:
Type | Description |
---|---|
list | The permissions for the ViewSet. |
View Source
def get_permissions(self):
"""
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them.
Otherwise, it falls back to the default permissions.
Returns:
list: The permissions for the ViewSet.
"""
if method := self._get_view_method():
if hasattr(method, "permission_classes"):
return method.permission_classes
return super().get_permissions()
get_renderer_context¶
Returns a dict that is passed through to Renderer.render(),
as the renderer_context
keyword argument.
View Source
def get_renderer_context(self):
"""
Returns a dict that is passed through to Renderer.render(),
as the `renderer_context` keyword argument.
"""
# Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {}),
'request': getattr(self, 'request', None)
}
get_renderers¶
Instantiates and returns the list of renderers that this view can use.
View Source
get_throttles¶
Instantiates and returns the list of throttles that this view uses.
View Source
get_user¶
def get_user(
self,
request: django.http.request.HttpRequest,
id: str
) -> django.http.response.HttpResponse
Get user information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | user uuid | None |
Returns:
Type | Description |
---|---|
HttpResponse | user as json response |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: UserSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
OpenApiExample(
name="Get user by id",
description="Retrieve user data by user ID.",
value="88a9f11a-846b-43b5-bd15-367fc332ba59",
parameter_only=("id", OpenApiParameter.PATH)
)
]
)
def get_user(self, request: HttpRequest, id: str) -> HttpResponse:
"""
Get user information.
Args:
request (HttpRequest): request object
id (str): user uuid
Returns:
HttpResponse: user as json response
"""
serializer = UserSerializer(get_entity(UserModel, pk=id), context={"request_user_id": request.user.id})
return Response(serializer.data)
get_user_groups¶
def get_user_groups(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Get user groups.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
id | str | user uuid | None |
Returns:
Type | Description |
---|---|
HttpResponse | user groups as json response |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: GroupSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
},
examples=[
OpenApiExample(
name="User uuid",
value="88a9f11a-846b-43b5-bd15-367fc332ba59",
parameter_only=("id", OpenApiParameter.PATH)
)
]
)
def get_user_groups(self, request: HttpRequest) -> HttpResponse:
"""
Get user groups.
Args:
request (HttpRequest): request object
id (str): user uuid
Returns:
HttpResponse: user groups as json response
"""
serializer = GroupSerializer(request.user.groups, many=True)
return Response(serializer.data)
get_user_trainings¶
def get_user_trainings(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Get user trainings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
Returns:
Type | Description |
---|---|
HttpResponse | user trainings as json response |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: TrainingSerializer,
status.HTTP_400_BAD_REQUEST: ErrorSerializer,
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def get_user_trainings(self, request: HttpRequest) -> HttpResponse:
"""
Get user trainings.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: user trainings as json response
"""
trainings = TrainingModel.objects.filter(Q(actor=request.user) | Q(participants=request.user)).distinct()
serializer = TrainingSerializer(trainings, many=True)
return Response(serializer.data)
get_users¶
def get_users(
self,
request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Get all registered users as list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request | HttpRequest | request object | None |
Returns:
Type | Description |
---|---|
HttpResponse | user list as json response |
View Source
@extend_schema(
responses={
status.HTTP_200_OK: UserSerializer(many=True),
status.HTTP_403_FORBIDDEN: error_response_403,
}
)
def get_users(self, request: HttpRequest) -> HttpResponse:
"""
Get all registered users as list.
Args:
request (HttpRequest): request object
Returns:
HttpResponse: user list as json response
"""
serializer = UserSerializer(UserModel.objects.all(), many=True)
return Response(serializer.data)
get_view_description¶
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
View Source
get_view_name¶
Return the view name, as used in OPTIONS responses and in the
browsable API.
View Source
handle_exception¶
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
View Source
def handle_exception(self, exc):
"""
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
exception_handler = self.get_exception_handler()
context = self.get_exception_handler_context()
response = exception_handler(exc, context)
if response is None:
self.raise_uncaught_exception(exc)
response.exception = True
return response
http_method_not_allowed¶
If request.method
does not correspond to a handler method,
determine what kind of exception to raise.
View Source
initial¶
Runs anything that needs to occur prior to calling the method handler.
View Source
def initial(self, request, *args, **kwargs):
"""
Runs anything that needs to occur prior to calling the method handler.
"""
self.format_kwarg = self.get_format_suffix(**kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
initialize_request¶
Set the .action
attribute on the view, depending on the request method.
View Source
def initialize_request(self, request, *args, **kwargs):
"""
Set the `.action` attribute on the view, depending on the request method.
"""
request = super().initialize_request(request, *args, **kwargs)
method = request.method.lower()
if method == 'options':
# This is a special case as we always provide handling for the
# options method in the base `View` class.
# Unlike the other explicitly defined actions, 'metadata' is implicit.
self.action = 'metadata'
else:
self.action = self.action_map.get(method)
return request
options¶
Handler method for HTTP 'OPTIONS' request.
View Source
def options(self, request, *args, **kwargs):
"""
Handler method for HTTP 'OPTIONS' request.
"""
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)
perform_authentication¶
Perform authentication on the incoming request.
Note that if you override this and simply 'pass', then authentication
will instead be performed lazily, the first time either
request.user
or request.auth
is accessed.
View Source
perform_content_negotiation¶
Determine which renderer and media type to use render the response.
View Source
def perform_content_negotiation(self, request, force=False):
"""
Determine which renderer and media type to use render the response.
"""
renderers = self.get_renderers()
conneg = self.get_content_negotiator()
try:
return conneg.select_renderer(request, renderers, self.format_kwarg)
except Exception:
if force:
return (renderers[0], renderers[0].media_type)
raise
permission_denied¶
If request is not permitted, determine what kind of exception to raise.
View Source
def permission_denied(self, request, message=None, code=None):
"""
If request is not permitted, determine what kind of exception to raise.
"""
if request.authenticators and not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied(detail=message, code=code)
raise_uncaught_exception¶
View Source
reverse_action¶
Reverse the action for the given url_name
.
View Source
def reverse_action(self, url_name, *args, **kwargs):
"""
Reverse the action for the given `url_name`.
"""
url_name = '%s-%s' % (self.basename, url_name)
namespace = None
if self.request and self.request.resolver_match:
namespace = self.request.resolver_match.namespace
if namespace:
url_name = namespace + ':' + url_name
kwargs.setdefault('request', self.request)
return reverse(url_name, *args, **kwargs)
setup¶
Initialize attributes shared by all view methods.
View Source
throttled¶
If request is throttled, determine what kind of exception to raise.