Skip to content

Module fl_server_ai.uncertainty.method

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from typing import overload, Type

from fl_server_core.models import Model, Training
from fl_server_core.models.training import UncertaintyMethod

from .base import UncertaintyBase
from .ensemble import Ensemble
from .mc_dropout import MCDropout
from .none import NoneUncertainty
from .swag import SWAG


@overload
def get_uncertainty_class(value: Model) -> Type[UncertaintyBase]: ...


@overload
def get_uncertainty_class(value: Training) -> Type[UncertaintyBase]: ...


@overload
def get_uncertainty_class(value: UncertaintyMethod) -> Type[UncertaintyBase]: ...


def get_uncertainty_class(value: Model | Training | UncertaintyMethod) -> Type[UncertaintyBase]:
    """
    Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

    Args:
        value (Model | Training | UncertaintyMethod): The object to retrieve the uncertainty class for.

    Returns:
        Type[UncertaintyBase]: The uncertainty class associated with the given object.

    Raises:
        ValueError: If the given object is not a Model, Training, or UncertaintyMethod,
                    or if the uncertainty method associated with the object is unknown.
    """
    if isinstance(value, UncertaintyMethod):
        method = value
    elif isinstance(value, Training):
        method = value.uncertainty_method
    elif isinstance(value, Model):
        uncertainty_method = Training.objects.filter(model=value) \
                .values("uncertainty_method") \
                .first()["uncertainty_method"]
        method = uncertainty_method
    else:
        raise ValueError(f"Unknown type: {type(value)}")

    match method:
        case UncertaintyMethod.ENSEMBLE: return Ensemble
        case UncertaintyMethod.MC_DROPOUT: return MCDropout
        case UncertaintyMethod.NONE: return NoneUncertainty
        case UncertaintyMethod.SWAG: return SWAG
        case _: raise ValueError(f"Unknown uncertainty method: {method}")

Functions

get_uncertainty_class

def get_uncertainty_class(
    value: fl_server_core.models.model.Model | fl_server_core.models.training.Training | fl_server_core.models.training.UncertaintyMethod
) -> Type[fl_server_ai.uncertainty.base.UncertaintyBase]

Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

Parameters:

Name Type Description Default
value Model Training UncertaintyMethod

Returns:

Type Description
Type[UncertaintyBase] The uncertainty class associated with the given object.

Raises:

Type Description
ValueError If the given object is not a Model, Training, or UncertaintyMethod,
or if the uncertainty method associated with the object is unknown.
View Source
def get_uncertainty_class(value: Model | Training | UncertaintyMethod) -> Type[UncertaintyBase]:
    """
    Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

    Args:
        value (Model | Training | UncertaintyMethod): The object to retrieve the uncertainty class for.

    Returns:
        Type[UncertaintyBase]: The uncertainty class associated with the given object.

    Raises:
        ValueError: If the given object is not a Model, Training, or UncertaintyMethod,
                    or if the uncertainty method associated with the object is unknown.
    """
    if isinstance(value, UncertaintyMethod):
        method = value
    elif isinstance(value, Training):
        method = value.uncertainty_method
    elif isinstance(value, Model):
        uncertainty_method = Training.objects.filter(model=value) \
                .values("uncertainty_method") \
                .first()["uncertainty_method"]
        method = uncertainty_method
    else:
        raise ValueError(f"Unknown type: {type(value)}")

    match method:
        case UncertaintyMethod.ENSEMBLE: return Ensemble
        case UncertaintyMethod.MC_DROPOUT: return MCDropout
        case UncertaintyMethod.NONE: return NoneUncertainty
        case UncertaintyMethod.SWAG: return SWAG
        case _: raise ValueError(f"Unknown uncertainty method: {method}")