Skip to content

Module fl_server_ai.trainer.tasks

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from celery.utils.log import get_task_logger
from django.db import transaction, DatabaseError
from logging import getLogger
from traceback import format_exception
from typing import Type
from uuid import UUID

from fl_server_core.models import Training

from ..celery_tasks import app

from .events import ModelTrainerEvent
from .model_trainer import ModelTrainer


@app.task(bind=False, ignore_result=False)
def process_trainer_task(training_uuid: UUID, event_cls: Type[ModelTrainerEvent]):
    """
    Celery task that processes a dispatched trainer task.

    Args:
        training_uuid (UUID): The UUID of the training.
        event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
    """
    logger = get_task_logger("fl.celery")
    try:
        training = Training.objects.get(id=training_uuid)
        ModelTrainer(training).handle_cls(event_cls)
    except Exception as e:
        error_msg = f"Exception occurred for training {training_uuid}: {e}"
        logger.error(error_msg)
        logger.debug(error_msg + "\n" + "".join(format_exception(e)))
        raise e
    finally:
        logger.info(f"Unlocking training {training_uuid}")
        if training:
            training = Training.objects.get(id=training_uuid)
            training.locked = False
            training.save()


def dispatch_trainer_task(training: Training, event_cls: Type[ModelTrainerEvent], lock_training: bool):
    """
    Dispatch a trainer task asynchronously.

    Args:
        training (Training): The training to dispatch the task for.
        event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
        lock_training (bool): Whether to lock the training.
    """
    logger = getLogger("fl.server")
    if lock_training:
        try:
            with transaction.atomic():
                training.refresh_from_db()
                assert not training.locked

                # set lock and do the aggregation
                training.locked = True
                training.save()

            logger.debug(f"Locking training {training.id}")
        except (DatabaseError, AssertionError):
            logger.debug(f"Training {training.id} is locked!")
            return

    # start task async
    process_trainer_task.s(training_uuid=training.id, event_cls=event_cls).apply_async(retry=False)

Functions

dispatch_trainer_task

def dispatch_trainer_task(
    training: fl_server_core.models.training.Training,
    event_cls: Type[fl_server_ai.trainer.events.base.ModelTrainerEvent],
    lock_training: bool
)

Dispatch a trainer task asynchronously.

Parameters:

Name Type Description Default
training Training The training to dispatch the task for. None
event_cls Type[ModelTrainerEvent] The class of the event to handle. None
lock_training bool Whether to lock the training. None
View Source
def dispatch_trainer_task(training: Training, event_cls: Type[ModelTrainerEvent], lock_training: bool):
    """
    Dispatch a trainer task asynchronously.

    Args:
        training (Training): The training to dispatch the task for.
        event_cls (Type[ModelTrainerEvent]): The class of the event to handle.
        lock_training (bool): Whether to lock the training.
    """
    logger = getLogger("fl.server")
    if lock_training:
        try:
            with transaction.atomic():
                training.refresh_from_db()
                assert not training.locked

                # set lock and do the aggregation
                training.locked = True
                training.save()

            logger.debug(f"Locking training {training.id}")
        except (DatabaseError, AssertionError):
            logger.debug(f"Training {training.id} is locked!")
            return

    # start task async
    process_trainer_task.s(training_uuid=training.id, event_cls=event_cls).apply_async(retry=False)