Bases: UncertaintyBase
flowchart TD
fl_server_ai.uncertainty.swag.SWAG[SWAG]
fl_server_ai.uncertainty.base.UncertaintyBase[UncertaintyBase]
fl_server_ai.uncertainty.base.UncertaintyBase --> fl_server_ai.uncertainty.swag.SWAG
click fl_server_ai.uncertainty.swag.SWAG href "" "fl_server_ai.uncertainty.swag.SWAG"
click fl_server_ai.uncertainty.base.UncertaintyBase href "" "fl_server_ai.uncertainty.base.UncertaintyBase"
Stochastic Weight Averaging Gaussian (SWAG) uncertainty estimation.
Methods:
Source code in fl_server_ai/uncertainty/swag.py
| class SWAG(UncertaintyBase):
"""
Stochastic Weight Averaging Gaussian (SWAG) uncertainty estimation.
"""
@classmethod
def prediction(cls, input: torch.Tensor, model: SWAGModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
options = cls.get_options(model)
N = options.get("N", 10)
net: torch.nn.Module = model.get_torch_model()
# first and second moment are already ensured to be in
# alphabetical order in the database
fm = model.first_moment
sm = model.second_moment
std = sm - torch.pow(fm, 2)
params = torch.normal(mean=fm[None, :], std=std).expand(N, -1)
prediction_list = []
for n in range(N):
torch.nn.utils.vector_to_parameters(params[n], net.parameters())
prediction = net(input)
prediction_list.append(prediction)
predictions = torch.stack(prediction_list)
inference = predictions.mean(dim=0)
uncertainty = cls.interpret(predictions)
return inference, uncertainty
|
Functions
prediction classmethod
Source code in fl_server_ai/uncertainty/swag.py
| @classmethod
def prediction(cls, input: torch.Tensor, model: SWAGModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
options = cls.get_options(model)
N = options.get("N", 10)
net: torch.nn.Module = model.get_torch_model()
# first and second moment are already ensured to be in
# alphabetical order in the database
fm = model.first_moment
sm = model.second_moment
std = sm - torch.pow(fm, 2)
params = torch.normal(mean=fm[None, :], std=std).expand(N, -1)
prediction_list = []
for n in range(N):
torch.nn.utils.vector_to_parameters(params[n], net.parameters())
prediction = net(input)
prediction_list.append(prediction)
predictions = torch.stack(prediction_list)
inference = predictions.mean(dim=0)
uncertainty = cls.interpret(predictions)
return inference, uncertainty
|