Module fl_server_ai.trainer¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from .model_trainer import ModelTrainer
from .options import TrainerOptions
from .tasks import process_trainer_task
__all__ = ["ModelTrainer", "process_trainer_task", "TrainerOptions"]
Sub-modules¶
- fl_server_ai.trainer.events
- fl_server_ai.trainer.model_trainer
- fl_server_ai.trainer.options
- fl_server_ai.trainer.tasks
Variables¶
Classes¶
ModelTrainer¶
class ModelTrainer(
training: fl_server_core.models.training.Training,
options: Optional[fl_server_ai.trainer.options.TrainerOptions] = None
)
Common model trainer.
View Source
class ModelTrainer:
"""
Common model trainer.
"""
def __new__(cls, training: Training, options: Optional[TrainerOptions] = None) -> "ModelTrainer":
"""
Ensure that the correct trainer class is returned based on the training.
The returned trainer class is determined by the `get_trainer_class` method.
It could for example return a `SWAGModelTrainer` if the training uses the SWAG uncertainty method.
Args:
training (Training): The training for which to get the trainer.
options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.
Returns:
"ModelTrainer": The trainer instance.
"""
return super().__new__(cls.get_trainer_class(training))
@classmethod
def get_trainer_class(cls, training: Training) -> Type["ModelTrainer"]:
"""
Get the appropriate model trainer class based on a training object.
Args:
training (Training): The training for which to get the trainer class.
Returns:
Type["ModelTrainer"]: The appropriate trainer class.
"""
return get_trainer_class(training)
def __init__(self, training: Training, options: Optional[TrainerOptions] = None):
"""
Initialize the trainer with the given training and options.
Args:
training (Training): The training to be handled by the trainer.
options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.
"""
super().__init__()
self.training = training
"""The training to be handled by the trainer."""
self.options = options if options else TrainerOptions()
"""The options for the trainer."""
def start(self):
"""
Start the training and send a start notification.
"""
self.training.state = TrainingState.ONGOING
self.training.save()
TrainingStartNotification.from_training(self.training).send()
TrainingRoundStartNotification.from_training(self.training).send()
def finish(self):
"""
Finish the training and send a finished notification.
"""
self.training.state = TrainingState.COMPLETED
self.training.save()
TrainingFinishedNotification.from_training(self.training).send()
def start_round(self):
"""
Start a training round and send a round start notification.
"""
TrainingRoundStartNotification.from_training(self.training).send()
def test_round(self):
"""
Test a training round and send a model test notification.
"""
TrainingModelTestNotification.from_training(self.training).send()
def handle(self, event: ModelTrainerEvent):
"""
Handle a model trainer event and proceed to the next event.
Args:
event (ModelTrainerEvent): The event to handle.
"""
event.handle()
event.next()
def handle_cls(self, event_cls: Type[ModelTrainerEvent]):
"""
Handle a model trainer event class by creating an instance of the event and handling it.
Args:
event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
"""
self.handle(event_cls(self))
Descendants¶
- fl_server_ai.trainer.model_trainer.SWAGModelTrainer
- fl_server_ai.trainer.model_trainer.FedDCModelTrainer
Static methods¶
get_trainer_class¶
def get_trainer_class(
training: fl_server_core.models.training.Training
) -> Type[ForwardRef('ModelTrainer')]
Get the appropriate model trainer class based on a training object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
training | Training | The training for which to get the trainer class. | None |
Returns:
Type | Description |
---|---|
Type["ModelTrainer"] | The appropriate trainer class. |
View Source
@classmethod
def get_trainer_class(cls, training: Training) -> Type["ModelTrainer"]:
"""
Get the appropriate model trainer class based on a training object.
Args:
training (Training): The training for which to get the trainer class.
Returns:
Type["ModelTrainer"]: The appropriate trainer class.
"""
return get_trainer_class(training)
Methods¶
finish¶
Finish the training and send a finished notification.
View Source
handle¶
Handle a model trainer event and proceed to the next event.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event | ModelTrainerEvent | The event to handle. | None |
View Source
handle_cls¶
Handle a model trainer event class by creating an instance of the event and handling it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event_cls | Type[ModelTrainerEvent] | The class of the event to handle. | None |
View Source
start¶
Start the training and send a start notification.
View Source
start_round¶
Start a training round and send a round start notification.
View Source
test_round¶
Test a training round and send a model test notification.
View Source
TrainerOptions¶
class TrainerOptions(
skip_model_tests: bool = False,
model_test_after_each_round: bool = True,
delete_local_models_after_aggregation: bool = True
)
Trainer options including their default values.
View Source
@dataclass
class TrainerOptions:
"""
Trainer options including their default values.
"""
skip_model_tests: bool = False
"""Flag indicating if model tests should be skipped."""
model_test_after_each_round: bool = True
"""Flag indicating if a model test should be performed after each round."""
delete_local_models_after_aggregation: bool = True
"""Flag indicating if local models should be deleted after aggregation."""