88from __future__ import annotations
99
1010import math
11+ from collections .abc import Generator
12+ from contextlib import contextmanager
13+ from contextvars import ContextVar
1114from types import ModuleType
1215from typing import Any , cast
1316
3033
3134__all__ = ["as_numpy_array" , "xp_assert_close" , "xp_assert_equal" , "xp_assert_less" ]
3235
36+ _default_xp_ctxvar : ContextVar [ModuleType ] = ContextVar ("_default_xp" )
37+
38+
39+ @contextmanager
40+ def default_xp (xp : ModuleType ) -> Generator [None , None , None ]:
41+ """In all ``xp_assert_*`` function calls executed within this
42+ context manager, test by default that the array namespace is
43+ the provided across all arrays, unless one explicitly passes the ``xp=``
44+ parameter.
45+
46+ Without this context manager, the default value for `xp` is the namespace
47+ for the desired array (the second parameter of the tests).
48+ """
49+ token = _default_xp_ctxvar .set (xp )
50+ try :
51+ yield
52+ finally :
53+ _default_xp_ctxvar .reset (token )
54+
55+
56+ def _assert_matching_namespace (actual : Array , desired : Array , xp : ModuleType ) -> None :
57+ desired_arr_space = array_namespace (desired )
58+ _msg = (
59+ "Namespace of desired array does not match expectations "
60+ "set by the `default_xp` context manager or by the `xp`"
61+ "pytest fixture.\n "
62+ f"Desired array's space: { desired_arr_space .__name__ } \n "
63+ f"Expected namespace: { xp .__name__ } "
64+ )
65+ assert desired_arr_space == xp , _msg
66+
67+ actual_arr_space = array_namespace (actual )
68+ _msg = (
69+ "Namespace of actual and desired arrays do not match.\n "
70+ f"Actual: { actual_arr_space .__name__ } \n "
71+ f"Desired: { xp .__name__ } "
72+ )
73+ assert actual_arr_space == xp , _msg
74+
3375
3476def _check_ns_shape_dtype (
3577 actual : Array ,
3678 desired : Array ,
3779 check_dtype : bool ,
3880 check_shape : bool ,
3981 check_scalar : bool ,
82+ xp : ModuleType | None = None ,
4083) -> tuple [Array , Array , ModuleType ]: # numpydoc ignore=RT03
4184 """
4285 Assert that namespace, shape and dtype of the two arrays match.
@@ -60,8 +103,12 @@ def _check_ns_shape_dtype(
60103 actual_xp = array_namespace (actual ) # Raises on Python scalars and lists
61104 desired_xp = array_namespace (desired )
62105
63- msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
64- assert actual_xp == desired_xp , msg
106+ if xp is None :
107+ try :
108+ xp = _default_xp_ctxvar .get ()
109+ except LookupError :
110+ xp = array_namespace (desired )
111+ _assert_matching_namespace (actual , desired , xp )
65112
66113 if is_numpy_namespace (actual_xp ) and check_scalar :
67114 # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
@@ -148,6 +195,7 @@ def xp_assert_equal(
148195 check_dtype : bool = True ,
149196 check_shape : bool = True ,
150197 check_scalar : bool = False ,
198+ xp : ModuleType | None = None ,
151199) -> None :
152200 """
153201 Array-API compatible version of `np.testing.assert_array_equal`.
@@ -174,7 +222,7 @@ def xp_assert_equal(
174222 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
175223 """
176224 actual , desired , xp = _check_ns_shape_dtype (
177- actual , desired , check_dtype , check_shape , check_scalar
225+ actual , desired , check_dtype , check_shape , check_scalar , xp
178226 )
179227 if not _is_materializable (actual ):
180228 return
@@ -194,6 +242,7 @@ def xp_assert_less(
194242 check_dtype : bool = True ,
195243 check_shape : bool = True ,
196244 check_scalar : bool = False ,
245+ xp : ModuleType | None = None ,
197246) -> None :
198247 """
199248 Array-API compatible version of `np.testing.assert_array_less`.
@@ -217,7 +266,7 @@ def xp_assert_less(
217266 xp_assert_close : Similar function for inexact equality checks.
218267 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
219268 """
220- x , y , xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
269+ x , y , xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar , xp )
221270 if not _is_materializable (x ):
222271 return
223272 x_np = as_numpy_array (x , xp = xp )
@@ -237,6 +286,7 @@ def xp_assert_close(
237286 check_dtype : bool = True ,
238287 check_shape : bool = True ,
239288 check_scalar : bool = False ,
289+ xp : ModuleType | None = None ,
240290) -> None :
241291 """
242292 Array-API compatible version of `np.testing.assert_allclose`.
@@ -276,7 +326,7 @@ def xp_assert_close(
276326 Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`.
277327 """
278328 actual , desired , xp = _check_ns_shape_dtype (
279- actual , desired , check_dtype , check_shape , check_scalar
329+ actual , desired , check_dtype , check_shape , check_scalar , xp
280330 )
281331 if not _is_materializable (actual ):
282332 return
0 commit comments