Skip to content

declearn.test_utils.assert_batch_equal

Utility function to test that a batch of the declearn.typing.Batch type is equal to an expected, numpy-based declearn.typing.Batch output.

Source code in declearn/test_utils/_assertions.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def assert_batch_equal(
    result: Sequence, expected: Sequence, framework: str
) -> None:
    """Utility function to test that a batch of the declearn.typing.Batch
    type is equal to an expected, numpy-based declearn.typing.Batch output.
    """
    # Flatten and assert type and shape of the arbitrarily nested batch
    gen = flatten_and_assert(result, expected)
    # Check all elements are equal
    for out in gen:
        res, exp = out
        # batchj element is None
        if res is None:
            assert exp is None
        # batch element is a tensor
        else:
            res = to_numpy(res, framework)
            assert_array_equal(res, exp)