Skip to content

declearn.test_utils.assert_dict_equal

Assert that two (possibly nested) dicts are equal.

This function is a more complex equivalent of assert dict_a == dict_b that enables comparing numpy array values, and optionally accepting to cast tuples as lists rather than assert that a tuple and a list are not equal in any case (even when their contents are the same).

Parameters:

Name Type Description Default
dict_a Dict[str, Any]

First dict to compare.

required
dict_b Dict[str, Any]

Second dict to compare.

required
strict_tuple bool

Whether to cast tuples to list prior to comparing them (enabling some tuple-list type differences between the two compared dicts).

False
np_tolerance Optional[float]

Optional absolute tolerance to numpy arrays or float values' differences (use np.allclose(a, b, rtol=0, atol=np_tolerance)).

None

Raises:

Type Description
AssertionError

If the two dicts are not equal.

Source code in declearn/test_utils/_assertions.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def assert_dict_equal(
    dict_a: Dict[str, Any],
    dict_b: Dict[str, Any],
    strict_tuple: bool = False,
    np_tolerance: Optional[float] = None,
) -> None:
    """Assert that two (possibly nested) dicts are equal.

    This function is a more complex equivalent of `assert dict_a == dict_b`
    that enables comparing numpy array values, and optionally accepting to
    cast tuples as lists rather than assert that a tuple and a list are not
    equal in any case (even when their contents are the same).

    Parameters
    ----------
    dict_a: dict
        First dict to compare.
    dict_b: dict
        Second dict to compare.
    strict_tuple: bool, default=False
        Whether to cast tuples to list prior to comparing them
        (enabling some tuple-list type differences between the
        two compared dicts).
    np_tolerance: float or none, default=None
        Optional absolute tolerance to numpy arrays or float values'
        differences (use `np.allclose(a, b, rtol=0, atol=np_tolerance)`).

    Raises
    ------
    AssertionError
        If the two dicts are not equal.
    """
    assert dict_a.keys() == dict_b.keys()
    for key, val_a in dict_a.items():
        val_b = dict_b[key]
        assert_values_equal(val_a, val_b, strict_tuple, np_tolerance)