Skip to content

fl_server_ai.uncertainty.ensemble

Classes:

Name Description
Ensemble

Ensemble uncertainty estimation.

Classes

Ensemble

Bases: UncertaintyBase


              flowchart TD
              fl_server_ai.uncertainty.ensemble.Ensemble[Ensemble]
              fl_server_ai.uncertainty.base.UncertaintyBase[UncertaintyBase]

                              fl_server_ai.uncertainty.base.UncertaintyBase --> fl_server_ai.uncertainty.ensemble.Ensemble
                


              click fl_server_ai.uncertainty.ensemble.Ensemble href "" "fl_server_ai.uncertainty.ensemble.Ensemble"
              click fl_server_ai.uncertainty.base.UncertaintyBase href "" "fl_server_ai.uncertainty.base.UncertaintyBase"
            

Ensemble uncertainty estimation.

Methods:

Name Description
prediction
Source code in fl_server_ai/uncertainty/ensemble.py
class Ensemble(UncertaintyBase):
    """
    Ensemble uncertainty estimation.
    """

    @classmethod
    def prediction(cls, input: torch.Tensor, model: MeanModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
        output_list = []
        for m in model.models.all():
            net = m.get_torch_model()
            output = net(input).detach()
            output_list.append(output)
        outputs = torch.stack(output_list, dim=0)  # (N, batch_size, n_classes)  # N = number of models

        inference = outputs.mean(dim=0)
        uncertainty = cls.interpret(outputs)
        return inference, uncertainty

Functions

prediction classmethod
prediction(input: Tensor, model: MeanModel) -> tuple[Tensor, dict[str, Any]]
Source code in fl_server_ai/uncertainty/ensemble.py
@classmethod
def prediction(cls, input: torch.Tensor, model: MeanModel) -> Tuple[torch.Tensor, Dict[str, Any]]:
    output_list = []
    for m in model.models.all():
        net = m.get_torch_model()
        output = net(input).detach()
        output_list.append(output)
    outputs = torch.stack(output_list, dim=0)  # (N, batch_size, n_classes)  # N = number of models

    inference = outputs.mean(dim=0)
    uncertainty = cls.interpret(outputs)
    return inference, uncertainty