Skip to content

fl_server_ai.trainer.model_trainer

Classes:

Name Description
FedDCModelTrainer

Federated daisy-chaining (FedDC) model trainer.

ModelTrainer

Common model trainer.

SWAGModelTrainer

Stochastic weight averaging Gaussian (SWAG) model trainer.

Functions:

Name Description
get_trainer

Get a model trainer instance for the given training object.

get_trainer_class

Get the appropriate model trainer class based on a training object.

Classes

FedDCModelTrainer

Bases: ModelTrainer


              flowchart TD
              fl_server_ai.trainer.model_trainer.FedDCModelTrainer[FedDCModelTrainer]
              fl_server_ai.trainer.model_trainer.ModelTrainer[ModelTrainer]

                              fl_server_ai.trainer.model_trainer.ModelTrainer --> fl_server_ai.trainer.model_trainer.FedDCModelTrainer
                


              click fl_server_ai.trainer.model_trainer.FedDCModelTrainer href "" "fl_server_ai.trainer.model_trainer.FedDCModelTrainer"
              click fl_server_ai.trainer.model_trainer.ModelTrainer href "" "fl_server_ai.trainer.model_trainer.ModelTrainer"
            

Federated daisy-chaining (FedDC) model trainer.

To tackle the problem that client models are potentially quite small and thus the models tend to overfit and therefore result in bad prediction quality on unseen data, one proposed solution is FedDC (also named Federated daisy-chaining). FedDC sends before each aggregation step each client model to another randomly selected client, which trains it on its local data. From the model perspective, it is as if the model is trained on a larger dataset.

Paper: Picking Daisies in Private: Federated Learning from Small Datasets

Methods:

Name Description
handle
start_round
Source code in fl_server_ai/trainer/model_trainer.py
class FedDCModelTrainer(ModelTrainer):
    """
    Federated daisy-chaining (FedDC) model trainer.

    To tackle the problem that client models are potentially quite small and thus the models tend to overfit and
    therefore result in bad prediction quality on unseen data, one proposed solution is
    FedDC (also named Federated daisy-chaining).
    FedDC sends before each aggregation step each client model to another randomly selected client, which trains
    it on its local data.
    From the model perspective, it is as if the model is trained on a larger dataset.

    Paper: [Picking Daisies in Private: Federated Learning from Small Datasets](https://openreview.net/forum?id=GVDwiINkMR)
    """  # noqa: E501

    def start_round(self):
        dc_period = self.training.options.get("daisy_chain_period", 0)
        if dc_period < 1 or self.training.model.round % dc_period == 0:
            # start training round, first (local) training round, therefore no local models to permute
            TrainingRoundStartNotification.from_training(self.training).send()
            return

        # daily chaining period, therefore send the permutated client models back for further training
        clients = self.training.participants.all()
        model_ids = list(LocalModel.objects.filter(
            base_model=self.training.model,
            round=self.training.model.round - 1
        ).values_list("pk", flat=True))
        shuffle(model_ids)
        for client, model_id in zip(clients, model_ids):
            TrainingRoundStartNotification(
                receivers=[client],
                body=TrainingRoundStartNotification.Body(
                    round=self.training.model.round,
                    global_model_uuid=model_id
                ),
                training_uuid=self.training.id
            ).send()

    def handle(self, event: ModelTrainerEvent):
        if type(event) is TrainingRoundFinished:
            real_event = DaisyChainRoundFinished(self)
            real_event.handle()
            real_event.next()
        else:
            event.handle()
            event.next()

Functions

handle
handle(event: ModelTrainerEvent)
Source code in fl_server_ai/trainer/model_trainer.py
def handle(self, event: ModelTrainerEvent):
    if type(event) is TrainingRoundFinished:
        real_event = DaisyChainRoundFinished(self)
        real_event.handle()
        real_event.next()
    else:
        event.handle()
        event.next()
start_round
start_round()
Source code in fl_server_ai/trainer/model_trainer.py
def start_round(self):
    dc_period = self.training.options.get("daisy_chain_period", 0)
    if dc_period < 1 or self.training.model.round % dc_period == 0:
        # start training round, first (local) training round, therefore no local models to permute
        TrainingRoundStartNotification.from_training(self.training).send()
        return

    # daily chaining period, therefore send the permutated client models back for further training
    clients = self.training.participants.all()
    model_ids = list(LocalModel.objects.filter(
        base_model=self.training.model,
        round=self.training.model.round - 1
    ).values_list("pk", flat=True))
    shuffle(model_ids)
    for client, model_id in zip(clients, model_ids):
        TrainingRoundStartNotification(
            receivers=[client],
            body=TrainingRoundStartNotification.Body(
                round=self.training.model.round,
                global_model_uuid=model_id
            ),
            training_uuid=self.training.id
        ).send()

ModelTrainer

Common model trainer.

Methods:

Name Description
__init__

Initialize the trainer with the given training and options.

__new__

Ensure that the correct trainer class is returned based on the training.

finish

Finish the training and send a finished notification.

get_trainer_class

Get the appropriate model trainer class based on a training object.

handle

Handle a model trainer event and proceed to the next event.

handle_cls

Handle a model trainer event class by creating an instance of the event and handling it.

start

Start the training and send a start notification.

start_round

Start a training round and send a round start notification.

test_round

Test a training round and send a model test notification.

Attributes:

Name Type Description
options

The options for the trainer.

training

The training to be handled by the trainer.

Source code in fl_server_ai/trainer/model_trainer.py
class ModelTrainer:
    """
    Common model trainer.
    """

    def __new__(cls, training: Training, options: Optional[TrainerOptions] = None) -> "ModelTrainer":
        """
        Ensure that the correct trainer class is returned based on the training.

        The returned trainer class is determined by the `get_trainer_class` method.
        It could for example return a `SWAGModelTrainer` if the training uses the SWAG uncertainty method.

        Args:
            training (Training): The training for which to get the trainer.
            options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.

        Returns:
            "ModelTrainer": The trainer instance.
        """
        return super().__new__(cls.get_trainer_class(training))

    @classmethod
    def get_trainer_class(cls, training: Training) -> Type["ModelTrainer"]:
        """
        Get the appropriate model trainer class based on a training object.

        Args:
            training (Training): The training for which to get the trainer class.

        Returns:
            Type["ModelTrainer"]: The appropriate trainer class.
        """
        return get_trainer_class(training)

    def __init__(self, training: Training, options: Optional[TrainerOptions] = None):
        """
        Initialize the trainer with the given training and options.

        Args:
            training (Training): The training to be handled by the trainer.
            options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.
        """
        super().__init__()
        self.training = training
        """The training to be handled by the trainer."""
        self.options = options if options else TrainerOptions()
        """The options for the trainer."""

    def start(self):
        """
        Start the training and send a start notification.
        """
        self.training.state = TrainingState.ONGOING
        self.training.save()
        TrainingStartNotification.from_training(self.training).send()
        TrainingRoundStartNotification.from_training(self.training).send()

    def finish(self):
        """
        Finish the training and send a finished notification.
        """
        self.training.state = TrainingState.COMPLETED
        self.training.save()
        TrainingFinishedNotification.from_training(self.training).send()

    def start_round(self):
        """
        Start a training round and send a round start notification.
        """
        TrainingRoundStartNotification.from_training(self.training).send()

    def test_round(self):
        """
        Test a training round and send a model test notification.
        """
        TrainingModelTestNotification.from_training(self.training).send()

    def handle(self, event: ModelTrainerEvent):
        """
        Handle a model trainer event and proceed to the next event.

        Args:
            event (ModelTrainerEvent): The event to handle.
        """
        event.handle()
        event.next()

    def handle_cls(self, event_cls: Type[ModelTrainerEvent]):
        """
        Handle a model trainer event class by creating an instance of the event and handling it.

        Args:
            event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
        """
        self.handle(event_cls(self))

Attributes

options instance-attribute
options = options if options else TrainerOptions()

The options for the trainer.

training instance-attribute
training = training

The training to be handled by the trainer.

Functions

__init__
__init__(training: Training, options: TrainerOptions | None = None)

Initialize the trainer with the given training and options.

Parameters:

Name Type Description Default
training
Training

The training to be handled by the trainer.

required
options
TrainerOptions | None

The options for the trainer. If None, default options will be used.

None
Source code in fl_server_ai/trainer/model_trainer.py
def __init__(self, training: Training, options: Optional[TrainerOptions] = None):
    """
    Initialize the trainer with the given training and options.

    Args:
        training (Training): The training to be handled by the trainer.
        options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.
    """
    super().__init__()
    self.training = training
    """The training to be handled by the trainer."""
    self.options = options if options else TrainerOptions()
    """The options for the trainer."""
__new__
__new__(training: Training, options: TrainerOptions | None = None) -> ModelTrainer

Ensure that the correct trainer class is returned based on the training.

The returned trainer class is determined by the get_trainer_class method. It could for example return a SWAGModelTrainer if the training uses the SWAG uncertainty method.

Parameters:

Name Type Description Default
training
Training

The training for which to get the trainer.

required
options
TrainerOptions | None

The options for the trainer. If None, default options will be used.

None

Returns:

Type Description
ModelTrainer

"ModelTrainer": The trainer instance.

Source code in fl_server_ai/trainer/model_trainer.py
def __new__(cls, training: Training, options: Optional[TrainerOptions] = None) -> "ModelTrainer":
    """
    Ensure that the correct trainer class is returned based on the training.

    The returned trainer class is determined by the `get_trainer_class` method.
    It could for example return a `SWAGModelTrainer` if the training uses the SWAG uncertainty method.

    Args:
        training (Training): The training for which to get the trainer.
        options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.

    Returns:
        "ModelTrainer": The trainer instance.
    """
    return super().__new__(cls.get_trainer_class(training))
finish
finish()

Finish the training and send a finished notification.

Source code in fl_server_ai/trainer/model_trainer.py
def finish(self):
    """
    Finish the training and send a finished notification.
    """
    self.training.state = TrainingState.COMPLETED
    self.training.save()
    TrainingFinishedNotification.from_training(self.training).send()
get_trainer_class classmethod
get_trainer_class(training: Training) -> Type[ModelTrainer]

Get the appropriate model trainer class based on a training object.

Parameters:

Name Type Description Default
training
Training

The training for which to get the trainer class.

required

Returns:

Type Description
Type[ModelTrainer]

Type["ModelTrainer"]: The appropriate trainer class.

Source code in fl_server_ai/trainer/model_trainer.py
@classmethod
def get_trainer_class(cls, training: Training) -> Type["ModelTrainer"]:
    """
    Get the appropriate model trainer class based on a training object.

    Args:
        training (Training): The training for which to get the trainer class.

    Returns:
        Type["ModelTrainer"]: The appropriate trainer class.
    """
    return get_trainer_class(training)
handle

Handle a model trainer event and proceed to the next event.

Parameters:

Name Type Description Default
event
ModelTrainerEvent

The event to handle.

required
Source code in fl_server_ai/trainer/model_trainer.py
def handle(self, event: ModelTrainerEvent):
    """
    Handle a model trainer event and proceed to the next event.

    Args:
        event (ModelTrainerEvent): The event to handle.
    """
    event.handle()
    event.next()
handle_cls

Handle a model trainer event class by creating an instance of the event and handling it.

Parameters:

Name Type Description Default
event_cls
Type[ModelTrainerEvent]

The class of the event to handle.

required
Source code in fl_server_ai/trainer/model_trainer.py
def handle_cls(self, event_cls: Type[ModelTrainerEvent]):
    """
    Handle a model trainer event class by creating an instance of the event and handling it.

    Args:
        event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
    """
    self.handle(event_cls(self))
start
start()

Start the training and send a start notification.

Source code in fl_server_ai/trainer/model_trainer.py
def start(self):
    """
    Start the training and send a start notification.
    """
    self.training.state = TrainingState.ONGOING
    self.training.save()
    TrainingStartNotification.from_training(self.training).send()
    TrainingRoundStartNotification.from_training(self.training).send()
start_round
start_round()

Start a training round and send a round start notification.

Source code in fl_server_ai/trainer/model_trainer.py
def start_round(self):
    """
    Start a training round and send a round start notification.
    """
    TrainingRoundStartNotification.from_training(self.training).send()
test_round
test_round()

Test a training round and send a model test notification.

Source code in fl_server_ai/trainer/model_trainer.py
def test_round(self):
    """
    Test a training round and send a model test notification.
    """
    TrainingModelTestNotification.from_training(self.training).send()

SWAGModelTrainer

Bases: ModelTrainer


              flowchart TD
              fl_server_ai.trainer.model_trainer.SWAGModelTrainer[SWAGModelTrainer]
              fl_server_ai.trainer.model_trainer.ModelTrainer[ModelTrainer]

                              fl_server_ai.trainer.model_trainer.ModelTrainer --> fl_server_ai.trainer.model_trainer.SWAGModelTrainer
                


              click fl_server_ai.trainer.model_trainer.SWAGModelTrainer href "" "fl_server_ai.trainer.model_trainer.SWAGModelTrainer"
              click fl_server_ai.trainer.model_trainer.ModelTrainer href "" "fl_server_ai.trainer.model_trainer.ModelTrainer"
            

Stochastic weight averaging Gaussian (SWAG) model trainer.

Methods:

Name Description
handle
start_swag_round

Start a SWAG round and send a SWAG round start notification.

Source code in fl_server_ai/trainer/model_trainer.py
class SWAGModelTrainer(ModelTrainer):
    """
    Stochastic weight averaging Gaussian (SWAG) model trainer.
    """

    def start_swag_round(self):
        """
        Start a SWAG round and send a SWAG round start notification.
        """
        self.training.state = TrainingState.SWAG_ROUND
        self.training.save()
        TrainingSWAGRoundStartNotification.from_training(self.training).send()

    def handle(self, event: ModelTrainerEvent):
        event.handle()
        if type(event) is TrainingRoundFinished:
            self.start_swag_round()
        else:
            event.next()

Functions

handle
handle(event: ModelTrainerEvent)
Source code in fl_server_ai/trainer/model_trainer.py
def handle(self, event: ModelTrainerEvent):
    event.handle()
    if type(event) is TrainingRoundFinished:
        self.start_swag_round()
    else:
        event.next()
start_swag_round
start_swag_round()

Start a SWAG round and send a SWAG round start notification.

Source code in fl_server_ai/trainer/model_trainer.py
def start_swag_round(self):
    """
    Start a SWAG round and send a SWAG round start notification.
    """
    self.training.state = TrainingState.SWAG_ROUND
    self.training.save()
    TrainingSWAGRoundStartNotification.from_training(self.training).send()

Functions

get_trainer

get_trainer(training: Training, options: TrainerOptions | None = None) -> ModelTrainer

Get a model trainer instance for the given training object.

Parameters:

Name Type Description Default

training

Training

The training for which to get the trainer.

required

options

TrainerOptions | None

The options for the trainer. Defaults to None.

None

Returns:

Type Description
ModelTrainer

"ModelTrainer": The trainer instance.

Source code in fl_server_ai/trainer/model_trainer.py
def get_trainer(training: Training, options: Optional[TrainerOptions] = None) -> "ModelTrainer":
    """
    Get a model trainer instance for the given training object.

    Args:
        training (Training): The training for which to get the trainer.
        options (Optional[TrainerOptions]): The options for the trainer. Defaults to None.

    Returns:
        "ModelTrainer": The trainer instance.
    """
    return get_trainer_class(training)(training, options)

get_trainer_class

get_trainer_class(training: Training) -> Type[ModelTrainer]

Get the appropriate model trainer class based on a training object.

Parameters:

Name Type Description Default

training

Training

The training for which to get the trainer class.

required

Returns:

Type Description
Type[ModelTrainer]

Type["ModelTrainer"]: The appropriate trainer class.

Source code in fl_server_ai/trainer/model_trainer.py
def get_trainer_class(training: Training) -> Type["ModelTrainer"]:
    """
    Get the appropriate model trainer class based on a training object.

    Args:
        training (Training): The training for which to get the trainer class.

    Returns:
        Type["ModelTrainer"]: The appropriate trainer class.
    """
    if training.uncertainty_method == UncertaintyMethod.SWAG:
        return SWAGModelTrainer
    if training.options.get("daisy_chain_period", 0) > 0:
        return FedDCModelTrainer
    return ModelTrainer