3434def _check_ns_shape_dtype (
3535 actual : Array ,
3636 desired : Array ,
37+ check_namespace : bool ,
3738 check_dtype : bool ,
3839 check_shape : bool ,
3940 check_scalar : bool ,
@@ -47,6 +48,8 @@ def _check_ns_shape_dtype(
4748 The array produced by the tested function.
4849 desired : Array
4950 The expected array (typically hardcoded).
51+ check_namespace : bool, default: True
52+ Whether to check agreement between actual and desired namespace.
5053 check_dtype, check_shape : bool, default: True
5154 Whether to check agreement between actual and desired dtypes and shapes
5255 check_scalar : bool, default: False
@@ -60,8 +63,9 @@ def _check_ns_shape_dtype(
6063 actual_xp = array_namespace (actual ) # Raises on scalars and lists
6164 desired_xp = array_namespace (desired )
6265
63- msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
64- assert actual_xp == desired_xp , msg
66+ if check_namespace :
67+ msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
68+ assert actual_xp == desired_xp , msg
6569
6670 # Dask uses nan instead of None for unknown shapes
6771 actual_shape = cast (tuple [float , ...], actual .shape )
@@ -139,6 +143,7 @@ def xp_assert_equal(
139143 desired : Array ,
140144 * ,
141145 err_msg : str = "" ,
146+ check_namespace : bool = True ,
142147 check_dtype : bool = True ,
143148 check_shape : bool = True ,
144149 check_scalar : bool = False ,
@@ -154,6 +159,8 @@ def xp_assert_equal(
154159 The expected array (typically hardcoded).
155160 err_msg : str, optional
156161 Error message to display on failure.
162+ check_namespace : bool, default: True
163+ Whether to check agreement between actual and desired namespace.
157164 check_dtype, check_shape : bool, default: True
158165 Whether to check agreement between actual and desired dtypes and shapes
159166 check_scalar : bool, default: False
@@ -165,7 +172,14 @@ def xp_assert_equal(
165172 xp_assert_close : Similar function for inexact equality checks.
166173 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
167174 """
168- xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
175+ xp = _check_ns_shape_dtype (
176+ actual ,
177+ desired ,
178+ check_namespace ,
179+ check_dtype ,
180+ check_shape ,
181+ check_scalar ,
182+ )
169183 if not _is_materializable (actual ):
170184 return
171185 actual_np = as_numpy_array (actual , xp = xp )
@@ -178,6 +192,7 @@ def xp_assert_less(
178192 y : Array ,
179193 * ,
180194 err_msg : str = "" ,
195+ check_namespace : bool = True ,
181196 check_dtype : bool = True ,
182197 check_shape : bool = True ,
183198 check_scalar : bool = False ,
@@ -191,6 +206,8 @@ def xp_assert_less(
191206 The arrays to compare according to ``x < y`` (elementwise).
192207 err_msg : str, optional
193208 Error message to display on failure.
209+ check_namespace : bool, default: True
210+ Whether to check agreement between actual and desired namespace.
194211 check_dtype, check_shape : bool, default: True
195212 Whether to check agreement between actual and desired dtypes and shapes
196213 check_scalar : bool, default: False
@@ -202,7 +219,7 @@ def xp_assert_less(
202219 xp_assert_close : Similar function for inexact equality checks.
203220 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
204221 """
205- xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
222+ xp = _check_ns_shape_dtype (x , y , check_namespace , check_dtype , check_shape , check_scalar )
206223 if not _is_materializable (x ):
207224 return
208225 x_np = as_numpy_array (x , xp = xp )
@@ -217,6 +234,7 @@ def xp_assert_close(
217234 rtol : float | None = None ,
218235 atol : float = 0 ,
219236 err_msg : str = "" ,
237+ check_namespace : bool = True ,
220238 check_dtype : bool = True ,
221239 check_shape : bool = True ,
222240 check_scalar : bool = False ,
@@ -236,6 +254,8 @@ def xp_assert_close(
236254 Absolute tolerance. Default: 0.
237255 err_msg : str, optional
238256 Error message to display on failure.
257+ check_namespace : bool, default: True
258+ Whether to check agreement between actual and desired namespace.
239259 check_dtype, check_shape : bool, default: True
240260 Whether to check agreement between actual and desired dtypes and shapes
241261 check_scalar : bool, default: False
@@ -252,7 +272,14 @@ def xp_assert_close(
252272 -----
253273 The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
254274 """
255- xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
275+ xp = _check_ns_shape_dtype (
276+ actual ,
277+ desired ,
278+ check_namespace ,
279+ check_dtype ,
280+ check_shape ,
281+ check_scalar ,
282+ )
256283 if not _is_materializable (actual ):
257284 return
258285
0 commit comments