Skip to content

fl_server_api.views

Modules:

Name Description
base
dummy
group
inference
model
training
user

Classes:

Name Description
Group

Group Model ViewSet.

Inference

Inference ViewSet for performing inference on a model.

Model

Model ViewSet.

Training

Training model ViewSet.

User

User model ViewSet.

Attributes

__all__ module-attribute

__all__ = ['Group', 'Inference', 'Model', 'Training', 'User']

Classes

Group

Bases: ViewSet


              flowchart TD
              fl_server_api.views.Group[Group]
              fl_server_api.views.base.ViewSet[ViewSet]

                              fl_server_api.views.base.ViewSet --> fl_server_api.views.Group
                


              click fl_server_api.views.Group href "" "fl_server_api.views.Group"
              click fl_server_api.views.base.ViewSet href "" "fl_server_api.views.base.ViewSet"
            

Group Model ViewSet.

Methods:

Name Description
create

Create a new group.

destroy

Remove group by id.

list

Get all groups.

partial_update

Update group information partially.

retrieve

Get group information by id.

update

Update group information.

Attributes:

Name Type Description
serializer_class

The serializer for the ViewSet.

Source code in fl_server_api/views/group.py
class Group(ViewSet):
    """
    Group Model ViewSet.
    """

    serializer_class = GroupSerializer
    """The serializer for the ViewSet."""

    def _get_group(self, user: UserModel, group_id: int) -> GroupModel:
        """
        Get a group by id if the user is a member of the group.

        Args:
            user (UserModel): The user making the request.
            group_id (int): The id of the group.

        Raises:
            PermissionDenied: If the user is not a member of the group.

        Returns:
            GroupModel: The group instance.
        """
        group = get_entity(GroupModel, pk=group_id)
        if not user.groups.contains(group):
            raise PermissionDenied("You are not allowed to access this group.")
        return group

    @extend_schema(
        responses={
            status.HTTP_200_OK: GroupSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def list(self, request: HttpRequest) -> HttpResponse:
        """
        Get all groups.

        Args:
            request (HttpRequest): request object

        Raises:
            PermissionDenied: If user is not a superuser.

        Returns:
            HttpResponse: list of groups as json response
        """
        if not request.user.is_superuser:
            raise PermissionDenied("You are not allowed to access all groups.")
        groups = GroupModel.objects.all()
        serializer = GroupSerializer(groups, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: GroupSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[_default_group_example]
    )
    def retrieve(self, request: HttpRequest, id: int) -> HttpResponse:
        """
        Get group information by id.

        Args:
            request (HttpRequest): request object
            id (int): group id

        Returns:
            HttpResponse: group as json response
        """
        group = self._get_group(request.user, id)
        serializer = GroupSerializer(group)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_201_CREATED: GroupSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[OpenApiExample(
            name="Create group",
            description="Create a new group.",
            value={"name": "My new amazing group"},
        )]
    )
    def create(self, request: HttpRequest) -> HttpResponse:
        """
        Create a new group.

        Args:
            request (HttpRequest): request object

        Returns:
            HttpResponse: new created group as json response
        """
        serializer = GroupSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response(serializer.data, status=status.HTTP_201_CREATED)

    def _update(self, request: HttpRequest, id: int, *, partial: bool) -> HttpResponse:
        """
        Update group information.

        Args:
            request (HttpRequest): request object
            id (int): group id
            partial (bool): allow partial update

        Returns:
            HttpResponse: updated group as json response
        """
        group = self._get_group(request.user, id)
        serializer = GroupSerializer(group, data=request.data, partial=partial)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        if getattr(group, '_prefetched_objects_cache', None):
            # If 'prefetch_related' has been applied to a queryset, we need to
            # forcibly invalidate the prefetch cache on the instance.
            group._prefetched_objects_cache = {}
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: GroupSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            _default_group_example,
            OpenApiExample(
                name="Update group",
                description="Update group fields.",
                value={"name": "My new amazing group is the best!"},
            )
        ]
    )
    def update(self, request: HttpRequest, id: int) -> HttpResponse:
        """
        Update group information.

        Args:
            request (HttpRequest): request object
            id (int): group id

        Returns:
            HttpResponse: updated group as json response
        """
        return self._update(request, id, partial=False)

    @extend_schema(
        responses={
            status.HTTP_200_OK: GroupSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            _default_group_example,
            OpenApiExample(
                name="Update group partially",
                description="Update only some group fields.",
                value={"name": "My new amazing group is the best!"},
            )
        ]
    )
    def partial_update(self, request: HttpRequest, id: int) -> HttpResponse:
        """
        Update group information partially.

        Args:
            request (HttpRequest): request object
            id (int): group id

        Returns:
            HttpResponse: updated group as json response
        """
        return self._update(request, id, partial=True)

    @extend_schema(
        responses={
            status.HTTP_204_NO_CONTENT: None,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[_default_group_example]
    )
    def destroy(self, request: HttpRequest, id: int) -> HttpResponse:
        """
        Remove group by id.

        Args:
            request (HttpRequest): request object
            id (int): group id

        Returns:
            HttpResponse: 204 NO CONTENT
        """
        group = self._get_group(request.user, id)
        group.delete()
        return Response(status=status.HTTP_204_NO_CONTENT)

Attributes

serializer_class class-attribute instance-attribute
serializer_class = GroupSerializer

The serializer for the ViewSet.

Functions

create
create(request: HttpRequest) -> HttpResponse

Create a new group.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required

Returns:

Name Type Description
HttpResponse HttpResponse

new created group as json response

Source code in fl_server_api/views/group.py
@extend_schema(
    responses={
        status.HTTP_201_CREATED: GroupSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[OpenApiExample(
        name="Create group",
        description="Create a new group.",
        value={"name": "My new amazing group"},
    )]
)
def create(self, request: HttpRequest) -> HttpResponse:
    """
    Create a new group.

    Args:
        request (HttpRequest): request object

    Returns:
        HttpResponse: new created group as json response
    """
    serializer = GroupSerializer(data=request.data)
    serializer.is_valid(raise_exception=True)
    serializer.save()
    return Response(serializer.data, status=status.HTTP_201_CREATED)
destroy
destroy(request: HttpRequest, id: int) -> HttpResponse

Remove group by id.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
int

group id

required

Returns:

Name Type Description
HttpResponse HttpResponse

204 NO CONTENT

Source code in fl_server_api/views/group.py
@extend_schema(
    responses={
        status.HTTP_204_NO_CONTENT: None,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[_default_group_example]
)
def destroy(self, request: HttpRequest, id: int) -> HttpResponse:
    """
    Remove group by id.

    Args:
        request (HttpRequest): request object
        id (int): group id

    Returns:
        HttpResponse: 204 NO CONTENT
    """
    group = self._get_group(request.user, id)
    group.delete()
    return Response(status=status.HTTP_204_NO_CONTENT)
list
list(request: HttpRequest) -> HttpResponse

Get all groups.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required

Raises:

Type Description
PermissionDenied

If user is not a superuser.

Returns:

Name Type Description
HttpResponse HttpResponse

list of groups as json response

Source code in fl_server_api/views/group.py
@extend_schema(
    responses={
        status.HTTP_200_OK: GroupSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    }
)
def list(self, request: HttpRequest) -> HttpResponse:
    """
    Get all groups.

    Args:
        request (HttpRequest): request object

    Raises:
        PermissionDenied: If user is not a superuser.

    Returns:
        HttpResponse: list of groups as json response
    """
    if not request.user.is_superuser:
        raise PermissionDenied("You are not allowed to access all groups.")
    groups = GroupModel.objects.all()
    serializer = GroupSerializer(groups, many=True)
    return Response(serializer.data)
partial_update
partial_update(request: HttpRequest, id: int) -> HttpResponse

Update group information partially.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
int

group id

required

Returns:

Name Type Description
HttpResponse HttpResponse

updated group as json response

Source code in fl_server_api/views/group.py
@extend_schema(
    responses={
        status.HTTP_200_OK: GroupSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[
        _default_group_example,
        OpenApiExample(
            name="Update group partially",
            description="Update only some group fields.",
            value={"name": "My new amazing group is the best!"},
        )
    ]
)
def partial_update(self, request: HttpRequest, id: int) -> HttpResponse:
    """
    Update group information partially.

    Args:
        request (HttpRequest): request object
        id (int): group id

    Returns:
        HttpResponse: updated group as json response
    """
    return self._update(request, id, partial=True)
retrieve
retrieve(request: HttpRequest, id: int) -> HttpResponse

Get group information by id.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
int

group id

required

Returns:

Name Type Description
HttpResponse HttpResponse

group as json response

Source code in fl_server_api/views/group.py
@extend_schema(
    responses={
        status.HTTP_200_OK: GroupSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[_default_group_example]
)
def retrieve(self, request: HttpRequest, id: int) -> HttpResponse:
    """
    Get group information by id.

    Args:
        request (HttpRequest): request object
        id (int): group id

    Returns:
        HttpResponse: group as json response
    """
    group = self._get_group(request.user, id)
    serializer = GroupSerializer(group)
    return Response(serializer.data)
update
update(request: HttpRequest, id: int) -> HttpResponse

Update group information.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
int

group id

required

Returns:

Name Type Description
HttpResponse HttpResponse

updated group as json response

Source code in fl_server_api/views/group.py
@extend_schema(
    responses={
        status.HTTP_200_OK: GroupSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[
        _default_group_example,
        OpenApiExample(
            name="Update group",
            description="Update group fields.",
            value={"name": "My new amazing group is the best!"},
        )
    ]
)
def update(self, request: HttpRequest, id: int) -> HttpResponse:
    """
    Update group information.

    Args:
        request (HttpRequest): request object
        id (int): group id

    Returns:
        HttpResponse: updated group as json response
    """
    return self._update(request, id, partial=False)

Inference

Bases: ViewSet


              flowchart TD
              fl_server_api.views.Inference[Inference]
              fl_server_api.views.base.ViewSet[ViewSet]

                              fl_server_api.views.base.ViewSet --> fl_server_api.views.Inference
                


              click fl_server_api.views.Inference href "" "fl_server_api.views.Inference"
              click fl_server_api.views.base.ViewSet href "" "fl_server_api.views.base.ViewSet"
            

Inference ViewSet for performing inference on a model.

Methods:

Name Description
inference

Performs inference on the provided model and input data.

Attributes:

Name Type Description
serializer_class

The serializer for the ViewSet.

Source code in fl_server_api/views/inference.py
class Inference(ViewSet):
    """
    Inference ViewSet for performing inference on a model.
    """

    serializer_class = inline_serializer("InferenceSerializer", fields={
        "inference": ListField(child=ListField(child=FloatField())),
        "uncertainty": DictField(child=FloatField())
    })
    """The serializer for the ViewSet."""

    @extend_schema(
        request=inline_serializer(
            "InferenceJsonSerializer",
            fields={
                "model_id": CharField(),
                "model_input": ListField(child=ListField(child=FloatField())),
                "return_format": ChoiceField(["binary", "json"])
            }
        ),
        responses={
            status.HTTP_200_OK: serializer_class,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        },
        examples=[
            OpenApiExample("JSON Example", value={
                "model_id": "mymodel",
                "model_input": [
                    [1.0, 2.3, -0.4, 3],
                    [0.01, 9.7, 5.6, 7]
                ],
                "return_format": "json"
            }, request_only=True),
        ]
    )
    def inference(self, request: HttpRequest) -> HttpResponse:
        """
        Performs inference on the provided model and input data.

        This method takes in an HTTP request containing the necessary metadata and input data,
        performs any required preprocessing on the input data, runs the inference using the specified model,
        and returns a response in the format specified by the `return_format` parameter including
        possible uncertainty measurements if defined.

        Args:
            request (HttpRequest): The current HTTP request.

        Returns:
            HttpResponse: A HttpResponse containing the result of the inference as well as its uncertainty.
        """
        request_body, is_json = self._get_handle_content_type(request)
        model, preprocessing, input_shape, return_format = self._get_inference_metadata(
            request_body,
            "json" if is_json else "binary"
        )
        model_input = self._get_model_input(request, request_body)

        if preprocessing:
            model_input = preprocessing(model_input)
        else:
            # if no preprocessing is defined, at least try to convert/interpret the model_input as
            # PyTorch tensor, before raising an exception
            model_input = self._try_cast_model_input_to_tensor(model_input)
        self._validate_model_input_after_preprocessing(model_input, input_shape, bool(preprocessing))

        uncertainty_cls, inference, uncertainty = self._do_inference(model, model_input)
        return self._make_response(uncertainty_cls, inference, uncertainty, return_format)

    def _get_handle_content_type(self, request: HttpRequest) -> Tuple[dict, bool]:
        """
        Handles HTTP request body based on their content type.

        This function checks if the request content type is either `application/json`
        or `multipart/form-data`. If it matches, it returns the corresponding data and
        a boolean indicating whether it's JSON (True) or multipart/form-data (False).

        Args:
            request (HttpRequest): The request.

        Returns:
            tuple: A tuple containing the parsed data and a boolean indicating the content type.
                * If content type is `application/json`, returns the JSON payload as a Python object (dict)
                and True to indicate it's JSON.
                * If content type is `multipart/form-data`, returns the request POST data and False.

        Raises:
            UnsupportedMediaType: If an unknown content type is specified, raising an error with
                details on supported types (`application/json` and `multipart/form-data`).
        """
        match request.content_type.lower():
            case s if s.startswith("multipart/form-data"):
                return request.POST, False
            case s if s.startswith("application/json"):
                return json.loads(request.body), True

        # if the content type is specified, but not supported, return 415
        self._logger.error(f"Unknown Content-Type '{request.content_type}'")
        raise UnsupportedMediaType(
            "Only Content-Type 'application/json' and 'multipart/form-data' is supported."
        )

    def _get_inference_metadata(
        self,
        request_body: dict,
        return_format_default: Literal["binary", "json"]
    ) -> Tuple[Model, Optional[torch.nn.Module], Optional[List[Optional[int]]], str]:
        """
        Retrieves inference metadata based on the content of the provided request body.

        This method checks if a `model_id` is present in the request body and retrieves
        the corresponding model entity. It then determines the return format based on the
        request body or default to one of the two supported formats (`binary` or `json`).

        Args:
            request_body (dict): The data sent with the request, containing at least `model_id`.
            return_format_default (Literal["binary", "json"]): The default return format to use if not specified in
                the request body.

        Returns:
            Tuple[Model, Optional[torch.nn.Module], Optional[List[Optional[int]]], str]: A tuple containing:
                * The retrieved model entity.
                * The global model's preprocessing torch module (if applicable).
                * The input shape of the global model (if applicable).
                * The return format (`binary` or `json`).

        Raises:
            ValidationError: If no valid `model_id` is provided in the request body, or if an unknown return format
                is specified.
        """
        if "model_id" not in request_body:
            self._logger.error("No 'model_id' provided in request.")
            raise ValidationError("No 'model_id' provided in request.")
        model_id = request_body["model_id"]
        model = get_entity(Model, pk=model_id)

        return_format = request_body.get("return_format", return_format_default)
        if return_format not in ["binary", "json"]:
            self._logger.error(f"Unknown return format '{return_format}'. Supported are binary and json.")
            raise ValidationError(f"Unknown return format '{return_format}'. Supported are binary and json.")

        global_model: Optional[GlobalModel] = None
        if isinstance(model, GlobalModel):
            global_model = model
        elif isinstance(model, LocalModel):
            global_model = model.base_model
        else:
            self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel. Skip preprocessing.")

        preprocessing: Optional[torch.nn.Module] = None
        input_shape: Optional[List[Optional[int]]] = None
        if global_model:
            if global_model.preprocessing is not None:
                preprocessing = global_model.get_preprocessing_torch_model()
            if global_model.input_shape is not None:
                input_shape = global_model.input_shape

        return model, preprocessing, input_shape, return_format

    def _get_model_input(self, request: HttpRequest, request_body: dict) -> Any:
        """
        Retrieves and decodes the model input from either an uploaded file or the request body.

        Args:
            request (HttpRequest): The current HTTP request.
            request_body (dict): The parsed request body as a dictionary.

        Returns:
            Any: The decoded model input data.

        Raises:
            ValidationError: If no `model_input` is found in the uploaded file or the request body.
        """
        uploaded_file = request.FILES.get("model_input", None)
        if uploaded_file and uploaded_file.file:
            model_input = uploaded_file.file.read()
        else:
            model_input = request_body.get("model_input", None)
        if not model_input:
            raise ValidationError("No uploaded file 'model_input' found.")
        return self._try_decode_model_input(model_input)

    def _try_decode_model_input(self, model_input: Any) -> Any:
        """
        Attempts to decode the input `model_input` from various formats and returns it in a usable form.

        This function first tries to deserialize the input as a PyTorch tensor. If that fails, it attempts to
        decode the input as a base64-encoded string. If neither attempt is successful, the original input is returned.

        Args:
            model_input (Any): The input to be decoded, which can be in any format.

        Returns:
            Any: The decoded input, which may still be in an unknown format if decoding attempts fail.
        """
        # 1. try to deserialize model_input as PyTorch tensor
        try:
            with disable_logger(self._logger):
                model_input = to_torch_tensor(model_input)
        except Exception:
            pass
        # 2. try to decode model_input as base64
        try:
            is_base64, tmp_model_input = self._is_base64(model_input)
            if is_base64:
                model_input = tmp_model_input
        except Exception:
            pass
        # result
        return model_input

    def _try_cast_model_input_to_tensor(self, model_input: Any) -> Any:
        """
        Attempt to cast the given model input to a PyTorch tensor.

        This function tries to interpret the input in several formats:

        1. PIL Image (and later convert it to a PyTorch tensor, see 3.)
        2. PyTorch tensor via `torch.as_tensor`
        3. PyTorch tensor via torchvision `ToTensor` (supports e.g. PIL images)

        If none of these attempts are successful, the original input is returned.

        Args:
            model_input: The input data to be cast to a PyTorch tensor.
                Can be any type that can be converted to a tensor.

        Returns:
            A PyTorch tensor representation of the input data, or the original
            input if it cannot be converted.
        """
        def _try_to_pil_image(model_input: Any) -> Any:
            stream = BytesIO(model_input)
            return Image.open(stream)

        if isinstance(model_input, torch.Tensor):
            return model_input

        # In the following order, try to:
        # 1. interpret model_input as PIL image (and later to PyTorch tensor, see step 3),
        # 2. interpret model_input as PyTorch tensor,
        # 3. interpret model_input as PyTorch tensor via torchvision ToTensor (supports e.g. PIL images).
        for fn in [_try_to_pil_image, torch.as_tensor, to_tensor]:
            try:
                model_input = fn(model_input)  # type: ignore
            except Exception:
                pass
        return model_input

    def _is_base64(self, sb: str | bytes) -> Tuple[bool, bytes]:
        """
        Check if a string or bytes object is a valid Base64 encoded string.

        This function checks if the input can be decoded and re-encoded without any changes.
        If decoding and encoding returns the same result as the original input, it's likely
        that the input was indeed a valid Base64 encoded string.

        Note: This code is based on the reference implementation from the linked Stack Overflow answer.

        Args:
            sb (str | bytes): The input string or bytes object to check.

        Returns:
            Tuple[bool, bytes]: A tuple containing a boolean indicating whether the input is
                a valid Base64 encoded string and the decoded bytes if it is.

        References:
            https://stackoverflow.com/a/45928164
        """
        try:
            if isinstance(sb, str):
                # If there's any unicode here, an exception will be thrown and the function will return false
                sb_bytes = bytes(sb, "ascii")
            elif isinstance(sb, bytes):
                sb_bytes = sb
            else:
                raise ValueError("Argument must be string or bytes")
            decoded = base64.b64decode(sb_bytes)
            return base64.b64encode(decoded) == sb_bytes, decoded
        except Exception:
            return False, b""

    def _validate_model_input_after_preprocessing(
        self,
        model_input: Any,
        model_input_shape: Optional[List[Optional[int]]],
        preprocessing: bool
    ) -> None:
        """
        Validates the model input after preprocessing.

        Ensures that the provided `model_input` is a valid PyTorch tensor and its shape matches
        the expected`model_input_shape`.

        Args:
            model_input (Any): The model input to be validated.
            model_input_shape (Optional[List[Optional[int]]]): The expected shape of the model input.
                Can contain None values if not all dimensions are fixed (e.g. first dimension as batch size).
            preprocessing (bool): Whether a preprocessing model was defined or not. (Only for a better error message.)

        Raises:
            ValidationError: If the `model_input` is not a valid PyTorch tensor or
                its shape does not match the expected `model_input_shape`.
        """
        if not isinstance(model_input, torch.Tensor):
            msg = "Model input could not be casted or interpreted as a PyTorch tensor object"
            if preprocessing:
                msg += " and is still not a PyTorch tensor after preprecessing."
            else:
                msg += " and no preprecessing is defined."
            raise ValidationError(msg)

        if model_input_shape and not all(
            dim_input == dim_model
            for (dim_input, dim_model) in zip(model_input.shape, model_input_shape)
            if dim_model is not None
        ):
            raise ValidationError("Input shape does not match model input shape.")

    def _make_response(
        self,
        uncertainty_cls: Type[UncertaintyBase],
        inference: torch.Tensor,
        uncertainty: Any,
        return_type: str
    ) -> HttpResponse:
        """
        Build the response object with the result data.

        This method checks the return type and makes a response with the appropriate content type.

        If return_type is "binary", a binary-encoded response will be generated using pickle.
        Otherwise, a JSON response will be generated by serializing the uncertainty object using its to_json method.

        Args:
            uncertainty_cls (Type[UncertaintyBase]): The uncertainty class.
            inference (torch.Tensor): The inference.
            uncertainty (Any): The uncertainty.
            return_type (str): The return type.

        Returns:
            HttpResponse: The inference result response.
        """
        if return_type == "binary":
            response_bytes = pickle.dumps(dict(inference=inference, uncertainty=uncertainty))
            return HttpResponse(response_bytes, content_type="application/octet-stream")

        return HttpResponse(uncertainty_cls.to_json(inference, uncertainty), content_type="application/json")

    def _do_inference(
        self, model: Model, input_tensor: torch.Tensor
    ) -> Tuple[Type[UncertaintyBase], torch.Tensor, Dict[str, Any]]:
        """
        Perform inference on a given input tensor using the provided model.

        This methods retrieves the uncertainty class, performs the prediction.
        The output of this method consists of:

        * The uncertainty class used for inference
        * The result of the model's prediction on the input tensor
        * Any associated uncertainty for the prediction

        Args:
            model (Model): The model to perform inference with.
            input_tensor (torch.Tensor): Input tensor to pass through the model.

        Returns:
            Tuple[Type[UncertaintyBase], torch.Tensor, Dict[str, Any]]:
                A tuple containing the uncertainty class, prediction result, and any associated uncertainty.

        Raises:
            APIException: If an error occurs during inference
        """
        try:
            uncertainty_cls = get_uncertainty_class(model)
            inference, uncertainty = uncertainty_cls.prediction(input_tensor, model)
            return uncertainty_cls, inference, uncertainty
        except TorchDeserializationException as e:
            raise APIException(e) from e
        except Exception as e:
            self._logger.error(e)
            raise APIException("Internal Server Error occurred during inference!") from e

Attributes

serializer_class class-attribute instance-attribute
serializer_class = inline_serializer('InferenceSerializer', fields={'inference': ListField(child=ListField(child=FloatField())), 'uncertainty': DictField(child=FloatField())})

The serializer for the ViewSet.

Functions

inference
inference(request: HttpRequest) -> HttpResponse

Performs inference on the provided model and input data.

This method takes in an HTTP request containing the necessary metadata and input data, performs any required preprocessing on the input data, runs the inference using the specified model, and returns a response in the format specified by the return_format parameter including possible uncertainty measurements if defined.

Parameters:

Name Type Description Default
request
HttpRequest

The current HTTP request.

required

Returns:

Name Type Description
HttpResponse HttpResponse

A HttpResponse containing the result of the inference as well as its uncertainty.

Source code in fl_server_api/views/inference.py
@extend_schema(
    request=inline_serializer(
        "InferenceJsonSerializer",
        fields={
            "model_id": CharField(),
            "model_input": ListField(child=ListField(child=FloatField())),
            "return_format": ChoiceField(["binary", "json"])
        }
    ),
    responses={
        status.HTTP_200_OK: serializer_class,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
    },
    examples=[
        OpenApiExample("JSON Example", value={
            "model_id": "mymodel",
            "model_input": [
                [1.0, 2.3, -0.4, 3],
                [0.01, 9.7, 5.6, 7]
            ],
            "return_format": "json"
        }, request_only=True),
    ]
)
def inference(self, request: HttpRequest) -> HttpResponse:
    """
    Performs inference on the provided model and input data.

    This method takes in an HTTP request containing the necessary metadata and input data,
    performs any required preprocessing on the input data, runs the inference using the specified model,
    and returns a response in the format specified by the `return_format` parameter including
    possible uncertainty measurements if defined.

    Args:
        request (HttpRequest): The current HTTP request.

    Returns:
        HttpResponse: A HttpResponse containing the result of the inference as well as its uncertainty.
    """
    request_body, is_json = self._get_handle_content_type(request)
    model, preprocessing, input_shape, return_format = self._get_inference_metadata(
        request_body,
        "json" if is_json else "binary"
    )
    model_input = self._get_model_input(request, request_body)

    if preprocessing:
        model_input = preprocessing(model_input)
    else:
        # if no preprocessing is defined, at least try to convert/interpret the model_input as
        # PyTorch tensor, before raising an exception
        model_input = self._try_cast_model_input_to_tensor(model_input)
    self._validate_model_input_after_preprocessing(model_input, input_shape, bool(preprocessing))

    uncertainty_cls, inference, uncertainty = self._do_inference(model, model_input)
    return self._make_response(uncertainty_cls, inference, uncertainty, return_format)

Model

Bases: ViewSet


              flowchart TD
              fl_server_api.views.Model[Model]
              fl_server_api.views.base.ViewSet[ViewSet]

                              fl_server_api.views.base.ViewSet --> fl_server_api.views.Model
                


              click fl_server_api.views.Model href "" "fl_server_api.views.Model"
              click fl_server_api.views.base.ViewSet href "" "fl_server_api.views.base.ViewSet"
            

Model ViewSet.

Methods:

Name Description
create_local_model

Upload a partial trained model file from client.

create_model

Upload a global model file.

create_model_metrics

Upload model metrics.

create_swag_stats

Upload SWAG statistics.

get_metadata

Get model meta data.

get_model

Download the whole model as PyTorch serialized file.

get_model_metrics

Reports all metrics for the selected model.

get_model_proprecessing

Download the whole preprocessing model as PyTorch serialized file.

get_models

Get a list of all global models associated with the requesting user.

get_training_models

Get a list of all models associated with a specific training process and the requesting user.

get_training_models_latest

Get a list of the latest models for a specific training process associated with the requesting user.

remove_model

Remove an existing model.

upload_model_preprocessing

Upload a preprocessing model file for a global model.

Attributes:

Name Type Description
serializer_class

The serializer for the ViewSet.

Source code in fl_server_api/views/model.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
class Model(ViewSet):
    """
    Model ViewSet.
    """

    serializer_class = GlobalModelSerializer
    """The serializer for the ViewSet."""

    def _get_user_related_global_models(self, user: Union[AbstractBaseUser, AnonymousUser]) -> List[ModelDB]:
        """
        Get global models related to a user.

        This method retrieves all global models where the user is the actor or a participant of.

        Args:
            user (Union[AbstractBaseUser, AnonymousUser]): The user.

        Returns:
            List[ModelDB]: The global models related to the user.
        """
        user_ids = Training.objects.filter(
            Q(actor=user) | Q(participants=user)
        ).distinct().values_list("model__id", flat=True)
        return ModelDB.objects.filter(Q(owner=user) | Q(id__in=user_ids)).distinct()

    def _get_local_models_for_global_model(self, global_model: GlobalModelDB) -> List[LocalModelDB]:
        """
        Get all local models that are based on the global model.

        Args:
            global_model (GlobalModelDB): The global model.

        Returns:
            List[LocalModelDB]: The local models for the global model.
        """
        return LocalModelDB.objects.filter(base_model=global_model).all()

    def _get_local_models_for_global_models(self, global_models: List[GlobalModelDB]) -> List[LocalModelDB]:
        """
        Get all local models that are based on any of the global models.

        Args:
            global_models (List[GlobalModelDB]): The global models.

        Returns:
            List[LocalModelDB]: The local models for the global models.
        """
        return LocalModelDB.objects.filter(base_model__in=global_models).all()

    def _filter_by_training(self, models: List[ModelDB], training_id: str) -> List[ModelDB]:
        """
        Filter a list of models by checking if they are associated with the training.

        Args:
            models (List[ModelDB]): The models to filter.
            training_id (str): The ID of the training.

        Returns:
            List[ModelDB]: The models associated with the training.
        """
        def associated_with_training(m: ModelDB) -> bool:
            training = m.get_training()
            if training is None:
                return False
            return training.pk == UUID(training_id)
        return list(filter(associated_with_training, models))

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_models(self, request: HttpRequest) -> HttpResponse:
        """
        Get a list of all global models associated with the requesting user.

        A global model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models = self._get_user_related_global_models(request.user)
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of all models associated with a specific training process and the requesting user.

        A model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        local_models = self._get_local_models_for_global_models(global_models)
        serializer = ModelSerializer([*global_models, *local_models], many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of the latest models for a specific training process associated with the requesting user.

        A model is considered associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.
        The latest model refers to the model from the most recent round (highest round number) of
        a participant's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models: List[ModelDB] = []
        # add latest global model
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        models.append(max(global_models, key=lambda m: m.round))
        # add latest local models
        local_models = self._get_local_models_for_global_models(global_models)
        local_models = sorted(local_models, key=lambda m: str(m.owner.pk))  # required for groupby
        for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
            models.append(max(group, key=lambda m: m.round))
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
        """
        Get model meta data.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: Model meta data as JSON response.
        """
        model = get_entity(ModelDB, pk=id)
        serializer = ModelSerializer(model, context={"with-stats": True})
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: model as file response
        """
        model = get_entity(ModelDB, pk=id)
        if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
            if model.swag_second_moment is None:
                raise APIException(f"Model {model.id} is in inconsistent state!")
            raise NotImplementedError(
                "SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
            )
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(model.weights, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}.pt"'
        return response

    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(
                response=bytes,
                description="Proprecessing model is returned as bytes"
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
            status.HTTP_404_NOT_FOUND: error_response_404,
        },
    )
    def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole preprocessing model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
        """
        model = get_entity(ModelDB, pk=id)
        global_model: torch.nn.Module
        if isinstance(model, GlobalModelDB):
            global_model = model
        elif isinstance(model, LocalModelDB):
            global_model = model.base_model
        else:
            self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
            raise ValidationError(f"Unknown model type. Model id: {id}")
        if global_model.preprocessing is None:
            raise NotFound(f"Model '{id}' has no preprocessing model defined.")
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
        return response

    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteModelSuccessSerializer",
            fields={
                "detail": CharField(default="Model removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing model.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: 200 Response if model was removed, else corresponding error code
        """
        model = get_entity(ModelDB, pk=id)
        if model.owner != request.user:
            training = model.get_training()
            if training is None or training.actor != request.user:
                raise PermissionDenied(
                    "You are neither the owner of the model nor the actor of the corresponding training."
                )
        model.delete()
        return JsonResponse({"detail": "Model removed!"})

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "name": {"type": "string"},
                    "description": {"type": "string"},
                    "model_file": {"type": "string", "format": "binary"},
                    "model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
                },
            },
        },
        responses={
            status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
                "detail": CharField(default="Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_model(self, request: HttpRequest) -> HttpResponse:
        """
        Upload a global model file.

        The model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: upload success message as json response
        """
        model = load_and_create_model_request(request)
        return JsonResponse({
            "detail": "Model Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "model_preprocessing_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
                "detail": CharField(default="Proprocessing Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a preprocessing model file for a global model.

        The preprocessing model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        ```python
        transforms = torch.nn.Sequential(
            torchvision.transforms.CenterCrop(10),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        )
        ```

        Make sure to only use transformations that inherit from `torch.nn.Module`.
        It is advised to use the `torchvision.transforms.v2` module for common transformations.

        Please note that this function is still in the beta phase.

        Args:
            request (HttpRequest): request object
            id (str): global model UUID

        Raises:
            PermissionDenied: Unauthorized to upload preprocessing model for the specified model
            ValidationError: Preprocessing model is not a valid torch model

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(GlobalModelDB, pk=id)
        if request.user.id != model.owner.id:
            raise PermissionDenied(f"You are not the owner of model {model.id}!")
        model.preprocessing = get_file(request, "model_preprocessing_file")
        verify_model_object(model.preprocessing, "preprocessing")
        model.save()
        return JsonResponse({
            "detail": "Proprocessing Model Upload Accepted",
        }, status=status.HTTP_202_ACCEPTED)

    @extend_schema(
        responses={
            status.HTTP_200_OK: MetricSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Reports all metrics for the selected model.

        Args:
            request (HttpRequest):  request object
            id (str):  model UUID

        Returns:
            HttpResponse: Metrics as JSON Array
        """
        model = get_entity(ModelDB, pk=id)
        metrics = MetricDB.objects.filter(model=model).all()
        return Response(MetricSerializer(metrics, many=True).data)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "metric_names": {"type": "list"},
                    "metric_values": {"type": "list"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
                "detail": CharField(default="Model Metrics Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            OpenApiExample("Example", value={
                "metric_names": ["accuracy", "training loss"],
                "metric_values": [0.6, 0.04]
            }, media_type="multipart/form-data")
        ]
    )
    def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload model metrics.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(ModelDB, pk=id)
        formdata = dict(request.POST)

        with locked_atomic_transaction(MetricDB):
            self._metric_upload(formdata, model, request.user)

            if isinstance(model, GlobalModelDB):
                n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
                training = model.get_training()
                if training:
                    if n_metrics == training.participants.count():
                        dispatch_trainer_task(training, ModelTestFinished, False)
                else:
                    self._logger.warning(f"Global model {id} is not connected to any training.")

        return JsonResponse({
            "detail": "Model Metrics Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "owner": {"type": "string"},
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "metric_names": {"type": "list[string]"},
                    "metric_values": {"type": "list[float]"},
                    "model_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: ModelSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a partial trained model file from client.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid of the model, which was used for training

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            client = request.user
            model_file = get_file(request, "model_file")
            global_model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=global_model)
            self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)

            verify_model_object(model_file)
            local_model = LocalModelDB.objects.create(
                base_model=global_model, weights=model_file,
                round=round_num, owner=client, sample_size=sample_size
            )
            self._metric_upload(formdata, local_model, client, metrics_required=False)

            updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
            if updates.count() == training.participants.count():
                dispatch_trainer_task(training, TrainingRoundFinished, True)

            return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "first_moment_file": {"type": "string", "format": "binary"},
                    "second_moment_file": {"type": "string", "format": "binary"}
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
                "detail": CharField(default="SWAg Statistics Accepted"),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload SWAG statistics.

        Args:
            request (HttpRequest): request object
            id (str): global model uuid

        Raises:
            APIException: internal server error
            NotFound: model not found
            ParseError: request data not valid

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            client = request.user
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            fst_moment = get_file(request, "first_moment_file")
            snd_moment = get_file(request, "second_moment_file")
            model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=model)
            self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)

            self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)

            swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
            swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")

            if swag_stats_first.count() != swag_stats_second.count():
                training.state = TrainingState.ERROR
                raise APIException("SWAG stats in inconsistent state!")
            if swag_stats_first.count() == training.participants.count():
                dispatch_trainer_task(training, SWAGRoundFinished, True)

            return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)
        except Exception as e:
            raise APIException(e)

    @staticmethod
    def _save_swag_stats(
        fst_moment: bytes, snd_moment: bytes, model: GlobalModelDB, client: UserDB, sample_size: int
    ):
        """
        Save the first and second moments, and the sample size of the SWAG to the database.

        This function creates and saves three metrics for each round of the model:

        - the first moment,
        - the second moment, and
        - the sample size.

        These metrics are associated with the model, the round, and the client that reported them.

        Args:
            fst_moment (bytes): The first moment of the SWAG.
            snd_moment (bytes): The second moment of the SWAG.
            model (GlobalModelDB): The global model for which the metrics are being reported.
            client (UserDB): The client reporting the metrics.
            sample_size (int): The sample size of the SWAG.
        """
        MetricDB.objects.create(
            model=model,
            key="SWAG First Moment Local",
            value_binary=fst_moment,
            step=model.round,
            reporter=client
        ).save()
        MetricDB.objects.create(
            model=model,
            key="SWAG Second Moment Local",
            value_binary=snd_moment,
            step=model.round,
            reporter=client
        ).save()
        MetricDB.objects.create(
            model=model,
            key="SWAG Sample Size Local",
            value_float=sample_size,
            step=model.round,
            reporter=client
        ).save()

    @transaction.atomic()
    def _metric_upload(self, formdata: dict, model: ModelDB, client: UserDB, metrics_required: bool = True):
        """
        Uploads metrics associated with a model.

        For each pair of metric name and value, it attempts to convert the value to a float.
        If this fails, it treats the value as a binary string.

        It then creates a new metric object with the model, the metric name, the float or binary value,
        the model's round number, and the client, and saves this object to the database.

        Args:
            formdata (dict): The form data containing the metric names and values.
            model (ModelDB): The model with which the metrics are associated.
            client (UserDB): The client reporting the metrics.
            metrics_required (bool): A flag indicating whether metrics are required. Defaults to True.

        Raises:
            ParseError: If `metric_names` or `metric_values` are not in formdata,
                or if they do not have the same length and metrics are required.
        """
        if "metric_names" not in formdata or "metric_values" not in formdata:
            if metrics_required or ("metric_names" in formdata) != ("metric_values" in formdata):
                raise ParseError("Metric names or values are missing")
            return
        if len(formdata["metric_names"]) != len(formdata["metric_values"]):
            if metrics_required:
                raise ParseError("Metric names and values must have the same length")
            return

        for key, value in zip(formdata["metric_names"], formdata["metric_values"]):
            try:
                metric_float = float(value)
                metric_binary = None
            except Exception:
                metric_float = None
                metric_binary = bytes(value, encoding="utf-8")
            MetricDB.objects.create(
                model=model,
                key=key,
                value_float=metric_float,
                value_binary=metric_binary,
                step=model.round,
                reporter=client
            ).save()

    def _verify_valid_update(self, client: UserDB, train: Training, round_num: int, expected_state: tuple[str, Any]):
        """
        Verifies if a client can update a training process.

        This function checks if

        - the client is a participant of the training process,
        - the training process is in the expected state, and if
        - the round number matches the current round of the model associated with the training process.

        Args:
            client (UserDB): The client attempting to update the training process.
            train (Training): The training process to be updated.
            round_num (int): The round number reported by the client.
            expected_state (tuple[str, Any]): The expected state of the training process.

        Raises:
            PermissionDenied: If the client is not a participant of the training process.
            ValidationError: If the training process is not in the expected state or if the round number does not match
                the current round of the model.
        """
        if client.id not in [p.id for p in train.participants.all()]:
            raise PermissionDenied(f"You are not a participant of training {train.id}!")
        if train.state != expected_state:
            raise ValidationError(f"Training with ID {train.id} is in state {train.state}")
        if int(round_num) != train.model.round:
            raise ValidationError(f"Training with ID {train.id} is not currently in round {round_num}")

Attributes

serializer_class class-attribute instance-attribute
serializer_class = GlobalModelSerializer

The serializer for the ViewSet.

Functions

create_local_model
create_local_model(request: HttpRequest, id: str) -> HttpResponse

Upload a partial trained model file from client.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

model uuid of the model, which was used for training

required

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "owner": {"type": "string"},
                "round": {"type": "int"},
                "sample_size": {"type": "int"},
                "metric_names": {"type": "list[string]"},
                "metric_values": {"type": "list[float]"},
                "model_file": {"type": "string", "format": "binary"},
            },
        },
    },
    responses={
        status.HTTP_200_OK: ModelSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Upload a partial trained model file from client.

    Args:
        request (HttpRequest):  request object
        id (str):  model uuid of the model, which was used for training

    Returns:
        HttpResponse: upload success message as json response
    """
    try:
        formdata = dict(request.POST)
        (round_num,) = formdata["round"]
        (sample_size,) = formdata["sample_size"]
        round_num, sample_size = int(round_num), int(sample_size)
        client = request.user
        model_file = get_file(request, "model_file")
        global_model = get_entity(GlobalModelDB, pk=id)

        # ensure that a training process coresponding to the model exists, else the process will error out
        training = Training.objects.get(model=global_model)
        self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)

        verify_model_object(model_file)
        local_model = LocalModelDB.objects.create(
            base_model=global_model, weights=model_file,
            round=round_num, owner=client, sample_size=sample_size
        )
        self._metric_upload(formdata, local_model, client, metrics_required=False)

        updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
        if updates.count() == training.participants.count():
            dispatch_trainer_task(training, TrainingRoundFinished, True)

        return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
    except Training.DoesNotExist:
        raise NotFound(f"Model with ID {id} does not have a training process running")
    except (MultiValueDictKeyError, KeyError) as e:
        raise ParseError(e)
create_model
create_model(request: HttpRequest) -> HttpResponse

Upload a global model file.

The model file should be a PyTorch serialized model. Providing the model via torch.save as well as in TorchScript format is supported.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "name": {"type": "string"},
                "description": {"type": "string"},
                "model_file": {"type": "string", "format": "binary"},
                "model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
            },
        },
    },
    responses={
        status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
            "detail": CharField(default="Model Upload Accepted"),
            "model_id": UUIDField(),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def create_model(self, request: HttpRequest) -> HttpResponse:
    """
    Upload a global model file.

    The model file should be a PyTorch serialized model.
    Providing the model via `torch.save` as well as in TorchScript format is supported.

    Args:
        request (HttpRequest): The incoming request object.

    Returns:
        HttpResponse: upload success message as json response
    """
    model = load_and_create_model_request(request)
    return JsonResponse({
        "detail": "Model Upload Accepted",
        "model_id": str(model.id),
    }, status=status.HTTP_201_CREATED)
create_model_metrics
create_model_metrics(request: HttpRequest, id: str) -> HttpResponse

Upload model metrics.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

model uuid

required

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "metric_names": {"type": "list"},
                "metric_values": {"type": "list"},
            },
        },
    },
    responses={
        status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
            "detail": CharField(default="Model Metrics Upload Accepted"),
            "model_id": UUIDField(),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[
        OpenApiExample("Example", value={
            "metric_names": ["accuracy", "training loss"],
            "metric_values": [0.6, 0.04]
        }, media_type="multipart/form-data")
    ]
)
def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Upload model metrics.

    Args:
        request (HttpRequest):  request object
        id (str):  model uuid

    Returns:
        HttpResponse: upload success message as json response
    """
    model = get_entity(ModelDB, pk=id)
    formdata = dict(request.POST)

    with locked_atomic_transaction(MetricDB):
        self._metric_upload(formdata, model, request.user)

        if isinstance(model, GlobalModelDB):
            n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
            training = model.get_training()
            if training:
                if n_metrics == training.participants.count():
                    dispatch_trainer_task(training, ModelTestFinished, False)
            else:
                self._logger.warning(f"Global model {id} is not connected to any training.")

    return JsonResponse({
        "detail": "Model Metrics Upload Accepted",
        "model_id": str(model.id),
    }, status=status.HTTP_201_CREATED)
create_swag_stats
create_swag_stats(request: HttpRequest, id: str) -> HttpResponse

Upload SWAG statistics.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

global model uuid

required

Raises:

Type Description
APIException

internal server error

NotFound

model not found

ParseError

request data not valid

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "round": {"type": "int"},
                "sample_size": {"type": "int"},
                "first_moment_file": {"type": "string", "format": "binary"},
                "second_moment_file": {"type": "string", "format": "binary"}
            },
        },
    },
    responses={
        status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
            "detail": CharField(default="SWAg Statistics Accepted"),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Upload SWAG statistics.

    Args:
        request (HttpRequest): request object
        id (str): global model uuid

    Raises:
        APIException: internal server error
        NotFound: model not found
        ParseError: request data not valid

    Returns:
        HttpResponse: upload success message as json response
    """
    try:
        client = request.user
        formdata = dict(request.POST)
        (round_num,) = formdata["round"]
        (sample_size,) = formdata["sample_size"]
        round_num, sample_size = int(round_num), int(sample_size)
        fst_moment = get_file(request, "first_moment_file")
        snd_moment = get_file(request, "second_moment_file")
        model = get_entity(GlobalModelDB, pk=id)

        # ensure that a training process coresponding to the model exists, else the process will error out
        training = Training.objects.get(model=model)
        self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)

        self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)

        swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
        swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")

        if swag_stats_first.count() != swag_stats_second.count():
            training.state = TrainingState.ERROR
            raise APIException("SWAG stats in inconsistent state!")
        if swag_stats_first.count() == training.participants.count():
            dispatch_trainer_task(training, SWAGRoundFinished, True)

        return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
    except Training.DoesNotExist:
        raise NotFound(f"Model with ID {id} does not have a training process running")
    except (MultiValueDictKeyError, KeyError) as e:
        raise ParseError(e)
    except Exception as e:
        raise APIException(e)
get_metadata
get_metadata(_request: HttpRequest, id: str) -> HttpResponse

Get model meta data.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
id
str

The unique identifier of the model.

required

Returns:

Name Type Description
HttpResponse HttpResponse

Model meta data as JSON response.

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
    """
    Get model meta data.

    Args:
        request (HttpRequest): The incoming request object.
        id (str): The unique identifier of the model.

    Returns:
        HttpResponse: Model meta data as JSON response.
    """
    model = get_entity(ModelDB, pk=id)
    serializer = ModelSerializer(model, context={"with-stats": True})
    return Response(serializer.data)
get_model
get_model(_request: HttpRequest, id: str) -> HttpResponseBase

Download the whole model as PyTorch serialized file.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
id
str

The unique identifier of the model.

required

Returns:

Name Type Description
HttpResponseBase HttpResponseBase

model as file response

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
    """
    Download the whole model as PyTorch serialized file.

    Args:
        request (HttpRequest): The incoming request object.
        id (str): The unique identifier of the model.

    Returns:
        HttpResponseBase: model as file response
    """
    model = get_entity(ModelDB, pk=id)
    if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
        if model.swag_second_moment is None:
            raise APIException(f"Model {model.id} is in inconsistent state!")
        raise NotImplementedError(
            "SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
        )
    # NOTE: FileResponse does strange stuff with bytes
    #       and in case of sqlite the weights will be bytes and not a memoryview
    response = HttpResponse(model.weights, content_type="application/octet-stream")
    response["Content-Disposition"] = f'filename="model-{id}.pt"'
    return response
get_model_metrics
get_model_metrics(request: HttpRequest, id: str) -> HttpResponse

Reports all metrics for the selected model.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

model UUID

required

Returns:

Name Type Description
HttpResponse HttpResponse

Metrics as JSON Array

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: MetricSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Reports all metrics for the selected model.

    Args:
        request (HttpRequest):  request object
        id (str):  model UUID

    Returns:
        HttpResponse: Metrics as JSON Array
    """
    model = get_entity(ModelDB, pk=id)
    metrics = MetricDB.objects.filter(model=model).all()
    return Response(MetricSerializer(metrics, many=True).data)
get_model_proprecessing
get_model_proprecessing(_request: HttpRequest, id: str) -> HttpResponseBase

Download the whole preprocessing model as PyTorch serialized file.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
id
str

The unique identifier of the model.

required

Returns:

Name Type Description
HttpResponseBase HttpResponseBase

proprecessing model as file response or 404 if proprecessing model not found

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: OpenApiResponse(
            response=bytes,
            description="Proprecessing model is returned as bytes"
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
        status.HTTP_404_NOT_FOUND: error_response_404,
    },
)
def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
    """
    Download the whole preprocessing model as PyTorch serialized file.

    Args:
        request (HttpRequest): The incoming request object.
        id (str): The unique identifier of the model.

    Returns:
        HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
    """
    model = get_entity(ModelDB, pk=id)
    global_model: torch.nn.Module
    if isinstance(model, GlobalModelDB):
        global_model = model
    elif isinstance(model, LocalModelDB):
        global_model = model.base_model
    else:
        self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
        raise ValidationError(f"Unknown model type. Model id: {id}")
    if global_model.preprocessing is None:
        raise NotFound(f"Model '{id}' has no preprocessing model defined.")
    # NOTE: FileResponse does strange stuff with bytes
    #       and in case of sqlite the weights will be bytes and not a memoryview
    response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
    response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
    return response
get_models
get_models(request: HttpRequest) -> HttpResponse

Get a list of all global models associated with the requesting user.

A global model is deemed associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required

Returns:

Name Type Description
HttpResponse HttpResponse

Model list as JSON response.

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_models(self, request: HttpRequest) -> HttpResponse:
    """
    Get a list of all global models associated with the requesting user.

    A global model is deemed associated with a user if the user is either the owner of the model,
    or if the user is an actor or a participant in the model's training process.

    Args:
        request (HttpRequest): The incoming request object.

    Returns:
        HttpResponse: Model list as JSON response.
    """
    models = self._get_user_related_global_models(request.user)
    serializer = ModelSerializer(models, many=True)
    return Response(serializer.data)
get_training_models
get_training_models(request: HttpRequest, training_id: str) -> HttpResponse

Get a list of all models associated with a specific training process and the requesting user.

A model is deemed associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
training_id
str

The unique identifier of the training process.

required

Returns:

Name Type Description
HttpResponse HttpResponse

Model list as JSON response.

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
    """
    Get a list of all models associated with a specific training process and the requesting user.

    A model is deemed associated with a user if the user is either the owner of the model,
    or if the user is an actor or a participant in the model's training process.

    Args:
        request (HttpRequest): The incoming request object.
        training_id (str): The unique identifier of the training process.

    Returns:
        HttpResponse: Model list as JSON response.
    """
    global_models = self._get_user_related_global_models(request.user)
    global_models = self._filter_by_training(global_models, training_id)
    local_models = self._get_local_models_for_global_models(global_models)
    serializer = ModelSerializer([*global_models, *local_models], many=True)
    return Response(serializer.data)
get_training_models_latest
get_training_models_latest(request: HttpRequest, training_id: str) -> HttpResponse

Get a list of the latest models for a specific training process associated with the requesting user.

A model is considered associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process. The latest model refers to the model from the most recent round (highest round number) of a participant's training process.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
training_id
str

The unique identifier of the training process.

required

Returns:

Name Type Description
HttpResponse HttpResponse

Model list as JSON response.

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
    """
    Get a list of the latest models for a specific training process associated with the requesting user.

    A model is considered associated with a user if the user is either the owner of the model,
    or if the user is an actor or a participant in the model's training process.
    The latest model refers to the model from the most recent round (highest round number) of
    a participant's training process.

    Args:
        request (HttpRequest): The incoming request object.
        training_id (str): The unique identifier of the training process.

    Returns:
        HttpResponse: Model list as JSON response.
    """
    models: List[ModelDB] = []
    # add latest global model
    global_models = self._get_user_related_global_models(request.user)
    global_models = self._filter_by_training(global_models, training_id)
    models.append(max(global_models, key=lambda m: m.round))
    # add latest local models
    local_models = self._get_local_models_for_global_models(global_models)
    local_models = sorted(local_models, key=lambda m: str(m.owner.pk))  # required for groupby
    for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
        models.append(max(group, key=lambda m: m.round))
    serializer = ModelSerializer(models, many=True)
    return Response(serializer.data)
remove_model
remove_model(request: HttpRequest, id: str) -> HttpResponse

Remove an existing model.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
id
str

The unique identifier of the model.

required

Returns:

Name Type Description
HttpResponse HttpResponse

200 Response if model was removed, else corresponding error code

Source code in fl_server_api/views/model.py
@extend_schema(responses={
    status.HTTP_200_OK: inline_serializer(
        "DeleteModelSuccessSerializer",
        fields={
            "detail": CharField(default="Model removed!")
        }
    ),
    status.HTTP_400_BAD_REQUEST: ErrorSerializer,
    status.HTTP_403_FORBIDDEN: error_response_403,
})
def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Remove an existing model.

    Args:
        request (HttpRequest): The incoming request object.
        id (str): The unique identifier of the model.

    Returns:
        HttpResponse: 200 Response if model was removed, else corresponding error code
    """
    model = get_entity(ModelDB, pk=id)
    if model.owner != request.user:
        training = model.get_training()
        if training is None or training.actor != request.user:
            raise PermissionDenied(
                "You are neither the owner of the model nor the actor of the corresponding training."
            )
    model.delete()
    return JsonResponse({"detail": "Model removed!"})
upload_model_preprocessing
upload_model_preprocessing(request: HttpRequest, id: str) -> HttpResponse

Upload a preprocessing model file for a global model.

The preprocessing model file should be a PyTorch serialized model. Providing the model via torch.save as well as in TorchScript format is supported.

transforms = torch.nn.Sequential(
    torchvision.transforms.CenterCrop(10),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)

Make sure to only use transformations that inherit from torch.nn.Module. It is advised to use the torchvision.transforms.v2 module for common transformations.

Please note that this function is still in the beta phase.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

global model UUID

required

Raises:

Type Description
PermissionDenied

Unauthorized to upload preprocessing model for the specified model

ValidationError

Preprocessing model is not a valid torch model

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "model_preprocessing_file": {"type": "string", "format": "binary"},
            },
        },
    },
    responses={
        status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
            "detail": CharField(default="Proprocessing Model Upload Accepted"),
            "model_id": UUIDField(),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Upload a preprocessing model file for a global model.

    The preprocessing model file should be a PyTorch serialized model.
    Providing the model via `torch.save` as well as in TorchScript format is supported.

    ```python
    transforms = torch.nn.Sequential(
        torchvision.transforms.CenterCrop(10),
        torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    )
    ```

    Make sure to only use transformations that inherit from `torch.nn.Module`.
    It is advised to use the `torchvision.transforms.v2` module for common transformations.

    Please note that this function is still in the beta phase.

    Args:
        request (HttpRequest): request object
        id (str): global model UUID

    Raises:
        PermissionDenied: Unauthorized to upload preprocessing model for the specified model
        ValidationError: Preprocessing model is not a valid torch model

    Returns:
        HttpResponse: upload success message as json response
    """
    model = get_entity(GlobalModelDB, pk=id)
    if request.user.id != model.owner.id:
        raise PermissionDenied(f"You are not the owner of model {model.id}!")
    model.preprocessing = get_file(request, "model_preprocessing_file")
    verify_model_object(model.preprocessing, "preprocessing")
    model.save()
    return JsonResponse({
        "detail": "Proprocessing Model Upload Accepted",
    }, status=status.HTTP_202_ACCEPTED)

Training

Bases: ViewSet


              flowchart TD
              fl_server_api.views.Training[Training]
              fl_server_api.views.base.ViewSet[ViewSet]

                              fl_server_api.views.base.ViewSet --> fl_server_api.views.Training
                


              click fl_server_api.views.Training href "" "fl_server_api.views.Training"
              click fl_server_api.views.base.ViewSet href "" "fl_server_api.views.base.ViewSet"
            

Training model ViewSet.

This ViewSet is used to create and manage trainings.

Methods:

Name Description
create_training

Create a new training process.

get_training

Get information about the selected training.

get_trainings

Get information about all owned trainings.

register_clients

Register one or more clients for a training process.

remove_clients

Remove one or more clients from a training process.

remove_training

Remove an existing training process.

start_training

Start a training process.

Attributes:

Name Type Description
serializer_class

The serializer for the ViewSet.

Source code in fl_server_api/views/training.py
class Training(ViewSet):
    """
    Training model ViewSet.

    This ViewSet is used to create and manage trainings.
    """

    serializer_class = TrainingSerializer
    """The serializer for the ViewSet."""

    def _check_user_permission_for_training(self, user: UserDB, training_id: UUID | str) -> TrainingDB:
        """
        Check if a user has permission for a training.

        This method checks if the user is the actor of the training or a participant in the training.

        Args:
            user (UserDB): The user.
            training_id (UUID | str): The ID of the training.

        Returns:
            TrainingDB: The training.
        """
        if isinstance(training_id, str):
            training_id = UUID(training_id)
        training = get_entity(TrainingDB, pk=training_id)
        if training.actor != user and user not in training.participants.all():
            raise PermissionDenied()
        return training

    def _get_clients_from_body(self, body_raw: bytes) -> list[UserDB]:
        """
        Get clients or participants from a request body.

        This method retrieves and loads all client data associated with the provided list of UUIDs contained
        within the request's clients field in the request body.

        Args:
            body_raw (bytes): The raw request body.

        Returns:
            list[UserDB]: The clients.
        """
        body: ClientAdministrationBody = self._load_marshmallow_request(ClientAdministrationBodySchema(), body_raw)
        return self._get_clients_from_uuid_list(body.clients)

    def _get_clients_from_uuid_list(self, uuids: list[UUID]) -> list[UserDB]:
        """
        Get clients from a list of UUIDs.

        This method gets the clients with the IDs in the list of UUIDs from the database.

        Args:
            uuids (list[UUID]): The list of UUIDs.

        Returns:
            list[UserDB]: The clients.
        """
        if uuids is None or len(uuids) == 0:
            return []
        # Note: filter "in" does not raise UserDB.DoesNotExist exceptions
        clients = UserDB.objects.filter(id__in=uuids)
        if len(clients) != len(uuids):
            raise ParseError("Not all provided users were found!")
        return clients

    def _load_marshmallow_request(self, schema: Schema, json_data: str | bytes | bytearray):
        """
        Load JSON data using from a request using a Marshmallow schema.

        Args:
            schema (Schema): The Marshmallow schema to use for loading the request.
            json_data (str | bytes | bytearray): The JSON data to load.

        Raises:
            ParseError: If a MarshmallowValidationError occurs.

        Returns:
            dict: The loaded data.
        """
        try:
            return schema.load(json.loads(json_data))  # should `schema.loads` be used instead?
        except MarshmallowValidationError as e:
            raise ParseError(e.messages) from e

    @extend_schema(responses={
        status.HTTP_200_OK: TrainingSerializer(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def get_trainings(self, request: HttpRequest) -> HttpResponse:
        """
        Get information about all owned trainings.

        Args:
            request (HttpRequest):  request object

        Returns:
            HttpResponse: list of training data as json response
        """
        trainings = TrainingDB.objects.filter(actor=request.user)
        serializer = TrainingSerializer(trainings, many=True)
        return Response(serializer.data)

    @extend_schema(responses={
        status.HTTP_200_OK: TrainingSerializerWithRounds,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def get_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Get information about the selected training.

        Args:
            request (HttpRequest):  request object
            id (str):  training uuid

        Returns:
            HttpResponse: training data as json response
        """
        train = self._check_user_permission_for_training(request.user, id)
        serializer = TrainingSerializerWithRounds(train)
        return Response(serializer.data)

    @extend_schema(
        request=inline_serializer("EmptyBodySerializer", fields={}),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "StartTrainingSuccessSerializer",
                fields={
                    "detail": CharField(default="Training started!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def start_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Start a training process.

        This method checks if there are any participants registered for the training process.
        If there are participants, it checks if the training process is in the INITIAL state and starts the training
        session.

        Args:
            request (HttpRequest): The request object, which includes information about the user making the request.
            id (str): The UUID of the training process to start.

        Raises:
            ParseError: If there are no participants registered for the training process or if the training process
                is not in the INITIAL state.

        Returns:
            HttpResponse: A JSON response indicating that the training process has started.
        """
        training = self._check_user_permission_for_training(request.user, id)
        if training.participants.count() == 0:
            raise ParseError("At least one participant must be registered!")
        if training.state != TrainingState.INITIAL:
            raise ParseError(f"Training {training.id} is not in state INITIAL!")
        ModelTrainer(training).start()
        return JsonResponse({"detail": "Training started!"}, status=status.HTTP_202_ACCEPTED)

    @extend_schema(
        request=inline_serializer(
            "RegisterClientsSerializer",
            fields={
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "RegisteredClientsSuccessSerializer",
                fields={
                    "detail": CharField(default="Users registered as participants!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def register_clients(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Register one or more clients for a training process.

        This method is designed to be called by a POST request with a JSON body of the form
        `{"clients": [<list of UUIDs>]}`.
        It adds these clients as participants of the training process.

        Note: This method should be called once before the training process is started.

        Args:
            request (HttpRequest): The request object.
            id (str): The UUID of the training process.

        Returns:
            HttpResponse: 202 Response if clients were registered, else corresponding error code.
        """
        train = self._check_user_permission_for_training(request.user, id)
        clients = self._get_clients_from_body(request.body)
        train.participants.add(*clients)
        return JsonResponse({"detail": "Users registered as participants!"}, status=status.HTTP_202_ACCEPTED)

    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteTrainingSuccessSerializer",
            fields={
                "detail": CharField(default="Training removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_training(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing training process.

        Args:
            request (HttpRequest):  request object
            id (str):  training uuid

        Returns:
            HttpResponse: 200 Response if training was removed, else corresponding error code
        """
        training = get_entity(TrainingDB, pk=id)
        if training.actor != request.user:
            raise PermissionDenied("You are not the owner the training.")
        training.delete()
        return JsonResponse({"detail": "Training removed!"})

    @extend_schema(
        request=inline_serializer(
            "RemoveClientsSerializer",
            fields={
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer(
                "RemovedClientsSuccessSerializer",
                fields={
                    "detail": CharField(default="Users removed from training participants!")
                }
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def remove_clients(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove one or more clients from a training process.

        This method is designed to modify an already existing training process.

        Args:
            request (HttpRequest): The request object.
            id (str): The UUID of the training process.

        Returns:
            HttpResponse: 200 Response if clients were removed, else corresponding error code.
        """
        train = self._check_user_permission_for_training(request.user, id)
        clients = self._get_clients_from_body(request.body)
        train.participants.remove(*clients)
        return JsonResponse({"detail": "Users removed from training participants!"})

    @extend_schema(
        request=inline_serializer(
            name="TrainingCreationSerializer",
            fields={
                "model_id": CharField(),
                "target_num_updates": IntegerField(),
                "metric_names": ListField(child=CharField()),
                "aggregation_method": CharField(),
                "clients": ListField(child=UUIDField())
            }
        ),
        responses={
            status.HTTP_200_OK: inline_serializer("TrainingCreatedSerializer", fields={
                "detail": CharField(default="Training created successfully!"),
                "training_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def create_training(self, request: HttpRequest) -> HttpResponse:
        """
        Create a new training process.

        This method is designed to be called by a POST request according to the `CreateTrainingRequestSchema`.
        The request should include a model file (the initial model) as an attached FILE.

        Args:
            request (HttpRequest):  The request object.

        Returns:
            HttpResponse: 201 if training could be registered.
        """
        parsed_request: CreateTrainingRequest = self._load_marshmallow_request(
            CreateTrainingRequestSchema(),
            request.body.decode("utf-8")
        )
        model = get_entity(ModelDB, pk=parsed_request.model_id)
        if model.owner != request.user:
            raise PermissionDenied()
        if TrainingDB.objects.filter(model=model).exists():
            # the selected model is already referenced by another training, so we need to copy it
            model = clone_model(model)

        clients = self._get_clients_from_uuid_list(parsed_request.clients)
        train = TrainingDB.objects.create(
            model=model,
            actor=request.user,
            target_num_updates=parsed_request.target_num_updates,
            state=TrainingState.INITIAL,
            uncertainty_method=parsed_request.uncertainty_method.value,
            aggregation_method=parsed_request.aggregation_method.value,
            options=parsed_request.options
        )
        train.participants.add(*clients)
        return JsonResponse({
            "detail": "Training created successfully!",
            "training_id": train.id
        }, status=status.HTTP_201_CREATED)

Attributes

serializer_class class-attribute instance-attribute
serializer_class = TrainingSerializer

The serializer for the ViewSet.

Functions

create_training
create_training(request: HttpRequest) -> HttpResponse

Create a new training process.

This method is designed to be called by a POST request according to the CreateTrainingRequestSchema. The request should include a model file (the initial model) as an attached FILE.

Parameters:

Name Type Description Default
request
HttpRequest

The request object.

required

Returns:

Name Type Description
HttpResponse HttpResponse

201 if training could be registered.

Source code in fl_server_api/views/training.py
@extend_schema(
    request=inline_serializer(
        name="TrainingCreationSerializer",
        fields={
            "model_id": CharField(),
            "target_num_updates": IntegerField(),
            "metric_names": ListField(child=CharField()),
            "aggregation_method": CharField(),
            "clients": ListField(child=UUIDField())
        }
    ),
    responses={
        status.HTTP_200_OK: inline_serializer("TrainingCreatedSerializer", fields={
            "detail": CharField(default="Training created successfully!"),
            "training_id": UUIDField(),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    }
)
def create_training(self, request: HttpRequest) -> HttpResponse:
    """
    Create a new training process.

    This method is designed to be called by a POST request according to the `CreateTrainingRequestSchema`.
    The request should include a model file (the initial model) as an attached FILE.

    Args:
        request (HttpRequest):  The request object.

    Returns:
        HttpResponse: 201 if training could be registered.
    """
    parsed_request: CreateTrainingRequest = self._load_marshmallow_request(
        CreateTrainingRequestSchema(),
        request.body.decode("utf-8")
    )
    model = get_entity(ModelDB, pk=parsed_request.model_id)
    if model.owner != request.user:
        raise PermissionDenied()
    if TrainingDB.objects.filter(model=model).exists():
        # the selected model is already referenced by another training, so we need to copy it
        model = clone_model(model)

    clients = self._get_clients_from_uuid_list(parsed_request.clients)
    train = TrainingDB.objects.create(
        model=model,
        actor=request.user,
        target_num_updates=parsed_request.target_num_updates,
        state=TrainingState.INITIAL,
        uncertainty_method=parsed_request.uncertainty_method.value,
        aggregation_method=parsed_request.aggregation_method.value,
        options=parsed_request.options
    )
    train.participants.add(*clients)
    return JsonResponse({
        "detail": "Training created successfully!",
        "training_id": train.id
    }, status=status.HTTP_201_CREATED)
get_training
get_training(request: HttpRequest, id: str) -> HttpResponse

Get information about the selected training.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

training uuid

required

Returns:

Name Type Description
HttpResponse HttpResponse

training data as json response

Source code in fl_server_api/views/training.py
@extend_schema(responses={
    status.HTTP_200_OK: TrainingSerializerWithRounds,
    status.HTTP_400_BAD_REQUEST: ErrorSerializer,
    status.HTTP_403_FORBIDDEN: error_response_403,
})
def get_training(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Get information about the selected training.

    Args:
        request (HttpRequest):  request object
        id (str):  training uuid

    Returns:
        HttpResponse: training data as json response
    """
    train = self._check_user_permission_for_training(request.user, id)
    serializer = TrainingSerializerWithRounds(train)
    return Response(serializer.data)
get_trainings
get_trainings(request: HttpRequest) -> HttpResponse

Get information about all owned trainings.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required

Returns:

Name Type Description
HttpResponse HttpResponse

list of training data as json response

Source code in fl_server_api/views/training.py
@extend_schema(responses={
    status.HTTP_200_OK: TrainingSerializer(many=True),
    status.HTTP_400_BAD_REQUEST: ErrorSerializer,
    status.HTTP_403_FORBIDDEN: error_response_403,
})
def get_trainings(self, request: HttpRequest) -> HttpResponse:
    """
    Get information about all owned trainings.

    Args:
        request (HttpRequest):  request object

    Returns:
        HttpResponse: list of training data as json response
    """
    trainings = TrainingDB.objects.filter(actor=request.user)
    serializer = TrainingSerializer(trainings, many=True)
    return Response(serializer.data)
register_clients
register_clients(request: HttpRequest, id: str) -> HttpResponse

Register one or more clients for a training process.

This method is designed to be called by a POST request with a JSON body of the form {"clients": [<list of UUIDs>]}. It adds these clients as participants of the training process.

Note: This method should be called once before the training process is started.

Parameters:

Name Type Description Default
request
HttpRequest

The request object.

required
id
str

The UUID of the training process.

required

Returns:

Name Type Description
HttpResponse HttpResponse

202 Response if clients were registered, else corresponding error code.

Source code in fl_server_api/views/training.py
@extend_schema(
    request=inline_serializer(
        "RegisterClientsSerializer",
        fields={
            "clients": ListField(child=UUIDField())
        }
    ),
    responses={
        status.HTTP_200_OK: inline_serializer(
            "RegisteredClientsSuccessSerializer",
            fields={
                "detail": CharField(default="Users registered as participants!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    }
)
def register_clients(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Register one or more clients for a training process.

    This method is designed to be called by a POST request with a JSON body of the form
    `{"clients": [<list of UUIDs>]}`.
    It adds these clients as participants of the training process.

    Note: This method should be called once before the training process is started.

    Args:
        request (HttpRequest): The request object.
        id (str): The UUID of the training process.

    Returns:
        HttpResponse: 202 Response if clients were registered, else corresponding error code.
    """
    train = self._check_user_permission_for_training(request.user, id)
    clients = self._get_clients_from_body(request.body)
    train.participants.add(*clients)
    return JsonResponse({"detail": "Users registered as participants!"}, status=status.HTTP_202_ACCEPTED)
remove_clients
remove_clients(request: HttpRequest, id: str) -> HttpResponse

Remove one or more clients from a training process.

This method is designed to modify an already existing training process.

Parameters:

Name Type Description Default
request
HttpRequest

The request object.

required
id
str

The UUID of the training process.

required

Returns:

Name Type Description
HttpResponse HttpResponse

200 Response if clients were removed, else corresponding error code.

Source code in fl_server_api/views/training.py
@extend_schema(
    request=inline_serializer(
        "RemoveClientsSerializer",
        fields={
            "clients": ListField(child=UUIDField())
        }
    ),
    responses={
        status.HTTP_200_OK: inline_serializer(
            "RemovedClientsSuccessSerializer",
            fields={
                "detail": CharField(default="Users removed from training participants!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    }
)
def remove_clients(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Remove one or more clients from a training process.

    This method is designed to modify an already existing training process.

    Args:
        request (HttpRequest): The request object.
        id (str): The UUID of the training process.

    Returns:
        HttpResponse: 200 Response if clients were removed, else corresponding error code.
    """
    train = self._check_user_permission_for_training(request.user, id)
    clients = self._get_clients_from_body(request.body)
    train.participants.remove(*clients)
    return JsonResponse({"detail": "Users removed from training participants!"})
remove_training
remove_training(request: HttpRequest, id: str) -> HttpResponse

Remove an existing training process.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

training uuid

required

Returns:

Name Type Description
HttpResponse HttpResponse

200 Response if training was removed, else corresponding error code

Source code in fl_server_api/views/training.py
@extend_schema(responses={
    status.HTTP_200_OK: inline_serializer(
        "DeleteTrainingSuccessSerializer",
        fields={
            "detail": CharField(default="Training removed!")
        }
    ),
    status.HTTP_400_BAD_REQUEST: ErrorSerializer,
    status.HTTP_403_FORBIDDEN: error_response_403,
})
def remove_training(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Remove an existing training process.

    Args:
        request (HttpRequest):  request object
        id (str):  training uuid

    Returns:
        HttpResponse: 200 Response if training was removed, else corresponding error code
    """
    training = get_entity(TrainingDB, pk=id)
    if training.actor != request.user:
        raise PermissionDenied("You are not the owner the training.")
    training.delete()
    return JsonResponse({"detail": "Training removed!"})
start_training
start_training(request: HttpRequest, id: str) -> HttpResponse

Start a training process.

This method checks if there are any participants registered for the training process. If there are participants, it checks if the training process is in the INITIAL state and starts the training session.

Parameters:

Name Type Description Default
request
HttpRequest

The request object, which includes information about the user making the request.

required
id
str

The UUID of the training process to start.

required

Raises:

Type Description
ParseError

If there are no participants registered for the training process or if the training process is not in the INITIAL state.

Returns:

Name Type Description
HttpResponse HttpResponse

A JSON response indicating that the training process has started.

Source code in fl_server_api/views/training.py
@extend_schema(
    request=inline_serializer("EmptyBodySerializer", fields={}),
    responses={
        status.HTTP_200_OK: inline_serializer(
            "StartTrainingSuccessSerializer",
            fields={
                "detail": CharField(default="Training started!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    }
)
def start_training(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Start a training process.

    This method checks if there are any participants registered for the training process.
    If there are participants, it checks if the training process is in the INITIAL state and starts the training
    session.

    Args:
        request (HttpRequest): The request object, which includes information about the user making the request.
        id (str): The UUID of the training process to start.

    Raises:
        ParseError: If there are no participants registered for the training process or if the training process
            is not in the INITIAL state.

    Returns:
        HttpResponse: A JSON response indicating that the training process has started.
    """
    training = self._check_user_permission_for_training(request.user, id)
    if training.participants.count() == 0:
        raise ParseError("At least one participant must be registered!")
    if training.state != TrainingState.INITIAL:
        raise ParseError(f"Training {training.id} is not in state INITIAL!")
    ModelTrainer(training).start()
    return JsonResponse({"detail": "Training started!"}, status=status.HTTP_202_ACCEPTED)

User

Bases: ViewSet


              flowchart TD
              fl_server_api.views.User[User]
              fl_server_api.views.base.ViewSet[ViewSet]

                              fl_server_api.views.base.ViewSet --> fl_server_api.views.User
                


              click fl_server_api.views.User href "" "fl_server_api.views.User"
              click fl_server_api.views.base.ViewSet href "" "fl_server_api.views.base.ViewSet"
            

User model ViewSet.

Methods:

Name Description
create_user

Create a new user.

get_myself

Get current user.

get_user

Get user information.

get_user_groups

Get user groups.

get_user_trainings

Get user trainings.

get_users

Get all registered users as list.

Attributes:

Name Type Description
serializer_class

The serializer for the ViewSet.

Source code in fl_server_api/views/user.py
class User(ViewSet):
    """
    User model ViewSet.
    """

    serializer_class = UserSerializer
    """The serializer for the ViewSet."""

    @extend_schema(
        responses={
            status.HTTP_200_OK: UserSerializer(many=True),
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def get_users(self, request: HttpRequest) -> HttpResponse:
        """
        Get all registered users as list.

        Args:
            request (HttpRequest):  request object

        Returns:
            HttpResponse: user list as json response
        """
        serializer = UserSerializer(UserModel.objects.all(), many=True)
        return Response(serializer.data)

    def get_myself(self, request: HttpRequest) -> HttpResponse:
        """
        Get current user.

        Args:
            request (HttpRequest):  request object

        Returns:
            HttpResponse: user data as json response
        """
        serializer = UserSerializer(request.user)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: UserSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            OpenApiExample(
                name="Get user by id",
                description="Retrieve user data by user ID.",
                value="88a9f11a-846b-43b5-bd15-367fc332ba59",
                parameter_only=("id", OpenApiParameter.PATH)
            )
        ]
    )
    def get_user(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Get user information.

        Args:
            request (HttpRequest):  request object
            id (str):  user uuid

        Returns:
            HttpResponse: user as json response
        """
        serializer = UserSerializer(get_entity(UserModel, pk=id), context={"request_user_id": request.user.id})
        return Response(serializer.data)

    @decorators.authentication_classes([])
    @decorators.permission_classes([])
    @extend_schema(
        responses={
            status.HTTP_200_OK: UserSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        },
        examples=[
            OpenApiExample("Jane Doe", value={
                "message_endpoint": "http://example.com/",
                "actor": True,
                "client": True,
                "username": "jane",
                "first_name": "Jane",
                "last_name": "Doe",
                "email": "jane.doe@example.com",
                "password": "my-super-secret-password"
            })
        ]
    )
    def create_user(self, request: HttpRequest) -> HttpResponse:
        """
        Create a new user.

        Args:
            request (HttpRequest):  request object

        Returns:
            HttpResponse: new created user as json response
        """
        user = UserSerializer().create(request.data)
        serializer = UserSerializer(user, context={"request_user_id": user.id})
        return Response(serializer.data, status=status.HTTP_201_CREATED)

    @extend_schema(
        responses={
            status.HTTP_200_OK: GroupSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            OpenApiExample(
                name="User uuid",
                value="88a9f11a-846b-43b5-bd15-367fc332ba59",
                parameter_only=("id", OpenApiParameter.PATH)
            )
        ]
    )
    def get_user_groups(self, request: HttpRequest) -> HttpResponse:
        """
        Get user groups.

        Args:
            request (HttpRequest):  request object
            id (str):  user uuid

        Returns:
            HttpResponse: user groups as json response
        """
        serializer = GroupSerializer(request.user.groups, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: TrainingSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        }
    )
    def get_user_trainings(self, request: HttpRequest) -> HttpResponse:
        """
        Get user trainings.

        Args:
            request (HttpRequest):  request object

        Returns:
            HttpResponse: user trainings as json response
        """
        trainings = TrainingModel.objects.filter(Q(actor=request.user) | Q(participants=request.user)).distinct()
        serializer = TrainingSerializer(trainings, many=True)
        return Response(serializer.data)

Attributes

serializer_class class-attribute instance-attribute
serializer_class = UserSerializer

The serializer for the ViewSet.

Functions

create_user
create_user(request: HttpRequest) -> HttpResponse

Create a new user.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required

Returns:

Name Type Description
HttpResponse HttpResponse

new created user as json response

Source code in fl_server_api/views/user.py
@decorators.authentication_classes([])
@decorators.permission_classes([])
@extend_schema(
    responses={
        status.HTTP_200_OK: UserSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
    },
    examples=[
        OpenApiExample("Jane Doe", value={
            "message_endpoint": "http://example.com/",
            "actor": True,
            "client": True,
            "username": "jane",
            "first_name": "Jane",
            "last_name": "Doe",
            "email": "jane.doe@example.com",
            "password": "my-super-secret-password"
        })
    ]
)
def create_user(self, request: HttpRequest) -> HttpResponse:
    """
    Create a new user.

    Args:
        request (HttpRequest):  request object

    Returns:
        HttpResponse: new created user as json response
    """
    user = UserSerializer().create(request.data)
    serializer = UserSerializer(user, context={"request_user_id": user.id})
    return Response(serializer.data, status=status.HTTP_201_CREATED)
get_myself
get_myself(request: HttpRequest) -> HttpResponse

Get current user.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required

Returns:

Name Type Description
HttpResponse HttpResponse

user data as json response

Source code in fl_server_api/views/user.py
def get_myself(self, request: HttpRequest) -> HttpResponse:
    """
    Get current user.

    Args:
        request (HttpRequest):  request object

    Returns:
        HttpResponse: user data as json response
    """
    serializer = UserSerializer(request.user)
    return Response(serializer.data)
get_user
get_user(request: HttpRequest, id: str) -> HttpResponse

Get user information.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

user uuid

required

Returns:

Name Type Description
HttpResponse HttpResponse

user as json response

Source code in fl_server_api/views/user.py
@extend_schema(
    responses={
        status.HTTP_200_OK: UserSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[
        OpenApiExample(
            name="Get user by id",
            description="Retrieve user data by user ID.",
            value="88a9f11a-846b-43b5-bd15-367fc332ba59",
            parameter_only=("id", OpenApiParameter.PATH)
        )
    ]
)
def get_user(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Get user information.

    Args:
        request (HttpRequest):  request object
        id (str):  user uuid

    Returns:
        HttpResponse: user as json response
    """
    serializer = UserSerializer(get_entity(UserModel, pk=id), context={"request_user_id": request.user.id})
    return Response(serializer.data)
get_user_groups
get_user_groups(request: HttpRequest) -> HttpResponse

Get user groups.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

user uuid

required

Returns:

Name Type Description
HttpResponse HttpResponse

user groups as json response

Source code in fl_server_api/views/user.py
@extend_schema(
    responses={
        status.HTTP_200_OK: GroupSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[
        OpenApiExample(
            name="User uuid",
            value="88a9f11a-846b-43b5-bd15-367fc332ba59",
            parameter_only=("id", OpenApiParameter.PATH)
        )
    ]
)
def get_user_groups(self, request: HttpRequest) -> HttpResponse:
    """
    Get user groups.

    Args:
        request (HttpRequest):  request object
        id (str):  user uuid

    Returns:
        HttpResponse: user groups as json response
    """
    serializer = GroupSerializer(request.user.groups, many=True)
    return Response(serializer.data)
get_user_trainings
get_user_trainings(request: HttpRequest) -> HttpResponse

Get user trainings.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required

Returns:

Name Type Description
HttpResponse HttpResponse

user trainings as json response

Source code in fl_server_api/views/user.py
@extend_schema(
    responses={
        status.HTTP_200_OK: TrainingSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    }
)
def get_user_trainings(self, request: HttpRequest) -> HttpResponse:
    """
    Get user trainings.

    Args:
        request (HttpRequest):  request object

    Returns:
        HttpResponse: user trainings as json response
    """
    trainings = TrainingModel.objects.filter(Q(actor=request.user) | Q(participants=request.user)).distinct()
    serializer = TrainingSerializer(trainings, many=True)
    return Response(serializer.data)
get_users
get_users(request: HttpRequest) -> HttpResponse

Get all registered users as list.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required

Returns:

Name Type Description
HttpResponse HttpResponse

user list as json response

Source code in fl_server_api/views/user.py
@extend_schema(
    responses={
        status.HTTP_200_OK: UserSerializer(many=True),
        status.HTTP_403_FORBIDDEN: error_response_403,
    }
)
def get_users(self, request: HttpRequest) -> HttpResponse:
    """
    Get all registered users as list.

    Args:
        request (HttpRequest):  request object

    Returns:
        HttpResponse: user list as json response
    """
    serializer = UserSerializer(UserModel.objects.all(), many=True)
    return Response(serializer.data)