Skip to content

fl_server_ai.tests.test_aggregation

Classes:

Name Description
AggregationTest

Classes

AggregationTest

Bases: TestCase


              flowchart TD
              fl_server_ai.tests.test_aggregation.AggregationTest[AggregationTest]

              

              click fl_server_ai.tests.test_aggregation.AggregationTest href "" "fl_server_ai.tests.test_aggregation.AggregationTest"
            

Methods:

Name Description
test_aggregate
test_aggregate_sample_sizes
Source code in fl_server_ai/tests/test_aggregation.py
class AggregationTest(TestCase):

    def test_aggregate(self):
        aggr = MeanAggregation()
        models = [_create_torchscript_model_and_init(i) for i in range(10)]
        model = aggr.aggregate(models, [1]*10)
        cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
        self.assertEqual(cls_name, "Sequential")
        res = model.state_dict()
        self.assertEqual(len(models[0].state_dict()), len(res))
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 4.5)
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 4.5)

    def test_aggregate_sample_sizes(self):
        aggr = MeanAggregation()
        models = [_create_torchscript_model_and_init(i) for i in range(3)]
        model = aggr.aggregate(models, [0, 1, 2])
        cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
        self.assertEqual(cls_name, "Sequential")
        self.assertEqual(len(list(models[0].parameters())), len(list(model.parameters())))
        res = model.state_dict()
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * (5/3))
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * (5/3))

Functions

test_aggregate
test_aggregate()
Source code in fl_server_ai/tests/test_aggregation.py
def test_aggregate(self):
    aggr = MeanAggregation()
    models = [_create_torchscript_model_and_init(i) for i in range(10)]
    model = aggr.aggregate(models, [1]*10)
    cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
    self.assertEqual(cls_name, "Sequential")
    res = model.state_dict()
    self.assertEqual(len(models[0].state_dict()), len(res))
    torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 4.5)
    torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 4.5)
test_aggregate_sample_sizes
test_aggregate_sample_sizes()
Source code in fl_server_ai/tests/test_aggregation.py
def test_aggregate_sample_sizes(self):
    aggr = MeanAggregation()
    models = [_create_torchscript_model_and_init(i) for i in range(3)]
    model = aggr.aggregate(models, [0, 1, 2])
    cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
    self.assertEqual(cls_name, "Sequential")
    self.assertEqual(len(list(models[0].parameters())), len(list(model.parameters())))
    res = model.state_dict()
    torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * (5/3))
    torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * (5/3))