Skip to content

fl_server_ai.uncertainty.mc_dropout

Classes:

Name Description
MCDropout

Monte Carlo (MC) Dropout Uncertainty Estimation

Functions:

Name Description
set_dropout

Set the state of the dropout layers to enable or disable them even during inference.

Classes

MCDropout

Bases: UncertaintyBase


              flowchart TD
              fl_server_ai.uncertainty.mc_dropout.MCDropout[MCDropout]
              fl_server_ai.uncertainty.base.UncertaintyBase[UncertaintyBase]

                              fl_server_ai.uncertainty.base.UncertaintyBase --> fl_server_ai.uncertainty.mc_dropout.MCDropout
                


              click fl_server_ai.uncertainty.mc_dropout.MCDropout href "" "fl_server_ai.uncertainty.mc_dropout.MCDropout"
              click fl_server_ai.uncertainty.base.UncertaintyBase href "" "fl_server_ai.uncertainty.base.UncertaintyBase"
            

Monte Carlo (MC) Dropout Uncertainty Estimation

Requirements:

  • model with dropout layers
  • T, number of samples per input (number of monte-carlo samples/forward passes)

References:

Methods:

Name Description
prediction
Source code in fl_server_ai/uncertainty/mc_dropout.py
class MCDropout(UncertaintyBase):
    """
    Monte Carlo (MC) Dropout Uncertainty Estimation

    Requirements:

    - model with dropout layers
    - T, number of samples per input (number of monte-carlo samples/forward passes)

    References:

    - Paper: Understanding Measures of Uncertainty for Adversarial Example Detection
             <https://arxiv.org/abs/1803.08533>
    - Code inspiration: <https://github.com/lsgos/uncertainty-adversarial-paper/tree/master>
    """

    @classmethod
    def prediction(cls, input: Tensor, model: Model) -> Tuple[torch.Tensor, Dict[str, Any]]:
        options = cls.get_options(model)
        N = options.get("N", 10)
        softmax = options.get("softmax", False)

        net: Module = model.get_torch_model()
        net.eval()
        set_dropout(net, state=True)

        out_list = []
        for _ in range(N):
            output = net(input).detach()
            # convert to probabilities if necessary
            if softmax:
                output = torch.softmax(output, dim=1)
            out_list.append(output)
        out = torch.stack(out_list, dim=0)  # (n_mc, batch_size, n_classes)

        inference = out.mean(dim=0)
        uncertainty = cls.interpret(out)
        return inference, uncertainty

Functions

prediction classmethod
prediction(input: Tensor, model: Model) -> tuple[Tensor, dict[str, Any]]
Source code in fl_server_ai/uncertainty/mc_dropout.py
@classmethod
def prediction(cls, input: Tensor, model: Model) -> Tuple[torch.Tensor, Dict[str, Any]]:
    options = cls.get_options(model)
    N = options.get("N", 10)
    softmax = options.get("softmax", False)

    net: Module = model.get_torch_model()
    net.eval()
    set_dropout(net, state=True)

    out_list = []
    for _ in range(N):
        output = net(input).detach()
        # convert to probabilities if necessary
        if softmax:
            output = torch.softmax(output, dim=1)
        out_list.append(output)
    out = torch.stack(out_list, dim=0)  # (n_mc, batch_size, n_classes)

    inference = out.mean(dim=0)
    uncertainty = cls.interpret(out)
    return inference, uncertainty

Functions

set_dropout

set_dropout(model: Module, state: bool = True)

Set the state of the dropout layers to enable or disable them even during inference.

Parameters:

Name Type Description Default

model

Module

PyTorch module

required

state

bool

Enable or disable dropout layers. Defaults to True.

True
Source code in fl_server_ai/uncertainty/mc_dropout.py
def set_dropout(model: Module, state: bool = True):
    """
    Set the state of the dropout layers to enable or disable them even during inference.

    Args:
        model (Module): PyTorch module
        state (bool, optional): Enable or disable dropout layers. Defaults to True.
    """
    is_torchscript_model = is_torchscript_instance(model)
    for m in model.modules():
        name = m.original_name if is_torchscript_model else m.__class__.__name__
        assert isinstance(name, str)
        if isinstance(m, _DropoutNd) or name.lower().__contains__("dropout"):
            m.train(mode=state)