Skip to content

Module fl_server_ai.trainer.events.daisy_chain_round_finished

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from .training_round_finished import TrainingRoundFinished
from .. import model_trainer


class DaisyChainRoundFinished(TrainingRoundFinished):
    """
    Federated daisy chain (FedDC) round finished event.
    """

    def __init__(self, trainer: "model_trainer.ModelTrainer"):
        super().__init__(trainer)
        self.trainer.options.model_test_after_each_round = False

    def handle(self):
        """
        Handle the FedDC event.

        (a) local client training is finished -> do the aggregation

        (b) local client training is not finished yet (but we reach a daisy chaining period)
            -> no aggregation, just send the permutated client models back for further training
            - also see `FedDCModelTrainer.handle()`
        """
        # the round increment is not done yet, therefore `model.round + 1`
        if (self.training.model.round + 1) >= self.training.target_num_updates:
            # local client training is finished, let's do the aggregation
            super().handle()  # also does the round increment
            return

        # local client training is not finished yet, but we reach a daisy chaining period
        # => no aggregation, just send the permutated client models back for further training
        # also see `FedDCModelTrainer.handle()`

        self.training.model.round += 1
        self.training.model.save()

Classes

DaisyChainRoundFinished

class DaisyChainRoundFinished(
    trainer: 'model_trainer.ModelTrainer'
)

Federated daisy chain (FedDC) round finished event.

View Source
class DaisyChainRoundFinished(TrainingRoundFinished):
    """
    Federated daisy chain (FedDC) round finished event.
    """

    def __init__(self, trainer: "model_trainer.ModelTrainer"):
        super().__init__(trainer)
        self.trainer.options.model_test_after_each_round = False

    def handle(self):
        """
        Handle the FedDC event.

        (a) local client training is finished -> do the aggregation

        (b) local client training is not finished yet (but we reach a daisy chaining period)
            -> no aggregation, just send the permutated client models back for further training
            - also see `FedDCModelTrainer.handle()`
        """
        # the round increment is not done yet, therefore `model.round + 1`
        if (self.training.model.round + 1) >= self.training.target_num_updates:
            # local client training is finished, let's do the aggregation
            super().handle()  # also does the round increment
            return

        # local client training is not finished yet, but we reach a daisy chaining period
        # => no aggregation, just send the permutated client models back for further training
        # also see `FedDCModelTrainer.handle()`

        self.training.model.round += 1
        self.training.model.save()

Ancestors (in MRO)

  • fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished
  • fl_server_ai.trainer.events.base.ModelTrainerEvent
  • abc.ABC

Methods

handle

def handle(
    self
)

Handle the FedDC event.

(a) local client training is finished -> do the aggregation

(b) local client training is not finished yet (but we reach a daisy chaining period) -> no aggregation, just send the permutated client models back for further training - also see FedDCModelTrainer.handle()

View Source
    def handle(self):
        """
        Handle the FedDC event.

        (a) local client training is finished -> do the aggregation

        (b) local client training is not finished yet (but we reach a daisy chaining period)
            -> no aggregation, just send the permutated client models back for further training
            - also see `FedDCModelTrainer.handle()`
        """
        # the round increment is not done yet, therefore `model.round + 1`
        if (self.training.model.round + 1) >= self.training.target_num_updates:
            # local client training is finished, let's do the aggregation
            super().handle()  # also does the round increment
            return

        # local client training is not finished yet, but we reach a daisy chaining period
        # => no aggregation, just send the permutated client models back for further training
        # also see `FedDCModelTrainer.handle()`

        self.training.model.round += 1
        self.training.model.save()

next

def next(
    self
)

Proceed with the next event.

View Source
    def next(self):
        tests_enabled = not self.trainer.options.skip_model_tests
        if tests_enabled and self.trainer.options.model_test_after_each_round:
            self.trainer.test_round()
        elif tests_enabled and self.training.model.round >= self.training.target_num_updates:
            # at least test the final trained model
            self.trainer.test_round()
        else:
            ModelTestFinished(self.trainer).next()