Module fl_server_api.tests.test_model¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from django.core.exceptions import ObjectDoesNotExist
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import TransactionTestCase
import io
import pickle
import torch
import torchvision
from torchvision.transforms import v2 as transforms
from unittest.mock import MagicMock, patch
from uuid import uuid4
from fl_server_ai.trainer.events import SWAGRoundFinished, TrainingRoundFinished
from fl_server_api.utils import get_entity
from fl_server_core.models import GlobalModel, MeanModel, Model, SWAGModel
from fl_server_core.models.training import Training, TrainingState
from fl_server_core.tests import BASE_URL, Dummy
from fl_server_core.utils.torch_serialization import from_torch_module
class ModelTests(TransactionTestCase):
def setUp(self):
self.user = Dummy.create_user_and_authenticate(self.client)
def test_unauthorized(self):
del self.client.defaults["HTTP_AUTHORIZATION"]
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/",
{"model_file": b"Hello World!"}
)
self.assertEqual(401, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Authentication credentials were not provided.", response_json["detail"])
def test_get_all_models(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create models and trainings - some related to user some not
[Dummy.create_model() for _ in range(2)]
models = [Dummy.create_model(owner=self.user) for _ in range(2)]
[Dummy.create_training() for _ in range(2)]
trainings = [Dummy.create_training(actor=self.user) for _ in range(2)]
trainings += [Dummy.create_training(participants=[self.user]) for _ in range(2)]
models += [t.model for t in trainings]
# get user related models
response = self.client.get(f"{BASE_URL}/models/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models), len(response_json))
self.assertEqual(
sorted([str(model.id) for model in models]),
sorted([model["id"] for model in response_json])
)
def test_get_all_models_for_a_training(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create participants
participants = [Dummy.create_user() for _ in range(4)]
participant_rounds = [3, 4, 4, 3]
# create models and trainings - some related to user some not
[Dummy.create_training() for _ in range(2)]
[Dummy.create_training(actor=self.user) for _ in range(2)]
[Dummy.create_training(participants=[self.user]) for _ in range(2)]
[Dummy.create_model_update() for _ in range(2)]
[Dummy.create_model_update(owner=self.user) for _ in range(2)]
training = Dummy.create_training(actor=self.user, participants=participants)
# create model update for 4 users
base_model = training.model
models = [base_model]
for participant, rounds in zip(participants, participant_rounds):
for round_idx in range(rounds):
model = Dummy.create_model_update(base_model=base_model, owner=participant, round=round_idx+1)
models.append(model)
# get user related models for a special training
response = self.client.get(f"{BASE_URL}/trainings/{training.pk}/models/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models), len(response_json))
self.assertEqual(
sorted([str(model.id) for model in models]),
sorted([model["id"] for model in response_json])
)
def test_get_all_models_for_a_training_latest_only(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create participants
participants = [Dummy.create_user() for _ in range(4)]
participant_rounds = [3, 4, 4, 3]
# create models and trainings - some related to user some not
[Dummy.create_training() for _ in range(2)]
[Dummy.create_training(actor=self.user) for _ in range(2)]
[Dummy.create_training(participants=[self.user]) for _ in range(2)]
[Dummy.create_model_update() for _ in range(2)]
[Dummy.create_model_update(owner=self.user) for _ in range(2)]
training = Dummy.create_training(actor=self.user, participants=participants)
# create model update for 4 users
base_model = training.model
models_latest = [base_model]
for participant, rounds in zip(participants, participant_rounds):
for round_idx in range(rounds):
model = Dummy.create_model_update(base_model=base_model, owner=participant, round=round_idx+1)
models_latest.append(model)
# get user related "latest" models for a special training
response = self.client.get(f"{BASE_URL}/trainings/{training.pk}/models/latest/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models_latest), len(response_json))
models_latest = sorted(models_latest, key=lambda m: str(m.pk))
response_models = sorted(response_json, key=lambda m: m["id"])
self.assertEqual(
[str(model.id) for model in models_latest],
[model["id"] for model in response_models]
)
self.assertEqual(
[model.round for model in models_latest],
[model["round"] for model in response_models]
)
def test_get_model_metadata(self):
model_bytes = from_torch_module(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model = Dummy.create_model(weights=model_bytes, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
self.assertFalse(response_json["has_preprocessing"])
# check stats
stats = response_json["stats"]
self.assertIsNotNone(stats)
self.assertEqual([[1, 3]], stats["input_size"])
self.assertIsNotNone(stats["total_input"])
self.assertIsNotNone(stats["total_mult_adds"])
self.assertIsNotNone(stats["total_output_bytes"])
self.assertIsNotNone(stats["total_param_bytes"])
self.assertIsNotNone(stats["total_params"])
self.assertIsNotNone(stats["trainable_params"])
# layer 1 stats
layer1 = stats["summary_list"][0]
self.assertEqual("Sequential", layer1["class_name"])
self.assertEqual(0, layer1["depth"])
self.assertEqual(1, layer1["depth_index"])
self.assertEqual(True, layer1["executed"])
self.assertEqual("Sequential", layer1["var_name"])
self.assertEqual(False, layer1["is_leaf_layer"])
self.assertEqual(False, layer1["contains_lazy_param"])
self.assertEqual(False, layer1["is_recursive"])
self.assertEqual([1, 3], layer1["input_size"])
self.assertEqual([1, 1], layer1["output_size"])
self.assertEqual(None, layer1["kernel_size"])
self.assertIsNotNone(layer1["trainable_params"])
self.assertIsNotNone(layer1["num_params"])
self.assertIsNotNone(layer1["param_bytes"])
self.assertIsNotNone(layer1["output_bytes"])
self.assertIsNotNone(layer1["macs"])
# layer 2 stats
layer2 = stats["summary_list"][1]
self.assertEqual("Linear", layer2["class_name"])
self.assertEqual(1, layer2["depth"])
self.assertEqual(1, layer2["depth_index"])
self.assertEqual(True, layer2["executed"])
self.assertEqual("0", layer2["var_name"])
self.assertEqual(True, layer2["is_leaf_layer"])
self.assertEqual(False, layer2["contains_lazy_param"])
self.assertEqual(False, layer2["is_recursive"])
self.assertEqual([1, 3], layer2["input_size"])
self.assertEqual([1, 64], layer2["output_size"])
self.assertEqual(None, layer2["kernel_size"])
self.assertIsNotNone(layer2["trainable_params"])
self.assertIsNotNone(layer2["num_params"])
self.assertIsNotNone(layer2["param_bytes"])
self.assertIsNotNone(layer2["output_bytes"])
self.assertIsNotNone(layer2["macs"])
# layer 3 stats
layer3 = stats["summary_list"][2]
self.assertEqual("ELU", layer3["class_name"])
self.assertEqual(1, layer3["depth"])
self.assertEqual(2, layer3["depth_index"])
self.assertEqual(True, layer3["executed"])
self.assertEqual("1", layer3["var_name"])
self.assertEqual(True, layer3["is_leaf_layer"])
self.assertEqual(False, layer3["contains_lazy_param"])
self.assertEqual(False, layer3["is_recursive"])
self.assertEqual([1, 64], layer3["input_size"])
self.assertEqual([1, 64], layer3["output_size"])
self.assertEqual(None, layer3["kernel_size"])
self.assertIsNotNone(layer3["trainable_params"])
self.assertIsNotNone(layer3["num_params"])
self.assertIsNotNone(layer3["param_bytes"])
self.assertIsNotNone(layer3["output_bytes"])
self.assertIsNotNone(layer3["macs"])
# layer 4 stats
layer4 = stats["summary_list"][3]
self.assertEqual("Linear", layer4["class_name"])
self.assertEqual(1, layer4["depth"])
self.assertEqual(3, layer4["depth_index"])
self.assertEqual(True, layer4["executed"])
self.assertEqual("2", layer4["var_name"])
self.assertEqual(True, layer4["is_leaf_layer"])
self.assertEqual(False, layer4["contains_lazy_param"])
self.assertEqual(False, layer4["is_recursive"])
self.assertEqual([1, 64], layer4["input_size"])
self.assertEqual([1, 1], layer4["output_size"])
self.assertEqual(None, layer4["kernel_size"])
self.assertIsNotNone(layer4["trainable_params"])
self.assertIsNotNone(layer4["num_params"])
self.assertIsNotNone(layer4["param_bytes"])
self.assertIsNotNone(layer4["output_bytes"])
self.assertIsNotNone(layer4["macs"])
def test_get_model_metadata_with_preprocessing(self):
model_bytes = from_torch_module(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
torch_model_preprocessing = from_torch_module(transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=(0.,), std=(1.,)),
]))
model = Dummy.create_model(weights=model_bytes, preprocessing=torch_model_preprocessing, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
self.assertTrue(response_json["has_preprocessing"])
def test_get_model_metadata_torchscript_model(self):
torchscript_model_bytes = from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
)))
model = Dummy.create_model(weights=torchscript_model_bytes, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
# check stats
stats = response_json["stats"]
self.assertIsNotNone(stats)
self.assertEqual([[1, 3]], stats["input_size"])
self.assertIsNotNone(stats["total_input"])
self.assertIsNotNone(stats["total_mult_adds"])
self.assertIsNotNone(stats["total_output_bytes"])
self.assertIsNotNone(stats["total_param_bytes"])
self.assertIsNotNone(stats["total_params"])
self.assertIsNotNone(stats["trainable_params"])
self.assertEqual(4, len(stats["summary_list"]))
def test_get_model(self):
model = Dummy.create_model(weights=b"Hello World!")
response = self.client.get(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
self.assertEqual(b"Hello World!", response.getvalue())
def test_delete_model_without_training_as_model_owner(self):
model = Dummy.create_model(owner=self.user)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
def test_delete_global_model_with_training_as_model_owner(self):
model = Dummy.create_model(owner=self.user)
training = Dummy.create_training(model=model)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
# due to cascade delete (in the case of GlobalModel), training should also be deleted
self.assertRaises(ObjectDoesNotExist, Training.objects.get, pk=training.id)
def test_delete_local_model_with_training_as_model_owner(self):
global_model = Dummy.create_model()
local_model = Dummy.create_model_update(base_model=global_model, owner=self.user)
training = Dummy.create_training(model=global_model)
response = self.client.delete(f"{BASE_URL}/models/{local_model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=local_model.id)
self.assertIsNotNone(Model.objects.get(pk=global_model.id))
self.assertIsNotNone(Training.objects.get(pk=training.id))
def test_delete_model_as_training_owner(self):
model = Dummy.create_model()
training = Dummy.create_training(model=model, actor=self.user)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
# due to cascade delete (in the case of GlobalModel), training should also be deleted
self.assertRaises(ObjectDoesNotExist, Training.objects.get, pk=training.id)
def test_delete_model_as_training_participant(self):
model = Dummy.create_model()
Dummy.create_training(model=model, participants=[Dummy.create_client(), self.user, Dummy.create_client()])
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
def test_delete_model_with_training_as_unrelated_user(self):
model = Dummy.create_model()
Dummy.create_training(model=model)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
def test_delete_model_without_training_as_unrelated_user(self):
model = Dummy.create_model()
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
def test_delete_non_existing_model(self):
model_id = str(uuid4())
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model_id}/")
self.assertEqual(400, response.status_code)
body = response.json()
self.assertEqual(f"Model {model_id} not found.", body["detail"])
def test_get_model_and_unpickle(self):
model = Dummy.create_model()
response = self.client.get(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
torch_model = torch.jit.load(io.BytesIO(response.content))
self.assertIsNotNone(torch_model)
self.assertTrue(isinstance(torch_model, torch.nn.Module))
def test_upload(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"model_file": model_file,
"name": "Test Model",
"description": "Test Model Description - Test Model Description Test",
"input_shape": [None, 3]
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(GlobalModel, type(Model.objects.get(id=uuid)))
self.assertEqual([None, 3], Model.objects.get(id=uuid).input_shape)
def test_upload_swag_model(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"type": "SWAG",
"model_file": model_file,
"name": "Test SWAG Model",
"description": "Test SWAG Model Description - Test SWAG Model Description Test",
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(SWAGModel, type(Model.objects.get(id=uuid)))
def test_upload_mean_model(self):
models = [Dummy.create_model(owner=self.user) for _ in range(10)]
model_uuids = [str(m.id) for m in models]
response = self.client.post(f"{BASE_URL}/models/", {
"type": "MEAN",
"name": "Test MEAN Model",
"description": "Test MEAN Model Description - Test MEAN Model Description Test",
"models": model_uuids,
}, "application/json")
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(MeanModel, type(Model.objects.get(id=uuid)))
def test_upload_with_preprocessing(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
torch_model_preprocessing = torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
))
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"model_file": model_file,
"model_preprocessing_file": model_preprocessing_file,
"name": "Test Model",
"description": "Test Model Description - Test Model Description Test",
"input_shape": [None, 3]
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(GlobalModel, type(Model.objects.get(id=uuid)))
self.assertEqual([None, 3], Model.objects.get(id=uuid).input_shape)
model = get_entity(GlobalModel, pk=uuid)
self.assertIsNotNone(model)
self.assertIsNotNone(model.weights)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_torch_model(), torch.nn.Module))
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
def test_upload_model_preprocessing(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
torch_model_preprocessing = torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
))
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(202, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Proprocessing Model Upload Accepted", response_json["detail"])
model.refresh_from_db()
self.assertIsNotNone(model)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
def test_upload_model_preprocessing_v1_Compose_bad(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
# torchvision.transforms.Compose (v1 not v2) does not inherit from torch.nn.Module!!
torch_model_preprocessing = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.,), std=(1.,)),
])
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # (normal) transforms.Compose
content_type="application/octet-stream"
)
with self.assertLogs("fl.server", level="ERROR"): # Loaded torch object is not of expected type.
with self.assertLogs("django.request", level="WARNING"): # Bad Request
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(
"Invalid preprocessing file: Loaded torch object is not of expected type.",
response_json[0],
)
def test_upload_model_preprocessing_v2_Compose_good(self):
# Maybe good now
model = Dummy.create_model(owner=self.user, preprocessing=None)
torch_model_preprocessing = transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=(0.,), std=(1.,)),
])
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # (normal) transforms.Compose
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(202, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Proprocessing Model Upload Accepted", response_json["detail"])
model.refresh_from_db()
self.assertIsNotNone(model)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
def test_download_model_preprocessing(self):
torch_model_preprocessing = from_torch_module(torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
)))
model = Dummy.create_model(owner=self.user, preprocessing=torch_model_preprocessing)
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
torch_model = torch.jit.load(io.BytesIO(response.content))
self.assertIsNotNone(torch_model)
self.assertTrue(isinstance(torch_model, torch.nn.Module))
def test_download_model_preprocessing_with_undefined_preprocessing(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
with self.assertLogs("django.request", level="WARNING") as cm:
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
self.assertEqual(cm.output, [
f"WARNING:django.request:Not Found: /api/models/{model.id}/preprocessing/",
])
self.assertEqual(404, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"Model '{model.id}' has no preprocessing model defined.", response_json["detail"])
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertFalse(apply_async.called)
def test_upload_update_bad_keys(self):
model = Dummy.create_model(owner=self.user, round=0)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"xXx_model_file_xXx": model_update_file, "round": 0, "sample_size": 100}
)
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("No uploaded file 'model_file' found.", response_json["detail"])
def test_upload_update_no_training(self):
model = Dummy.create_model(owner=self.user, round=0)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0, "sample_size": 100}
)
self.assertEqual(404, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"Model with ID {model.id} does not have a training process running", response_json["detail"])
def test_upload_update_no_participant(self):
self.client.defaults["HTTP_ACCEPT"] = "application/json"
actor = Dummy.create_actor()
model = Dummy.create_model(owner=actor, round=0)
training = Dummy.create_training(
model=model, actor=actor, state=TrainingState.ONGOING,
participants=[actor, Dummy.create_client()]
)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 500}
)
self.assertEqual(403, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"You are not a participant of training {training.id}!", response_json["detail"])
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update_and_aggregate(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
train = Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertTrue(apply_async.called)
apply_async.assert_called_once_with(
(),
{"training_uuid": train.id, "event_cls": TrainingRoundFinished},
retry=False
)
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update_and_not_aggregate_since_training_is_locked(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
training = Dummy.create_training(
model=model, actor=self.user, state=TrainingState.ONGOING, participants=[self.user]
)
training.locked = True
training.save()
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertFalse(apply_async.called)
def test_upload_update_with_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{
"model_file": model_update_file,
"round": 0,
"metric_names": ["loss", "accuracy", "dummy_binary"],
"metric_values": [1999.0, 0.12, b"Hello World!"],
"sample_size": 50
},
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
def test_upload_update_with_metrics_bad(self):
model = Dummy.create_model(owner=self.user)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0, "metric_names": 5,
"sample_size": 500}
)
self.assertEqual(400, response.status_code)
def test_upload_global_model_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, 0.12, b"Hello World!"],
)
with self.assertLogs("fl.server", level="WARNING") as cm:
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(cm.output, [
f"WARNING:fl.server:Global model {model.id} is not connected to any training.",
])
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Metrics Upload Accepted", response_json["detail"])
self.assertEqual(str(model.id), response_json["model_id"])
def test_upload_local_model_metrics(self):
model = Dummy.create_model_update(owner=self.user)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, 0.12, b"Hello World!"],
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Metrics Upload Accepted", response_json["detail"])
self.assertEqual(str(model.id), response_json["model_id"])
def test_upload_bad_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, b"Hello World!"],
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Metric names and values must have the same length", response_json["detail"])
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_swag_stats(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
train = Dummy.create_training(
model=model,
actor=self.user,
state=TrainingState.SWAG_ROUND,
participants=[self.user]
)
first_moment_file = SimpleUploadedFile(
"first_moment.pkl",
pickle.dumps(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
).state_dict()),
content_type="application/octet-stream"
)
second_moment_file = SimpleUploadedFile(
"second_moment.pkl",
pickle.dumps(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
).state_dict()),
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/swag/", {
"first_moment_file": first_moment_file,
"second_moment_file": second_moment_file,
"sample_size": 100,
"round": 0
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("SWAG Statistic Accepted", response_json["detail"])
self.assertTrue(apply_async.called)
apply_async.assert_called_once_with(
(),
{"training_uuid": train.id, "event_cls": SWAGRoundFinished},
retry=False
)
def test_get_global_model_metrics(self):
model = Dummy.create_model(owner=self.user)
metric = Dummy.create_metric(model=model)
response = self.client.get(f"{BASE_URL}/models/{model.id}/metrics/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual(1, len(body))
self.assertEqual(metric.value_float, body[0]["value_float"])
self.assertEqual(metric.key, body[0]["key"])
def test_get_local_model_metrics(self):
model = Dummy.create_model_update(owner=self.user)
metric = Dummy.create_metric(model=model)
response = self.client.get(f"{BASE_URL}/models/{model.id}/metrics/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual(1, len(body))
self.assertEqual(metric.value_float, body[0]["value_float"])
self.assertEqual(metric.key, body[0]["key"])
Variables¶
Classes¶
ModelTests¶
A class whose instances are single test cases.
By default, the test code itself should be placed in a method named 'runTest'.
If the fixture may be used for many test cases, create as many test methods as are needed. When instantiating such a TestCase subclass, specify in the constructor arguments the name of the test method that the instance is to execute.
Test authors should subclass TestCase for their own tests. Construction and deconstruction of the test's environment ('fixture') can be implemented by overriding the 'setUp' and 'tearDown' methods respectively.
If it is necessary to override the init method, the base class init method must always be called. It is important that subclasses should not change the signature of their init method, since instances of the classes are instantiated automatically by parts of the framework in order to be run.
When subclassing TestCase, you can set these attributes: * failureException: determines which exception will be raised when the instance's assertion methods fail; test methods raising this exception will be deemed to have 'failed' rather than 'errored'. * longMessage: determines whether long messages (including repr of objects used in assert methods) will be printed on failure in addition to any explicit message passed. * maxDiff: sets the maximum length of a diff in failure messages by assert methods using difflib. It is looked up as an instance attribute so can be configured by individual tests if required.
View Source
class ModelTests(TransactionTestCase):
def setUp(self):
self.user = Dummy.create_user_and_authenticate(self.client)
def test_unauthorized(self):
del self.client.defaults["HTTP_AUTHORIZATION"]
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/",
{"model_file": b"Hello World!"}
)
self.assertEqual(401, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Authentication credentials were not provided.", response_json["detail"])
def test_get_all_models(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create models and trainings - some related to user some not
[Dummy.create_model() for _ in range(2)]
models = [Dummy.create_model(owner=self.user) for _ in range(2)]
[Dummy.create_training() for _ in range(2)]
trainings = [Dummy.create_training(actor=self.user) for _ in range(2)]
trainings += [Dummy.create_training(participants=[self.user]) for _ in range(2)]
models += [t.model for t in trainings]
# get user related models
response = self.client.get(f"{BASE_URL}/models/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models), len(response_json))
self.assertEqual(
sorted([str(model.id) for model in models]),
sorted([model["id"] for model in response_json])
)
def test_get_all_models_for_a_training(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create participants
participants = [Dummy.create_user() for _ in range(4)]
participant_rounds = [3, 4, 4, 3]
# create models and trainings - some related to user some not
[Dummy.create_training() for _ in range(2)]
[Dummy.create_training(actor=self.user) for _ in range(2)]
[Dummy.create_training(participants=[self.user]) for _ in range(2)]
[Dummy.create_model_update() for _ in range(2)]
[Dummy.create_model_update(owner=self.user) for _ in range(2)]
training = Dummy.create_training(actor=self.user, participants=participants)
# create model update for 4 users
base_model = training.model
models = [base_model]
for participant, rounds in zip(participants, participant_rounds):
for round_idx in range(rounds):
model = Dummy.create_model_update(base_model=base_model, owner=participant, round=round_idx+1)
models.append(model)
# get user related models for a special training
response = self.client.get(f"{BASE_URL}/trainings/{training.pk}/models/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models), len(response_json))
self.assertEqual(
sorted([str(model.id) for model in models]),
sorted([model["id"] for model in response_json])
)
def test_get_all_models_for_a_training_latest_only(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create participants
participants = [Dummy.create_user() for _ in range(4)]
participant_rounds = [3, 4, 4, 3]
# create models and trainings - some related to user some not
[Dummy.create_training() for _ in range(2)]
[Dummy.create_training(actor=self.user) for _ in range(2)]
[Dummy.create_training(participants=[self.user]) for _ in range(2)]
[Dummy.create_model_update() for _ in range(2)]
[Dummy.create_model_update(owner=self.user) for _ in range(2)]
training = Dummy.create_training(actor=self.user, participants=participants)
# create model update for 4 users
base_model = training.model
models_latest = [base_model]
for participant, rounds in zip(participants, participant_rounds):
for round_idx in range(rounds):
model = Dummy.create_model_update(base_model=base_model, owner=participant, round=round_idx+1)
models_latest.append(model)
# get user related "latest" models for a special training
response = self.client.get(f"{BASE_URL}/trainings/{training.pk}/models/latest/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models_latest), len(response_json))
models_latest = sorted(models_latest, key=lambda m: str(m.pk))
response_models = sorted(response_json, key=lambda m: m["id"])
self.assertEqual(
[str(model.id) for model in models_latest],
[model["id"] for model in response_models]
)
self.assertEqual(
[model.round for model in models_latest],
[model["round"] for model in response_models]
)
def test_get_model_metadata(self):
model_bytes = from_torch_module(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model = Dummy.create_model(weights=model_bytes, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
self.assertFalse(response_json["has_preprocessing"])
# check stats
stats = response_json["stats"]
self.assertIsNotNone(stats)
self.assertEqual([[1, 3]], stats["input_size"])
self.assertIsNotNone(stats["total_input"])
self.assertIsNotNone(stats["total_mult_adds"])
self.assertIsNotNone(stats["total_output_bytes"])
self.assertIsNotNone(stats["total_param_bytes"])
self.assertIsNotNone(stats["total_params"])
self.assertIsNotNone(stats["trainable_params"])
# layer 1 stats
layer1 = stats["summary_list"][0]
self.assertEqual("Sequential", layer1["class_name"])
self.assertEqual(0, layer1["depth"])
self.assertEqual(1, layer1["depth_index"])
self.assertEqual(True, layer1["executed"])
self.assertEqual("Sequential", layer1["var_name"])
self.assertEqual(False, layer1["is_leaf_layer"])
self.assertEqual(False, layer1["contains_lazy_param"])
self.assertEqual(False, layer1["is_recursive"])
self.assertEqual([1, 3], layer1["input_size"])
self.assertEqual([1, 1], layer1["output_size"])
self.assertEqual(None, layer1["kernel_size"])
self.assertIsNotNone(layer1["trainable_params"])
self.assertIsNotNone(layer1["num_params"])
self.assertIsNotNone(layer1["param_bytes"])
self.assertIsNotNone(layer1["output_bytes"])
self.assertIsNotNone(layer1["macs"])
# layer 2 stats
layer2 = stats["summary_list"][1]
self.assertEqual("Linear", layer2["class_name"])
self.assertEqual(1, layer2["depth"])
self.assertEqual(1, layer2["depth_index"])
self.assertEqual(True, layer2["executed"])
self.assertEqual("0", layer2["var_name"])
self.assertEqual(True, layer2["is_leaf_layer"])
self.assertEqual(False, layer2["contains_lazy_param"])
self.assertEqual(False, layer2["is_recursive"])
self.assertEqual([1, 3], layer2["input_size"])
self.assertEqual([1, 64], layer2["output_size"])
self.assertEqual(None, layer2["kernel_size"])
self.assertIsNotNone(layer2["trainable_params"])
self.assertIsNotNone(layer2["num_params"])
self.assertIsNotNone(layer2["param_bytes"])
self.assertIsNotNone(layer2["output_bytes"])
self.assertIsNotNone(layer2["macs"])
# layer 3 stats
layer3 = stats["summary_list"][2]
self.assertEqual("ELU", layer3["class_name"])
self.assertEqual(1, layer3["depth"])
self.assertEqual(2, layer3["depth_index"])
self.assertEqual(True, layer3["executed"])
self.assertEqual("1", layer3["var_name"])
self.assertEqual(True, layer3["is_leaf_layer"])
self.assertEqual(False, layer3["contains_lazy_param"])
self.assertEqual(False, layer3["is_recursive"])
self.assertEqual([1, 64], layer3["input_size"])
self.assertEqual([1, 64], layer3["output_size"])
self.assertEqual(None, layer3["kernel_size"])
self.assertIsNotNone(layer3["trainable_params"])
self.assertIsNotNone(layer3["num_params"])
self.assertIsNotNone(layer3["param_bytes"])
self.assertIsNotNone(layer3["output_bytes"])
self.assertIsNotNone(layer3["macs"])
# layer 4 stats
layer4 = stats["summary_list"][3]
self.assertEqual("Linear", layer4["class_name"])
self.assertEqual(1, layer4["depth"])
self.assertEqual(3, layer4["depth_index"])
self.assertEqual(True, layer4["executed"])
self.assertEqual("2", layer4["var_name"])
self.assertEqual(True, layer4["is_leaf_layer"])
self.assertEqual(False, layer4["contains_lazy_param"])
self.assertEqual(False, layer4["is_recursive"])
self.assertEqual([1, 64], layer4["input_size"])
self.assertEqual([1, 1], layer4["output_size"])
self.assertEqual(None, layer4["kernel_size"])
self.assertIsNotNone(layer4["trainable_params"])
self.assertIsNotNone(layer4["num_params"])
self.assertIsNotNone(layer4["param_bytes"])
self.assertIsNotNone(layer4["output_bytes"])
self.assertIsNotNone(layer4["macs"])
def test_get_model_metadata_with_preprocessing(self):
model_bytes = from_torch_module(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
torch_model_preprocessing = from_torch_module(transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=(0.,), std=(1.,)),
]))
model = Dummy.create_model(weights=model_bytes, preprocessing=torch_model_preprocessing, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
self.assertTrue(response_json["has_preprocessing"])
def test_get_model_metadata_torchscript_model(self):
torchscript_model_bytes = from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
)))
model = Dummy.create_model(weights=torchscript_model_bytes, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
# check stats
stats = response_json["stats"]
self.assertIsNotNone(stats)
self.assertEqual([[1, 3]], stats["input_size"])
self.assertIsNotNone(stats["total_input"])
self.assertIsNotNone(stats["total_mult_adds"])
self.assertIsNotNone(stats["total_output_bytes"])
self.assertIsNotNone(stats["total_param_bytes"])
self.assertIsNotNone(stats["total_params"])
self.assertIsNotNone(stats["trainable_params"])
self.assertEqual(4, len(stats["summary_list"]))
def test_get_model(self):
model = Dummy.create_model(weights=b"Hello World!")
response = self.client.get(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
self.assertEqual(b"Hello World!", response.getvalue())
def test_delete_model_without_training_as_model_owner(self):
model = Dummy.create_model(owner=self.user)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
def test_delete_global_model_with_training_as_model_owner(self):
model = Dummy.create_model(owner=self.user)
training = Dummy.create_training(model=model)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
# due to cascade delete (in the case of GlobalModel), training should also be deleted
self.assertRaises(ObjectDoesNotExist, Training.objects.get, pk=training.id)
def test_delete_local_model_with_training_as_model_owner(self):
global_model = Dummy.create_model()
local_model = Dummy.create_model_update(base_model=global_model, owner=self.user)
training = Dummy.create_training(model=global_model)
response = self.client.delete(f"{BASE_URL}/models/{local_model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=local_model.id)
self.assertIsNotNone(Model.objects.get(pk=global_model.id))
self.assertIsNotNone(Training.objects.get(pk=training.id))
def test_delete_model_as_training_owner(self):
model = Dummy.create_model()
training = Dummy.create_training(model=model, actor=self.user)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
# due to cascade delete (in the case of GlobalModel), training should also be deleted
self.assertRaises(ObjectDoesNotExist, Training.objects.get, pk=training.id)
def test_delete_model_as_training_participant(self):
model = Dummy.create_model()
Dummy.create_training(model=model, participants=[Dummy.create_client(), self.user, Dummy.create_client()])
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
def test_delete_model_with_training_as_unrelated_user(self):
model = Dummy.create_model()
Dummy.create_training(model=model)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
def test_delete_model_without_training_as_unrelated_user(self):
model = Dummy.create_model()
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
def test_delete_non_existing_model(self):
model_id = str(uuid4())
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model_id}/")
self.assertEqual(400, response.status_code)
body = response.json()
self.assertEqual(f"Model {model_id} not found.", body["detail"])
def test_get_model_and_unpickle(self):
model = Dummy.create_model()
response = self.client.get(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
torch_model = torch.jit.load(io.BytesIO(response.content))
self.assertIsNotNone(torch_model)
self.assertTrue(isinstance(torch_model, torch.nn.Module))
def test_upload(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"model_file": model_file,
"name": "Test Model",
"description": "Test Model Description - Test Model Description Test",
"input_shape": [None, 3]
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(GlobalModel, type(Model.objects.get(id=uuid)))
self.assertEqual([None, 3], Model.objects.get(id=uuid).input_shape)
def test_upload_swag_model(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"type": "SWAG",
"model_file": model_file,
"name": "Test SWAG Model",
"description": "Test SWAG Model Description - Test SWAG Model Description Test",
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(SWAGModel, type(Model.objects.get(id=uuid)))
def test_upload_mean_model(self):
models = [Dummy.create_model(owner=self.user) for _ in range(10)]
model_uuids = [str(m.id) for m in models]
response = self.client.post(f"{BASE_URL}/models/", {
"type": "MEAN",
"name": "Test MEAN Model",
"description": "Test MEAN Model Description - Test MEAN Model Description Test",
"models": model_uuids,
}, "application/json")
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(MeanModel, type(Model.objects.get(id=uuid)))
def test_upload_with_preprocessing(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
torch_model_preprocessing = torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
))
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"model_file": model_file,
"model_preprocessing_file": model_preprocessing_file,
"name": "Test Model",
"description": "Test Model Description - Test Model Description Test",
"input_shape": [None, 3]
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(GlobalModel, type(Model.objects.get(id=uuid)))
self.assertEqual([None, 3], Model.objects.get(id=uuid).input_shape)
model = get_entity(GlobalModel, pk=uuid)
self.assertIsNotNone(model)
self.assertIsNotNone(model.weights)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_torch_model(), torch.nn.Module))
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
def test_upload_model_preprocessing(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
torch_model_preprocessing = torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
))
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(202, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Proprocessing Model Upload Accepted", response_json["detail"])
model.refresh_from_db()
self.assertIsNotNone(model)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
def test_upload_model_preprocessing_v1_Compose_bad(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
# torchvision.transforms.Compose (v1 not v2) does not inherit from torch.nn.Module!!
torch_model_preprocessing = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.,), std=(1.,)),
])
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # (normal) transforms.Compose
content_type="application/octet-stream"
)
with self.assertLogs("fl.server", level="ERROR"): # Loaded torch object is not of expected type.
with self.assertLogs("django.request", level="WARNING"): # Bad Request
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(
"Invalid preprocessing file: Loaded torch object is not of expected type.",
response_json[0],
)
def test_upload_model_preprocessing_v2_Compose_good(self):
# Maybe good now
model = Dummy.create_model(owner=self.user, preprocessing=None)
torch_model_preprocessing = transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=(0.,), std=(1.,)),
])
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # (normal) transforms.Compose
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(202, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Proprocessing Model Upload Accepted", response_json["detail"])
model.refresh_from_db()
self.assertIsNotNone(model)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
def test_download_model_preprocessing(self):
torch_model_preprocessing = from_torch_module(torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
)))
model = Dummy.create_model(owner=self.user, preprocessing=torch_model_preprocessing)
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
torch_model = torch.jit.load(io.BytesIO(response.content))
self.assertIsNotNone(torch_model)
self.assertTrue(isinstance(torch_model, torch.nn.Module))
def test_download_model_preprocessing_with_undefined_preprocessing(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
with self.assertLogs("django.request", level="WARNING") as cm:
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
self.assertEqual(cm.output, [
f"WARNING:django.request:Not Found: /api/models/{model.id}/preprocessing/",
])
self.assertEqual(404, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"Model '{model.id}' has no preprocessing model defined.", response_json["detail"])
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertFalse(apply_async.called)
def test_upload_update_bad_keys(self):
model = Dummy.create_model(owner=self.user, round=0)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"xXx_model_file_xXx": model_update_file, "round": 0, "sample_size": 100}
)
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("No uploaded file 'model_file' found.", response_json["detail"])
def test_upload_update_no_training(self):
model = Dummy.create_model(owner=self.user, round=0)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0, "sample_size": 100}
)
self.assertEqual(404, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"Model with ID {model.id} does not have a training process running", response_json["detail"])
def test_upload_update_no_participant(self):
self.client.defaults["HTTP_ACCEPT"] = "application/json"
actor = Dummy.create_actor()
model = Dummy.create_model(owner=actor, round=0)
training = Dummy.create_training(
model=model, actor=actor, state=TrainingState.ONGOING,
participants=[actor, Dummy.create_client()]
)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 500}
)
self.assertEqual(403, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"You are not a participant of training {training.id}!", response_json["detail"])
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update_and_aggregate(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
train = Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertTrue(apply_async.called)
apply_async.assert_called_once_with(
(),
{"training_uuid": train.id, "event_cls": TrainingRoundFinished},
retry=False
)
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update_and_not_aggregate_since_training_is_locked(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
training = Dummy.create_training(
model=model, actor=self.user, state=TrainingState.ONGOING, participants=[self.user]
)
training.locked = True
training.save()
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertFalse(apply_async.called)
def test_upload_update_with_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{
"model_file": model_update_file,
"round": 0,
"metric_names": ["loss", "accuracy", "dummy_binary"],
"metric_values": [1999.0, 0.12, b"Hello World!"],
"sample_size": 50
},
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
def test_upload_update_with_metrics_bad(self):
model = Dummy.create_model(owner=self.user)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0, "metric_names": 5,
"sample_size": 500}
)
self.assertEqual(400, response.status_code)
def test_upload_global_model_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, 0.12, b"Hello World!"],
)
with self.assertLogs("fl.server", level="WARNING") as cm:
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(cm.output, [
f"WARNING:fl.server:Global model {model.id} is not connected to any training.",
])
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Metrics Upload Accepted", response_json["detail"])
self.assertEqual(str(model.id), response_json["model_id"])
def test_upload_local_model_metrics(self):
model = Dummy.create_model_update(owner=self.user)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, 0.12, b"Hello World!"],
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Metrics Upload Accepted", response_json["detail"])
self.assertEqual(str(model.id), response_json["model_id"])
def test_upload_bad_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, b"Hello World!"],
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Metric names and values must have the same length", response_json["detail"])
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_swag_stats(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
train = Dummy.create_training(
model=model,
actor=self.user,
state=TrainingState.SWAG_ROUND,
participants=[self.user]
)
first_moment_file = SimpleUploadedFile(
"first_moment.pkl",
pickle.dumps(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
).state_dict()),
content_type="application/octet-stream"
)
second_moment_file = SimpleUploadedFile(
"second_moment.pkl",
pickle.dumps(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
).state_dict()),
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/swag/", {
"first_moment_file": first_moment_file,
"second_moment_file": second_moment_file,
"sample_size": 100,
"round": 0
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("SWAG Statistic Accepted", response_json["detail"])
self.assertTrue(apply_async.called)
apply_async.assert_called_once_with(
(),
{"training_uuid": train.id, "event_cls": SWAGRoundFinished},
retry=False
)
def test_get_global_model_metrics(self):
model = Dummy.create_model(owner=self.user)
metric = Dummy.create_metric(model=model)
response = self.client.get(f"{BASE_URL}/models/{model.id}/metrics/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual(1, len(body))
self.assertEqual(metric.value_float, body[0]["value_float"])
self.assertEqual(metric.key, body[0]["key"])
def test_get_local_model_metrics(self):
model = Dummy.create_model_update(owner=self.user)
metric = Dummy.create_metric(model=model)
response = self.client.get(f"{BASE_URL}/models/{model.id}/metrics/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual(1, len(body))
self.assertEqual(metric.value_float, body[0]["value_float"])
self.assertEqual(metric.key, body[0]["key"])
Ancestors (in MRO)¶
- django.test.testcases.TransactionTestCase
- django.test.testcases.SimpleTestCase
- unittest.case.TestCase
Class variables¶
Static methods¶
addClassCleanup¶
Same as addCleanup, except the cleanup items are called even if
setUpClass fails (unlike tearDownClass).
View Source
doClassCleanups¶
Execute all class cleanup functions. Normally called for you after
tearDownClass.
View Source
@classmethod
def doClassCleanups(cls):
"""Execute all class cleanup functions. Normally called for you after
tearDownClass."""
cls.tearDown_exceptions = []
while cls._class_cleanups:
function, args, kwargs = cls._class_cleanups.pop()
try:
function(*args, **kwargs)
except Exception:
cls.tearDown_exceptions.append(sys.exc_info())
setUpClass¶
Hook method for setting up class fixture before running tests in the class.
View Source
@classmethod
def setUpClass(cls):
super().setUpClass()
if cls._overridden_settings:
cls._cls_overridden_context = override_settings(**cls._overridden_settings)
cls._cls_overridden_context.enable()
cls.addClassCleanup(cls._cls_overridden_context.disable)
if cls._modified_settings:
cls._cls_modified_context = modify_settings(cls._modified_settings)
cls._cls_modified_context.enable()
cls.addClassCleanup(cls._cls_modified_context.disable)
cls._add_databases_failures()
cls.addClassCleanup(cls._remove_databases_failures)
tearDownClass¶
Hook method for deconstructing the class fixture after running all tests in the class.
View Source
Methods¶
addCleanup¶
Add a function, with arguments, to be called when the test is
completed. Functions added are called on a LIFO basis and are called after tearDown on test failure or success.
Cleanup items are called even if setUp fails (unlike tearDown).
View Source
def addCleanup(self, function, /, *args, **kwargs):
"""Add a function, with arguments, to be called when the test is
completed. Functions added are called on a LIFO basis and are
called after tearDown on test failure or success.
Cleanup items are called even if setUp fails (unlike tearDown)."""
self._cleanups.append((function, args, kwargs))
addTypeEqualityFunc¶
Add a type specific assertEqual style function to compare a type.
This method is for use by TestCase subclasses that need to register their own type equality functions to provide nicer error messages.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
typeobj | None | The data type to call this function on when both values are of the same type in assertEqual(). |
None |
function | None | The callable taking two arguments and an optional msg= argument that raises self.failureException with a useful error message when the two arguments are not equal. |
None |
View Source
def addTypeEqualityFunc(self, typeobj, function):
"""Add a type specific assertEqual style function to compare a type.
This method is for use by TestCase subclasses that need to register
their own type equality functions to provide nicer error messages.
Args:
typeobj: The data type to call this function on when both values
are of the same type in assertEqual().
function: The callable taking two arguments and an optional
msg= argument that raises self.failureException with a
useful error message when the two arguments are not equal.
"""
self._type_equality_funcs[typeobj] = function
assertAlmostEqual¶
Fail if the two objects are unequal as determined by their
difference rounded to the given number of decimal places (default 7) and comparing to zero, or by comparing that the difference between the two objects is more than the given delta.
Note that decimal places (from zero) are usually not the same as significant digits (measured from the most significant digit).
If the two objects compare equal then they will automatically compare almost equal.
View Source
def assertAlmostEqual(self, first, second, places=None, msg=None,
delta=None):
"""Fail if the two objects are unequal as determined by their
difference rounded to the given number of decimal places
(default 7) and comparing to zero, or by comparing that the
difference between the two objects is more than the given
delta.
Note that decimal places (from zero) are usually not the same
as significant digits (measured from the most significant digit).
If the two objects compare equal then they will automatically
compare almost equal.
"""
if first == second:
# shortcut
return
if delta is not None and places is not None:
raise TypeError("specify delta or places not both")
diff = abs(first - second)
if delta is not None:
if diff <= delta:
return
standardMsg = '%s != %s within %s delta (%s difference)' % (
safe_repr(first),
safe_repr(second),
safe_repr(delta),
safe_repr(diff))
else:
if places is None:
places = 7
if round(diff, places) == 0:
return
standardMsg = '%s != %s within %r places (%s difference)' % (
safe_repr(first),
safe_repr(second),
places,
safe_repr(diff))
msg = self._formatMessage(msg, standardMsg)
raise self.failureException(msg)
assertAlmostEquals¶
View Source
assertContains¶
Assert that a response indicates that some content was retrieved
successfully, (i.e., the HTTP status code was as expected) and that
text
occurs count
times in the content of the response.
If count
is None, the count doesn't matter - the assertion is true
if the text occurs at least once in the response.
View Source
def assertContains(
self, response, text, count=None, status_code=200, msg_prefix="", html=False
):
"""
Assert that a response indicates that some content was retrieved
successfully, (i.e., the HTTP status code was as expected) and that
``text`` occurs ``count`` times in the content of the response.
If ``count`` is None, the count doesn't matter - the assertion is true
if the text occurs at least once in the response.
"""
text_repr, real_count, msg_prefix = self._assert_contains(
response, text, status_code, msg_prefix, html
)
if count is not None:
self.assertEqual(
real_count,
count,
msg_prefix
+ "Found %d instances of %s in response (expected %d)"
% (real_count, text_repr, count),
)
else:
self.assertTrue(
real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr
)
assertCountEqual¶
Asserts that two iterables have the same elements, the same number of
times, without regard to order.
self.assertEqual(Counter(list(first)),
Counter(list(second)))
Example: - [0, 1, 1] and [1, 0, 1] compare equal. - [0, 0, 1] and [0, 1] compare unequal.
View Source
def assertCountEqual(self, first, second, msg=None):
"""Asserts that two iterables have the same elements, the same number of
times, without regard to order.
self.assertEqual(Counter(list(first)),
Counter(list(second)))
Example:
- [0, 1, 1] and [1, 0, 1] compare equal.
- [0, 0, 1] and [0, 1] compare unequal.
"""
first_seq, second_seq = list(first), list(second)
try:
first = collections.Counter(first_seq)
second = collections.Counter(second_seq)
except TypeError:
# Handle case with unhashable elements
differences = _count_diff_all_purpose(first_seq, second_seq)
else:
if first == second:
return
differences = _count_diff_hashable(first_seq, second_seq)
if differences:
standardMsg = 'Element counts were not equal:\n'
lines = ['First has %d, Second has %d: %r' % diff for diff in differences]
diffMsg = '\n'.join(lines)
standardMsg = self._truncateMessage(standardMsg, diffMsg)
msg = self._formatMessage(msg, standardMsg)
self.fail(msg)
assertDictContainsSubset¶
Checks whether dictionary is a superset of subset.
View Source
def assertDictContainsSubset(self, subset, dictionary, msg=None):
"""Checks whether dictionary is a superset of subset."""
warnings.warn('assertDictContainsSubset is deprecated',
DeprecationWarning,
stacklevel=2)
missing = []
mismatched = []
for key, value in subset.items():
if key not in dictionary:
missing.append(key)
elif value != dictionary[key]:
mismatched.append('%s, expected: %s, actual: %s' %
(safe_repr(key), safe_repr(value),
safe_repr(dictionary[key])))
if not (missing or mismatched):
return
standardMsg = ''
if missing:
standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
missing)
if mismatched:
if standardMsg:
standardMsg += '; '
standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
self.fail(self._formatMessage(msg, standardMsg))
assertDictEqual¶
View Source
def assertDictEqual(self, d1, d2, msg=None):
self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
if d1 != d2:
standardMsg = '%s != %s' % _common_shorten_repr(d1, d2)
diff = ('\n' + '\n'.join(difflib.ndiff(
pprint.pformat(d1).splitlines(),
pprint.pformat(d2).splitlines())))
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
assertEqual¶
Fail if the two objects are unequal as determined by the '=='
operator.
View Source
assertEquals¶
View Source
assertFalse¶
Check that the expression is false.
View Source
assertFieldOutput¶
def assertFieldOutput(
self,
fieldclass,
valid,
invalid,
field_args=None,
field_kwargs=None,
empty_value=''
)
Assert that a form field behaves correctly with various inputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fieldclass | None | the class of the field to be tested. | None |
valid | None | a dictionary mapping valid inputs to their expected cleaned values. |
None |
invalid | None | a dictionary mapping invalid inputs to one or more raised error messages. |
None |
field_args | None | the args passed to instantiate the field | None |
field_kwargs | None | the kwargs passed to instantiate the field | None |
empty_value | None | the expected clean output for inputs in empty_values | None |
View Source
def assertFieldOutput(
self,
fieldclass,
valid,
invalid,
field_args=None,
field_kwargs=None,
empty_value="",
):
"""
Assert that a form field behaves correctly with various inputs.
Args:
fieldclass: the class of the field to be tested.
valid: a dictionary mapping valid inputs to their expected
cleaned values.
invalid: a dictionary mapping invalid inputs to one or more
raised error messages.
field_args: the args passed to instantiate the field
field_kwargs: the kwargs passed to instantiate the field
empty_value: the expected clean output for inputs in empty_values
"""
if field_args is None:
field_args = []
if field_kwargs is None:
field_kwargs = {}
required = fieldclass(*field_args, **field_kwargs)
optional = fieldclass(*field_args, **{**field_kwargs, "required": False})
# test valid inputs
for input, output in valid.items():
self.assertEqual(required.clean(input), output)
self.assertEqual(optional.clean(input), output)
# test invalid inputs
for input, errors in invalid.items():
with self.assertRaises(ValidationError) as context_manager:
required.clean(input)
self.assertEqual(context_manager.exception.messages, errors)
with self.assertRaises(ValidationError) as context_manager:
optional.clean(input)
self.assertEqual(context_manager.exception.messages, errors)
# test required inputs
error_required = [required.error_messages["required"]]
for e in required.empty_values:
with self.assertRaises(ValidationError) as context_manager:
required.clean(e)
self.assertEqual(context_manager.exception.messages, error_required)
self.assertEqual(optional.clean(e), empty_value)
# test that max_length and min_length are always accepted
if issubclass(fieldclass, CharField):
field_kwargs.update({"min_length": 2, "max_length": 20})
self.assertIsInstance(fieldclass(*field_args, **field_kwargs), fieldclass)
assertFormError¶
Assert that a form used to render the response has a specific field
error.
View Source
def assertFormError(self, response, form, field, errors, msg_prefix=""):
"""
Assert that a form used to render the response has a specific field
error.
"""
if msg_prefix:
msg_prefix += ": "
# Put context(s) into a list to simplify processing.
contexts = to_list(response.context)
if not contexts:
self.fail(
msg_prefix + "Response did not use any contexts to render the response"
)
# Put error(s) into a list to simplify processing.
errors = to_list(errors)
# Search all contexts for the error.
found_form = False
for i, context in enumerate(contexts):
if form not in context:
continue
found_form = True
for err in errors:
if field:
if field in context[form].errors:
field_errors = context[form].errors[field]
self.assertTrue(
err in field_errors,
msg_prefix + "The field '%s' on form '%s' in"
" context %d does not contain the error '%s'"
" (actual errors: %s)"
% (field, form, i, err, repr(field_errors)),
)
elif field in context[form].fields:
self.fail(
msg_prefix
+ (
"The field '%s' on form '%s' in context %d contains no "
"errors"
)
% (field, form, i)
)
else:
self.fail(
msg_prefix
+ (
"The form '%s' in context %d does not contain the "
"field '%s'"
)
% (form, i, field)
)
else:
non_field_errors = context[form].non_field_errors()
self.assertTrue(
err in non_field_errors,
msg_prefix + "The form '%s' in context %d does not"
" contain the non-field error '%s'"
" (actual errors: %s)"
% (form, i, err, non_field_errors or "none"),
)
if not found_form:
self.fail(
msg_prefix + "The form '%s' was not used to render the response" % form
)
assertFormsetError¶
Assert that a formset used to render the response has a specific error.
For field errors, specify the form_index
and the field
.
For non-field errors, specify the form_index
and the field
as
None.
For non-form errors, specify form_index
as None and the field
as None.
View Source
def assertFormsetError(
self, response, formset, form_index, field, errors, msg_prefix=""
):
"""
Assert that a formset used to render the response has a specific error.
For field errors, specify the ``form_index`` and the ``field``.
For non-field errors, specify the ``form_index`` and the ``field`` as
None.
For non-form errors, specify ``form_index`` as None and the ``field``
as None.
"""
# Add punctuation to msg_prefix
if msg_prefix:
msg_prefix += ": "
# Put context(s) into a list to simplify processing.
contexts = to_list(response.context)
if not contexts:
self.fail(
msg_prefix + "Response did not use any contexts to "
"render the response"
)
# Put error(s) into a list to simplify processing.
errors = to_list(errors)
# Search all contexts for the error.
found_formset = False
for i, context in enumerate(contexts):
if formset not in context or not hasattr(context[formset], "forms"):
continue
found_formset = True
for err in errors:
if field is not None:
if field in context[formset].forms[form_index].errors:
field_errors = context[formset].forms[form_index].errors[field]
self.assertTrue(
err in field_errors,
msg_prefix + "The field '%s' on formset '%s', "
"form %d in context %d does not contain the "
"error '%s' (actual errors: %s)"
% (field, formset, form_index, i, err, repr(field_errors)),
)
elif field in context[formset].forms[form_index].fields:
self.fail(
msg_prefix
+ (
"The field '%s' on formset '%s', form %d in context "
"%d contains no errors"
)
% (field, formset, form_index, i)
)
else:
self.fail(
msg_prefix
+ (
"The formset '%s', form %d in context %d does not "
"contain the field '%s'"
)
% (formset, form_index, i, field)
)
elif form_index is not None:
non_field_errors = (
context[formset].forms[form_index].non_field_errors()
)
self.assertFalse(
not non_field_errors,
msg_prefix + "The formset '%s', form %d in context %d "
"does not contain any non-field errors."
% (formset, form_index, i),
)
self.assertTrue(
err in non_field_errors,
msg_prefix + "The formset '%s', form %d in context %d "
"does not contain the non-field error '%s' (actual errors: %s)"
% (formset, form_index, i, err, repr(non_field_errors)),
)
else:
non_form_errors = context[formset].non_form_errors()
self.assertFalse(
not non_form_errors,
msg_prefix + "The formset '%s' in context %d does not "
"contain any non-form errors." % (formset, i),
)
self.assertTrue(
err in non_form_errors,
msg_prefix + "The formset '%s' in context %d does not "
"contain the non-form error '%s' (actual errors: %s)"
% (formset, i, err, repr(non_form_errors)),
)
if not found_formset:
self.fail(
msg_prefix
+ "The formset '%s' was not used to render the response" % formset
)
assertGreater¶
Just like self.assertTrue(a > b), but with a nicer default message.
View Source
assertGreaterEqual¶
Just like self.assertTrue(a >= b), but with a nicer default message.
View Source
assertHTMLEqual¶
Assert that two HTML snippets are semantically the same.
Whitespace in most cases is ignored, and attribute ordering is not significant. The arguments must be valid HTML.
View Source
def assertHTMLEqual(self, html1, html2, msg=None):
"""
Assert that two HTML snippets are semantically the same.
Whitespace in most cases is ignored, and attribute ordering is not
significant. The arguments must be valid HTML.
"""
dom1 = assert_and_parse_html(
self, html1, msg, "First argument is not valid HTML:"
)
dom2 = assert_and_parse_html(
self, html2, msg, "Second argument is not valid HTML:"
)
if dom1 != dom2:
standardMsg = "%s != %s" % (safe_repr(dom1, True), safe_repr(dom2, True))
diff = "\n" + "\n".join(
difflib.ndiff(
str(dom1).splitlines(),
str(dom2).splitlines(),
)
)
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
assertHTMLNotEqual¶
Assert that two HTML snippets are not semantically equivalent.
View Source
def assertHTMLNotEqual(self, html1, html2, msg=None):
"""Assert that two HTML snippets are not semantically equivalent."""
dom1 = assert_and_parse_html(
self, html1, msg, "First argument is not valid HTML:"
)
dom2 = assert_and_parse_html(
self, html2, msg, "Second argument is not valid HTML:"
)
if dom1 == dom2:
standardMsg = "%s == %s" % (safe_repr(dom1, True), safe_repr(dom2, True))
self.fail(self._formatMessage(msg, standardMsg))
assertIn¶
Just like self.assertTrue(a in b), but with a nicer default message.
View Source
assertInHTML¶
View Source
def assertInHTML(self, needle, haystack, count=None, msg_prefix=""):
needle = assert_and_parse_html(
self, needle, None, "First argument is not valid HTML:"
)
haystack = assert_and_parse_html(
self, haystack, None, "Second argument is not valid HTML:"
)
real_count = haystack.count(needle)
if count is not None:
self.assertEqual(
real_count,
count,
msg_prefix
+ "Found %d instances of '%s' in response (expected %d)"
% (real_count, needle, count),
)
else:
self.assertTrue(
real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle
)
assertIs¶
Just like self.assertTrue(a is b), but with a nicer default message.
View Source
assertIsInstance¶
Same as self.assertTrue(isinstance(obj, cls)), with a nicer
default message.
View Source
assertIsNone¶
Same as self.assertTrue(obj is None), with a nicer default message.
View Source
assertIsNot¶
Just like self.assertTrue(a is not b), but with a nicer default message.
View Source
assertIsNotNone¶
Included for symmetry with assertIsNone.
View Source
assertJSONEqual¶
Assert that the JSON fragments raw and expected_data are equal.
Usual JSON non-significant whitespace rules apply as the heavyweight is delegated to the json library.
View Source
def assertJSONEqual(self, raw, expected_data, msg=None):
"""
Assert that the JSON fragments raw and expected_data are equal.
Usual JSON non-significant whitespace rules apply as the heavyweight
is delegated to the json library.
"""
try:
data = json.loads(raw)
except json.JSONDecodeError:
self.fail("First argument is not valid JSON: %r" % raw)
if isinstance(expected_data, str):
try:
expected_data = json.loads(expected_data)
except ValueError:
self.fail("Second argument is not valid JSON: %r" % expected_data)
self.assertEqual(data, expected_data, msg=msg)
assertJSONNotEqual¶
Assert that the JSON fragments raw and expected_data are not equal.
Usual JSON non-significant whitespace rules apply as the heavyweight is delegated to the json library.
View Source
def assertJSONNotEqual(self, raw, expected_data, msg=None):
"""
Assert that the JSON fragments raw and expected_data are not equal.
Usual JSON non-significant whitespace rules apply as the heavyweight
is delegated to the json library.
"""
try:
data = json.loads(raw)
except json.JSONDecodeError:
self.fail("First argument is not valid JSON: %r" % raw)
if isinstance(expected_data, str):
try:
expected_data = json.loads(expected_data)
except json.JSONDecodeError:
self.fail("Second argument is not valid JSON: %r" % expected_data)
self.assertNotEqual(data, expected_data, msg=msg)
assertLess¶
Just like self.assertTrue(a < b), but with a nicer default message.
View Source
assertLessEqual¶
Just like self.assertTrue(a <= b), but with a nicer default message.
View Source
assertListEqual¶
A list-specific equality assertion.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
list1 | None | The first list to compare. | None |
list2 | None | The second list to compare. | None |
msg | None | Optional message to use on failure instead of a list of differences. |
None |
View Source
def assertListEqual(self, list1, list2, msg=None):
"""A list-specific equality assertion.
Args:
list1: The first list to compare.
list2: The second list to compare.
msg: Optional message to use on failure instead of a list of
differences.
"""
self.assertSequenceEqual(list1, list2, msg, seq_type=list)
assertLogs¶
Fail unless a log message of level level or higher is emitted
on logger_name or its children. If omitted, level defaults to INFO and logger defaults to the root logger.
This method must be used as a context manager, and will yield
a recording object with two attributes: output
and records
.
At the end of the context manager, the output
attribute will
be a list of the matching formatted log messages and the
records
attribute will be a list of the corresponding LogRecord
objects.
Example::
with self.assertLogs('foo', level='INFO') as cm:
logging.getLogger('foo').info('first message')
logging.getLogger('foo.bar').error('second message')
self.assertEqual(cm.output, ['INFO:foo:first message',
'ERROR:foo.bar:second message'])
View Source
def assertLogs(self, logger=None, level=None):
"""Fail unless a log message of level *level* or higher is emitted
on *logger_name* or its children. If omitted, *level* defaults to
INFO and *logger* defaults to the root logger.
This method must be used as a context manager, and will yield
a recording object with two attributes: `output` and `records`.
At the end of the context manager, the `output` attribute will
be a list of the matching formatted log messages and the
`records` attribute will be a list of the corresponding LogRecord
objects.
Example::
with self.assertLogs('foo', level='INFO') as cm:
logging.getLogger('foo').info('first message')
logging.getLogger('foo.bar').error('second message')
self.assertEqual(cm.output, ['INFO:foo:first message',
'ERROR:foo.bar:second message'])
"""
# Lazy import to avoid importing logging if it is not needed.
from ._log import _AssertLogsContext
return _AssertLogsContext(self, logger, level, no_logs=False)
assertMultiLineEqual¶
Assert that two multi-line strings are equal.
View Source
def assertMultiLineEqual(self, first, second, msg=None):
"""Assert that two multi-line strings are equal."""
self.assertIsInstance(first, str, 'First argument is not a string')
self.assertIsInstance(second, str, 'Second argument is not a string')
if first != second:
# don't use difflib if the strings are too long
if (len(first) > self._diffThreshold or
len(second) > self._diffThreshold):
self._baseAssertEqual(first, second, msg)
firstlines = first.splitlines(keepends=True)
secondlines = second.splitlines(keepends=True)
if len(firstlines) == 1 and first.strip('\r\n') == first:
firstlines = [first + '\n']
secondlines = [second + '\n']
standardMsg = '%s != %s' % _common_shorten_repr(first, second)
diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
assertNoLogs¶
Fail unless no log messages of level level or higher are emitted
on logger_name or its children.
This method must be used as a context manager.
View Source
def assertNoLogs(self, logger=None, level=None):
""" Fail unless no log messages of level *level* or higher are emitted
on *logger_name* or its children.
This method must be used as a context manager.
"""
from ._log import _AssertLogsContext
return _AssertLogsContext(self, logger, level, no_logs=True)
assertNotAlmostEqual¶
Fail if the two objects are equal as determined by their
difference rounded to the given number of decimal places (default 7) and comparing to zero, or by comparing that the difference between the two objects is less than the given delta.
Note that decimal places (from zero) are usually not the same as significant digits (measured from the most significant digit).
Objects that are equal automatically fail.
View Source
def assertNotAlmostEqual(self, first, second, places=None, msg=None,
delta=None):
"""Fail if the two objects are equal as determined by their
difference rounded to the given number of decimal places
(default 7) and comparing to zero, or by comparing that the
difference between the two objects is less than the given delta.
Note that decimal places (from zero) are usually not the same
as significant digits (measured from the most significant digit).
Objects that are equal automatically fail.
"""
if delta is not None and places is not None:
raise TypeError("specify delta or places not both")
diff = abs(first - second)
if delta is not None:
if not (first == second) and diff > delta:
return
standardMsg = '%s == %s within %s delta (%s difference)' % (
safe_repr(first),
safe_repr(second),
safe_repr(delta),
safe_repr(diff))
else:
if places is None:
places = 7
if not (first == second) and round(diff, places) != 0:
return
standardMsg = '%s == %s within %r places' % (safe_repr(first),
safe_repr(second),
places)
msg = self._formatMessage(msg, standardMsg)
raise self.failureException(msg)
assertNotAlmostEquals¶
View Source
assertNotContains¶
Assert that a response indicates that some content was retrieved
successfully, (i.e., the HTTP status code was as expected) and that
text
doesn't occur in the content of the response.
View Source
def assertNotContains(
self, response, text, status_code=200, msg_prefix="", html=False
):
"""
Assert that a response indicates that some content was retrieved
successfully, (i.e., the HTTP status code was as expected) and that
``text`` doesn't occur in the content of the response.
"""
text_repr, real_count, msg_prefix = self._assert_contains(
response, text, status_code, msg_prefix, html
)
self.assertEqual(
real_count, 0, msg_prefix + "Response should not contain %s" % text_repr
)
assertNotEqual¶
Fail if the two objects are equal as determined by the '!='
operator.
View Source
assertNotEquals¶
View Source
assertNotIn¶
Just like self.assertTrue(a not in b), but with a nicer default message.
View Source
assertNotIsInstance¶
Included for symmetry with assertIsInstance.
View Source
assertNotRegex¶
Fail the test if the text matches the regular expression.
View Source
def assertNotRegex(self, text, unexpected_regex, msg=None):
"""Fail the test if the text matches the regular expression."""
if isinstance(unexpected_regex, (str, bytes)):
unexpected_regex = re.compile(unexpected_regex)
match = unexpected_regex.search(text)
if match:
standardMsg = 'Regex matched: %r matches %r in %r' % (
text[match.start() : match.end()],
unexpected_regex.pattern,
text)
# _formatMessage ensures the longMessage option is respected
msg = self._formatMessage(msg, standardMsg)
raise self.failureException(msg)
assertNotRegexpMatches¶
View Source
assertNumQueries¶
View Source
assertQuerysetEqual¶
View Source
def assertQuerysetEqual(self, qs, values, transform=None, ordered=True, msg=None):
values = list(values)
# RemovedInDjango41Warning.
if transform is None:
if (
values
and isinstance(values[0], str)
and qs
and not isinstance(qs[0], str)
):
# Transform qs using repr() if the first element of values is a
# string and the first element of qs is not (which would be the
# case if qs is a flattened values_list).
warnings.warn(
"In Django 4.1, repr() will not be called automatically "
"on a queryset when compared to string values. Set an "
"explicit 'transform' to silence this warning.",
category=RemovedInDjango41Warning,
stacklevel=2,
)
transform = repr
items = qs
if transform is not None:
items = map(transform, items)
if not ordered:
return self.assertDictEqual(Counter(items), Counter(values), msg=msg)
# For example qs.iterator() could be passed as qs, but it does not
# have 'ordered' attribute.
if len(values) > 1 and hasattr(qs, "ordered") and not qs.ordered:
raise ValueError(
"Trying to compare non-ordered queryset against more than one "
"ordered value."
)
return self.assertEqual(list(items), values, msg=msg)
assertRaises¶
Fail unless an exception of class expected_exception is raised
by the callable when invoked with specified positional and keyword arguments. If a different type of exception is raised, it will not be caught, and the test case will be deemed to have suffered an error, exactly as for an unexpected exception.
If called with the callable and arguments omitted, will return a context object used like this::
with self.assertRaises(SomeException):
do_something()
An optional keyword argument 'msg' can be provided when assertRaises is used as a context object.
The context manager keeps a reference to the exception as the 'exception' attribute. This allows you to inspect the exception after the assertion::
with self.assertRaises(SomeException) as cm:
do_something()
the_exception = cm.exception
self.assertEqual(the_exception.error_code, 3)
View Source
def assertRaises(self, expected_exception, *args, **kwargs):
"""Fail unless an exception of class expected_exception is raised
by the callable when invoked with specified positional and
keyword arguments. If a different type of exception is
raised, it will not be caught, and the test case will be
deemed to have suffered an error, exactly as for an
unexpected exception.
If called with the callable and arguments omitted, will return a
context object used like this::
with self.assertRaises(SomeException):
do_something()
An optional keyword argument 'msg' can be provided when assertRaises
is used as a context object.
The context manager keeps a reference to the exception as
the 'exception' attribute. This allows you to inspect the
exception after the assertion::
with self.assertRaises(SomeException) as cm:
do_something()
the_exception = cm.exception
self.assertEqual(the_exception.error_code, 3)
"""
context = _AssertRaisesContext(expected_exception, self)
try:
return context.handle('assertRaises', args, kwargs)
finally:
# bpo-23890: manually break a reference cycle
context = None
assertRaisesMessage¶
Assert that expected_message is found in the message of a raised
exception.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expected_exception | None | Exception class expected to be raised. | None |
expected_message | None | expected error message string value. | None |
args | None | Function to be called and extra positional args. | None |
kwargs | None | Extra kwargs. | None |
View Source
def assertRaisesMessage(
self, expected_exception, expected_message, *args, **kwargs
):
"""
Assert that expected_message is found in the message of a raised
exception.
Args:
expected_exception: Exception class expected to be raised.
expected_message: expected error message string value.
args: Function to be called and extra positional args.
kwargs: Extra kwargs.
"""
return self._assertFooMessage(
self.assertRaises,
"exception",
expected_exception,
expected_message,
*args,
**kwargs,
)
assertRaisesRegex¶
Asserts that the message in a raised exception matches a regex.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expected_exception | None | Exception class expected to be raised. | None |
expected_regex | None | Regex (re.Pattern object or string) expected to be found in error message. |
None |
args | None | Function to be called and extra positional args. | None |
kwargs | None | Extra kwargs. | None |
msg | None | Optional message used in case of failure. Can only be used when assertRaisesRegex is used as a context manager. |
None |
View Source
def assertRaisesRegex(self, expected_exception, expected_regex,
*args, **kwargs):
"""Asserts that the message in a raised exception matches a regex.
Args:
expected_exception: Exception class expected to be raised.
expected_regex: Regex (re.Pattern object or string) expected
to be found in error message.
args: Function to be called and extra positional args.
kwargs: Extra kwargs.
msg: Optional message used in case of failure. Can only be used
when assertRaisesRegex is used as a context manager.
"""
context = _AssertRaisesContext(expected_exception, self, expected_regex)
return context.handle('assertRaisesRegex', args, kwargs)
assertRaisesRegexp¶
View Source
assertRedirects¶
def assertRedirects(
self,
response,
expected_url,
status_code=302,
target_status_code=200,
msg_prefix='',
fetch_redirect_response=True
)
Assert that a response redirected to a specific URL and that the
redirect URL can be loaded.
Won't work for external links since it uses the test client to do a request (use fetch_redirect_response=False to check such links without fetching them).
View Source
def assertRedirects(
self,
response,
expected_url,
status_code=302,
target_status_code=200,
msg_prefix="",
fetch_redirect_response=True,
):
"""
Assert that a response redirected to a specific URL and that the
redirect URL can be loaded.
Won't work for external links since it uses the test client to do a
request (use fetch_redirect_response=False to check such links without
fetching them).
"""
if msg_prefix:
msg_prefix += ": "
if hasattr(response, "redirect_chain"):
# The request was a followed redirect
self.assertTrue(
response.redirect_chain,
msg_prefix
+ (
"Response didn't redirect as expected: Response code was %d "
"(expected %d)"
)
% (response.status_code, status_code),
)
self.assertEqual(
response.redirect_chain[0][1],
status_code,
msg_prefix
+ (
"Initial response didn't redirect as expected: Response code was "
"%d (expected %d)"
)
% (response.redirect_chain[0][1], status_code),
)
url, status_code = response.redirect_chain[-1]
self.assertEqual(
response.status_code,
target_status_code,
msg_prefix
+ (
"Response didn't redirect as expected: Final Response code was %d "
"(expected %d)"
)
% (response.status_code, target_status_code),
)
else:
# Not a followed redirect
self.assertEqual(
response.status_code,
status_code,
msg_prefix
+ (
"Response didn't redirect as expected: Response code was %d "
"(expected %d)"
)
% (response.status_code, status_code),
)
url = response.url
scheme, netloc, path, query, fragment = urlsplit(url)
# Prepend the request path to handle relative path redirects.
if not path.startswith("/"):
url = urljoin(response.request["PATH_INFO"], url)
path = urljoin(response.request["PATH_INFO"], path)
if fetch_redirect_response:
# netloc might be empty, or in cases where Django tests the
# HTTP scheme, the convention is for netloc to be 'testserver'.
# Trust both as "internal" URLs here.
domain, port = split_domain_port(netloc)
if domain and not validate_host(domain, settings.ALLOWED_HOSTS):
raise ValueError(
"The test client is unable to fetch remote URLs (got %s). "
"If the host is served by Django, add '%s' to ALLOWED_HOSTS. "
"Otherwise, use "
"assertRedirects(..., fetch_redirect_response=False)."
% (url, domain)
)
# Get the redirection page, using the same client that was used
# to obtain the original response.
extra = response.client.extra or {}
redirect_response = response.client.get(
path,
QueryDict(query),
secure=(scheme == "https"),
**extra,
)
self.assertEqual(
redirect_response.status_code,
target_status_code,
msg_prefix
+ (
"Couldn't retrieve redirection page '%s': response code was %d "
"(expected %d)"
)
% (path, redirect_response.status_code, target_status_code),
)
self.assertURLEqual(
url,
expected_url,
msg_prefix
+ "Response redirected to '%s', expected '%s'" % (url, expected_url),
)
assertRegex¶
Fail the test unless the text matches the regular expression.
View Source
def assertRegex(self, text, expected_regex, msg=None):
"""Fail the test unless the text matches the regular expression."""
if isinstance(expected_regex, (str, bytes)):
assert expected_regex, "expected_regex must not be empty."
expected_regex = re.compile(expected_regex)
if not expected_regex.search(text):
standardMsg = "Regex didn't match: %r not found in %r" % (
expected_regex.pattern, text)
# _formatMessage ensures the longMessage option is respected
msg = self._formatMessage(msg, standardMsg)
raise self.failureException(msg)
assertRegexpMatches¶
View Source
assertSequenceEqual¶
An equality assertion for ordered sequences (like lists and tuples).
For the purposes of this function, a valid ordered sequence type is one which can be indexed, has a length, and has an equality operator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seq1 | None | The first sequence to compare. | None |
seq2 | None | The second sequence to compare. | None |
seq_type | None | The expected datatype of the sequences, or None if no datatype should be enforced. |
None |
msg | None | Optional message to use on failure instead of a list of differences. |
None |
View Source
def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
"""An equality assertion for ordered sequences (like lists and tuples).
For the purposes of this function, a valid ordered sequence type is one
which can be indexed, has a length, and has an equality operator.
Args:
seq1: The first sequence to compare.
seq2: The second sequence to compare.
seq_type: The expected datatype of the sequences, or None if no
datatype should be enforced.
msg: Optional message to use on failure instead of a list of
differences.
"""
if seq_type is not None:
seq_type_name = seq_type.__name__
if not isinstance(seq1, seq_type):
raise self.failureException('First sequence is not a %s: %s'
% (seq_type_name, safe_repr(seq1)))
if not isinstance(seq2, seq_type):
raise self.failureException('Second sequence is not a %s: %s'
% (seq_type_name, safe_repr(seq2)))
else:
seq_type_name = "sequence"
differing = None
try:
len1 = len(seq1)
except (TypeError, NotImplementedError):
differing = 'First %s has no length. Non-sequence?' % (
seq_type_name)
if differing is None:
try:
len2 = len(seq2)
except (TypeError, NotImplementedError):
differing = 'Second %s has no length. Non-sequence?' % (
seq_type_name)
if differing is None:
if seq1 == seq2:
return
differing = '%ss differ: %s != %s\n' % (
(seq_type_name.capitalize(),) +
_common_shorten_repr(seq1, seq2))
for i in range(min(len1, len2)):
try:
item1 = seq1[i]
except (TypeError, IndexError, NotImplementedError):
differing += ('\nUnable to index element %d of first %s\n' %
(i, seq_type_name))
break
try:
item2 = seq2[i]
except (TypeError, IndexError, NotImplementedError):
differing += ('\nUnable to index element %d of second %s\n' %
(i, seq_type_name))
break
if item1 != item2:
differing += ('\nFirst differing element %d:\n%s\n%s\n' %
((i,) + _common_shorten_repr(item1, item2)))
break
else:
if (len1 == len2 and seq_type is None and
type(seq1) != type(seq2)):
# The sequences are the same, but have differing types.
return
if len1 > len2:
differing += ('\nFirst %s contains %d additional '
'elements.\n' % (seq_type_name, len1 - len2))
try:
differing += ('First extra element %d:\n%s\n' %
(len2, safe_repr(seq1[len2])))
except (TypeError, IndexError, NotImplementedError):
differing += ('Unable to index element %d '
'of first %s\n' % (len2, seq_type_name))
elif len1 < len2:
differing += ('\nSecond %s contains %d additional '
'elements.\n' % (seq_type_name, len2 - len1))
try:
differing += ('First extra element %d:\n%s\n' %
(len1, safe_repr(seq2[len1])))
except (TypeError, IndexError, NotImplementedError):
differing += ('Unable to index element %d '
'of second %s\n' % (len1, seq_type_name))
standardMsg = differing
diffMsg = '\n' + '\n'.join(
difflib.ndiff(pprint.pformat(seq1).splitlines(),
pprint.pformat(seq2).splitlines()))
standardMsg = self._truncateMessage(standardMsg, diffMsg)
msg = self._formatMessage(msg, standardMsg)
self.fail(msg)
assertSetEqual¶
A set-specific equality assertion.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
set1 | None | The first set to compare. | None |
set2 | None | The second set to compare. | None |
msg | None | Optional message to use on failure instead of a list of differences. |
None |
View Source
def assertSetEqual(self, set1, set2, msg=None):
"""A set-specific equality assertion.
Args:
set1: The first set to compare.
set2: The second set to compare.
msg: Optional message to use on failure instead of a list of
differences.
assertSetEqual uses ducktyping to support different types of sets, and
is optimized for sets specifically (parameters must support a
difference method).
"""
try:
difference1 = set1.difference(set2)
except TypeError as e:
self.fail('invalid type when attempting set difference: %s' % e)
except AttributeError as e:
self.fail('first argument does not support set difference: %s' % e)
try:
difference2 = set2.difference(set1)
except TypeError as e:
self.fail('invalid type when attempting set difference: %s' % e)
except AttributeError as e:
self.fail('second argument does not support set difference: %s' % e)
if not (difference1 or difference2):
return
lines = []
if difference1:
lines.append('Items in the first set but not the second:')
for item in difference1:
lines.append(repr(item))
if difference2:
lines.append('Items in the second set but not the first:')
for item in difference2:
lines.append(repr(item))
standardMsg = '\n'.join(lines)
self.fail(self._formatMessage(msg, standardMsg))
assertTemplateNotUsed¶
Assert that the template with the provided name was NOT used in
rendering the response. Also usable as context manager.
View Source
def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=""):
"""
Assert that the template with the provided name was NOT used in
rendering the response. Also usable as context manager.
"""
context_mgr_template, template_names, msg_prefix = self._assert_template_used(
response, template_name, msg_prefix
)
if context_mgr_template:
# Use assertTemplateNotUsed as context manager.
return _AssertTemplateNotUsedContext(self, context_mgr_template)
self.assertFalse(
template_name in template_names,
msg_prefix
+ "Template '%s' was used unexpectedly in rendering the response"
% template_name,
)
assertTemplateUsed¶
Assert that the template with the provided name was used in rendering
the response. Also usable as context manager.
View Source
def assertTemplateUsed(
self, response=None, template_name=None, msg_prefix="", count=None
):
"""
Assert that the template with the provided name was used in rendering
the response. Also usable as context manager.
"""
context_mgr_template, template_names, msg_prefix = self._assert_template_used(
response, template_name, msg_prefix
)
if context_mgr_template:
# Use assertTemplateUsed as context manager.
return _AssertTemplateUsedContext(self, context_mgr_template)
if not template_names:
self.fail(msg_prefix + "No templates used to render the response")
self.assertTrue(
template_name in template_names,
msg_prefix + "Template '%s' was not a template used to render"
" the response. Actual template(s) used: %s"
% (template_name, ", ".join(template_names)),
)
if count is not None:
self.assertEqual(
template_names.count(template_name),
count,
msg_prefix + "Template '%s' was expected to be rendered %d "
"time(s) but was actually rendered %d time(s)."
% (template_name, count, template_names.count(template_name)),
)
assertTrue¶
Check that the expression is true.
View Source
assertTupleEqual¶
A tuple-specific equality assertion.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tuple1 | None | The first tuple to compare. | None |
tuple2 | None | The second tuple to compare. | None |
msg | None | Optional message to use on failure instead of a list of differences. |
None |
View Source
def assertTupleEqual(self, tuple1, tuple2, msg=None):
"""A tuple-specific equality assertion.
Args:
tuple1: The first tuple to compare.
tuple2: The second tuple to compare.
msg: Optional message to use on failure instead of a list of
differences.
"""
self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
assertURLEqual¶
Assert that two URLs are the same, ignoring the order of query string
parameters except for parameters with the same name.
For example, /path/?x=1&y=2 is equal to /path/?y=2&x=1, but /path/?a=1&a=2 isn't equal to /path/?a=2&a=1.
View Source
def assertURLEqual(self, url1, url2, msg_prefix=""):
"""
Assert that two URLs are the same, ignoring the order of query string
parameters except for parameters with the same name.
For example, /path/?x=1&y=2 is equal to /path/?y=2&x=1, but
/path/?a=1&a=2 isn't equal to /path/?a=2&a=1.
"""
def normalize(url):
"""Sort the URL's query string parameters."""
url = str(url) # Coerce reverse_lazy() URLs.
scheme, netloc, path, params, query, fragment = urlparse(url)
query_parts = sorted(parse_qsl(query))
return urlunparse(
(scheme, netloc, path, params, urlencode(query_parts), fragment)
)
self.assertEqual(
normalize(url1),
normalize(url2),
msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2),
)
assertWarns¶
Fail unless a warning of class warnClass is triggered
by the callable when invoked with specified positional and keyword arguments. If a different type of warning is triggered, it will not be handled: depending on the other warning filtering rules in effect, it might be silenced, printed out, or raised as an exception.
If called with the callable and arguments omitted, will return a context object used like this::
with self.assertWarns(SomeWarning):
do_something()
An optional keyword argument 'msg' can be provided when assertWarns is used as a context object.
The context manager keeps a reference to the first matching warning as the 'warning' attribute; similarly, the 'filename' and 'lineno' attributes give you information about the line of Python code from which the warning was triggered. This allows you to inspect the warning after the assertion::
with self.assertWarns(SomeWarning) as cm:
do_something()
the_warning = cm.warning
self.assertEqual(the_warning.some_attribute, 147)
View Source
def assertWarns(self, expected_warning, *args, **kwargs):
"""Fail unless a warning of class warnClass is triggered
by the callable when invoked with specified positional and
keyword arguments. If a different type of warning is
triggered, it will not be handled: depending on the other
warning filtering rules in effect, it might be silenced, printed
out, or raised as an exception.
If called with the callable and arguments omitted, will return a
context object used like this::
with self.assertWarns(SomeWarning):
do_something()
An optional keyword argument 'msg' can be provided when assertWarns
is used as a context object.
The context manager keeps a reference to the first matching
warning as the 'warning' attribute; similarly, the 'filename'
and 'lineno' attributes give you information about the line
of Python code from which the warning was triggered.
This allows you to inspect the warning after the assertion::
with self.assertWarns(SomeWarning) as cm:
do_something()
the_warning = cm.warning
self.assertEqual(the_warning.some_attribute, 147)
"""
context = _AssertWarnsContext(expected_warning, self)
return context.handle('assertWarns', args, kwargs)
assertWarnsMessage¶
Same as assertRaisesMessage but for assertWarns() instead of
assertRaises().
View Source
assertWarnsRegex¶
Asserts that the message in a triggered warning matches a regexp.
Basic functioning is similar to assertWarns() with the addition that only warnings whose messages also match the regular expression are considered successful matches.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expected_warning | None | Warning class expected to be triggered. | None |
expected_regex | None | Regex (re.Pattern object or string) expected to be found in error message. |
None |
args | None | Function to be called and extra positional args. | None |
kwargs | None | Extra kwargs. | None |
msg | None | Optional message used in case of failure. Can only be used when assertWarnsRegex is used as a context manager. |
None |
View Source
def assertWarnsRegex(self, expected_warning, expected_regex,
*args, **kwargs):
"""Asserts that the message in a triggered warning matches a regexp.
Basic functioning is similar to assertWarns() with the addition
that only warnings whose messages also match the regular expression
are considered successful matches.
Args:
expected_warning: Warning class expected to be triggered.
expected_regex: Regex (re.Pattern object or string) expected
to be found in error message.
args: Function to be called and extra positional args.
kwargs: Extra kwargs.
msg: Optional message used in case of failure. Can only be used
when assertWarnsRegex is used as a context manager.
"""
context = _AssertWarnsContext(expected_warning, self, expected_regex)
return context.handle('assertWarnsRegex', args, kwargs)
assertXMLEqual¶
Assert that two XML snippets are semantically the same.
Whitespace in most cases is ignored and attribute ordering is not significant. The arguments must be valid XML.
View Source
def assertXMLEqual(self, xml1, xml2, msg=None):
"""
Assert that two XML snippets are semantically the same.
Whitespace in most cases is ignored and attribute ordering is not
significant. The arguments must be valid XML.
"""
try:
result = compare_xml(xml1, xml2)
except Exception as e:
standardMsg = "First or second argument is not valid XML\n%s" % e
self.fail(self._formatMessage(msg, standardMsg))
else:
if not result:
standardMsg = "%s != %s" % (
safe_repr(xml1, True),
safe_repr(xml2, True),
)
diff = "\n" + "\n".join(
difflib.ndiff(xml1.splitlines(), xml2.splitlines())
)
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
assertXMLNotEqual¶
Assert that two XML snippets are not semantically equivalent.
Whitespace in most cases is ignored and attribute ordering is not significant. The arguments must be valid XML.
View Source
def assertXMLNotEqual(self, xml1, xml2, msg=None):
"""
Assert that two XML snippets are not semantically equivalent.
Whitespace in most cases is ignored and attribute ordering is not
significant. The arguments must be valid XML.
"""
try:
result = compare_xml(xml1, xml2)
except Exception as e:
standardMsg = "First or second argument is not valid XML\n%s" % e
self.fail(self._formatMessage(msg, standardMsg))
else:
if result:
standardMsg = "%s == %s" % (
safe_repr(xml1, True),
safe_repr(xml2, True),
)
self.fail(self._formatMessage(msg, standardMsg))
assert_¶
View Source
countTestCases¶
debug¶
Perform the same as call(), without catching the exception.
View Source
defaultTestResult¶
doCleanups¶
Execute all cleanup functions. Normally called for you after
tearDown.
View Source
def doCleanups(self):
"""Execute all cleanup functions. Normally called for you after
tearDown."""
outcome = self._outcome or _Outcome()
while self._cleanups:
function, args, kwargs = self._cleanups.pop()
with outcome.testPartExecutor(self):
self._callCleanup(function, *args, **kwargs)
# return this for backwards compatibility
# even though we no longer use it internally
return outcome.success
fail¶
Fail immediately, with the given message.
View Source
failIf¶
View Source
failIfAlmostEqual¶
View Source
failIfEqual¶
View Source
failUnless¶
View Source
failUnlessAlmostEqual¶
View Source
failUnlessEqual¶
View Source
failUnlessRaises¶
View Source
id¶
modify_settings¶
A context manager that temporarily applies changes a list setting and
reverts back to the original value when exiting the context.
View Source
run¶
View Source
def run(self, result=None):
if result is None:
result = self.defaultTestResult()
startTestRun = getattr(result, 'startTestRun', None)
stopTestRun = getattr(result, 'stopTestRun', None)
if startTestRun is not None:
startTestRun()
else:
stopTestRun = None
result.startTest(self)
try:
testMethod = getattr(self, self._testMethodName)
if (getattr(self.__class__, "__unittest_skip__", False) or
getattr(testMethod, "__unittest_skip__", False)):
# If the class or method was skipped.
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
or getattr(testMethod, '__unittest_skip_why__', ''))
self._addSkip(result, self, skip_why)
return result
expecting_failure = (
getattr(self, "__unittest_expecting_failure__", False) or
getattr(testMethod, "__unittest_expecting_failure__", False)
)
outcome = _Outcome(result)
try:
self._outcome = outcome
with outcome.testPartExecutor(self):
self._callSetUp()
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
self._callTestMethod(testMethod)
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
self._callTearDown()
self.doCleanups()
for test, reason in outcome.skipped:
self._addSkip(result, test, reason)
self._feedErrorsToResult(result, outcome.errors)
if outcome.success:
if expecting_failure:
if outcome.expectedFailure:
self._addExpectedFailure(result, outcome.expectedFailure)
else:
self._addUnexpectedSuccess(result)
else:
result.addSuccess(self)
return result
finally:
# explicitly break reference cycles:
# outcome.errors -> frame -> outcome -> outcome.errors
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
outcome.errors.clear()
outcome.expectedFailure = None
# clear the outcome, no more needed
self._outcome = None
finally:
result.stopTest(self)
if stopTestRun is not None:
stopTestRun()
setUp¶
Hook method for setting up the test fixture before exercising it.
settings¶
A context manager that temporarily sets a setting and reverts to the
original value when exiting the context.
View Source
shortDescription¶
Returns a one-line description of the test, or None if no
description has been provided.
The default implementation of this method returns the first line of the specified test method's docstring.
View Source
def shortDescription(self):
"""Returns a one-line description of the test, or None if no
description has been provided.
The default implementation of this method returns the first line of
the specified test method's docstring.
"""
doc = self._testMethodDoc
return doc.strip().split("\n")[0].strip() if doc else None
skipTest¶
Skip this test.
subTest¶
Return a context manager that will return the enclosed block
of code in a subtest identified by the optional message and keyword parameters. A failure in the subtest marks the test case as failed but resumes execution at the end of the enclosed block, allowing further test code to be executed.
View Source
@contextlib.contextmanager
def subTest(self, msg=_subtest_msg_sentinel, **params):
"""Return a context manager that will return the enclosed block
of code in a subtest identified by the optional message and
keyword parameters. A failure in the subtest marks the test
case as failed but resumes execution at the end of the enclosed
block, allowing further test code to be executed.
"""
if self._outcome is None or not self._outcome.result_supports_subtests:
yield
return
parent = self._subtest
if parent is None:
params_map = _OrderedChainMap(params)
else:
params_map = parent.params.new_child(params)
self._subtest = _SubTest(self, msg, params_map)
try:
with self._outcome.testPartExecutor(self._subtest, isTest=True):
yield
if not self._outcome.success:
result = self._outcome.result
if result is not None and result.failfast:
raise _ShouldStop
elif self._outcome.expectedFailure:
# If the test is expecting a failure, we really want to
# stop now and register the expected failure.
raise _ShouldStop
finally:
self._subtest = parent
tearDown¶
Hook method for deconstructing the test fixture after testing it.
View Source
test_delete_global_model_with_training_as_model_owner¶
View Source
def test_delete_global_model_with_training_as_model_owner(self):
model = Dummy.create_model(owner=self.user)
training = Dummy.create_training(model=model)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
# due to cascade delete (in the case of GlobalModel), training should also be deleted
self.assertRaises(ObjectDoesNotExist, Training.objects.get, pk=training.id)
test_delete_local_model_with_training_as_model_owner¶
View Source
def test_delete_local_model_with_training_as_model_owner(self):
global_model = Dummy.create_model()
local_model = Dummy.create_model_update(base_model=global_model, owner=self.user)
training = Dummy.create_training(model=global_model)
response = self.client.delete(f"{BASE_URL}/models/{local_model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=local_model.id)
self.assertIsNotNone(Model.objects.get(pk=global_model.id))
self.assertIsNotNone(Training.objects.get(pk=training.id))
test_delete_model_as_training_owner¶
View Source
def test_delete_model_as_training_owner(self):
model = Dummy.create_model()
training = Dummy.create_training(model=model, actor=self.user)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
# due to cascade delete (in the case of GlobalModel), training should also be deleted
self.assertRaises(ObjectDoesNotExist, Training.objects.get, pk=training.id)
test_delete_model_as_training_participant¶
View Source
def test_delete_model_as_training_participant(self):
model = Dummy.create_model()
Dummy.create_training(model=model, participants=[Dummy.create_client(), self.user, Dummy.create_client()])
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
test_delete_model_with_training_as_unrelated_user¶
View Source
def test_delete_model_with_training_as_unrelated_user(self):
model = Dummy.create_model()
Dummy.create_training(model=model)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
test_delete_model_without_training_as_model_owner¶
View Source
def test_delete_model_without_training_as_model_owner(self):
model = Dummy.create_model(owner=self.user)
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual("Model removed!", body["detail"])
self.assertRaises(ObjectDoesNotExist, Model.objects.get, pk=model.id)
test_delete_model_without_training_as_unrelated_user¶
View Source
def test_delete_model_without_training_as_unrelated_user(self):
model = Dummy.create_model()
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(403, response.status_code)
body = response.json()
self.assertEqual(
"You are neither the owner of the model nor the actor of the corresponding training.",
body["detail"]
)
self.assertIsNotNone(Model.objects.get(pk=model.id))
test_delete_non_existing_model¶
View Source
def test_delete_non_existing_model(self):
model_id = str(uuid4())
with self.assertLogs("django.request", level="WARNING"):
response = self.client.delete(f"{BASE_URL}/models/{model_id}/")
self.assertEqual(400, response.status_code)
body = response.json()
self.assertEqual(f"Model {model_id} not found.", body["detail"])
test_download_model_preprocessing¶
View Source
def test_download_model_preprocessing(self):
torch_model_preprocessing = from_torch_module(torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
)))
model = Dummy.create_model(owner=self.user, preprocessing=torch_model_preprocessing)
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
torch_model = torch.jit.load(io.BytesIO(response.content))
self.assertIsNotNone(torch_model)
self.assertTrue(isinstance(torch_model, torch.nn.Module))
test_download_model_preprocessing_with_undefined_preprocessing¶
View Source
def test_download_model_preprocessing_with_undefined_preprocessing(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
with self.assertLogs("django.request", level="WARNING") as cm:
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
self.assertEqual(cm.output, [
f"WARNING:django.request:Not Found: /api/models/{model.id}/preprocessing/",
])
self.assertEqual(404, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"Model '{model.id}' has no preprocessing model defined.", response_json["detail"])
test_get_all_models¶
View Source
def test_get_all_models(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create models and trainings - some related to user some not
[Dummy.create_model() for _ in range(2)]
models = [Dummy.create_model(owner=self.user) for _ in range(2)]
[Dummy.create_training() for _ in range(2)]
trainings = [Dummy.create_training(actor=self.user) for _ in range(2)]
trainings += [Dummy.create_training(participants=[self.user]) for _ in range(2)]
models += [t.model for t in trainings]
# get user related models
response = self.client.get(f"{BASE_URL}/models/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models), len(response_json))
self.assertEqual(
sorted([str(model.id) for model in models]),
sorted([model["id"] for model in response_json])
)
test_get_all_models_for_a_training¶
View Source
def test_get_all_models_for_a_training(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create participants
participants = [Dummy.create_user() for _ in range(4)]
participant_rounds = [3, 4, 4, 3]
# create models and trainings - some related to user some not
[Dummy.create_training() for _ in range(2)]
[Dummy.create_training(actor=self.user) for _ in range(2)]
[Dummy.create_training(participants=[self.user]) for _ in range(2)]
[Dummy.create_model_update() for _ in range(2)]
[Dummy.create_model_update(owner=self.user) for _ in range(2)]
training = Dummy.create_training(actor=self.user, participants=participants)
# create model update for 4 users
base_model = training.model
models = [base_model]
for participant, rounds in zip(participants, participant_rounds):
for round_idx in range(rounds):
model = Dummy.create_model_update(base_model=base_model, owner=participant, round=round_idx+1)
models.append(model)
# get user related models for a special training
response = self.client.get(f"{BASE_URL}/trainings/{training.pk}/models/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models), len(response_json))
self.assertEqual(
sorted([str(model.id) for model in models]),
sorted([model["id"] for model in response_json])
)
test_get_all_models_for_a_training_latest_only¶
View Source
def test_get_all_models_for_a_training_latest_only(self):
# make user actor and client
self.user.actor = True
self.user.client = True
self.user.save()
# create participants
participants = [Dummy.create_user() for _ in range(4)]
participant_rounds = [3, 4, 4, 3]
# create models and trainings - some related to user some not
[Dummy.create_training() for _ in range(2)]
[Dummy.create_training(actor=self.user) for _ in range(2)]
[Dummy.create_training(participants=[self.user]) for _ in range(2)]
[Dummy.create_model_update() for _ in range(2)]
[Dummy.create_model_update(owner=self.user) for _ in range(2)]
training = Dummy.create_training(actor=self.user, participants=participants)
# create model update for 4 users
base_model = training.model
models_latest = [base_model]
for participant, rounds in zip(participants, participant_rounds):
for round_idx in range(rounds):
model = Dummy.create_model_update(base_model=base_model, owner=participant, round=round_idx+1)
models_latest.append(model)
# get user related "latest" models for a special training
response = self.client.get(f"{BASE_URL}/trainings/{training.pk}/models/latest/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(len(models_latest), len(response_json))
models_latest = sorted(models_latest, key=lambda m: str(m.pk))
response_models = sorted(response_json, key=lambda m: m["id"])
self.assertEqual(
[str(model.id) for model in models_latest],
[model["id"] for model in response_models]
)
self.assertEqual(
[model.round for model in models_latest],
[model["round"] for model in response_models]
)
test_get_global_model_metrics¶
View Source
def test_get_global_model_metrics(self):
model = Dummy.create_model(owner=self.user)
metric = Dummy.create_metric(model=model)
response = self.client.get(f"{BASE_URL}/models/{model.id}/metrics/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual(1, len(body))
self.assertEqual(metric.value_float, body[0]["value_float"])
self.assertEqual(metric.key, body[0]["key"])
test_get_local_model_metrics¶
View Source
def test_get_local_model_metrics(self):
model = Dummy.create_model_update(owner=self.user)
metric = Dummy.create_metric(model=model)
response = self.client.get(f"{BASE_URL}/models/{model.id}/metrics/")
self.assertEqual(200, response.status_code)
body = response.json()
self.assertEqual(1, len(body))
self.assertEqual(metric.value_float, body[0]["value_float"])
self.assertEqual(metric.key, body[0]["key"])
test_get_model¶
View Source
def test_get_model(self):
model = Dummy.create_model(weights=b"Hello World!")
response = self.client.get(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
self.assertEqual(b"Hello World!", response.getvalue())
test_get_model_and_unpickle¶
View Source
def test_get_model_and_unpickle(self):
model = Dummy.create_model()
response = self.client.get(f"{BASE_URL}/models/{model.id}/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
torch_model = torch.jit.load(io.BytesIO(response.content))
self.assertIsNotNone(torch_model)
self.assertTrue(isinstance(torch_model, torch.nn.Module))
test_get_model_metadata¶
View Source
def test_get_model_metadata(self):
model_bytes = from_torch_module(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model = Dummy.create_model(weights=model_bytes, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
self.assertFalse(response_json["has_preprocessing"])
# check stats
stats = response_json["stats"]
self.assertIsNotNone(stats)
self.assertEqual([[1, 3]], stats["input_size"])
self.assertIsNotNone(stats["total_input"])
self.assertIsNotNone(stats["total_mult_adds"])
self.assertIsNotNone(stats["total_output_bytes"])
self.assertIsNotNone(stats["total_param_bytes"])
self.assertIsNotNone(stats["total_params"])
self.assertIsNotNone(stats["trainable_params"])
# layer 1 stats
layer1 = stats["summary_list"][0]
self.assertEqual("Sequential", layer1["class_name"])
self.assertEqual(0, layer1["depth"])
self.assertEqual(1, layer1["depth_index"])
self.assertEqual(True, layer1["executed"])
self.assertEqual("Sequential", layer1["var_name"])
self.assertEqual(False, layer1["is_leaf_layer"])
self.assertEqual(False, layer1["contains_lazy_param"])
self.assertEqual(False, layer1["is_recursive"])
self.assertEqual([1, 3], layer1["input_size"])
self.assertEqual([1, 1], layer1["output_size"])
self.assertEqual(None, layer1["kernel_size"])
self.assertIsNotNone(layer1["trainable_params"])
self.assertIsNotNone(layer1["num_params"])
self.assertIsNotNone(layer1["param_bytes"])
self.assertIsNotNone(layer1["output_bytes"])
self.assertIsNotNone(layer1["macs"])
# layer 2 stats
layer2 = stats["summary_list"][1]
self.assertEqual("Linear", layer2["class_name"])
self.assertEqual(1, layer2["depth"])
self.assertEqual(1, layer2["depth_index"])
self.assertEqual(True, layer2["executed"])
self.assertEqual("0", layer2["var_name"])
self.assertEqual(True, layer2["is_leaf_layer"])
self.assertEqual(False, layer2["contains_lazy_param"])
self.assertEqual(False, layer2["is_recursive"])
self.assertEqual([1, 3], layer2["input_size"])
self.assertEqual([1, 64], layer2["output_size"])
self.assertEqual(None, layer2["kernel_size"])
self.assertIsNotNone(layer2["trainable_params"])
self.assertIsNotNone(layer2["num_params"])
self.assertIsNotNone(layer2["param_bytes"])
self.assertIsNotNone(layer2["output_bytes"])
self.assertIsNotNone(layer2["macs"])
# layer 3 stats
layer3 = stats["summary_list"][2]
self.assertEqual("ELU", layer3["class_name"])
self.assertEqual(1, layer3["depth"])
self.assertEqual(2, layer3["depth_index"])
self.assertEqual(True, layer3["executed"])
self.assertEqual("1", layer3["var_name"])
self.assertEqual(True, layer3["is_leaf_layer"])
self.assertEqual(False, layer3["contains_lazy_param"])
self.assertEqual(False, layer3["is_recursive"])
self.assertEqual([1, 64], layer3["input_size"])
self.assertEqual([1, 64], layer3["output_size"])
self.assertEqual(None, layer3["kernel_size"])
self.assertIsNotNone(layer3["trainable_params"])
self.assertIsNotNone(layer3["num_params"])
self.assertIsNotNone(layer3["param_bytes"])
self.assertIsNotNone(layer3["output_bytes"])
self.assertIsNotNone(layer3["macs"])
# layer 4 stats
layer4 = stats["summary_list"][3]
self.assertEqual("Linear", layer4["class_name"])
self.assertEqual(1, layer4["depth"])
self.assertEqual(3, layer4["depth_index"])
self.assertEqual(True, layer4["executed"])
self.assertEqual("2", layer4["var_name"])
self.assertEqual(True, layer4["is_leaf_layer"])
self.assertEqual(False, layer4["contains_lazy_param"])
self.assertEqual(False, layer4["is_recursive"])
self.assertEqual([1, 64], layer4["input_size"])
self.assertEqual([1, 1], layer4["output_size"])
self.assertEqual(None, layer4["kernel_size"])
self.assertIsNotNone(layer4["trainable_params"])
self.assertIsNotNone(layer4["num_params"])
self.assertIsNotNone(layer4["param_bytes"])
self.assertIsNotNone(layer4["output_bytes"])
self.assertIsNotNone(layer4["macs"])
test_get_model_metadata_torchscript_model¶
View Source
def test_get_model_metadata_torchscript_model(self):
torchscript_model_bytes = from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
)))
model = Dummy.create_model(weights=torchscript_model_bytes, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
# check stats
stats = response_json["stats"]
self.assertIsNotNone(stats)
self.assertEqual([[1, 3]], stats["input_size"])
self.assertIsNotNone(stats["total_input"])
self.assertIsNotNone(stats["total_mult_adds"])
self.assertIsNotNone(stats["total_output_bytes"])
self.assertIsNotNone(stats["total_param_bytes"])
self.assertIsNotNone(stats["total_params"])
self.assertIsNotNone(stats["trainable_params"])
self.assertEqual(4, len(stats["summary_list"]))
test_get_model_metadata_with_preprocessing¶
View Source
def test_get_model_metadata_with_preprocessing(self):
model_bytes = from_torch_module(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
torch_model_preprocessing = from_torch_module(transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=(0.,), std=(1.,)),
]))
model = Dummy.create_model(weights=model_bytes, preprocessing=torch_model_preprocessing, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
self.assertTrue(response_json["has_preprocessing"])
test_unauthorized¶
View Source
def test_unauthorized(self):
del self.client.defaults["HTTP_AUTHORIZATION"]
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/",
{"model_file": b"Hello World!"}
)
self.assertEqual(401, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Authentication credentials were not provided.", response_json["detail"])
test_upload¶
View Source
def test_upload(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"model_file": model_file,
"name": "Test Model",
"description": "Test Model Description - Test Model Description Test",
"input_shape": [None, 3]
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(GlobalModel, type(Model.objects.get(id=uuid)))
self.assertEqual([None, 3], Model.objects.get(id=uuid).input_shape)
test_upload_bad_metrics¶
View Source
def test_upload_bad_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, b"Hello World!"],
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Metric names and values must have the same length", response_json["detail"])
test_upload_global_model_metrics¶
View Source
def test_upload_global_model_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, 0.12, b"Hello World!"],
)
with self.assertLogs("fl.server", level="WARNING") as cm:
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(cm.output, [
f"WARNING:fl.server:Global model {model.id} is not connected to any training.",
])
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Metrics Upload Accepted", response_json["detail"])
self.assertEqual(str(model.id), response_json["model_id"])
test_upload_local_model_metrics¶
View Source
def test_upload_local_model_metrics(self):
model = Dummy.create_model_update(owner=self.user)
metrics = dict(
metric_names=["loss", "accuracy", "dummy_binary"],
metric_values=[1999.0, 0.12, b"Hello World!"],
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/metrics/",
metrics,
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Metrics Upload Accepted", response_json["detail"])
self.assertEqual(str(model.id), response_json["model_id"])
test_upload_mean_model¶
View Source
def test_upload_mean_model(self):
models = [Dummy.create_model(owner=self.user) for _ in range(10)]
model_uuids = [str(m.id) for m in models]
response = self.client.post(f"{BASE_URL}/models/", {
"type": "MEAN",
"name": "Test MEAN Model",
"description": "Test MEAN Model Description - Test MEAN Model Description Test",
"models": model_uuids,
}, "application/json")
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(MeanModel, type(Model.objects.get(id=uuid)))
test_upload_model_preprocessing¶
View Source
def test_upload_model_preprocessing(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
torch_model_preprocessing = torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
))
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(202, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Proprocessing Model Upload Accepted", response_json["detail"])
model.refresh_from_db()
self.assertIsNotNone(model)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
test_upload_model_preprocessing_v1_Compose_bad¶
View Source
def test_upload_model_preprocessing_v1_Compose_bad(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
# torchvision.transforms.Compose (v1 not v2) does not inherit from torch.nn.Module!!
torch_model_preprocessing = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.,), std=(1.,)),
])
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # (normal) transforms.Compose
content_type="application/octet-stream"
)
with self.assertLogs("fl.server", level="ERROR"): # Loaded torch object is not of expected type.
with self.assertLogs("django.request", level="WARNING"): # Bad Request
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(
"Invalid preprocessing file: Loaded torch object is not of expected type.",
response_json[0],
)
test_upload_model_preprocessing_v2_Compose_good¶
View Source
def test_upload_model_preprocessing_v2_Compose_good(self):
# Maybe good now
model = Dummy.create_model(owner=self.user, preprocessing=None)
torch_model_preprocessing = transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=(0.,), std=(1.,)),
])
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # (normal) transforms.Compose
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/preprocessing/", {
"model_preprocessing_file": model_preprocessing_file,
})
self.assertEqual(202, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Proprocessing Model Upload Accepted", response_json["detail"])
model.refresh_from_db()
self.assertIsNotNone(model)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
test_upload_swag_model¶
View Source
def test_upload_swag_model(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"type": "SWAG",
"model_file": model_file,
"name": "Test SWAG Model",
"description": "Test SWAG Model Description - Test SWAG Model Description Test",
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(SWAGModel, type(Model.objects.get(id=uuid)))
test_upload_swag_stats¶
View Source
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_swag_stats(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
train = Dummy.create_training(
model=model,
actor=self.user,
state=TrainingState.SWAG_ROUND,
participants=[self.user]
)
first_moment_file = SimpleUploadedFile(
"first_moment.pkl",
pickle.dumps(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
).state_dict()),
content_type="application/octet-stream"
)
second_moment_file = SimpleUploadedFile(
"second_moment.pkl",
pickle.dumps(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
).state_dict()),
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/{model.id}/swag/", {
"first_moment_file": first_moment_file,
"second_moment_file": second_moment_file,
"sample_size": 100,
"round": 0
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("SWAG Statistic Accepted", response_json["detail"])
self.assertTrue(apply_async.called)
apply_async.assert_called_once_with(
(),
{"training_uuid": train.id, "event_cls": SWAGRoundFinished},
retry=False
)
test_upload_update¶
View Source
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertFalse(apply_async.called)
test_upload_update_and_aggregate¶
View Source
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update_and_aggregate(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
train = Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertTrue(apply_async.called)
apply_async.assert_called_once_with(
(),
{"training_uuid": train.id, "event_cls": TrainingRoundFinished},
retry=False
)
test_upload_update_and_not_aggregate_since_training_is_locked¶
def test_upload_update_and_not_aggregate_since_training_is_locked(
self,
apply_async: unittest.mock.MagicMock
)
View Source
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update_and_not_aggregate_since_training_is_locked(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
training = Dummy.create_training(
model=model, actor=self.user, state=TrainingState.ONGOING, participants=[self.user]
)
training.locked = True
training.save()
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 100}
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
self.assertFalse(apply_async.called)
test_upload_update_bad_keys¶
View Source
def test_upload_update_bad_keys(self):
model = Dummy.create_model(owner=self.user, round=0)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"xXx_model_file_xXx": model_update_file, "round": 0, "sample_size": 100}
)
self.assertEqual(400, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("No uploaded file 'model_file' found.", response_json["detail"])
test_upload_update_no_participant¶
View Source
def test_upload_update_no_participant(self):
self.client.defaults["HTTP_ACCEPT"] = "application/json"
actor = Dummy.create_actor()
model = Dummy.create_model(owner=actor, round=0)
training = Dummy.create_training(
model=model, actor=actor, state=TrainingState.ONGOING,
participants=[actor, Dummy.create_client()]
)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0,
"sample_size": 500}
)
self.assertEqual(403, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"You are not a participant of training {training.id}!", response_json["detail"])
test_upload_update_no_training¶
View Source
def test_upload_update_no_training(self):
model = Dummy.create_model(owner=self.user, round=0)
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("django.request", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0, "sample_size": 100}
)
self.assertEqual(404, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"Model with ID {model.id} does not have a training process running", response_json["detail"])
test_upload_update_with_metrics¶
View Source
def test_upload_update_with_metrics(self):
model = Dummy.create_model(owner=self.user, round=0)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{
"model_file": model_update_file,
"round": 0,
"metric_names": ["loss", "accuracy", "dummy_binary"],
"metric_values": [1999.0, 0.12, b"Hello World!"],
"sample_size": 50
},
)
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Update Accepted", response_json["detail"])
test_upload_update_with_metrics_bad¶
View Source
def test_upload_update_with_metrics_bad(self):
model = Dummy.create_model(owner=self.user)
Dummy.create_training(model=model, actor=self.user, state=TrainingState.ONGOING,
participants=[self.user, Dummy.create_user()])
model_update_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Sigmoid()
))),
content_type="application/octet-stream"
)
with self.assertLogs("root", level="WARNING"):
response = self.client.post(
f"{BASE_URL}/models/{model.id}/",
{"model_file": model_update_file, "round": 0, "metric_names": 5,
"sample_size": 500}
)
self.assertEqual(400, response.status_code)
test_upload_with_preprocessing¶
View Source
def test_upload_with_preprocessing(self):
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
model_file = SimpleUploadedFile(
"model.pt",
from_torch_module(torch_model), # torchscript model
content_type="application/octet-stream"
)
torch_model_preprocessing = torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
))
model_preprocessing_file = SimpleUploadedFile(
"preprocessing.pt",
from_torch_module(torch_model_preprocessing), # torchscript model
content_type="application/octet-stream"
)
response = self.client.post(f"{BASE_URL}/models/", {
"model_file": model_file,
"model_preprocessing_file": model_preprocessing_file,
"name": "Test Model",
"description": "Test Model Description - Test Model Description Test",
"input_shape": [None, 3]
})
self.assertEqual(201, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual("Model Upload Accepted", response_json["detail"])
uuid = response_json["model_id"]
self.assertIsNotNone(uuid)
self.assertIsNot("", uuid)
self.assertEqual(GlobalModel, type(Model.objects.get(id=uuid)))
self.assertEqual([None, 3], Model.objects.get(id=uuid).input_shape)
model = get_entity(GlobalModel, pk=uuid)
self.assertIsNotNone(model)
self.assertIsNotNone(model.weights)
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_torch_model(), torch.nn.Module))
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))