Module fl_server_ai.aggregation.method¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from typing import overload, Type
from fl_server_core.models import Model, Training
from fl_server_core.models.training import AggregationMethod
from .base import Aggregation
from .fed_dc import FedDC
from .fed_prox import FedProx
from .mean import MeanAggregation
@overload
def get_aggregation_class(value: Model) -> Type[Aggregation]: ...
@overload
def get_aggregation_class(value: Training) -> Type[Aggregation]: ...
@overload
def get_aggregation_class(value: AggregationMethod) -> Type[Aggregation]: ...
def get_aggregation_class(value: AggregationMethod | Model | Training) -> Type[Aggregation]:
"""
Get the aggregation class based on the provided model, training or aggregation method.
Args:
value (AggregationMethod | Model | Training): The value based on which the aggregation class is determined.
Returns:
Type[Aggregation]: The type of the aggregation class.
Raises:
ValueError: If the type of value or the aggregation method is unknown.
"""
if isinstance(value, AggregationMethod):
method = value
elif isinstance(value, Training):
method = value.aggregation_method
elif isinstance(value, Model):
aggregation_method = Training.objects.filter(model=value) \
.values("aggregation_method") \
.first()["aggregation_method"]
method = aggregation_method
else:
raise ValueError(f"Unknown type: {type(value)}")
match method:
case AggregationMethod.FED_AVG: return MeanAggregation
case AggregationMethod.FED_DC: return FedDC
case AggregationMethod.FED_PROX: return FedProx
case _: raise ValueError(f"Unknown aggregation method: {method}")
Functions¶
get_aggregation_class¶
def get_aggregation_class(
value: fl_server_core.models.training.AggregationMethod | fl_server_core.models.model.Model | fl_server_core.models.training.Training
) -> Type[fl_server_ai.aggregation.base.Aggregation]
Get the aggregation class based on the provided model, training or aggregation method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value | AggregationMethod | Model | Training |
Returns:
Type | Description |
---|---|
Type[Aggregation] | The type of the aggregation class. |
Raises:
Type | Description |
---|---|
ValueError | If the type of value or the aggregation method is unknown. |
View Source
def get_aggregation_class(value: AggregationMethod | Model | Training) -> Type[Aggregation]:
"""
Get the aggregation class based on the provided model, training or aggregation method.
Args:
value (AggregationMethod | Model | Training): The value based on which the aggregation class is determined.
Returns:
Type[Aggregation]: The type of the aggregation class.
Raises:
ValueError: If the type of value or the aggregation method is unknown.
"""
if isinstance(value, AggregationMethod):
method = value
elif isinstance(value, Training):
method = value.aggregation_method
elif isinstance(value, Model):
aggregation_method = Training.objects.filter(model=value) \
.values("aggregation_method") \
.first()["aggregation_method"]
method = aggregation_method
else:
raise ValueError(f"Unknown type: {type(value)}")
match method:
case AggregationMethod.FED_AVG: return MeanAggregation
case AggregationMethod.FED_DC: return FedDC
case AggregationMethod.FED_PROX: return FedProx
case _: raise ValueError(f"Unknown aggregation method: {method}")