Skip to content

Commit 3848cf2

Browse files
committed
typing
1 parent 89dd532 commit 3848cf2

1 file changed

Lines changed: 8 additions & 9 deletions

File tree

src/array_api_extra/_lib/_testing.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,21 +287,20 @@ def xp_assert_close(
287287
else:
288288
rtol = 1e-7
289289

290-
if hasattr(atol, "ndim"):
291-
if atol.ndim == 0:
292-
atol = as_numpy_array(atol, xp=xp)
290+
if hasattr(atol, "ndim") and atol.ndim == 0: # pyright: ignore[reportAttributeAccessIssue]
291+
atol = cast(Array, as_numpy_array(cast(Array, atol), xp=xp)) # pyright: ignore[reportInvalidCast]
293292

294-
if hasattr(rtol, "ndim"):
295-
if rtol.ndim == 0:
296-
rtol = as_numpy_array(rtol, xp=xp)
293+
if hasattr(rtol, "ndim") and rtol.ndim == 0: # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
294+
rtol = cast(Array, as_numpy_array(cast(Array, rtol), xp=xp)) # pyright: ignore[reportInvalidCast]
297295

298296
actual_np = as_numpy_array(actual, xp=xp)
299297
desired_np = as_numpy_array(desired, xp=xp)
300-
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
298+
np.testing.assert_allclose( # pyright: ignore[reportCallIssue] # pyrefly: ignore[no-matching-overload]
301299
actual_np,
302300
desired_np,
303-
rtol=rtol, # pyright: ignore[reportArgumentType]
304-
atol=atol,
301+
# https://github.com/numpy/numpy/issues/31449
302+
rtol=rtol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
303+
atol=atol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
305304
equal_nan=equal_nan,
306305
err_msg=err_msg,
307306
verbose=verbose,

0 commit comments

Comments
 (0)