Skip to content

Module fl_server_api.views.model

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.models import AnonymousUser
from django.db import transaction
from django.db.models import Q
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.http.response import HttpResponseBase
from django.utils.datastructures import MultiValueDictKeyError
from drf_spectacular.utils import extend_schema, OpenApiExample, inline_serializer, OpenApiResponse
from itertools import groupby
from rest_framework import status
from rest_framework.exceptions import APIException, NotFound, ParseError, PermissionDenied, ValidationError
from rest_framework.response import Response
from rest_framework.fields import UUIDField, CharField
import torch
from typing import Any, List, Union
from uuid import UUID

from fl_server_core.models import (
    GlobalModel as GlobalModelDB,
    LocalModel as LocalModelDB,
    Metric as MetricDB,
    Model as ModelDB,
    SWAGModel as SWAGModelDB,
    User as UserDB,
)
from fl_server_core.models.training import Training, TrainingState
from fl_server_core.utils.locked_atomic_transaction import locked_atomic_transaction
from fl_server_ai.trainer.events import ModelTestFinished, SWAGRoundFinished, TrainingRoundFinished
from fl_server_ai.trainer.tasks import dispatch_trainer_task

from .base import ViewSet
from ..utils import get_entity, get_file
from ..serializers.generic import ErrorSerializer, MetricSerializer
from ..serializers.model import ModelSerializer, ModelSerializerNoWeightsWithStats, load_and_create_model_request, \
    GlobalModelSerializer, ModelSerializerNoWeights, verify_model_object
from ..openapi import error_response_403, error_response_404


class Model(ViewSet):
    """
    Model ViewSet.
    """

    serializer_class = GlobalModelSerializer
    """The serializer for the ViewSet."""

    def _get_user_related_global_models(self, user: Union[AbstractBaseUser, AnonymousUser]) -> List[ModelDB]:
        """
        Get global models related to a user.

        This method retrieves all global models where the user is the actor or a participant of.

        Args:
            user (Union[AbstractBaseUser, AnonymousUser]): The user.

        Returns:
            List[ModelDB]: The global models related to the user.
        """
        user_ids = Training.objects.filter(
            Q(actor=user) | Q(participants=user)
        ).distinct().values_list("model__id", flat=True)
        return ModelDB.objects.filter(Q(owner=user) | Q(id__in=user_ids)).distinct()

    def _get_local_models_for_global_model(self, global_model: GlobalModelDB) -> List[LocalModelDB]:
        """
        Get all local models that are based on the global model.

        Args:
            global_model (GlobalModelDB): The global model.

        Returns:
            List[LocalModelDB]: The local models for the global model.
        """
        return LocalModelDB.objects.filter(base_model=global_model).all()

    def _get_local_models_for_global_models(self, global_models: List[GlobalModelDB]) -> List[LocalModelDB]:
        """
        Get all local models that are based on any of the global models.

        Args:
            global_models (List[GlobalModelDB]): The global models.

        Returns:
            List[LocalModelDB]: The local models for the global models.
        """
        return LocalModelDB.objects.filter(base_model__in=global_models).all()

    def _filter_by_training(self, models: List[ModelDB], training_id: str) -> List[ModelDB]:
        """
        Filter a list of models by checking if they are associated with the training.

        Args:
            models (List[ModelDB]): The models to filter.
            training_id (str): The ID of the training.

        Returns:
            List[ModelDB]: The models associated with the training.
        """
        def associated_with_training(m: ModelDB) -> bool:
            training = m.get_training()
            if training is None:
                return False
            return training.pk == UUID(training_id)
        return list(filter(associated_with_training, models))

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_models(self, request: HttpRequest) -> HttpResponse:
        """
        Get a list of all global models associated with the requesting user.

        A global model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models = self._get_user_related_global_models(request.user)
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of all models associated with a specific training process and the requesting user.

        A model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        local_models = self._get_local_models_for_global_models(global_models)
        serializer = ModelSerializer([*global_models, *local_models], many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of the latest models for a specific training process associated with the requesting user.

        A model is considered associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.
        The latest model refers to the model from the most recent round (highest round number) of
        a participant's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models: List[ModelDB] = []
        # add latest global model
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        models.append(max(global_models, key=lambda m: m.round))
        # add latest local models
        local_models = self._get_local_models_for_global_models(global_models)
        local_models = sorted(local_models, key=lambda m: str(m.owner.pk))  # required for groupby
        for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
            models.append(max(group, key=lambda m: m.round))
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
        """
        Get model meta data.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: Model meta data as JSON response.
        """
        model = get_entity(ModelDB, pk=id)
        serializer = ModelSerializer(model, context={"with-stats": True})
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: model as file response
        """
        model = get_entity(ModelDB, pk=id)
        if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
            if model.swag_second_moment is None:
                raise APIException(f"Model {model.id} is in inconsistent state!")
            raise NotImplementedError(
                "SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
            )
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(model.weights, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}.pt"'
        return response

    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(
                response=bytes,
                description="Proprecessing model is returned as bytes"
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
            status.HTTP_404_NOT_FOUND: error_response_404,
        },
    )
    def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole preprocessing model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
        """
        model = get_entity(ModelDB, pk=id)
        global_model: torch.nn.Module
        if isinstance(model, GlobalModelDB):
            global_model = model
        elif isinstance(model, LocalModelDB):
            global_model = model.base_model
        else:
            self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
            raise ValidationError(f"Unknown model type. Model id: {id}")
        if global_model.preprocessing is None:
            raise NotFound(f"Model '{id}' has no preprocessing model defined.")
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
        return response

    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteModelSuccessSerializer",
            fields={
                "detail": CharField(default="Model removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing model.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: 200 Response if model was removed, else corresponding error code
        """
        model = get_entity(ModelDB, pk=id)
        if model.owner != request.user:
            training = model.get_training()
            if training is None or training.actor != request.user:
                raise PermissionDenied(
                    "You are neither the owner of the model nor the actor of the corresponding training."
                )
        model.delete()
        return JsonResponse({"detail": "Model removed!"})

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "name": {"type": "string"},
                    "description": {"type": "string"},
                    "model_file": {"type": "string", "format": "binary"},
                    "model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
                },
            },
        },
        responses={
            status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
                "detail": CharField(default="Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_model(self, request: HttpRequest) -> HttpResponse:
        """
        Upload a global model file.

        The model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: upload success message as json response
        """
        model = load_and_create_model_request(request)
        return JsonResponse({
            "detail": "Model Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "model_preprocessing_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
                "detail": CharField(default="Proprocessing Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a preprocessing model file for a global model.

        The preprocessing model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        ```python
        transforms = torch.nn.Sequential(
            torchvision.transforms.CenterCrop(10),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        )
        ```

        Make sure to only use transformations that inherit from `torch.nn.Module`.
        It is advised to use the `torchvision.transforms.v2` module for common transformations.

        Please note that this function is still in the beta phase.

        Args:
            request (HttpRequest): request object
            id (str): global model UUID

        Raises:
            PermissionDenied: Unauthorized to upload preprocessing model for the specified model
            ValidationError: Preprocessing model is not a valid torch model

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(GlobalModelDB, pk=id)
        if request.user.id != model.owner.id:
            raise PermissionDenied(f"You are not the owner of model {model.id}!")
        model.preprocessing = get_file(request, "model_preprocessing_file")
        verify_model_object(model.preprocessing, "preprocessing")
        model.save()
        return JsonResponse({
            "detail": "Proprocessing Model Upload Accepted",
        }, status=status.HTTP_202_ACCEPTED)

    @extend_schema(
        responses={
            status.HTTP_200_OK: MetricSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Reports all metrics for the selected model.

        Args:
            request (HttpRequest):  request object
            id (str):  model UUID

        Returns:
            HttpResponse: Metrics as JSON Array
        """
        model = get_entity(ModelDB, pk=id)
        metrics = MetricDB.objects.filter(model=model).all()
        return Response(MetricSerializer(metrics, many=True).data)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "metric_names": {"type": "list"},
                    "metric_values": {"type": "list"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
                "detail": CharField(default="Model Metrics Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            OpenApiExample("Example", value={
                "metric_names": ["accuracy", "training loss"],
                "metric_values": [0.6, 0.04]
            }, media_type="multipart/form-data")
        ]
    )
    def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload model metrics.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(ModelDB, pk=id)
        formdata = dict(request.POST)

        with locked_atomic_transaction(MetricDB):
            self._metric_upload(formdata, model, request.user)

            if isinstance(model, GlobalModelDB):
                n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
                training = model.get_training()
                if training:
                    if n_metrics == training.participants.count():
                        dispatch_trainer_task(training, ModelTestFinished, False)
                else:
                    self._logger.warning(f"Global model {id} is not connected to any training.")

        return JsonResponse({
            "detail": "Model Metrics Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "owner": {"type": "string"},
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "metric_names": {"type": "list[string]"},
                    "metric_values": {"type": "list[float]"},
                    "model_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: ModelSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a partial trained model file from client.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid of the model, which was used for training

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            client = request.user
            model_file = get_file(request, "model_file")
            global_model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=global_model)
            self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)

            verify_model_object(model_file)
            local_model = LocalModelDB.objects.create(
                base_model=global_model, weights=model_file,
                round=round_num, owner=client, sample_size=sample_size
            )
            self._metric_upload(formdata, local_model, client, metrics_required=False)

            updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
            if updates.count() == training.participants.count():
                dispatch_trainer_task(training, TrainingRoundFinished, True)

            return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "first_moment_file": {"type": "string", "format": "binary"},
                    "second_moment_file": {"type": "string", "format": "binary"}
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
                "detail": CharField(default="SWAg Statistics Accepted"),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload SWAG statistics.

        Args:
            request (HttpRequest): request object
            id (str): global model uuid

        Raises:
            APIException: internal server error
            NotFound: model not found
            ParseError: request data not valid

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            client = request.user
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            fst_moment = get_file(request, "first_moment_file")
            snd_moment = get_file(request, "second_moment_file")
            model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=model)
            self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)

            self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)

            swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
            swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")

            if swag_stats_first.count() != swag_stats_second.count():
                training.state = TrainingState.ERROR
                raise APIException("SWAG stats in inconsistent state!")
            if swag_stats_first.count() == training.participants.count():
                dispatch_trainer_task(training, SWAGRoundFinished, True)

            return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)
        except Exception as e:
            raise APIException(e)

    @staticmethod
    def _save_swag_stats(
        fst_moment: bytes, snd_moment: bytes, model: GlobalModelDB, client: UserDB, sample_size: int
    ):
        """
        Save the first and second moments, and the sample size of the SWAG to the database.

        This function creates and saves three metrics for each round of the model:

        - the first moment,
        - the second moment, and
        - the sample size.

        These metrics are associated with the model, the round, and the client that reported them.

        Args:
            fst_moment (bytes): The first moment of the SWAG.
            snd_moment (bytes): The second moment of the SWAG.
            model (GlobalModelDB): The global model for which the metrics are being reported.
            client (UserDB): The client reporting the metrics.
            sample_size (int): The sample size of the SWAG.
        """
        MetricDB.objects.create(
            model=model,
            key="SWAG First Moment Local",
            value_binary=fst_moment,
            step=model.round,
            reporter=client
        ).save()
        MetricDB.objects.create(
            model=model,
            key="SWAG Second Moment Local",
            value_binary=snd_moment,
            step=model.round,
            reporter=client
        ).save()
        MetricDB.objects.create(
            model=model,
            key="SWAG Sample Size Local",
            value_float=sample_size,
            step=model.round,
            reporter=client
        ).save()

    @transaction.atomic()
    def _metric_upload(self, formdata: dict, model: ModelDB, client: UserDB, metrics_required: bool = True):
        """
        Uploads metrics associated with a model.

        For each pair of metric name and value, it attempts to convert the value to a float.
        If this fails, it treats the value as a binary string.

        It then creates a new metric object with the model, the metric name, the float or binary value,
        the model's round number, and the client, and saves this object to the database.

        Args:
            formdata (dict): The form data containing the metric names and values.
            model (ModelDB): The model with which the metrics are associated.
            client (UserDB): The client reporting the metrics.
            metrics_required (bool): A flag indicating whether metrics are required. Defaults to True.

        Raises:
            ParseError: If `metric_names` or `metric_values` are not in formdata,
                or if they do not have the same length and metrics are required.
        """
        if "metric_names" not in formdata or "metric_values" not in formdata:
            if metrics_required or ("metric_names" in formdata) != ("metric_values" in formdata):
                raise ParseError("Metric names or values are missing")
            return
        if len(formdata["metric_names"]) != len(formdata["metric_values"]):
            if metrics_required:
                raise ParseError("Metric names and values must have the same length")
            return

        for key, value in zip(formdata["metric_names"], formdata["metric_values"]):
            try:
                metric_float = float(value)
                metric_binary = None
            except Exception:
                metric_float = None
                metric_binary = bytes(value, encoding="utf-8")
            MetricDB.objects.create(
                model=model,
                key=key,
                value_float=metric_float,
                value_binary=metric_binary,
                step=model.round,
                reporter=client
            ).save()

    def _verify_valid_update(self, client: UserDB, train: Training, round_num: int, expected_state: tuple[str, Any]):
        """
        Verifies if a client can update a training process.

        This function checks if

        - the client is a participant of the training process,
        - the training process is in the expected state, and if
        - the round number matches the current round of the model associated with the training process.

        Args:
            client (UserDB): The client attempting to update the training process.
            train (Training): The training process to be updated.
            round_num (int): The round number reported by the client.
            expected_state (tuple[str, Any]): The expected state of the training process.

        Raises:
            PermissionDenied: If the client is not a participant of the training process.
            ValidationError: If the training process is not in the expected state or if the round number does not match
                the current round of the model.
        """
        if client.id not in [p.id for p in train.participants.all()]:
            raise PermissionDenied(f"You are not a participant of training {train.id}!")
        if train.state != expected_state:
            raise ValidationError(f"Training with ID {train.id} is in state {train.state}")
        if int(round_num) != train.model.round:
            raise ValidationError(f"Training with ID {train.id} is not currently in round {round_num}")

Classes

Model

class Model(
    **kwargs
)

Model ViewSet.

View Source
class Model(ViewSet):
    """
    Model ViewSet.
    """

    serializer_class = GlobalModelSerializer
    """The serializer for the ViewSet."""

    def _get_user_related_global_models(self, user: Union[AbstractBaseUser, AnonymousUser]) -> List[ModelDB]:
        """
        Get global models related to a user.

        This method retrieves all global models where the user is the actor or a participant of.

        Args:
            user (Union[AbstractBaseUser, AnonymousUser]): The user.

        Returns:
            List[ModelDB]: The global models related to the user.
        """
        user_ids = Training.objects.filter(
            Q(actor=user) | Q(participants=user)
        ).distinct().values_list("model__id", flat=True)
        return ModelDB.objects.filter(Q(owner=user) | Q(id__in=user_ids)).distinct()

    def _get_local_models_for_global_model(self, global_model: GlobalModelDB) -> List[LocalModelDB]:
        """
        Get all local models that are based on the global model.

        Args:
            global_model (GlobalModelDB): The global model.

        Returns:
            List[LocalModelDB]: The local models for the global model.
        """
        return LocalModelDB.objects.filter(base_model=global_model).all()

    def _get_local_models_for_global_models(self, global_models: List[GlobalModelDB]) -> List[LocalModelDB]:
        """
        Get all local models that are based on any of the global models.

        Args:
            global_models (List[GlobalModelDB]): The global models.

        Returns:
            List[LocalModelDB]: The local models for the global models.
        """
        return LocalModelDB.objects.filter(base_model__in=global_models).all()

    def _filter_by_training(self, models: List[ModelDB], training_id: str) -> List[ModelDB]:
        """
        Filter a list of models by checking if they are associated with the training.

        Args:
            models (List[ModelDB]): The models to filter.
            training_id (str): The ID of the training.

        Returns:
            List[ModelDB]: The models associated with the training.
        """
        def associated_with_training(m: ModelDB) -> bool:
            training = m.get_training()
            if training is None:
                return False
            return training.pk == UUID(training_id)
        return list(filter(associated_with_training, models))

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_models(self, request: HttpRequest) -> HttpResponse:
        """
        Get a list of all global models associated with the requesting user.

        A global model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models = self._get_user_related_global_models(request.user)
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of all models associated with a specific training process and the requesting user.

        A model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        local_models = self._get_local_models_for_global_models(global_models)
        serializer = ModelSerializer([*global_models, *local_models], many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of the latest models for a specific training process associated with the requesting user.

        A model is considered associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.
        The latest model refers to the model from the most recent round (highest round number) of
        a participant's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models: List[ModelDB] = []
        # add latest global model
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        models.append(max(global_models, key=lambda m: m.round))
        # add latest local models
        local_models = self._get_local_models_for_global_models(global_models)
        local_models = sorted(local_models, key=lambda m: str(m.owner.pk))  # required for groupby
        for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
            models.append(max(group, key=lambda m: m.round))
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
        """
        Get model meta data.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: Model meta data as JSON response.
        """
        model = get_entity(ModelDB, pk=id)
        serializer = ModelSerializer(model, context={"with-stats": True})
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: model as file response
        """
        model = get_entity(ModelDB, pk=id)
        if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
            if model.swag_second_moment is None:
                raise APIException(f"Model {model.id} is in inconsistent state!")
            raise NotImplementedError(
                "SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
            )
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(model.weights, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}.pt"'
        return response

    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(
                response=bytes,
                description="Proprecessing model is returned as bytes"
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
            status.HTTP_404_NOT_FOUND: error_response_404,
        },
    )
    def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole preprocessing model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
        """
        model = get_entity(ModelDB, pk=id)
        global_model: torch.nn.Module
        if isinstance(model, GlobalModelDB):
            global_model = model
        elif isinstance(model, LocalModelDB):
            global_model = model.base_model
        else:
            self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
            raise ValidationError(f"Unknown model type. Model id: {id}")
        if global_model.preprocessing is None:
            raise NotFound(f"Model '{id}' has no preprocessing model defined.")
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
        return response

    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteModelSuccessSerializer",
            fields={
                "detail": CharField(default="Model removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing model.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: 200 Response if model was removed, else corresponding error code
        """
        model = get_entity(ModelDB, pk=id)
        if model.owner != request.user:
            training = model.get_training()
            if training is None or training.actor != request.user:
                raise PermissionDenied(
                    "You are neither the owner of the model nor the actor of the corresponding training."
                )
        model.delete()
        return JsonResponse({"detail": "Model removed!"})

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "name": {"type": "string"},
                    "description": {"type": "string"},
                    "model_file": {"type": "string", "format": "binary"},
                    "model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
                },
            },
        },
        responses={
            status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
                "detail": CharField(default="Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_model(self, request: HttpRequest) -> HttpResponse:
        """
        Upload a global model file.

        The model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: upload success message as json response
        """
        model = load_and_create_model_request(request)
        return JsonResponse({
            "detail": "Model Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "model_preprocessing_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
                "detail": CharField(default="Proprocessing Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a preprocessing model file for a global model.

        The preprocessing model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        ```python
        transforms = torch.nn.Sequential(
            torchvision.transforms.CenterCrop(10),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        )
        ```

        Make sure to only use transformations that inherit from `torch.nn.Module`.
        It is advised to use the `torchvision.transforms.v2` module for common transformations.

        Please note that this function is still in the beta phase.

        Args:
            request (HttpRequest): request object
            id (str): global model UUID

        Raises:
            PermissionDenied: Unauthorized to upload preprocessing model for the specified model
            ValidationError: Preprocessing model is not a valid torch model

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(GlobalModelDB, pk=id)
        if request.user.id != model.owner.id:
            raise PermissionDenied(f"You are not the owner of model {model.id}!")
        model.preprocessing = get_file(request, "model_preprocessing_file")
        verify_model_object(model.preprocessing, "preprocessing")
        model.save()
        return JsonResponse({
            "detail": "Proprocessing Model Upload Accepted",
        }, status=status.HTTP_202_ACCEPTED)

    @extend_schema(
        responses={
            status.HTTP_200_OK: MetricSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Reports all metrics for the selected model.

        Args:
            request (HttpRequest):  request object
            id (str):  model UUID

        Returns:
            HttpResponse: Metrics as JSON Array
        """
        model = get_entity(ModelDB, pk=id)
        metrics = MetricDB.objects.filter(model=model).all()
        return Response(MetricSerializer(metrics, many=True).data)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "metric_names": {"type": "list"},
                    "metric_values": {"type": "list"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
                "detail": CharField(default="Model Metrics Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            OpenApiExample("Example", value={
                "metric_names": ["accuracy", "training loss"],
                "metric_values": [0.6, 0.04]
            }, media_type="multipart/form-data")
        ]
    )
    def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload model metrics.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(ModelDB, pk=id)
        formdata = dict(request.POST)

        with locked_atomic_transaction(MetricDB):
            self._metric_upload(formdata, model, request.user)

            if isinstance(model, GlobalModelDB):
                n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
                training = model.get_training()
                if training:
                    if n_metrics == training.participants.count():
                        dispatch_trainer_task(training, ModelTestFinished, False)
                else:
                    self._logger.warning(f"Global model {id} is not connected to any training.")

        return JsonResponse({
            "detail": "Model Metrics Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "owner": {"type": "string"},
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "metric_names": {"type": "list[string]"},
                    "metric_values": {"type": "list[float]"},
                    "model_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: ModelSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a partial trained model file from client.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid of the model, which was used for training

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            client = request.user
            model_file = get_file(request, "model_file")
            global_model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=global_model)
            self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)

            verify_model_object(model_file)
            local_model = LocalModelDB.objects.create(
                base_model=global_model, weights=model_file,
                round=round_num, owner=client, sample_size=sample_size
            )
            self._metric_upload(formdata, local_model, client, metrics_required=False)

            updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
            if updates.count() == training.participants.count():
                dispatch_trainer_task(training, TrainingRoundFinished, True)

            return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "first_moment_file": {"type": "string", "format": "binary"},
                    "second_moment_file": {"type": "string", "format": "binary"}
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
                "detail": CharField(default="SWAg Statistics Accepted"),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload SWAG statistics.

        Args:
            request (HttpRequest): request object
            id (str): global model uuid

        Raises:
            APIException: internal server error
            NotFound: model not found
            ParseError: request data not valid

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            client = request.user
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            fst_moment = get_file(request, "first_moment_file")
            snd_moment = get_file(request, "second_moment_file")
            model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=model)
            self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)

            self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)

            swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
            swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")

            if swag_stats_first.count() != swag_stats_second.count():
                training.state = TrainingState.ERROR
                raise APIException("SWAG stats in inconsistent state!")
            if swag_stats_first.count() == training.participants.count():
                dispatch_trainer_task(training, SWAGRoundFinished, True)

            return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)
        except Exception as e:
            raise APIException(e)

    @staticmethod
    def _save_swag_stats(
        fst_moment: bytes, snd_moment: bytes, model: GlobalModelDB, client: UserDB, sample_size: int
    ):
        """
        Save the first and second moments, and the sample size of the SWAG to the database.

        This function creates and saves three metrics for each round of the model:

        - the first moment,
        - the second moment, and
        - the sample size.

        These metrics are associated with the model, the round, and the client that reported them.

        Args:
            fst_moment (bytes): The first moment of the SWAG.
            snd_moment (bytes): The second moment of the SWAG.
            model (GlobalModelDB): The global model for which the metrics are being reported.
            client (UserDB): The client reporting the metrics.
            sample_size (int): The sample size of the SWAG.
        """
        MetricDB.objects.create(
            model=model,
            key="SWAG First Moment Local",
            value_binary=fst_moment,
            step=model.round,
            reporter=client
        ).save()
        MetricDB.objects.create(
            model=model,
            key="SWAG Second Moment Local",
            value_binary=snd_moment,
            step=model.round,
            reporter=client
        ).save()
        MetricDB.objects.create(
            model=model,
            key="SWAG Sample Size Local",
            value_float=sample_size,
            step=model.round,
            reporter=client
        ).save()

    @transaction.atomic()
    def _metric_upload(self, formdata: dict, model: ModelDB, client: UserDB, metrics_required: bool = True):
        """
        Uploads metrics associated with a model.

        For each pair of metric name and value, it attempts to convert the value to a float.
        If this fails, it treats the value as a binary string.

        It then creates a new metric object with the model, the metric name, the float or binary value,
        the model's round number, and the client, and saves this object to the database.

        Args:
            formdata (dict): The form data containing the metric names and values.
            model (ModelDB): The model with which the metrics are associated.
            client (UserDB): The client reporting the metrics.
            metrics_required (bool): A flag indicating whether metrics are required. Defaults to True.

        Raises:
            ParseError: If `metric_names` or `metric_values` are not in formdata,
                or if they do not have the same length and metrics are required.
        """
        if "metric_names" not in formdata or "metric_values" not in formdata:
            if metrics_required or ("metric_names" in formdata) != ("metric_values" in formdata):
                raise ParseError("Metric names or values are missing")
            return
        if len(formdata["metric_names"]) != len(formdata["metric_values"]):
            if metrics_required:
                raise ParseError("Metric names and values must have the same length")
            return

        for key, value in zip(formdata["metric_names"], formdata["metric_values"]):
            try:
                metric_float = float(value)
                metric_binary = None
            except Exception:
                metric_float = None
                metric_binary = bytes(value, encoding="utf-8")
            MetricDB.objects.create(
                model=model,
                key=key,
                value_float=metric_float,
                value_binary=metric_binary,
                step=model.round,
                reporter=client
            ).save()

    def _verify_valid_update(self, client: UserDB, train: Training, round_num: int, expected_state: tuple[str, Any]):
        """
        Verifies if a client can update a training process.

        This function checks if

        - the client is a participant of the training process,
        - the training process is in the expected state, and if
        - the round number matches the current round of the model associated with the training process.

        Args:
            client (UserDB): The client attempting to update the training process.
            train (Training): The training process to be updated.
            round_num (int): The round number reported by the client.
            expected_state (tuple[str, Any]): The expected state of the training process.

        Raises:
            PermissionDenied: If the client is not a participant of the training process.
            ValidationError: If the training process is not in the expected state or if the round number does not match
                the current round of the model.
        """
        if client.id not in [p.id for p in train.participants.all()]:
            raise PermissionDenied(f"You are not a participant of training {train.id}!")
        if train.state != expected_state:
            raise ValidationError(f"Training with ID {train.id} is in state {train.state}")
        if int(round_num) != train.model.round:
            raise ValidationError(f"Training with ID {train.id} is not currently in round {round_num}")

Ancestors (in MRO)

  • fl_server_api.views.base.ViewSet
  • rest_framework.viewsets.ViewSet
  • rest_framework.viewsets.ViewSetMixin
  • rest_framework.views.APIView
  • django.views.generic.base.View

Class variables

authentication_classes
basename
content_negotiation_class
description
detail
http_method_names
metadata_class
name
parser_classes
permission_classes
renderer_classes
schema
serializer_class

The serializer for the ViewSet.

settings
suffix
throttle_classes
versioning_class

Static methods

as_view

def as_view(
    actions=None,
    **initkwargs
)

Because of the way class based views create a closure around the

instantiated view, we need to totally reimplement .as_view, and slightly modify the view function that is created and returned.

View Source
    @classonlymethod
    def as_view(cls, actions=None, **initkwargs):
        """
        Because of the way class based views create a closure around the
        instantiated view, we need to totally reimplement `.as_view`,
        and slightly modify the view function that is created and returned.
        """
        # The name and description initkwargs may be explicitly overridden for
        # certain route configurations. eg, names of extra actions.
        cls.name = None
        cls.description = None

        # The suffix initkwarg is reserved for displaying the viewset type.
        # This initkwarg should have no effect if the name is provided.
        # eg. 'List' or 'Instance'.
        cls.suffix = None

        # The detail initkwarg is reserved for introspecting the viewset type.
        cls.detail = None

        # Setting a basename allows a view to reverse its action urls. This
        # value is provided by the router through the initkwargs.
        cls.basename = None

        # actions must not be empty
        if not actions:
            raise TypeError("The `actions` argument must be provided when "
                            "calling `.as_view()` on a ViewSet. For example "
                            "`.as_view({'get': 'list'})`")

        # sanitize keyword arguments
        for key in initkwargs:
            if key in cls.http_method_names:
                raise TypeError("You tried to pass in the %s method name as a "
                                "keyword argument to %s(). Don't do that."
                                % (key, cls.__name__))
            if not hasattr(cls, key):
                raise TypeError("%s() received an invalid keyword %r" % (
                    cls.__name__, key))

        # name and suffix are mutually exclusive
        if 'name' in initkwargs and 'suffix' in initkwargs:
            raise TypeError("%s() received both `name` and `suffix`, which are "
                            "mutually exclusive arguments." % (cls.__name__))

        def view(request, *args, **kwargs):
            self = cls(**initkwargs)

            if 'get' in actions and 'head' not in actions:
                actions['head'] = actions['get']

            # We also store the mapping of request methods to actions,
            # so that we can later set the action attribute.
            # eg. `self.action = 'list'` on an incoming GET request.
            self.action_map = actions

            # Bind methods to actions
            # This is the bit that's different to a standard view
            for method, action in actions.items():
                handler = getattr(self, action)
                setattr(self, method, handler)

            self.request = request
            self.args = args
            self.kwargs = kwargs

            # And continue as usual
            return self.dispatch(request, *args, **kwargs)

        # take name and docstring from class
        update_wrapper(view, cls, updated=())

        # and possible attributes set by decorators
        # like csrf_exempt from dispatch
        update_wrapper(view, cls.dispatch, assigned=())

        # We need to set these on the view function, so that breadcrumb
        # generation can pick out these bits of information from a
        # resolved URL.
        view.cls = cls
        view.initkwargs = initkwargs
        view.actions = actions
        return csrf_exempt(view)

get_extra_actions

def get_extra_actions()

Get the methods that are marked as an extra ViewSet @action.

View Source
    @classmethod
    def get_extra_actions(cls):
        """
        Get the methods that are marked as an extra ViewSet `@action`.
        """
        return [_check_attr_name(method, name)
                for name, method
                in getmembers(cls, _is_extra_action)]

Instance variables

allowed_methods

Wrap Django's private _allowed_methods interface in a public property.

default_response_headers

Methods

check_object_permissions

def check_object_permissions(
    self,
    request,
    obj
)

Check if the request should be permitted for a given object.

Raises an appropriate exception if the request is not permitted.

View Source
    def check_object_permissions(self, request, obj):
        """
        Check if the request should be permitted for a given object.
        Raises an appropriate exception if the request is not permitted.
        """
        for permission in self.get_permissions():
            if not permission.has_object_permission(request, self, obj):
                self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )

check_permissions

def check_permissions(
    self,
    request
)

Check if the request should be permitted.

Raises an appropriate exception if the request is not permitted.

View Source
    def check_permissions(self, request):
        """
        Check if the request should be permitted.
        Raises an appropriate exception if the request is not permitted.
        """
        for permission in self.get_permissions():
            if not permission.has_permission(request, self):
                self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )

check_throttles

def check_throttles(
    self,
    request
)

Check if request should be throttled.

Raises an appropriate exception if the request is throttled.

View Source
    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)

create_local_model

def create_local_model(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse

Upload a partial trained model file from client.

Parameters:

Name Type Description Default
request HttpRequest request object None
id str model uuid of the model, which was used for training None

Returns:

Type Description
HttpResponse upload success message as json response
View Source
    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "owner": {"type": "string"},
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "metric_names": {"type": "list[string]"},
                    "metric_values": {"type": "list[float]"},
                    "model_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: ModelSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a partial trained model file from client.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid of the model, which was used for training

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            client = request.user
            model_file = get_file(request, "model_file")
            global_model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=global_model)
            self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)

            verify_model_object(model_file)
            local_model = LocalModelDB.objects.create(
                base_model=global_model, weights=model_file,
                round=round_num, owner=client, sample_size=sample_size
            )
            self._metric_upload(formdata, local_model, client, metrics_required=False)

            updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
            if updates.count() == training.participants.count():
                dispatch_trainer_task(training, TrainingRoundFinished, True)

            return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)

create_model

def create_model(
    self,
    request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse

Upload a global model file.

The model file should be a PyTorch serialized model. Providing the model via torch.save as well as in TorchScript format is supported.

Parameters:

Name Type Description Default
request HttpRequest The incoming request object. None

Returns:

Type Description
HttpResponse upload success message as json response
View Source
    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "name": {"type": "string"},
                    "description": {"type": "string"},
                    "model_file": {"type": "string", "format": "binary"},
                    "model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
                },
            },
        },
        responses={
            status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
                "detail": CharField(default="Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_model(self, request: HttpRequest) -> HttpResponse:
        """
        Upload a global model file.

        The model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: upload success message as json response
        """
        model = load_and_create_model_request(request)
        return JsonResponse({
            "detail": "Model Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

create_model_metrics

def create_model_metrics(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse

Upload model metrics.

Parameters:

Name Type Description Default
request HttpRequest request object None
id str model uuid None

Returns:

Type Description
HttpResponse upload success message as json response
View Source
    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "metric_names": {"type": "list"},
                    "metric_values": {"type": "list"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
                "detail": CharField(default="Model Metrics Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            OpenApiExample("Example", value={
                "metric_names": ["accuracy", "training loss"],
                "metric_values": [0.6, 0.04]
            }, media_type="multipart/form-data")
        ]
    )
    def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload model metrics.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(ModelDB, pk=id)
        formdata = dict(request.POST)

        with locked_atomic_transaction(MetricDB):
            self._metric_upload(formdata, model, request.user)

            if isinstance(model, GlobalModelDB):
                n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
                training = model.get_training()
                if training:
                    if n_metrics == training.participants.count():
                        dispatch_trainer_task(training, ModelTestFinished, False)
                else:
                    self._logger.warning(f"Global model {id} is not connected to any training.")

        return JsonResponse({
            "detail": "Model Metrics Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

create_swag_stats

def create_swag_stats(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse

Upload SWAG statistics.

Parameters:

Name Type Description Default
request HttpRequest request object None
id str global model uuid None

Returns:

Type Description
HttpResponse upload success message as json response

Raises:

Type Description
APIException internal server error
NotFound model not found
ParseError request data not valid
View Source
    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "first_moment_file": {"type": "string", "format": "binary"},
                    "second_moment_file": {"type": "string", "format": "binary"}
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
                "detail": CharField(default="SWAg Statistics Accepted"),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload SWAG statistics.

        Args:
            request (HttpRequest): request object
            id (str): global model uuid

        Raises:
            APIException: internal server error
            NotFound: model not found
            ParseError: request data not valid

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            client = request.user
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            fst_moment = get_file(request, "first_moment_file")
            snd_moment = get_file(request, "second_moment_file")
            model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=model)
            self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)

            self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)

            swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
            swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")

            if swag_stats_first.count() != swag_stats_second.count():
                training.state = TrainingState.ERROR
                raise APIException("SWAG stats in inconsistent state!")
            if swag_stats_first.count() == training.participants.count():
                dispatch_trainer_task(training, SWAGRoundFinished, True)

            return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)
        except Exception as e:
            raise APIException(e)

determine_version

def determine_version(
    self,
    request,
    *args,
    **kwargs
)

If versioning is being used, then determine any API version for the

incoming request. Returns a two-tuple of (version, versioning_scheme)

View Source
    def determine_version(self, request, *args, **kwargs):
        """
        If versioning is being used, then determine any API version for the
        incoming request. Returns a two-tuple of (version, versioning_scheme)
        """
        if self.versioning_class is None:
            return (None, None)
        scheme = self.versioning_class()
        return (scheme.determine_version(request, *args, **kwargs), scheme)

dispatch

def dispatch(
    self,
    request,
    *args,
    **kwargs
)

.dispatch() is pretty much the same as Django's regular dispatch,

but with extra hooks for startup, finalize, and exception handling.

View Source
    def dispatch(self, request, *args, **kwargs):
        """
        `.dispatch()` is pretty much the same as Django's regular dispatch,
        but with extra hooks for startup, finalize, and exception handling.
        """
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?

        try:
            self.initial(request, *args, **kwargs)

            # Get the appropriate handler method
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response

finalize_response

def finalize_response(
    self,
    request,
    response,
    *args,
    **kwargs
)

Returns the final response object.

View Source
    def finalize_response(self, request, response, *args, **kwargs):
        """
        Returns the final response object.
        """
        # Make the error obvious if a proper response is not returned
        assert isinstance(response, HttpResponseBase), (
            'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
            'to be returned from the view, but received a `%s`'
            % type(response)
        )

        if isinstance(response, Response):
            if not getattr(request, 'accepted_renderer', None):
                neg = self.perform_content_negotiation(request, force=True)
                request.accepted_renderer, request.accepted_media_type = neg

            response.accepted_renderer = request.accepted_renderer
            response.accepted_media_type = request.accepted_media_type
            response.renderer_context = self.get_renderer_context()

        # Add new vary headers to the response instead of overwriting.
        vary_headers = self.headers.pop('Vary', None)
        if vary_headers is not None:
            patch_vary_headers(response, cc_delim_re.split(vary_headers))

        for key, value in self.headers.items():
            response[key] = value

        return response

get_authenticate_header

def get_authenticate_header(
    self,
    request
)

If a request is unauthenticated, determine the WWW-Authenticate

header to use for 401 responses, if any.

View Source
    def get_authenticate_header(self, request):
        """
        If a request is unauthenticated, determine the WWW-Authenticate
        header to use for 401 responses, if any.
        """
        authenticators = self.get_authenticators()
        if authenticators:
            return authenticators[0].authenticate_header(request)

get_authenticators

def get_authenticators(
    self
)

Get the authenticators for the ViewSet.

This method gets the view method and, if it has authentication classes defined via the decorator, returns them. Otherwise, it falls back to the default authenticators.

Returns:

Type Description
list The authenticators for the ViewSet.
View Source
    def get_authenticators(self):
        """
        Get the authenticators for the ViewSet.

        This method gets the view method and, if it has authentication classes defined via the decorator, returns them.
        Otherwise, it falls back to the default authenticators.

        Returns:
            list: The authenticators for the ViewSet.
        """
        if method := self._get_view_method():
            if hasattr(method, "authentication_classes"):
                return method.authentication_classes
        return super().get_authenticators()

get_content_negotiator

def get_content_negotiator(
    self
)

Instantiate and return the content negotiation class to use.

View Source
    def get_content_negotiator(self):
        """
        Instantiate and return the content negotiation class to use.
        """
        if not getattr(self, '_negotiator', None):
            self._negotiator = self.content_negotiation_class()
        return self._negotiator

get_exception_handler

def get_exception_handler(
    self
)

Returns the exception handler that this view uses.

View Source
    def get_exception_handler(self):
        """
        Returns the exception handler that this view uses.
        """
        return self.settings.EXCEPTION_HANDLER

get_exception_handler_context

def get_exception_handler_context(
    self
)

Returns a dict that is passed through to EXCEPTION_HANDLER,

as the context argument.

View Source
    def get_exception_handler_context(self):
        """
        Returns a dict that is passed through to EXCEPTION_HANDLER,
        as the `context` argument.
        """
        return {
            'view': self,
            'args': getattr(self, 'args', ()),
            'kwargs': getattr(self, 'kwargs', {}),
            'request': getattr(self, 'request', None)
        }

get_extra_action_url_map

def get_extra_action_url_map(
    self
)

Build a map of {names: urls} for the extra actions.

This method will noop if detail was not provided as a view initkwarg.

View Source
    def get_extra_action_url_map(self):
        """
        Build a map of {names: urls} for the extra actions.

        This method will noop if `detail` was not provided as a view initkwarg.
        """
        action_urls = OrderedDict()

        # exit early if `detail` has not been provided
        if self.detail is None:
            return action_urls

        # filter for the relevant extra actions
        actions = [
            action for action in self.get_extra_actions()
            if action.detail == self.detail
        ]

        for action in actions:
            try:
                url_name = '%s-%s' % (self.basename, action.url_name)
                namespace = self.request.resolver_match.namespace
                if namespace:
                    url_name = '%s:%s' % (namespace, url_name)

                url = reverse(url_name, self.args, self.kwargs, request=self.request)
                view = self.__class__(**action.kwargs)
                action_urls[view.get_view_name()] = url
            except NoReverseMatch:
                pass  # URL requires additional arguments, ignore

        return action_urls

get_format_suffix

def get_format_suffix(
    self,
    **kwargs
)

Determine if the request includes a '.json' style format suffix

View Source
    def get_format_suffix(self, **kwargs):
        """
        Determine if the request includes a '.json' style format suffix
        """
        if self.settings.FORMAT_SUFFIX_KWARG:
            return kwargs.get(self.settings.FORMAT_SUFFIX_KWARG)

get_metadata

def get_metadata(
    self,
    _request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse

Get model meta data.

Parameters:

Name Type Description Default
request HttpRequest The incoming request object. None
id str The unique identifier of the model. None

Returns:

Type Description
HttpResponse Model meta data as JSON response.
View Source
    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
        """
        Get model meta data.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: Model meta data as JSON response.
        """
        model = get_entity(ModelDB, pk=id)
        serializer = ModelSerializer(model, context={"with-stats": True})
        return Response(serializer.data)

get_model

def get_model(
    self,
    _request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponseBase

Download the whole model as PyTorch serialized file.

Parameters:

Name Type Description Default
request HttpRequest The incoming request object. None
id str The unique identifier of the model. None

Returns:

Type Description
HttpResponseBase model as file response
View Source
    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: model as file response
        """
        model = get_entity(ModelDB, pk=id)
        if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
            if model.swag_second_moment is None:
                raise APIException(f"Model {model.id} is in inconsistent state!")
            raise NotImplementedError(
                "SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
            )
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(model.weights, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}.pt"'
        return response

get_model_metrics

def get_model_metrics(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse

Reports all metrics for the selected model.

Parameters:

Name Type Description Default
request HttpRequest request object None
id str model UUID None

Returns:

Type Description
HttpResponse Metrics as JSON Array
View Source
    @extend_schema(
        responses={
            status.HTTP_200_OK: MetricSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Reports all metrics for the selected model.

        Args:
            request (HttpRequest):  request object
            id (str):  model UUID

        Returns:
            HttpResponse: Metrics as JSON Array
        """
        model = get_entity(ModelDB, pk=id)
        metrics = MetricDB.objects.filter(model=model).all()
        return Response(MetricSerializer(metrics, many=True).data)

get_model_proprecessing

def get_model_proprecessing(
    self,
    _request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponseBase

Download the whole preprocessing model as PyTorch serialized file.

Parameters:

Name Type Description Default
request HttpRequest The incoming request object. None
id str The unique identifier of the model. None

Returns:

Type Description
HttpResponseBase proprecessing model as file response or 404 if proprecessing model not found
View Source
    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(
                response=bytes,
                description="Proprecessing model is returned as bytes"
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
            status.HTTP_404_NOT_FOUND: error_response_404,
        },
    )
    def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole preprocessing model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
        """
        model = get_entity(ModelDB, pk=id)
        global_model: torch.nn.Module
        if isinstance(model, GlobalModelDB):
            global_model = model
        elif isinstance(model, LocalModelDB):
            global_model = model.base_model
        else:
            self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
            raise ValidationError(f"Unknown model type. Model id: {id}")
        if global_model.preprocessing is None:
            raise NotFound(f"Model '{id}' has no preprocessing model defined.")
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
        return response

get_models

def get_models(
    self,
    request: django.http.request.HttpRequest
) -> django.http.response.HttpResponse

Get a list of all global models associated with the requesting user.

A global model is deemed associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process.

Parameters:

Name Type Description Default
request HttpRequest The incoming request object. None

Returns:

Type Description
HttpResponse Model list as JSON response.
View Source
    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_models(self, request: HttpRequest) -> HttpResponse:
        """
        Get a list of all global models associated with the requesting user.

        A global model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models = self._get_user_related_global_models(request.user)
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

get_parser_context

def get_parser_context(
    self,
    http_request
)

Returns a dict that is passed through to Parser.parse(),

as the parser_context keyword argument.

View Source
    def get_parser_context(self, http_request):
        """
        Returns a dict that is passed through to Parser.parse(),
        as the `parser_context` keyword argument.
        """
        # Note: Additionally `request` and `encoding` will also be added
        #       to the context by the Request object.
        return {
            'view': self,
            'args': getattr(self, 'args', ()),
            'kwargs': getattr(self, 'kwargs', {})
        }

get_parsers

def get_parsers(
    self
)

Instantiates and returns the list of parsers that this view can use.

View Source
    def get_parsers(self):
        """
        Instantiates and returns the list of parsers that this view can use.
        """
        return [parser() for parser in self.parser_classes]

get_permissions

def get_permissions(
    self
)

Get the permissions for the ViewSet.

This method gets the view method and, if it has permission classes defined via the decorator, returns them. Otherwise, it falls back to the default permissions.

Returns:

Type Description
list The permissions for the ViewSet.
View Source
    def get_permissions(self):
        """
        Get the permissions for the ViewSet.

        This method gets the view method and, if it has permission classes defined via the decorator, returns them.
        Otherwise, it falls back to the default permissions.

        Returns:
            list: The permissions for the ViewSet.
        """
        if method := self._get_view_method():
            if hasattr(method, "permission_classes"):
                return method.permission_classes
        return super().get_permissions()

get_renderer_context

def get_renderer_context(
    self
)

Returns a dict that is passed through to Renderer.render(),

as the renderer_context keyword argument.

View Source
    def get_renderer_context(self):
        """
        Returns a dict that is passed through to Renderer.render(),
        as the `renderer_context` keyword argument.
        """
        # Note: Additionally 'response' will also be added to the context,
        #       by the Response object.
        return {
            'view': self,
            'args': getattr(self, 'args', ()),
            'kwargs': getattr(self, 'kwargs', {}),
            'request': getattr(self, 'request', None)
        }

get_renderers

def get_renderers(
    self
)

Instantiates and returns the list of renderers that this view can use.

View Source
    def get_renderers(self):
        """
        Instantiates and returns the list of renderers that this view can use.
        """
        return [renderer() for renderer in self.renderer_classes]

get_throttles

def get_throttles(
    self
)

Instantiates and returns the list of throttles that this view uses.

View Source
    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        return [throttle() for throttle in self.throttle_classes]

get_training_models

def get_training_models(
    self,
    request: django.http.request.HttpRequest,
    training_id: str
) -> django.http.response.HttpResponse

Get a list of all models associated with a specific training process and the requesting user.

A model is deemed associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process.

Parameters:

Name Type Description Default
request HttpRequest The incoming request object. None
training_id str The unique identifier of the training process. None

Returns:

Type Description
HttpResponse Model list as JSON response.
View Source
    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of all models associated with a specific training process and the requesting user.

        A model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        local_models = self._get_local_models_for_global_models(global_models)
        serializer = ModelSerializer([*global_models, *local_models], many=True)
        return Response(serializer.data)

get_training_models_latest

def get_training_models_latest(
    self,
    request: django.http.request.HttpRequest,
    training_id: str
) -> django.http.response.HttpResponse

Get a list of the latest models for a specific training process associated with the requesting user.

A model is considered associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process. The latest model refers to the model from the most recent round (highest round number) of a participant's training process.

Parameters:

Name Type Description Default
request HttpRequest The incoming request object. None
training_id str The unique identifier of the training process. None

Returns:

Type Description
HttpResponse Model list as JSON response.
View Source
    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of the latest models for a specific training process associated with the requesting user.

        A model is considered associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.
        The latest model refers to the model from the most recent round (highest round number) of
        a participant's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models: List[ModelDB] = []
        # add latest global model
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        models.append(max(global_models, key=lambda m: m.round))
        # add latest local models
        local_models = self._get_local_models_for_global_models(global_models)
        local_models = sorted(local_models, key=lambda m: str(m.owner.pk))  # required for groupby
        for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
            models.append(max(group, key=lambda m: m.round))
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

get_view_description

def get_view_description(
    self,
    html=False
)

Return some descriptive text for the view, as used in OPTIONS responses

and in the browsable API.

View Source
    def get_view_description(self, html=False):
        """
        Return some descriptive text for the view, as used in OPTIONS responses
        and in the browsable API.
        """
        func = self.settings.VIEW_DESCRIPTION_FUNCTION
        return func(self, html)

get_view_name

def get_view_name(
    self
)

Return the view name, as used in OPTIONS responses and in the

browsable API.

View Source
    def get_view_name(self):
        """
        Return the view name, as used in OPTIONS responses and in the
        browsable API.
        """
        func = self.settings.VIEW_NAME_FUNCTION
        return func(self)

handle_exception

def handle_exception(
    self,
    exc
)

Handle any exception that occurs, by returning an appropriate response,

or re-raising the error.

View Source
    def handle_exception(self, exc):
        """
        Handle any exception that occurs, by returning an appropriate response,
        or re-raising the error.
        """
        if isinstance(exc, (exceptions.NotAuthenticated,
                            exceptions.AuthenticationFailed)):
            # WWW-Authenticate header for 401 responses, else coerce to 403
            auth_header = self.get_authenticate_header(self.request)

            if auth_header:
                exc.auth_header = auth_header
            else:
                exc.status_code = status.HTTP_403_FORBIDDEN

        exception_handler = self.get_exception_handler()

        context = self.get_exception_handler_context()
        response = exception_handler(exc, context)

        if response is None:
            self.raise_uncaught_exception(exc)

        response.exception = True
        return response

http_method_not_allowed

def http_method_not_allowed(
    self,
    request,
    *args,
    **kwargs
)

If request.method does not correspond to a handler method,

determine what kind of exception to raise.

View Source
    def http_method_not_allowed(self, request, *args, **kwargs):
        """
        If `request.method` does not correspond to a handler method,
        determine what kind of exception to raise.
        """
        raise exceptions.MethodNotAllowed(request.method)

initial

def initial(
    self,
    request,
    *args,
    **kwargs
)

Runs anything that needs to occur prior to calling the method handler.

View Source
    def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        """
        self.format_kwarg = self.get_format_suffix(**kwargs)

        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # Determine the API version, if versioning is in use.
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme

        # Ensure that the incoming request is permitted
        self.perform_authentication(request)
        self.check_permissions(request)
        self.check_throttles(request)

initialize_request

def initialize_request(
    self,
    request,
    *args,
    **kwargs
)

Set the .action attribute on the view, depending on the request method.

View Source
    def initialize_request(self, request, *args, **kwargs):
        """
        Set the `.action` attribute on the view, depending on the request method.
        """
        request = super().initialize_request(request, *args, **kwargs)
        method = request.method.lower()
        if method == 'options':
            # This is a special case as we always provide handling for the
            # options method in the base `View` class.
            # Unlike the other explicitly defined actions, 'metadata' is implicit.
            self.action = 'metadata'
        else:
            self.action = self.action_map.get(method)
        return request

options

def options(
    self,
    request,
    *args,
    **kwargs
)

Handler method for HTTP 'OPTIONS' request.

View Source
    def options(self, request, *args, **kwargs):
        """
        Handler method for HTTP 'OPTIONS' request.
        """
        if self.metadata_class is None:
            return self.http_method_not_allowed(request, *args, **kwargs)
        data = self.metadata_class().determine_metadata(request, self)
        return Response(data, status=status.HTTP_200_OK)

perform_authentication

def perform_authentication(
    self,
    request
)

Perform authentication on the incoming request.

Note that if you override this and simply 'pass', then authentication will instead be performed lazily, the first time either request.user or request.auth is accessed.

View Source
    def perform_authentication(self, request):
        """
        Perform authentication on the incoming request.

        Note that if you override this and simply 'pass', then authentication
        will instead be performed lazily, the first time either
        `request.user` or `request.auth` is accessed.
        """
        request.user

perform_content_negotiation

def perform_content_negotiation(
    self,
    request,
    force=False
)

Determine which renderer and media type to use render the response.

View Source
    def perform_content_negotiation(self, request, force=False):
        """
        Determine which renderer and media type to use render the response.
        """
        renderers = self.get_renderers()
        conneg = self.get_content_negotiator()

        try:
            return conneg.select_renderer(request, renderers, self.format_kwarg)
        except Exception:
            if force:
                return (renderers[0], renderers[0].media_type)
            raise

permission_denied

def permission_denied(
    self,
    request,
    message=None,
    code=None
)

If request is not permitted, determine what kind of exception to raise.

View Source
    def permission_denied(self, request, message=None, code=None):
        """
        If request is not permitted, determine what kind of exception to raise.
        """
        if request.authenticators and not request.successful_authenticator:
            raise exceptions.NotAuthenticated()
        raise exceptions.PermissionDenied(detail=message, code=code)

raise_uncaught_exception

def raise_uncaught_exception(
    self,
    exc
)
View Source
    def raise_uncaught_exception(self, exc):
        if settings.DEBUG:
            request = self.request
            renderer_format = getattr(request.accepted_renderer, 'format')
            use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin')
            request.force_plaintext_errors(use_plaintext_traceback)
        raise exc

remove_model

def remove_model(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse

Remove an existing model.

Parameters:

Name Type Description Default
request HttpRequest The incoming request object. None
id str The unique identifier of the model. None

Returns:

Type Description
HttpResponse 200 Response if model was removed, else corresponding error code
View Source
    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteModelSuccessSerializer",
            fields={
                "detail": CharField(default="Model removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing model.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: 200 Response if model was removed, else corresponding error code
        """
        model = get_entity(ModelDB, pk=id)
        if model.owner != request.user:
            training = model.get_training()
            if training is None or training.actor != request.user:
                raise PermissionDenied(
                    "You are neither the owner of the model nor the actor of the corresponding training."
                )
        model.delete()
        return JsonResponse({"detail": "Model removed!"})

reverse_action

def reverse_action(
    self,
    url_name,
    *args,
    **kwargs
)

Reverse the action for the given url_name.

View Source
    def reverse_action(self, url_name, *args, **kwargs):
        """
        Reverse the action for the given `url_name`.
        """
        url_name = '%s-%s' % (self.basename, url_name)
        namespace = None
        if self.request and self.request.resolver_match:
            namespace = self.request.resolver_match.namespace
        if namespace:
            url_name = namespace + ':' + url_name
        kwargs.setdefault('request', self.request)

        return reverse(url_name, *args, **kwargs)

setup

def setup(
    self,
    request,
    *args,
    **kwargs
)

Initialize attributes shared by all view methods.

View Source
    def setup(self, request, *args, **kwargs):
        """Initialize attributes shared by all view methods."""
        if hasattr(self, "get") and not hasattr(self, "head"):
            self.head = self.get
        self.request = request
        self.args = args
        self.kwargs = kwargs

throttled

def throttled(
    self,
    request,
    wait
)

If request is throttled, determine what kind of exception to raise.

View Source
    def throttled(self, request, wait):
        """
        If request is throttled, determine what kind of exception to raise.
        """
        raise exceptions.Throttled(wait)

upload_model_preprocessing

def upload_model_preprocessing(
    self,
    request: django.http.request.HttpRequest,
    id: str
) -> django.http.response.HttpResponse

Upload a preprocessing model file for a global model.

The preprocessing model file should be a PyTorch serialized model. Providing the model via torch.save as well as in TorchScript format is supported.

transforms = torch.nn.Sequential(
    torchvision.transforms.CenterCrop(10),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)

Make sure to only use transformations that inherit from torch.nn.Module. It is advised to use the torchvision.transforms.v2 module for common transformations.

Please note that this function is still in the beta phase.

Parameters:

Name Type Description Default
request HttpRequest request object None
id str global model UUID None

Returns:

Type Description
HttpResponse upload success message as json response

Raises:

Type Description
PermissionDenied Unauthorized to upload preprocessing model for the specified model
ValidationError Preprocessing model is not a valid torch model
View Source
    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "model_preprocessing_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
                "detail": CharField(default="Proprocessing Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a preprocessing model file for a global model.

        The preprocessing model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        ```python
        transforms = torch.nn.Sequential(
            torchvision.transforms.CenterCrop(10),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        )
        ```

        Make sure to only use transformations that inherit from `torch.nn.Module`.
        It is advised to use the `torchvision.transforms.v2` module for common transformations.

        Please note that this function is still in the beta phase.

        Args:
            request (HttpRequest): request object
            id (str): global model UUID

        Raises:
            PermissionDenied: Unauthorized to upload preprocessing model for the specified model
            ValidationError: Preprocessing model is not a valid torch model

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(GlobalModelDB, pk=id)
        if request.user.id != model.owner.id:
            raise PermissionDenied(f"You are not the owner of model {model.id}!")
        model.preprocessing = get_file(request, "model_preprocessing_file")
        verify_model_object(model.preprocessing, "preprocessing")
        model.save()
        return JsonResponse({
            "detail": "Proprocessing Model Upload Accepted",
        }, status=status.HTTP_202_ACCEPTED)