Skip to content

fl_server_api.views.model

Classes:

Name Description
Model

Model ViewSet.

Attributes

Classes

Model

Bases: ViewSet


              flowchart TD
              fl_server_api.views.model.Model[Model]
              fl_server_api.views.base.ViewSet[ViewSet]

                              fl_server_api.views.base.ViewSet --> fl_server_api.views.model.Model
                


              click fl_server_api.views.model.Model href "" "fl_server_api.views.model.Model"
              click fl_server_api.views.base.ViewSet href "" "fl_server_api.views.base.ViewSet"
            

Model ViewSet.

Methods:

Name Description
create_local_model

Upload a partial trained model file from client.

create_model

Upload a global model file.

create_model_metrics

Upload model metrics.

create_swag_stats

Upload SWAG statistics.

get_metadata

Get model meta data.

get_model

Download the whole model as PyTorch serialized file.

get_model_metrics

Reports all metrics for the selected model.

get_model_proprecessing

Download the whole preprocessing model as PyTorch serialized file.

get_models

Get a list of all global models associated with the requesting user.

get_training_models

Get a list of all models associated with a specific training process and the requesting user.

get_training_models_latest

Get a list of the latest models for a specific training process associated with the requesting user.

remove_model

Remove an existing model.

upload_model_preprocessing

Upload a preprocessing model file for a global model.

Attributes:

Name Type Description
serializer_class

The serializer for the ViewSet.

Source code in fl_server_api/views/model.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
class Model(ViewSet):
    """
    Model ViewSet.
    """

    serializer_class = GlobalModelSerializer
    """The serializer for the ViewSet."""

    def _get_user_related_global_models(self, user: Union[AbstractBaseUser, AnonymousUser]) -> List[ModelDB]:
        """
        Get global models related to a user.

        This method retrieves all global models where the user is the actor or a participant of.

        Args:
            user (Union[AbstractBaseUser, AnonymousUser]): The user.

        Returns:
            List[ModelDB]: The global models related to the user.
        """
        user_ids = Training.objects.filter(
            Q(actor=user) | Q(participants=user)
        ).distinct().values_list("model__id", flat=True)
        return ModelDB.objects.filter(Q(owner=user) | Q(id__in=user_ids)).distinct()

    def _get_local_models_for_global_model(self, global_model: GlobalModelDB) -> List[LocalModelDB]:
        """
        Get all local models that are based on the global model.

        Args:
            global_model (GlobalModelDB): The global model.

        Returns:
            List[LocalModelDB]: The local models for the global model.
        """
        return LocalModelDB.objects.filter(base_model=global_model).all()

    def _get_local_models_for_global_models(self, global_models: List[GlobalModelDB]) -> List[LocalModelDB]:
        """
        Get all local models that are based on any of the global models.

        Args:
            global_models (List[GlobalModelDB]): The global models.

        Returns:
            List[LocalModelDB]: The local models for the global models.
        """
        return LocalModelDB.objects.filter(base_model__in=global_models).all()

    def _filter_by_training(self, models: List[ModelDB], training_id: str) -> List[ModelDB]:
        """
        Filter a list of models by checking if they are associated with the training.

        Args:
            models (List[ModelDB]): The models to filter.
            training_id (str): The ID of the training.

        Returns:
            List[ModelDB]: The models associated with the training.
        """
        def associated_with_training(m: ModelDB) -> bool:
            training = m.get_training()
            if training is None:
                return False
            return training.pk == UUID(training_id)
        return list(filter(associated_with_training, models))

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_models(self, request: HttpRequest) -> HttpResponse:
        """
        Get a list of all global models associated with the requesting user.

        A global model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models = self._get_user_related_global_models(request.user)
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of all models associated with a specific training process and the requesting user.

        A model is deemed associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        local_models = self._get_local_models_for_global_models(global_models)
        serializer = ModelSerializer([*global_models, *local_models], many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
        """
        Get a list of the latest models for a specific training process associated with the requesting user.

        A model is considered associated with a user if the user is either the owner of the model,
        or if the user is an actor or a participant in the model's training process.
        The latest model refers to the model from the most recent round (highest round number) of
        a participant's training process.

        Args:
            request (HttpRequest): The incoming request object.
            training_id (str): The unique identifier of the training process.

        Returns:
            HttpResponse: Model list as JSON response.
        """
        models: List[ModelDB] = []
        # add latest global model
        global_models = self._get_user_related_global_models(request.user)
        global_models = self._filter_by_training(global_models, training_id)
        models.append(max(global_models, key=lambda m: m.round))
        # add latest local models
        local_models = self._get_local_models_for_global_models(global_models)
        local_models = sorted(local_models, key=lambda m: str(m.owner.pk))  # required for groupby
        for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
            models.append(max(group, key=lambda m: m.round))
        serializer = ModelSerializer(models, many=True)
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
        """
        Get model meta data.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: Model meta data as JSON response.
        """
        model = get_entity(ModelDB, pk=id)
        serializer = ModelSerializer(model, context={"with-stats": True})
        return Response(serializer.data)

    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: model as file response
        """
        model = get_entity(ModelDB, pk=id)
        if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
            if model.swag_second_moment is None:
                raise APIException(f"Model {model.id} is in inconsistent state!")
            raise NotImplementedError(
                "SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
            )
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(model.weights, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}.pt"'
        return response

    @extend_schema(
        responses={
            status.HTTP_200_OK: OpenApiResponse(
                response=bytes,
                description="Proprecessing model is returned as bytes"
            ),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
            status.HTTP_404_NOT_FOUND: error_response_404,
        },
    )
    def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
        """
        Download the whole preprocessing model as PyTorch serialized file.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
        """
        model = get_entity(ModelDB, pk=id)
        global_model: torch.nn.Module
        if isinstance(model, GlobalModelDB):
            global_model = model
        elif isinstance(model, LocalModelDB):
            global_model = model.base_model
        else:
            self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
            raise ValidationError(f"Unknown model type. Model id: {id}")
        if global_model.preprocessing is None:
            raise NotFound(f"Model '{id}' has no preprocessing model defined.")
        # NOTE: FileResponse does strange stuff with bytes
        #       and in case of sqlite the weights will be bytes and not a memoryview
        response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
        response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
        return response

    @extend_schema(responses={
        status.HTTP_200_OK: inline_serializer(
            "DeleteModelSuccessSerializer",
            fields={
                "detail": CharField(default="Model removed!")
            }
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    })
    def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Remove an existing model.

        Args:
            request (HttpRequest): The incoming request object.
            id (str): The unique identifier of the model.

        Returns:
            HttpResponse: 200 Response if model was removed, else corresponding error code
        """
        model = get_entity(ModelDB, pk=id)
        if model.owner != request.user:
            training = model.get_training()
            if training is None or training.actor != request.user:
                raise PermissionDenied(
                    "You are neither the owner of the model nor the actor of the corresponding training."
                )
        model.delete()
        return JsonResponse({"detail": "Model removed!"})

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "name": {"type": "string"},
                    "description": {"type": "string"},
                    "model_file": {"type": "string", "format": "binary"},
                    "model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
                },
            },
        },
        responses={
            status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
                "detail": CharField(default="Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_model(self, request: HttpRequest) -> HttpResponse:
        """
        Upload a global model file.

        The model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        Args:
            request (HttpRequest): The incoming request object.

        Returns:
            HttpResponse: upload success message as json response
        """
        model = load_and_create_model_request(request)
        return JsonResponse({
            "detail": "Model Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "model_preprocessing_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
                "detail": CharField(default="Proprocessing Model Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a preprocessing model file for a global model.

        The preprocessing model file should be a PyTorch serialized model.
        Providing the model via `torch.save` as well as in TorchScript format is supported.

        ```python
        transforms = torch.nn.Sequential(
            torchvision.transforms.CenterCrop(10),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        )
        ```

        Make sure to only use transformations that inherit from `torch.nn.Module`.
        It is advised to use the `torchvision.transforms.v2` module for common transformations.

        Please note that this function is still in the beta phase.

        Args:
            request (HttpRequest): request object
            id (str): global model UUID

        Raises:
            PermissionDenied: Unauthorized to upload preprocessing model for the specified model
            ValidationError: Preprocessing model is not a valid torch model

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(GlobalModelDB, pk=id)
        if request.user.id != model.owner.id:
            raise PermissionDenied(f"You are not the owner of model {model.id}!")
        model.preprocessing = get_file(request, "model_preprocessing_file")
        verify_model_object(model.preprocessing, "preprocessing")
        model.save()
        return JsonResponse({
            "detail": "Proprocessing Model Upload Accepted",
        }, status=status.HTTP_202_ACCEPTED)

    @extend_schema(
        responses={
            status.HTTP_200_OK: MetricSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Reports all metrics for the selected model.

        Args:
            request (HttpRequest):  request object
            id (str):  model UUID

        Returns:
            HttpResponse: Metrics as JSON Array
        """
        model = get_entity(ModelDB, pk=id)
        metrics = MetricDB.objects.filter(model=model).all()
        return Response(MetricSerializer(metrics, many=True).data)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "metric_names": {"type": "list"},
                    "metric_values": {"type": "list"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
                "detail": CharField(default="Model Metrics Upload Accepted"),
                "model_id": UUIDField(),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
        examples=[
            OpenApiExample("Example", value={
                "metric_names": ["accuracy", "training loss"],
                "metric_values": [0.6, 0.04]
            }, media_type="multipart/form-data")
        ]
    )
    def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload model metrics.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid

        Returns:
            HttpResponse: upload success message as json response
        """
        model = get_entity(ModelDB, pk=id)
        formdata = dict(request.POST)

        with locked_atomic_transaction(MetricDB):
            self._metric_upload(formdata, model, request.user)

            if isinstance(model, GlobalModelDB):
                n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
                training = model.get_training()
                if training:
                    if n_metrics == training.participants.count():
                        dispatch_trainer_task(training, ModelTestFinished, False)
                else:
                    self._logger.warning(f"Global model {id} is not connected to any training.")

        return JsonResponse({
            "detail": "Model Metrics Upload Accepted",
            "model_id": str(model.id),
        }, status=status.HTTP_201_CREATED)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "owner": {"type": "string"},
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "metric_names": {"type": "list[string]"},
                    "metric_values": {"type": "list[float]"},
                    "model_file": {"type": "string", "format": "binary"},
                },
            },
        },
        responses={
            status.HTTP_200_OK: ModelSerializer,
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload a partial trained model file from client.

        Args:
            request (HttpRequest):  request object
            id (str):  model uuid of the model, which was used for training

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            client = request.user
            model_file = get_file(request, "model_file")
            global_model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=global_model)
            self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)

            verify_model_object(model_file)
            local_model = LocalModelDB.objects.create(
                base_model=global_model, weights=model_file,
                round=round_num, owner=client, sample_size=sample_size
            )
            self._metric_upload(formdata, local_model, client, metrics_required=False)

            updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
            if updates.count() == training.participants.count():
                dispatch_trainer_task(training, TrainingRoundFinished, True)

            return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)

    @extend_schema(
        request={
            "multipart/form-data": {
                "type": "object",
                "properties": {
                    "round": {"type": "int"},
                    "sample_size": {"type": "int"},
                    "first_moment_file": {"type": "string", "format": "binary"},
                    "second_moment_file": {"type": "string", "format": "binary"}
                },
            },
        },
        responses={
            status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
                "detail": CharField(default="SWAg Statistics Accepted"),
            }),
            status.HTTP_400_BAD_REQUEST: ErrorSerializer,
            status.HTTP_403_FORBIDDEN: error_response_403,
        },
    )
    def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
        """
        Upload SWAG statistics.

        Args:
            request (HttpRequest): request object
            id (str): global model uuid

        Raises:
            APIException: internal server error
            NotFound: model not found
            ParseError: request data not valid

        Returns:
            HttpResponse: upload success message as json response
        """
        try:
            client = request.user
            formdata = dict(request.POST)
            (round_num,) = formdata["round"]
            (sample_size,) = formdata["sample_size"]
            round_num, sample_size = int(round_num), int(sample_size)
            fst_moment = get_file(request, "first_moment_file")
            snd_moment = get_file(request, "second_moment_file")
            model = get_entity(GlobalModelDB, pk=id)

            # ensure that a training process coresponding to the model exists, else the process will error out
            training = Training.objects.get(model=model)
            self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)

            self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)

            swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
            swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")

            if swag_stats_first.count() != swag_stats_second.count():
                training.state = TrainingState.ERROR
                raise APIException("SWAG stats in inconsistent state!")
            if swag_stats_first.count() == training.participants.count():
                dispatch_trainer_task(training, SWAGRoundFinished, True)

            return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
        except Training.DoesNotExist:
            raise NotFound(f"Model with ID {id} does not have a training process running")
        except (MultiValueDictKeyError, KeyError) as e:
            raise ParseError(e)
        except Exception as e:
            raise APIException(e)

    @staticmethod
    def _save_swag_stats(
        fst_moment: bytes, snd_moment: bytes, model: GlobalModelDB, client: UserDB, sample_size: int
    ):
        """
        Save the first and second moments, and the sample size of the SWAG to the database.

        This function creates and saves three metrics for each round of the model:

        - the first moment,
        - the second moment, and
        - the sample size.

        These metrics are associated with the model, the round, and the client that reported them.

        Args:
            fst_moment (bytes): The first moment of the SWAG.
            snd_moment (bytes): The second moment of the SWAG.
            model (GlobalModelDB): The global model for which the metrics are being reported.
            client (UserDB): The client reporting the metrics.
            sample_size (int): The sample size of the SWAG.
        """
        MetricDB.objects.create(
            model=model,
            key="SWAG First Moment Local",
            value_binary=fst_moment,
            step=model.round,
            reporter=client
        ).save()
        MetricDB.objects.create(
            model=model,
            key="SWAG Second Moment Local",
            value_binary=snd_moment,
            step=model.round,
            reporter=client
        ).save()
        MetricDB.objects.create(
            model=model,
            key="SWAG Sample Size Local",
            value_float=sample_size,
            step=model.round,
            reporter=client
        ).save()

    @transaction.atomic()
    def _metric_upload(self, formdata: dict, model: ModelDB, client: UserDB, metrics_required: bool = True):
        """
        Uploads metrics associated with a model.

        For each pair of metric name and value, it attempts to convert the value to a float.
        If this fails, it treats the value as a binary string.

        It then creates a new metric object with the model, the metric name, the float or binary value,
        the model's round number, and the client, and saves this object to the database.

        Args:
            formdata (dict): The form data containing the metric names and values.
            model (ModelDB): The model with which the metrics are associated.
            client (UserDB): The client reporting the metrics.
            metrics_required (bool): A flag indicating whether metrics are required. Defaults to True.

        Raises:
            ParseError: If `metric_names` or `metric_values` are not in formdata,
                or if they do not have the same length and metrics are required.
        """
        if "metric_names" not in formdata or "metric_values" not in formdata:
            if metrics_required or ("metric_names" in formdata) != ("metric_values" in formdata):
                raise ParseError("Metric names or values are missing")
            return
        if len(formdata["metric_names"]) != len(formdata["metric_values"]):
            if metrics_required:
                raise ParseError("Metric names and values must have the same length")
            return

        for key, value in zip(formdata["metric_names"], formdata["metric_values"]):
            try:
                metric_float = float(value)
                metric_binary = None
            except Exception:
                metric_float = None
                metric_binary = bytes(value, encoding="utf-8")
            MetricDB.objects.create(
                model=model,
                key=key,
                value_float=metric_float,
                value_binary=metric_binary,
                step=model.round,
                reporter=client
            ).save()

    def _verify_valid_update(self, client: UserDB, train: Training, round_num: int, expected_state: tuple[str, Any]):
        """
        Verifies if a client can update a training process.

        This function checks if

        - the client is a participant of the training process,
        - the training process is in the expected state, and if
        - the round number matches the current round of the model associated with the training process.

        Args:
            client (UserDB): The client attempting to update the training process.
            train (Training): The training process to be updated.
            round_num (int): The round number reported by the client.
            expected_state (tuple[str, Any]): The expected state of the training process.

        Raises:
            PermissionDenied: If the client is not a participant of the training process.
            ValidationError: If the training process is not in the expected state or if the round number does not match
                the current round of the model.
        """
        if client.id not in [p.id for p in train.participants.all()]:
            raise PermissionDenied(f"You are not a participant of training {train.id}!")
        if train.state != expected_state:
            raise ValidationError(f"Training with ID {train.id} is in state {train.state}")
        if int(round_num) != train.model.round:
            raise ValidationError(f"Training with ID {train.id} is not currently in round {round_num}")

Attributes

serializer_class class-attribute instance-attribute
serializer_class = GlobalModelSerializer

The serializer for the ViewSet.

Functions

create_local_model
create_local_model(request: HttpRequest, id: str) -> HttpResponse

Upload a partial trained model file from client.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

model uuid of the model, which was used for training

required

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "owner": {"type": "string"},
                "round": {"type": "int"},
                "sample_size": {"type": "int"},
                "metric_names": {"type": "list[string]"},
                "metric_values": {"type": "list[float]"},
                "model_file": {"type": "string", "format": "binary"},
            },
        },
    },
    responses={
        status.HTTP_200_OK: ModelSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def create_local_model(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Upload a partial trained model file from client.

    Args:
        request (HttpRequest):  request object
        id (str):  model uuid of the model, which was used for training

    Returns:
        HttpResponse: upload success message as json response
    """
    try:
        formdata = dict(request.POST)
        (round_num,) = formdata["round"]
        (sample_size,) = formdata["sample_size"]
        round_num, sample_size = int(round_num), int(sample_size)
        client = request.user
        model_file = get_file(request, "model_file")
        global_model = get_entity(GlobalModelDB, pk=id)

        # ensure that a training process coresponding to the model exists, else the process will error out
        training = Training.objects.get(model=global_model)
        self._verify_valid_update(client, training, round_num, TrainingState.ONGOING)

        verify_model_object(model_file)
        local_model = LocalModelDB.objects.create(
            base_model=global_model, weights=model_file,
            round=round_num, owner=client, sample_size=sample_size
        )
        self._metric_upload(formdata, local_model, client, metrics_required=False)

        updates = LocalModelDB.objects.filter(base_model=global_model, round=round_num)
        if updates.count() == training.participants.count():
            dispatch_trainer_task(training, TrainingRoundFinished, True)

        return JsonResponse({"detail": "Model Update Accepted"}, status=status.HTTP_201_CREATED)
    except Training.DoesNotExist:
        raise NotFound(f"Model with ID {id} does not have a training process running")
    except (MultiValueDictKeyError, KeyError) as e:
        raise ParseError(e)
create_model
create_model(request: HttpRequest) -> HttpResponse

Upload a global model file.

The model file should be a PyTorch serialized model. Providing the model via torch.save as well as in TorchScript format is supported.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "name": {"type": "string"},
                "description": {"type": "string"},
                "model_file": {"type": "string", "format": "binary"},
                "model_preprocessing_file": {"type": "string", "format": "binary", "required": "false"},
            },
        },
    },
    responses={
        status.HTTP_201_CREATED: inline_serializer("ModelUploadSerializer", fields={
            "detail": CharField(default="Model Upload Accepted"),
            "model_id": UUIDField(),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def create_model(self, request: HttpRequest) -> HttpResponse:
    """
    Upload a global model file.

    The model file should be a PyTorch serialized model.
    Providing the model via `torch.save` as well as in TorchScript format is supported.

    Args:
        request (HttpRequest): The incoming request object.

    Returns:
        HttpResponse: upload success message as json response
    """
    model = load_and_create_model_request(request)
    return JsonResponse({
        "detail": "Model Upload Accepted",
        "model_id": str(model.id),
    }, status=status.HTTP_201_CREATED)
create_model_metrics
create_model_metrics(request: HttpRequest, id: str) -> HttpResponse

Upload model metrics.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

model uuid

required

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "metric_names": {"type": "list"},
                "metric_values": {"type": "list"},
            },
        },
    },
    responses={
        status.HTTP_200_OK: inline_serializer("MetricUploadResponseSerializer", fields={
            "detail": CharField(default="Model Metrics Upload Accepted"),
            "model_id": UUIDField(),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
    examples=[
        OpenApiExample("Example", value={
            "metric_names": ["accuracy", "training loss"],
            "metric_values": [0.6, 0.04]
        }, media_type="multipart/form-data")
    ]
)
def create_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Upload model metrics.

    Args:
        request (HttpRequest):  request object
        id (str):  model uuid

    Returns:
        HttpResponse: upload success message as json response
    """
    model = get_entity(ModelDB, pk=id)
    formdata = dict(request.POST)

    with locked_atomic_transaction(MetricDB):
        self._metric_upload(formdata, model, request.user)

        if isinstance(model, GlobalModelDB):
            n_metrics = MetricDB.objects.filter(model=model, step=model.round).distinct("reporter").count()
            training = model.get_training()
            if training:
                if n_metrics == training.participants.count():
                    dispatch_trainer_task(training, ModelTestFinished, False)
            else:
                self._logger.warning(f"Global model {id} is not connected to any training.")

    return JsonResponse({
        "detail": "Model Metrics Upload Accepted",
        "model_id": str(model.id),
    }, status=status.HTTP_201_CREATED)
create_swag_stats
create_swag_stats(request: HttpRequest, id: str) -> HttpResponse

Upload SWAG statistics.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

global model uuid

required

Raises:

Type Description
APIException

internal server error

NotFound

model not found

ParseError

request data not valid

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "round": {"type": "int"},
                "sample_size": {"type": "int"},
                "first_moment_file": {"type": "string", "format": "binary"},
                "second_moment_file": {"type": "string", "format": "binary"}
            },
        },
    },
    responses={
        status.HTTP_200_OK: inline_serializer("MetricUploadSerializer", fields={
            "detail": CharField(default="SWAg Statistics Accepted"),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def create_swag_stats(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Upload SWAG statistics.

    Args:
        request (HttpRequest): request object
        id (str): global model uuid

    Raises:
        APIException: internal server error
        NotFound: model not found
        ParseError: request data not valid

    Returns:
        HttpResponse: upload success message as json response
    """
    try:
        client = request.user
        formdata = dict(request.POST)
        (round_num,) = formdata["round"]
        (sample_size,) = formdata["sample_size"]
        round_num, sample_size = int(round_num), int(sample_size)
        fst_moment = get_file(request, "first_moment_file")
        snd_moment = get_file(request, "second_moment_file")
        model = get_entity(GlobalModelDB, pk=id)

        # ensure that a training process coresponding to the model exists, else the process will error out
        training = Training.objects.get(model=model)
        self._verify_valid_update(client, training, round_num, TrainingState.SWAG_ROUND)

        self._save_swag_stats(fst_moment, snd_moment, model, client, sample_size)

        swag_stats_first = MetricDB.objects.filter(model=model, step=model.round, key="SWAG First Moment Local")
        swag_stats_second = MetricDB.objects.filter(model=model, step=model.round, key="SWAG Second Moment Local")

        if swag_stats_first.count() != swag_stats_second.count():
            training.state = TrainingState.ERROR
            raise APIException("SWAG stats in inconsistent state!")
        if swag_stats_first.count() == training.participants.count():
            dispatch_trainer_task(training, SWAGRoundFinished, True)

        return JsonResponse({"detail": "SWAG Statistic Accepted"}, status=status.HTTP_201_CREATED)
    except Training.DoesNotExist:
        raise NotFound(f"Model with ID {id} does not have a training process running")
    except (MultiValueDictKeyError, KeyError) as e:
        raise ParseError(e)
    except Exception as e:
        raise APIException(e)
get_metadata
get_metadata(_request: HttpRequest, id: str) -> HttpResponse

Get model meta data.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
id
str

The unique identifier of the model.

required

Returns:

Name Type Description
HttpResponse HttpResponse

Model meta data as JSON response.

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: ModelSerializerNoWeightsWithStats(),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_metadata(self, _request: HttpRequest, id: str) -> HttpResponse:
    """
    Get model meta data.

    Args:
        request (HttpRequest): The incoming request object.
        id (str): The unique identifier of the model.

    Returns:
        HttpResponse: Model meta data as JSON response.
    """
    model = get_entity(ModelDB, pk=id)
    serializer = ModelSerializer(model, context={"with-stats": True})
    return Response(serializer.data)
get_model
get_model(_request: HttpRequest, id: str) -> HttpResponseBase

Download the whole model as PyTorch serialized file.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
id
str

The unique identifier of the model.

required

Returns:

Name Type Description
HttpResponseBase HttpResponseBase

model as file response

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: OpenApiResponse(response=bytes, description="Model is returned as bytes"),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_model(self, _request: HttpRequest, id: str) -> HttpResponseBase:
    """
    Download the whole model as PyTorch serialized file.

    Args:
        request (HttpRequest): The incoming request object.
        id (str): The unique identifier of the model.

    Returns:
        HttpResponseBase: model as file response
    """
    model = get_entity(ModelDB, pk=id)
    if isinstance(model, SWAGModelDB) and model.swag_first_moment is not None:
        if model.swag_second_moment is None:
            raise APIException(f"Model {model.id} is in inconsistent state!")
        raise NotImplementedError(
            "SWAG models need to be returned in 3 parts: model architecture, first moment, second moment"
        )
    # NOTE: FileResponse does strange stuff with bytes
    #       and in case of sqlite the weights will be bytes and not a memoryview
    response = HttpResponse(model.weights, content_type="application/octet-stream")
    response["Content-Disposition"] = f'filename="model-{id}.pt"'
    return response
get_model_metrics
get_model_metrics(request: HttpRequest, id: str) -> HttpResponse

Reports all metrics for the selected model.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

model UUID

required

Returns:

Name Type Description
HttpResponse HttpResponse

Metrics as JSON Array

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: MetricSerializer,
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_model_metrics(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Reports all metrics for the selected model.

    Args:
        request (HttpRequest):  request object
        id (str):  model UUID

    Returns:
        HttpResponse: Metrics as JSON Array
    """
    model = get_entity(ModelDB, pk=id)
    metrics = MetricDB.objects.filter(model=model).all()
    return Response(MetricSerializer(metrics, many=True).data)
get_model_proprecessing
get_model_proprecessing(_request: HttpRequest, id: str) -> HttpResponseBase

Download the whole preprocessing model as PyTorch serialized file.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
id
str

The unique identifier of the model.

required

Returns:

Name Type Description
HttpResponseBase HttpResponseBase

proprecessing model as file response or 404 if proprecessing model not found

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: OpenApiResponse(
            response=bytes,
            description="Proprecessing model is returned as bytes"
        ),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
        status.HTTP_404_NOT_FOUND: error_response_404,
    },
)
def get_model_proprecessing(self, _request: HttpRequest, id: str) -> HttpResponseBase:
    """
    Download the whole preprocessing model as PyTorch serialized file.

    Args:
        request (HttpRequest): The incoming request object.
        id (str): The unique identifier of the model.

    Returns:
        HttpResponseBase: proprecessing model as file response or 404 if proprecessing model not found
    """
    model = get_entity(ModelDB, pk=id)
    global_model: torch.nn.Module
    if isinstance(model, GlobalModelDB):
        global_model = model
    elif isinstance(model, LocalModelDB):
        global_model = model.base_model
    else:
        self._logger.error("Unknown model type. Not a GlobalModel and not a LocalModel.")
        raise ValidationError(f"Unknown model type. Model id: {id}")
    if global_model.preprocessing is None:
        raise NotFound(f"Model '{id}' has no preprocessing model defined.")
    # NOTE: FileResponse does strange stuff with bytes
    #       and in case of sqlite the weights will be bytes and not a memoryview
    response = HttpResponse(global_model.preprocessing, content_type="application/octet-stream")
    response["Content-Disposition"] = f'filename="model-{id}-proprecessing.pt"'
    return response
get_models
get_models(request: HttpRequest) -> HttpResponse

Get a list of all global models associated with the requesting user.

A global model is deemed associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required

Returns:

Name Type Description
HttpResponse HttpResponse

Model list as JSON response.

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_models(self, request: HttpRequest) -> HttpResponse:
    """
    Get a list of all global models associated with the requesting user.

    A global model is deemed associated with a user if the user is either the owner of the model,
    or if the user is an actor or a participant in the model's training process.

    Args:
        request (HttpRequest): The incoming request object.

    Returns:
        HttpResponse: Model list as JSON response.
    """
    models = self._get_user_related_global_models(request.user)
    serializer = ModelSerializer(models, many=True)
    return Response(serializer.data)
get_training_models
get_training_models(request: HttpRequest, training_id: str) -> HttpResponse

Get a list of all models associated with a specific training process and the requesting user.

A model is deemed associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
training_id
str

The unique identifier of the training process.

required

Returns:

Name Type Description
HttpResponse HttpResponse

Model list as JSON response.

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_training_models(self, request: HttpRequest, training_id: str) -> HttpResponse:
    """
    Get a list of all models associated with a specific training process and the requesting user.

    A model is deemed associated with a user if the user is either the owner of the model,
    or if the user is an actor or a participant in the model's training process.

    Args:
        request (HttpRequest): The incoming request object.
        training_id (str): The unique identifier of the training process.

    Returns:
        HttpResponse: Model list as JSON response.
    """
    global_models = self._get_user_related_global_models(request.user)
    global_models = self._filter_by_training(global_models, training_id)
    local_models = self._get_local_models_for_global_models(global_models)
    serializer = ModelSerializer([*global_models, *local_models], many=True)
    return Response(serializer.data)
get_training_models_latest
get_training_models_latest(request: HttpRequest, training_id: str) -> HttpResponse

Get a list of the latest models for a specific training process associated with the requesting user.

A model is considered associated with a user if the user is either the owner of the model, or if the user is an actor or a participant in the model's training process. The latest model refers to the model from the most recent round (highest round number) of a participant's training process.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
training_id
str

The unique identifier of the training process.

required

Returns:

Name Type Description
HttpResponse HttpResponse

Model list as JSON response.

Source code in fl_server_api/views/model.py
@extend_schema(
    responses={
        status.HTTP_200_OK: ModelSerializerNoWeights(many=True),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def get_training_models_latest(self, request: HttpRequest, training_id: str) -> HttpResponse:
    """
    Get a list of the latest models for a specific training process associated with the requesting user.

    A model is considered associated with a user if the user is either the owner of the model,
    or if the user is an actor or a participant in the model's training process.
    The latest model refers to the model from the most recent round (highest round number) of
    a participant's training process.

    Args:
        request (HttpRequest): The incoming request object.
        training_id (str): The unique identifier of the training process.

    Returns:
        HttpResponse: Model list as JSON response.
    """
    models: List[ModelDB] = []
    # add latest global model
    global_models = self._get_user_related_global_models(request.user)
    global_models = self._filter_by_training(global_models, training_id)
    models.append(max(global_models, key=lambda m: m.round))
    # add latest local models
    local_models = self._get_local_models_for_global_models(global_models)
    local_models = sorted(local_models, key=lambda m: str(m.owner.pk))  # required for groupby
    for _, group in groupby(local_models, key=lambda m: str(m.owner.pk)):
        models.append(max(group, key=lambda m: m.round))
    serializer = ModelSerializer(models, many=True)
    return Response(serializer.data)
remove_model
remove_model(request: HttpRequest, id: str) -> HttpResponse

Remove an existing model.

Parameters:

Name Type Description Default
request
HttpRequest

The incoming request object.

required
id
str

The unique identifier of the model.

required

Returns:

Name Type Description
HttpResponse HttpResponse

200 Response if model was removed, else corresponding error code

Source code in fl_server_api/views/model.py
@extend_schema(responses={
    status.HTTP_200_OK: inline_serializer(
        "DeleteModelSuccessSerializer",
        fields={
            "detail": CharField(default="Model removed!")
        }
    ),
    status.HTTP_400_BAD_REQUEST: ErrorSerializer,
    status.HTTP_403_FORBIDDEN: error_response_403,
})
def remove_model(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Remove an existing model.

    Args:
        request (HttpRequest): The incoming request object.
        id (str): The unique identifier of the model.

    Returns:
        HttpResponse: 200 Response if model was removed, else corresponding error code
    """
    model = get_entity(ModelDB, pk=id)
    if model.owner != request.user:
        training = model.get_training()
        if training is None or training.actor != request.user:
            raise PermissionDenied(
                "You are neither the owner of the model nor the actor of the corresponding training."
            )
    model.delete()
    return JsonResponse({"detail": "Model removed!"})
upload_model_preprocessing
upload_model_preprocessing(request: HttpRequest, id: str) -> HttpResponse

Upload a preprocessing model file for a global model.

The preprocessing model file should be a PyTorch serialized model. Providing the model via torch.save as well as in TorchScript format is supported.

transforms = torch.nn.Sequential(
    torchvision.transforms.CenterCrop(10),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)

Make sure to only use transformations that inherit from torch.nn.Module. It is advised to use the torchvision.transforms.v2 module for common transformations.

Please note that this function is still in the beta phase.

Parameters:

Name Type Description Default
request
HttpRequest

request object

required
id
str

global model UUID

required

Raises:

Type Description
PermissionDenied

Unauthorized to upload preprocessing model for the specified model

ValidationError

Preprocessing model is not a valid torch model

Returns:

Name Type Description
HttpResponse HttpResponse

upload success message as json response

Source code in fl_server_api/views/model.py
@extend_schema(
    request={
        "multipart/form-data": {
            "type": "object",
            "properties": {
                "model_preprocessing_file": {"type": "string", "format": "binary"},
            },
        },
    },
    responses={
        status.HTTP_202_ACCEPTED: inline_serializer("PreprocessingModelUploadSerializer", fields={
            "detail": CharField(default="Proprocessing Model Upload Accepted"),
            "model_id": UUIDField(),
        }),
        status.HTTP_400_BAD_REQUEST: ErrorSerializer,
        status.HTTP_403_FORBIDDEN: error_response_403,
    },
)
def upload_model_preprocessing(self, request: HttpRequest, id: str) -> HttpResponse:
    """
    Upload a preprocessing model file for a global model.

    The preprocessing model file should be a PyTorch serialized model.
    Providing the model via `torch.save` as well as in TorchScript format is supported.

    ```python
    transforms = torch.nn.Sequential(
        torchvision.transforms.CenterCrop(10),
        torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    )
    ```

    Make sure to only use transformations that inherit from `torch.nn.Module`.
    It is advised to use the `torchvision.transforms.v2` module for common transformations.

    Please note that this function is still in the beta phase.

    Args:
        request (HttpRequest): request object
        id (str): global model UUID

    Raises:
        PermissionDenied: Unauthorized to upload preprocessing model for the specified model
        ValidationError: Preprocessing model is not a valid torch model

    Returns:
        HttpResponse: upload success message as json response
    """
    model = get_entity(GlobalModelDB, pk=id)
    if request.user.id != model.owner.id:
        raise PermissionDenied(f"You are not the owner of model {model.id}!")
    model.preprocessing = get_file(request, "model_preprocessing_file")
    verify_model_object(model.preprocessing, "preprocessing")
    model.save()
    return JsonResponse({
        "detail": "Proprocessing Model Upload Accepted",
    }, status=status.HTTP_202_ACCEPTED)

Functions