Skip to content

Module fl_server_ai.tests.test_aggregation

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from django.test import TestCase
import torch

from fl_server_core.utils.torch_serialization import is_torchscript_instance

from ..aggregation.mean import MeanAggregation


def _create_torchscript_model_and_init(init: float) -> torch.jit.ScriptModule:
    init = float(init)
    model = torch.nn.Sequential(
        torch.nn.Linear(1, 5),
        torch.nn.Tanh(),
        torch.nn.BatchNorm1d(5),
        torch.nn.Linear(5, 3)
    )
    torch.nn.init.constant_(model[0].weight, init)
    torch.nn.init.constant_(model[0].bias, init)
    torch.nn.init.constant_(model[3].weight, init)
    torch.nn.init.constant_(model[3].bias, init)
    return torch.jit.script(model)


class AggregationTest(TestCase):

    def test_aggregate(self):
        aggr = MeanAggregation()
        models = [_create_torchscript_model_and_init(i) for i in range(10)]
        model = aggr.aggregate(models, [1]*10)
        cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
        self.assertEqual(cls_name, "Sequential")
        res = model.state_dict()
        self.assertEqual(len(models[0].state_dict()), len(res))
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 4.5)
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 4.5)

    def test_aggregate_sample_sizes(self):
        aggr = MeanAggregation()
        models = [_create_torchscript_model_and_init(i) for i in range(3)]
        model = aggr.aggregate(models, [0, 1, 2])
        cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
        self.assertEqual(cls_name, "Sequential")
        self.assertEqual(len(list(models[0].parameters())), len(list(model.parameters())))
        res = model.state_dict()
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * (5/3))
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * (5/3))

Classes

AggregationTest

class AggregationTest(
    methodName='runTest'
)

Similar to TransactionTestCase, but use transaction.atomic() to achieve

test isolation.

In most situations, TestCase should be preferred to TransactionTestCase as it allows faster execution. However, there are some situations where using TransactionTestCase might be necessary (e.g. testing some transactional behavior).

On database backends with no transaction support, TestCase behaves as TransactionTestCase.

View Source
class AggregationTest(TestCase):

    def test_aggregate(self):
        aggr = MeanAggregation()
        models = [_create_torchscript_model_and_init(i) for i in range(10)]
        model = aggr.aggregate(models, [1]*10)
        cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
        self.assertEqual(cls_name, "Sequential")
        res = model.state_dict()
        self.assertEqual(len(models[0].state_dict()), len(res))
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 4.5)
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 4.5)

    def test_aggregate_sample_sizes(self):
        aggr = MeanAggregation()
        models = [_create_torchscript_model_and_init(i) for i in range(3)]
        model = aggr.aggregate(models, [0, 1, 2])
        cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
        self.assertEqual(cls_name, "Sequential")
        self.assertEqual(len(list(models[0].parameters())), len(list(model.parameters())))
        res = model.state_dict()
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * (5/3))
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * (5/3))

Ancestors (in MRO)

  • django.test.testcases.TestCase
  • django.test.testcases.TransactionTestCase
  • django.test.testcases.SimpleTestCase
  • unittest.case.TestCase

Class variables

async_client_class
available_apps
client_class
databases
failureException
fixtures
longMessage
maxDiff
reset_sequences
serialized_rollback

Static methods

addClassCleanup

def addClassCleanup(
    function,
    /,
    *args,
    **kwargs
)

Same as addCleanup, except the cleanup items are called even if

setUpClass fails (unlike tearDownClass).

View Source
    @classmethod
    def addClassCleanup(cls, function, /, *args, **kwargs):
        """Same as addCleanup, except the cleanup items are called even if
        setUpClass fails (unlike tearDownClass)."""
        cls._class_cleanups.append((function, args, kwargs))

captureOnCommitCallbacks

def captureOnCommitCallbacks(
    *,
    using='default',
    execute=False
)

Context manager to capture transaction.on_commit() callbacks.

View Source
    @classmethod
    @contextmanager
    def captureOnCommitCallbacks(cls, *, using=DEFAULT_DB_ALIAS, execute=False):
        """Context manager to capture transaction.on_commit() callbacks."""
        callbacks = []
        start_count = len(connections[using].run_on_commit)
        try:
            yield callbacks
        finally:
            while True:
                callback_count = len(connections[using].run_on_commit)
                for _, callback in connections[using].run_on_commit[start_count:]:
                    callbacks.append(callback)
                    if execute:
                        callback()

                if callback_count == len(connections[using].run_on_commit):
                    break
                start_count = callback_count

doClassCleanups

def doClassCleanups()

Execute all class cleanup functions. Normally called for you after

tearDownClass.

View Source
    @classmethod
    def doClassCleanups(cls):
        """Execute all class cleanup functions. Normally called for you after
        tearDownClass."""
        cls.tearDown_exceptions = []
        while cls._class_cleanups:
            function, args, kwargs = cls._class_cleanups.pop()
            try:
                function(*args, **kwargs)
            except Exception:
                cls.tearDown_exceptions.append(sys.exc_info())

setUpClass

def setUpClass()

Hook method for setting up class fixture before running tests in the class.

View Source
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        if not cls._databases_support_transactions():
            return
        # Disable the durability check to allow testing durable atomic blocks
        # in a transaction for performance reasons.
        transaction.Atomic._ensure_durability = False
        try:
            cls.cls_atomics = cls._enter_atomics()

            if cls.fixtures:
                for db_name in cls._databases_names(include_mirrors=False):
                    try:
                        call_command(
                            "loaddata",
                            *cls.fixtures,
                            **{"verbosity": 0, "database": db_name},
                        )
                    except Exception:
                        cls._rollback_atomics(cls.cls_atomics)
                        raise
            pre_attrs = cls.__dict__.copy()
            try:
                cls.setUpTestData()
            except Exception:
                cls._rollback_atomics(cls.cls_atomics)
                raise
            for name, value in cls.__dict__.items():
                if value is not pre_attrs.get(name):
                    setattr(cls, name, TestData(name, value))
        except Exception:
            transaction.Atomic._ensure_durability = True
            raise

setUpTestData

def setUpTestData()

Load initial data for the TestCase.

View Source
    @classmethod
    def setUpTestData(cls):
        """Load initial data for the TestCase."""
        pass

tearDownClass

def tearDownClass()

Hook method for deconstructing the class fixture after running all tests in the class.

View Source
    @classmethod
    def tearDownClass(cls):
        transaction.Atomic._ensure_durability = True
        if cls._databases_support_transactions():
            cls._rollback_atomics(cls.cls_atomics)
            for conn in connections.all():
                conn.close()
        super().tearDownClass()

Methods

addCleanup

def addCleanup(
    self,
    function,
    /,
    *args,
    **kwargs
)

Add a function, with arguments, to be called when the test is

completed. Functions added are called on a LIFO basis and are called after tearDown on test failure or success.

Cleanup items are called even if setUp fails (unlike tearDown).

View Source
    def addCleanup(self, function, /, *args, **kwargs):
        """Add a function, with arguments, to be called when the test is
        completed. Functions added are called on a LIFO basis and are
        called after tearDown on test failure or success.

        Cleanup items are called even if setUp fails (unlike tearDown)."""
        self._cleanups.append((function, args, kwargs))

addTypeEqualityFunc

def addTypeEqualityFunc(
    self,
    typeobj,
    function
)

Add a type specific assertEqual style function to compare a type.

This method is for use by TestCase subclasses that need to register their own type equality functions to provide nicer error messages.

Parameters:

Name Type Description Default
typeobj None The data type to call this function on when both values
are of the same type in assertEqual().
None
function None The callable taking two arguments and an optional
msg= argument that raises self.failureException with a
useful error message when the two arguments are not equal.
None
View Source
    def addTypeEqualityFunc(self, typeobj, function):
        """Add a type specific assertEqual style function to compare a type.

        This method is for use by TestCase subclasses that need to register
        their own type equality functions to provide nicer error messages.

        Args:
            typeobj: The data type to call this function on when both values
                    are of the same type in assertEqual().
            function: The callable taking two arguments and an optional
                    msg= argument that raises self.failureException with a
                    useful error message when the two arguments are not equal.
        """
        self._type_equality_funcs[typeobj] = function

assertAlmostEqual

def assertAlmostEqual(
    self,
    first,
    second,
    places=None,
    msg=None,
    delta=None
)

Fail if the two objects are unequal as determined by their

difference rounded to the given number of decimal places (default 7) and comparing to zero, or by comparing that the difference between the two objects is more than the given delta.

Note that decimal places (from zero) are usually not the same as significant digits (measured from the most significant digit).

If the two objects compare equal then they will automatically compare almost equal.

View Source
    def assertAlmostEqual(self, first, second, places=None, msg=None,
                          delta=None):
        """Fail if the two objects are unequal as determined by their
           difference rounded to the given number of decimal places
           (default 7) and comparing to zero, or by comparing that the
           difference between the two objects is more than the given
           delta.

           Note that decimal places (from zero) are usually not the same
           as significant digits (measured from the most significant digit).

           If the two objects compare equal then they will automatically
           compare almost equal.
        """
        if first == second:
            # shortcut
            return
        if delta is not None and places is not None:
            raise TypeError("specify delta or places not both")

        diff = abs(first - second)
        if delta is not None:
            if diff <= delta:
                return

            standardMsg = '%s != %s within %s delta (%s difference)' % (
                safe_repr(first),
                safe_repr(second),
                safe_repr(delta),
                safe_repr(diff))
        else:
            if places is None:
                places = 7

            if round(diff, places) == 0:
                return

            standardMsg = '%s != %s within %r places (%s difference)' % (
                safe_repr(first),
                safe_repr(second),
                places,
                safe_repr(diff))
        msg = self._formatMessage(msg, standardMsg)
        raise self.failureException(msg)

assertAlmostEquals

def assertAlmostEquals(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

assertContains

def assertContains(
    self,
    response,
    text,
    count=None,
    status_code=200,
    msg_prefix='',
    html=False
)

Assert that a response indicates that some content was retrieved

successfully, (i.e., the HTTP status code was as expected) and that text occurs count times in the content of the response. If count is None, the count doesn't matter - the assertion is true if the text occurs at least once in the response.

View Source
    def assertContains(
        self, response, text, count=None, status_code=200, msg_prefix="", html=False
    ):
        """
        Assert that a response indicates that some content was retrieved
        successfully, (i.e., the HTTP status code was as expected) and that
        ``text`` occurs ``count`` times in the content of the response.
        If ``count`` is None, the count doesn't matter - the assertion is true
        if the text occurs at least once in the response.
        """
        text_repr, real_count, msg_prefix = self._assert_contains(
            response, text, status_code, msg_prefix, html
        )

        if count is not None:
            self.assertEqual(
                real_count,
                count,
                msg_prefix
                + "Found %d instances of %s in response (expected %d)"
                % (real_count, text_repr, count),
            )
        else:
            self.assertTrue(
                real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr
            )

assertCountEqual

def assertCountEqual(
    self,
    first,
    second,
    msg=None
)

Asserts that two iterables have the same elements, the same number of

times, without regard to order.

self.assertEqual(Counter(list(first)),
                 Counter(list(second)))

Example: - [0, 1, 1] and [1, 0, 1] compare equal. - [0, 0, 1] and [0, 1] compare unequal.

View Source
    def assertCountEqual(self, first, second, msg=None):
        """Asserts that two iterables have the same elements, the same number of
        times, without regard to order.

            self.assertEqual(Counter(list(first)),
                             Counter(list(second)))

         Example:
            - [0, 1, 1] and [1, 0, 1] compare equal.
            - [0, 0, 1] and [0, 1] compare unequal.

        """
        first_seq, second_seq = list(first), list(second)
        try:
            first = collections.Counter(first_seq)
            second = collections.Counter(second_seq)
        except TypeError:
            # Handle case with unhashable elements
            differences = _count_diff_all_purpose(first_seq, second_seq)
        else:
            if first == second:
                return
            differences = _count_diff_hashable(first_seq, second_seq)

        if differences:
            standardMsg = 'Element counts were not equal:\n'
            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
            diffMsg = '\n'.join(lines)
            standardMsg = self._truncateMessage(standardMsg, diffMsg)
            msg = self._formatMessage(msg, standardMsg)
            self.fail(msg)

assertDictContainsSubset

def assertDictContainsSubset(
    self,
    subset,
    dictionary,
    msg=None
)

Checks whether dictionary is a superset of subset.

View Source
    def assertDictContainsSubset(self, subset, dictionary, msg=None):
        """Checks whether dictionary is a superset of subset."""
        warnings.warn('assertDictContainsSubset is deprecated',
                      DeprecationWarning,
                      stacklevel=2)
        missing = []
        mismatched = []
        for key, value in subset.items():
            if key not in dictionary:
                missing.append(key)
            elif value != dictionary[key]:
                mismatched.append('%s, expected: %s, actual: %s' %
                                  (safe_repr(key), safe_repr(value),
                                   safe_repr(dictionary[key])))

        if not (missing or mismatched):
            return

        standardMsg = ''
        if missing:
            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
                                                    missing)
        if mismatched:
            if standardMsg:
                standardMsg += '; '
            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)

        self.fail(self._formatMessage(msg, standardMsg))

assertDictEqual

def assertDictEqual(
    self,
    d1,
    d2,
    msg=None
)
View Source
    def assertDictEqual(self, d1, d2, msg=None):
        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')

        if d1 != d2:
            standardMsg = '%s != %s' % _common_shorten_repr(d1, d2)
            diff = ('\n' + '\n'.join(difflib.ndiff(
                           pprint.pformat(d1).splitlines(),
                           pprint.pformat(d2).splitlines())))
            standardMsg = self._truncateMessage(standardMsg, diff)
            self.fail(self._formatMessage(msg, standardMsg))

assertEqual

def assertEqual(
    self,
    first,
    second,
    msg=None
)

Fail if the two objects are unequal as determined by the '=='

operator.

View Source
    def assertEqual(self, first, second, msg=None):
        """Fail if the two objects are unequal as determined by the '=='
           operator.
        """
        assertion_func = self._getAssertEqualityFunc(first, second)
        assertion_func(first, second, msg=msg)

assertEquals

def assertEquals(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

assertFalse

def assertFalse(
    self,
    expr,
    msg=None
)

Check that the expression is false.

View Source
    def assertFalse(self, expr, msg=None):
        """Check that the expression is false."""
        if expr:
            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
            raise self.failureException(msg)

assertFieldOutput

def assertFieldOutput(
    self,
    fieldclass,
    valid,
    invalid,
    field_args=None,
    field_kwargs=None,
    empty_value=''
)

Assert that a form field behaves correctly with various inputs.

Parameters:

Name Type Description Default
fieldclass None the class of the field to be tested. None
valid None a dictionary mapping valid inputs to their expected
cleaned values.
None
invalid None a dictionary mapping invalid inputs to one or more
raised error messages.
None
field_args None the args passed to instantiate the field None
field_kwargs None the kwargs passed to instantiate the field None
empty_value None the expected clean output for inputs in empty_values None
View Source
    def assertFieldOutput(
        self,
        fieldclass,
        valid,
        invalid,
        field_args=None,
        field_kwargs=None,
        empty_value="",
    ):
        """
        Assert that a form field behaves correctly with various inputs.

        Args:
            fieldclass: the class of the field to be tested.
            valid: a dictionary mapping valid inputs to their expected
                    cleaned values.
            invalid: a dictionary mapping invalid inputs to one or more
                    raised error messages.
            field_args: the args passed to instantiate the field
            field_kwargs: the kwargs passed to instantiate the field
            empty_value: the expected clean output for inputs in empty_values
        """
        if field_args is None:
            field_args = []
        if field_kwargs is None:
            field_kwargs = {}
        required = fieldclass(*field_args, **field_kwargs)
        optional = fieldclass(*field_args, **{**field_kwargs, "required": False})
        # test valid inputs
        for input, output in valid.items():
            self.assertEqual(required.clean(input), output)
            self.assertEqual(optional.clean(input), output)
        # test invalid inputs
        for input, errors in invalid.items():
            with self.assertRaises(ValidationError) as context_manager:
                required.clean(input)
            self.assertEqual(context_manager.exception.messages, errors)

            with self.assertRaises(ValidationError) as context_manager:
                optional.clean(input)
            self.assertEqual(context_manager.exception.messages, errors)
        # test required inputs
        error_required = [required.error_messages["required"]]
        for e in required.empty_values:
            with self.assertRaises(ValidationError) as context_manager:
                required.clean(e)
            self.assertEqual(context_manager.exception.messages, error_required)
            self.assertEqual(optional.clean(e), empty_value)
        # test that max_length and min_length are always accepted
        if issubclass(fieldclass, CharField):
            field_kwargs.update({"min_length": 2, "max_length": 20})
            self.assertIsInstance(fieldclass(*field_args, **field_kwargs), fieldclass)

assertFormError

def assertFormError(
    self,
    response,
    form,
    field,
    errors,
    msg_prefix=''
)

Assert that a form used to render the response has a specific field

error.

View Source
    def assertFormError(self, response, form, field, errors, msg_prefix=""):
        """
        Assert that a form used to render the response has a specific field
        error.
        """
        if msg_prefix:
            msg_prefix += ": "

        # Put context(s) into a list to simplify processing.
        contexts = to_list(response.context)
        if not contexts:
            self.fail(
                msg_prefix + "Response did not use any contexts to render the response"
            )

        # Put error(s) into a list to simplify processing.
        errors = to_list(errors)

        # Search all contexts for the error.
        found_form = False
        for i, context in enumerate(contexts):
            if form not in context:
                continue
            found_form = True
            for err in errors:
                if field:
                    if field in context[form].errors:
                        field_errors = context[form].errors[field]
                        self.assertTrue(
                            err in field_errors,
                            msg_prefix + "The field '%s' on form '%s' in"
                            " context %d does not contain the error '%s'"
                            " (actual errors: %s)"
                            % (field, form, i, err, repr(field_errors)),
                        )
                    elif field in context[form].fields:
                        self.fail(
                            msg_prefix
                            + (
                                "The field '%s' on form '%s' in context %d contains no "
                                "errors"
                            )
                            % (field, form, i)
                        )
                    else:
                        self.fail(
                            msg_prefix
                            + (
                                "The form '%s' in context %d does not contain the "
                                "field '%s'"
                            )
                            % (form, i, field)
                        )
                else:
                    non_field_errors = context[form].non_field_errors()
                    self.assertTrue(
                        err in non_field_errors,
                        msg_prefix + "The form '%s' in context %d does not"
                        " contain the non-field error '%s'"
                        " (actual errors: %s)"
                        % (form, i, err, non_field_errors or "none"),
                    )
        if not found_form:
            self.fail(
                msg_prefix + "The form '%s' was not used to render the response" % form
            )

assertFormsetError

def assertFormsetError(
    self,
    response,
    formset,
    form_index,
    field,
    errors,
    msg_prefix=''
)

Assert that a formset used to render the response has a specific error.

For field errors, specify the form_index and the field. For non-field errors, specify the form_index and the field as None. For non-form errors, specify form_index as None and the field as None.

View Source
    def assertFormsetError(
        self, response, formset, form_index, field, errors, msg_prefix=""
    ):
        """
        Assert that a formset used to render the response has a specific error.

        For field errors, specify the ``form_index`` and the ``field``.
        For non-field errors, specify the ``form_index`` and the ``field`` as
        None.
        For non-form errors, specify ``form_index`` as None and the ``field``
        as None.
        """
        # Add punctuation to msg_prefix
        if msg_prefix:
            msg_prefix += ": "

        # Put context(s) into a list to simplify processing.
        contexts = to_list(response.context)
        if not contexts:
            self.fail(
                msg_prefix + "Response did not use any contexts to "
                "render the response"
            )

        # Put error(s) into a list to simplify processing.
        errors = to_list(errors)

        # Search all contexts for the error.
        found_formset = False
        for i, context in enumerate(contexts):
            if formset not in context or not hasattr(context[formset], "forms"):
                continue
            found_formset = True
            for err in errors:
                if field is not None:
                    if field in context[formset].forms[form_index].errors:
                        field_errors = context[formset].forms[form_index].errors[field]
                        self.assertTrue(
                            err in field_errors,
                            msg_prefix + "The field '%s' on formset '%s', "
                            "form %d in context %d does not contain the "
                            "error '%s' (actual errors: %s)"
                            % (field, formset, form_index, i, err, repr(field_errors)),
                        )
                    elif field in context[formset].forms[form_index].fields:
                        self.fail(
                            msg_prefix
                            + (
                                "The field '%s' on formset '%s', form %d in context "
                                "%d contains no errors"
                            )
                            % (field, formset, form_index, i)
                        )
                    else:
                        self.fail(
                            msg_prefix
                            + (
                                "The formset '%s', form %d in context %d does not "
                                "contain the field '%s'"
                            )
                            % (formset, form_index, i, field)
                        )
                elif form_index is not None:
                    non_field_errors = (
                        context[formset].forms[form_index].non_field_errors()
                    )
                    self.assertFalse(
                        not non_field_errors,
                        msg_prefix + "The formset '%s', form %d in context %d "
                        "does not contain any non-field errors."
                        % (formset, form_index, i),
                    )
                    self.assertTrue(
                        err in non_field_errors,
                        msg_prefix + "The formset '%s', form %d in context %d "
                        "does not contain the non-field error '%s' (actual errors: %s)"
                        % (formset, form_index, i, err, repr(non_field_errors)),
                    )
                else:
                    non_form_errors = context[formset].non_form_errors()
                    self.assertFalse(
                        not non_form_errors,
                        msg_prefix + "The formset '%s' in context %d does not "
                        "contain any non-form errors." % (formset, i),
                    )
                    self.assertTrue(
                        err in non_form_errors,
                        msg_prefix + "The formset '%s' in context %d does not "
                        "contain the non-form error '%s' (actual errors: %s)"
                        % (formset, i, err, repr(non_form_errors)),
                    )
        if not found_formset:
            self.fail(
                msg_prefix
                + "The formset '%s' was not used to render the response" % formset
            )

assertGreater

def assertGreater(
    self,
    a,
    b,
    msg=None
)

Just like self.assertTrue(a > b), but with a nicer default message.

View Source
    def assertGreater(self, a, b, msg=None):
        """Just like self.assertTrue(a > b), but with a nicer default message."""
        if not a > b:
            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
            self.fail(self._formatMessage(msg, standardMsg))

assertGreaterEqual

def assertGreaterEqual(
    self,
    a,
    b,
    msg=None
)

Just like self.assertTrue(a >= b), but with a nicer default message.

View Source
    def assertGreaterEqual(self, a, b, msg=None):
        """Just like self.assertTrue(a >= b), but with a nicer default message."""
        if not a >= b:
            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
            self.fail(self._formatMessage(msg, standardMsg))

assertHTMLEqual

def assertHTMLEqual(
    self,
    html1,
    html2,
    msg=None
)

Assert that two HTML snippets are semantically the same.

Whitespace in most cases is ignored, and attribute ordering is not significant. The arguments must be valid HTML.

View Source
    def assertHTMLEqual(self, html1, html2, msg=None):
        """
        Assert that two HTML snippets are semantically the same.
        Whitespace in most cases is ignored, and attribute ordering is not
        significant. The arguments must be valid HTML.
        """
        dom1 = assert_and_parse_html(
            self, html1, msg, "First argument is not valid HTML:"
        )
        dom2 = assert_and_parse_html(
            self, html2, msg, "Second argument is not valid HTML:"
        )

        if dom1 != dom2:
            standardMsg = "%s != %s" % (safe_repr(dom1, True), safe_repr(dom2, True))
            diff = "\n" + "\n".join(
                difflib.ndiff(
                    str(dom1).splitlines(),
                    str(dom2).splitlines(),
                )
            )
            standardMsg = self._truncateMessage(standardMsg, diff)
            self.fail(self._formatMessage(msg, standardMsg))

assertHTMLNotEqual

def assertHTMLNotEqual(
    self,
    html1,
    html2,
    msg=None
)

Assert that two HTML snippets are not semantically equivalent.

View Source
    def assertHTMLNotEqual(self, html1, html2, msg=None):
        """Assert that two HTML snippets are not semantically equivalent."""
        dom1 = assert_and_parse_html(
            self, html1, msg, "First argument is not valid HTML:"
        )
        dom2 = assert_and_parse_html(
            self, html2, msg, "Second argument is not valid HTML:"
        )

        if dom1 == dom2:
            standardMsg = "%s == %s" % (safe_repr(dom1, True), safe_repr(dom2, True))
            self.fail(self._formatMessage(msg, standardMsg))

assertIn

def assertIn(
    self,
    member,
    container,
    msg=None
)

Just like self.assertTrue(a in b), but with a nicer default message.

View Source
    def assertIn(self, member, container, msg=None):
        """Just like self.assertTrue(a in b), but with a nicer default message."""
        if member not in container:
            standardMsg = '%s not found in %s' % (safe_repr(member),
                                                  safe_repr(container))
            self.fail(self._formatMessage(msg, standardMsg))

assertInHTML

def assertInHTML(
    self,
    needle,
    haystack,
    count=None,
    msg_prefix=''
)
View Source
    def assertInHTML(self, needle, haystack, count=None, msg_prefix=""):
        needle = assert_and_parse_html(
            self, needle, None, "First argument is not valid HTML:"
        )
        haystack = assert_and_parse_html(
            self, haystack, None, "Second argument is not valid HTML:"
        )
        real_count = haystack.count(needle)
        if count is not None:
            self.assertEqual(
                real_count,
                count,
                msg_prefix
                + "Found %d instances of '%s' in response (expected %d)"
                % (real_count, needle, count),
            )
        else:
            self.assertTrue(
                real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle
            )

assertIs

def assertIs(
    self,
    expr1,
    expr2,
    msg=None
)

Just like self.assertTrue(a is b), but with a nicer default message.

View Source
    def assertIs(self, expr1, expr2, msg=None):
        """Just like self.assertTrue(a is b), but with a nicer default message."""
        if expr1 is not expr2:
            standardMsg = '%s is not %s' % (safe_repr(expr1),
                                             safe_repr(expr2))
            self.fail(self._formatMessage(msg, standardMsg))

assertIsInstance

def assertIsInstance(
    self,
    obj,
    cls,
    msg=None
)

Same as self.assertTrue(isinstance(obj, cls)), with a nicer

default message.

View Source
    def assertIsInstance(self, obj, cls, msg=None):
        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
        default message."""
        if not isinstance(obj, cls):
            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
            self.fail(self._formatMessage(msg, standardMsg))

assertIsNone

def assertIsNone(
    self,
    obj,
    msg=None
)

Same as self.assertTrue(obj is None), with a nicer default message.

View Source
    def assertIsNone(self, obj, msg=None):
        """Same as self.assertTrue(obj is None), with a nicer default message."""
        if obj is not None:
            standardMsg = '%s is not None' % (safe_repr(obj),)
            self.fail(self._formatMessage(msg, standardMsg))

assertIsNot

def assertIsNot(
    self,
    expr1,
    expr2,
    msg=None
)

Just like self.assertTrue(a is not b), but with a nicer default message.

View Source
    def assertIsNot(self, expr1, expr2, msg=None):
        """Just like self.assertTrue(a is not b), but with a nicer default message."""
        if expr1 is expr2:
            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
            self.fail(self._formatMessage(msg, standardMsg))

assertIsNotNone

def assertIsNotNone(
    self,
    obj,
    msg=None
)

Included for symmetry with assertIsNone.

View Source
    def assertIsNotNone(self, obj, msg=None):
        """Included for symmetry with assertIsNone."""
        if obj is None:
            standardMsg = 'unexpectedly None'
            self.fail(self._formatMessage(msg, standardMsg))

assertJSONEqual

def assertJSONEqual(
    self,
    raw,
    expected_data,
    msg=None
)

Assert that the JSON fragments raw and expected_data are equal.

Usual JSON non-significant whitespace rules apply as the heavyweight is delegated to the json library.

View Source
    def assertJSONEqual(self, raw, expected_data, msg=None):
        """
        Assert that the JSON fragments raw and expected_data are equal.
        Usual JSON non-significant whitespace rules apply as the heavyweight
        is delegated to the json library.
        """
        try:
            data = json.loads(raw)
        except json.JSONDecodeError:
            self.fail("First argument is not valid JSON: %r" % raw)
        if isinstance(expected_data, str):
            try:
                expected_data = json.loads(expected_data)
            except ValueError:
                self.fail("Second argument is not valid JSON: %r" % expected_data)
        self.assertEqual(data, expected_data, msg=msg)

assertJSONNotEqual

def assertJSONNotEqual(
    self,
    raw,
    expected_data,
    msg=None
)

Assert that the JSON fragments raw and expected_data are not equal.

Usual JSON non-significant whitespace rules apply as the heavyweight is delegated to the json library.

View Source
    def assertJSONNotEqual(self, raw, expected_data, msg=None):
        """
        Assert that the JSON fragments raw and expected_data are not equal.
        Usual JSON non-significant whitespace rules apply as the heavyweight
        is delegated to the json library.
        """
        try:
            data = json.loads(raw)
        except json.JSONDecodeError:
            self.fail("First argument is not valid JSON: %r" % raw)
        if isinstance(expected_data, str):
            try:
                expected_data = json.loads(expected_data)
            except json.JSONDecodeError:
                self.fail("Second argument is not valid JSON: %r" % expected_data)
        self.assertNotEqual(data, expected_data, msg=msg)

assertLess

def assertLess(
    self,
    a,
    b,
    msg=None
)

Just like self.assertTrue(a < b), but with a nicer default message.

View Source
    def assertLess(self, a, b, msg=None):
        """Just like self.assertTrue(a < b), but with a nicer default message."""
        if not a < b:
            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
            self.fail(self._formatMessage(msg, standardMsg))

assertLessEqual

def assertLessEqual(
    self,
    a,
    b,
    msg=None
)

Just like self.assertTrue(a <= b), but with a nicer default message.

View Source
    def assertLessEqual(self, a, b, msg=None):
        """Just like self.assertTrue(a <= b), but with a nicer default message."""
        if not a <= b:
            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
            self.fail(self._formatMessage(msg, standardMsg))

assertListEqual

def assertListEqual(
    self,
    list1,
    list2,
    msg=None
)

A list-specific equality assertion.

Parameters:

Name Type Description Default
list1 None The first list to compare. None
list2 None The second list to compare. None
msg None Optional message to use on failure instead of a list of
differences.
None
View Source
    def assertListEqual(self, list1, list2, msg=None):
        """A list-specific equality assertion.

        Args:
            list1: The first list to compare.
            list2: The second list to compare.
            msg: Optional message to use on failure instead of a list of
                    differences.

        """
        self.assertSequenceEqual(list1, list2, msg, seq_type=list)

assertLogs

def assertLogs(
    self,
    logger=None,
    level=None
)

Fail unless a log message of level level or higher is emitted

on logger_name or its children. If omitted, level defaults to INFO and logger defaults to the root logger.

This method must be used as a context manager, and will yield a recording object with two attributes: output and records. At the end of the context manager, the output attribute will be a list of the matching formatted log messages and the records attribute will be a list of the corresponding LogRecord objects.

Example::

with self.assertLogs('foo', level='INFO') as cm:
    logging.getLogger('foo').info('first message')
    logging.getLogger('foo.bar').error('second message')
self.assertEqual(cm.output, ['INFO:foo:first message',
                             'ERROR:foo.bar:second message'])
View Source
    def assertLogs(self, logger=None, level=None):
        """Fail unless a log message of level *level* or higher is emitted
        on *logger_name* or its children.  If omitted, *level* defaults to
        INFO and *logger* defaults to the root logger.

        This method must be used as a context manager, and will yield
        a recording object with two attributes: `output` and `records`.
        At the end of the context manager, the `output` attribute will
        be a list of the matching formatted log messages and the
        `records` attribute will be a list of the corresponding LogRecord
        objects.

        Example::

            with self.assertLogs('foo', level='INFO') as cm:
                logging.getLogger('foo').info('first message')
                logging.getLogger('foo.bar').error('second message')
            self.assertEqual(cm.output, ['INFO:foo:first message',
                                         'ERROR:foo.bar:second message'])
        """
        # Lazy import to avoid importing logging if it is not needed.
        from ._log import _AssertLogsContext
        return _AssertLogsContext(self, logger, level, no_logs=False)

assertMultiLineEqual

def assertMultiLineEqual(
    self,
    first,
    second,
    msg=None
)

Assert that two multi-line strings are equal.

View Source
    def assertMultiLineEqual(self, first, second, msg=None):
        """Assert that two multi-line strings are equal."""
        self.assertIsInstance(first, str, 'First argument is not a string')
        self.assertIsInstance(second, str, 'Second argument is not a string')

        if first != second:
            # don't use difflib if the strings are too long
            if (len(first) > self._diffThreshold or
                len(second) > self._diffThreshold):
                self._baseAssertEqual(first, second, msg)
            firstlines = first.splitlines(keepends=True)
            secondlines = second.splitlines(keepends=True)
            if len(firstlines) == 1 and first.strip('\r\n') == first:
                firstlines = [first + '\n']
                secondlines = [second + '\n']
            standardMsg = '%s != %s' % _common_shorten_repr(first, second)
            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
            standardMsg = self._truncateMessage(standardMsg, diff)
            self.fail(self._formatMessage(msg, standardMsg))

assertNoLogs

def assertNoLogs(
    self,
    logger=None,
    level=None
)

Fail unless no log messages of level level or higher are emitted

on logger_name or its children.

This method must be used as a context manager.

View Source
    def assertNoLogs(self, logger=None, level=None):
        """ Fail unless no log messages of level *level* or higher are emitted
        on *logger_name* or its children.

        This method must be used as a context manager.
        """
        from ._log import _AssertLogsContext
        return _AssertLogsContext(self, logger, level, no_logs=True)

assertNotAlmostEqual

def assertNotAlmostEqual(
    self,
    first,
    second,
    places=None,
    msg=None,
    delta=None
)

Fail if the two objects are equal as determined by their

difference rounded to the given number of decimal places (default 7) and comparing to zero, or by comparing that the difference between the two objects is less than the given delta.

Note that decimal places (from zero) are usually not the same as significant digits (measured from the most significant digit).

Objects that are equal automatically fail.

View Source
    def assertNotAlmostEqual(self, first, second, places=None, msg=None,
                             delta=None):
        """Fail if the two objects are equal as determined by their
           difference rounded to the given number of decimal places
           (default 7) and comparing to zero, or by comparing that the
           difference between the two objects is less than the given delta.

           Note that decimal places (from zero) are usually not the same
           as significant digits (measured from the most significant digit).

           Objects that are equal automatically fail.
        """
        if delta is not None and places is not None:
            raise TypeError("specify delta or places not both")
        diff = abs(first - second)
        if delta is not None:
            if not (first == second) and diff > delta:
                return
            standardMsg = '%s == %s within %s delta (%s difference)' % (
                safe_repr(first),
                safe_repr(second),
                safe_repr(delta),
                safe_repr(diff))
        else:
            if places is None:
                places = 7
            if not (first == second) and round(diff, places) != 0:
                return
            standardMsg = '%s == %s within %r places' % (safe_repr(first),
                                                         safe_repr(second),
                                                         places)

        msg = self._formatMessage(msg, standardMsg)
        raise self.failureException(msg)

assertNotAlmostEquals

def assertNotAlmostEquals(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

assertNotContains

def assertNotContains(
    self,
    response,
    text,
    status_code=200,
    msg_prefix='',
    html=False
)

Assert that a response indicates that some content was retrieved

successfully, (i.e., the HTTP status code was as expected) and that text doesn't occur in the content of the response.

View Source
    def assertNotContains(
        self, response, text, status_code=200, msg_prefix="", html=False
    ):
        """
        Assert that a response indicates that some content was retrieved
        successfully, (i.e., the HTTP status code was as expected) and that
        ``text`` doesn't occur in the content of the response.
        """
        text_repr, real_count, msg_prefix = self._assert_contains(
            response, text, status_code, msg_prefix, html
        )

        self.assertEqual(
            real_count, 0, msg_prefix + "Response should not contain %s" % text_repr
        )

assertNotEqual

def assertNotEqual(
    self,
    first,
    second,
    msg=None
)

Fail if the two objects are equal as determined by the '!='

operator.

View Source
    def assertNotEqual(self, first, second, msg=None):
        """Fail if the two objects are equal as determined by the '!='
           operator.
        """
        if not first != second:
            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
                                                          safe_repr(second)))
            raise self.failureException(msg)

assertNotEquals

def assertNotEquals(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

assertNotIn

def assertNotIn(
    self,
    member,
    container,
    msg=None
)

Just like self.assertTrue(a not in b), but with a nicer default message.

View Source
    def assertNotIn(self, member, container, msg=None):
        """Just like self.assertTrue(a not in b), but with a nicer default message."""
        if member in container:
            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
                                                        safe_repr(container))
            self.fail(self._formatMessage(msg, standardMsg))

assertNotIsInstance

def assertNotIsInstance(
    self,
    obj,
    cls,
    msg=None
)

Included for symmetry with assertIsInstance.

View Source
    def assertNotIsInstance(self, obj, cls, msg=None):
        """Included for symmetry with assertIsInstance."""
        if isinstance(obj, cls):
            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
            self.fail(self._formatMessage(msg, standardMsg))

assertNotRegex

def assertNotRegex(
    self,
    text,
    unexpected_regex,
    msg=None
)

Fail the test if the text matches the regular expression.

View Source
    def assertNotRegex(self, text, unexpected_regex, msg=None):
        """Fail the test if the text matches the regular expression."""
        if isinstance(unexpected_regex, (str, bytes)):
            unexpected_regex = re.compile(unexpected_regex)
        match = unexpected_regex.search(text)
        if match:
            standardMsg = 'Regex matched: %r matches %r in %r' % (
                text[match.start() : match.end()],
                unexpected_regex.pattern,
                text)
            # _formatMessage ensures the longMessage option is respected
            msg = self._formatMessage(msg, standardMsg)
            raise self.failureException(msg)

assertNotRegexpMatches

def assertNotRegexpMatches(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

assertNumQueries

def assertNumQueries(
    self,
    num,
    func=None,
    *args,
    using='default',
    **kwargs
)
View Source
    def assertNumQueries(self, num, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs):
        conn = connections[using]

        context = _AssertNumQueriesContext(self, num, conn)
        if func is None:
            return context

        with context:
            func(*args, **kwargs)

assertQuerysetEqual

def assertQuerysetEqual(
    self,
    qs,
    values,
    transform=None,
    ordered=True,
    msg=None
)
View Source
    def assertQuerysetEqual(self, qs, values, transform=None, ordered=True, msg=None):
        values = list(values)
        # RemovedInDjango41Warning.
        if transform is None:
            if (
                values
                and isinstance(values[0], str)
                and qs
                and not isinstance(qs[0], str)
            ):
                # Transform qs using repr() if the first element of values is a
                # string and the first element of qs is not (which would be the
                # case if qs is a flattened values_list).
                warnings.warn(
                    "In Django 4.1, repr() will not be called automatically "
                    "on a queryset when compared to string values. Set an "
                    "explicit 'transform' to silence this warning.",
                    category=RemovedInDjango41Warning,
                    stacklevel=2,
                )
                transform = repr
        items = qs
        if transform is not None:
            items = map(transform, items)
        if not ordered:
            return self.assertDictEqual(Counter(items), Counter(values), msg=msg)
        # For example qs.iterator() could be passed as qs, but it does not
        # have 'ordered' attribute.
        if len(values) > 1 and hasattr(qs, "ordered") and not qs.ordered:
            raise ValueError(
                "Trying to compare non-ordered queryset against more than one "
                "ordered value."
            )
        return self.assertEqual(list(items), values, msg=msg)

assertRaises

def assertRaises(
    self,
    expected_exception,
    *args,
    **kwargs
)

Fail unless an exception of class expected_exception is raised

by the callable when invoked with specified positional and keyword arguments. If a different type of exception is raised, it will not be caught, and the test case will be deemed to have suffered an error, exactly as for an unexpected exception.

If called with the callable and arguments omitted, will return a context object used like this::

 with self.assertRaises(SomeException):
     do_something()

An optional keyword argument 'msg' can be provided when assertRaises is used as a context object.

The context manager keeps a reference to the exception as the 'exception' attribute. This allows you to inspect the exception after the assertion::

with self.assertRaises(SomeException) as cm:
    do_something()
the_exception = cm.exception
self.assertEqual(the_exception.error_code, 3)
View Source
    def assertRaises(self, expected_exception, *args, **kwargs):
        """Fail unless an exception of class expected_exception is raised
           by the callable when invoked with specified positional and
           keyword arguments. If a different type of exception is
           raised, it will not be caught, and the test case will be
           deemed to have suffered an error, exactly as for an
           unexpected exception.

           If called with the callable and arguments omitted, will return a
           context object used like this::

                with self.assertRaises(SomeException):
                    do_something()

           An optional keyword argument 'msg' can be provided when assertRaises
           is used as a context object.

           The context manager keeps a reference to the exception as
           the 'exception' attribute. This allows you to inspect the
           exception after the assertion::

               with self.assertRaises(SomeException) as cm:
                   do_something()
               the_exception = cm.exception
               self.assertEqual(the_exception.error_code, 3)
        """
        context = _AssertRaisesContext(expected_exception, self)
        try:
            return context.handle('assertRaises', args, kwargs)
        finally:
            # bpo-23890: manually break a reference cycle
            context = None

assertRaisesMessage

def assertRaisesMessage(
    self,
    expected_exception,
    expected_message,
    *args,
    **kwargs
)

Assert that expected_message is found in the message of a raised

exception.

Parameters:

Name Type Description Default
expected_exception None Exception class expected to be raised. None
expected_message None expected error message string value. None
args None Function to be called and extra positional args. None
kwargs None Extra kwargs. None
View Source
    def assertRaisesMessage(
        self, expected_exception, expected_message, *args, **kwargs
    ):
        """
        Assert that expected_message is found in the message of a raised
        exception.

        Args:
            expected_exception: Exception class expected to be raised.
            expected_message: expected error message string value.
            args: Function to be called and extra positional args.
            kwargs: Extra kwargs.
        """
        return self._assertFooMessage(
            self.assertRaises,
            "exception",
            expected_exception,
            expected_message,
            *args,
            **kwargs,
        )

assertRaisesRegex

def assertRaisesRegex(
    self,
    expected_exception,
    expected_regex,
    *args,
    **kwargs
)

Asserts that the message in a raised exception matches a regex.

Parameters:

Name Type Description Default
expected_exception None Exception class expected to be raised. None
expected_regex None Regex (re.Pattern object or string) expected
to be found in error message.
None
args None Function to be called and extra positional args. None
kwargs None Extra kwargs. None
msg None Optional message used in case of failure. Can only be used
when assertRaisesRegex is used as a context manager.
None
View Source
    def assertRaisesRegex(self, expected_exception, expected_regex,
                          *args, **kwargs):
        """Asserts that the message in a raised exception matches a regex.

        Args:
            expected_exception: Exception class expected to be raised.
            expected_regex: Regex (re.Pattern object or string) expected
                    to be found in error message.
            args: Function to be called and extra positional args.
            kwargs: Extra kwargs.
            msg: Optional message used in case of failure. Can only be used
                    when assertRaisesRegex is used as a context manager.
        """
        context = _AssertRaisesContext(expected_exception, self, expected_regex)
        return context.handle('assertRaisesRegex', args, kwargs)

assertRaisesRegexp

def assertRaisesRegexp(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

assertRedirects

def assertRedirects(
    self,
    response,
    expected_url,
    status_code=302,
    target_status_code=200,
    msg_prefix='',
    fetch_redirect_response=True
)

Assert that a response redirected to a specific URL and that the

redirect URL can be loaded.

Won't work for external links since it uses the test client to do a request (use fetch_redirect_response=False to check such links without fetching them).

View Source
    def assertRedirects(
        self,
        response,
        expected_url,
        status_code=302,
        target_status_code=200,
        msg_prefix="",
        fetch_redirect_response=True,
    ):
        """
        Assert that a response redirected to a specific URL and that the
        redirect URL can be loaded.

        Won't work for external links since it uses the test client to do a
        request (use fetch_redirect_response=False to check such links without
        fetching them).
        """
        if msg_prefix:
            msg_prefix += ": "

        if hasattr(response, "redirect_chain"):
            # The request was a followed redirect
            self.assertTrue(
                response.redirect_chain,
                msg_prefix
                + (
                    "Response didn't redirect as expected: Response code was %d "
                    "(expected %d)"
                )
                % (response.status_code, status_code),
            )

            self.assertEqual(
                response.redirect_chain[0][1],
                status_code,
                msg_prefix
                + (
                    "Initial response didn't redirect as expected: Response code was "
                    "%d (expected %d)"
                )
                % (response.redirect_chain[0][1], status_code),
            )

            url, status_code = response.redirect_chain[-1]

            self.assertEqual(
                response.status_code,
                target_status_code,
                msg_prefix
                + (
                    "Response didn't redirect as expected: Final Response code was %d "
                    "(expected %d)"
                )
                % (response.status_code, target_status_code),
            )

        else:
            # Not a followed redirect
            self.assertEqual(
                response.status_code,
                status_code,
                msg_prefix
                + (
                    "Response didn't redirect as expected: Response code was %d "
                    "(expected %d)"
                )
                % (response.status_code, status_code),
            )

            url = response.url
            scheme, netloc, path, query, fragment = urlsplit(url)

            # Prepend the request path to handle relative path redirects.
            if not path.startswith("/"):
                url = urljoin(response.request["PATH_INFO"], url)
                path = urljoin(response.request["PATH_INFO"], path)

            if fetch_redirect_response:
                # netloc might be empty, or in cases where Django tests the
                # HTTP scheme, the convention is for netloc to be 'testserver'.
                # Trust both as "internal" URLs here.
                domain, port = split_domain_port(netloc)
                if domain and not validate_host(domain, settings.ALLOWED_HOSTS):
                    raise ValueError(
                        "The test client is unable to fetch remote URLs (got %s). "
                        "If the host is served by Django, add '%s' to ALLOWED_HOSTS. "
                        "Otherwise, use "
                        "assertRedirects(..., fetch_redirect_response=False)."
                        % (url, domain)
                    )
                # Get the redirection page, using the same client that was used
                # to obtain the original response.
                extra = response.client.extra or {}
                redirect_response = response.client.get(
                    path,
                    QueryDict(query),
                    secure=(scheme == "https"),
                    **extra,
                )
                self.assertEqual(
                    redirect_response.status_code,
                    target_status_code,
                    msg_prefix
                    + (
                        "Couldn't retrieve redirection page '%s': response code was %d "
                        "(expected %d)"
                    )
                    % (path, redirect_response.status_code, target_status_code),
                )

        self.assertURLEqual(
            url,
            expected_url,
            msg_prefix
            + "Response redirected to '%s', expected '%s'" % (url, expected_url),
        )

assertRegex

def assertRegex(
    self,
    text,
    expected_regex,
    msg=None
)

Fail the test unless the text matches the regular expression.

View Source
    def assertRegex(self, text, expected_regex, msg=None):
        """Fail the test unless the text matches the regular expression."""
        if isinstance(expected_regex, (str, bytes)):
            assert expected_regex, "expected_regex must not be empty."
            expected_regex = re.compile(expected_regex)
        if not expected_regex.search(text):
            standardMsg = "Regex didn't match: %r not found in %r" % (
                expected_regex.pattern, text)
            # _formatMessage ensures the longMessage option is respected
            msg = self._formatMessage(msg, standardMsg)
            raise self.failureException(msg)

assertRegexpMatches

def assertRegexpMatches(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

assertSequenceEqual

def assertSequenceEqual(
    self,
    seq1,
    seq2,
    msg=None,
    seq_type=None
)

An equality assertion for ordered sequences (like lists and tuples).

For the purposes of this function, a valid ordered sequence type is one which can be indexed, has a length, and has an equality operator.

Parameters:

Name Type Description Default
seq1 None The first sequence to compare. None
seq2 None The second sequence to compare. None
seq_type None The expected datatype of the sequences, or None if no
datatype should be enforced.
None
msg None Optional message to use on failure instead of a list of
differences.
None
View Source
    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
        """An equality assertion for ordered sequences (like lists and tuples).

        For the purposes of this function, a valid ordered sequence type is one
        which can be indexed, has a length, and has an equality operator.

        Args:
            seq1: The first sequence to compare.
            seq2: The second sequence to compare.
            seq_type: The expected datatype of the sequences, or None if no
                    datatype should be enforced.
            msg: Optional message to use on failure instead of a list of
                    differences.
        """
        if seq_type is not None:
            seq_type_name = seq_type.__name__
            if not isinstance(seq1, seq_type):
                raise self.failureException('First sequence is not a %s: %s'
                                        % (seq_type_name, safe_repr(seq1)))
            if not isinstance(seq2, seq_type):
                raise self.failureException('Second sequence is not a %s: %s'
                                        % (seq_type_name, safe_repr(seq2)))
        else:
            seq_type_name = "sequence"

        differing = None
        try:
            len1 = len(seq1)
        except (TypeError, NotImplementedError):
            differing = 'First %s has no length.    Non-sequence?' % (
                    seq_type_name)

        if differing is None:
            try:
                len2 = len(seq2)
            except (TypeError, NotImplementedError):
                differing = 'Second %s has no length.    Non-sequence?' % (
                        seq_type_name)

        if differing is None:
            if seq1 == seq2:
                return

            differing = '%ss differ: %s != %s\n' % (
                    (seq_type_name.capitalize(),) +
                    _common_shorten_repr(seq1, seq2))

            for i in range(min(len1, len2)):
                try:
                    item1 = seq1[i]
                except (TypeError, IndexError, NotImplementedError):
                    differing += ('\nUnable to index element %d of first %s\n' %
                                 (i, seq_type_name))
                    break

                try:
                    item2 = seq2[i]
                except (TypeError, IndexError, NotImplementedError):
                    differing += ('\nUnable to index element %d of second %s\n' %
                                 (i, seq_type_name))
                    break

                if item1 != item2:
                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
                                 ((i,) + _common_shorten_repr(item1, item2)))
                    break
            else:
                if (len1 == len2 and seq_type is None and
                    type(seq1) != type(seq2)):
                    # The sequences are the same, but have differing types.
                    return

            if len1 > len2:
                differing += ('\nFirst %s contains %d additional '
                             'elements.\n' % (seq_type_name, len1 - len2))
                try:
                    differing += ('First extra element %d:\n%s\n' %
                                  (len2, safe_repr(seq1[len2])))
                except (TypeError, IndexError, NotImplementedError):
                    differing += ('Unable to index element %d '
                                  'of first %s\n' % (len2, seq_type_name))
            elif len1 < len2:
                differing += ('\nSecond %s contains %d additional '
                             'elements.\n' % (seq_type_name, len2 - len1))
                try:
                    differing += ('First extra element %d:\n%s\n' %
                                  (len1, safe_repr(seq2[len1])))
                except (TypeError, IndexError, NotImplementedError):
                    differing += ('Unable to index element %d '
                                  'of second %s\n' % (len1, seq_type_name))
        standardMsg = differing
        diffMsg = '\n' + '\n'.join(
            difflib.ndiff(pprint.pformat(seq1).splitlines(),
                          pprint.pformat(seq2).splitlines()))

        standardMsg = self._truncateMessage(standardMsg, diffMsg)
        msg = self._formatMessage(msg, standardMsg)
        self.fail(msg)

assertSetEqual

def assertSetEqual(
    self,
    set1,
    set2,
    msg=None
)

A set-specific equality assertion.

Parameters:

Name Type Description Default
set1 None The first set to compare. None
set2 None The second set to compare. None
msg None Optional message to use on failure instead of a list of
differences.
None
View Source
    def assertSetEqual(self, set1, set2, msg=None):
        """A set-specific equality assertion.

        Args:
            set1: The first set to compare.
            set2: The second set to compare.
            msg: Optional message to use on failure instead of a list of
                    differences.

        assertSetEqual uses ducktyping to support different types of sets, and
        is optimized for sets specifically (parameters must support a
        difference method).
        """
        try:
            difference1 = set1.difference(set2)
        except TypeError as e:
            self.fail('invalid type when attempting set difference: %s' % e)
        except AttributeError as e:
            self.fail('first argument does not support set difference: %s' % e)

        try:
            difference2 = set2.difference(set1)
        except TypeError as e:
            self.fail('invalid type when attempting set difference: %s' % e)
        except AttributeError as e:
            self.fail('second argument does not support set difference: %s' % e)

        if not (difference1 or difference2):
            return

        lines = []
        if difference1:
            lines.append('Items in the first set but not the second:')
            for item in difference1:
                lines.append(repr(item))
        if difference2:
            lines.append('Items in the second set but not the first:')
            for item in difference2:
                lines.append(repr(item))

        standardMsg = '\n'.join(lines)
        self.fail(self._formatMessage(msg, standardMsg))

assertTemplateNotUsed

def assertTemplateNotUsed(
    self,
    response=None,
    template_name=None,
    msg_prefix=''
)

Assert that the template with the provided name was NOT used in

rendering the response. Also usable as context manager.

View Source
    def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=""):
        """
        Assert that the template with the provided name was NOT used in
        rendering the response. Also usable as context manager.
        """
        context_mgr_template, template_names, msg_prefix = self._assert_template_used(
            response, template_name, msg_prefix
        )
        if context_mgr_template:
            # Use assertTemplateNotUsed as context manager.
            return _AssertTemplateNotUsedContext(self, context_mgr_template)

        self.assertFalse(
            template_name in template_names,
            msg_prefix
            + "Template '%s' was used unexpectedly in rendering the response"
            % template_name,
        )

assertTemplateUsed

def assertTemplateUsed(
    self,
    response=None,
    template_name=None,
    msg_prefix='',
    count=None
)

Assert that the template with the provided name was used in rendering

the response. Also usable as context manager.

View Source
    def assertTemplateUsed(
        self, response=None, template_name=None, msg_prefix="", count=None
    ):
        """
        Assert that the template with the provided name was used in rendering
        the response. Also usable as context manager.
        """
        context_mgr_template, template_names, msg_prefix = self._assert_template_used(
            response, template_name, msg_prefix
        )

        if context_mgr_template:
            # Use assertTemplateUsed as context manager.
            return _AssertTemplateUsedContext(self, context_mgr_template)

        if not template_names:
            self.fail(msg_prefix + "No templates used to render the response")
        self.assertTrue(
            template_name in template_names,
            msg_prefix + "Template '%s' was not a template used to render"
            " the response. Actual template(s) used: %s"
            % (template_name, ", ".join(template_names)),
        )

        if count is not None:
            self.assertEqual(
                template_names.count(template_name),
                count,
                msg_prefix + "Template '%s' was expected to be rendered %d "
                "time(s) but was actually rendered %d time(s)."
                % (template_name, count, template_names.count(template_name)),
            )

assertTrue

def assertTrue(
    self,
    expr,
    msg=None
)

Check that the expression is true.

View Source
    def assertTrue(self, expr, msg=None):
        """Check that the expression is true."""
        if not expr:
            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
            raise self.failureException(msg)

assertTupleEqual

def assertTupleEqual(
    self,
    tuple1,
    tuple2,
    msg=None
)

A tuple-specific equality assertion.

Parameters:

Name Type Description Default
tuple1 None The first tuple to compare. None
tuple2 None The second tuple to compare. None
msg None Optional message to use on failure instead of a list of
differences.
None
View Source
    def assertTupleEqual(self, tuple1, tuple2, msg=None):
        """A tuple-specific equality assertion.

        Args:
            tuple1: The first tuple to compare.
            tuple2: The second tuple to compare.
            msg: Optional message to use on failure instead of a list of
                    differences.
        """
        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)

assertURLEqual

def assertURLEqual(
    self,
    url1,
    url2,
    msg_prefix=''
)

Assert that two URLs are the same, ignoring the order of query string

parameters except for parameters with the same name.

For example, /path/?x=1&y=2 is equal to /path/?y=2&x=1, but /path/?a=1&a=2 isn't equal to /path/?a=2&a=1.

View Source
    def assertURLEqual(self, url1, url2, msg_prefix=""):
        """
        Assert that two URLs are the same, ignoring the order of query string
        parameters except for parameters with the same name.

        For example, /path/?x=1&y=2 is equal to /path/?y=2&x=1, but
        /path/?a=1&a=2 isn't equal to /path/?a=2&a=1.
        """

        def normalize(url):
            """Sort the URL's query string parameters."""
            url = str(url)  # Coerce reverse_lazy() URLs.
            scheme, netloc, path, params, query, fragment = urlparse(url)
            query_parts = sorted(parse_qsl(query))
            return urlunparse(
                (scheme, netloc, path, params, urlencode(query_parts), fragment)
            )

        self.assertEqual(
            normalize(url1),
            normalize(url2),
            msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2),
        )

assertWarns

def assertWarns(
    self,
    expected_warning,
    *args,
    **kwargs
)

Fail unless a warning of class warnClass is triggered

by the callable when invoked with specified positional and keyword arguments. If a different type of warning is triggered, it will not be handled: depending on the other warning filtering rules in effect, it might be silenced, printed out, or raised as an exception.

If called with the callable and arguments omitted, will return a context object used like this::

 with self.assertWarns(SomeWarning):
     do_something()

An optional keyword argument 'msg' can be provided when assertWarns is used as a context object.

The context manager keeps a reference to the first matching warning as the 'warning' attribute; similarly, the 'filename' and 'lineno' attributes give you information about the line of Python code from which the warning was triggered. This allows you to inspect the warning after the assertion::

with self.assertWarns(SomeWarning) as cm:
    do_something()
the_warning = cm.warning
self.assertEqual(the_warning.some_attribute, 147)
View Source
    def assertWarns(self, expected_warning, *args, **kwargs):
        """Fail unless a warning of class warnClass is triggered
           by the callable when invoked with specified positional and
           keyword arguments.  If a different type of warning is
           triggered, it will not be handled: depending on the other
           warning filtering rules in effect, it might be silenced, printed
           out, or raised as an exception.

           If called with the callable and arguments omitted, will return a
           context object used like this::

                with self.assertWarns(SomeWarning):
                    do_something()

           An optional keyword argument 'msg' can be provided when assertWarns
           is used as a context object.

           The context manager keeps a reference to the first matching
           warning as the 'warning' attribute; similarly, the 'filename'
           and 'lineno' attributes give you information about the line
           of Python code from which the warning was triggered.
           This allows you to inspect the warning after the assertion::

               with self.assertWarns(SomeWarning) as cm:
                   do_something()
               the_warning = cm.warning
               self.assertEqual(the_warning.some_attribute, 147)
        """
        context = _AssertWarnsContext(expected_warning, self)
        return context.handle('assertWarns', args, kwargs)

assertWarnsMessage

def assertWarnsMessage(
    self,
    expected_warning,
    expected_message,
    *args,
    **kwargs
)

Same as assertRaisesMessage but for assertWarns() instead of

assertRaises().

View Source
    def assertWarnsMessage(self, expected_warning, expected_message, *args, **kwargs):
        """
        Same as assertRaisesMessage but for assertWarns() instead of
        assertRaises().
        """
        return self._assertFooMessage(
            self.assertWarns,
            "warning",
            expected_warning,
            expected_message,
            *args,
            **kwargs,
        )

assertWarnsRegex

def assertWarnsRegex(
    self,
    expected_warning,
    expected_regex,
    *args,
    **kwargs
)

Asserts that the message in a triggered warning matches a regexp.

Basic functioning is similar to assertWarns() with the addition that only warnings whose messages also match the regular expression are considered successful matches.

Parameters:

Name Type Description Default
expected_warning None Warning class expected to be triggered. None
expected_regex None Regex (re.Pattern object or string) expected
to be found in error message.
None
args None Function to be called and extra positional args. None
kwargs None Extra kwargs. None
msg None Optional message used in case of failure. Can only be used
when assertWarnsRegex is used as a context manager.
None
View Source
    def assertWarnsRegex(self, expected_warning, expected_regex,
                         *args, **kwargs):
        """Asserts that the message in a triggered warning matches a regexp.
        Basic functioning is similar to assertWarns() with the addition
        that only warnings whose messages also match the regular expression
        are considered successful matches.

        Args:
            expected_warning: Warning class expected to be triggered.
            expected_regex: Regex (re.Pattern object or string) expected
                    to be found in error message.
            args: Function to be called and extra positional args.
            kwargs: Extra kwargs.
            msg: Optional message used in case of failure. Can only be used
                    when assertWarnsRegex is used as a context manager.
        """
        context = _AssertWarnsContext(expected_warning, self, expected_regex)
        return context.handle('assertWarnsRegex', args, kwargs)

assertXMLEqual

def assertXMLEqual(
    self,
    xml1,
    xml2,
    msg=None
)

Assert that two XML snippets are semantically the same.

Whitespace in most cases is ignored and attribute ordering is not significant. The arguments must be valid XML.

View Source
    def assertXMLEqual(self, xml1, xml2, msg=None):
        """
        Assert that two XML snippets are semantically the same.
        Whitespace in most cases is ignored and attribute ordering is not
        significant. The arguments must be valid XML.
        """
        try:
            result = compare_xml(xml1, xml2)
        except Exception as e:
            standardMsg = "First or second argument is not valid XML\n%s" % e
            self.fail(self._formatMessage(msg, standardMsg))
        else:
            if not result:
                standardMsg = "%s != %s" % (
                    safe_repr(xml1, True),
                    safe_repr(xml2, True),
                )
                diff = "\n" + "\n".join(
                    difflib.ndiff(xml1.splitlines(), xml2.splitlines())
                )
                standardMsg = self._truncateMessage(standardMsg, diff)
                self.fail(self._formatMessage(msg, standardMsg))

assertXMLNotEqual

def assertXMLNotEqual(
    self,
    xml1,
    xml2,
    msg=None
)

Assert that two XML snippets are not semantically equivalent.

Whitespace in most cases is ignored and attribute ordering is not significant. The arguments must be valid XML.

View Source
    def assertXMLNotEqual(self, xml1, xml2, msg=None):
        """
        Assert that two XML snippets are not semantically equivalent.
        Whitespace in most cases is ignored and attribute ordering is not
        significant. The arguments must be valid XML.
        """
        try:
            result = compare_xml(xml1, xml2)
        except Exception as e:
            standardMsg = "First or second argument is not valid XML\n%s" % e
            self.fail(self._formatMessage(msg, standardMsg))
        else:
            if result:
                standardMsg = "%s == %s" % (
                    safe_repr(xml1, True),
                    safe_repr(xml2, True),
                )
                self.fail(self._formatMessage(msg, standardMsg))

assert_

def assert_(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

countTestCases

def countTestCases(
    self
)
View Source
    def countTestCases(self):
        return 1

debug

def debug(
    self
)

Perform the same as call(), without catching the exception.

View Source
    def debug(self):
        """Perform the same as __call__(), without catching the exception."""
        debug_result = _DebugResult()
        self._setup_and_call(debug_result, debug=True)

defaultTestResult

def defaultTestResult(
    self
)
View Source
    def defaultTestResult(self):
        return result.TestResult()

doCleanups

def doCleanups(
    self
)

Execute all cleanup functions. Normally called for you after

tearDown.

View Source
    def doCleanups(self):
        """Execute all cleanup functions. Normally called for you after
        tearDown."""
        outcome = self._outcome or _Outcome()
        while self._cleanups:
            function, args, kwargs = self._cleanups.pop()
            with outcome.testPartExecutor(self):
                self._callCleanup(function, *args, **kwargs)

        # return this for backwards compatibility
        # even though we no longer use it internally
        return outcome.success

fail

def fail(
    self,
    msg=None
)

Fail immediately, with the given message.

View Source
    def fail(self, msg=None):
        """Fail immediately, with the given message."""
        raise self.failureException(msg)

failIf

def failIf(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

failIfAlmostEqual

def failIfAlmostEqual(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

failIfEqual

def failIfEqual(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

failUnless

def failUnless(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

failUnlessAlmostEqual

def failUnlessAlmostEqual(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

failUnlessEqual

def failUnlessEqual(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

failUnlessRaises

def failUnlessRaises(
    *args,
    **kwargs
)
View Source
        def deprecated_func(*args, **kwargs):
            warnings.warn(
                'Please use {0} instead.'.format(original_func.__name__),
                DeprecationWarning, 2)
            return original_func(*args, **kwargs)

id

def id(
    self
)
View Source
    def id(self):
        return "%s.%s" % (strclass(self.__class__), self._testMethodName)

modify_settings

def modify_settings(
    self,
    **kwargs
)

A context manager that temporarily applies changes a list setting and

reverts back to the original value when exiting the context.

View Source
    def modify_settings(self, **kwargs):
        """
        A context manager that temporarily applies changes a list setting and
        reverts back to the original value when exiting the context.
        """
        return modify_settings(**kwargs)

run

def run(
    self,
    result=None
)
View Source
    def run(self, result=None):
        if result is None:
            result = self.defaultTestResult()
            startTestRun = getattr(result, 'startTestRun', None)
            stopTestRun = getattr(result, 'stopTestRun', None)
            if startTestRun is not None:
                startTestRun()
        else:
            stopTestRun = None

        result.startTest(self)
        try:
            testMethod = getattr(self, self._testMethodName)
            if (getattr(self.__class__, "__unittest_skip__", False) or
                getattr(testMethod, "__unittest_skip__", False)):
                # If the class or method was skipped.
                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
                            or getattr(testMethod, '__unittest_skip_why__', ''))
                self._addSkip(result, self, skip_why)
                return result

            expecting_failure = (
                getattr(self, "__unittest_expecting_failure__", False) or
                getattr(testMethod, "__unittest_expecting_failure__", False)
            )
            outcome = _Outcome(result)
            try:
                self._outcome = outcome

                with outcome.testPartExecutor(self):
                    self._callSetUp()
                if outcome.success:
                    outcome.expecting_failure = expecting_failure
                    with outcome.testPartExecutor(self, isTest=True):
                        self._callTestMethod(testMethod)
                    outcome.expecting_failure = False
                    with outcome.testPartExecutor(self):
                        self._callTearDown()

                self.doCleanups()
                for test, reason in outcome.skipped:
                    self._addSkip(result, test, reason)
                self._feedErrorsToResult(result, outcome.errors)
                if outcome.success:
                    if expecting_failure:
                        if outcome.expectedFailure:
                            self._addExpectedFailure(result, outcome.expectedFailure)
                        else:
                            self._addUnexpectedSuccess(result)
                    else:
                        result.addSuccess(self)
                return result
            finally:
                # explicitly break reference cycles:
                # outcome.errors -> frame -> outcome -> outcome.errors
                # outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
                outcome.errors.clear()
                outcome.expectedFailure = None

                # clear the outcome, no more needed
                self._outcome = None

        finally:
            result.stopTest(self)
            if stopTestRun is not None:
                stopTestRun()

setUp

def setUp(
    self
)

Hook method for setting up the test fixture before exercising it.

View Source
    def setUp(self):
        "Hook method for setting up the test fixture before exercising it."
        pass

settings

def settings(
    self,
    **kwargs
)

A context manager that temporarily sets a setting and reverts to the

original value when exiting the context.

View Source
    def settings(self, **kwargs):
        """
        A context manager that temporarily sets a setting and reverts to the
        original value when exiting the context.
        """
        return override_settings(**kwargs)

shortDescription

def shortDescription(
    self
)

Returns a one-line description of the test, or None if no

description has been provided.

The default implementation of this method returns the first line of the specified test method's docstring.

View Source
    def shortDescription(self):
        """Returns a one-line description of the test, or None if no
        description has been provided.

        The default implementation of this method returns the first line of
        the specified test method's docstring.
        """
        doc = self._testMethodDoc
        return doc.strip().split("\n")[0].strip() if doc else None

skipTest

def skipTest(
    self,
    reason
)

Skip this test.

View Source
    def skipTest(self, reason):
        """Skip this test."""
        raise SkipTest(reason)

subTest

def subTest(
    self,
    msg=<object object at 0x7fbd5e331d90>,
    **params
)

Return a context manager that will return the enclosed block

of code in a subtest identified by the optional message and keyword parameters. A failure in the subtest marks the test case as failed but resumes execution at the end of the enclosed block, allowing further test code to be executed.

View Source
    @contextlib.contextmanager
    def subTest(self, msg=_subtest_msg_sentinel, **params):
        """Return a context manager that will return the enclosed block
        of code in a subtest identified by the optional message and
        keyword parameters.  A failure in the subtest marks the test
        case as failed but resumes execution at the end of the enclosed
        block, allowing further test code to be executed.
        """
        if self._outcome is None or not self._outcome.result_supports_subtests:
            yield
            return
        parent = self._subtest
        if parent is None:
            params_map = _OrderedChainMap(params)
        else:
            params_map = parent.params.new_child(params)
        self._subtest = _SubTest(self, msg, params_map)
        try:
            with self._outcome.testPartExecutor(self._subtest, isTest=True):
                yield
            if not self._outcome.success:
                result = self._outcome.result
                if result is not None and result.failfast:
                    raise _ShouldStop
            elif self._outcome.expectedFailure:
                # If the test is expecting a failure, we really want to
                # stop now and register the expected failure.
                raise _ShouldStop
        finally:
            self._subtest = parent

tearDown

def tearDown(
    self
)

Hook method for deconstructing the test fixture after testing it.

View Source
    def tearDown(self):
        "Hook method for deconstructing the test fixture after testing it."
        pass

test_aggregate

def test_aggregate(
    self
)
View Source
    def test_aggregate(self):
        aggr = MeanAggregation()
        models = [_create_torchscript_model_and_init(i) for i in range(10)]
        model = aggr.aggregate(models, [1]*10)
        cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
        self.assertEqual(cls_name, "Sequential")
        res = model.state_dict()
        self.assertEqual(len(models[0].state_dict()), len(res))
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * 4.5)
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * 4.5)

test_aggregate_sample_sizes

def test_aggregate_sample_sizes(
    self
)
View Source
    def test_aggregate_sample_sizes(self):
        aggr = MeanAggregation()
        models = [_create_torchscript_model_and_init(i) for i in range(3)]
        model = aggr.aggregate(models, [0, 1, 2])
        cls_name = model.original_name if is_torchscript_instance(model) else model.__class__.__name__
        self.assertEqual(cls_name, "Sequential")
        self.assertEqual(len(list(models[0].parameters())), len(list(model.parameters())))
        res = model.state_dict()
        torch.testing.assert_close(res["0.weight"], torch.ones_like(res["0.weight"]) * (5/3))
        torch.testing.assert_close(res["3.weight"], torch.ones_like(res["3.weight"]) * (5/3))