Module fl_server_ai.trainer.events¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from .base import ModelTrainerEvent
from .daisy_chain_round_finished import DaisyChainRoundFinished
from .model_test_finished import ModelTestFinished
from .swag_round_finished import SWAGRoundFinished
from .training_round_finished import TrainingRoundFinished
__all__ = [
"ModelTrainerEvent",
"DaisyChainRoundFinished",
"ModelTestFinished",
"SWAGRoundFinished",
"TrainingRoundFinished",
]
Sub-modules¶
- fl_server_ai.trainer.events.base
- fl_server_ai.trainer.events.daisy_chain_round_finished
- fl_server_ai.trainer.events.model_test_finished
- fl_server_ai.trainer.events.swag_round_finished
- fl_server_ai.trainer.events.training_round_finished
Classes¶
DaisyChainRoundFinished¶
Federated daisy chain (FedDC) round finished event.
View Source
class DaisyChainRoundFinished(TrainingRoundFinished):
"""
Federated daisy chain (FedDC) round finished event.
"""
def __init__(self, trainer: "model_trainer.ModelTrainer"):
super().__init__(trainer)
self.trainer.options.model_test_after_each_round = False
def handle(self):
"""
Handle the FedDC event.
(a) local client training is finished -> do the aggregation
(b) local client training is not finished yet (but we reach a daisy chaining period)
-> no aggregation, just send the permutated client models back for further training
- also see `FedDCModelTrainer.handle()`
"""
# the round increment is not done yet, therefore `model.round + 1`
if (self.training.model.round + 1) >= self.training.target_num_updates:
# local client training is finished, let's do the aggregation
super().handle() # also does the round increment
return
# local client training is not finished yet, but we reach a daisy chaining period
# => no aggregation, just send the permutated client models back for further training
# also see `FedDCModelTrainer.handle()`
self.training.model.round += 1
self.training.model.save()
Ancestors (in MRO)¶
- fl_server_ai.trainer.events.TrainingRoundFinished
- fl_server_ai.trainer.events.ModelTrainerEvent
- abc.ABC
Methods¶
handle¶
Handle the FedDC event.
(a) local client training is finished -> do the aggregation
(b) local client training is not finished yet (but we reach a daisy chaining period)
-> no aggregation, just send the permutated client models back for further training
- also see FedDCModelTrainer.handle()
View Source
def handle(self):
"""
Handle the FedDC event.
(a) local client training is finished -> do the aggregation
(b) local client training is not finished yet (but we reach a daisy chaining period)
-> no aggregation, just send the permutated client models back for further training
- also see `FedDCModelTrainer.handle()`
"""
# the round increment is not done yet, therefore `model.round + 1`
if (self.training.model.round + 1) >= self.training.target_num_updates:
# local client training is finished, let's do the aggregation
super().handle() # also does the round increment
return
# local client training is not finished yet, but we reach a daisy chaining period
# => no aggregation, just send the permutated client models back for further training
# also see `FedDCModelTrainer.handle()`
self.training.model.round += 1
self.training.model.save()
next¶
Proceed with the next event.
View Source
def next(self):
tests_enabled = not self.trainer.options.skip_model_tests
if tests_enabled and self.trainer.options.model_test_after_each_round:
self.trainer.test_round()
elif tests_enabled and self.training.model.round >= self.training.target_num_updates:
# at least test the final trained model
self.trainer.test_round()
else:
ModelTestFinished(self.trainer).next()
ModelTestFinished¶
Model test finished event.
View Source
class ModelTestFinished(ModelTrainerEvent):
"""
Model test finished event.
"""
def next(self):
if self.training.model.round < self.training.target_num_updates:
self.trainer.start_round()
else:
self.trainer.finish()
def handle(self):
# currently do nothing
# Potentially, one could aggregate all common metrics here.
pass
Ancestors (in MRO)¶
- fl_server_ai.trainer.events.ModelTrainerEvent
- abc.ABC
Methods¶
handle¶
Handle the event.
View Source
next¶
Proceed with the next event.
View Source
ModelTrainerEvent¶
Abstract base class for a model trainer event.
View Source
class ModelTrainerEvent(ABC):
"""
Abstract base class for a model trainer event.
"""
_logger = getLogger("fl.server")
def __init__(self, trainer: "model_trainer.ModelTrainer"):
"""
Initialize the event with the given trainer.
Args:
trainer (model_trainer.ModelTrainer): The trainer that the event is associated with.
"""
super().__init__()
self.trainer = trainer
"""The trainer that the event is associated with."""
self.training = trainer.training
"""The training that the event is associated with."""
@abstractmethod
def handle(self):
"""
Handle the event.
"""
pass
@abstractmethod
def next(self):
"""
Proceed with the next event.
"""
pass
Ancestors (in MRO)¶
- abc.ABC
Descendants¶
- fl_server_ai.trainer.events.ModelTestFinished
- fl_server_ai.trainer.events.TrainingRoundFinished
Methods¶
handle¶
Handle the event.
next¶
Proceed with the next event.
SWAGRoundFinished¶
Stochastic weight averaging Gaussian (SWAG) round finished event.
View Source
class SWAGRoundFinished(TrainingRoundFinished):
"""
Stochastic weight averaging Gaussian (SWAG) round finished event.
"""
def next(self):
self.training.state = TrainingState.ONGOING
self.training.save()
super().next()
def handle(self):
"""
Handle the SWAG event by collecting the SWAG first and second moments
from all participants and aggregating them.
"""
# collect metric value
swag_fst = [m.to_torch() for m in self._get_metric("SWAG First Moment Local")]
swag_snd = [m.to_torch() for m in self._get_metric("SWAG Second Moment Local")]
sample_sizes = [m.value_float for m in self._get_metric("SWAG Sample Size Local")]
n_participants = self.training.participants.count()
# validate
self._validate_swag(swag_fst, swag_snd, sample_sizes, n_participants)
self._logger.info(
f"Training {self.training.id}: Doing SWAG aggregation as all {n_participants} updates arrived"
)
# SWAG aggregation and save
self.training.model.first_moment = self._aggregate_param_vectors(swag_fst, sample_sizes)
self.training.model.second_moment = self._aggregate_param_vectors(swag_snd, sample_sizes)
self.training.model.save()
self._logger.info(f"SWAG completed for training {self.training.id}")
def _get_metric(self, key: str) -> QuerySet[Metric]:
"""
Get database metrics that match the training model and round as well as given key.
Args:
key (str): The key of the metric to retrieve.
Returns:
QuerySet[Metric]: A QuerySet of Metric objects that match the training model and round as well as given key.
"""
return Metric.objects.filter(
model=self.training.model,
step=self.training.model.round,
key=key
)
def _validate_swag(
self,
swag_fst: List[torch.Tensor],
swag_snd: List[torch.Tensor],
sample_sizes: List[int],
n_participants: int
):
"""
Validate the SWAG parameters and participant number for the training.
This method checks if the lengths of first and second SWAG moments, sample sizes
as well as the number of participants.
If any of these conditions are not met, an error is logged and a `RuntimeError` is raised.
Args:
swag_fst (List[torch.Tensor]): List of first SWAG moments.
swag_snd (List[torch.Tensor]): List of second SWAG moments.
sample_sizes (List[int]): List of sample sizes.
n_participants (int): The number of participants in the training.
Raises:
ValueError: If the lengths of first and second SWAG moments, and sample sizes are not equal.
RuntimeError: If the length of first SWAG moments does not match the number of participants.
"""
if len(swag_fst) != len(swag_snd) != len(sample_sizes):
self.training.state = TrainingState.ERROR
self.training.save()
raise ValueError("SWAG stats in inconsistent state!")
if len(swag_fst) != n_participants:
text = f"Aggregation was started, but training {self.training.id}" \
f"has {len(swag_fst)} updates," \
f"but {n_participants} clients!"
self._logger.error(text)
raise RuntimeError(text)
@torch.no_grad()
def _aggregate_param_vectors(
self,
param_vectors: List[torch.Tensor],
sample_sizes: List[int]
) -> torch.Tensor:
"""
Aggregate parameter vectors using sample sizes.
This method checks if all parameter vectors have the same length and if the length of
parameter vectors matches the length of sample sizes.
If any of these conditions are not met, a `RuntimeError` is raised.
Args:
param_vectors (List[torch.Tensor]): List of parameter vectors.
sample_sizes (List[int]): List of sample sizes.
Returns:
torch.Tensor: Aggregated parameter vector.
Raises:
AggregationException: If not all parameter vectors have the same length.
RuntimeError: If the length of sample sizes does not match the length of parameter vectors.
"""
if not all(map(lambda v: len(v) == len(param_vectors[0]), param_vectors[1:])):
raise AggregationException("Models do not have the same number of parameters!")
if len(param_vectors) != len(sample_sizes):
raise RuntimeError("len(sample_sizes) != len(param_vectors)")
factors = torch.tensor([s / sum(sample_sizes) for s in sample_sizes])
result = torch.stack(param_vectors) * factors[:, None]
result = torch.sum(result, dim=0)
return result
Ancestors (in MRO)¶
- fl_server_ai.trainer.events.TrainingRoundFinished
- fl_server_ai.trainer.events.ModelTrainerEvent
- abc.ABC
Methods¶
handle¶
Handle the SWAG event by collecting the SWAG first and second moments
from all participants and aggregating them.
View Source
def handle(self):
"""
Handle the SWAG event by collecting the SWAG first and second moments
from all participants and aggregating them.
"""
# collect metric value
swag_fst = [m.to_torch() for m in self._get_metric("SWAG First Moment Local")]
swag_snd = [m.to_torch() for m in self._get_metric("SWAG Second Moment Local")]
sample_sizes = [m.value_float for m in self._get_metric("SWAG Sample Size Local")]
n_participants = self.training.participants.count()
# validate
self._validate_swag(swag_fst, swag_snd, sample_sizes, n_participants)
self._logger.info(
f"Training {self.training.id}: Doing SWAG aggregation as all {n_participants} updates arrived"
)
# SWAG aggregation and save
self.training.model.first_moment = self._aggregate_param_vectors(swag_fst, sample_sizes)
self.training.model.second_moment = self._aggregate_param_vectors(swag_snd, sample_sizes)
self.training.model.save()
self._logger.info(f"SWAG completed for training {self.training.id}")
next¶
Proceed with the next event.
View Source
TrainingRoundFinished¶
Training round finished event.
This event should only be triggered when all model updates (local models) that are to participate in the aggregation have arrived.
View Source
class TrainingRoundFinished(ModelTrainerEvent):
"""
Training round finished event.
This event should only be triggered when all model updates (local models)
that are to participate in the aggregation have arrived.
"""
def next(self):
tests_enabled = not self.trainer.options.skip_model_tests
if tests_enabled and self.trainer.options.model_test_after_each_round:
self.trainer.test_round()
elif tests_enabled and self.training.model.round >= self.training.target_num_updates:
# at least test the final trained model
self.trainer.test_round()
else:
ModelTestFinished(self.trainer).next()
def handle(self):
"""
Handle the training round finished event.
- aggregate all model updates (local models) into a new global model
- save the new global model into the database (i.e. updates/overwrites the weights field of the model)
- increase the round field of the global model by 1
- delete the model updates (local models) from the database, if the trainer options do not disagree
Note: If not enough updates have arrived, the method does nothing.
"""
model_updates = LocalModel.objects.filter(base_model=self.training.model, round=self.training.model.round)
models = [m.get_torch_model() for m in model_updates]
model_sample_sizes = [m.sample_size for m in model_updates]
n_participants = self.training.participants.count()
# validate
self._validate(models, n_participants)
self._logger.info(f"Training {self.training.id}: Doing aggregation as all {n_participants} updates arrived")
# do aggregation
aggregation_cls = get_aggregation_class(self.training)
final_model = aggregation_cls().aggregate(
models,
model_sample_sizes,
deepcopy=not self.trainer.options.delete_local_models_after_aggregation
)
# write the result back to database and update the trainings round
self.training.model.set_torch_model(final_model)
self.training.model.round += 1
self.training.model.save()
# clean local updates
if self.trainer.options.delete_local_models_after_aggregation:
model_updates.delete()
def _validate(self, models: List, n_participants: int):
"""
Validate the models and participant number for the training.
This method checks if there are any models and if the number of models matches the number of participants.
If any of these conditions are not met, an error is logged and a `RuntimeError` is raised.
Args:
models (List): The list of models for the training.
n_participants (int): The number of participants in the training.
Raises:
RuntimeError: If there are no models or if the number of models does not match the number of participants.
"""
if not models:
text = f"Aggregation was run for training {self.training.id} but no model updates were in db!"
self._logger.error(text)
raise RuntimeError(text)
if len(models) != n_participants:
text = f"Aggregation was started, but training {self.training.id} has {len(models)} updates," \
f"but {n_participants} clients!"
self._logger.error(text)
raise RuntimeError(text)
Ancestors (in MRO)¶
- fl_server_ai.trainer.events.ModelTrainerEvent
- abc.ABC
Descendants¶
- fl_server_ai.trainer.events.DaisyChainRoundFinished
- fl_server_ai.trainer.events.SWAGRoundFinished
Methods¶
handle¶
Handle the training round finished event.
- aggregate all model updates (local models) into a new global model
- save the new global model into the database (i.e. updates/overwrites the weights field of the model)
- increase the round field of the global model by 1
- delete the model updates (local models) from the database, if the trainer options do not disagree
Note: If not enough updates have arrived, the method does nothing.
View Source
def handle(self):
"""
Handle the training round finished event.
- aggregate all model updates (local models) into a new global model
- save the new global model into the database (i.e. updates/overwrites the weights field of the model)
- increase the round field of the global model by 1
- delete the model updates (local models) from the database, if the trainer options do not disagree
Note: If not enough updates have arrived, the method does nothing.
"""
model_updates = LocalModel.objects.filter(base_model=self.training.model, round=self.training.model.round)
models = [m.get_torch_model() for m in model_updates]
model_sample_sizes = [m.sample_size for m in model_updates]
n_participants = self.training.participants.count()
# validate
self._validate(models, n_participants)
self._logger.info(f"Training {self.training.id}: Doing aggregation as all {n_participants} updates arrived")
# do aggregation
aggregation_cls = get_aggregation_class(self.training)
final_model = aggregation_cls().aggregate(
models,
model_sample_sizes,
deepcopy=not self.trainer.options.delete_local_models_after_aggregation
)
# write the result back to database and update the trainings round
self.training.model.set_torch_model(final_model)
self.training.model.round += 1
self.training.model.save()
# clean local updates
if self.trainer.options.delete_local_models_after_aggregation:
model_updates.delete()
next¶
Proceed with the next event.
View Source
def next(self):
tests_enabled = not self.trainer.options.skip_model_tests
if tests_enabled and self.trainer.options.model_test_after_each_round:
self.trainer.test_round()
elif tests_enabled and self.training.model.round >= self.training.target_num_updates:
# at least test the final trained model
self.trainer.test_round()
else:
ModelTestFinished(self.trainer).next()