@@ -228,8 +228,8 @@ def xp_assert_close(
228228 actual : Array ,
229229 desired : Array ,
230230 * ,
231- rtol : float | None = None ,
232- atol : float = 0 ,
231+ rtol : float | Array | None = None ,
232+ atol : float | Array = 0 ,
233233 equal_nan : bool = True ,
234234 err_msg : str = "" ,
235235 verbose : bool = True ,
@@ -246,9 +246,9 @@ def xp_assert_close(
246246 The array produced by the tested function.
247247 desired : Array
248248 The expected array (typically hardcoded).
249- rtol : float, optional
249+ rtol : float or Array , optional
250250 Relative tolerance. Default: dtype-dependent.
251- atol : float, optional
251+ atol : float or Array , optional
252252 Absolute tolerance. Default: 0.
253253 equal_nan : bool, default: True
254254 Whether to consider NaNs in corresponding locations as equal.
@@ -271,6 +271,8 @@ def xp_assert_close(
271271 Notes
272272 -----
273273 The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
274+
275+ Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`.
274276 """
275277 actual , desired , xp = _check_ns_shape_dtype (
276278 actual , desired , check_dtype , check_shape , check_scalar
@@ -286,13 +288,17 @@ def xp_assert_close(
286288 rtol = xp .finfo (actual .dtype ).eps ** 0.5 * 4
287289 else :
288290 rtol = 1e-7
291+ else :
292+ rtol = float (rtol )
293+
294+ atol = float (atol )
289295
290296 actual_np = as_numpy_array (actual , xp = xp )
291297 desired_np = as_numpy_array (desired , xp = xp )
292298 np .testing .assert_allclose ( # pyright: ignore[reportCallIssue]
293299 actual_np ,
294300 desired_np ,
295- rtol = rtol , # pyright: ignore[reportArgumentType]
301+ rtol = rtol , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
296302 atol = atol ,
297303 equal_nan = equal_nan ,
298304 err_msg = err_msg ,
0 commit comments