Bases: TransactionTestCase
flowchart TD
fl_server_ai.tests.test_aggregation_process.AggregationProcessTest[AggregationProcessTest]
click fl_server_ai.tests.test_aggregation_process.AggregationProcessTest href "" "fl_server_ai.tests.test_aggregation_process.AggregationProcessTest"
Methods:
Source code in fl_server_ai/tests/test_aggregation_process.py
| class AggregationProcessTest(TransactionTestCase):
def setUp(self):
self.user = Dummy.create_user_and_authenticate(self.client)
def test_check_and_run_aggregation_if_applicable_training_step(self):
clients = [Dummy.create_user() for _ in range(10)]
training = Dummy.create_training(state=TrainingState.ONGOING, target_num_updates=10, actor=self.user)
training.participants.set(clients)
updates = [Dummy.create_model_update(base_model=training.model, round=0) for _ in range(10)]
for update in updates:
update.save()
with self.assertLogs("fl.server", level="INFO"):
ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)
response = self.client.get(f"{BASE_URL}/models/{training.model.id}/metadata/")
content = json.loads(response.content)
self.assertEqual(content["round"], 1)
training = Training.objects.get(id=training.id)
self.assertEqual(TrainingState.ONGOING, training.state)
def test_check_and_run_aggregation_if_applicable_training_finished(self):
clients = [Dummy.create_user() for _ in range(10)]
model = Dummy.create_model(owner=self.user, round=1)
training = Dummy.create_training(state=TrainingState.ONGOING, target_num_updates=1, actor=self.user,
model=model)
training.participants.set(clients)
updates = [Dummy.create_model_update(base_model=training.model, round=1) for _ in range(10)]
for update in updates:
update.save()
options = TrainerOptions(skip_model_tests=True)
trainer = ModelTrainer(training, options)
with self.assertLogs("fl.server", level="INFO"):
trainer.handle_cls(TrainingRoundFinished)
training = Training.objects.get(id=training.id)
self.assertEqual(TrainingState.COMPLETED, training.state)
|
Functions
setUp
Source code in fl_server_ai/tests/test_aggregation_process.py
| def setUp(self):
self.user = Dummy.create_user_and_authenticate(self.client)
|
test_check_and_run_aggregation_if_applicable_training_finished
test_check_and_run_aggregation_if_applicable_training_finished()
Source code in fl_server_ai/tests/test_aggregation_process.py
| def test_check_and_run_aggregation_if_applicable_training_finished(self):
clients = [Dummy.create_user() for _ in range(10)]
model = Dummy.create_model(owner=self.user, round=1)
training = Dummy.create_training(state=TrainingState.ONGOING, target_num_updates=1, actor=self.user,
model=model)
training.participants.set(clients)
updates = [Dummy.create_model_update(base_model=training.model, round=1) for _ in range(10)]
for update in updates:
update.save()
options = TrainerOptions(skip_model_tests=True)
trainer = ModelTrainer(training, options)
with self.assertLogs("fl.server", level="INFO"):
trainer.handle_cls(TrainingRoundFinished)
training = Training.objects.get(id=training.id)
self.assertEqual(TrainingState.COMPLETED, training.state)
|
test_check_and_run_aggregation_if_applicable_training_step
test_check_and_run_aggregation_if_applicable_training_step()
Source code in fl_server_ai/tests/test_aggregation_process.py
| def test_check_and_run_aggregation_if_applicable_training_step(self):
clients = [Dummy.create_user() for _ in range(10)]
training = Dummy.create_training(state=TrainingState.ONGOING, target_num_updates=10, actor=self.user)
training.participants.set(clients)
updates = [Dummy.create_model_update(base_model=training.model, round=0) for _ in range(10)]
for update in updates:
update.save()
with self.assertLogs("fl.server", level="INFO"):
ModelTrainer(training, TrainerOptions(skip_model_tests=True)).handle_cls(TrainingRoundFinished)
response = self.client.get(f"{BASE_URL}/models/{training.model.id}/metadata/")
content = json.loads(response.content)
self.assertEqual(content["round"], 1)
training = Training.objects.get(id=training.id)
self.assertEqual(TrainingState.ONGOING, training.state)
|