Module fl_server_ai.uncertainty.method¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from typing import overload, Type
from fl_server_core.models import Model, Training
from fl_server_core.models.training import UncertaintyMethod
from .base import UncertaintyBase
from .ensemble import Ensemble
from .mc_dropout import MCDropout
from .none import NoneUncertainty
from .swag import SWAG
@overload
def get_uncertainty_class(value: Model) -> Type[UncertaintyBase]: ...
@overload
def get_uncertainty_class(value: Training) -> Type[UncertaintyBase]: ...
@overload
def get_uncertainty_class(value: UncertaintyMethod) -> Type[UncertaintyBase]: ...
def get_uncertainty_class(value: Model | Training | UncertaintyMethod) -> Type[UncertaintyBase]:
"""
Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.
Args:
value (Model | Training | UncertaintyMethod): The object to retrieve the uncertainty class for.
Returns:
Type[UncertaintyBase]: The uncertainty class associated with the given object.
Raises:
ValueError: If the given object is not a Model, Training, or UncertaintyMethod,
or if the uncertainty method associated with the object is unknown.
"""
if isinstance(value, UncertaintyMethod):
method = value
elif isinstance(value, Training):
method = value.uncertainty_method
elif isinstance(value, Model):
uncertainty_method = Training.objects.filter(model=value) \
.values("uncertainty_method") \
.first()["uncertainty_method"]
method = uncertainty_method
else:
raise ValueError(f"Unknown type: {type(value)}")
match method:
case UncertaintyMethod.ENSEMBLE: return Ensemble
case UncertaintyMethod.MC_DROPOUT: return MCDropout
case UncertaintyMethod.NONE: return NoneUncertainty
case UncertaintyMethod.SWAG: return SWAG
case _: raise ValueError(f"Unknown uncertainty method: {method}")
Functions¶
get_uncertainty_class¶
def get_uncertainty_class(
value: fl_server_core.models.model.Model | fl_server_core.models.training.Training | fl_server_core.models.training.UncertaintyMethod
) -> Type[fl_server_ai.uncertainty.base.UncertaintyBase]
Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value | Model | Training | UncertaintyMethod |
Returns:
Type | Description |
---|---|
Type[UncertaintyBase] | The uncertainty class associated with the given object. |
Raises:
Type | Description |
---|---|
ValueError | If the given object is not a Model, Training, or UncertaintyMethod, or if the uncertainty method associated with the object is unknown. |
View Source
def get_uncertainty_class(value: Model | Training | UncertaintyMethod) -> Type[UncertaintyBase]:
"""
Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.
Args:
value (Model | Training | UncertaintyMethod): The object to retrieve the uncertainty class for.
Returns:
Type[UncertaintyBase]: The uncertainty class associated with the given object.
Raises:
ValueError: If the given object is not a Model, Training, or UncertaintyMethod,
or if the uncertainty method associated with the object is unknown.
"""
if isinstance(value, UncertaintyMethod):
method = value
elif isinstance(value, Training):
method = value.uncertainty_method
elif isinstance(value, Model):
uncertainty_method = Training.objects.filter(model=value) \
.values("uncertainty_method") \
.first()["uncertainty_method"]
method = uncertainty_method
else:
raise ValueError(f"Unknown type: {type(value)}")
match method:
case UncertaintyMethod.ENSEMBLE: return Ensemble
case UncertaintyMethod.MC_DROPOUT: return MCDropout
case UncertaintyMethod.NONE: return NoneUncertainty
case UncertaintyMethod.SWAG: return SWAG
case _: raise ValueError(f"Unknown uncertainty method: {method}")