Skip to content

fl_server_ai.uncertainty.method

Functions:

Name Description
get_uncertainty_class

Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

Classes

Functions

get_uncertainty_class

get_uncertainty_class(value: Model) -> Type[UncertaintyBase]
get_uncertainty_class(value: Training) -> Type[UncertaintyBase]
get_uncertainty_class(value: UncertaintyMethod) -> Type[UncertaintyBase]
get_uncertainty_class(value: Model | Training | UncertaintyMethod) -> Type[UncertaintyBase]

Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

Parameters:

Name Type Description Default

value

Model | Training | UncertaintyMethod

The object to retrieve the uncertainty class for.

required

Returns:

Type Description
Type[UncertaintyBase]

Type[UncertaintyBase]: The uncertainty class associated with the given object.

Raises:

Type Description
ValueError

If the given object is not a Model, Training, or UncertaintyMethod, or if the uncertainty method associated with the object is unknown.

Source code in fl_server_ai/uncertainty/method.py
def get_uncertainty_class(value: Model | Training | UncertaintyMethod) -> Type[UncertaintyBase]:
    """
    Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

    Args:
        value (Model | Training | UncertaintyMethod): The object to retrieve the uncertainty class for.

    Returns:
        Type[UncertaintyBase]: The uncertainty class associated with the given object.

    Raises:
        ValueError: If the given object is not a Model, Training, or UncertaintyMethod,
                    or if the uncertainty method associated with the object is unknown.
    """
    if isinstance(value, UncertaintyMethod):
        method = value
    elif isinstance(value, Training):
        method = value.uncertainty_method
    elif isinstance(value, Model):
        uncertainty_method = Training.objects.filter(model=value) \
                .values("uncertainty_method") \
                .first()["uncertainty_method"]
        method = uncertainty_method
    else:
        raise ValueError(f"Unknown type: {type(value)}")

    match method:
        case UncertaintyMethod.ENSEMBLE: return Ensemble
        case UncertaintyMethod.MC_DROPOUT: return MCDropout
        case UncertaintyMethod.NONE: return NoneUncertainty
        case UncertaintyMethod.SWAG: return SWAG
        case _: raise ValueError(f"Unknown uncertainty method: {method}")