Skip to content

Module fl_server_ai.aggregation

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from .base import Aggregation
from .mean import MeanAggregation
from .method import get_aggregation_class


__all__ = ["Aggregation", "get_aggregation_class", "MeanAggregation"]

Sub-modules

Functions

get_aggregation_class

def get_aggregation_class(
    value: fl_server_core.models.training.AggregationMethod | fl_server_core.models.model.Model | fl_server_core.models.training.Training
) -> Type[fl_server_ai.aggregation.base.Aggregation]

Get the aggregation class based on the provided model, training or aggregation method.

Parameters:

Name Type Description Default
value AggregationMethod Model Training

Returns:

Type Description
Type[Aggregation] The type of the aggregation class.

Raises:

Type Description
ValueError If the type of value or the aggregation method is unknown.
View Source
def get_aggregation_class(value: AggregationMethod | Model | Training) -> Type[Aggregation]:
    """
    Get the aggregation class based on the provided model, training or aggregation method.

    Args:
        value (AggregationMethod | Model | Training): The value based on which the aggregation class is determined.

    Returns:
        Type[Aggregation]: The type of the aggregation class.

    Raises:
        ValueError: If the type of value or the aggregation method is unknown.
    """
    if isinstance(value, AggregationMethod):
        method = value
    elif isinstance(value, Training):
        method = value.aggregation_method
    elif isinstance(value, Model):
        aggregation_method = Training.objects.filter(model=value) \
                .values("aggregation_method") \
                .first()["aggregation_method"]
        method = aggregation_method
    else:
        raise ValueError(f"Unknown type: {type(value)}")

    match method:
        case AggregationMethod.FED_AVG: return MeanAggregation
        case AggregationMethod.FED_DC: return FedDC
        case AggregationMethod.FED_PROX: return FedProx
        case _: raise ValueError(f"Unknown aggregation method: {method}")

Classes

Aggregation

class Aggregation(
    /,
    *args,
    **kwargs
)

Abstract base class for aggregation strategies.

View Source
class Aggregation(ABC):
    """
    Abstract base class for aggregation strategies.
    """
    _logger = getLogger("fl.server")

    @abstractmethod
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Abstract method for aggregating models.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.
        """
        pass

Ancestors (in MRO)

  • abc.ABC

Descendants

  • fl_server_ai.aggregation.MeanAggregation

Methods

aggregate

def aggregate(
    self,
    models: Sequence[torch.nn.modules.module.Module],
    model_sample_sizes: Sequence[int],
    *,
    deepcopy: bool = True
) -> torch.nn.modules.module.Module

Abstract method for aggregating models.

Parameters:

Name Type Description Default
models Sequence[torch.nn.Module] The models to be aggregated. None
model_sample_sizes Sequence[int] The sample sizes for each model. None
deepcopy bool Whether to create a deep copy of the models. Defaults to True. True

Returns:

Type Description
torch.nn.Module The aggregated model.
View Source
    @abstractmethod
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Abstract method for aggregating models.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.
        """
        pass

MeanAggregation

class MeanAggregation(
    /,
    *args,
    **kwargs
)

Implements the aggregate method for aggregating models by calculating their mean.

View Source
class MeanAggregation(Aggregation):
    """
    Implements the aggregate method for aggregating models by calculating their mean.
    """

    @torch.no_grad()
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Aggregate models by calculating the mean.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.

        Raises:
            AggregationException: If the models do not have the same architecture.
        """
        assert len(models) == len(model_sample_sizes)

        self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
        model_state_dicts = [model.state_dict() for model in models]

        total_dataset_size = model_sample_sizes[0]
        result_dict = model_state_dicts[0]
        for layer_name in result_dict:
            result_dict[layer_name] *= model_sample_sizes[0]

        # sum accumulation
        for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
            if set(model_dict.keys()) != set(result_dict.keys()):
                raise AggregationException("Models do not have the same architecture!")

            total_dataset_size += dataset_size
            for layer_name in result_dict:
                result_dict[layer_name] += model_dict[layer_name] * dataset_size

        # factor 1/n
        for layer_name in result_dict:
            result_dict[layer_name] = result_dict[layer_name] / total_dataset_size

        # return aggregated model
        result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
        result_model.load_state_dict(result_dict)
        return result_model

Ancestors (in MRO)

  • fl_server_ai.aggregation.Aggregation
  • abc.ABC

Descendants

  • fl_server_ai.aggregation.fed_dc.FedDC
  • fl_server_ai.aggregation.fed_prox.FedProx

Methods

aggregate

def aggregate(
    self,
    models: Sequence[torch.nn.modules.module.Module],
    model_sample_sizes: Sequence[int],
    *,
    deepcopy: bool = True
) -> torch.nn.modules.module.Module

Aggregate models by calculating the mean.

Parameters:

Name Type Description Default
models Sequence[torch.nn.Module] The models to be aggregated. None
model_sample_sizes Sequence[int] The sample sizes for each model. None
deepcopy bool Whether to create a deep copy of the models. Defaults to True. True

Returns:

Type Description
torch.nn.Module The aggregated model.

Raises:

Type Description
AggregationException If the models do not have the same architecture.
View Source
    @torch.no_grad()
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Aggregate models by calculating the mean.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.

        Raises:
            AggregationException: If the models do not have the same architecture.
        """
        assert len(models) == len(model_sample_sizes)

        self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
        model_state_dicts = [model.state_dict() for model in models]

        total_dataset_size = model_sample_sizes[0]
        result_dict = model_state_dicts[0]
        for layer_name in result_dict:
            result_dict[layer_name] *= model_sample_sizes[0]

        # sum accumulation
        for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
            if set(model_dict.keys()) != set(result_dict.keys()):
                raise AggregationException("Models do not have the same architecture!")

            total_dataset_size += dataset_size
            for layer_name in result_dict:
                result_dict[layer_name] += model_dict[layer_name] * dataset_size

        # factor 1/n
        for layer_name in result_dict:
            result_dict[layer_name] = result_dict[layer_name] / total_dataset_size

        # return aggregated model
        result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
        result_model.load_state_dict(result_dict)
        return result_model