Skip to content

Commit 49723b0

Browse files
committed
adding check_namespace option
Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>
1 parent 5045afa commit 49723b0

1 file changed

Lines changed: 32 additions & 5 deletions

File tree

src/array_api_extra/_lib/_testing.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
def _check_ns_shape_dtype(
3535
actual: Array,
3636
desired: Array,
37+
check_namespace: bool,
3738
check_dtype: bool,
3839
check_shape: bool,
3940
check_scalar: bool,
@@ -47,6 +48,8 @@ def _check_ns_shape_dtype(
4748
The array produced by the tested function.
4849
desired : Array
4950
The expected array (typically hardcoded).
51+
check_namespace : bool, default: True
52+
Whether to check agreement between actual and desired namespace.
5053
check_dtype, check_shape : bool, default: True
5154
Whether to check agreement between actual and desired dtypes and shapes
5255
check_scalar : bool, default: False
@@ -60,8 +63,9 @@ def _check_ns_shape_dtype(
6063
actual_xp = array_namespace(actual) # Raises on scalars and lists
6164
desired_xp = array_namespace(desired)
6265

63-
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
64-
assert actual_xp == desired_xp, msg
66+
if check_namespace:
67+
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
68+
assert actual_xp == desired_xp, msg
6569

6670
# Dask uses nan instead of None for unknown shapes
6771
actual_shape = cast(tuple[float, ...], actual.shape)
@@ -139,6 +143,7 @@ def xp_assert_equal(
139143
desired: Array,
140144
*,
141145
err_msg: str = "",
146+
check_namespace: bool = True,
142147
check_dtype: bool = True,
143148
check_shape: bool = True,
144149
check_scalar: bool = False,
@@ -154,6 +159,8 @@ def xp_assert_equal(
154159
The expected array (typically hardcoded).
155160
err_msg : str, optional
156161
Error message to display on failure.
162+
check_namespace : bool, default: True
163+
Whether to check agreement between actual and desired namespace.
157164
check_dtype, check_shape : bool, default: True
158165
Whether to check agreement between actual and desired dtypes and shapes
159166
check_scalar : bool, default: False
@@ -165,7 +172,14 @@ def xp_assert_equal(
165172
xp_assert_close : Similar function for inexact equality checks.
166173
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
167174
"""
168-
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
175+
xp = _check_ns_shape_dtype(
176+
actual,
177+
desired,
178+
check_namespace,
179+
check_dtype,
180+
check_shape,
181+
check_scalar,
182+
)
169183
if not _is_materializable(actual):
170184
return
171185
actual_np = as_numpy_array(actual, xp=xp)
@@ -178,6 +192,7 @@ def xp_assert_less(
178192
y: Array,
179193
*,
180194
err_msg: str = "",
195+
check_namespace: bool = True,
181196
check_dtype: bool = True,
182197
check_shape: bool = True,
183198
check_scalar: bool = False,
@@ -191,6 +206,8 @@ def xp_assert_less(
191206
The arrays to compare according to ``x < y`` (elementwise).
192207
err_msg : str, optional
193208
Error message to display on failure.
209+
check_namespace : bool, default: True
210+
Whether to check agreement between actual and desired namespace.
194211
check_dtype, check_shape : bool, default: True
195212
Whether to check agreement between actual and desired dtypes and shapes
196213
check_scalar : bool, default: False
@@ -202,7 +219,7 @@ def xp_assert_less(
202219
xp_assert_close : Similar function for inexact equality checks.
203220
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
204221
"""
205-
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
222+
xp = _check_ns_shape_dtype(x, y, check_namespace, check_dtype, check_shape, check_scalar)
206223
if not _is_materializable(x):
207224
return
208225
x_np = as_numpy_array(x, xp=xp)
@@ -217,6 +234,7 @@ def xp_assert_close(
217234
rtol: float | None = None,
218235
atol: float = 0,
219236
err_msg: str = "",
237+
check_namespace: bool = True,
220238
check_dtype: bool = True,
221239
check_shape: bool = True,
222240
check_scalar: bool = False,
@@ -236,6 +254,8 @@ def xp_assert_close(
236254
Absolute tolerance. Default: 0.
237255
err_msg : str, optional
238256
Error message to display on failure.
257+
check_namespace : bool, default: True
258+
Whether to check agreement between actual and desired namespace.
239259
check_dtype, check_shape : bool, default: True
240260
Whether to check agreement between actual and desired dtypes and shapes
241261
check_scalar : bool, default: False
@@ -252,7 +272,14 @@ def xp_assert_close(
252272
-----
253273
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
254274
"""
255-
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
275+
xp = _check_ns_shape_dtype(
276+
actual,
277+
desired,
278+
check_namespace,
279+
check_dtype,
280+
check_shape,
281+
check_scalar,
282+
)
256283
if not _is_materializable(actual):
257284
return
258285

0 commit comments

Comments
 (0)