Skip to content

fl_server_ai.aggregation.method

Functions:

Name Description
get_aggregation_class

Get the aggregation class based on the provided model, training or aggregation method.

Classes

Functions

get_aggregation_class

get_aggregation_class(value: Model) -> Type[Aggregation]
get_aggregation_class(value: Training) -> Type[Aggregation]
get_aggregation_class(value: AggregationMethod) -> Type[Aggregation]
get_aggregation_class(value: AggregationMethod | Model | Training) -> Type[Aggregation]

Get the aggregation class based on the provided model, training or aggregation method.

Parameters:

Name Type Description Default

value

AggregationMethod | Model | Training

The value based on which the aggregation class is determined.

required

Returns:

Type Description
Type[Aggregation]

Type[Aggregation]: The type of the aggregation class.

Raises:

Type Description
ValueError

If the type of value or the aggregation method is unknown.

Source code in fl_server_ai/aggregation/method.py
def get_aggregation_class(value: AggregationMethod | Model | Training) -> Type[Aggregation]:
    """
    Get the aggregation class based on the provided model, training or aggregation method.

    Args:
        value (AggregationMethod | Model | Training): The value based on which the aggregation class is determined.

    Returns:
        Type[Aggregation]: The type of the aggregation class.

    Raises:
        ValueError: If the type of value or the aggregation method is unknown.
    """
    if isinstance(value, AggregationMethod):
        method = value
    elif isinstance(value, Training):
        method = value.aggregation_method
    elif isinstance(value, Model):
        aggregation_method = Training.objects.filter(model=value) \
                .values("aggregation_method") \
                .first()["aggregation_method"]
        method = aggregation_method
    else:
        raise ValueError(f"Unknown type: {type(value)}")

    match method:
        case AggregationMethod.FED_AVG: return MeanAggregation
        case AggregationMethod.FED_DC: return FedDC
        case AggregationMethod.FED_PROX: return FedProx
        case _: raise ValueError(f"Unknown aggregation method: {method}")