Skip to content

fl_server_ai.trainer

Modules:

Name Description
events
model_trainer
options
tasks

Classes:

Name Description
ModelTrainer

Common model trainer.

TrainerOptions

Trainer options including their default values.

Functions:

Name Description
process_trainer_task

Celery task that processes a dispatched trainer task.

Attributes

__all__ module-attribute

__all__ = ['ModelTrainer', 'process_trainer_task', 'TrainerOptions']

Classes

ModelTrainer

Common model trainer.

Methods:

Name Description
__init__

Initialize the trainer with the given training and options.

__new__

Ensure that the correct trainer class is returned based on the training.

finish

Finish the training and send a finished notification.

get_trainer_class

Get the appropriate model trainer class based on a training object.

handle

Handle a model trainer event and proceed to the next event.

handle_cls

Handle a model trainer event class by creating an instance of the event and handling it.

start

Start the training and send a start notification.

start_round

Start a training round and send a round start notification.

test_round

Test a training round and send a model test notification.

Attributes:

Name Type Description
options

The options for the trainer.

training

The training to be handled by the trainer.

Source code in fl_server_ai/trainer/model_trainer.py
class ModelTrainer:
    """
    Common model trainer.
    """

    def __new__(cls, training: Training, options: Optional[TrainerOptions] = None) -> "ModelTrainer":
        """
        Ensure that the correct trainer class is returned based on the training.

        The returned trainer class is determined by the `get_trainer_class` method.
        It could for example return a `SWAGModelTrainer` if the training uses the SWAG uncertainty method.

        Args:
            training (Training): The training for which to get the trainer.
            options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.

        Returns:
            "ModelTrainer": The trainer instance.
        """
        return super().__new__(cls.get_trainer_class(training))

    @classmethod
    def get_trainer_class(cls, training: Training) -> Type["ModelTrainer"]:
        """
        Get the appropriate model trainer class based on a training object.

        Args:
            training (Training): The training for which to get the trainer class.

        Returns:
            Type["ModelTrainer"]: The appropriate trainer class.
        """
        return get_trainer_class(training)

    def __init__(self, training: Training, options: Optional[TrainerOptions] = None):
        """
        Initialize the trainer with the given training and options.

        Args:
            training (Training): The training to be handled by the trainer.
            options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.
        """
        super().__init__()
        self.training = training
        """The training to be handled by the trainer."""
        self.options = options if options else TrainerOptions()
        """The options for the trainer."""

    def start(self):
        """
        Start the training and send a start notification.
        """
        self.training.state = TrainingState.ONGOING
        self.training.save()
        TrainingStartNotification.from_training(self.training).send()
        TrainingRoundStartNotification.from_training(self.training).send()

    def finish(self):
        """
        Finish the training and send a finished notification.
        """
        self.training.state = TrainingState.COMPLETED
        self.training.save()
        TrainingFinishedNotification.from_training(self.training).send()

    def start_round(self):
        """
        Start a training round and send a round start notification.
        """
        TrainingRoundStartNotification.from_training(self.training).send()

    def test_round(self):
        """
        Test a training round and send a model test notification.
        """
        TrainingModelTestNotification.from_training(self.training).send()

    def handle(self, event: ModelTrainerEvent):
        """
        Handle a model trainer event and proceed to the next event.

        Args:
            event (ModelTrainerEvent): The event to handle.
        """
        event.handle()
        event.next()

    def handle_cls(self, event_cls: Type[ModelTrainerEvent]):
        """
        Handle a model trainer event class by creating an instance of the event and handling it.

        Args:
            event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
        """
        self.handle(event_cls(self))

Attributes

options instance-attribute
options = options if options else TrainerOptions()

The options for the trainer.

training instance-attribute
training = training

The training to be handled by the trainer.

Functions

__init__
__init__(training: Training, options: TrainerOptions | None = None)

Initialize the trainer with the given training and options.

Parameters:

Name Type Description Default
training
Training

The training to be handled by the trainer.

required
options
TrainerOptions | None

The options for the trainer. If None, default options will be used.

None
Source code in fl_server_ai/trainer/model_trainer.py
def __init__(self, training: Training, options: Optional[TrainerOptions] = None):
    """
    Initialize the trainer with the given training and options.

    Args:
        training (Training): The training to be handled by the trainer.
        options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.
    """
    super().__init__()
    self.training = training
    """The training to be handled by the trainer."""
    self.options = options if options else TrainerOptions()
    """The options for the trainer."""
__new__
__new__(training: Training, options: TrainerOptions | None = None) -> ModelTrainer

Ensure that the correct trainer class is returned based on the training.

The returned trainer class is determined by the get_trainer_class method. It could for example return a SWAGModelTrainer if the training uses the SWAG uncertainty method.

Parameters:

Name Type Description Default
training
Training

The training for which to get the trainer.

required
options
TrainerOptions | None

The options for the trainer. If None, default options will be used.

None

Returns:

Type Description
ModelTrainer

"ModelTrainer": The trainer instance.

Source code in fl_server_ai/trainer/model_trainer.py
def __new__(cls, training: Training, options: Optional[TrainerOptions] = None) -> "ModelTrainer":
    """
    Ensure that the correct trainer class is returned based on the training.

    The returned trainer class is determined by the `get_trainer_class` method.
    It could for example return a `SWAGModelTrainer` if the training uses the SWAG uncertainty method.

    Args:
        training (Training): The training for which to get the trainer.
        options (Optional[TrainerOptions]): The options for the trainer. If None, default options will be used.

    Returns:
        "ModelTrainer": The trainer instance.
    """
    return super().__new__(cls.get_trainer_class(training))
finish
finish()

Finish the training and send a finished notification.

Source code in fl_server_ai/trainer/model_trainer.py
def finish(self):
    """
    Finish the training and send a finished notification.
    """
    self.training.state = TrainingState.COMPLETED
    self.training.save()
    TrainingFinishedNotification.from_training(self.training).send()
get_trainer_class classmethod
get_trainer_class(training: Training) -> Type[ModelTrainer]

Get the appropriate model trainer class based on a training object.

Parameters:

Name Type Description Default
training
Training

The training for which to get the trainer class.

required

Returns:

Type Description
Type[ModelTrainer]

Type["ModelTrainer"]: The appropriate trainer class.

Source code in fl_server_ai/trainer/model_trainer.py
@classmethod
def get_trainer_class(cls, training: Training) -> Type["ModelTrainer"]:
    """
    Get the appropriate model trainer class based on a training object.

    Args:
        training (Training): The training for which to get the trainer class.

    Returns:
        Type["ModelTrainer"]: The appropriate trainer class.
    """
    return get_trainer_class(training)
handle

Handle a model trainer event and proceed to the next event.

Parameters:

Name Type Description Default
event
ModelTrainerEvent

The event to handle.

required
Source code in fl_server_ai/trainer/model_trainer.py
def handle(self, event: ModelTrainerEvent):
    """
    Handle a model trainer event and proceed to the next event.

    Args:
        event (ModelTrainerEvent): The event to handle.
    """
    event.handle()
    event.next()
handle_cls

Handle a model trainer event class by creating an instance of the event and handling it.

Parameters:

Name Type Description Default
event_cls
Type[ModelTrainerEvent]

The class of the event to handle.

required
Source code in fl_server_ai/trainer/model_trainer.py
def handle_cls(self, event_cls: Type[ModelTrainerEvent]):
    """
    Handle a model trainer event class by creating an instance of the event and handling it.

    Args:
        event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
    """
    self.handle(event_cls(self))
start
start()

Start the training and send a start notification.

Source code in fl_server_ai/trainer/model_trainer.py
def start(self):
    """
    Start the training and send a start notification.
    """
    self.training.state = TrainingState.ONGOING
    self.training.save()
    TrainingStartNotification.from_training(self.training).send()
    TrainingRoundStartNotification.from_training(self.training).send()
start_round
start_round()

Start a training round and send a round start notification.

Source code in fl_server_ai/trainer/model_trainer.py
def start_round(self):
    """
    Start a training round and send a round start notification.
    """
    TrainingRoundStartNotification.from_training(self.training).send()
test_round
test_round()

Test a training round and send a model test notification.

Source code in fl_server_ai/trainer/model_trainer.py
def test_round(self):
    """
    Test a training round and send a model test notification.
    """
    TrainingModelTestNotification.from_training(self.training).send()

TrainerOptions dataclass

Trainer options including their default values.

Methods:

Name Description
__init__

Attributes:

Name Type Description
delete_local_models_after_aggregation bool

Flag indicating if local models should be deleted after aggregation.

model_test_after_each_round bool

Flag indicating if a model test should be performed after each round.

skip_model_tests bool

Flag indicating if model tests should be skipped.

Source code in fl_server_ai/trainer/options.py
@dataclass
class TrainerOptions:
    """
    Trainer options including their default values.
    """

    skip_model_tests: bool = False
    """Flag indicating if model tests should be skipped."""
    model_test_after_each_round: bool = True
    """Flag indicating if a model test should be performed after each round."""
    delete_local_models_after_aggregation: bool = True
    """Flag indicating if local models should be deleted after aggregation."""

Attributes

delete_local_models_after_aggregation class-attribute instance-attribute
delete_local_models_after_aggregation: bool = True

Flag indicating if local models should be deleted after aggregation.

model_test_after_each_round class-attribute instance-attribute
model_test_after_each_round: bool = True

Flag indicating if a model test should be performed after each round.

skip_model_tests class-attribute instance-attribute
skip_model_tests: bool = False

Flag indicating if model tests should be skipped.

Functions

__init__
__init__(skip_model_tests: bool = False, model_test_after_each_round: bool = True, delete_local_models_after_aggregation: bool = True) -> None

Functions

process_trainer_task

Celery task that processes a dispatched trainer task.

Parameters:

Name Type Description Default

training_uuid

UUID

The UUID of the training.

required

event_cls

Type[ModelTrainerEvent]

The class of the event to handle.

required
Source code in fl_server_ai/trainer/tasks.py
@app.task(bind=False, ignore_result=False)
def process_trainer_task(training_uuid: UUID, event_cls: Type[ModelTrainerEvent]):
    """
    Celery task that processes a dispatched trainer task.

    Args:
        training_uuid (UUID): The UUID of the training.
        event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
    """
    logger = get_task_logger("fl.celery")
    try:
        training = Training.objects.get(id=training_uuid)
        ModelTrainer(training).handle_cls(event_cls)
    except Exception as e:
        error_msg = f"Exception occurred for training {training_uuid}: {e}"
        logger.error(error_msg)
        logger.debug(error_msg + "\n" + "".join(format_exception(e)))
        raise e
    finally:
        logger.info(f"Unlocking training {training_uuid}")
        if training:
            training = Training.objects.get(id=training_uuid)
            training.locked = False
            training.save()