Skip to content

Commit 76f7b9d

Browse files
authored
API: testing: relax condition that desired match xp in assertions (#785)
1 parent 28faf60 commit 76f7b9d

2 files changed

Lines changed: 15 additions & 12 deletions

File tree

src/array_api_extra/testing.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -598,22 +598,23 @@ def _check_ns_shape_dtype(
598598
np = _require_numpy()
599599

600600
actual_xp = array_namespace(actual) # Raises on Python scalars and lists
601-
desired_xp = array_namespace(desired)
602601

603602
if xp is not None:
604603
_msg = (
605-
"Namespace of desired array does not match the `xp` argument.\n"
606-
f"Desired array's namespace: {desired_xp.__name__}\n"
604+
"Namespace of actual array does not match the `xp` argument.\n"
605+
f"Actual array's namespace: {actual_xp.__name__}\n"
607606
f"Expected namespace: {xp.__name__}."
608607
)
609-
assert desired_xp == xp, _msg
610-
611-
_msg = (
612-
"Namespaces of actual and desired arrays do not match.\n"
613-
f"Actual: {actual_xp.__name__}\n"
614-
f"Desired: {desired_xp.__name__}."
615-
)
616-
assert actual_xp == desired_xp, _msg
608+
assert actual_xp == xp, _msg
609+
desired_xp = xp
610+
else:
611+
desired_xp = array_namespace(desired)
612+
_msg = (
613+
"Namespaces of actual and desired arrays do not match.\n"
614+
f"Actual: {actual_xp.__name__}\n"
615+
f"Desired: {desired_xp.__name__}."
616+
)
617+
assert actual_xp == desired_xp, _msg
617618

618619
if is_numpy_namespace(actual_xp) and check_scalar:
619620
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
@@ -650,6 +651,7 @@ def _check_ns_shape_dtype(
650651
msg = f"sizes do not match: {actual_size} != {desired_size}"
651652
assert actual_size == desired_size, msg
652653

654+
desired = desired_xp.asarray(desired)
653655
if check_dtype:
654656
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
655657
assert actual.dtype == desired.dtype, msg

tests/test_testing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,11 @@ def test_namespace(self, xp: ModuleType, func: Callable[..., None]):
8181
func(xp.asarray(0), 0)
8282
with pytest.raises(TypeError, match="list is not a supported array type"):
8383
func(xp.asarray([0]), [0])
84+
func(xp.asarray(0), xp.asarray(1 if func is assert_less else 0), xp=xp)
8485
with (
8586
pytest.raises(
8687
AssertionError,
87-
match="Namespace of desired array does not match the `xp` argument",
88+
match="Namespace of actual array does not match the `xp` argument",
8889
),
8990
):
9091
func(xp.asarray(0), xp.asarray(0), xp=np)

0 commit comments

Comments
 (0)