Skip to content

fl_server_ai.uncertainty.swag

Classes:

Name Description
SWAG

Stochastic Weight Averaging Gaussian (SWAG) uncertainty estimation.

Classes

SWAG

Bases: UncertaintyBase


              flowchart TD
              fl_server_ai.uncertainty.swag.SWAG[SWAG]
              fl_server_ai.uncertainty.base.UncertaintyBase[UncertaintyBase]

                              fl_server_ai.uncertainty.base.UncertaintyBase --> fl_server_ai.uncertainty.swag.SWAG
                


              click fl_server_ai.uncertainty.swag.SWAG href "" "fl_server_ai.uncertainty.swag.SWAG"
              click fl_server_ai.uncertainty.base.UncertaintyBase href "" "fl_server_ai.uncertainty.base.UncertaintyBase"
            

Stochastic Weight Averaging Gaussian (SWAG) uncertainty estimation.

Methods:

Name Description
prediction
Source code in fl_server_ai/uncertainty/swag.py
class SWAG(UncertaintyBase):
    """
    Stochastic Weight Averaging Gaussian (SWAG) uncertainty estimation.
    """

    @classmethod
    def prediction(cls, input: torch.Tensor, model: SWAGModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
        options = cls.get_options(model)
        N = options.get("N", 10)

        net: torch.nn.Module = model.get_torch_model()

        # first and second moment are already ensured to be in
        # alphabetical order in the database
        fm = model.first_moment
        sm = model.second_moment
        std = sm - torch.pow(fm, 2)
        params = torch.normal(mean=fm[None, :], std=std).expand(N, -1)

        prediction_list = []
        for n in range(N):
            torch.nn.utils.vector_to_parameters(params[n], net.parameters())
            prediction = net(input)
            prediction_list.append(prediction)
        predictions = torch.stack(prediction_list)

        inference = predictions.mean(dim=0)
        uncertainty = cls.interpret(predictions)
        return inference, uncertainty

Functions

prediction classmethod
prediction(input: Tensor, model: SWAGModel) -> tuple[Tensor, dict[str, Any]]
Source code in fl_server_ai/uncertainty/swag.py
@classmethod
def prediction(cls, input: torch.Tensor, model: SWAGModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
    options = cls.get_options(model)
    N = options.get("N", 10)

    net: torch.nn.Module = model.get_torch_model()

    # first and second moment are already ensured to be in
    # alphabetical order in the database
    fm = model.first_moment
    sm = model.second_moment
    std = sm - torch.pow(fm, 2)
    params = torch.normal(mean=fm[None, :], std=std).expand(N, -1)

    prediction_list = []
    for n in range(N):
        torch.nn.utils.vector_to_parameters(params[n], net.parameters())
        prediction = net(input)
        prediction_list.append(prediction)
    predictions = torch.stack(prediction_list)

    inference = predictions.mean(dim=0)
    uncertainty = cls.interpret(predictions)
    return inference, uncertainty