Skip to content

Commit 4a65b5d

Browse files
committed
Adding case for non float rtol and atol inputs
Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>
1 parent e01426c commit 4a65b5d

1 file changed

Lines changed: 16 additions & 4 deletions

File tree

src/array_api_extra/_lib/_testing.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ def xp_assert_close(
228228
actual: Array,
229229
desired: Array,
230230
*,
231-
rtol: float | None = None,
232-
atol: float = 0,
231+
rtol: float | Array | None = None,
232+
atol: float | Array = 0,
233233
equal_nan: bool = True,
234234
err_msg: str = "",
235235
verbose: bool = True,
@@ -246,9 +246,9 @@ def xp_assert_close(
246246
The array produced by the tested function.
247247
desired : Array
248248
The expected array (typically hardcoded).
249-
rtol : float, optional
249+
rtol : float or Array, optional
250250
Relative tolerance. Default: dtype-dependent.
251-
atol : float, optional
251+
atol : float or Array, optional
252252
Absolute tolerance. Default: 0.
253253
equal_nan : bool, default: True
254254
Whether to consider NaNs in corresponding locations as equal.
@@ -287,6 +287,18 @@ def xp_assert_close(
287287
else:
288288
rtol = 1e-7
289289

290+
if not isinstance(atol, float):
291+
atol = as_numpy_array(atol, xp=xp)
292+
if atol.ndim > 0:
293+
msg = "atol must be a scalar or 0-D array"
294+
raise TypeError(msg)
295+
296+
if not isinstance(rtol, float):
297+
rtol = as_numpy_array(rtol, xp=xp)
298+
if rtol.ndim > 0:
299+
msg = "rtol must be a scalar or 0-D array"
300+
raise TypeError(msg)
301+
290302
actual_np = as_numpy_array(actual, xp=xp)
291303
desired_np = as_numpy_array(desired, xp=xp)
292304
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]

0 commit comments

Comments
 (0)