Module fl_server_ai.aggregation.mean¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
import copy
import torch
from typing import Sequence
from ..exceptions import AggregationException
from .base import Aggregation
class MeanAggregation(Aggregation):
"""
Implements the aggregate method for aggregating models by calculating their mean.
"""
@torch.no_grad()
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Aggregate models by calculating the mean.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
Raises:
AggregationException: If the models do not have the same architecture.
"""
assert len(models) == len(model_sample_sizes)
self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
model_state_dicts = [model.state_dict() for model in models]
total_dataset_size = model_sample_sizes[0]
result_dict = model_state_dicts[0]
for layer_name in result_dict:
result_dict[layer_name] *= model_sample_sizes[0]
# sum accumulation
for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
if set(model_dict.keys()) != set(result_dict.keys()):
raise AggregationException("Models do not have the same architecture!")
total_dataset_size += dataset_size
for layer_name in result_dict:
result_dict[layer_name] += model_dict[layer_name] * dataset_size
# factor 1/n
for layer_name in result_dict:
result_dict[layer_name] = result_dict[layer_name] / total_dataset_size
# return aggregated model
result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
result_model.load_state_dict(result_dict)
return result_model
Classes¶
MeanAggregation¶
Implements the aggregate method for aggregating models by calculating their mean.
View Source
class MeanAggregation(Aggregation):
"""
Implements the aggregate method for aggregating models by calculating their mean.
"""
@torch.no_grad()
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Aggregate models by calculating the mean.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
Raises:
AggregationException: If the models do not have the same architecture.
"""
assert len(models) == len(model_sample_sizes)
self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
model_state_dicts = [model.state_dict() for model in models]
total_dataset_size = model_sample_sizes[0]
result_dict = model_state_dicts[0]
for layer_name in result_dict:
result_dict[layer_name] *= model_sample_sizes[0]
# sum accumulation
for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
if set(model_dict.keys()) != set(result_dict.keys()):
raise AggregationException("Models do not have the same architecture!")
total_dataset_size += dataset_size
for layer_name in result_dict:
result_dict[layer_name] += model_dict[layer_name] * dataset_size
# factor 1/n
for layer_name in result_dict:
result_dict[layer_name] = result_dict[layer_name] / total_dataset_size
# return aggregated model
result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
result_model.load_state_dict(result_dict)
return result_model
Ancestors (in MRO)¶
- fl_server_ai.aggregation.base.Aggregation
- abc.ABC
Descendants¶
- fl_server_ai.aggregation.fed_dc.FedDC
- fl_server_ai.aggregation.fed_prox.FedProx
Methods¶
aggregate¶
def aggregate(
self,
models: Sequence[torch.nn.modules.module.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.modules.module.Module
Aggregate models by calculating the mean.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models | Sequence[torch.nn.Module] | The models to be aggregated. | None |
model_sample_sizes | Sequence[int] | The sample sizes for each model. | None |
deepcopy | bool | Whether to create a deep copy of the models. Defaults to True. | True |
Returns:
Type | Description |
---|---|
torch.nn.Module | The aggregated model. |
Raises:
Type | Description |
---|---|
AggregationException | If the models do not have the same architecture. |
View Source
@torch.no_grad()
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Aggregate models by calculating the mean.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
Raises:
AggregationException: If the models do not have the same architecture.
"""
assert len(models) == len(model_sample_sizes)
self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
model_state_dicts = [model.state_dict() for model in models]
total_dataset_size = model_sample_sizes[0]
result_dict = model_state_dicts[0]
for layer_name in result_dict:
result_dict[layer_name] *= model_sample_sizes[0]
# sum accumulation
for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
if set(model_dict.keys()) != set(result_dict.keys()):
raise AggregationException("Models do not have the same architecture!")
total_dataset_size += dataset_size
for layer_name in result_dict:
result_dict[layer_name] += model_dict[layer_name] * dataset_size
# factor 1/n
for layer_name in result_dict:
result_dict[layer_name] = result_dict[layer_name] / total_dataset_size
# return aggregated model
result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
result_model.load_state_dict(result_dict)
return result_model