@@ -287,21 +287,20 @@ def xp_assert_close(
287287 else :
288288 rtol = 1e-7
289289
290- if hasattr (atol , "ndim" ):
291- if atol .ndim == 0 :
292- atol = as_numpy_array (atol , xp = xp )
290+ if hasattr (atol , "ndim" ) and atol .ndim == 0 : # pyright: ignore[reportAttributeAccessIssue]
291+ atol = cast (Array , as_numpy_array (cast (Array , atol ), xp = xp )) # pyright: ignore[reportInvalidCast]
293292
294- if hasattr (rtol , "ndim" ):
295- if rtol .ndim == 0 :
296- rtol = as_numpy_array (rtol , xp = xp )
293+ if hasattr (rtol , "ndim" ) and rtol .ndim == 0 : # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
294+ rtol = cast (Array , as_numpy_array (cast (Array , rtol ), xp = xp )) # pyright: ignore[reportInvalidCast]
297295
298296 actual_np = as_numpy_array (actual , xp = xp )
299297 desired_np = as_numpy_array (desired , xp = xp )
300- np .testing .assert_allclose ( # pyright: ignore[reportCallIssue]
298+ np .testing .assert_allclose ( # pyright: ignore[reportCallIssue] # pyrefly: ignore[no-matching-overload]
301299 actual_np ,
302300 desired_np ,
303- rtol = rtol , # pyright: ignore[reportArgumentType]
304- atol = atol ,
301+ # https://github.com/numpy/numpy/issues/31449
302+ rtol = rtol , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
303+ atol = atol , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
305304 equal_nan = equal_nan ,
306305 err_msg = err_msg ,
307306 verbose = verbose ,
0 commit comments