We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 941bd07 commit 89dd532Copy full SHA for 89dd532
1 file changed
src/array_api_extra/_lib/_testing.py
@@ -288,16 +288,12 @@ def xp_assert_close(
288
rtol = 1e-7
289
290
if hasattr(atol, "ndim"):
291
- atol = as_numpy_array(atol, xp=xp)
292
- if atol.ndim > 0:
293
- msg = "atol must be a scalar or 0-D array"
294
- raise TypeError(msg)
+ if atol.ndim == 0:
+ atol = as_numpy_array(atol, xp=xp)
295
296
if hasattr(rtol, "ndim"):
297
- rtol = as_numpy_array(rtol, xp=xp)
298
- if rtol.ndim > 0:
299
- msg = "rtol must be a scalar or 0-D array"
300
+ if rtol.ndim == 0:
+ rtol = as_numpy_array(rtol, xp=xp)
301
302
actual_np = as_numpy_array(actual, xp=xp)
303
desired_np = as_numpy_array(desired, xp=xp)
0 commit comments