Utility function to test that a batch of the declearn.typing.Batch
type is equal to an expected, numpy-based declearn.typing.Batch output.
Source code in declearn/test_utils/_assertions.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252 | def assert_batch_equal(
result: Sequence, expected: Sequence, framework: str
) -> None:
"""Utility function to test that a batch of the declearn.typing.Batch
type is equal to an expected, numpy-based declearn.typing.Batch output.
"""
# Flatten and assert type and shape of the arbitrarily nested batch
gen = flatten_and_assert(result, expected)
# Check all elements are equal
for out in gen:
res, exp = out
# batchj element is None
if res is None:
assert exp is None
# batch element is a tensor
else:
res = to_numpy(res, framework)
assert_array_equal(res, exp)
|