Bases: TestCase
flowchart TD
fl_server_ai.tests.test_aggregation.AggregationTest[AggregationTest]
click fl_server_ai.tests.test_aggregation.AggregationTest href "" "fl_server_ai.tests.test_aggregation.AggregationTest"
Methods:
Source code in fl_server_ai/tests/test_aggregation.py
| class AggregationTest(TestCase):
def test_aggregate(self):
aggr = MeanAggregation()
models = [_create_torchscript_model_and_init(i) for i in range(10)]
model = aggr.aggregate(models, [1]*10)
cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
self.assertEqual(cls_name, "Sequential")
res = model.state_dict()
self.assertEqual(len(models[0].state_dict()), len(res))
torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 4.5)
torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 4.5)
def test_aggregate_sample_sizes(self):
aggr = MeanAggregation()
models = [_create_torchscript_model_and_init(i) for i in range(3)]
model = aggr.aggregate(models, [0, 1, 2])
cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
self.assertEqual(cls_name, "Sequential")
self.assertEqual(len(list(models[0].parameters())), len(list(model.parameters())))
res = model.state_dict()
torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * (5/3))
torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * (5/3))
|
Functions
test_aggregate
Source code in fl_server_ai/tests/test_aggregation.py
| def test_aggregate(self):
aggr = MeanAggregation()
models = [_create_torchscript_model_and_init(i) for i in range(10)]
model = aggr.aggregate(models, [1]*10)
cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
self.assertEqual(cls_name, "Sequential")
res = model.state_dict()
self.assertEqual(len(models[0].state_dict()), len(res))
torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 4.5)
torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 4.5)
|
test_aggregate_sample_sizes
test_aggregate_sample_sizes()
Source code in fl_server_ai/tests/test_aggregation.py
| def test_aggregate_sample_sizes(self):
aggr = MeanAggregation()
models = [_create_torchscript_model_and_init(i) for i in range(3)]
model = aggr.aggregate(models, [0, 1, 2])
cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
self.assertEqual(cls_name, "Sequential")
self.assertEqual(len(list(models[0].parameters())), len(list(model.parameters())))
res = model.state_dict()
torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * (5/3))
torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * (5/3))
|