Skip to content

fl_server_api.tests.test_training

Classes:

Name Description
TrainingTests

Classes

TrainingTests

Bases: TestCase


              flowchart TD
              fl_server_api.tests.test_training.TrainingTests[TrainingTests]

              

              click fl_server_api.tests.test_training.TrainingTests href "" "fl_server_api.tests.test_training.TrainingTests"
            

Methods:

Name Description
setUp
test_create_training
test_create_training_invalid_aggregation_method
test_create_training_not_model_owner
test_create_training_with_clients
test_create_training_with_trained_model
test_delete_non_existing_training
test_delete_training_as_actor
test_delete_training_as_other_user
test_delete_training_as_participant
test_get_training_bad
test_get_training_good
test_get_trainings
test_register_clients_bad
test_register_clients_good
test_remove_clients_good
test_start_training
test_start_training_no_participants
test_start_training_not_initial_state
Source code in fl_server_api/tests/test_training.py
class TrainingTests(TestCase):

    def setUp(self):
        self.user = Dummy.create_user_and_authenticate(self.client)

    def test_create_training(self):
        model = Dummy.create_model(owner=self.user)
        request_body = dict(
            model_id=str(model.id),
            target_num_updates=100,
            metric_names=["accuracy", "f1_score"],
            uncertainty_method="NONE",
            aggregation_method="FedAvg"
        )
        response = self.client.post(
            f"{BASE_URL}/trainings/",
            data=json.dumps(request_body),
            content_type="application/json"
        )
        self.assertEqual(201, response.status_code)
        response_json = response.json()
        self.assertIsNotNone(response_json)
        self.assertEqual("Training created successfully!", response_json["detail"])

    def test_create_training_with_clients(self):
        model = Dummy.create_model(owner=self.user)
        clients = [Dummy.create_client(username=f"client-{n}") for n in range(3)]
        request_body = dict(
            model_id=str(model.id),
            target_num_updates=100,
            metric_names=["accuracy", "f1_score"],
            uncertainty_method="NONE",
            aggregation_method="FedAvg",
            clients=list(map(lambda c: str(c.id), clients))
        )
        response = self.client.post(
            f"{BASE_URL}/trainings/",
            data=json.dumps(request_body),
            content_type="application/json"
        )
        self.assertEqual(201, response.status_code)
        response_json = response.json()
        self.assertIsNotNone(response_json)
        self.assertEqual("Training created successfully!", response_json["detail"])

    def test_create_training_invalid_aggregation_method(self):
        model = Dummy.create_model(owner=self.user)
        request_body = dict(
            model_id=str(model.id),
            target_num_updates=100,
            metric_names=["accuracy", "f1_score"],
            uncertainty_method="NONE",
            aggregation_method="INVALID"
        )
        with self.assertLogs("root", level="WARNING"):
            response = self.client.post(
                f"{BASE_URL}/trainings/",
                data=json.dumps(request_body),
                content_type="application/json"
            )
        self.assertEqual(400, response.status_code)

    def test_create_training_not_model_owner(self):
        model = Dummy.create_model()
        request_body = dict(
            model_id=str(model.id),
            target_num_updates=100,
            metric_names=["accuracy", "f1_score"],
            uncertainty_method="NONE",
            aggregation_method="FedAvg"
        )
        with self.assertLogs("django.request", level="WARNING") as cm:
            response = self.client.post(
                f"{BASE_URL}/trainings/",
                data=json.dumps(request_body),
                content_type="application/json"
            )
        self.assertEqual(cm.output, [
            "WARNING:django.request:Forbidden: /api/trainings/",
        ])
        self.assertEqual(403, response.status_code)
        response_json = response.json()
        self.assertIsNotNone(response_json)
        self.assertEqual("You do not have permission to perform this action.", response_json["detail"])

    def test_get_trainings(self):
        # make user actor
        self.user.actor = True
        self.user.save()
        # create trainings - some related to user some not
        [Dummy.create_training() for _ in range(3)]
        trainings = [Dummy.create_training(actor=self.user) for _ in range(3)]
        # get user related trainings
        response = self.client.get(f"{BASE_URL}/trainings/")
        self.assertEqual(200, response.status_code)
        self.assertEqual("application/json", response["content-type"])
        response_json = response.json()
        self.assertEqual(len(trainings), len(response_json))
        self.assertEqual(
            sorted([str(training.id) for training in trainings]),
            sorted([training["id"] for training in response_json])
        )

    def test_get_training_good(self):
        training = Dummy.create_training(actor=self.user)
        response = self.client.get(f"{BASE_URL}/trainings/{training.id}/")
        self.assertEqual(response.status_code, 200)
        body = response.json()
        self.assertEqual(str(training.actor.id), body["actor"])
        self.assertEqual(TrainingState.INITIAL, body["state"])
        self.assertEqual(0, body["target_num_updates"])

    def test_get_training_bad(self):
        training = Dummy.create_training()
        with self.assertLogs("root", level="WARNING"):
            response = self.client.get(f"{BASE_URL}/trainings/{training.id}/")
        self.assertEqual(response.status_code, 403)

    def test_delete_training_as_actor(self):
        training = Dummy.create_training(actor=self.user)
        response = self.client.delete(f"{BASE_URL}/trainings/{training.id}/")
        self.assertEqual(response.status_code, 200)
        body = response.json()
        self.assertEqual("Training removed!", body["detail"])
        self.assertRaises(ObjectDoesNotExist, Training.objects.get, pk=training.id)

    def test_delete_training_as_participant(self):
        participants = [Dummy.create_client(), self.user, Dummy.create_client()]
        training = Dummy.create_training(participants=participants)
        with self.assertLogs("django.request", level="WARNING"):
            response = self.client.delete(f"{BASE_URL}/trainings/{training.id}/")
        self.assertEqual(response.status_code, 403)
        body = response.json()
        self.assertEqual("You are not the owner the training.", body["detail"])
        self.assertIsNotNone(Training.objects.get(pk=training.id))

    def test_delete_training_as_other_user(self):
        training = Dummy.create_training()
        with self.assertLogs("django.request", level="WARNING"):
            response = self.client.delete(f"{BASE_URL}/trainings/{training.id}/")
        self.assertEqual(response.status_code, 403)
        body = response.json()
        self.assertEqual("You are not the owner the training.", body["detail"])
        self.assertIsNotNone(Training.objects.get(pk=training.id))

    def test_delete_non_existing_training(self):
        training_id = str(uuid4())
        with self.assertLogs("django.request", level="WARNING"):
            response = self.client.delete(f"{BASE_URL}/trainings/{training_id}/")
        self.assertEqual(response.status_code, 400)
        body = response.json()
        self.assertEqual(f"Training {training_id} not found.", body["detail"])

    def test_register_clients_good(self):
        training = Dummy.create_training(actor=self.user)
        users = [str(Dummy.create_user(username=f"client{i}").id) for i in range(1, 5)]
        request_body = dict(clients=users)
        response = self.client.put(
            f"{BASE_URL}/trainings/{training.id}/clients/",
            json.dumps(request_body)
        )
        self.assertEqual(response.status_code, 202)
        body = response.json()
        self.assertEqual("Users registered as participants!", body["detail"])

    def test_register_clients_bad(self):
        training = Dummy.create_training(actor=self.user)
        users = [str(Dummy.create_user(username=f"client{i}").id) for i in range(1, 5)] + [str(uuid4())]
        request_body = dict(clients=users)
        with self.assertLogs("root", level="WARNING"):
            response = self.client.put(
                f"{BASE_URL}/trainings/{training.id}/clients/",
                json.dumps(request_body)
            )
        self.assertEqual(response.status_code, 400)
        self.assertIsNotNone(response.content)
        response_body = response.json()
        self.assertEqual("Not all provided users were found!", response_body["detail"])

    def test_remove_clients_good(self):
        training = Dummy.create_training(actor=self.user)
        users = [str(t.id) for t in training.participants.all()]
        assert users
        request_body = dict(clients=users)

        response = self.client.delete(
            f"{BASE_URL}/trainings/{training.id}/clients/",
            json.dumps(request_body)
        )
        self.assertEqual(response.status_code, 200)
        response = self.client.get(f"{BASE_URL}/trainings/{training.id}/")
        self.assertEqual(response.status_code, 200)
        body = response.json()
        self.assertEqual(0, len(body["participants"]))

    @patch("fl_server_ai.notification.notification.send_notifications.apply_async")
    def test_start_training(self, apply_async: MagicMock):
        user = Dummy.create_user(message_endpoint="http://example.com")
        training = Dummy.create_training(actor=self.user)
        training.participants.set([user])
        training.save()
        response = self.client.post(f"{BASE_URL}/trainings/{training.id}/start/")
        self.assertEqual(response.status_code, 202)
        self.assertEqual(2, apply_async.call_count)  # TrainingStartNotification, TrainingRoundStartNotification

    def test_start_training_no_participants(self):
        training = Dummy.create_training(actor=self.user)
        training.participants.set([])
        training.save()
        with self.assertLogs("root", level="WARNING"):
            response = self.client.post(f"{BASE_URL}/trainings/{training.id}/start/")
        self.assertEqual(response.status_code, 400)
        self.assertIsNotNone(response.content)
        response_body = response.json()
        self.assertEqual("At least one participant must be registered!", response_body["detail"])

    def test_start_training_not_initial_state(self):
        user = Dummy.create_user(message_endpoint="http://example.com")
        training = Dummy.create_training(actor=self.user, state=TrainingState.ONGOING)
        training.participants.set([user])
        training.save()
        with self.assertLogs("django.request", level="WARNING") as cm:
            response = self.client.post(f"{BASE_URL}/trainings/{training.id}/start/")
        self.assertEqual(cm.output, [
            f"WARNING:django.request:Bad Request: /api/trainings/{training.id}/start/",
        ])
        self.assertEqual(response.status_code, 400)
        self.assertIsNotNone(response.content)
        response_body = response.json()
        self.assertEqual(f"Training {training.id} is not in state INITIAL!", response_body["detail"])

    def test_create_training_with_trained_model(self):
        training = Dummy.create_training(actor=self.user)
        model = training.model
        request_body = dict(
            model_id=str(model.id),
            target_num_updates=100,
            metric_names=["accuracy", "f1_score"],
            uncertainty_method="NONE",
            aggregation_method="FedAvg"
        )
        response = self.client.post(
            f"{BASE_URL}/trainings/",
            data=json.dumps(request_body),
            content_type="application/json"
        )
        self.assertEqual(201, response.status_code)
        response_json = response.json()
        self.assertIsNotNone(response_json)
        self.assertEqual("Training created successfully!", response_json["detail"])

        response = self.client.get(f"{BASE_URL}/trainings/{response_json['training_id']}/")
        self.assertEqual(200, response.status_code)
        response_json = response.json()
        self.assertIsNotNone(response_json)
        self.assertNotEqual(model.id, response_json['id'])

Functions

setUp
setUp()
Source code in fl_server_api/tests/test_training.py
def setUp(self):
    self.user = Dummy.create_user_and_authenticate(self.client)
test_create_training
test_create_training()
Source code in fl_server_api/tests/test_training.py
def test_create_training(self):
    model = Dummy.create_model(owner=self.user)
    request_body = dict(
        model_id=str(model.id),
        target_num_updates=100,
        metric_names=["accuracy", "f1_score"],
        uncertainty_method="NONE",
        aggregation_method="FedAvg"
    )
    response = self.client.post(
        f"{BASE_URL}/trainings/",
        data=json.dumps(request_body),
        content_type="application/json"
    )
    self.assertEqual(201, response.status_code)
    response_json = response.json()
    self.assertIsNotNone(response_json)
    self.assertEqual("Training created successfully!", response_json["detail"])
test_create_training_invalid_aggregation_method
test_create_training_invalid_aggregation_method()
Source code in fl_server_api/tests/test_training.py
def test_create_training_invalid_aggregation_method(self):
    model = Dummy.create_model(owner=self.user)
    request_body = dict(
        model_id=str(model.id),
        target_num_updates=100,
        metric_names=["accuracy", "f1_score"],
        uncertainty_method="NONE",
        aggregation_method="INVALID"
    )
    with self.assertLogs("root", level="WARNING"):
        response = self.client.post(
            f"{BASE_URL}/trainings/",
            data=json.dumps(request_body),
            content_type="application/json"
        )
    self.assertEqual(400, response.status_code)
test_create_training_not_model_owner
test_create_training_not_model_owner()
Source code in fl_server_api/tests/test_training.py
def test_create_training_not_model_owner(self):
    model = Dummy.create_model()
    request_body = dict(
        model_id=str(model.id),
        target_num_updates=100,
        metric_names=["accuracy", "f1_score"],
        uncertainty_method="NONE",
        aggregation_method="FedAvg"
    )
    with self.assertLogs("django.request", level="WARNING") as cm:
        response = self.client.post(
            f"{BASE_URL}/trainings/",
            data=json.dumps(request_body),
            content_type="application/json"
        )
    self.assertEqual(cm.output, [
        "WARNING:django.request:Forbidden: /api/trainings/",
    ])
    self.assertEqual(403, response.status_code)
    response_json = response.json()
    self.assertIsNotNone(response_json)
    self.assertEqual("You do not have permission to perform this action.", response_json["detail"])
test_create_training_with_clients
test_create_training_with_clients()
Source code in fl_server_api/tests/test_training.py
def test_create_training_with_clients(self):
    model = Dummy.create_model(owner=self.user)
    clients = [Dummy.create_client(username=f"client-{n}") for n in range(3)]
    request_body = dict(
        model_id=str(model.id),
        target_num_updates=100,
        metric_names=["accuracy", "f1_score"],
        uncertainty_method="NONE",
        aggregation_method="FedAvg",
        clients=list(map(lambda c: str(c.id), clients))
    )
    response = self.client.post(
        f"{BASE_URL}/trainings/",
        data=json.dumps(request_body),
        content_type="application/json"
    )
    self.assertEqual(201, response.status_code)
    response_json = response.json()
    self.assertIsNotNone(response_json)
    self.assertEqual("Training created successfully!", response_json["detail"])
test_create_training_with_trained_model
test_create_training_with_trained_model()
Source code in fl_server_api/tests/test_training.py
def test_create_training_with_trained_model(self):
    training = Dummy.create_training(actor=self.user)
    model = training.model
    request_body = dict(
        model_id=str(model.id),
        target_num_updates=100,
        metric_names=["accuracy", "f1_score"],
        uncertainty_method="NONE",
        aggregation_method="FedAvg"
    )
    response = self.client.post(
        f"{BASE_URL}/trainings/",
        data=json.dumps(request_body),
        content_type="application/json"
    )
    self.assertEqual(201, response.status_code)
    response_json = response.json()
    self.assertIsNotNone(response_json)
    self.assertEqual("Training created successfully!", response_json["detail"])

    response = self.client.get(f"{BASE_URL}/trainings/{response_json['training_id']}/")
    self.assertEqual(200, response.status_code)
    response_json = response.json()
    self.assertIsNotNone(response_json)
    self.assertNotEqual(model.id, response_json['id'])
test_delete_non_existing_training
test_delete_non_existing_training()
Source code in fl_server_api/tests/test_training.py
def test_delete_non_existing_training(self):
    training_id = str(uuid4())
    with self.assertLogs("django.request", level="WARNING"):
        response = self.client.delete(f"{BASE_URL}/trainings/{training_id}/")
    self.assertEqual(response.status_code, 400)
    body = response.json()
    self.assertEqual(f"Training {training_id} not found.", body["detail"])
test_delete_training_as_actor
test_delete_training_as_actor()
Source code in fl_server_api/tests/test_training.py
def test_delete_training_as_actor(self):
    training = Dummy.create_training(actor=self.user)
    response = self.client.delete(f"{BASE_URL}/trainings/{training.id}/")
    self.assertEqual(response.status_code, 200)
    body = response.json()
    self.assertEqual("Training removed!", body["detail"])
    self.assertRaises(ObjectDoesNotExist, Training.objects.get, pk=training.id)
test_delete_training_as_other_user
test_delete_training_as_other_user()
Source code in fl_server_api/tests/test_training.py
def test_delete_training_as_other_user(self):
    training = Dummy.create_training()
    with self.assertLogs("django.request", level="WARNING"):
        response = self.client.delete(f"{BASE_URL}/trainings/{training.id}/")
    self.assertEqual(response.status_code, 403)
    body = response.json()
    self.assertEqual("You are not the owner the training.", body["detail"])
    self.assertIsNotNone(Training.objects.get(pk=training.id))
test_delete_training_as_participant
test_delete_training_as_participant()
Source code in fl_server_api/tests/test_training.py
def test_delete_training_as_participant(self):
    participants = [Dummy.create_client(), self.user, Dummy.create_client()]
    training = Dummy.create_training(participants=participants)
    with self.assertLogs("django.request", level="WARNING"):
        response = self.client.delete(f"{BASE_URL}/trainings/{training.id}/")
    self.assertEqual(response.status_code, 403)
    body = response.json()
    self.assertEqual("You are not the owner the training.", body["detail"])
    self.assertIsNotNone(Training.objects.get(pk=training.id))
test_get_training_bad
test_get_training_bad()
Source code in fl_server_api/tests/test_training.py
def test_get_training_bad(self):
    training = Dummy.create_training()
    with self.assertLogs("root", level="WARNING"):
        response = self.client.get(f"{BASE_URL}/trainings/{training.id}/")
    self.assertEqual(response.status_code, 403)
test_get_training_good
test_get_training_good()
Source code in fl_server_api/tests/test_training.py
def test_get_training_good(self):
    training = Dummy.create_training(actor=self.user)
    response = self.client.get(f"{BASE_URL}/trainings/{training.id}/")
    self.assertEqual(response.status_code, 200)
    body = response.json()
    self.assertEqual(str(training.actor.id), body["actor"])
    self.assertEqual(TrainingState.INITIAL, body["state"])
    self.assertEqual(0, body["target_num_updates"])
test_get_trainings
test_get_trainings()
Source code in fl_server_api/tests/test_training.py
def test_get_trainings(self):
    # make user actor
    self.user.actor = True
    self.user.save()
    # create trainings - some related to user some not
    [Dummy.create_training() for _ in range(3)]
    trainings = [Dummy.create_training(actor=self.user) for _ in range(3)]
    # get user related trainings
    response = self.client.get(f"{BASE_URL}/trainings/")
    self.assertEqual(200, response.status_code)
    self.assertEqual("application/json", response["content-type"])
    response_json = response.json()
    self.assertEqual(len(trainings), len(response_json))
    self.assertEqual(
        sorted([str(training.id) for training in trainings]),
        sorted([training["id"] for training in response_json])
    )
test_register_clients_bad
test_register_clients_bad()
Source code in fl_server_api/tests/test_training.py
def test_register_clients_bad(self):
    training = Dummy.create_training(actor=self.user)
    users = [str(Dummy.create_user(username=f"client{i}").id) for i in range(1, 5)] + [str(uuid4())]
    request_body = dict(clients=users)
    with self.assertLogs("root", level="WARNING"):
        response = self.client.put(
            f"{BASE_URL}/trainings/{training.id}/clients/",
            json.dumps(request_body)
        )
    self.assertEqual(response.status_code, 400)
    self.assertIsNotNone(response.content)
    response_body = response.json()
    self.assertEqual("Not all provided users were found!", response_body["detail"])
test_register_clients_good
test_register_clients_good()
Source code in fl_server_api/tests/test_training.py
def test_register_clients_good(self):
    training = Dummy.create_training(actor=self.user)
    users = [str(Dummy.create_user(username=f"client{i}").id) for i in range(1, 5)]
    request_body = dict(clients=users)
    response = self.client.put(
        f"{BASE_URL}/trainings/{training.id}/clients/",
        json.dumps(request_body)
    )
    self.assertEqual(response.status_code, 202)
    body = response.json()
    self.assertEqual("Users registered as participants!", body["detail"])
test_remove_clients_good
test_remove_clients_good()
Source code in fl_server_api/tests/test_training.py
def test_remove_clients_good(self):
    training = Dummy.create_training(actor=self.user)
    users = [str(t.id) for t in training.participants.all()]
    assert users
    request_body = dict(clients=users)

    response = self.client.delete(
        f"{BASE_URL}/trainings/{training.id}/clients/",
        json.dumps(request_body)
    )
    self.assertEqual(response.status_code, 200)
    response = self.client.get(f"{BASE_URL}/trainings/{training.id}/")
    self.assertEqual(response.status_code, 200)
    body = response.json()
    self.assertEqual(0, len(body["participants"]))
test_start_training
test_start_training(apply_async: MagicMock)
Source code in fl_server_api/tests/test_training.py
@patch("fl_server_ai.notification.notification.send_notifications.apply_async")
def test_start_training(self, apply_async: MagicMock):
    user = Dummy.create_user(message_endpoint="http://example.com")
    training = Dummy.create_training(actor=self.user)
    training.participants.set([user])
    training.save()
    response = self.client.post(f"{BASE_URL}/trainings/{training.id}/start/")
    self.assertEqual(response.status_code, 202)
    self.assertEqual(2, apply_async.call_count)  # TrainingStartNotification, TrainingRoundStartNotification
test_start_training_no_participants
test_start_training_no_participants()
Source code in fl_server_api/tests/test_training.py
def test_start_training_no_participants(self):
    training = Dummy.create_training(actor=self.user)
    training.participants.set([])
    training.save()
    with self.assertLogs("root", level="WARNING"):
        response = self.client.post(f"{BASE_URL}/trainings/{training.id}/start/")
    self.assertEqual(response.status_code, 400)
    self.assertIsNotNone(response.content)
    response_body = response.json()
    self.assertEqual("At least one participant must be registered!", response_body["detail"])
test_start_training_not_initial_state
test_start_training_not_initial_state()
Source code in fl_server_api/tests/test_training.py
def test_start_training_not_initial_state(self):
    user = Dummy.create_user(message_endpoint="http://example.com")
    training = Dummy.create_training(actor=self.user, state=TrainingState.ONGOING)
    training.participants.set([user])
    training.save()
    with self.assertLogs("django.request", level="WARNING") as cm:
        response = self.client.post(f"{BASE_URL}/trainings/{training.id}/start/")
    self.assertEqual(cm.output, [
        f"WARNING:django.request:Bad Request: /api/trainings/{training.id}/start/",
    ])
    self.assertEqual(response.status_code, 400)
    self.assertIsNotNone(response.content)
    response_body = response.json()
    self.assertEqual(f"Training {training.id} is not in state INITIAL!", response_body["detail"])