Skip to content

Module fl_server_ai.aggregation.base

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from logging import getLogger
import torch
from typing import Sequence


class Aggregation(ABC):
    """
    Abstract base class for aggregation strategies.
    """
    _logger = getLogger("fl.server")

    @abstractmethod
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Abstract method for aggregating models.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.
        """
        pass

Classes

Aggregation

class Aggregation(
    /,
    *args,
    **kwargs
)

Abstract base class for aggregation strategies.

View Source
class Aggregation(ABC):
    """
    Abstract base class for aggregation strategies.
    """
    _logger = getLogger("fl.server")

    @abstractmethod
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Abstract method for aggregating models.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.
        """
        pass

Ancestors (in MRO)

  • abc.ABC

Descendants

  • fl_server_ai.aggregation.mean.MeanAggregation

Methods

aggregate

def aggregate(
    self,
    models: Sequence[torch.nn.modules.module.Module],
    model_sample_sizes: Sequence[int],
    *,
    deepcopy: bool = True
) -> torch.nn.modules.module.Module

Abstract method for aggregating models.

Parameters:

Name Type Description Default
models Sequence[torch.nn.Module] The models to be aggregated. None
model_sample_sizes Sequence[int] The sample sizes for each model. None
deepcopy bool Whether to create a deep copy of the models. Defaults to True. True

Returns:

Type Description
torch.nn.Module The aggregated model.
View Source
    @abstractmethod
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Abstract method for aggregating models.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.
        """
        pass