Skip to content

Module fl_server_ai.trainer.events.model_test_finished

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from .base import ModelTrainerEvent


class ModelTestFinished(ModelTrainerEvent):
    """
    Model test finished event.
    """

    def next(self):
        if self.training.model.round < self.training.target_num_updates:
            self.trainer.start_round()
        else:
            self.trainer.finish()

    def handle(self):
        # currently do nothing
        # Potentially, one could aggregate all common metrics here.
        pass

Classes

ModelTestFinished

class ModelTestFinished(
    trainer: 'model_trainer.ModelTrainer'
)

Model test finished event.

View Source
class ModelTestFinished(ModelTrainerEvent):
    """
    Model test finished event.
    """

    def next(self):
        if self.training.model.round < self.training.target_num_updates:
            self.trainer.start_round()
        else:
            self.trainer.finish()

    def handle(self):
        # currently do nothing
        # Potentially, one could aggregate all common metrics here.
        pass

Ancestors (in MRO)

  • fl_server_ai.trainer.events.base.ModelTrainerEvent
  • abc.ABC

Methods

handle

def handle(
    self
)

Handle the event.

View Source
    def handle(self):
        # currently do nothing
        # Potentially, one could aggregate all common metrics here.
        pass

next

def next(
    self
)

Proceed with the next event.

View Source
    def next(self):
        if self.training.model.round < self.training.target_num_updates:
            self.trainer.start_round()
        else:
            self.trainer.finish()