Skip to content

Commit 1fae65b

Browse files
committed
Adding default_xp context manager for xp_assert functions
Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>
1 parent 2076183 commit 1fae65b

2 files changed

Lines changed: 67 additions & 6 deletions

File tree

src/array_api_extra/_lib/_testing.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from __future__ import annotations
99

1010
import math
11+
from collections.abc import Generator
12+
from contextlib import contextmanager
13+
from contextvars import ContextVar
1114
from types import ModuleType
1215
from typing import Any, cast
1316

@@ -30,13 +33,53 @@
3033

3134
__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"]
3235

36+
_default_xp_ctxvar: ContextVar[ModuleType] = ContextVar("_default_xp")
37+
38+
39+
@contextmanager
40+
def default_xp(xp: ModuleType) -> Generator[None, None, None]:
41+
"""In all ``xp_assert_*`` function calls executed within this
42+
context manager, test by default that the array namespace is
43+
the provided across all arrays, unless one explicitly passes the ``xp=``
44+
parameter.
45+
46+
Without this context manager, the default value for `xp` is the namespace
47+
for the desired array (the second parameter of the tests).
48+
"""
49+
token = _default_xp_ctxvar.set(xp)
50+
try:
51+
yield
52+
finally:
53+
_default_xp_ctxvar.reset(token)
54+
55+
56+
def _assert_matching_namespace(actual: Array, desired: Array, xp: ModuleType) -> None:
57+
desired_arr_space = array_namespace(desired)
58+
_msg = (
59+
"Namespace of desired array does not match expectations "
60+
"set by the `default_xp` context manager or by the `xp`"
61+
"pytest fixture.\n"
62+
f"Desired array's space: {desired_arr_space.__name__}\n"
63+
f"Expected namespace: {xp.__name__}"
64+
)
65+
assert desired_arr_space == xp, _msg
66+
67+
actual_arr_space = array_namespace(actual)
68+
_msg = (
69+
"Namespace of actual and desired arrays do not match.\n"
70+
f"Actual: {actual_arr_space.__name__}\n"
71+
f"Desired: {xp.__name__}"
72+
)
73+
assert actual_arr_space == xp, _msg
74+
3375

3476
def _check_ns_shape_dtype(
3577
actual: Array,
3678
desired: Array,
3779
check_dtype: bool,
3880
check_shape: bool,
3981
check_scalar: bool,
82+
xp: ModuleType | None = None,
4083
) -> tuple[Array, Array, ModuleType]: # numpydoc ignore=RT03
4184
"""
4285
Assert that namespace, shape and dtype of the two arrays match.
@@ -60,8 +103,12 @@ def _check_ns_shape_dtype(
60103
actual_xp = array_namespace(actual) # Raises on Python scalars and lists
61104
desired_xp = array_namespace(desired)
62105

63-
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
64-
assert actual_xp == desired_xp, msg
106+
if xp is None:
107+
try:
108+
xp = _default_xp_ctxvar.get()
109+
except LookupError:
110+
xp = array_namespace(desired)
111+
_assert_matching_namespace(actual, desired, xp)
65112

66113
if is_numpy_namespace(actual_xp) and check_scalar:
67114
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
@@ -148,6 +195,7 @@ def xp_assert_equal(
148195
check_dtype: bool = True,
149196
check_shape: bool = True,
150197
check_scalar: bool = False,
198+
xp: ModuleType | None = None,
151199
) -> None:
152200
"""
153201
Array-API compatible version of `np.testing.assert_array_equal`.
@@ -174,7 +222,7 @@ def xp_assert_equal(
174222
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
175223
"""
176224
actual, desired, xp = _check_ns_shape_dtype(
177-
actual, desired, check_dtype, check_shape, check_scalar
225+
actual, desired, check_dtype, check_shape, check_scalar, xp
178226
)
179227
if not _is_materializable(actual):
180228
return
@@ -194,6 +242,7 @@ def xp_assert_less(
194242
check_dtype: bool = True,
195243
check_shape: bool = True,
196244
check_scalar: bool = False,
245+
xp: ModuleType | None = None,
197246
) -> None:
198247
"""
199248
Array-API compatible version of `np.testing.assert_array_less`.
@@ -217,7 +266,7 @@ def xp_assert_less(
217266
xp_assert_close : Similar function for inexact equality checks.
218267
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
219268
"""
220-
x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
269+
x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar, xp)
221270
if not _is_materializable(x):
222271
return
223272
x_np = as_numpy_array(x, xp=xp)
@@ -237,6 +286,7 @@ def xp_assert_close(
237286
check_dtype: bool = True,
238287
check_shape: bool = True,
239288
check_scalar: bool = False,
289+
xp: ModuleType | None = None,
240290
) -> None:
241291
"""
242292
Array-API compatible version of `np.testing.assert_allclose`.
@@ -276,7 +326,7 @@ def xp_assert_close(
276326
Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`.
277327
"""
278328
actual, desired, xp = _check_ns_shape_dtype(
279-
actual, desired, check_dtype, check_shape, check_scalar
329+
actual, desired, check_dtype, check_shape, check_scalar, xp
280330
)
281331
if not _is_materializable(actual):
282332
return

tests/test_testing.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from array_api_extra._lib._backends import Backend
1010
from array_api_extra._lib._testing import (
1111
as_numpy_array,
12+
default_xp,
1213
xp_assert_close,
1314
xp_assert_equal,
1415
xp_assert_less,
@@ -63,12 +64,22 @@ def test_shape_dtype(self, xp: ModuleType, func: Callable[..., None]):
6364
)
6465
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
6566
def test_namespace(self, xp: ModuleType, func: Callable[..., None]):
66-
with pytest.raises(AssertionError, match="namespaces do not match"):
67+
with pytest.raises(
68+
AssertionError, match="Namespace of actual and desired arrays do not match"
69+
):
6770
func(xp.asarray(0), np.asarray(0))
6871
with pytest.raises(TypeError, match=r"array_namespace requires .* array input"):
6972
func(xp.asarray(0), 0)
7073
with pytest.raises(TypeError, match="list is not a supported array type"):
7174
func(xp.asarray([0]), [0])
75+
with (
76+
default_xp(np),
77+
pytest.raises(
78+
AssertionError,
79+
match="Namespace of desired array does not match expectations",
80+
),
81+
):
82+
func(xp.asarray(0), xp.asarray(0))
7283

7384
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
7485
def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):

0 commit comments

Comments
 (0)