Skip to content

fl_server_ai.uncertainty

Modules:

Name Description
base
ensemble
mc_dropout
method
none
swag

Classes:

Name Description
Ensemble

Ensemble uncertainty estimation.

MCDropout

Monte Carlo (MC) Dropout Uncertainty Estimation

NoneUncertainty

Empty uncertainty estimation when no specific uncertainty method is used.

SWAG

Stochastic Weight Averaging Gaussian (SWAG) uncertainty estimation.

UncertaintyBase

Abstract base class for uncertainty estimation.

Functions:

Name Description
get_uncertainty_class

Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

Attributes

__all__ module-attribute

__all__ = ['get_uncertainty_class', 'Ensemble', 'MCDropout', 'NoneUncertainty', 'SWAG', 'UncertaintyBase']

Classes

Ensemble

Bases: UncertaintyBase


              flowchart TD
              fl_server_ai.uncertainty.Ensemble[Ensemble]
              fl_server_ai.uncertainty.base.UncertaintyBase[UncertaintyBase]

                              fl_server_ai.uncertainty.base.UncertaintyBase --> fl_server_ai.uncertainty.Ensemble
                


              click fl_server_ai.uncertainty.Ensemble href "" "fl_server_ai.uncertainty.Ensemble"
              click fl_server_ai.uncertainty.base.UncertaintyBase href "" "fl_server_ai.uncertainty.base.UncertaintyBase"
            

Ensemble uncertainty estimation.

Methods:

Name Description
prediction
Source code in fl_server_ai/uncertainty/ensemble.py
class Ensemble(UncertaintyBase):
    """
    Ensemble uncertainty estimation.
    """

    @classmethod
    def prediction(cls, input: torch.Tensor, model: MeanModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
        output_list = []
        for m in model.models.all():
            net = m.get_torch_model()
            output = net(input).detach()
            output_list.append(output)
        outputs = torch.stack(output_list, dim=0)  # (N, batch_size, n_classes)  # N = number of models

        inference = outputs.mean(dim=0)
        uncertainty = cls.interpret(outputs)
        return inference, uncertainty

Functions

prediction classmethod
prediction(input: Tensor, model: MeanModel) -> tuple[Tensor, dict[str, Any]]
Source code in fl_server_ai/uncertainty/ensemble.py
@classmethod
def prediction(cls, input: torch.Tensor, model: MeanModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
    output_list = []
    for m in model.models.all():
        net = m.get_torch_model()
        output = net(input).detach()
        output_list.append(output)
    outputs = torch.stack(output_list, dim=0)  # (N, batch_size, n_classes)  # N = number of models

    inference = outputs.mean(dim=0)
    uncertainty = cls.interpret(outputs)
    return inference, uncertainty

MCDropout

Bases: UncertaintyBase


              flowchart TD
              fl_server_ai.uncertainty.MCDropout[MCDropout]
              fl_server_ai.uncertainty.base.UncertaintyBase[UncertaintyBase]

                              fl_server_ai.uncertainty.base.UncertaintyBase --> fl_server_ai.uncertainty.MCDropout
                


              click fl_server_ai.uncertainty.MCDropout href "" "fl_server_ai.uncertainty.MCDropout"
              click fl_server_ai.uncertainty.base.UncertaintyBase href "" "fl_server_ai.uncertainty.base.UncertaintyBase"
            

Monte Carlo (MC) Dropout Uncertainty Estimation

Requirements:

  • model with dropout layers
  • T, number of samples per input (number of monte-carlo samples/forward passes)

References:

Methods:

Name Description
prediction
Source code in fl_server_ai/uncertainty/mc_dropout.py
class MCDropout(UncertaintyBase):
    """
    Monte Carlo (MC) Dropout Uncertainty Estimation

    Requirements:

    - model with dropout layers
    - T, number of samples per input (number of monte-carlo samples/forward passes)

    References:

    - Paper: Understanding Measures of Uncertainty for Adversarial Example Detection
             <https://arxiv.org/abs/1803.08533>
    - Code inspiration: <https://github.com/lsgos/uncertainty-adversarial-paper/tree/master>
    """

    @classmethod
    def prediction(cls, input: Tensor, model: Model) -> Tuple[torch.Tensor, Dict[str, Any]]:
        options = cls.get_options(model)
        N = options.get("N", 10)
        softmax = options.get("softmax", False)

        net: Module = model.get_torch_model()
        net.eval()
        set_dropout(net, state=True)

        out_list = []
        for _ in range(N):
            output = net(input).detach()
            # convert to probabilities if necessary
            if softmax:
                output = torch.softmax(output, dim=1)
            out_list.append(output)
        out = torch.stack(out_list, dim=0)  # (n_mc, batch_size, n_classes)

        inference = out.mean(dim=0)
        uncertainty = cls.interpret(out)
        return inference, uncertainty

Functions

prediction classmethod
prediction(input: Tensor, model: Model) -> tuple[Tensor, dict[str, Any]]
Source code in fl_server_ai/uncertainty/mc_dropout.py
@classmethod
def prediction(cls, input: Tensor, model: Model) -> Tuple[torch.Tensor, Dict[str, Any]]:
    options = cls.get_options(model)
    N = options.get("N", 10)
    softmax = options.get("softmax", False)

    net: Module = model.get_torch_model()
    net.eval()
    set_dropout(net, state=True)

    out_list = []
    for _ in range(N):
        output = net(input).detach()
        # convert to probabilities if necessary
        if softmax:
            output = torch.softmax(output, dim=1)
        out_list.append(output)
    out = torch.stack(out_list, dim=0)  # (n_mc, batch_size, n_classes)

    inference = out.mean(dim=0)
    uncertainty = cls.interpret(out)
    return inference, uncertainty

NoneUncertainty

Bases: UncertaintyBase


              flowchart TD
              fl_server_ai.uncertainty.NoneUncertainty[NoneUncertainty]
              fl_server_ai.uncertainty.base.UncertaintyBase[UncertaintyBase]

                              fl_server_ai.uncertainty.base.UncertaintyBase --> fl_server_ai.uncertainty.NoneUncertainty
                


              click fl_server_ai.uncertainty.NoneUncertainty href "" "fl_server_ai.uncertainty.NoneUncertainty"
              click fl_server_ai.uncertainty.base.UncertaintyBase href "" "fl_server_ai.uncertainty.base.UncertaintyBase"
            

Empty uncertainty estimation when no specific uncertainty method is used.

This class does not calculate any uncertainty and only returns the prediction with an empty uncertainty dictionary.

Methods:

Name Description
prediction
Source code in fl_server_ai/uncertainty/none.py
class NoneUncertainty(UncertaintyBase):
    """
    Empty uncertainty estimation when no specific uncertainty method is used.

    This class does not calculate any uncertainty and only returns the prediction with an empty uncertainty dictionary.
    """

    @classmethod
    def prediction(cls, input: torch.Tensor, model: Model) -> Tuple[torch.Tensor, Dict[str, Any]]:
        net: torch.nn.Module = model.get_torch_model()
        prediction: torch.Tensor = net(input)
        return prediction.argmax(dim=1), {}

Functions

prediction classmethod
prediction(input: Tensor, model: Model) -> tuple[Tensor, dict[str, Any]]
Source code in fl_server_ai/uncertainty/none.py
@classmethod
def prediction(cls, input: torch.Tensor, model: Model) -> Tuple[torch.Tensor, Dict[str, Any]]:
    net: torch.nn.Module = model.get_torch_model()
    prediction: torch.Tensor = net(input)
    return prediction.argmax(dim=1), {}

SWAG

Bases: UncertaintyBase


              flowchart TD
              fl_server_ai.uncertainty.SWAG[SWAG]
              fl_server_ai.uncertainty.base.UncertaintyBase[UncertaintyBase]

                              fl_server_ai.uncertainty.base.UncertaintyBase --> fl_server_ai.uncertainty.SWAG
                


              click fl_server_ai.uncertainty.SWAG href "" "fl_server_ai.uncertainty.SWAG"
              click fl_server_ai.uncertainty.base.UncertaintyBase href "" "fl_server_ai.uncertainty.base.UncertaintyBase"
            

Stochastic Weight Averaging Gaussian (SWAG) uncertainty estimation.

Methods:

Name Description
prediction
Source code in fl_server_ai/uncertainty/swag.py
class SWAG(UncertaintyBase):
    """
    Stochastic Weight Averaging Gaussian (SWAG) uncertainty estimation.
    """

    @classmethod
    def prediction(cls, input: torch.Tensor, model: SWAGModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
        options = cls.get_options(model)
        N = options.get("N", 10)

        net: torch.nn.Module = model.get_torch_model()

        # first and second moment are already ensured to be in
        # alphabetical order in the database
        fm = model.first_moment
        sm = model.second_moment
        std = sm - torch.pow(fm, 2)
        params = torch.normal(mean=fm[None, :], std=std).expand(N, -1)

        prediction_list = []
        for n in range(N):
            torch.nn.utils.vector_to_parameters(params[n], net.parameters())
            prediction = net(input)
            prediction_list.append(prediction)
        predictions = torch.stack(prediction_list)

        inference = predictions.mean(dim=0)
        uncertainty = cls.interpret(predictions)
        return inference, uncertainty

Functions

prediction classmethod
prediction(input: Tensor, model: SWAGModel) -> tuple[Tensor, dict[str, Any]]
Source code in fl_server_ai/uncertainty/swag.py
@classmethod
def prediction(cls, input: torch.Tensor, model: SWAGModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
    options = cls.get_options(model)
    N = options.get("N", 10)

    net: torch.nn.Module = model.get_torch_model()

    # first and second moment are already ensured to be in
    # alphabetical order in the database
    fm = model.first_moment
    sm = model.second_moment
    std = sm - torch.pow(fm, 2)
    params = torch.normal(mean=fm[None, :], std=std).expand(N, -1)

    prediction_list = []
    for n in range(N):
        torch.nn.utils.vector_to_parameters(params[n], net.parameters())
        prediction = net(input)
        prediction_list.append(prediction)
    predictions = torch.stack(prediction_list)

    inference = predictions.mean(dim=0)
    uncertainty = cls.interpret(predictions)
    return inference, uncertainty

UncertaintyBase

Bases: ABC


              flowchart TD
              fl_server_ai.uncertainty.UncertaintyBase[UncertaintyBase]

              

              click fl_server_ai.uncertainty.UncertaintyBase href "" "fl_server_ai.uncertainty.UncertaintyBase"
            

Abstract base class for uncertainty estimation.

This class defines the interface for uncertainty estimation in federated learning.

Methods:

Name Description
expected_entropy

Calculate the mean entropy of the predictive distribution across the MC samples.

get_options

Get uncertainty options from training options.

interpret

Interpret the different network (model) outputs and calculate the uncertainty.

mutual_information

Calculate the BALD (Bayesian Active Learning by Disagreement) of a model;

prediction

Make a prediction using the given input and model.

predictive_entropy

Calculate the entropy of the mean predictive distribution across the MC samples.

to_json

Convert the given inference and uncertainty data to a JSON string.

Source code in fl_server_ai/uncertainty/base.py
class UncertaintyBase(ABC):
    """
    Abstract base class for uncertainty estimation.

    This class defines the interface for uncertainty estimation in federated learning.
    """

    _logger = getLogger("fl.server")
    """The private logger instance for the uncertainty estimation."""

    @classmethod
    @abstractmethod
    def prediction(cls, input: torch.Tensor, model: Model) -> Tuple[torch.Tensor, Dict[str, Any]]:
        """
        Make a prediction using the given input and model.

        Args:
            input (torch.Tensor): The input to make a prediction for.
            model (Model): The model to use for making the prediction.

        Returns:
            Tuple[torch.Tensor, Dict[str, Any]]: The prediction and any associated uncertainty.
        """
        pass

    @classmethod
    def interpret(cls, outputs: torch.Tensor) -> Dict[str, Any]:
        """
        Interpret the different network (model) outputs and calculate the uncertainty.

        Args:
            outputs (torch.Tensor): multiple network (model) outputs (N, batch_size, n_classes)

        Return:
            Tuple[torch.Tensor, Dict[str, Any]]: inference and uncertainty
        """
        variance = outputs.var(dim=0)
        std = outputs.std(dim=0)
        if not (torch.all(outputs <= 1.) and torch.all(outputs >= 0.)):
            return dict(variance=variance, std=std)

        predictive_entropy = cls.predictive_entropy(outputs)
        expected_entropy = cls.expected_entropy(outputs)
        mutual_info = predictive_entropy - expected_entropy  # see cls.mutual_information
        return dict(
            variance=variance,
            std=std,
            predictive_entropy=predictive_entropy,
            expected_entropy=expected_entropy,
            mutual_info=mutual_info,
        )

    @classmethod
    def expected_entropy(cls, predictions: torch.Tensor) -> torch.Tensor:
        """
        Calculate the mean entropy of the predictive distribution across the MC samples.

        Args:
            predictions (torch.Tensor): predictions of shape (n_mc x batch_size x n_classes)

        Returns:
            torch.Tensor: mean entropy of the predictive distribution
        """
        return torch.distributions.Categorical(probs=predictions).entropy().mean(dim=0)

    @classmethod
    def predictive_entropy(cls, predictions: torch.Tensor) -> torch.Tensor:
        """
        Calculate the entropy of the mean predictive distribution across the MC samples.

        Args:
            predictions (torch.Tensor): predictions of shape (n_mc x batch_size x n_classes)

        Returns:
            torch.Tensor: entropy of the mean predictive distribution
        """
        return torch.distributions.Categorical(probs=predictions.mean(dim=0)).entropy()

    @classmethod
    def mutual_information(cls, predictions: torch.Tensor) -> torch.Tensor:
        """
        Calculate the BALD (Bayesian Active Learning by Disagreement) of a model;
        the difference between the mean of the entropy and the entropy of the mean
        of the predicted distribution on the predictions.
        This method is also sometimes referred to as the mutual information (MI).

        Args:
            predictions (torch.Tensor): predictions of shape (n_mc x batch_size x n_classes)

        Returns:
            torch.Tensor: difference between the mean of the entropy and the entropy of the mean
                    of the predicted distribution
        """
        return cls.predictive_entropy(predictions) - cls.expected_entropy(predictions)

    @classmethod
    def get_options(cls, obj: Model | Training) -> Dict[str, Any]:
        """
        Get uncertainty options from training options.

        Args:
            obj (Model | Training): The Model or Training object to retrieve options for.

        Returns:
            Dict[str, Any]: Uncertainty options.

        Raises:
            TypeError: If the given object is not a Model or Training.
        """
        if isinstance(obj, Model):
            return Training.objects.filter(model=obj) \
                .values("options") \
                .first()["options"] \
                .get("uncertainty", {})
        if isinstance(obj, Training):
            return obj.options.get("uncertainty", {})
        raise TypeError(f"Expected Model or Training, got {type(obj)}")

    @classmethod
    def to_json(cls, inference: torch.Tensor, uncertainty: Dict[str, Any] = {}, **json_kwargs) -> str:
        """
        Convert the given inference and uncertainty data to a JSON string.

        Args:
            inference (torch.Tensor): The inference to convert.
            uncertainty (Dict[str, Any]): The uncertainty to convert.
            **json_kwargs: Additional keyword arguments to pass to `json.dumps`.

        Returns:
            str: A JSON string representation of the given inference and uncertainty data.
        """
        def prepare(v):
            if isinstance(v, torch.Tensor):
                return v.cpu().tolist()
            if isinstance(v, np.ndarray):  # cspell:ignore ndarray
                return v.tolist()
            if isinstance(v, dict):
                return {k: prepare(_v) for k, _v in v.items()}
            return v

        return json.dumps({
            "inference": inference.tolist(),
            "uncertainty": prepare(uncertainty) if uncertainty else {},
        }, **json_kwargs)

Functions

expected_entropy classmethod
expected_entropy(predictions: Tensor) -> Tensor

Calculate the mean entropy of the predictive distribution across the MC samples.

Parameters:

Name Type Description Default
predictions
Tensor

predictions of shape (n_mc x batch_size x n_classes)

required

Returns:

Type Description
Tensor

torch.Tensor: mean entropy of the predictive distribution

Source code in fl_server_ai/uncertainty/base.py
@classmethod
def expected_entropy(cls, predictions: torch.Tensor) -> torch.Tensor:
    """
    Calculate the mean entropy of the predictive distribution across the MC samples.

    Args:
        predictions (torch.Tensor): predictions of shape (n_mc x batch_size x n_classes)

    Returns:
        torch.Tensor: mean entropy of the predictive distribution
    """
    return torch.distributions.Categorical(probs=predictions).entropy().mean(dim=0)
get_options classmethod
get_options(obj: Model | Training) -> dict[str, Any]

Get uncertainty options from training options.

Parameters:

Name Type Description Default
obj
Model | Training

The Model or Training object to retrieve options for.

required

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: Uncertainty options.

Raises:

Type Description
TypeError

If the given object is not a Model or Training.

Source code in fl_server_ai/uncertainty/base.py
@classmethod
def get_options(cls, obj: Model | Training) -> Dict[str, Any]:
    """
    Get uncertainty options from training options.

    Args:
        obj (Model | Training): The Model or Training object to retrieve options for.

    Returns:
        Dict[str, Any]: Uncertainty options.

    Raises:
        TypeError: If the given object is not a Model or Training.
    """
    if isinstance(obj, Model):
        return Training.objects.filter(model=obj) \
            .values("options") \
            .first()["options"] \
            .get("uncertainty", {})
    if isinstance(obj, Training):
        return obj.options.get("uncertainty", {})
    raise TypeError(f"Expected Model or Training, got {type(obj)}")
interpret classmethod
interpret(outputs: Tensor) -> dict[str, Any]

Interpret the different network (model) outputs and calculate the uncertainty.

Parameters:

Name Type Description Default
outputs
Tensor

multiple network (model) outputs (N, batch_size, n_classes)

required
Return

Tuple[torch.Tensor, Dict[str, Any]]: inference and uncertainty

Source code in fl_server_ai/uncertainty/base.py
@classmethod
def interpret(cls, outputs: torch.Tensor) -> Dict[str, Any]:
    """
    Interpret the different network (model) outputs and calculate the uncertainty.

    Args:
        outputs (torch.Tensor): multiple network (model) outputs (N, batch_size, n_classes)

    Return:
        Tuple[torch.Tensor, Dict[str, Any]]: inference and uncertainty
    """
    variance = outputs.var(dim=0)
    std = outputs.std(dim=0)
    if not (torch.all(outputs <= 1.) and torch.all(outputs >= 0.)):
        return dict(variance=variance, std=std)

    predictive_entropy = cls.predictive_entropy(outputs)
    expected_entropy = cls.expected_entropy(outputs)
    mutual_info = predictive_entropy - expected_entropy  # see cls.mutual_information
    return dict(
        variance=variance,
        std=std,
        predictive_entropy=predictive_entropy,
        expected_entropy=expected_entropy,
        mutual_info=mutual_info,
    )
mutual_information classmethod
mutual_information(predictions: Tensor) -> Tensor

Calculate the BALD (Bayesian Active Learning by Disagreement) of a model; the difference between the mean of the entropy and the entropy of the mean of the predicted distribution on the predictions. This method is also sometimes referred to as the mutual information (MI).

Parameters:

Name Type Description Default
predictions
Tensor

predictions of shape (n_mc x batch_size x n_classes)

required

Returns:

Type Description
Tensor

torch.Tensor: difference between the mean of the entropy and the entropy of the mean of the predicted distribution

Source code in fl_server_ai/uncertainty/base.py
@classmethod
def mutual_information(cls, predictions: torch.Tensor) -> torch.Tensor:
    """
    Calculate the BALD (Bayesian Active Learning by Disagreement) of a model;
    the difference between the mean of the entropy and the entropy of the mean
    of the predicted distribution on the predictions.
    This method is also sometimes referred to as the mutual information (MI).

    Args:
        predictions (torch.Tensor): predictions of shape (n_mc x batch_size x n_classes)

    Returns:
        torch.Tensor: difference between the mean of the entropy and the entropy of the mean
                of the predicted distribution
    """
    return cls.predictive_entropy(predictions) - cls.expected_entropy(predictions)
prediction abstractmethod classmethod
prediction(input: Tensor, model: Model) -> tuple[Tensor, dict[str, Any]]

Make a prediction using the given input and model.

Parameters:

Name Type Description Default
input
Tensor

The input to make a prediction for.

required
model
Model

The model to use for making the prediction.

required

Returns:

Type Description
tuple[Tensor, dict[str, Any]]

Tuple[torch.Tensor, Dict[str, Any]]: The prediction and any associated uncertainty.

Source code in fl_server_ai/uncertainty/base.py
@classmethod
@abstractmethod
def prediction(cls, input: torch.Tensor, model: Model) -> Tuple[torch.Tensor, Dict[str, Any]]:
    """
    Make a prediction using the given input and model.

    Args:
        input (torch.Tensor): The input to make a prediction for.
        model (Model): The model to use for making the prediction.

    Returns:
        Tuple[torch.Tensor, Dict[str, Any]]: The prediction and any associated uncertainty.
    """
    pass
predictive_entropy classmethod
predictive_entropy(predictions: Tensor) -> Tensor

Calculate the entropy of the mean predictive distribution across the MC samples.

Parameters:

Name Type Description Default
predictions
Tensor

predictions of shape (n_mc x batch_size x n_classes)

required

Returns:

Type Description
Tensor

torch.Tensor: entropy of the mean predictive distribution

Source code in fl_server_ai/uncertainty/base.py
@classmethod
def predictive_entropy(cls, predictions: torch.Tensor) -> torch.Tensor:
    """
    Calculate the entropy of the mean predictive distribution across the MC samples.

    Args:
        predictions (torch.Tensor): predictions of shape (n_mc x batch_size x n_classes)

    Returns:
        torch.Tensor: entropy of the mean predictive distribution
    """
    return torch.distributions.Categorical(probs=predictions.mean(dim=0)).entropy()
to_json classmethod
to_json(inference: Tensor, uncertainty: dict[str, Any] = {}, **json_kwargs) -> str

Convert the given inference and uncertainty data to a JSON string.

Parameters:

Name Type Description Default
inference
Tensor

The inference to convert.

required
uncertainty
dict[str, Any]

The uncertainty to convert.

{}
**json_kwargs

Additional keyword arguments to pass to json.dumps.

{}

Returns:

Name Type Description
str str

A JSON string representation of the given inference and uncertainty data.

Source code in fl_server_ai/uncertainty/base.py
@classmethod
def to_json(cls, inference: torch.Tensor, uncertainty: Dict[str, Any] = {}, **json_kwargs) -> str:
    """
    Convert the given inference and uncertainty data to a JSON string.

    Args:
        inference (torch.Tensor): The inference to convert.
        uncertainty (Dict[str, Any]): The uncertainty to convert.
        **json_kwargs: Additional keyword arguments to pass to `json.dumps`.

    Returns:
        str: A JSON string representation of the given inference and uncertainty data.
    """
    def prepare(v):
        if isinstance(v, torch.Tensor):
            return v.cpu().tolist()
        if isinstance(v, np.ndarray):  # cspell:ignore ndarray
            return v.tolist()
        if isinstance(v, dict):
            return {k: prepare(_v) for k, _v in v.items()}
        return v

    return json.dumps({
        "inference": inference.tolist(),
        "uncertainty": prepare(uncertainty) if uncertainty else {},
    }, **json_kwargs)

Functions

get_uncertainty_class

get_uncertainty_class(value: Model) -> Type[UncertaintyBase]
get_uncertainty_class(value: Training) -> Type[UncertaintyBase]
get_uncertainty_class(value: UncertaintyMethod) -> Type[UncertaintyBase]
get_uncertainty_class(value: Model | Training | UncertaintyMethod) -> Type[UncertaintyBase]

Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

Parameters:

Name Type Description Default

value

Model | Training | UncertaintyMethod

The object to retrieve the uncertainty class for.

required

Returns:

Type Description
Type[UncertaintyBase]

Type[UncertaintyBase]: The uncertainty class associated with the given object.

Raises:

Type Description
ValueError

If the given object is not a Model, Training, or UncertaintyMethod, or if the uncertainty method associated with the object is unknown.

Source code in fl_server_ai/uncertainty/method.py
def get_uncertainty_class(value: Model | Training | UncertaintyMethod) -> Type[UncertaintyBase]:
    """
    Get uncertainty class associated with a given Model, Training, or UncertaintyMethod object.

    Args:
        value (Model | Training | UncertaintyMethod): The object to retrieve the uncertainty class for.

    Returns:
        Type[UncertaintyBase]: The uncertainty class associated with the given object.

    Raises:
        ValueError: If the given object is not a Model, Training, or UncertaintyMethod,
                    or if the uncertainty method associated with the object is unknown.
    """
    if isinstance(value, UncertaintyMethod):
        method = value
    elif isinstance(value, Training):
        method = value.uncertainty_method
    elif isinstance(value, Model):
        uncertainty_method = Training.objects.filter(model=value) \
                .values("uncertainty_method") \
                .first()["uncertainty_method"]
        method = uncertainty_method
    else:
        raise ValueError(f"Unknown type: {type(value)}")

    match method:
        case UncertaintyMethod.ENSEMBLE: return Ensemble
        case UncertaintyMethod.MC_DROPOUT: return MCDropout
        case UncertaintyMethod.NONE: return NoneUncertainty
        case UncertaintyMethod.SWAG: return SWAG
        case _: raise ValueError(f"Unknown uncertainty method: {method}")