Skip to content

Commit 2ec7cf5

Browse files
authored
ENH: testing: support 0-D arrays for rtol and atol (#743)
1 parent e0bae44 commit 2ec7cf5

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

src/array_api_extra/_lib/_testing.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ def xp_assert_close(
228228
actual: Array,
229229
desired: Array,
230230
*,
231-
rtol: float | None = None,
232-
atol: float = 0,
231+
rtol: float | Array | None = None,
232+
atol: float | Array = 0,
233233
equal_nan: bool = True,
234234
err_msg: str = "",
235235
verbose: bool = True,
@@ -246,9 +246,9 @@ def xp_assert_close(
246246
The array produced by the tested function.
247247
desired : Array
248248
The expected array (typically hardcoded).
249-
rtol : float, optional
249+
rtol : float or Array, optional
250250
Relative tolerance. Default: dtype-dependent.
251-
atol : float, optional
251+
atol : float or Array, optional
252252
Absolute tolerance. Default: 0.
253253
equal_nan : bool, default: True
254254
Whether to consider NaNs in corresponding locations as equal.
@@ -271,6 +271,8 @@ def xp_assert_close(
271271
Notes
272272
-----
273273
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
274+
275+
Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`.
274276
"""
275277
actual, desired, xp = _check_ns_shape_dtype(
276278
actual, desired, check_dtype, check_shape, check_scalar
@@ -286,13 +288,17 @@ def xp_assert_close(
286288
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
287289
else:
288290
rtol = 1e-7
291+
else:
292+
rtol = float(rtol)
293+
294+
atol = float(atol)
289295

290296
actual_np = as_numpy_array(actual, xp=xp)
291297
desired_np = as_numpy_array(desired, xp=xp)
292298
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
293299
actual_np,
294300
desired_np,
295-
rtol=rtol, # pyright: ignore[reportArgumentType]
301+
rtol=rtol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
296302
atol=atol,
297303
equal_nan=equal_nan,
298304
err_msg=err_msg,

0 commit comments

Comments
 (0)