Module fl_server_ai.aggregation.base¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from logging import getLogger
import torch
from typing import Sequence
class Aggregation(ABC):
    """
    Abstract base class for aggregation strategies.
    """
    _logger = getLogger("fl.server")
    @abstractmethod
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Abstract method for aggregating models.
        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
        Returns:
            torch.nn.Module: The aggregated model.
        """
        pass
Classes¶
Aggregation¶
Abstract base class for aggregation strategies.
View Source
class Aggregation(ABC):
    """
    Abstract base class for aggregation strategies.
    """
    _logger = getLogger("fl.server")
    @abstractmethod
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Abstract method for aggregating models.
        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
        Returns:
            torch.nn.Module: The aggregated model.
        """
        pass
Ancestors (in MRO)¶
- abc.ABC
 
Descendants¶
- fl_server_ai.aggregation.mean.MeanAggregation
 
Methods¶
aggregate¶
def aggregate(
    self,
    models: Sequence[torch.nn.modules.module.Module],
    model_sample_sizes: Sequence[int],
    *,
    deepcopy: bool = True
) -> torch.nn.modules.module.Module
Abstract method for aggregating models.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| models | Sequence[torch.nn.Module] | The models to be aggregated. | None | 
| model_sample_sizes | Sequence[int] | The sample sizes for each model. | None | 
| deepcopy | bool | Whether to create a deep copy of the models. Defaults to True. | True | 
Returns:
| Type | Description | 
|---|---|
| torch.nn.Module | The aggregated model. | 
View Source
    @abstractmethod
    def aggregate(
        self,
        models: Sequence[torch.nn.Module],
        model_sample_sizes: Sequence[int],
        *,
        deepcopy: bool = True
    ) -> torch.nn.Module:
        """
        Abstract method for aggregating models.
        Args:
            models (Sequence[torch.nn.Module]): The models to be aggregated.
            model_sample_sizes (Sequence[int]): The sample sizes for each model.
            deepcopy (bool, optional): Whether to create a deep copy of the models. Defaults to True.
        Returns:
            torch.nn.Module: The aggregated model.
        """
        pass