Skip to content

fl_server_ai.trainer.events

Modules:

Name Description
base
daisy_chain_round_finished
model_test_finished
swag_round_finished
training_round_finished

Classes:

Name Description
DaisyChainRoundFinished

Federated daisy chain (FedDC) round finished event.

ModelTestFinished

Model test finished event.

ModelTrainerEvent

Abstract base class for a model trainer event.

SWAGRoundFinished

Stochastic weight averaging Gaussian (SWAG) round finished event.

TrainingRoundFinished

Training round finished event.

Attributes

__all__ module-attribute

__all__ = ['ModelTrainerEvent', 'DaisyChainRoundFinished', 'ModelTestFinished', 'SWAGRoundFinished', 'TrainingRoundFinished']

Classes

DaisyChainRoundFinished

Bases: TrainingRoundFinished


              flowchart TD
              fl_server_ai.trainer.events.DaisyChainRoundFinished[DaisyChainRoundFinished]
              fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished[TrainingRoundFinished]
              fl_server_ai.trainer.events.base.ModelTrainerEvent[ModelTrainerEvent]

                              fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished --> fl_server_ai.trainer.events.DaisyChainRoundFinished
                                fl_server_ai.trainer.events.base.ModelTrainerEvent --> fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished
                



              click fl_server_ai.trainer.events.DaisyChainRoundFinished href "" "fl_server_ai.trainer.events.DaisyChainRoundFinished"
              click fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished href "" "fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished"
              click fl_server_ai.trainer.events.base.ModelTrainerEvent href "" "fl_server_ai.trainer.events.base.ModelTrainerEvent"
            

Federated daisy chain (FedDC) round finished event.

Methods:

Name Description
__init__
handle

Handle the FedDC event.

Source code in fl_server_ai/trainer/events/daisy_chain_round_finished.py
class DaisyChainRoundFinished(TrainingRoundFinished):
    """
    Federated daisy chain (FedDC) round finished event.
    """

    def __init__(self, trainer: "model_trainer.ModelTrainer"):
        super().__init__(trainer)
        self.trainer.options.model_test_after_each_round = False

    def handle(self):
        """
        Handle the FedDC event.

        (a) local client training is finished -> do the aggregation

        (b) local client training is not finished yet (but we reach a daisy chaining period)
            -> no aggregation, just send the permutated client models back for further training
            - also see `FedDCModelTrainer.handle()`
        """
        # the round increment is not done yet, therefore `model.round + 1`
        if (self.training.model.round + 1) >= self.training.target_num_updates:
            # local client training is finished, let's do the aggregation
            super().handle()  # also does the round increment
            return

        # local client training is not finished yet, but we reach a daisy chaining period
        # => no aggregation, just send the permutated client models back for further training
        # also see `FedDCModelTrainer.handle()`

        self.training.model.round += 1
        self.training.model.save()

Functions

__init__
__init__(trainer: ModelTrainer)
Source code in fl_server_ai/trainer/events/daisy_chain_round_finished.py
def __init__(self, trainer: "model_trainer.ModelTrainer"):
    super().__init__(trainer)
    self.trainer.options.model_test_after_each_round = False
handle
handle()

Handle the FedDC event.

(a) local client training is finished -> do the aggregation

(b) local client training is not finished yet (but we reach a daisy chaining period) -> no aggregation, just send the permutated client models back for further training - also see FedDCModelTrainer.handle()

Source code in fl_server_ai/trainer/events/daisy_chain_round_finished.py
def handle(self):
    """
    Handle the FedDC event.

    (a) local client training is finished -> do the aggregation

    (b) local client training is not finished yet (but we reach a daisy chaining period)
        -> no aggregation, just send the permutated client models back for further training
        - also see `FedDCModelTrainer.handle()`
    """
    # the round increment is not done yet, therefore `model.round + 1`
    if (self.training.model.round + 1) >= self.training.target_num_updates:
        # local client training is finished, let's do the aggregation
        super().handle()  # also does the round increment
        return

    # local client training is not finished yet, but we reach a daisy chaining period
    # => no aggregation, just send the permutated client models back for further training
    # also see `FedDCModelTrainer.handle()`

    self.training.model.round += 1
    self.training.model.save()

ModelTestFinished

Bases: ModelTrainerEvent


              flowchart TD
              fl_server_ai.trainer.events.ModelTestFinished[ModelTestFinished]
              fl_server_ai.trainer.events.base.ModelTrainerEvent[ModelTrainerEvent]

                              fl_server_ai.trainer.events.base.ModelTrainerEvent --> fl_server_ai.trainer.events.ModelTestFinished
                


              click fl_server_ai.trainer.events.ModelTestFinished href "" "fl_server_ai.trainer.events.ModelTestFinished"
              click fl_server_ai.trainer.events.base.ModelTrainerEvent href "" "fl_server_ai.trainer.events.base.ModelTrainerEvent"
            

Model test finished event.

Methods:

Name Description
handle
next
Source code in fl_server_ai/trainer/events/model_test_finished.py
class ModelTestFinished(ModelTrainerEvent):
    """
    Model test finished event.
    """

    def next(self):
        if self.training.model.round < self.training.target_num_updates:
            self.trainer.start_round()
        else:
            self.trainer.finish()

    def handle(self):
        # currently do nothing
        # Potentially, one could aggregate all common metrics here.
        pass

Functions

handle
handle()
Source code in fl_server_ai/trainer/events/model_test_finished.py
def handle(self):
    # currently do nothing
    # Potentially, one could aggregate all common metrics here.
    pass
next
next()
Source code in fl_server_ai/trainer/events/model_test_finished.py
def next(self):
    if self.training.model.round < self.training.target_num_updates:
        self.trainer.start_round()
    else:
        self.trainer.finish()

ModelTrainerEvent

Bases: ABC


              flowchart TD
              fl_server_ai.trainer.events.ModelTrainerEvent[ModelTrainerEvent]

              

              click fl_server_ai.trainer.events.ModelTrainerEvent href "" "fl_server_ai.trainer.events.ModelTrainerEvent"
            

Abstract base class for a model trainer event.

Methods:

Name Description
__init__

Initialize the event with the given trainer.

handle

Handle the event.

next

Proceed with the next event.

Attributes:

Name Type Description
trainer

The trainer that the event is associated with.

training

The training that the event is associated with.

Source code in fl_server_ai/trainer/events/base.py
class ModelTrainerEvent(ABC):
    """
    Abstract base class for a model trainer event.
    """

    _logger = getLogger("fl.server")

    def __init__(self, trainer: "model_trainer.ModelTrainer"):
        """
        Initialize the event with the given trainer.

        Args:
            trainer (model_trainer.ModelTrainer): The trainer that the event is associated with.
        """
        super().__init__()
        self.trainer = trainer
        """The trainer that the event is associated with."""
        self.training = trainer.training
        """The training that the event is associated with."""

    @abstractmethod
    def handle(self):
        """
        Handle the event.
        """
        pass

    @abstractmethod
    def next(self):
        """
        Proceed with the next event.
        """
        pass

Attributes

trainer instance-attribute
trainer = trainer

The trainer that the event is associated with.

training instance-attribute
training = training

The training that the event is associated with.

Functions

__init__
__init__(trainer: ModelTrainer)

Initialize the event with the given trainer.

Parameters:

Name Type Description Default
trainer
ModelTrainer

The trainer that the event is associated with.

required
Source code in fl_server_ai/trainer/events/base.py
def __init__(self, trainer: "model_trainer.ModelTrainer"):
    """
    Initialize the event with the given trainer.

    Args:
        trainer (model_trainer.ModelTrainer): The trainer that the event is associated with.
    """
    super().__init__()
    self.trainer = trainer
    """The trainer that the event is associated with."""
    self.training = trainer.training
    """The training that the event is associated with."""
handle abstractmethod
handle()

Handle the event.

Source code in fl_server_ai/trainer/events/base.py
@abstractmethod
def handle(self):
    """
    Handle the event.
    """
    pass
next abstractmethod
next()

Proceed with the next event.

Source code in fl_server_ai/trainer/events/base.py
@abstractmethod
def next(self):
    """
    Proceed with the next event.
    """
    pass

SWAGRoundFinished

Bases: TrainingRoundFinished


              flowchart TD
              fl_server_ai.trainer.events.SWAGRoundFinished[SWAGRoundFinished]
              fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished[TrainingRoundFinished]
              fl_server_ai.trainer.events.base.ModelTrainerEvent[ModelTrainerEvent]

                              fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished --> fl_server_ai.trainer.events.SWAGRoundFinished
                                fl_server_ai.trainer.events.base.ModelTrainerEvent --> fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished
                



              click fl_server_ai.trainer.events.SWAGRoundFinished href "" "fl_server_ai.trainer.events.SWAGRoundFinished"
              click fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished href "" "fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished"
              click fl_server_ai.trainer.events.base.ModelTrainerEvent href "" "fl_server_ai.trainer.events.base.ModelTrainerEvent"
            

Stochastic weight averaging Gaussian (SWAG) round finished event.

Methods:

Name Description
handle

Handle the SWAG event by collecting the SWAG first and second moments

next
Source code in fl_server_ai/trainer/events/swag_round_finished.py
class SWAGRoundFinished(TrainingRoundFinished):
    """
    Stochastic weight averaging Gaussian (SWAG) round finished event.
    """

    def next(self):
        self.training.state = TrainingState.ONGOING
        self.training.save()
        super().next()

    def handle(self):
        """
        Handle the SWAG event by collecting the SWAG first and second moments
        from all participants and aggregating them.
        """
        # collect metric value
        swag_fst = [m.to_torch() for m in self._get_metric("SWAG First Moment Local")]
        swag_snd = [m.to_torch() for m in self._get_metric("SWAG Second Moment Local")]
        sample_sizes = [m.value_float for m in self._get_metric("SWAG Sample Size Local")]
        n_participants = self.training.participants.count()

        # validate
        self._validate_swag(swag_fst, swag_snd, sample_sizes, n_participants)
        self._logger.info(
            f"Training {self.training.id}: Doing SWAG aggregation as all {n_participants} updates arrived"
        )

        # SWAG aggregation and save
        self.training.model.first_moment = self._aggregate_param_vectors(swag_fst, sample_sizes)
        self.training.model.second_moment = self._aggregate_param_vectors(swag_snd, sample_sizes)
        self.training.model.save()
        self._logger.info(f"SWAG completed for training {self.training.id}")

    def _get_metric(self, key: str) -> QuerySet[Metric]:
        """
        Get database metrics that match the training model and round as well as given key.

        Args:
            key (str): The key of the metric to retrieve.

        Returns:
            QuerySet[Metric]: A QuerySet of Metric objects that match the training model and round as well as given key.
        """
        return Metric.objects.filter(
            model=self.training.model,
            step=self.training.model.round,
            key=key
        )

    def _validate_swag(
        self,
        swag_fst: List[torch.Tensor],
        swag_snd: List[torch.Tensor],
        sample_sizes: List[int],
        n_participants: int
    ):
        """
        Validate the SWAG parameters and participant number for the training.

        This method checks if the lengths of first and second SWAG moments, sample sizes
        as well as the number of participants.
        If any of these conditions are not met, an error is logged and a `RuntimeError` is raised.

        Args:
            swag_fst (List[torch.Tensor]): List of first SWAG moments.
            swag_snd (List[torch.Tensor]): List of second SWAG moments.
            sample_sizes (List[int]): List of sample sizes.
            n_participants (int): The number of participants in the training.

        Raises:
            ValueError: If the lengths of first and second SWAG moments, and sample sizes are not equal.
            RuntimeError: If the length of first SWAG moments does not match the number of participants.
        """
        if len(swag_fst) != len(swag_snd) != len(sample_sizes):
            self.training.state = TrainingState.ERROR
            self.training.save()
            raise ValueError("SWAG stats in inconsistent state!")

        if len(swag_fst) != n_participants:
            text = f"Aggregation was started, but training {self.training.id}" \
                f"has {len(swag_fst)} updates," \
                f"but {n_participants} clients!"
            self._logger.error(text)
            raise RuntimeError(text)

    @torch.no_grad()
    def _aggregate_param_vectors(
        self,
        param_vectors: List[torch.Tensor],
        sample_sizes: List[int]
    ) -> torch.Tensor:
        """
        Aggregate parameter vectors using sample sizes.

        This method checks if all parameter vectors have the same length and if the length of
        parameter vectors matches the length of sample sizes.
        If any of these conditions are not met, a `RuntimeError` is raised.

        Args:
            param_vectors (List[torch.Tensor]): List of parameter vectors.
            sample_sizes (List[int]): List of sample sizes.

        Returns:
            torch.Tensor: Aggregated parameter vector.

        Raises:
            AggregationException: If not all parameter vectors have the same length.
            RuntimeError: If the length of sample sizes does not match the length of parameter vectors.
        """
        if not all(map(lambda v: len(v) == len(param_vectors[0]), param_vectors[1:])):
            raise AggregationException("Models do not have the same number of parameters!")
        if len(param_vectors) != len(sample_sizes):
            raise RuntimeError("len(sample_sizes) != len(param_vectors)")

        factors = torch.tensor([s / sum(sample_sizes) for s in sample_sizes])
        result = torch.stack(param_vectors) * factors[:, None]
        result = torch.sum(result, dim=0)
        return result

Functions

handle
handle()

Handle the SWAG event by collecting the SWAG first and second moments from all participants and aggregating them.

Source code in fl_server_ai/trainer/events/swag_round_finished.py
def handle(self):
    """
    Handle the SWAG event by collecting the SWAG first and second moments
    from all participants and aggregating them.
    """
    # collect metric value
    swag_fst = [m.to_torch() for m in self._get_metric("SWAG First Moment Local")]
    swag_snd = [m.to_torch() for m in self._get_metric("SWAG Second Moment Local")]
    sample_sizes = [m.value_float for m in self._get_metric("SWAG Sample Size Local")]
    n_participants = self.training.participants.count()

    # validate
    self._validate_swag(swag_fst, swag_snd, sample_sizes, n_participants)
    self._logger.info(
        f"Training {self.training.id}: Doing SWAG aggregation as all {n_participants} updates arrived"
    )

    # SWAG aggregation and save
    self.training.model.first_moment = self._aggregate_param_vectors(swag_fst, sample_sizes)
    self.training.model.second_moment = self._aggregate_param_vectors(swag_snd, sample_sizes)
    self.training.model.save()
    self._logger.info(f"SWAG completed for training {self.training.id}")
next
next()
Source code in fl_server_ai/trainer/events/swag_round_finished.py
def next(self):
    self.training.state = TrainingState.ONGOING
    self.training.save()
    super().next()

TrainingRoundFinished

Bases: ModelTrainerEvent


              flowchart TD
              fl_server_ai.trainer.events.TrainingRoundFinished[TrainingRoundFinished]
              fl_server_ai.trainer.events.base.ModelTrainerEvent[ModelTrainerEvent]

                              fl_server_ai.trainer.events.base.ModelTrainerEvent --> fl_server_ai.trainer.events.TrainingRoundFinished
                


              click fl_server_ai.trainer.events.TrainingRoundFinished href "" "fl_server_ai.trainer.events.TrainingRoundFinished"
              click fl_server_ai.trainer.events.base.ModelTrainerEvent href "" "fl_server_ai.trainer.events.base.ModelTrainerEvent"
            

Training round finished event.

This event should only be triggered when all model updates (local models) that are to participate in the aggregation have arrived.

Methods:

Name Description
handle

Handle the training round finished event.

next
Source code in fl_server_ai/trainer/events/training_round_finished.py
class TrainingRoundFinished(ModelTrainerEvent):
    """
    Training round finished event.

    This event should only be triggered when all model updates (local models)
    that are to participate in the aggregation have arrived.
    """

    def next(self):
        tests_enabled = not self.trainer.options.skip_model_tests
        if tests_enabled and self.trainer.options.model_test_after_each_round:
            self.trainer.test_round()
        elif tests_enabled and self.training.model.round >= self.training.target_num_updates:
            # at least test the final trained model
            self.trainer.test_round()
        else:
            ModelTestFinished(self.trainer).next()

    def handle(self):
        """
        Handle the training round finished event.

        - aggregate all model updates (local models) into a new global model
        - save the new global model into the database (i.e. updates/overwrites the weights field of the model)
        - increase the round field of the global model by 1
        - delete the model updates (local models) from the database, if the trainer options do not disagree

        Note: If not enough updates have arrived, the method does nothing.
        """
        model_updates = LocalModel.objects.filter(base_model=self.training.model, round=self.training.model.round)
        models = [m.get_torch_model() for m in model_updates]
        model_sample_sizes = [m.sample_size for m in model_updates]
        n_participants = self.training.participants.count()

        # validate
        self._validate(models, n_participants)
        self._logger.info(f"Training {self.training.id}: Doing aggregation as all {n_participants} updates arrived")

        # do aggregation
        aggregation_cls = get_aggregation_class(self.training)
        final_model = aggregation_cls().aggregate(
            models,
            model_sample_sizes,
            deepcopy=not self.trainer.options.delete_local_models_after_aggregation
        )

        # write the result back to database and update the trainings round
        self.training.model.set_torch_model(final_model)
        self.training.model.round += 1
        self.training.model.save()

        # clean local updates
        if self.trainer.options.delete_local_models_after_aggregation:
            model_updates.delete()

    def _validate(self, models: List, n_participants: int):
        """
        Validate the models and participant number for the training.

        This method checks if there are any models and if the number of models matches the number of participants.
        If any of these conditions are not met, an error is logged and a `RuntimeError` is raised.

        Args:
            models (List): The list of models for the training.
            n_participants (int): The number of participants in the training.

        Raises:
            RuntimeError: If there are no models or if the number of models does not match the number of participants.
        """
        if not models:
            text = f"Aggregation was run for training {self.training.id} but no model updates were in db!"
            self._logger.error(text)
            raise RuntimeError(text)

        if len(models) != n_participants:
            text = f"Aggregation was started, but training {self.training.id} has {len(models)} updates," \
                f"but {n_participants} clients!"
            self._logger.error(text)
            raise RuntimeError(text)

Functions

handle
handle()

Handle the training round finished event.

  • aggregate all model updates (local models) into a new global model
  • save the new global model into the database (i.e. updates/overwrites the weights field of the model)
  • increase the round field of the global model by 1
  • delete the model updates (local models) from the database, if the trainer options do not disagree

Note: If not enough updates have arrived, the method does nothing.

Source code in fl_server_ai/trainer/events/training_round_finished.py
def handle(self):
    """
    Handle the training round finished event.

    - aggregate all model updates (local models) into a new global model
    - save the new global model into the database (i.e. updates/overwrites the weights field of the model)
    - increase the round field of the global model by 1
    - delete the model updates (local models) from the database, if the trainer options do not disagree

    Note: If not enough updates have arrived, the method does nothing.
    """
    model_updates = LocalModel.objects.filter(base_model=self.training.model, round=self.training.model.round)
    models = [m.get_torch_model() for m in model_updates]
    model_sample_sizes = [m.sample_size for m in model_updates]
    n_participants = self.training.participants.count()

    # validate
    self._validate(models, n_participants)
    self._logger.info(f"Training {self.training.id}: Doing aggregation as all {n_participants} updates arrived")

    # do aggregation
    aggregation_cls = get_aggregation_class(self.training)
    final_model = aggregation_cls().aggregate(
        models,
        model_sample_sizes,
        deepcopy=not self.trainer.options.delete_local_models_after_aggregation
    )

    # write the result back to database and update the trainings round
    self.training.model.set_torch_model(final_model)
    self.training.model.round += 1
    self.training.model.save()

    # clean local updates
    if self.trainer.options.delete_local_models_after_aggregation:
        model_updates.delete()
next
next()
Source code in fl_server_ai/trainer/events/training_round_finished.py
def next(self):
    tests_enabled = not self.trainer.options.skip_model_tests
    if tests_enabled and self.trainer.options.model_test_after_each_round:
        self.trainer.test_round()
    elif tests_enabled and self.training.model.round >= self.training.target_num_updates:
        # at least test the final trained model
        self.trainer.test_round()
    else:
        ModelTestFinished(self.trainer).next()