Module fl_server_ai.aggregation.base¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from logging import getLogger
import torch
from typing import Sequence
class Aggregation(ABC):
"""
Abstract base class for aggregation strategies.
"""
_logger = getLogger("fl.server")
@abstractmethod
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Abstract method for aggregating models.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
"""
pass
Classes¶
Aggregation¶
Abstract base class for aggregation strategies.
View Source
class Aggregation(ABC):
"""
Abstract base class for aggregation strategies.
"""
_logger = getLogger("fl.server")
@abstractmethod
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Abstract method for aggregating models.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
"""
pass
Ancestors (in MRO)¶
- abc.ABC
Descendants¶
- fl_server_ai.aggregation.mean.MeanAggregation
Methods¶
aggregate¶
def aggregate(
self,
models: Sequence[torch.nn.modules.module.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.modules.module.Module
Abstract method for aggregating models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models | Sequence[torch.nn.Module] | The models to be aggregated. | None |
model_sample_sizes | Sequence[int] | The sample sizes for each model. | None |
deepcopy | bool | Whether to create a deep copy of the models. Defaults to True. | True |
Returns:
Type | Description |
---|---|
torch.nn.Module | The aggregated model. |
View Source
@abstractmethod
def aggregate(
self,
models: Sequence[torch.nn.Module],
model_sample_sizes: Sequence[int],
*,
deepcopy: bool = True
) -> torch.nn.Module:
"""
Abstract method for aggregating models.
Args:
models (Sequence[torch.nn.Module]): The models to be aggregated.
model_sample_sizes (Sequence[int]): The sample sizes for each model.
deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
Returns:
torch.nn.Module: The aggregated model.
"""
pass