Module fl_server_ai.aggregation.fed_prox¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from .mean import MeanAggregation
class FedProx(MeanAggregation):
"""
To tackle the problem that client models drift away from optimum due to data heterogeneity,
different learning speeds, etc. one proposed solution is FedProx.
FedProx limits the client drift by using a modified learning objective but keeping the standard
FedAvg aggregation method.
Note:
FedProx does not do anything different on the server side than normal FedAvg.
The difference lies in the application of a special loss function on the client side.
Paper: [Federated Optimization in Heterogeneous Networks](https://arxiv.org/abs/1812.06127)
"""
pass
Classes¶
FedProx¶
To tackle the problem that client models drift away from optimum due to data heterogeneity,
different learning speeds, etc. one proposed solution is FedProx. FedProx limits the client drift by using a modified learning objective but keeping the standard FedAvg aggregation method.
Note: FedProx does not do anything different on the server side than normal FedAvg. The difference lies in the application of a special loss function on the client side.
Paper: Federated Optimization in Heterogeneous Networks
View Source
class FedProx(MeanAggregation):
"""
To tackle the problem that client models drift away from optimum due to data heterogeneity,
different learning speeds, etc. one proposed solution is FedProx.
FedProx limits the client drift by using a modified learning objective but keeping the standard
FedAvg aggregation method.
Note:
FedProx does not do anything different on the server side than normal FedAvg.
The difference lies in the application of a special loss function on the client side.
Paper: [Federated Optimization in Heterogeneous Networks](https://arxiv.org/abs/1812.06127)
"""
pass
Ancestors (in MRO)¶
- fl_server_ai.aggregation.mean.MeanAggregation
- fl_server_ai.aggregation.base.Aggregation
- abc.ABC
Methods¶
aggregate¶
def aggregate(
self,
models: Sequence[torch.nn.modules.module.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.modules.module.Module
Aggregate models by calculating the mean.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models | Sequence[torch.nn.Module] | The models to be aggregated. | None |
model_sample_sizes | Sequence[int] | The sample sizes for each model. | None |
deepcopy | bool | Whether to create a deep copy of the models. Defaults to True. | True |
Returns:
Type | Description |
---|---|
torch.nn.Module | The aggregated model. |
Raises:
Type | Description |
---|---|
AggregationException | If the models do not have the same architecture. |
View Source
@torch.no_grad()
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Aggregate models by calculating the mean.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
Raises:
AggregationException: If the models do not have the same architecture.
"""
assert len(models) == len(model_sample_sizes)
self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
model_state_dicts = [model.state_dict() for model in models]
total_dataset_size = model_sample_sizes[0]
result_dict = model_state_dicts[0]
for layer_name in result_dict:
result_dict[layer_name] *= model_sample_sizes[0]
# sum accumulation
for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
if set(model_dict.keys()) != set(result_dict.keys()):
raise AggregationException("Models do not have the same architecture!")
total_dataset_size += dataset_size
for layer_name in result_dict:
result_dict[layer_name] += model_dict[layer_name] * dataset_size
# factor 1/n
for layer_name in result_dict:
result_dict[layer_name] = result_dict[layer_name] / total_dataset_size
# return aggregated model
result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
result_model.load_state_dict(result_dict)
return result_model