@@ -228,8 +228,8 @@ def xp_assert_close(
228228 actual : Array ,
229229 desired : Array ,
230230 * ,
231- rtol : float | None = None ,
232- atol : float = 0 ,
231+ rtol : float | Array | None = None ,
232+ atol : float | Array = 0 ,
233233 equal_nan : bool = True ,
234234 err_msg : str = "" ,
235235 verbose : bool = True ,
@@ -246,9 +246,9 @@ def xp_assert_close(
246246 The array produced by the tested function.
247247 desired : Array
248248 The expected array (typically hardcoded).
249- rtol : float, optional
249+ rtol : float or Array , optional
250250 Relative tolerance. Default: dtype-dependent.
251- atol : float, optional
251+ atol : float or Array , optional
252252 Absolute tolerance. Default: 0.
253253 equal_nan : bool, default: True
254254 Whether to consider NaNs in corresponding locations as equal.
@@ -287,6 +287,18 @@ def xp_assert_close(
287287 else :
288288 rtol = 1e-7
289289
290+ if not isinstance (atol , float ):
291+ atol = as_numpy_array (atol , xp = xp )
292+ if atol .ndim > 0 :
293+ msg = "atol must be a scalar or 0-D array"
294+ raise TypeError (msg )
295+
296+ if not isinstance (rtol , float ):
297+ rtol = as_numpy_array (rtol , xp = xp )
298+ if rtol .ndim > 0 :
299+ msg = "rtol must be a scalar or 0-D array"
300+ raise TypeError (msg )
301+
290302 actual_np = as_numpy_array (actual , xp = xp )
291303 desired_np = as_numpy_array (desired , xp = xp )
292304 np .testing .assert_allclose ( # pyright: ignore[reportCallIssue]
0 commit comments