Skip to content

fl_server_ai.tests.test_ai_worker

Classes:

Name Description
AiWorkerTest

Classes

AiWorkerTest

Bases: TestCase


              flowchart TD
              fl_server_ai.tests.test_ai_worker.AiWorkerTest[AiWorkerTest]

              

              click fl_server_ai.tests.test_ai_worker.AiWorkerTest href "" "fl_server_ai.tests.test_ai_worker.AiWorkerTest"
            

Methods:

Name Description
test_process_training
test_process_training_bad1
test_process_training_bad2
test_process_training_good
Source code in fl_server_ai/tests/test_ai_worker.py
class AiWorkerTest(TestCase):

    def test_process_training_good(self):
        base_model = Dummy.create_model(weights=from_torch_module(_create_torchscript_model_and_init(0)))
        training = Dummy.create_training(state=TrainingState.ONGOING, locked=True, model=base_model)
        model1 = _create_torchscript_model_and_init(0)
        model2 = _create_torchscript_model_and_init(1)

        Dummy.create_model_update(
            base_model=training.model,
            owner=training.participants.all()[0],
            round=0,
            weights=from_torch_module(model1),  # torchscript model
        )
        Dummy.create_model_update(
            base_model=training.model,
            owner=training.participants.all()[1],
            round=0,
            weights=from_torch_module(model2)  # torchscript model
        )
        with self.assertLogs(level="INFO"):
            ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)
        res = training.model.get_torch_model().state_dict()
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 0.5)
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 0.5)

    @patch("fl_server_ai.aggregation.mean.MeanAggregation.aggregate")
    def test_process_training_bad1(self, aggregate_updates: MagicMock):
        base_model = Dummy.create_model(weights=from_torch_module(_create_torchscript_model_and_init(0)))
        training = Dummy.create_training(state=TrainingState.ONGOING, locked=True, model=base_model)
        model1 = _create_torchscript_model_and_init(0)
        model2 = _create_torchscript_model_and_init(1)
        model3 = _create_torchscript_model_and_init(500)
        Dummy.create_model_update(
            base_model=training.model,
            owner=training.participants.all()[0],
            round=0,
            weights=from_torch_module(model1),  # torchscript model
        )
        Dummy.create_model_update(
            base_model=training.model,
            owner=training.participants.all()[1],
            round=0,
            weights=from_torch_module(model2)  # torchscript model
        )
        Dummy.create_model_update(
            base_model=training.model,
            owner=training.participants.all()[1],
            round=0,
            weights=from_torch_module(model3)  # torchscript model
        )
        with self.assertLogs(level="ERROR"):
            with self.assertRaises(RuntimeError):
                ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)

        self.assertFalse(aggregate_updates.called)

    @patch("fl_server_ai.aggregation.mean.MeanAggregation.aggregate")
    def test_process_training_bad2(self, aggregate_updates: MagicMock):
        base_model = Dummy.create_model(weights=from_torch_module(_create_torchscript_model_and_init(0)))
        training = Dummy.create_training(state=TrainingState.ONGOING, locked=True, model=base_model)
        model1 = _create_torchscript_model_and_init(0)
        Dummy.create_model_update(
            base_model=training.model,
            owner=training.participants.all()[0],
            round=0,
            weights=from_torch_module(model1),  # torchscript model
        )
        with self.assertLogs(level="ERROR"), self.assertRaises(RuntimeError):
            ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)

        self.assertFalse(aggregate_updates.called)

    @patch.object(Notification, "send")
    def test_process_training(self, send_notification):
        model = Dummy.create_model(
            model_cls=SWAGModel,
            weights=from_torch_module(torch.jit.script(torch.nn.Sequential(
                # this model has exactly 15 parameters
                torch.nn.Linear(1, 5),
                torch.nn.ReLU(),
                torch.nn.Linear(5, 1)
            ))),
            round=100
        )
        training = Dummy.create_training(
            state=TrainingState.SWAG_ROUND,
            uncertainty_method=UncertaintyMethod.SWAG,
            locked=True,
            model=model,
            target_num_updates=100
        )
        Dummy.create_metric(
            step=100,
            key="SWAG First Moment Local",
            model=model,
            reporter=training.participants.all()[0],
            value_binary=from_torch_tensor(torch.zeros(15))
        )
        Dummy.create_metric(
            step=100,
            key="SWAG Second Moment Local",
            model=model,
            reporter=training.participants.all()[0],
            value_binary=from_torch_tensor(torch.ones(15))
        )
        Dummy.create_metric(
            step=100,
            key="SWAG Sample Size Local",
            model=model,
            reporter=training.participants.all()[0],
            value_float=1000
        )
        Dummy.create_metric(
            step=100,
            key="SWAG First Moment Local",
            model=model,
            reporter=training.participants.all()[1],
            value_binary=from_torch_tensor(torch.zeros(15))
        )
        Dummy.create_metric(
            step=100,
            key="SWAG Second Moment Local",
            model=model,
            reporter=training.participants.all()[1],
            value_binary=from_torch_tensor(torch.ones(15))
        )
        Dummy.create_metric(
            step=100,
            key="SWAG Sample Size Local",
            model=model,
            reporter=training.participants.all()[1],
            value_float=1000
        )
        with self.assertLogs("fl.server", level="INFO") as cm:
            process_trainer_task(training.id, SWAGRoundFinished)
        self.assertEqual(cm.output, [
            f"INFO:fl.server:Training {training.id}: Doing SWAG aggregation as all 2 updates arrived",
            f"INFO:fl.server:SWAG completed for training {training.id}",
        ])
        self.assertTrue(send_notification.called)
        training = Training.objects.get(id=training.id)
        self.assertFalse(training.locked)
        model = training.model
        self.assertEqual(TrainingState.ONGOING, training.state)  # next would be ModelTestFinished
        fst = model.first_moment
        snd = model.second_moment
        torch.testing.assert_close(torch.zeros(15), fst)
        torch.testing.assert_close(torch.ones(15), snd)

Functions

test_process_training
test_process_training(send_notification)
Source code in fl_server_ai/tests/test_ai_worker.py
@patch.object(Notification, "send")
def test_process_training(self, send_notification):
    model = Dummy.create_model(
        model_cls=SWAGModel,
        weights=from_torch_module(torch.jit.script(torch.nn.Sequential(
            # this model has exactly 15 parameters
            torch.nn.Linear(1, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 1)
        ))),
        round=100
    )
    training = Dummy.create_training(
        state=TrainingState.SWAG_ROUND,
        uncertainty_method=UncertaintyMethod.SWAG,
        locked=True,
        model=model,
        target_num_updates=100
    )
    Dummy.create_metric(
        step=100,
        key="SWAG First Moment Local",
        model=model,
        reporter=training.participants.all()[0],
        value_binary=from_torch_tensor(torch.zeros(15))
    )
    Dummy.create_metric(
        step=100,
        key="SWAG Second Moment Local",
        model=model,
        reporter=training.participants.all()[0],
        value_binary=from_torch_tensor(torch.ones(15))
    )
    Dummy.create_metric(
        step=100,
        key="SWAG Sample Size Local",
        model=model,
        reporter=training.participants.all()[0],
        value_float=1000
    )
    Dummy.create_metric(
        step=100,
        key="SWAG First Moment Local",
        model=model,
        reporter=training.participants.all()[1],
        value_binary=from_torch_tensor(torch.zeros(15))
    )
    Dummy.create_metric(
        step=100,
        key="SWAG Second Moment Local",
        model=model,
        reporter=training.participants.all()[1],
        value_binary=from_torch_tensor(torch.ones(15))
    )
    Dummy.create_metric(
        step=100,
        key="SWAG Sample Size Local",
        model=model,
        reporter=training.participants.all()[1],
        value_float=1000
    )
    with self.assertLogs("fl.server", level="INFO") as cm:
        process_trainer_task(training.id, SWAGRoundFinished)
    self.assertEqual(cm.output, [
        f"INFO:fl.server:Training {training.id}: Doing SWAG aggregation as all 2 updates arrived",
        f"INFO:fl.server:SWAG completed for training {training.id}",
    ])
    self.assertTrue(send_notification.called)
    training = Training.objects.get(id=training.id)
    self.assertFalse(training.locked)
    model = training.model
    self.assertEqual(TrainingState.ONGOING, training.state)  # next would be ModelTestFinished
    fst = model.first_moment
    snd = model.second_moment
    torch.testing.assert_close(torch.zeros(15), fst)
    torch.testing.assert_close(torch.ones(15), snd)
test_process_training_bad1
test_process_training_bad1(aggregate_updates: MagicMock)
Source code in fl_server_ai/tests/test_ai_worker.py
@patch("fl_server_ai.aggregation.mean.MeanAggregation.aggregate")
def test_process_training_bad1(self, aggregate_updates: MagicMock):
    base_model = Dummy.create_model(weights=from_torch_module(_create_torchscript_model_and_init(0)))
    training = Dummy.create_training(state=TrainingState.ONGOING, locked=True, model=base_model)
    model1 = _create_torchscript_model_and_init(0)
    model2 = _create_torchscript_model_and_init(1)
    model3 = _create_torchscript_model_and_init(500)
    Dummy.create_model_update(
        base_model=training.model,
        owner=training.participants.all()[0],
        round=0,
        weights=from_torch_module(model1),  # torchscript model
    )
    Dummy.create_model_update(
        base_model=training.model,
        owner=training.participants.all()[1],
        round=0,
        weights=from_torch_module(model2)  # torchscript model
    )
    Dummy.create_model_update(
        base_model=training.model,
        owner=training.participants.all()[1],
        round=0,
        weights=from_torch_module(model3)  # torchscript model
    )
    with self.assertLogs(level="ERROR"):
        with self.assertRaises(RuntimeError):
            ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)

    self.assertFalse(aggregate_updates.called)
test_process_training_bad2
test_process_training_bad2(aggregate_updates: MagicMock)
Source code in fl_server_ai/tests/test_ai_worker.py
@patch("fl_server_ai.aggregation.mean.MeanAggregation.aggregate")
def test_process_training_bad2(self, aggregate_updates: MagicMock):
    base_model = Dummy.create_model(weights=from_torch_module(_create_torchscript_model_and_init(0)))
    training = Dummy.create_training(state=TrainingState.ONGOING, locked=True, model=base_model)
    model1 = _create_torchscript_model_and_init(0)
    Dummy.create_model_update(
        base_model=training.model,
        owner=training.participants.all()[0],
        round=0,
        weights=from_torch_module(model1),  # torchscript model
    )
    with self.assertLogs(level="ERROR"), self.assertRaises(RuntimeError):
        ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)

    self.assertFalse(aggregate_updates.called)
test_process_training_good
test_process_training_good()
Source code in fl_server_ai/tests/test_ai_worker.py
def test_process_training_good(self):
    base_model = Dummy.create_model(weights=from_torch_module(_create_torchscript_model_and_init(0)))
    training = Dummy.create_training(state=TrainingState.ONGOING, locked=True, model=base_model)
    model1 = _create_torchscript_model_and_init(0)
    model2 = _create_torchscript_model_and_init(1)

    Dummy.create_model_update(
        base_model=training.model,
        owner=training.participants.all()[0],
        round=0,
        weights=from_torch_module(model1),  # torchscript model
    )
    Dummy.create_model_update(
        base_model=training.model,
        owner=training.participants.all()[1],
        round=0,
        weights=from_torch_module(model2)  # torchscript model
    )
    with self.assertLogs(level="INFO"):
        ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)
    res = training.model.get_torch_model().state_dict()
    torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 0.5)
    torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 0.5)

Functions