Module fl_server_api.views.training¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from django.http import HttpRequest, HttpResponse, JsonResponse
import json
from marshmallow import Schema
from marshmallow.exceptions import ValidationError as MarshmallowValidationError
from rest_framework import status
from rest_framework.exceptions import ParseError, PermissionDenied
from rest_framework.response import Response
from uuid import UUID
from fl_server_core.models import (
    Model as ModelDB,
    Training as TrainingDB,
    User as UserDB,
)
from fl_server_core.models.model import clone_model
from fl_server_core.models.training import TrainingState
from fl_server_ai.trainer import ModelTrainer
from .base import ViewSet
from ..utils import get_entity
from ..serializers.generic import TrainingSerializer, TrainingSerializerWithRounds
from ..serializers.training import (
    CreateTrainingRequest, CreateTrainingRequestSchema,
    ClientAdministrationBody, ClientAdministrationBodySchema
)
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.fields import UUIDField, CharField, IntegerField, ListField
from ..openapi import error_response_403
from ..serializers.generic import ErrorSerializer
class Training(ViewSet):
    """
    Training model ViewSet.
    This ViewSet is used to create and manage trainings.
    """
    serializer_class = TrainingSerializer
    """The serializer for the ViewSet."""
    def _check_user_permission_for_training(self, user: UserDB, training_id: UUID | str) -> TrainingDB:
        """
        Check if a user has permission for a training.
        This method checks if the user is the actor of the training or a participant in the training.
        Args:
            user (UserDB): The user.
            training_id (UUID | str): The ID of the training.
        Returns:
            TrainingDB: The training.
        """
        if isinstance(training_id, str):
            training_id = UUID(training_id)
        training = get_entity(TrainingDB, pk=training_id)
        if training.actor != user and user not in training.participants.all():
            raise PermissionDenied()
        return training
    def _get_clients_from_body(self, body_raw: bytes) -> list[UserDB]:
        """
        Get clients or participants from a request body.
        This method retrieves and loads all client data associated with the provided list of UUIDs contained
        within the request's clients field in the request body.
        Args:
            body_raw (bytes): The raw request body.
        Returns:
            list[UserDB]: The clients.
        """
        body: ClientAdministrationBody = self._load_marshmallow_request(ClientAdministrationBodySchema(), body_raw)
        return self._get_clients_from_uuid_list(body.clients)
    def _get_clients_from_uuid_list(self, uuids: list[UUID]) -> list[UserDB]:
        """
        Get clients from a list of UUIDs.
        This method gets the clients with the IDs in the list of UUIDs from the database.
        Args:
            uuids (list[UUID]): The list of UUIDs.
        Returns:
            list[UserDB]: The clients.
        """
        if uuids is None or len(uuids) == 0:
            return []
        # Note: filter "in" does not raise UserDB.DoesNotExist exceptions
        clients = UserDB.objects.filter(id__in=uuids)
        if len(clients) != len(uuids):
            raise ParseError("Not all provided users were found!")
        return clients
    def _load_marshmallow_request(self, schema: Schema, json_data: str | bytes | bytearray):
        """
        Load JSON data using from a request using a Marshmallow schema.
        Args:
            schema (Schema): The Marshmallow schema to use for loading the request.
            json_data (str | bytes | bytearray): The JSON data to load.
        Raises:
            ParseError: If a MarshmallowValidationError occurs.
        Returns:
            dict: The loaded data.
        """
        try:
            return schema.load(json.loads(json_data))  # should `schema.loads` be used instead?
        except MarshmallowValidationError as e:
            raise ParseError(e.messages) from e
    @extend_schema(responses={
        status.HTTP_200_OK: TrainingSerializer(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def get_trainings(self, request: HttpRequest) -> HttpResponse:
        """
        Get information about all owned trainings.
        Args:
            request (HttpRequest):  request object
        Returns:
            HttpResponse: list of training data as json response
        """
        trainings = TrainingDB.objects.filter(actor=request.user)
        serializer = TrainingSerializer(trainings, many=True)
        return Response(serializer.data)
    @extend_schema(responses={
        status.HTTP_200_OK: TrainingSerializerWithRounds,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def get_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Get information about the selected training.
        Args:
            request (HttpRequest):  request object
            id (str):  training uuid
        Returns:
            HttpResponse: training data as json response
        """
        train = self._check_user_permission_for_training(request.user, id)
        serializer = TrainingSerializerWithRounds(train)
        return Response(serializer.data)
    @extend_schema(
        request=inline_serializer("EmptyBodySerializer", fields={}),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "StartTrainingSuccessSerializer",
                fields={
                    "detail": CharField(default="Training started!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def start_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Start a training process.
        This method checks if there are any participants registered for the training process.
        If there are participants, it checks if the training process is in the INITIAL state and starts the training
        session.
        Args:
            request (HttpRequest): The request object, which includes information about the user making the request.
            id (str): The UUID of the training process to start.
        Raises:
            ParseError: If there are no participants registered for the training process or if the training process
                is not in the INITIAL state.
        Returns:
            HttpResponse: A JSON response indicating that the training process has started.
        """
        training = self._check_user_permission_for_training(request.user, id)
        if training.participants.count() == 0:
            raise ParseError("At least one participant must be registered!")
        if training.state != TrainingState.INITIAL:
            raise ParseError(f"Training {training.id} is not in state INITIAL!")
        ModelTrainer(training).start()
        return JsonResponse({"detail": "Training started!"}, status=status.HTTP_202_ACCEPTED)
    @extend_schema(
        request=inline_serializer(
            "RegisterClientsSerializer",
            fields={
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "RegisteredClientsSuccessSerializer",
                fields={
                    "detail": CharField(default="Users registered as participants!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def register_clients(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Register one or more clients for a training process.
        This method is designed to be called by a POST request with a JSON body of the form
        `{"clients": [<list of UUIDs>]}`.
        It adds these clients as participants of the training process.
        Note: This method should be called once before the training process is started.
        Args:
            request (HttpRequest): The request object.
            id (str): The UUID of the training process.
        Returns:
            HttpResponse: 202 Response if clients were registered, else corresponding error code.
        """
        train = self._check_user_permission_for_training(request.user, id)
        clients = self._get_clients_from_body(request.body)
        train.participants.add(*clients)
        return JsonResponse({"detail": "Users registered as participants!"}, status=status.HTTP_202_ACCEPTED)
    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteTrainingSuccessSerializer",
            fields={
                "detail": CharField(default="Training removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing training process.
        Args:
            request (HttpRequest):  request object
            id (str):  training uuid
        Returns:
            HttpResponse: 200 Response if training was removed, else corresponding error code
        """
        training = get_entity(TrainingDB, pk=id)
        if training.actor != request.user:
            raise PermissionDenied("You are not the owner the training.")
        training.delete()
        return JsonResponse({"detail": "Training removed!"})
    @extend_schema(
        request=inline_serializer(
            "RemoveClientsSerializer",
            fields={
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "RemovedClientsSuccessSerializer",
                fields={
                    "detail": CharField(default="Users removed from training participants!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def remove_clients(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove one or more clients from a training process.
        This method is designed to modify an already existing training process.
        Args:
            request (HttpRequest): The request object.
            id (str): The UUID of the training process.
        Returns:
            HttpResponse: 200 Response if clients were removed, else corresponding error code.
        """
        train = self._check_user_permission_for_training(request.user, id)
        clients = self._get_clients_from_body(request.body)
        train.participants.remove(*clients)
        return JsonResponse({"detail": "Users removed from training participants!"})
    @extend_schema(
        request=inline_serializer(
            name="TrainingCreationSerializer",
            fields={
                "model_id": CharField(),
                "target_num_updates": IntegerField(),
                "metric_names": ListField(child=CharField()),
                "aggregation_method": CharField(),
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer("TrainingCreatedSerializer", fields={
                "detail": CharField(default="Training created successfully!"),
                "training_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def create_training(self, request: HttpRequest) -> HttpResponse:
        """
        Create a new training process.
        This method is designed to be called by a POST request according to the `CreateTrainingRequestSchema`.
        The request should include a model file (the initial model) as an attached FILE.
        Args:
            request (HttpRequest):  The request object.
        Returns:
            HttpResponse: 201 if training could be registered.
        """
        parsed_request: CreateTrainingRequest = self._load_marshmallow_request(
            CreateTrainingRequestSchema(),
            request.body.decode("utf-8")
        )
        model = get_entity(ModelDB, pk=parsed_request.model_id)
        if model.owner != request.user:
            raise PermissionDenied()
        if TrainingDB.objects.filter(model=model).exists():
            # the selected model is already referenced by another training, so we need to copy it
            model = clone_model(model)
        clients = self._get_clients_from_uuid_list(parsed_request.clients)
        train = TrainingDB.objects.create(
            model=model,
            actor=request.user,
            target_num_updates=parsed_request.target_num_updates,
            state=TrainingState.INITIAL,
            uncertainty_method=parsed_request.uncertainty_method.value,
            aggregation_method=parsed_request.aggregation_method.value,
            options=parsed_request.options
        )
        train.participants.add(*clients)
        return JsonResponse({
            "detail": "Training created successfully!",
            "training_id": train.id
        }, status=status.HTTP_201_CREATED)
Classes¶
Training¶
Training model ViewSet.
This ViewSet is used to create and manage trainings.
View Source
class Training(ViewSet):
    """
    Training model ViewSet.
    This ViewSet is used to create and manage trainings.
    """
    serializer_class = TrainingSerializer
    """The serializer for the ViewSet."""
    def _check_user_permission_for_training(self, user: UserDB, training_id: UUID | str) -> TrainingDB:
        """
        Check if a user has permission for a training.
        This method checks if the user is the actor of the training or a participant in the training.
        Args:
            user (UserDB): The user.
            training_id (UUID | str): The ID of the training.
        Returns:
            TrainingDB: The training.
        """
        if isinstance(training_id, str):
            training_id = UUID(training_id)
        training = get_entity(TrainingDB, pk=training_id)
        if training.actor != user and user not in training.participants.all():
            raise PermissionDenied()
        return training
    def _get_clients_from_body(self, body_raw: bytes) -> list[UserDB]:
        """
        Get clients or participants from a request body.
        This method retrieves and loads all client data associated with the provided list of UUIDs contained
        within the request's clients field in the request body.
        Args:
            body_raw (bytes): The raw request body.
        Returns:
            list[UserDB]: The clients.
        """
        body: ClientAdministrationBody = self._load_marshmallow_request(ClientAdministrationBodySchema(), body_raw)
        return self._get_clients_from_uuid_list(body.clients)
    def _get_clients_from_uuid_list(self, uuids: list[UUID]) -> list[UserDB]:
        """
        Get clients from a list of UUIDs.
        This method gets the clients with the IDs in the list of UUIDs from the database.
        Args:
            uuids (list[UUID]): The list of UUIDs.
        Returns:
            list[UserDB]: The clients.
        """
        if uuids is None or len(uuids) == 0:
            return []
        # Note: filter "in" does not raise UserDB.DoesNotExist exceptions
        clients = UserDB.objects.filter(id__in=uuids)
        if len(clients) != len(uuids):
            raise ParseError("Not all provided users were found!")
        return clients
    def _load_marshmallow_request(self, schema: Schema, json_data: str | bytes | bytearray):
        """
        Load JSON data using from a request using a Marshmallow schema.
        Args:
            schema (Schema): The Marshmallow schema to use for loading the request.
            json_data (str | bytes | bytearray): The JSON data to load.
        Raises:
            ParseError: If a MarshmallowValidationError occurs.
        Returns:
            dict: The loaded data.
        """
        try:
            return schema.load(json.loads(json_data))  # should `schema.loads` be used instead?
        except MarshmallowValidationError as e:
            raise ParseError(e.messages) from e
    @extend_schema(responses={
        status.HTTP_200_OK: TrainingSerializer(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def get_trainings(self, request: HttpRequest) -> HttpResponse:
        """
        Get information about all owned trainings.
        Args:
            request (HttpRequest):  request object
        Returns:
            HttpResponse: list of training data as json response
        """
        trainings = TrainingDB.objects.filter(actor=request.user)
        serializer = TrainingSerializer(trainings, many=True)
        return Response(serializer.data)
    @extend_schema(responses={
        status.HTTP_200_OK: TrainingSerializerWithRounds,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def get_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Get information about the selected training.
        Args:
            request (HttpRequest):  request object
            id (str):  training uuid
        Returns:
            HttpResponse: training data as json response
        """
        train = self._check_user_permission_for_training(request.user, id)
        serializer = TrainingSerializerWithRounds(train)
        return Response(serializer.data)
    @extend_schema(
        request=inline_serializer("EmptyBodySerializer", fields={}),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "StartTrainingSuccessSerializer",
                fields={
                    "detail": CharField(default="Training started!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def start_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Start a training process.
        This method checks if there are any participants registered for the training process.
        If there are participants, it checks if the training process is in the INITIAL state and starts the training
        session.
        Args:
            request (HttpRequest): The request object, which includes information about the user making the request.
            id (str): The UUID of the training process to start.
        Raises:
            ParseError: If there are no participants registered for the training process or if the training process
                is not in the INITIAL state.
        Returns:
            HttpResponse: A JSON response indicating that the training process has started.
        """
        training = self._check_user_permission_for_training(request.user, id)
        if training.participants.count() == 0:
            raise ParseError("At least one participant must be registered!")
        if training.state != TrainingState.INITIAL:
            raise ParseError(f"Training {training.id} is not in state INITIAL!")
        ModelTrainer(training).start()
        return JsonResponse({"detail": "Training started!"}, status=status.HTTP_202_ACCEPTED)
    @extend_schema(
        request=inline_serializer(
            "RegisterClientsSerializer",
            fields={
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "RegisteredClientsSuccessSerializer",
                fields={
                    "detail": CharField(default="Users registered as participants!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def register_clients(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Register one or more clients for a training process.
        This method is designed to be called by a POST request with a JSON body of the form
        `{"clients": [<list of UUIDs>]}`.
        It adds these clients as participants of the training process.
        Note: This method should be called once before the training process is started.
        Args:
            request (HttpRequest): The request object.
            id (str): The UUID of the training process.
        Returns:
            HttpResponse: 202 Response if clients were registered, else corresponding error code.
        """
        train = self._check_user_permission_for_training(request.user, id)
        clients = self._get_clients_from_body(request.body)
        train.participants.add(*clients)
        return JsonResponse({"detail": "Users registered as participants!"}, status=status.HTTP_202_ACCEPTED)
    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteTrainingSuccessSerializer",
            fields={
                "detail": CharField(default="Training removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing training process.
        Args:
            request (HttpRequest):  request object
            id (str):  training uuid
        Returns:
            HttpResponse: 200 Response if training was removed, else corresponding error code
        """
        training = get_entity(TrainingDB, pk=id)
        if training.actor != request.user:
            raise PermissionDenied("You are not the owner the training.")
        training.delete()
        return JsonResponse({"detail": "Training removed!"})
    @extend_schema(
        request=inline_serializer(
            "RemoveClientsSerializer",
            fields={
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "RemovedClientsSuccessSerializer",
                fields={
                    "detail": CharField(default="Users removed from training participants!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def remove_clients(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove one or more clients from a training process.
        This method is designed to modify an already existing training process.
        Args:
            request (HttpRequest): The request object.
            id (str): The UUID of the training process.
        Returns:
            HttpResponse: 200 Response if clients were removed, else corresponding error code.
        """
        train = self._check_user_permission_for_training(request.user, id)
        clients = self._get_clients_from_body(request.body)
        train.participants.remove(*clients)
        return JsonResponse({"detail": "Users removed from training participants!"})
    @extend_schema(
        request=inline_serializer(
            name="TrainingCreationSerializer",
            fields={
                "model_id": CharField(),
                "target_num_updates": IntegerField(),
                "metric_names": ListField(child=CharField()),
                "aggregation_method": CharField(),
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer("TrainingCreatedSerializer", fields={
                "detail": CharField(default="Training created successfully!"),
                "training_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def create_training(self, request: HttpRequest) -> HttpResponse:
        """
        Create a new training process.
        This method is designed to be called by a POST request according to the `CreateTrainingRequestSchema`.
        The request should include a model file (the initial model) as an attached FILE.
        Args:
            request (HttpRequest):  The request object.
        Returns:
            HttpResponse: 201 if training could be registered.
        """
        parsed_request: CreateTrainingRequest = self._load_marshmallow_request(
            CreateTrainingRequestSchema(),
            request.body.decode("utf-8")
        )
        model = get_entity(ModelDB, pk=parsed_request.model_id)
        if model.owner != request.user:
            raise PermissionDenied()
        if TrainingDB.objects.filter(model=model).exists():
            # the selected model is already referenced by another training, so we need to copy it
            model = clone_model(model)
        clients = self._get_clients_from_uuid_list(parsed_request.clients)
        train = TrainingDB.objects.create(
            model=model,
            actor=request.user,
            target_num_updates=parsed_request.target_num_updates,
            state=TrainingState.INITIAL,
            uncertainty_method=parsed_request.uncertainty_method.value,
            aggregation_method=parsed_request.aggregation_method.value,
            options=parsed_request.options
        )
        train.participants.add(*clients)
        return JsonResponse({
            "detail": "Training created successfully!",
            "training_id": train.id
        }, status=status.HTTP_201_CREATED)
Ancestors (in MRO)¶
- fl_server_api.views.base.ViewSet
 - rest_framework.viewsets.ViewSet
 - rest_framework.viewsets.ViewSetMixin
 - rest_framework.views.APIView
 - django.views.generic.base.View
 
Class variables¶
The serializer for the ViewSet.
Static methods¶
as_view¶
Because of the way class based views create a closure around the
instantiated view, we need to totally reimplement .as_view,
and slightly modify the view function that is created and returned.
View Source
    @classonlymethod
    def as_view(cls, actions=None, **initkwargs):
        """
        Because of the way class based views create a closure around the
        instantiated view, we need to totally reimplement `.as_view`,
        and slightly modify the view function that is created and returned.
        """
        # The name and description initkwargs may be explicitly overridden for
        # certain route configurations. eg, names of extra actions.
        cls.name = None
        cls.description = None
        # The suffix initkwarg is reserved for displaying the viewset type.
        # This initkwarg should have no effect if the name is provided.
        # eg. 'List' or 'Instance'.
        cls.suffix = None
        # The detail initkwarg is reserved for introspecting the viewset type.
        cls.detail = None
        # Setting a basename allows a view to reverse its action urls. This
        # value is provided by the router through the initkwargs.
        cls.basename = None
        # actions must not be empty
        if not actions:
            raise TypeError("The `actions` argument must be provided when "
                            "calling `.as_view()` on a ViewSet. For example "
                            "`.as_view({'get': 'list'})`")
        # sanitize keyword arguments
        for key in initkwargs:
            if key in cls.http_method_names:
                raise TypeError("You tried to pass in the %s method name as a "
                                "keyword argument to %s(). Don't do that."
                                % (key, cls.__name__))
            if not hasattr(cls, key):
                raise TypeError("%s() received an invalid keyword %r" % (
                    cls.__name__, key))
        # name and suffix are mutually exclusive
        if 'name' in initkwargs and 'suffix' in initkwargs:
            raise TypeError("%s() received both `name` and `suffix`, which are "
                            "mutually exclusive arguments." % (cls.__name__))
        def view(request, *args, **kwargs):
            self = cls(**initkwargs)
            if 'get' in actions and 'head' not in actions:
                actions['head'] = actions['get']
            # We also store the mapping of request methods to actions,
            # so that we can later set the action attribute.
            # eg. `self.action = 'list'` on an incoming GET request.
            self.action_map = actions
            # Bind methods to actions
            # This is the bit that's different to a standard view
            for method, action in actions.items():
                handler = getattr(self, action)
                setattr(self, method, handler)
            self.request = request
            self.args = args
            self.kwargs = kwargs
            # And continue as usual
            return self.dispatch(request, *args, **kwargs)
        # take name and docstring from class
        update_wrapper(view, cls, updated=())
        # and possible attributes set by decorators
        # like csrf_exempt from dispatch
        update_wrapper(view, cls.dispatch, assigned=())
        # We need to set these on the view function, so that breadcrumb
        # generation can pick out these bits of information from a
        # resolved URL.
        view.cls = cls
        view.initkwargs = initkwargs
        view.actions = actions
        return csrf_exempt(view)
get_extra_actions¶
Get the methods that are marked as an extra ViewSet @action.
View Source
Instance variables¶
Wrap Django's private _allowed_methods interface in a public property.
Methods¶
check_object_permissions¶
Check if the request should be permitted for a given object.
Raises an appropriate exception if the request is not permitted.
View Source
    def check_object_permissions(self, request, obj):
        """
        Check if the request should be permitted for a given object.
        Raises an appropriate exception if the request is not permitted.
        """
        for permission in self.get_permissions():
            if not permission.has_object_permission(request, self, obj):
                self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )
check_permissions¶
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
View Source
    def check_permissions(self, request):
        """
        Check if the request should be permitted.
        Raises an appropriate exception if the request is not permitted.
        """
        for permission in self.get_permissions():
            if not permission.has_permission(request, self):
                self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )
check_throttles¶
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
View Source
    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())
        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]
            duration = max(durations, default=None)
            self.throttled(request, duration)
create_training¶
def create_training(
    self,
    request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Create a new training process.
This method is designed to be called by a POST request according to the CreateTrainingRequestSchema.
The request should include a model file (the initial model) as an attached FILE.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| request | HttpRequest | The request object. | None | 
Returns:
| Type | Description | 
|---|---|
| HttpResponse | 201 if training could be registered. | 
View Source
    @extend_schema(
        request=inline_serializer(
            name="TrainingCreationSerializer",
            fields={
                "model_id": CharField(),
                "target_num_updates": IntegerField(),
                "metric_names": ListField(child=CharField()),
                "aggregation_method": CharField(),
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer("TrainingCreatedSerializer", fields={
                "detail": CharField(default="Training created successfully!"),
                "training_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def create_training(self, request: HttpRequest) -> HttpResponse:
        """
        Create a new training process.
        This method is designed to be called by a POST request according to the `CreateTrainingRequestSchema`.
        The request should include a model file (the initial model) as an attached FILE.
        Args:
            request (HttpRequest):  The request object.
        Returns:
            HttpResponse: 201 if training could be registered.
        """
        parsed_request: CreateTrainingRequest = self._load_marshmallow_request(
            CreateTrainingRequestSchema(),
            request.body.decode("utf-8")
        )
        model = get_entity(ModelDB, pk=parsed_request.model_id)
        if model.owner != request.user:
            raise PermissionDenied()
        if TrainingDB.objects.filter(model=model).exists():
            # the selected model is already referenced by another training, so we need to copy it
            model = clone_model(model)
        clients = self._get_clients_from_uuid_list(parsed_request.clients)
        train = TrainingDB.objects.create(
            model=model,
            actor=request.user,
            target_num_updates=parsed_request.target_num_updates,
            state=TrainingState.INITIAL,
            uncertainty_method=parsed_request.uncertainty_method.value,
            aggregation_method=parsed_request.aggregation_method.value,
            options=parsed_request.options
        )
        train.participants.add(*clients)
        return JsonResponse({
            "detail": "Training created successfully!",
            "training_id": train.id
        }, status=status.HTTP_201_CREATED)
determine_version¶
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
View Source
    def determine_version(self, request, *args, **kwargs):
        """
        If versioning is being used, then determine any API version for the
        incoming request. Returns a two-tuple of (version, versioning_scheme)
        """
        if self.versioning_class is None:
            return (None, None)
        scheme = self.versioning_class()
        return (scheme.determine_version(request, *args, **kwargs), scheme)
dispatch¶
.dispatch() is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
View Source
    def dispatch(self, request, *args, **kwargs):
        """
        `.dispatch()` is pretty much the same as Django's regular dispatch,
        but with extra hooks for startup, finalize, and exception handling.
        """
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?
        try:
            self.initial(request, *args, **kwargs)
            # Get the appropriate handler method
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed
            response = handler(request, *args, **kwargs)
        except Exception as exc:
            response = self.handle_exception(exc)
        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
finalize_response¶
Returns the final response object.
View Source
    def finalize_response(self, request, response, *args, **kwargs):
        """
        Returns the final response object.
        """
        # Make the error obvious if a proper response is not returned
        assert isinstance(response, HttpResponseBase), (
            'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
            'to be returned from the view, but received a `%s`'
            % type(response)
        )
        if isinstance(response, Response):
            if not getattr(request, 'accepted_renderer', None):
                neg = self.perform_content_negotiation(request, force=True)
                request.accepted_renderer, request.accepted_media_type = neg
            response.accepted_renderer = request.accepted_renderer
            response.accepted_media_type = request.accepted_media_type
            response.renderer_context = self.get_renderer_context()
        # Add new vary headers to the response instead of overwriting.
        vary_headers = self.headers.pop('Vary', None)
        if vary_headers is not None:
            patch_vary_headers(response, cc_delim_re.split(vary_headers))
        for key, value in self.headers.items():
            response[key] = value
        return response
get_authenticate_header¶
If a request is unauthenticated, determine the WWW-Authenticate
header to use for 401 responses, if any.
View Source
get_authenticators¶
Get the authenticators for the ViewSet.
This method gets the view method and, if it has authentication classes defined via the decorator, returns them. Otherwise, it falls back to the default authenticators.
Returns:
| Type | Description | 
|---|---|
| list | The authenticators for the ViewSet. | 
View Source
    def get_authenticators(self):
        """
        Get the authenticators for the ViewSet.
        This method gets the view method and, if it has authentication classes defined via the decorator, returns them.
        Otherwise, it falls back to the default authenticators.
        Returns:
            list: The authenticators for the ViewSet.
        """
        if method := self._get_view_method():
            if hasattr(method, "authentication_classes"):
                return method.authentication_classes
        return super().get_authenticators()
get_content_negotiator¶
Instantiate and return the content negotiation class to use.
View Source
get_exception_handler¶
Returns the exception handler that this view uses.
View Source
get_exception_handler_context¶
Returns a dict that is passed through to EXCEPTION_HANDLER,
as the context argument.
View Source
get_extra_action_url_map¶
Build a map of {names: urls} for the extra actions.
This method will noop if detail was not provided as a view initkwarg.
View Source
    def get_extra_action_url_map(self):
        """
        Build a map of {names: urls} for the extra actions.
        This method will noop if `detail` was not provided as a view initkwarg.
        """
        action_urls = OrderedDict()
        # exit early if `detail` has not been provided
        if self.detail is None:
            return action_urls
        # filter for the relevant extra actions
        actions = [
            action for action in self.get_extra_actions()
            if action.detail == self.detail
        ]
        for action in actions:
            try:
                url_name = '%s-%s' % (self.basename, action.url_name)
                namespace = self.request.resolver_match.namespace
                if namespace:
                    url_name = '%s:%s' % (namespace, url_name)
                url = reverse(url_name, self.args, self.kwargs, request=self.request)
                view = self.__class__(**action.kwargs)
                action_urls[view.get_view_name()] = url
            except NoReverseMatch:
                pass  # URL requires additional arguments, ignore
        return action_urls
get_format_suffix¶
Determine if the request includes a '.json' style format suffix
View Source
get_parser_context¶
Returns a dict that is passed through to Parser.parse(),
as the parser_context keyword argument.
View Source
    def get_parser_context(self, http_request):
        """
        Returns a dict that is passed through to Parser.parse(),
        as the `parser_context` keyword argument.
        """
        # Note: Additionally `request` and `encoding` will also be added
        #       to the context by the Request object.
        return {
            'view': self,
            'args': getattr(self, 'args', ()),
            'kwargs': getattr(self, 'kwargs', {})
        }
get_parsers¶
Instantiates and returns the list of parsers that this view can use.
View Source
get_permissions¶
Get the permissions for the ViewSet.
This method gets the view method and, if it has permission classes defined via the decorator, returns them. Otherwise, it falls back to the default permissions.
Returns:
| Type | Description | 
|---|---|
| list | The permissions for the ViewSet. | 
View Source
    def get_permissions(self):
        """
        Get the permissions for the ViewSet.
        This method gets the view method and, if it has permission classes defined via the decorator, returns them.
        Otherwise, it falls back to the default permissions.
        Returns:
            list: The permissions for the ViewSet.
        """
        if method := self._get_view_method():
            if hasattr(method, "permission_classes"):
                return method.permission_classes
        return super().get_permissions()
get_renderer_context¶
Returns a dict that is passed through to Renderer.render(),
as the renderer_context keyword argument.
View Source
    def get_renderer_context(self):
        """
        Returns a dict that is passed through to Renderer.render(),
        as the `renderer_context` keyword argument.
        """
        # Note: Additionally 'response' will also be added to the context,
        #       by the Response object.
        return {
            'view': self,
            'args': getattr(self, 'args', ()),
            'kwargs': getattr(self, 'kwargs', {}),
            'request': getattr(self, 'request', None)
        }
get_renderers¶
Instantiates and returns the list of renderers that this view can use.
View Source
get_throttles¶
Instantiates and returns the list of throttles that this view uses.
View Source
get_training¶
def get_training(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse
Get information about the selected training.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| request | HttpRequest | request object | None | 
| id | str | training uuid | None | 
Returns:
| Type | Description | 
|---|---|
| HttpResponse | training data as json response | 
View Source
    @extend_schema(responses={
        status.HTTP_200_OK: TrainingSerializerWithRounds,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def get_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Get information about the selected training.
        Args:
            request (HttpRequest):  request object
            id (str):  training uuid
        Returns:
            HttpResponse: training data as json response
        """
        train = self._check_user_permission_for_training(request.user, id)
        serializer = TrainingSerializerWithRounds(train)
        return Response(serializer.data)
get_trainings¶
def get_trainings(
    self,
    request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse
Get information about all owned trainings.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| request | HttpRequest | request object | None | 
Returns:
| Type | Description | 
|---|---|
| HttpResponse | list of training data as json response | 
View Source
    @extend_schema(responses={
        status.HTTP_200_OK: TrainingSerializer(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def get_trainings(self, request: HttpRequest) -> HttpResponse:
        """
        Get information about all owned trainings.
        Args:
            request (HttpRequest):  request object
        Returns:
            HttpResponse: list of training data as json response
        """
        trainings = TrainingDB.objects.filter(actor=request.user)
        serializer = TrainingSerializer(trainings, many=True)
        return Response(serializer.data)
get_view_description¶
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
View Source
get_view_name¶
Return the view name, as used in OPTIONS responses and in the
browsable API.
View Source
handle_exception¶
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
View Source
    def handle_exception(self, exc):
        """
        Handle any exception that occurs, by returning an appropriate response,
        or re-raising the error.
        """
        if isinstance(exc, (exceptions.NotAuthenticated,
                            exceptions.AuthenticationFailed)):
            # WWW-Authenticate header for 401 responses, else coerce to 403
            auth_header = self.get_authenticate_header(self.request)
            if auth_header:
                exc.auth_header = auth_header
            else:
                exc.status_code = status.HTTP_403_FORBIDDEN
        exception_handler = self.get_exception_handler()
        context = self.get_exception_handler_context()
        response = exception_handler(exc, context)
        if response is None:
            self.raise_uncaught_exception(exc)
        response.exception = True
        return response
http_method_not_allowed¶
If request.method does not correspond to a handler method,
determine what kind of exception to raise.
View Source
initial¶
Runs anything that needs to occur prior to calling the method handler.
View Source
    def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        """
        self.format_kwarg = self.get_format_suffix(**kwargs)
        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg
        # Determine the API version, if versioning is in use.
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme
        # Ensure that the incoming request is permitted
        self.perform_authentication(request)
        self.check_permissions(request)
        self.check_throttles(request)
initialize_request¶
Set the .action attribute on the view, depending on the request method.
View Source
    def initialize_request(self, request, *args, **kwargs):
        """
        Set the `.action` attribute on the view, depending on the request method.
        """
        request = super().initialize_request(request, *args, **kwargs)
        method = request.method.lower()
        if method == 'options':
            # This is a special case as we always provide handling for the
            # options method in the base `View` class.
            # Unlike the other explicitly defined actions, 'metadata' is implicit.
            self.action = 'metadata'
        else:
            self.action = self.action_map.get(method)
        return request
options¶
Handler method for HTTP 'OPTIONS' request.
View Source
    def options(self, request, *args, **kwargs):
        """
        Handler method for HTTP 'OPTIONS' request.
        """
        if self.metadata_class is None:
            return self.http_method_not_allowed(request, *args, **kwargs)
        data = self.metadata_class().determine_metadata(request, self)
        return Response(data, status=status.HTTP_200_OK)
perform_authentication¶
Perform authentication on the incoming request.
Note that if you override this and simply 'pass', then authentication
will instead be performed lazily, the first time either
request.user or request.auth is accessed.
View Source
perform_content_negotiation¶
Determine which renderer and media type to use render the response.
View Source
    def perform_content_negotiation(self, request, force=False):
        """
        Determine which renderer and media type to use render the response.
        """
        renderers = self.get_renderers()
        conneg = self.get_content_negotiator()
        try:
            return conneg.select_renderer(request, renderers, self.format_kwarg)
        except Exception:
            if force:
                return (renderers[0], renderers[0].media_type)
            raise
permission_denied¶
If request is not permitted, determine what kind of exception to raise.
View Source
    def permission_denied(self, request, message=None, code=None):
        """
        If request is not permitted, determine what kind of exception to raise.
        """
        if request.authenticators and not request.successful_authenticator:
            raise exceptions.NotAuthenticated()
        raise exceptions.PermissionDenied(detail=message, code=code)
raise_uncaught_exception¶
View Source
register_clients¶
def register_clients(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse
Register one or more clients for a training process.
This method is designed to be called by a POST request with a JSON body of the form
{"clients": [<list of UUIDs>]}.
It adds these clients as participants of the training process.
Note: This method should be called once before the training process is started.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| request | HttpRequest | The request object. | None | 
| id | str | The UUID of the training process. | None | 
Returns:
| Type | Description | 
|---|---|
| HttpResponse | 202 Response if clients were registered, else corresponding error code. | 
View Source
    @extend_schema(
        request=inline_serializer(
            "RegisterClientsSerializer",
            fields={
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "RegisteredClientsSuccessSerializer",
                fields={
                    "detail": CharField(default="Users registered as participants!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def register_clients(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Register one or more clients for a training process.
        This method is designed to be called by a POST request with a JSON body of the form
        `{"clients": [<list of UUIDs>]}`.
        It adds these clients as participants of the training process.
        Note: This method should be called once before the training process is started.
        Args:
            request (HttpRequest): The request object.
            id (str): The UUID of the training process.
        Returns:
            HttpResponse: 202 Response if clients were registered, else corresponding error code.
        """
        train = self._check_user_permission_for_training(request.user, id)
        clients = self._get_clients_from_body(request.body)
        train.participants.add(*clients)
        return JsonResponse({"detail": "Users registered as participants!"}, status=status.HTTP_202_ACCEPTED)
remove_clients¶
def remove_clients(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse
Remove one or more clients from a training process.
This method is designed to modify an already existing training process.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| request | HttpRequest | The request object. | None | 
| id | str | The UUID of the training process. | None | 
Returns:
| Type | Description | 
|---|---|
| HttpResponse | 200 Response if clients were removed, else corresponding error code. | 
View Source
    @extend_schema(
        request=inline_serializer(
            "RemoveClientsSerializer",
            fields={
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "RemovedClientsSuccessSerializer",
                fields={
                    "detail": CharField(default="Users removed from training participants!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def remove_clients(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove one or more clients from a training process.
        This method is designed to modify an already existing training process.
        Args:
            request (HttpRequest): The request object.
            id (str): The UUID of the training process.
        Returns:
            HttpResponse: 200 Response if clients were removed, else corresponding error code.
        """
        train = self._check_user_permission_for_training(request.user, id)
        clients = self._get_clients_from_body(request.body)
        train.participants.remove(*clients)
        return JsonResponse({"detail": "Users removed from training participants!"})
remove_training¶
def remove_training(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse
Remove an existing training process.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| request | HttpRequest | request object | None | 
| id | str | training uuid | None | 
Returns:
| Type | Description | 
|---|---|
| HttpResponse | 200 Response if training was removed, else corresponding error code | 
View Source
    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteTrainingSuccessSerializer",
            fields={
                "detail": CharField(default="Training removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing training process.
        Args:
            request (HttpRequest):  request object
            id (str):  training uuid
        Returns:
            HttpResponse: 200 Response if training was removed, else corresponding error code
        """
        training = get_entity(TrainingDB, pk=id)
        if training.actor != request.user:
            raise PermissionDenied("You are not the owner the training.")
        training.delete()
        return JsonResponse({"detail": "Training removed!"})
reverse_action¶
Reverse the action for the given url_name.
View Source
    def reverse_action(self, url_name, *args, **kwargs):
        """
        Reverse the action for the given `url_name`.
        """
        url_name = '%s-%s' % (self.basename, url_name)
        namespace = None
        if self.request and self.request.resolver_match:
            namespace = self.request.resolver_match.namespace
        if namespace:
            url_name = namespace + ':' + url_name
        kwargs.setdefault('request', self.request)
        return reverse(url_name, *args, **kwargs)
setup¶
Initialize attributes shared by all view methods.
View Source
start_training¶
def start_training(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse
Start a training process.
This method checks if there are any participants registered for the training process. If there are participants, it checks if the training process is in the INITIAL state and starts the training session.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| request | HttpRequest | The request object, which includes information about the user making the request. | None | 
| id | str | The UUID of the training process to start. | None | 
Returns:
| Type | Description | 
|---|---|
| HttpResponse | A JSON response indicating that the training process has started. | 
Raises:
| Type | Description | 
|---|---|
| ParseError | If there are no participants registered for the training process or if the training process is not in the INITIAL state.  | 
View Source
    @extend_schema(
        request=inline_serializer("EmptyBodySerializer", fields={}),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "StartTrainingSuccessSerializer",
                fields={
                    "detail": CharField(default="Training started!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def start_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Start a training process.
        This method checks if there are any participants registered for the training process.
        If there are participants, it checks if the training process is in the INITIAL state and starts the training
        session.
        Args:
            request (HttpRequest): The request object, which includes information about the user making the request.
            id (str): The UUID of the training process to start.
        Raises:
            ParseError: If there are no participants registered for the training process or if the training process
                is not in the INITIAL state.
        Returns:
            HttpResponse: A JSON response indicating that the training process has started.
        """
        training = self._check_user_permission_for_training(request.user, id)
        if training.participants.count() == 0:
            raise ParseError("At least one participant must be registered!")
        if training.state != TrainingState.INITIAL:
            raise ParseError(f"Training {training.id} is not in state INITIAL!")
        ModelTrainer(training).start()
        return JsonResponse({"detail": "Training started!"}, status=status.HTTP_202_ACCEPTED)
throttled¶
If request is throttled, determine what kind of exception to raise.