Skip to content

fl_server_ai.tests.test_uncertainty_swag

Classes:

Name Description
SwagTest

Classes

SwagTest

Bases: TestCase


              flowchart TD
              fl_server_ai.tests.test_uncertainty_swag.SwagTest[SwagTest]

              

              click fl_server_ai.tests.test_uncertainty_swag.SwagTest href "" "fl_server_ai.tests.test_uncertainty_swag.SwagTest"
            

Methods:

Name Description
test_handle_swag_round_finished
test_prediction
test_start_swag_round
test_start_swag_round_via_handle
test_trainer_class
test_trainer_type
Source code in fl_server_ai/tests/test_uncertainty_swag.py
class SwagTest(TestCase):

    def test_trainer_class(self):
        training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
        trainer_cls = get_trainer_class(training)
        self.assertTrue(trainer_cls is SWAGModelTrainer)

    def test_trainer_type(self):
        training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
        trainer = get_trainer(training)
        self.assertTrue(type(trainer) is SWAGModelTrainer)

    @patch.object(TrainingSWAGRoundStartNotification, "send")
    def test_start_swag_round(self, send_method):
        training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
        trainer = get_trainer(training)
        assert type(trainer) is SWAGModelTrainer
        trainer.start_swag_round()
        self.assertEqual(TrainingState.SWAG_ROUND, training.state)
        self.assertTrue(send_method.called)

    @patch.object(TrainingRoundFinished, "handle")
    @patch.object(TrainingRoundFinished, "next")
    @patch.object(TrainingSWAGRoundStartNotification, "send")
    def test_start_swag_round_via_handle(self, send_method, next_method, handle_method):
        training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
        trainer = get_trainer(training)
        assert type(trainer) is SWAGModelTrainer
        event = TrainingRoundFinished(trainer)
        trainer.handle(event)
        self.assertEqual(TrainingState.SWAG_ROUND, training.state)
        self.assertTrue(handle_method.called)
        self.assertFalse(next_method.called)
        self.assertTrue(send_method.called)

    @patch.object(SWAGRoundFinished, "handle")
    @patch.object(TrainingRoundFinished, "next")
    def test_handle_swag_round_finished(self, base_cls_next_method, handle_method):
        training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
        trainer = get_trainer(training)
        assert type(trainer) is SWAGModelTrainer
        event = SWAGRoundFinished(trainer)
        trainer.handle(event)
        self.assertEqual(TrainingState.ONGOING, training.state)
        self.assertTrue(handle_method.called)
        self.assertTrue(base_cls_next_method.called)

    def test_prediction(self):
        model = from_torch_module(torch.jit.script(
            torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Tanh())
        ))
        first_moment = from_torch_tensor(torch.tensor([1.0, 0.0]))
        second_moment = from_torch_tensor(torch.tensor([1.0, 0.0]))
        model = Dummy.create_model(SWAGModel, weights=model, swag_first_moment=first_moment,
                                   swag_second_moment=second_moment)
        Dummy.create_training(
            model=model,
            uncertainty_method=UncertaintyMethod.SWAG,
            options=dict(uncertainty={"N": 10})
        )
        X = torch.tensor([[-4.0], [-2.0], [2.0], [4.0]])
        y = torch.tensor([-1.0, -1.0, 1.0, 1.0])
        logits, _ = SWAG.prediction(X, model)
        torch.testing.assert_close(y, torch.sign(torch.squeeze(logits)))

Functions

test_handle_swag_round_finished
test_handle_swag_round_finished(base_cls_next_method, handle_method)
Source code in fl_server_ai/tests/test_uncertainty_swag.py
@patch.object(SWAGRoundFinished, "handle")
@patch.object(TrainingRoundFinished, "next")
def test_handle_swag_round_finished(self, base_cls_next_method, handle_method):
    training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
    trainer = get_trainer(training)
    assert type(trainer) is SWAGModelTrainer
    event = SWAGRoundFinished(trainer)
    trainer.handle(event)
    self.assertEqual(TrainingState.ONGOING, training.state)
    self.assertTrue(handle_method.called)
    self.assertTrue(base_cls_next_method.called)
test_prediction
test_prediction()
Source code in fl_server_ai/tests/test_uncertainty_swag.py
def test_prediction(self):
    model = from_torch_module(torch.jit.script(
        torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Tanh())
    ))
    first_moment = from_torch_tensor(torch.tensor([1.0, 0.0]))
    second_moment = from_torch_tensor(torch.tensor([1.0, 0.0]))
    model = Dummy.create_model(SWAGModel, weights=model, swag_first_moment=first_moment,
                               swag_second_moment=second_moment)
    Dummy.create_training(
        model=model,
        uncertainty_method=UncertaintyMethod.SWAG,
        options=dict(uncertainty={"N": 10})
    )
    X = torch.tensor([[-4.0], [-2.0], [2.0], [4.0]])
    y = torch.tensor([-1.0, -1.0, 1.0, 1.0])
    logits, _ = SWAG.prediction(X, model)
    torch.testing.assert_close(y, torch.sign(torch.squeeze(logits)))
test_start_swag_round
test_start_swag_round(send_method)
Source code in fl_server_ai/tests/test_uncertainty_swag.py
@patch.object(TrainingSWAGRoundStartNotification, "send")
def test_start_swag_round(self, send_method):
    training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
    trainer = get_trainer(training)
    assert type(trainer) is SWAGModelTrainer
    trainer.start_swag_round()
    self.assertEqual(TrainingState.SWAG_ROUND, training.state)
    self.assertTrue(send_method.called)
test_start_swag_round_via_handle
test_start_swag_round_via_handle(send_method, next_method, handle_method)
Source code in fl_server_ai/tests/test_uncertainty_swag.py
@patch.object(TrainingRoundFinished, "handle")
@patch.object(TrainingRoundFinished, "next")
@patch.object(TrainingSWAGRoundStartNotification, "send")
def test_start_swag_round_via_handle(self, send_method, next_method, handle_method):
    training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
    trainer = get_trainer(training)
    assert type(trainer) is SWAGModelTrainer
    event = TrainingRoundFinished(trainer)
    trainer.handle(event)
    self.assertEqual(TrainingState.SWAG_ROUND, training.state)
    self.assertTrue(handle_method.called)
    self.assertFalse(next_method.called)
    self.assertTrue(send_method.called)
test_trainer_class
test_trainer_class()
Source code in fl_server_ai/tests/test_uncertainty_swag.py
def test_trainer_class(self):
    training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
    trainer_cls = get_trainer_class(training)
    self.assertTrue(trainer_cls is SWAGModelTrainer)
test_trainer_type
test_trainer_type()
Source code in fl_server_ai/tests/test_uncertainty_swag.py
def test_trainer_type(self):
    training = Dummy.create_training(uncertainty_method=UncertaintyMethod.SWAG)
    trainer = get_trainer(training)
    self.assertTrue(type(trainer) is SWAGModelTrainer)

Functions