Skip to content

fl_server_core.models.model

Classes:

Name Description
GlobalModel

Model class for global models.

LocalModel

Model class for local models.

MeanModel

Model class for mean models.

MeanModule

PyTorch module for mean models.

Model

Base model class for all types of model models.

SWAGModel

Model class for SWAG models.

Functions:

Name Description
clone_model

Copies a model instance in the database

Attributes:

Name Type Description
TModel

Attributes

TModel module-attribute

TModel = TypeVar('TModel', bound=Model)

Classes

GlobalModel

Bases: Model


              flowchart TD
              fl_server_core.models.model.GlobalModel[GlobalModel]
              fl_server_core.models.model.Model[Model]

                              fl_server_core.models.model.Model --> fl_server_core.models.model.GlobalModel
                


              click fl_server_core.models.model.GlobalModel href "" "fl_server_core.models.model.GlobalModel"
              click fl_server_core.models.model.Model href "" "fl_server_core.models.model.Model"
            

Model class for global models.

Methods:

Name Description
get_preprocessing_torch_model

Converts the preprocessing to a PyTorch model.

set_preprocessing_torch_model

Sets the preprocessing from a PyTorch model.

Attributes:

Name Type Description
description TextField

Description of the model.

input_shape ArrayField

Input shape of the model.

name CharField

Name of the model.

preprocessing BinaryField

Preprocessing of the model.

Source code in fl_server_core/models/model.py
class GlobalModel(Model):
    """
    Model class for global models.
    """

    name: CharField = CharField(max_length=256)
    """Name of the model."""
    description: TextField = TextField()
    """Description of the model."""
    # alternative to be postgres independent: create a new model for each nullable integer field
    # and map the corresponding list of integers to the model (but pay attention to the order)
    input_shape: ArrayField = ArrayField(IntegerField(null=True), null=True)
    """Input shape of the model."""
    preprocessing: BinaryField = BinaryField(null=True)
    """Preprocessing of the model."""

    def get_preprocessing_torch_model(self) -> torch.nn.Module:
        """
        Converts the preprocessing to a PyTorch model.

        Returns:
            torch.nn.Module: The PyTorch model.
        """
        return to_torch_module(self.preprocessing)

    def set_preprocessing_torch_model(self, value: torch.nn.Module):
        """
        Sets the preprocessing from a PyTorch model.

        Args:
            value (torch.nn.Module): The PyTorch model.
        """
        self.preprocessing = from_torch_module(value)

Attributes

description class-attribute instance-attribute
description: TextField = TextField()

Description of the model.

input_shape class-attribute instance-attribute
input_shape: ArrayField = ArrayField(IntegerField(null=True), null=True)

Input shape of the model.

name class-attribute instance-attribute
name: CharField = CharField(max_length=256)

Name of the model.

preprocessing class-attribute instance-attribute
preprocessing: BinaryField = BinaryField(null=True)

Preprocessing of the model.

Functions

get_preprocessing_torch_model
get_preprocessing_torch_model() -> Module

Converts the preprocessing to a PyTorch model.

Returns:

Type Description
Module

torch.nn.Module: The PyTorch model.

Source code in fl_server_core/models/model.py
def get_preprocessing_torch_model(self) -> torch.nn.Module:
    """
    Converts the preprocessing to a PyTorch model.

    Returns:
        torch.nn.Module: The PyTorch model.
    """
    return to_torch_module(self.preprocessing)
set_preprocessing_torch_model
set_preprocessing_torch_model(value: Module)

Sets the preprocessing from a PyTorch model.

Parameters:

Name Type Description Default
value
Module

The PyTorch model.

required
Source code in fl_server_core/models/model.py
def set_preprocessing_torch_model(self, value: torch.nn.Module):
    """
    Sets the preprocessing from a PyTorch model.

    Args:
        value (torch.nn.Module): The PyTorch model.
    """
    self.preprocessing = from_torch_module(value)

LocalModel

Bases: Model


              flowchart TD
              fl_server_core.models.model.LocalModel[LocalModel]
              fl_server_core.models.model.Model[Model]

                              fl_server_core.models.model.Model --> fl_server_core.models.model.LocalModel
                


              click fl_server_core.models.model.LocalModel href "" "fl_server_core.models.model.LocalModel"
              click fl_server_core.models.model.Model href "" "fl_server_core.models.model.Model"
            

Model class for local models.

Methods:

Name Description
get_training

Gets the training associated with the base model.

Attributes:

Name Type Description
base_model ForeignKey

Base model of the local model.

sample_size IntegerField

Sample size of the local model.

Source code in fl_server_core/models/model.py
class LocalModel(Model):
    """
    Model class for local models.
    """

    base_model: ForeignKey = ForeignKey(GlobalModel, on_delete=CASCADE)
    """Base model of the local model."""
    sample_size: IntegerField = IntegerField()
    """Sample size of the local model."""

    def get_training(self) -> Optional["models.Training"]:
        """
        Gets the training associated with the base model.

        Returns:
            models.Training: The training associated with the base model.
        """
        return models.Training.objects.filter(model=self.base_model).first()

Attributes

base_model class-attribute instance-attribute
base_model: ForeignKey = ForeignKey(GlobalModel, on_delete=CASCADE)

Base model of the local model.

sample_size class-attribute instance-attribute
sample_size: IntegerField = IntegerField()

Sample size of the local model.

Functions

get_training
get_training() -> Training | None

Gets the training associated with the base model.

Returns:

Type Description
Training | None

models.Training: The training associated with the base model.

Source code in fl_server_core/models/model.py
def get_training(self) -> Optional["models.Training"]:
    """
    Gets the training associated with the base model.

    Returns:
        models.Training: The training associated with the base model.
    """
    return models.Training.objects.filter(model=self.base_model).first()

MeanModel

Bases: GlobalModel


              flowchart TD
              fl_server_core.models.model.MeanModel[MeanModel]
              fl_server_core.models.model.GlobalModel[GlobalModel]
              fl_server_core.models.model.Model[Model]

                              fl_server_core.models.model.GlobalModel --> fl_server_core.models.model.MeanModel
                                fl_server_core.models.model.Model --> fl_server_core.models.model.GlobalModel
                



              click fl_server_core.models.model.MeanModel href "" "fl_server_core.models.model.MeanModel"
              click fl_server_core.models.model.GlobalModel href "" "fl_server_core.models.model.GlobalModel"
              click fl_server_core.models.model.Model href "" "fl_server_core.models.model.Model"
            

Model class for mean models.

Methods:

Name Description
get_torch_model

Converts the models to a PyTorch model.

set_torch_model

Sets the models from a PyTorch model.

Attributes:

Name Type Description
models ManyToManyField

Models of the mean model.

Source code in fl_server_core/models/model.py
class MeanModel(GlobalModel):
    """
    Model class for mean models.
    """

    models: ManyToManyField = ManyToManyField(GlobalModel, related_name="mean_models")
    """Models of the mean model."""

    def get_torch_model(self) -> torch.nn.Module:
        """
        Converts the models to a PyTorch model.

        Returns:
            torch.nn.Module: The PyTorch model.
        """
        torch_models: List[torch.nn.Module] = [model.get_torch_model() for model in self.models.all()]
        model = MeanModule(torch_models)
        if all(is_torchscript_instance(m) for m in torch_models):
            return torch.jit.script(model)
        return model

    def set_torch_model(self, value: torch.nn.Module):
        """
        Sets the models from a PyTorch model.

        Args:
            value (torch.nn.Module): The PyTorch model.
        """
        raise NotImplementedError()

Attributes

models class-attribute instance-attribute
models: ManyToManyField = ManyToManyField(GlobalModel, related_name='mean_models')

Models of the mean model.

Functions

get_torch_model
get_torch_model() -> Module

Converts the models to a PyTorch model.

Returns:

Type Description
Module

torch.nn.Module: The PyTorch model.

Source code in fl_server_core/models/model.py
def get_torch_model(self) -> torch.nn.Module:
    """
    Converts the models to a PyTorch model.

    Returns:
        torch.nn.Module: The PyTorch model.
    """
    torch_models: List[torch.nn.Module] = [model.get_torch_model() for model in self.models.all()]
    model = MeanModule(torch_models)
    if all(is_torchscript_instance(m) for m in torch_models):
        return torch.jit.script(model)
    return model
set_torch_model
set_torch_model(value: Module)

Sets the models from a PyTorch model.

Parameters:

Name Type Description Default
value
Module

The PyTorch model.

required
Source code in fl_server_core/models/model.py
def set_torch_model(self, value: torch.nn.Module):
    """
    Sets the models from a PyTorch model.

    Args:
        value (torch.nn.Module): The PyTorch model.
    """
    raise NotImplementedError()

MeanModule

Bases: Module


              flowchart TD
              fl_server_core.models.model.MeanModule[MeanModule]

              

              click fl_server_core.models.model.MeanModule href "" "fl_server_core.models.model.MeanModule"
            

PyTorch module for mean models.

Methods:

Name Description
__init__

Initializes the mean models.

forward

Forward pass of the mean models.

Attributes:

Name Type Description
models

Models of the mean model.

Source code in fl_server_core/models/model.py
class MeanModule(Module):
    """
    PyTorch module for mean models.
    """

    def __init__(self, models: Sequence[torch.nn.Module]):
        """
        Initializes the mean models.

        Args:
            models (Sequence[torch.nn.Module]): The models of the mean model.
        """
        super().__init__()
        self.models = models
        """Models of the mean model."""

    def forward(self, input: Tensor) -> Tensor:
        """
        Forward pass of the mean models.

        Args:
            input (Tensor): The input tensor.

        Returns:
            Tensor: The output tensor.
        """
        return torch.stack([model(input) for model in self.models], dim=0).mean(dim=0)

Attributes

models instance-attribute
models = models

Models of the mean model.

Functions

__init__
__init__(models: Sequence[Module])

Initializes the mean models.

Parameters:

Name Type Description Default
models
Sequence[Module]

The models of the mean model.

required
Source code in fl_server_core/models/model.py
def __init__(self, models: Sequence[torch.nn.Module]):
    """
    Initializes the mean models.

    Args:
        models (Sequence[torch.nn.Module]): The models of the mean model.
    """
    super().__init__()
    self.models = models
    """Models of the mean model."""
forward
forward(input: Tensor) -> Tensor

Forward pass of the mean models.

Parameters:

Name Type Description Default
input
Tensor

The input tensor.

required

Returns:

Name Type Description
Tensor Tensor

The output tensor.

Source code in fl_server_core/models/model.py
def forward(self, input: Tensor) -> Tensor:
    """
    Forward pass of the mean models.

    Args:
        input (Tensor): The input tensor.

    Returns:
        Tensor: The output tensor.
    """
    return torch.stack([model(input) for model in self.models], dim=0).mean(dim=0)

Model

Bases: PolymorphicModel


              flowchart TD
              fl_server_core.models.model.Model[Model]

              

              click fl_server_core.models.model.Model href "" "fl_server_core.models.model.Model"
            

Base model class for all types of model models.

Methods:

Name Description
get_torch_model

Converts the model weights to a PyTorch model.

get_training

Gets the training associated with the model.

is_global_model

Checks if the model is a global model.

is_local_model

Checks if the model is a local model.

set_torch_model

Sets the model weights from a PyTorch model.

Attributes:

Name Type Description
id UUIDField

Unique identifier for the model.

owner ForeignKey

User who owns the model.

round IntegerField

Round number of the model.

weights BinaryField

Weights of the model.

Source code in fl_server_core/models/model.py
class Model(PolymorphicModel):
    """
    Base model class for all types of model models.
    """

    id: UUIDField = UUIDField(primary_key=True, editable=False, default=uuid4)
    """Unique identifier for the model."""
    owner: ForeignKey = ForeignKey(User, on_delete=CASCADE)
    """User who owns the model."""
    round: IntegerField = IntegerField()
    """Round number of the model."""
    weights: BinaryField = BinaryField()
    """Weights of the model."""

    def is_global_model(self):
        """
        Checks if the model is a global model.

        Returns:
            bool: True if the model is a global model, False otherwise.
        """
        return isinstance(self, GlobalModel)

    def is_local_model(self):
        """
        Checks if the model is a local model.

        Returns:
            bool: True if the model is a local model, False otherwise.
        """
        return isinstance(self, LocalModel)

    def get_torch_model(self) -> torch.nn.Module:
        """
        Converts the model weights to a PyTorch model.

        Returns:
            torch.nn.Module: The PyTorch model.
        """
        return to_torch_module(self.weights)

    def set_torch_model(self, value: torch.nn.Module):
        """
        Sets the model weights from a PyTorch model.

        Args:
            value (torch.nn.Module): The PyTorch model.
        """
        self.weights = from_torch_module(value)

    def get_training(self) -> Optional["models.Training"]:
        """
        Gets the training associated with the model.

        Returns:
            models.Training: The training associated with the model.
        """
        return models.Training.objects.filter(model=self).first()

Attributes

id class-attribute instance-attribute
id: UUIDField = UUIDField(primary_key=True, editable=False, default=uuid4)

Unique identifier for the model.

owner class-attribute instance-attribute
owner: ForeignKey = ForeignKey(User, on_delete=CASCADE)

User who owns the model.

round class-attribute instance-attribute
round: IntegerField = IntegerField()

Round number of the model.

weights class-attribute instance-attribute
weights: BinaryField = BinaryField()

Weights of the model.

Functions

get_torch_model
get_torch_model() -> Module

Converts the model weights to a PyTorch model.

Returns:

Type Description
Module

torch.nn.Module: The PyTorch model.

Source code in fl_server_core/models/model.py
def get_torch_model(self) -> torch.nn.Module:
    """
    Converts the model weights to a PyTorch model.

    Returns:
        torch.nn.Module: The PyTorch model.
    """
    return to_torch_module(self.weights)
get_training
get_training() -> Training | None

Gets the training associated with the model.

Returns:

Type Description
Training | None

models.Training: The training associated with the model.

Source code in fl_server_core/models/model.py
def get_training(self) -> Optional["models.Training"]:
    """
    Gets the training associated with the model.

    Returns:
        models.Training: The training associated with the model.
    """
    return models.Training.objects.filter(model=self).first()
is_global_model
is_global_model()

Checks if the model is a global model.

Returns:

Name Type Description
bool

True if the model is a global model, False otherwise.

Source code in fl_server_core/models/model.py
def is_global_model(self):
    """
    Checks if the model is a global model.

    Returns:
        bool: True if the model is a global model, False otherwise.
    """
    return isinstance(self, GlobalModel)
is_local_model
is_local_model()

Checks if the model is a local model.

Returns:

Name Type Description
bool

True if the model is a local model, False otherwise.

Source code in fl_server_core/models/model.py
def is_local_model(self):
    """
    Checks if the model is a local model.

    Returns:
        bool: True if the model is a local model, False otherwise.
    """
    return isinstance(self, LocalModel)
set_torch_model
set_torch_model(value: Module)

Sets the model weights from a PyTorch model.

Parameters:

Name Type Description Default
value
Module

The PyTorch model.

required
Source code in fl_server_core/models/model.py
def set_torch_model(self, value: torch.nn.Module):
    """
    Sets the model weights from a PyTorch model.

    Args:
        value (torch.nn.Module): The PyTorch model.
    """
    self.weights = from_torch_module(value)

SWAGModel

Bases: GlobalModel


              flowchart TD
              fl_server_core.models.model.SWAGModel[SWAGModel]
              fl_server_core.models.model.GlobalModel[GlobalModel]
              fl_server_core.models.model.Model[Model]

                              fl_server_core.models.model.GlobalModel --> fl_server_core.models.model.SWAGModel
                                fl_server_core.models.model.Model --> fl_server_core.models.model.GlobalModel
                



              click fl_server_core.models.model.SWAGModel href "" "fl_server_core.models.model.SWAGModel"
              click fl_server_core.models.model.GlobalModel href "" "fl_server_core.models.model.GlobalModel"
              click fl_server_core.models.model.Model href "" "fl_server_core.models.model.Model"
            

Model class for SWAG models.

Attributes:

Name Type Description
first_moment Tensor

Gets the first moment of the SWAG model.

second_moment Tensor

Gets the second moment of the SWAG model.

swag_first_moment BinaryField

First moment of the SWAG model.

swag_second_moment BinaryField

Second moment of the SWAG model.

Source code in fl_server_core/models/model.py
class SWAGModel(GlobalModel):
    """
    Model class for SWAG models.
    """

    swag_first_moment: BinaryField = BinaryField()
    """First moment of the SWAG model."""
    swag_second_moment: BinaryField = BinaryField()
    """Second moment of the SWAG model."""

    @property
    def first_moment(self) -> Tensor:
        """
        Gets the first moment of the SWAG model.

        Returns:
            Tensor: The first moment of the SWAG model.
        """
        return to_torch_tensor(self.swag_first_moment)

    @first_moment.setter
    def first_moment(self, value: Tensor):
        """
        Sets the first moment of the SWAG model.

        Args:
            value (Tensor): The first moment of the SWAG model.
        """
        self.swag_first_moment = from_torch_tensor(value)

    @property
    def second_moment(self) -> Tensor:
        """
        Gets the second moment of the SWAG model.

        Returns:
            Tensor: The second moment of the SWAG model.
        """
        return to_torch_tensor(self.swag_second_moment)

    @second_moment.setter
    def second_moment(self, value: Tensor):
        """
        Sets the second moment of the SWAG model.

        Args:
            value (Tensor): The second moment of the SWAG model.
        """
        self.swag_second_moment = from_torch_tensor(value)

Attributes

first_moment property writable
first_moment: Tensor

Gets the first moment of the SWAG model.

Returns:

Name Type Description
Tensor Tensor

The first moment of the SWAG model.

second_moment property writable
second_moment: Tensor

Gets the second moment of the SWAG model.

Returns:

Name Type Description
Tensor Tensor

The second moment of the SWAG model.

swag_first_moment class-attribute instance-attribute
swag_first_moment: BinaryField = BinaryField()

First moment of the SWAG model.

swag_second_moment class-attribute instance-attribute
swag_second_moment: BinaryField = BinaryField()

Second moment of the SWAG model.

Functions

clone_model

clone_model(model: Model) -> Model

Copies a model instance in the database See https://docs.djangoproject.com/en/5.0/topics/db/queries/#copying-model-instances and stackoverflow.com/questions/4733609/how-do-i-clone-a-django-model-instance-object-and-save-it-to-the-database

Parameters:

Name Type Description Default

model

Model

the model to be copied

required

Returns:

Name Type Description
Model Model

New Model instance that is a copy of the old one

Source code in fl_server_core/models/model.py
def clone_model(model: Model) -> Model:
    """
    Copies a model instance in the database
    See https://docs.djangoproject.com/en/5.0/topics/db/queries/#copying-model-instances
    and stackoverflow.com/questions/4733609/how-do-i-clone-a-django-model-instance-object-and-save-it-to-the-database

    Args:
        model (Model):  the model to be copied

    Returns:
        Model: New Model instance that is a copy of the old one
    """
    model.save()
    new_model = copy(model)
    new_model.pk = None
    new_model.id = None
    new_model._state.adding = True
    try:
        delattr(new_model, '_prefetched_objects_cache')
    except AttributeError:
        pass
    new_model.save()
    new_model.owner = model.owner
    new_model.save()
    return new_model