Skip to content

fl_server_ai.tests.test_aggregation_process

Classes:

Name Description
AggregationProcessTest

Classes

AggregationProcessTest

Bases: TransactionTestCase


              flowchart TD
              fl_server_ai.tests.test_aggregation_process.AggregationProcessTest[AggregationProcessTest]

              

              click fl_server_ai.tests.test_aggregation_process.AggregationProcessTest href "" "fl_server_ai.tests.test_aggregation_process.AggregationProcessTest"
            

Methods:

Name Description
setUp
test_check_and_run_aggregation_if_applicable_training_finished
test_check_and_run_aggregation_if_applicable_training_step
Source code in fl_server_ai/tests/test_aggregation_process.py
class AggregationProcessTest(TransactionTestCase):

    def setUp(self):
        self.user = Dummy.create_user_and_authenticate(self.client)

    def test_check_and_run_aggregation_if_applicable_training_step(self):
        clients = [Dummy.create_user() for _ in range(10)]
        training = Dummy.create_training(state=TrainingState.ONGOING, target_num_updates=10, actor=self.user)
        training.participants.set(clients)
        updates = [Dummy.create_model_update(base_model=training.model, round=0) for _ in range(10)]
        for update in updates:
            update.save()

        with self.assertLogs("fl.server", level="INFO"):
            ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)

        response = self.client.get(f"{BASE_URL}/models/{training.model.id}/metadata/")
        content = json.loads(response.content)
        self.assertEqual(content["round"], 1)

        training = Training.objects.get(id=training.id)
        self.assertEqual(TrainingState.ONGOING, training.state)

    def test_check_and_run_aggregation_if_applicable_training_finished(self):
        clients = [Dummy.create_user() for _ in range(10)]
        model = Dummy.create_model(owner=self.user, round=1)
        training = Dummy.create_training(state=TrainingState.ONGOING, target_num_updates=1, actor=self.user,
                                         model=model)
        training.participants.set(clients)
        updates = [Dummy.create_model_update(base_model=training.model, round=1) for _ in range(10)]
        for update in updates:
            update.save()

        options = TrainerOptions(skip_model_tests=True)
        trainer = ModelTrainer(training, options)
        with self.assertLogs("fl.server", level="INFO"):
            trainer.handle_cls(TrainingRoundFinished)

        training = Training.objects.get(id=training.id)
        self.assertEqual(TrainingState.COMPLETED, training.state)

Functions

setUp
setUp()
Source code in fl_server_ai/tests/test_aggregation_process.py
def setUp(self):
    self.user = Dummy.create_user_and_authenticate(self.client)
test_check_and_run_aggregation_if_applicable_training_finished
test_check_and_run_aggregation_if_applicable_training_finished()
Source code in fl_server_ai/tests/test_aggregation_process.py
def test_check_and_run_aggregation_if_applicable_training_finished(self):
    clients = [Dummy.create_user() for _ in range(10)]
    model = Dummy.create_model(owner=self.user, round=1)
    training = Dummy.create_training(state=TrainingState.ONGOING, target_num_updates=1, actor=self.user,
                                     model=model)
    training.participants.set(clients)
    updates = [Dummy.create_model_update(base_model=training.model, round=1) for _ in range(10)]
    for update in updates:
        update.save()

    options = TrainerOptions(skip_model_tests=True)
    trainer = ModelTrainer(training, options)
    with self.assertLogs("fl.server", level="INFO"):
        trainer.handle_cls(TrainingRoundFinished)

    training = Training.objects.get(id=training.id)
    self.assertEqual(TrainingState.COMPLETED, training.state)
test_check_and_run_aggregation_if_applicable_training_step
test_check_and_run_aggregation_if_applicable_training_step()
Source code in fl_server_ai/tests/test_aggregation_process.py
def test_check_and_run_aggregation_if_applicable_training_step(self):
    clients = [Dummy.create_user() for _ in range(10)]
    training = Dummy.create_training(state=TrainingState.ONGOING, target_num_updates=10, actor=self.user)
    training.participants.set(clients)
    updates = [Dummy.create_model_update(base_model=training.model, round=0) for _ in range(10)]
    for update in updates:
        update.save()

    with self.assertLogs("fl.server", level="INFO"):
        ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)

    response = self.client.get(f"{BASE_URL}/models/{training.model.id}/metadata/")
    content = json.loads(response.content)
    self.assertEqual(content["round"], 1)

    training = Training.objects.get(id=training.id)
    self.assertEqual(TrainingState.ONGOING, training.state)