Module fl_server_ai.trainer.events.daisy_chain_round_finished¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from .training_round_finished import TrainingRoundFinished
from .. import model_trainer
class DaisyChainRoundFinished(TrainingRoundFinished):
"""
Federated daisy chain (FedDC) round finished event.
"""
def __init__(self, trainer: "model_trainer.ModelTrainer"):
super().__init__(trainer)
self.trainer.options.model_test_after_each_round = False
def handle(self):
"""
Handle the FedDC event.
(a) local client training is finished -> do the aggregation
(b) local client training is not finished yet (but we reach a daisy chaining period)
-> no aggregation, just send the permutated client models back for further training
- also see `FedDCModelTrainer.handle()`
"""
# the round increment is not done yet, therefore `model.round + 1`
if (self.training.model.round + 1) >= self.training.target_num_updates:
# local client training is finished, let's do the aggregation
super().handle() # also does the round increment
return
# local client training is not finished yet, but we reach a daisy chaining period
# => no aggregation, just send the permutated client models back for further training
# also see `FedDCModelTrainer.handle()`
self.training.model.round += 1
self.training.model.save()
Classes¶
DaisyChainRoundFinished¶
Federated daisy chain (FedDC) round finished event.
View Source
class DaisyChainRoundFinished(TrainingRoundFinished):
"""
Federated daisy chain (FedDC) round finished event.
"""
def __init__(self, trainer: "model_trainer.ModelTrainer"):
super().__init__(trainer)
self.trainer.options.model_test_after_each_round = False
def handle(self):
"""
Handle the FedDC event.
(a) local client training is finished -> do the aggregation
(b) local client training is not finished yet (but we reach a daisy chaining period)
-> no aggregation, just send the permutated client models back for further training
- also see `FedDCModelTrainer.handle()`
"""
# the round increment is not done yet, therefore `model.round + 1`
if (self.training.model.round + 1) >= self.training.target_num_updates:
# local client training is finished, let's do the aggregation
super().handle() # also does the round increment
return
# local client training is not finished yet, but we reach a daisy chaining period
# => no aggregation, just send the permutated client models back for further training
# also see `FedDCModelTrainer.handle()`
self.training.model.round += 1
self.training.model.save()
Ancestors (in MRO)¶
- fl_server_ai.trainer.events.training_round_finished.TrainingRoundFinished
- fl_server_ai.trainer.events.base.ModelTrainerEvent
- abc.ABC
Methods¶
handle¶
Handle the FedDC event.
(a) local client training is finished -> do the aggregation
(b) local client training is not finished yet (but we reach a daisy chaining period)
-> no aggregation, just send the permutated client models back for further training
- also see FedDCModelTrainer.handle()
View Source
def handle(self):
"""
Handle the FedDC event.
(a) local client training is finished -> do the aggregation
(b) local client training is not finished yet (but we reach a daisy chaining period)
-> no aggregation, just send the permutated client models back for further training
- also see `FedDCModelTrainer.handle()`
"""
# the round increment is not done yet, therefore `model.round + 1`
if (self.training.model.round + 1) >= self.training.target_num_updates:
# local client training is finished, let's do the aggregation
super().handle() # also does the round increment
return
# local client training is not finished yet, but we reach a daisy chaining period
# => no aggregation, just send the permutated client models back for further training
# also see `FedDCModelTrainer.handle()`
self.training.model.round += 1
self.training.model.save()
next¶
Proceed with the next event.
View Source
def next(self):
tests_enabled = not self.trainer.options.skip_model_tests
if tests_enabled and self.trainer.options.model_test_after_each_round:
self.trainer.test_round()
elif tests_enabled and self.training.model.round >= self.training.target_num_updates:
# at least test the final trained model
self.trainer.test_round()
else:
ModelTestFinished(self.trainer).next()