Skip to content

Module fl_server_core.models.training

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from django.core.serializers.json import DjangoJSONEncoder
from django.db import models
from django.db.models import (
    CASCADE, CharField, ForeignKey, IntegerField, JSONField, ManyToManyField,
    OneToOneField, UUIDField, BooleanField
)
from django.db.models.signals import post_save
from django.dispatch import receiver
from django.utils.translation import gettext_lazy as _
from uuid import uuid4

from .model import Model
from .user import User


class TrainingState(models.TextChoices):
    """
    Training state choices for a Training.
    """
    INITIAL = "I", _("Initial")
    ONGOING = "O", _("Ongoing")
    COMPLETED = "C", _("Completed")
    ERROR = "E", _("Error")
    SWAG_ROUND = "S", _("SwagRound")


class AggregationMethod(models.TextChoices):
    """
    Aggregation method choices for a Training.
    """
    FED_AVG = "FedAvg", _("FedAvg")
    FED_DC = "FedDC", _("FedDC")
    FED_PROX = "FedProx", _("FedProx")


class UncertaintyMethod(models.TextChoices):
    """
    Uncertainty method choices for a Training.
    """
    NONE = "NONE", _("None")
    ENSEMBLE = "ENSEMBLE", _("Ensemble")
    MC_DROPOUT = "MC_DROPOUT", _("MC Dropout")
    SWAG = "SWAG", _("SWAG")


class Training(models.Model):
    """
    Training model class.
    """

    id: UUIDField = UUIDField(primary_key=True, editable=False, default=uuid4)
    """Unique identifier for the training."""
    model: OneToOneField = OneToOneField(Model, on_delete=CASCADE)
    """Model used in the training."""
    actor: ForeignKey = ForeignKey(User, on_delete=CASCADE, related_name="actors")
    """User who is the actor of the training."""
    participants: ManyToManyField = ManyToManyField(User)
    """Users who are the participants of the training."""
    state: CharField = CharField(max_length=1, choices=TrainingState.choices)
    """State of the training."""
    target_num_updates: IntegerField = IntegerField()
    """Target number of updates for the training."""
    last_update = models.DateTimeField(auto_now=True)
    """Time of the last update."""
    uncertainty_method: CharField = CharField(
        max_length=32, choices=UncertaintyMethod.choices, default=UncertaintyMethod.NONE
    )
    """Uncertainty method used in the training."""
    aggregation_method: CharField = CharField(
        max_length=32, choices=AggregationMethod.choices, default=AggregationMethod.FED_AVG
    )
    """Aggregation method used in the training."""
    # HINT: https://docs.djangoproject.com/en/4.2/topics/db/queries/#querying-jsonfield
    options: JSONField = JSONField(default=dict, encoder=DjangoJSONEncoder)
    """Options for the training."""
    locked: BooleanField = BooleanField(default=False)
    """Flag indicating whether the training is locked."""


@receiver(post_save, sender=Training)
def post_save_training(sender, instance=None, created=False, *args, **kwargs):
    """
    Ensure that the correct `target_num_updates` is set for every new training.

    This method is called after saving a training instance.
    It is used to set the `target_num_updates` to the correct value if this training instance is newly created
    and daisy chaining is enabled.

    Args:
        sender: The model class.
        instance (Training, optional): The actual instance being saved. Defaults to None.
        created (bool, optional): A boolean; True if a new record was created. Defaults to False.
        *args: Additional arguments.
        **kwargs: Arbitrary keyword arguments.
    """
    if not created:
        return
    daisy_chain_period = instance.options.get("daisy_chain_period", 0)
    if daisy_chain_period <= 0:
        return
    instance.target_num_updates = instance.target_num_updates * daisy_chain_period
    instance.save()

Functions

post_save_training

def post_save_training(
    sender,
    instance=None,
    created=False,
    *args,
    **kwargs
)

Ensure that the correct target_num_updates is set for every new training.

This method is called after saving a training instance. It is used to set the target_num_updates to the correct value if this training instance is newly created and daisy chaining is enabled.

Parameters:

Name Type Description Default
sender None The model class. None
instance Training The actual instance being saved. Defaults to None. None
created bool A boolean; True if a new record was created. Defaults to False. False
*args None Additional arguments. None
**kwargs None Arbitrary keyword arguments. None
View Source
@receiver(post_save, sender=Training)
def post_save_training(sender, instance=None, created=False, *args, **kwargs):
    """
    Ensure that the correct `target_num_updates` is set for every new training.

    This method is called after saving a training instance.
    It is used to set the `target_num_updates` to the correct value if this training instance is newly created
    and daisy chaining is enabled.

    Args:
        sender: The model class.
        instance (Training, optional): The actual instance being saved. Defaults to None.
        created (bool, optional): A boolean; True if a new record was created. Defaults to False.
        *args: Additional arguments.
        **kwargs: Arbitrary keyword arguments.
    """
    if not created:
        return
    daisy_chain_period = instance.options.get("daisy_chain_period", 0)
    if daisy_chain_period <= 0:
        return
    instance.target_num_updates = instance.target_num_updates * daisy_chain_period
    instance.save()

Classes

AggregationMethod

class AggregationMethod(
    /,
    *args,
    **kwargs
)

Aggregation method choices for a Training.

View Source
class AggregationMethod(models.TextChoices):
    """
    Aggregation method choices for a Training.
    """
    FED_AVG = "FedAvg", _("FedAvg")
    FED_DC = "FedDC", _("FedDC")
    FED_PROX = "FedProx", _("FedProx")

Ancestors (in MRO)

  • django.db.models.enums.TextChoices
  • builtins.str
  • django.db.models.enums.Choices
  • enum.Enum

Class variables

FED_AVG
FED_DC
FED_PROX

Training

class Training(
    *args,
    **kwargs
)

Training model class.

View Source
class Training(models.Model):
    """
    Training model class.
    """

    id: UUIDField = UUIDField(primary_key=True, editable=False, default=uuid4)
    """Unique identifier for the training."""
    model: OneToOneField = OneToOneField(Model, on_delete=CASCADE)
    """Model used in the training."""
    actor: ForeignKey = ForeignKey(User, on_delete=CASCADE, related_name="actors")
    """User who is the actor of the training."""
    participants: ManyToManyField = ManyToManyField(User)
    """Users who are the participants of the training."""
    state: CharField = CharField(max_length=1, choices=TrainingState.choices)
    """State of the training."""
    target_num_updates: IntegerField = IntegerField()
    """Target number of updates for the training."""
    last_update = models.DateTimeField(auto_now=True)
    """Time of the last update."""
    uncertainty_method: CharField = CharField(
        max_length=32, choices=UncertaintyMethod.choices, default=UncertaintyMethod.NONE
    )
    """Uncertainty method used in the training."""
    aggregation_method: CharField = CharField(
        max_length=32, choices=AggregationMethod.choices, default=AggregationMethod.FED_AVG
    )
    """Aggregation method used in the training."""
    # HINT: https://docs.djangoproject.com/en/4.2/topics/db/queries/#querying-jsonfield
    options: JSONField = JSONField(default=dict, encoder=DjangoJSONEncoder)
    """Options for the training."""
    locked: BooleanField = BooleanField(default=False)
    """Flag indicating whether the training is locked."""

Ancestors (in MRO)

  • django.db.models.base.Model

Class variables

DoesNotExist
MultipleObjectsReturned
actor
actor_id
last_update

Time of the last update.

model
model_id
objects
participants

Static methods

check

def check(
    **kwargs
)
View Source
    @classmethod
    def check(cls, **kwargs):
        errors = [
            *cls._check_swappable(),
            *cls._check_model(),
            *cls._check_managers(**kwargs),
        ]
        if not cls._meta.swapped:
            databases = kwargs.get("databases") or []
            errors += [
                *cls._check_fields(**kwargs),
                *cls._check_m2m_through_same_relationship(),
                *cls._check_long_column_names(databases),
            ]
            clash_errors = (
                *cls._check_id_field(),
                *cls._check_field_name_clashes(),
                *cls._check_model_name_db_lookup_clashes(),
                *cls._check_property_name_related_field_accessor_clashes(),
                *cls._check_single_primary_key(),
            )
            errors.extend(clash_errors)
            # If there are field name clashes, hide consequent column name
            # clashes.
            if not clash_errors:
                errors.extend(cls._check_column_name_clashes())
            errors += [
                *cls._check_index_together(),
                *cls._check_unique_together(),
                *cls._check_indexes(databases),
                *cls._check_ordering(),
                *cls._check_constraints(databases),
                *cls._check_default_pk(),
            ]

        return errors

from_db

def from_db(
    db,
    field_names,
    values
)
View Source
    @classmethod
    def from_db(cls, db, field_names, values):
        if len(values) != len(cls._meta.concrete_fields):
            values_iter = iter(values)
            values = [
                next(values_iter) if f.attname in field_names else DEFERRED
                for f in cls._meta.concrete_fields
            ]
        new = cls(*values)
        new._state.adding = False
        new._state.db = db
        return new

Instance variables

pk

Methods

aggregation_method

def aggregation_method(
    ...
)

A wrapper for a deferred-loading field. When the value is read from this

object the first time, the query is executed.

clean

def clean(
    self
)

Hook for doing any extra model-wide validation after clean() has been

called on every field by self.clean_fields. Any ValidationError raised by this method will not be associated with a particular field; it will have a special-case association with the field defined by NON_FIELD_ERRORS.

View Source
    def clean(self):
        """
        Hook for doing any extra model-wide validation after clean() has been
        called on every field by self.clean_fields. Any ValidationError raised
        by this method will not be associated with a particular field; it will
        have a special-case association with the field defined by NON_FIELD_ERRORS.
        """
        pass

clean_fields

def clean_fields(
    self,
    exclude=None
)

Clean all fields and raise a ValidationError containing a dict

of all validation errors if any occur.

View Source
    def clean_fields(self, exclude=None):
        """
        Clean all fields and raise a ValidationError containing a dict
        of all validation errors if any occur.
        """
        if exclude is None:
            exclude = []

        errors = {}
        for f in self._meta.fields:
            if f.name in exclude:
                continue
            # Skip validation for empty fields with blank=True. The developer
            # is responsible for making sure they have a valid value.
            raw_value = getattr(self, f.attname)
            if f.blank and raw_value in f.empty_values:
                continue
            try:
                setattr(self, f.attname, f.clean(raw_value, self))
            except ValidationError as e:
                errors[f.name] = e.error_list

        if errors:
            raise ValidationError(errors)

date_error_message

def date_error_message(
    self,
    lookup_type,
    field_name,
    unique_for
)
View Source
    def date_error_message(self, lookup_type, field_name, unique_for):
        opts = self._meta
        field = opts.get_field(field_name)
        return ValidationError(
            message=field.error_messages["unique_for_date"],
            code="unique_for_date",
            params={
                "model": self,
                "model_name": capfirst(opts.verbose_name),
                "lookup_type": lookup_type,
                "field": field_name,
                "field_label": capfirst(field.verbose_name),
                "date_field": unique_for,
                "date_field_label": capfirst(opts.get_field(unique_for).verbose_name),
            },
        )

delete

def delete(
    self,
    using=None,
    keep_parents=False
)
View Source
    def delete(self, using=None, keep_parents=False):
        if self.pk is None:
            raise ValueError(
                "%s object can't be deleted because its %s attribute is set "
                "to None." % (self._meta.object_name, self._meta.pk.attname)
            )
        using = using or router.db_for_write(self.__class__, instance=self)
        collector = Collector(using=using)
        collector.collect([self], keep_parents=keep_parents)
        return collector.delete()

full_clean

def full_clean(
    self,
    exclude=None,
    validate_unique=True
)

Call clean_fields(), clean(), and validate_unique() on the model.

Raise a ValidationError for any errors that occur.

View Source
    def full_clean(self, exclude=None, validate_unique=True):
        """
        Call clean_fields(), clean(), and validate_unique() on the model.
        Raise a ValidationError for any errors that occur.
        """
        errors = {}
        if exclude is None:
            exclude = []
        else:
            exclude = list(exclude)

        try:
            self.clean_fields(exclude=exclude)
        except ValidationError as e:
            errors = e.update_error_dict(errors)

        # Form.clean() is run even if other validation fails, so do the
        # same with Model.clean() for consistency.
        try:
            self.clean()
        except ValidationError as e:
            errors = e.update_error_dict(errors)

        # Run unique checks, but only for fields that passed validation.
        if validate_unique:
            for name in errors:
                if name != NON_FIELD_ERRORS and name not in exclude:
                    exclude.append(name)
            try:
                self.validate_unique(exclude=exclude)
            except ValidationError as e:
                errors = e.update_error_dict(errors)

        if errors:
            raise ValidationError(errors)

get_aggregation_method_display

def get_aggregation_method_display(
    self,
    *,
    field=<django.db.models.fields.CharField: aggregation_method>
)
View Source
        def _method(cls_or_self, /, *args, **keywords):
            keywords = {**self.keywords, **keywords}
            return self.func(cls_or_self, *self.args, *args, **keywords)

get_deferred_fields

def get_deferred_fields(
    self
)

Return a set containing names of deferred fields for this instance.

View Source
    def get_deferred_fields(self):
        """
        Return a set containing names of deferred fields for this instance.
        """
        return {
            f.attname
            for f in self._meta.concrete_fields
            if f.attname not in self.__dict__
        }

get_next_by_last_update

def get_next_by_last_update(
    self,
    *,
    field=<django.db.models.fields.DateTimeField: last_update>,
    is_next=True,
    **kwargs
)
View Source
        def _method(cls_or_self, /, *args, **keywords):
            keywords = {**self.keywords, **keywords}
            return self.func(cls_or_self, *self.args, *args, **keywords)

get_previous_by_last_update

def get_previous_by_last_update(
    self,
    *,
    field=<django.db.models.fields.DateTimeField: last_update>,
    is_next=False,
    **kwargs
)
View Source
        def _method(cls_or_self, /, *args, **keywords):
            keywords = {**self.keywords, **keywords}
            return self.func(cls_or_self, *self.args, *args, **keywords)

get_state_display

def get_state_display(
    self,
    *,
    field=<django.db.models.fields.CharField: state>
)
View Source
        def _method(cls_or_self, /, *args, **keywords):
            keywords = {**self.keywords, **keywords}
            return self.func(cls_or_self, *self.args, *args, **keywords)

get_uncertainty_method_display

def get_uncertainty_method_display(
    self,
    *,
    field=<django.db.models.fields.CharField: uncertainty_method>
)
View Source
        def _method(cls_or_self, /, *args, **keywords):
            keywords = {**self.keywords, **keywords}
            return self.func(cls_or_self, *self.args, *args, **keywords)

id

def id(
    ...
)

A wrapper for a deferred-loading field. When the value is read from this

object the first time, the query is executed.

locked

def locked(
    ...
)

A wrapper for a deferred-loading field. When the value is read from this

object the first time, the query is executed.

options

def options(
    ...
)

A wrapper for a deferred-loading field. When the value is read from this

object the first time, the query is executed.

prepare_database_save

def prepare_database_save(
    self,
    field
)
View Source
    def prepare_database_save(self, field):
        if self.pk is None:
            raise ValueError(
                "Unsaved model instance %r cannot be used in an ORM query." % self
            )
        return getattr(self, field.remote_field.get_related_field().attname)

refresh_from_db

def refresh_from_db(
    self,
    using=None,
    fields=None
)

Reload field values from the database.

By default, the reloading happens from the database this instance was loaded from, or by the read router if this instance wasn't loaded from any database. The using parameter will override the default.

Fields can be used to specify which fields to reload. The fields should be an iterable of field attnames. If fields is None, then all non-deferred fields are reloaded.

When accessing deferred fields of an instance, the deferred loading of the field will call this method.

View Source
    def refresh_from_db(self, using=None, fields=None):
        """
        Reload field values from the database.

        By default, the reloading happens from the database this instance was
        loaded from, or by the read router if this instance wasn't loaded from
        any database. The using parameter will override the default.

        Fields can be used to specify which fields to reload. The fields
        should be an iterable of field attnames. If fields is None, then
        all non-deferred fields are reloaded.

        When accessing deferred fields of an instance, the deferred loading
        of the field will call this method.
        """
        if fields is None:
            self._prefetched_objects_cache = {}
        else:
            prefetched_objects_cache = getattr(self, "_prefetched_objects_cache", ())
            for field in fields:
                if field in prefetched_objects_cache:
                    del prefetched_objects_cache[field]
                    fields.remove(field)
            if not fields:
                return
            if any(LOOKUP_SEP in f for f in fields):
                raise ValueError(
                    'Found "%s" in fields argument. Relations and transforms '
                    "are not allowed in fields." % LOOKUP_SEP
                )

        hints = {"instance": self}
        db_instance_qs = self.__class__._base_manager.db_manager(
            using, hints=hints
        ).filter(pk=self.pk)

        # Use provided fields, if not set then reload all non-deferred fields.
        deferred_fields = self.get_deferred_fields()
        if fields is not None:
            fields = list(fields)
            db_instance_qs = db_instance_qs.only(*fields)
        elif deferred_fields:
            fields = [
                f.attname
                for f in self._meta.concrete_fields
                if f.attname not in deferred_fields
            ]
            db_instance_qs = db_instance_qs.only(*fields)

        db_instance = db_instance_qs.get()
        non_loaded_fields = db_instance.get_deferred_fields()
        for field in self._meta.concrete_fields:
            if field.attname in non_loaded_fields:
                # This field wasn't refreshed - skip ahead.
                continue
            setattr(self, field.attname, getattr(db_instance, field.attname))
            # Clear cached foreign keys.
            if field.is_relation and field.is_cached(self):
                field.delete_cached_value(self)

        # Clear cached relations.
        for field in self._meta.related_objects:
            if field.is_cached(self):
                field.delete_cached_value(self)

        self._state.db = db_instance._state.db

save

def save(
    self,
    force_insert=False,
    force_update=False,
    using=None,
    update_fields=None
)

Save the current instance. Override this in a subclass if you want to

control the saving process.

The 'force_insert' and 'force_update' parameters can be used to insist that the "save" must be an SQL insert or update (or equivalent for non-SQL backends), respectively. Normally, they should not be set.

View Source
    def save(
        self, force_insert=False, force_update=False, using=None, update_fields=None
    ):
        """
        Save the current instance. Override this in a subclass if you want to
        control the saving process.

        The 'force_insert' and 'force_update' parameters can be used to insist
        that the "save" must be an SQL insert or update (or equivalent for
        non-SQL backends), respectively. Normally, they should not be set.
        """
        self._prepare_related_fields_for_save(operation_name="save")

        using = using or router.db_for_write(self.__class__, instance=self)
        if force_insert and (force_update or update_fields):
            raise ValueError("Cannot force both insert and updating in model saving.")

        deferred_fields = self.get_deferred_fields()
        if update_fields is not None:
            # If update_fields is empty, skip the save. We do also check for
            # no-op saves later on for inheritance cases. This bailout is
            # still needed for skipping signal sending.
            if not update_fields:
                return

            update_fields = frozenset(update_fields)
            field_names = set()

            for field in self._meta.concrete_fields:
                if not field.primary_key:
                    field_names.add(field.name)

                    if field.name != field.attname:
                        field_names.add(field.attname)

            non_model_fields = update_fields.difference(field_names)

            if non_model_fields:
                raise ValueError(
                    "The following fields do not exist in this model, are m2m "
                    "fields, or are non-concrete fields: %s"
                    % ", ".join(non_model_fields)
                )

        # If saving to the same database, and this model is deferred, then
        # automatically do an "update_fields" save on the loaded fields.
        elif not force_insert and deferred_fields and using == self._state.db:
            field_names = set()
            for field in self._meta.concrete_fields:
                if not field.primary_key and not hasattr(field, "through"):
                    field_names.add(field.attname)
            loaded_fields = field_names.difference(deferred_fields)
            if loaded_fields:
                update_fields = frozenset(loaded_fields)

        self.save_base(
            using=using,
            force_insert=force_insert,
            force_update=force_update,
            update_fields=update_fields,
        )

save_base

def save_base(
    self,
    raw=False,
    force_insert=False,
    force_update=False,
    using=None,
    update_fields=None
)

Handle the parts of saving which should be done only once per save,

yet need to be done in raw saves, too. This includes some sanity checks and signal sending.

The 'raw' argument is telling save_base not to save any parent models and not to do any changes to the values before save. This is used by fixture loading.

View Source
    def save_base(
        self,
        raw=False,
        force_insert=False,
        force_update=False,
        using=None,
        update_fields=None,
    ):
        """
        Handle the parts of saving which should be done only once per save,
        yet need to be done in raw saves, too. This includes some sanity
        checks and signal sending.

        The 'raw' argument is telling save_base not to save any parent
        models and not to do any changes to the values before save. This
        is used by fixture loading.
        """
        using = using or router.db_for_write(self.__class__, instance=self)
        assert not (force_insert and (force_update or update_fields))
        assert update_fields is None or update_fields
        cls = origin = self.__class__
        # Skip proxies, but keep the origin as the proxy model.
        if cls._meta.proxy:
            cls = cls._meta.concrete_model
        meta = cls._meta
        if not meta.auto_created:
            pre_save.send(
                sender=origin,
                instance=self,
                raw=raw,
                using=using,
                update_fields=update_fields,
            )
        # A transaction isn't needed if one query is issued.
        if meta.parents:
            context_manager = transaction.atomic(using=using, savepoint=False)
        else:
            context_manager = transaction.mark_for_rollback_on_error(using=using)
        with context_manager:
            parent_inserted = False
            if not raw:
                parent_inserted = self._save_parents(cls, using, update_fields)
            updated = self._save_table(
                raw,
                cls,
                force_insert or parent_inserted,
                force_update,
                using,
                update_fields,
            )
        # Store the database on which the object was saved
        self._state.db = using
        # Once saved, this is no longer a to-be-added instance.
        self._state.adding = False

        # Signal that the save is complete
        if not meta.auto_created:
            post_save.send(
                sender=origin,
                instance=self,
                created=(not updated),
                update_fields=update_fields,
                raw=raw,
                using=using,
            )

serializable_value

def serializable_value(
    self,
    field_name
)

Return the value of the field name for this instance. If the field is

a foreign key, return the id value instead of the object. If there's no Field object with this name on the model, return the model attribute's value.

Used to serialize a field's value (in the serializer, or form output, for example). Normally, you would just access the attribute directly and not use this method.

View Source
    def serializable_value(self, field_name):
        """
        Return the value of the field name for this instance. If the field is
        a foreign key, return the id value instead of the object. If there's
        no Field object with this name on the model, return the model
        attribute's value.

        Used to serialize a field's value (in the serializer, or form output,
        for example). Normally, you would just access the attribute directly
        and not use this method.
        """
        try:
            field = self._meta.get_field(field_name)
        except FieldDoesNotExist:
            return getattr(self, field_name)
        return getattr(self, field.attname)

state

def state(
    ...
)

A wrapper for a deferred-loading field. When the value is read from this

object the first time, the query is executed.

target_num_updates

def target_num_updates(
    ...
)

A wrapper for a deferred-loading field. When the value is read from this

object the first time, the query is executed.

uncertainty_method

def uncertainty_method(
    ...
)

A wrapper for a deferred-loading field. When the value is read from this

object the first time, the query is executed.

unique_error_message

def unique_error_message(
    self,
    model_class,
    unique_check
)
View Source
    def unique_error_message(self, model_class, unique_check):
        opts = model_class._meta

        params = {
            "model": self,
            "model_class": model_class,
            "model_name": capfirst(opts.verbose_name),
            "unique_check": unique_check,
        }

        # A unique field
        if len(unique_check) == 1:
            field = opts.get_field(unique_check[0])
            params["field_label"] = capfirst(field.verbose_name)
            return ValidationError(
                message=field.error_messages["unique"],
                code="unique",
                params=params,
            )

        # unique_together
        else:
            field_labels = [
                capfirst(opts.get_field(f).verbose_name) for f in unique_check
            ]
            params["field_labels"] = get_text_list(field_labels, _("and"))
            return ValidationError(
                message=_("%(model_name)s with this %(field_labels)s already exists."),
                code="unique_together",
                params=params,
            )

validate_unique

def validate_unique(
    self,
    exclude=None
)

Check unique constraints on the model and raise ValidationError if any

failed.

View Source
    def validate_unique(self, exclude=None):
        """
        Check unique constraints on the model and raise ValidationError if any
        failed.
        """
        unique_checks, date_checks = self._get_unique_checks(exclude=exclude)

        errors = self._perform_unique_checks(unique_checks)
        date_errors = self._perform_date_checks(date_checks)

        for k, v in date_errors.items():
            errors.setdefault(k, []).extend(v)

        if errors:
            raise ValidationError(errors)

TrainingState

class TrainingState(
    /,
    *args,
    **kwargs
)

Training state choices for a Training.

View Source
class TrainingState(models.TextChoices):
    """
    Training state choices for a Training.
    """
    INITIAL = "I", _("Initial")
    ONGOING = "O", _("Ongoing")
    COMPLETED = "C", _("Completed")
    ERROR = "E", _("Error")
    SWAG_ROUND = "S", _("SwagRound")

Ancestors (in MRO)

  • django.db.models.enums.TextChoices
  • builtins.str
  • django.db.models.enums.Choices
  • enum.Enum

Class variables

COMPLETED
ERROR
INITIAL
ONGOING
SWAG_ROUND

UncertaintyMethod

class UncertaintyMethod(
    /,
    *args,
    **kwargs
)

Uncertainty method choices for a Training.

View Source
class UncertaintyMethod(models.TextChoices):
    """
    Uncertainty method choices for a Training.
    """
    NONE = "NONE", _("None")
    ENSEMBLE = "ENSEMBLE", _("Ensemble")
    MC_DROPOUT = "MC_DROPOUT", _("MC Dropout")
    SWAG = "SWAG", _("SWAG")

Ancestors (in MRO)

  • django.db.models.enums.TextChoices
  • builtins.str
  • django.db.models.enums.Choices
  • enum.Enum

Class variables

ENSEMBLE
MC_DROPOUT
NONE
SWAG