Skip to content

Module fl_server_ai.trainer

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from .model_trainer import ModelTrainer
from .options import TrainerOptions
from .tasks import process_trainer_task


__all__ = ["ModelTrainer", "process_trainer_task", "TrainerOptions"]

Sub-modules

Variables

process_trainer_task

Classes

ModelTrainer

class ModelTrainer(
    training: fl_server_core.models.training.Training,
    options: Optional[fl_server_ai.trainer.options.TrainerOptions] = None
)

Common model trainer.

View Source
class ModelTrainer:
    """
    Common model trainer.
    """

    def __new__(cls, training: Training, options: Optional[TrainerOptions] = None) -> "ModelTrainer":
        """
        Ensure that the correct trainer class is returned based on the training.

        The returned trainer class is determined by the `get_trainer_class` method.
        It could for example return a `SWAGModelTrainer` if the training uses the SWAG uncertainty method.

        Args:
            training (Training): The training for which to get the trainer.
            options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.

        Returns:
            "ModelTrainer": The trainer instance.
        """
        return super().__new__(cls.get_trainer_class(training))

    @classmethod
    def get_trainer_class(cls, training: Training) -> Type["ModelTrainer"]:
        """
        Get the appropriate model trainer class based on a training object.

        Args:
            training (Training): The training for which to get the trainer class.

        Returns:
            Type["ModelTrainer"]: The appropriate trainer class.
        """
        return get_trainer_class(training)

    def __init__(self, training: Training, options: Optional[TrainerOptions] = None):
        """
        Initialize the trainer with the given training and options.

        Args:
            training (Training): The training to be handled by the trainer.
            options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.
        """
        super().__init__()
        self.training = training
        """The training to be handled by the trainer."""
        self.options = options if options else TrainerOptions()
        """The options for the trainer."""

    def start(self):
        """
        Start the training and send a start notification.
        """
        self.training.state = TrainingState.ONGOING
        self.training.save()
        TrainingStartNotification.from_training(self.training).send()
        TrainingRoundStartNotification.from_training(self.training).send()

    def finish(self):
        """
        Finish the training and send a finished notification.
        """
        self.training.state = TrainingState.COMPLETED
        self.training.save()
        TrainingFinishedNotification.from_training(self.training).send()

    def start_round(self):
        """
        Start a training round and send a round start notification.
        """
        TrainingRoundStartNotification.from_training(self.training).send()

    def test_round(self):
        """
        Test a training round and send a model test notification.
        """
        TrainingModelTestNotification.from_training(self.training).send()

    def handle(self, event: ModelTrainerEvent):
        """
        Handle a model trainer event and proceed to the next event.

        Args:
            event (ModelTrainerEvent): The event to handle.
        """
        event.handle()
        event.next()

    def handle_cls(self, event_cls: Type[ModelTrainerEvent]):
        """
        Handle a model trainer event class by creating an instance of the event and handling it.

        Args:
            event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
        """
        self.handle(event_cls(self))

Descendants

  • fl_server_ai.trainer.model_trainer.SWAGModelTrainer
  • fl_server_ai.trainer.model_trainer.FedDCModelTrainer

Static methods

get_trainer_class

def get_trainer_class(
    training: fl_server_core.models.training.Training
) -> Type[ForwardRef('ModelTrainer')]

Get the appropriate model trainer class based on a training object.

Parameters:

Name Type Description Default
training Training The training for which to get the trainer class. None

Returns:

Type Description
Type["ModelTrainer"] The appropriate trainer class.
View Source
    @classmethod
    def get_trainer_class(cls, training: Training) -> Type["ModelTrainer"]:
        """
        Get the appropriate model trainer class based on a training object.

        Args:
            training (Training): The training for which to get the trainer class.

        Returns:
            Type["ModelTrainer"]: The appropriate trainer class.
        """
        return get_trainer_class(training)

Methods

finish

def finish(
    self
)

Finish the training and send a finished notification.

View Source
    def finish(self):
        """
        Finish the training and send a finished notification.
        """
        self.training.state = TrainingState.COMPLETED
        self.training.save()
        TrainingFinishedNotification.from_training(self.training).send()

handle

def handle(
    self,
    event: fl_server_ai.trainer.events.base.ModelTrainerEvent
)

Handle a model trainer event and proceed to the next event.

Parameters:

Name Type Description Default
event ModelTrainerEvent The event to handle. None
View Source
    def handle(self, event: ModelTrainerEvent):
        """
        Handle a model trainer event and proceed to the next event.

        Args:
            event (ModelTrainerEvent): The event to handle.
        """
        event.handle()
        event.next()

handle_cls

def handle_cls(
    self,
    event_cls: Type[fl_server_ai.trainer.events.base.ModelTrainerEvent]
)

Handle a model trainer event class by creating an instance of the event and handling it.

Parameters:

Name Type Description Default
event_cls Type[ModelTrainerEvent] The class of the event to handle. None
View Source
    def handle_cls(self, event_cls: Type[ModelTrainerEvent]):
        """
        Handle a model trainer event class by creating an instance of the event and handling it.

        Args:
            event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
        """
        self.handle(event_cls(self))

start

def start(
    self
)

Start the training and send a start notification.

View Source
    def start(self):
        """
        Start the training and send a start notification.
        """
        self.training.state = TrainingState.ONGOING
        self.training.save()
        TrainingStartNotification.from_training(self.training).send()
        TrainingRoundStartNotification.from_training(self.training).send()

start_round

def start_round(
    self
)

Start a training round and send a round start notification.

View Source
    def start_round(self):
        """
        Start a training round and send a round start notification.
        """
        TrainingRoundStartNotification.from_training(self.training).send()

test_round

def test_round(
    self
)

Test a training round and send a model test notification.

View Source
    def test_round(self):
        """
        Test a training round and send a model test notification.
        """
        TrainingModelTestNotification.from_training(self.training).send()

TrainerOptions

class TrainerOptions(
    skip_model_tests: bool = False,
    model_test_after_each_round: bool = True,
    delete_local_models_after_aggregation: bool = True
)

Trainer options including their default values.

View Source
@dataclass
class TrainerOptions:
    """
    Trainer options including their default values.
    """

    skip_model_tests: bool = False
    """Flag indicating if model tests should be skipped."""
    model_test_after_each_round: bool = True
    """Flag indicating if a model test should be performed after each round."""
    delete_local_models_after_aggregation: bool = True
    """Flag indicating if local models should be deleted after aggregation."""

Class variables

delete_local_models_after_aggregation
model_test_after_each_round
skip_model_tests