Skip to content

Module fl_server_ai.aggregation.fed_prox

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from .mean import MeanAggregation


class FedProx(MeanAggregation):
    """
    To tackle the problem that client models drift away from optimum due to data heterogeneity,
    different learning speeds, etc. one proposed solution is FedProx.
    FedProx limits the client drift by using a modified learning objective but keeping the standard
    FedAvg aggregation method.

    Note:
    FedProx does not do anything different on the server side than normal FedAvg.
    The difference lies in the application of a special loss function on the client side.

    Paper: [Federated Optimization in Heterogeneous Networks](https://arxiv.org/abs/1812.06127)
    """
    pass

Classes

FedProx

class FedProx(
    /,
    *args,
    **kwargs
)

To tackle the problem that client models drift away from optimum due to data heterogeneity,

different learning speeds, etc. one proposed solution is FedProx. FedProx limits the client drift by using a modified learning objective but keeping the standard FedAvg aggregation method.

Note: FedProx does not do anything different on the server side than normal FedAvg. The difference lies in the application of a special loss function on the client side.

Paper: Federated Optimization in Heterogeneous Networks

View Source
class FedProx(MeanAggregation):
    """
    To tackle the problem that client models drift away from optimum due to data heterogeneity,
    different learning speeds, etc. one proposed solution is FedProx.
    FedProx limits the client drift by using a modified learning objective but keeping the standard
    FedAvg aggregation method.

    Note:
    FedProx does not do anything different on the server side than normal FedAvg.
    The difference lies in the application of a special loss function on the client side.

    Paper: [Federated Optimization in Heterogeneous Networks](https://arxiv.org/abs/1812.06127)
    """
    pass

Ancestors (in MRO)

  • fl_server_ai.aggregation.mean.MeanAggregation
  • fl_server_ai.aggregation.base.Aggregation
  • abc.ABC

Methods

aggregate

def aggregate(
    self,
    models: Sequence[torch.nn.modules.module.Module],
    model_sample_sizes: Sequence[int],
    *,
    deepcopy: bool = True
) -> torch.nn.modules.module.Module

Aggregate models by calculating the mean.

Parameters:

Name Type Description Default
models Sequence[torch.nn.Module] The models to be aggregated. None
model_sample_sizes Sequence[int] The sample sizes for each model. None
deepcopy bool Whether to create a deep copy of the models. Defaults to True. True

Returns:

Type Description
torch.nn.Module The aggregated model.

Raises:

Type Description
AggregationException If the models do not have the same architecture.
View Source
    @torch.no_grad()
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Aggregate models by calculating the mean.

        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.

        Returns:
            torch.nn.Module: The aggregated model.

        Raises:
            AggregationException: If the models do not have the same architecture.
        """
        assert len(models) == len(model_sample_sizes)

        self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
        model_state_dicts = [model.state_dict() for model in models]

        total_dataset_size = model_sample_sizes[0]
        result_dict = model_state_dicts[0]
        for layer_name in result_dict:
            result_dict[layer_name] *= model_sample_sizes[0]

        # sum accumulation
        for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
            if set(model_dict.keys()) != set(result_dict.keys()):
                raise AggregationException("Models do not have the same architecture!")

            total_dataset_size += dataset_size
            for layer_name in result_dict:
                result_dict[layer_name] += model_dict[layer_name] * dataset_size

        # factor 1/n
        for layer_name in result_dict:
            result_dict[layer_name] = result_dict[layer_name] / total_dataset_size

        # return aggregated model
        result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
        result_model.load_state_dict(result_dict)
        return result_model