Module fl_server_ai.aggregation¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from .base import Aggregation
from .mean import MeanAggregation
from .method import get_aggregation_class
__all__ = ["Aggregation", "get_aggregation_class", "MeanAggregation"]
Sub-modules¶
- fl_server_ai.aggregation.base
- fl_server_ai.aggregation.fed_dc
- fl_server_ai.aggregation.fed_prox
- fl_server_ai.aggregation.mean
- fl_server_ai.aggregation.method
Functions¶
get_aggregation_class¶
def get_aggregation_class(
value: fl_server_core.models.training.AggregationMethod | fl_server_core.models.model.Model | fl_server_core.models.training.Training
) -> Type[fl_server_ai.aggregation.base.Aggregation]
Get the aggregation class based on the provided model, training or aggregation method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value | AggregationMethod | Model | Training |
Returns:
Type | Description |
---|---|
Type[Aggregation] | The type of the aggregation class. |
Raises:
Type | Description |
---|---|
ValueError | If the type of value or the aggregation method is unknown. |
View Source
def get_aggregation_class(value: AggregationMethod | Model | Training) -> Type[Aggregation]:
"""
Get the aggregation class based on the provided model, training or aggregation method.
Args:
value (AggregationMethod | Model | Training): The value based on which the aggregation class is determined.
Returns:
Type[Aggregation]: The type of the aggregation class.
Raises:
ValueError: If the type of value or the aggregation method is unknown.
"""
if isinstance(value, AggregationMethod):
method = value
elif isinstance(value, Training):
method = value.aggregation_method
elif isinstance(value, Model):
aggregation_method = Training.objects.filter(model=value) \
.values("aggregation_method") \
.first()["aggregation_method"]
method = aggregation_method
else:
raise ValueError(f"Unknown type: {type(value)}")
match method:
case AggregationMethod.FED_AVG: return MeanAggregation
case AggregationMethod.FED_DC: return FedDC
case AggregationMethod.FED_PROX: return FedProx
case _: raise ValueError(f"Unknown aggregation method: {method}")
Classes¶
Aggregation¶
Abstract base class for aggregation strategies.
View Source
class Aggregation(ABC):
"""
Abstract base class for aggregation strategies.
"""
_logger = getLogger("fl.server")
@abstractmethod
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Abstract method for aggregating models.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
"""
pass
Ancestors (in MRO)¶
- abc.ABC
Descendants¶
- fl_server_ai.aggregation.MeanAggregation
Methods¶
aggregate¶
def aggregate(
self,
models: Sequence[torch.nn.modules.module.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.modules.module.Module
Abstract method for aggregating models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models | Sequence[torch.nn.Module] | The models to be aggregated. | None |
model_sample_sizes | Sequence[int] | The sample sizes for each model. | None |
deepcopy | bool | Whether to create a deep copy of the models. Defaults to True. | True |
Returns:
Type | Description |
---|---|
torch.nn.Module | The aggregated model. |
View Source
@abstractmethod
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Abstract method for aggregating models.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
"""
pass
MeanAggregation¶
Implements the aggregate method for aggregating models by calculating their mean.
View Source
class MeanAggregation(Aggregation):
"""
Implements the aggregate method for aggregating models by calculating their mean.
"""
@torch.no_grad()
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Aggregate models by calculating the mean.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
Raises:
AggregationException: If the models do not have the same architecture.
"""
assert len(models) == len(model_sample_sizes)
self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
model_state_dicts = [model.state_dict() for model in models]
total_dataset_size = model_sample_sizes[0]
result_dict = model_state_dicts[0]
for layer_name in result_dict:
result_dict[layer_name] *= model_sample_sizes[0]
# sum accumulation
for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
if set(model_dict.keys()) != set(result_dict.keys()):
raise AggregationException("Models do not have the same architecture!")
total_dataset_size += dataset_size
for layer_name in result_dict:
result_dict[layer_name] += model_dict[layer_name] * dataset_size
# factor 1/n
for layer_name in result_dict:
result_dict[layer_name] = result_dict[layer_name] / total_dataset_size
# return aggregated model
result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
result_model.load_state_dict(result_dict)
return result_model
Ancestors (in MRO)¶
- fl_server_ai.aggregation.Aggregation
- abc.ABC
Descendants¶
- fl_server_ai.aggregation.fed_dc.FedDC
- fl_server_ai.aggregation.fed_prox.FedProx
Methods¶
aggregate¶
def aggregate(
self,
models: Sequence[torch.nn.modules.module.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.modules.module.Module
Aggregate models by calculating the mean.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models | Sequence[torch.nn.Module] | The models to be aggregated. | None |
model_sample_sizes | Sequence[int] | The sample sizes for each model. | None |
deepcopy | bool | Whether to create a deep copy of the models. Defaults to True. | True |
Returns:
Type | Description |
---|---|
torch.nn.Module | The aggregated model. |
Raises:
Type | Description |
---|---|
AggregationException | If the models do not have the same architecture. |
View Source
@torch.no_grad()
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Aggregate models by calculating the mean.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
Raises:
AggregationException: If the models do not have the same architecture.
"""
assert len(models) == len(model_sample_sizes)
self._logger.debug(f"Doing mean aggregation for {len(models)} models!")
model_state_dicts = [model.state_dict() for model in models]
total_dataset_size = model_sample_sizes[0]
result_dict = model_state_dicts[0]
for layer_name in result_dict:
result_dict[layer_name] *= model_sample_sizes[0]
# sum accumulation
for model_dict, dataset_size in zip(model_state_dicts[1:], model_sample_sizes[1:]):
if set(model_dict.keys()) != set(result_dict.keys()):
raise AggregationException("Models do not have the same architecture!")
total_dataset_size += dataset_size
for layer_name in result_dict:
result_dict[layer_name] += model_dict[layer_name] * dataset_size
# factor 1/n
for layer_name in result_dict:
result_dict[layer_name] = result_dict[layer_name] / total_dataset_size
# return aggregated model
result_model = copy.deepcopy(models[0]) if deepcopy else models[0]
result_model.load_state_dict(result_dict)
return result_model