Skip to content

Module fl_server_ai.aggregation.fed_dc

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from .mean import MeanAggregation


class FedDC(MeanAggregation):
    """
    FedDC (Federated daisy-chaining)

    To tackle the problem that client models are potentially quite small and thus the models tend to overfit and
    therefore result in bad prediction quality on unseen data, one proposed solution is
    FedDC (also named Federated daisy-chaining).
    FedDC sends before each aggregation step each client model to another randomly selected client, which trains
    it on its local data.
    From the model perspective, it is as if the model is trained on a larger dataset.

    Paper: [Picking Daisies in Private: Federated Learning from Small Datasets](https://openreview.net/forum?id=GVDwiINkMR)
    """  # noqa: E501
    pass

Classes

FedDC

class FedDC(
    /,
    *args,
    **kwargs
)

FedDC (Federated daisy-chaining)

To tackle the problem that client models are potentially quite small and thus the models tend to overfit and therefore result in bad prediction quality on unseen data, one proposed solution is FedDC (also named Federated daisy-chaining). FedDC sends before each aggregation step each client model to another randomly selected client, which trains it on its local data. From the model perspective, it is as if the model is trained on a larger dataset.

Paper: Picking Daisies in Private: Federated Learning from Small Datasets

View Source
class FedDC(MeanAggregation):
    """
    FedDC (Federated daisy-chaining)

    To tackle the problem that client models are potentially quite small and thus the models tend to overfit and
    therefore result in bad prediction quality on unseen data, one proposed solution is
    FedDC (also named Federated daisy-chaining).
    FedDC sends before each aggregation step each client model to another randomly selected client, which trains
    it on its local data.
    From the model perspective, it is as if the model is trained on a larger dataset.

    Paper: [Picking Daisies in Private: Federated Learning from Small Datasets](https://openreview.net/forum?id=GVDwiINkMR)
    """  # noqa: E501
    pass

Ancestors (in MRO)

  • fl_server_ai.aggregation.mean.MeanAggregation
  • fl_server_ai.aggregation.base.Aggregation
  • abc.ABC

Methods

aggregate

def aggregate(
    self,
    models: Sequence[torch.nn.modules.module.Module],
    model_sample_sizes: Sequence[int],
    *,
    deepcopy: bool = True
) -> torch.nn.modules.module.Module

Aggregate models by calculating the mean.

Parameters:

Name Type Description Default
models Sequence[torch.nn.Module] The models to be aggregated. None
model_sample_sizes Sequence[int] The sample sizes for each model. None
deepcopy bool Whether to create a deep copy of the models. Defaults to True. True

Returns:

Type Description
torch.nn.Module The aggregated model.

Raises:

Type Description
AggregationException If the models do not have the same architecture.
View Source
    @torch.no_grad()
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Aggregate models by calculating the mean.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.

        Raises:
            AggregationException: If the models do not have the same architecture.
        """
        assert len(models) == len(model_sample_sizes)

        self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
        model_state_dicts = [model.state_dict() for model in models]

        total_dataset_size = model_sample_sizes[0]
        result_dict = model_state_dicts[0]
        for layer_name in result_dict:
            result_dict[layer_name] *= model_sample_sizes[0]

        # sum accumulation
        for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
            if set(model_dict.keys()) != set(result_dict.keys()):
                raise AggregationException("Models do not have the same architecture!")

            total_dataset_size += dataset_size
            for layer_name in result_dict:
                result_dict[layer_name] += model_dict[layer_name] * dataset_size

        # factor 1/n
        for layer_name in result_dict:
            result_dict[layer_name] = result_dict[layer_name] / total_dataset_size

        # return aggregated model
        result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
        result_model.load_state_dict(result_dict)
        return result_model