Skip to content

Module fl_server_core.tests

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from .dummy import BASE_URL, Dummy


__all__ = ["BASE_URL", "Dummy"]

Sub-modules

Variables

BASE_URL

Classes

Dummy

class Dummy(
    /,
    *args,
    **kwargs
)
View Source
class Dummy:

    fake = Faker("en_US")
    _logger = getLogger("fl.server.core")

    @classmethod
    def create_model(cls, model_cls: Type[Model] = GlobalModel, **kwargs) -> GlobalModel:
        torchscript_model = torch.jit.script(torch.nn.Sequential(
            torch.nn.Linear(3, 1),
            torch.nn.Sigmoid()
        ))
        args = dict(
            name=f"{cls.fake.company()} Model",
            description=f"Model created for {cls.fake.catch_phrase()}.",
            round=0,
            weights=from_torch_module(torchscript_model),  # torchscript model
            input_shape=None if kwargs.__contains__("weights") else [None, 3],
        )
        args.update(kwargs)
        if "owner" not in args:
            args["owner"] = cls.create_user()
        model = model_cls.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Model with id {model.id}")
        return model

    @classmethod
    def create_model_update(cls, **kwargs) -> LocalModel:
        args = dict(
            round=1,
            weights=from_torch_module(torch.jit.script(torch.nn.Sequential(
                torch.nn.Linear(3, 1),
                torch.nn.Sigmoid()
            ))),
            sample_size=10,
        )
        args.update(kwargs)
        args["base_model"] = args["base_model"] if args.__contains__("base_model") else cls.create_model()
        args["owner"] = args["owner"] if args.__contains__("owner") else cls.create_client()
        model = LocalModel.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Model Update with id {model.id}")
        return model

    @classmethod
    def create_broken_model(cls, **kwargs):
        args = dict(weights=pickle.dumps("I am not a torch.nn.Module!"))
        args.update(kwargs)
        return cls.create_model(**args)

    @classmethod
    def create_training(cls, **kwargs):
        args = dict(
            state=TrainingState.INITIAL,
            target_num_updates=0,
            uncertainty_method=UncertaintyMethod.NONE,
            aggregation_method=AggregationMethod.FED_AVG,
        )
        args.update(kwargs)
        if args.__contains__("actor"):
            model_kwargs = {"owner": args["actor"]}
        else:
            args["actor"] = cls.create_actor()
            model_kwargs = {}

        args["model"] = args["model"] if args.__contains__("model") else cls.create_model(**model_kwargs)
        participants = args.pop("participants", [cls.create_client(), cls.create_client()])
        training = Training.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Training with id {training.id}")
        for participant in participants:
            training.participants.add(participant)
        training.save()
        return training

    @classmethod
    def _create_user(cls, **kwargs) -> User:
        user = User.objects.create(**kwargs)
        user.set_password(kwargs["password"])
        user.save()
        cls._logger.debug(f"Creating Dummy User with id {user.id}")
        return user

    @classmethod
    def create_user(cls, **kwargs):
        args = dict(
            message_endpoint="https://" + cls.fake.safe_email().replace("@", "."),
            actor=False,
            client=False,
            username=cls.fake.user_name(),
            first_name=cls.fake.first_name(),
            last_name=cls.fake.last_name(),
            email=cls.fake.safe_email(),
            password="secret",
        )
        args.update(kwargs)
        return cls._create_user(**args)

    @classmethod
    def create_actor(cls, **kwargs):
        kwargs["actor"] = True
        kwargs["client"] = False
        return cls.create_user(**kwargs)

    @classmethod
    def create_client(cls, **kwargs):
        kwargs["actor"] = False
        kwargs["client"] = True
        return cls.create_user(**kwargs)

    @classmethod
    def create_user_and_authenticate(cls, client: Client, **kwargs):
        args = dict(
            username=cls.fake.user_name(),
            password="secret",
        )
        args.update(kwargs)
        credentials = base64.b64encode(
            f'{args["username"]}:{args["password"]}'.encode("utf-8")
        ).decode("utf-8")
        client.defaults["HTTP_AUTHORIZATION"] = "Basic " + credentials
        return cls._create_user(**args)

    @classmethod
    def create_group(cls, **kwargs):
        args = dict(
            name=cls.fake.catch_phrase(),
        )
        args.update(kwargs)
        group = Group.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Group with id {group.id}")
        return group

    @classmethod
    def create_metric(cls, **kwargs):
        epoch = random.randint(0, 100)
        args = dict(
            identifier=str(uuid4()).split("-")[0],
            key="NLLoss",
            step=epoch,
        )
        args.update(kwargs)
        if not args.__contains__("value_float") and not args.__contains__("value_binary"):
            args["value_float"] = 27.354/(epoch+1) + random.randint(0, 100) / 100
        args["model"] = args["model"] if args.__contains__("model") else cls.create_model()
        args["reporter"] = args["reporter"] if args.__contains__("reporter") else cls.create_client()
        metric = Metric.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Metric with id {metric.id}")
        return metric

Class variables

fake

Static methods

create_actor

def create_actor(
    **kwargs
)
View Source
    @classmethod
    def create_actor(cls, **kwargs):
        kwargs["actor"] = True
        kwargs["client"] = False
        return cls.create_user(**kwargs)

create_broken_model

def create_broken_model(
    **kwargs
)
View Source
    @classmethod
    def create_broken_model(cls, **kwargs):
        args = dict(weights=pickle.dumps("I am not a torch.nn.Module!"))
        args.update(kwargs)
        return cls.create_model(**args)

create_client

def create_client(
    **kwargs
)
View Source
    @classmethod
    def create_client(cls, **kwargs):
        kwargs["actor"] = False
        kwargs["client"] = True
        return cls.create_user(**kwargs)

create_group

def create_group(
    **kwargs
)
View Source
    @classmethod
    def create_group(cls, **kwargs):
        args = dict(
            name=cls.fake.catch_phrase(),
        )
        args.update(kwargs)
        group = Group.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Group with id {group.id}")
        return group

create_metric

def create_metric(
    **kwargs
)
View Source
    @classmethod
    def create_metric(cls, **kwargs):
        epoch = random.randint(0, 100)
        args = dict(
            identifier=str(uuid4()).split("-")[0],
            key="NLLoss",
            step=epoch,
        )
        args.update(kwargs)
        if not args.__contains__("value_float") and not args.__contains__("value_binary"):
            args["value_float"] = 27.354/(epoch+1) + random.randint(0, 100) / 100
        args["model"] = args["model"] if args.__contains__("model") else cls.create_model()
        args["reporter"] = args["reporter"] if args.__contains__("reporter") else cls.create_client()
        metric = Metric.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Metric with id {metric.id}")
        return metric

create_model

def create_model(
    model_cls: Type[fl_server_core.models.model.Model] = <class 'fl_server_core.models.model.GlobalModel'>,
    **kwargs
) -> fl_server_core.models.model.GlobalModel
View Source
    @classmethod
    def create_model(cls, model_cls: Type[Model] = GlobalModel, **kwargs) -> GlobalModel:
        torchscript_model = torch.jit.script(torch.nn.Sequential(
            torch.nn.Linear(3, 1),
            torch.nn.Sigmoid()
        ))
        args = dict(
            name=f"{cls.fake.company()} Model",
            description=f"Model created for {cls.fake.catch_phrase()}.",
            round=0,
            weights=from_torch_module(torchscript_model),  # torchscript model
            input_shape=None if kwargs.__contains__("weights") else [None, 3],
        )
        args.update(kwargs)
        if "owner" not in args:
            args["owner"] = cls.create_user()
        model = model_cls.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Model with id {model.id}")
        return model

create_model_update

def create_model_update(
    **kwargs
) -> fl_server_core.models.model.LocalModel
View Source
    @classmethod
    def create_model_update(cls, **kwargs) -> LocalModel:
        args = dict(
            round=1,
            weights=from_torch_module(torch.jit.script(torch.nn.Sequential(
                torch.nn.Linear(3, 1),
                torch.nn.Sigmoid()
            ))),
            sample_size=10,
        )
        args.update(kwargs)
        args["base_model"] = args["base_model"] if args.__contains__("base_model") else cls.create_model()
        args["owner"] = args["owner"] if args.__contains__("owner") else cls.create_client()
        model = LocalModel.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Model Update with id {model.id}")
        return model

create_training

def create_training(
    **kwargs
)
View Source
    @classmethod
    def create_training(cls, **kwargs):
        args = dict(
            state=TrainingState.INITIAL,
            target_num_updates=0,
            uncertainty_method=UncertaintyMethod.NONE,
            aggregation_method=AggregationMethod.FED_AVG,
        )
        args.update(kwargs)
        if args.__contains__("actor"):
            model_kwargs = {"owner": args["actor"]}
        else:
            args["actor"] = cls.create_actor()
            model_kwargs = {}

        args["model"] = args["model"] if args.__contains__("model") else cls.create_model(**model_kwargs)
        participants = args.pop("participants", [cls.create_client(), cls.create_client()])
        training = Training.objects.create(**args)
        cls._logger.debug(f"Creating Dummy Training with id {training.id}")
        for participant in participants:
            training.participants.add(participant)
        training.save()
        return training

create_user

def create_user(
    **kwargs
)
View Source
    @classmethod
    def create_user(cls, **kwargs):
        args = dict(
            message_endpoint="https://" + cls.fake.safe_email().replace("@", "."),
            actor=False,
            client=False,
            username=cls.fake.user_name(),
            first_name=cls.fake.first_name(),
            last_name=cls.fake.last_name(),
            email=cls.fake.safe_email(),
            password="secret",
        )
        args.update(kwargs)
        return cls._create_user(**args)

create_user_and_authenticate

def create_user_and_authenticate(
    client: django.test.client.Client,
    **kwargs
)
View Source
    @classmethod
    def create_user_and_authenticate(cls, client: Client, **kwargs):
        args = dict(
            username=cls.fake.user_name(),
            password="secret",
        )
        args.update(kwargs)
        credentials = base64.b64encode(
            f'{args["username"]}:{args["password"]}'.encode("utf-8")
        ).decode("utf-8")
        client.defaults["HTTP_AUTHORIZATION"] = "Basic " + credentials
        return cls._create_user(**args)