Skip to content

Commit a154f0f

Browse files
committed
revisions
1 parent 65512cb commit a154f0f

1 file changed

Lines changed: 24 additions & 15 deletions

File tree

src/array_api_extra/testing.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@
3131
from ._lib._utils._helpers import jax_autojit, pickle_flatten, pickle_unflatten
3232
from ._lib._utils._typing import Array, Device
3333

34-
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
34+
__all__ = [
35+
"assert_close",
36+
"assert_equal",
37+
"assert_less",
38+
"lazy_xp_function",
39+
"patch_lazy_xp_functions",
40+
]
3541

3642
if TYPE_CHECKING: # pragma: no cover
3743
# TODO import override from typing (requires Python >=3.12)
@@ -467,7 +473,7 @@ def revert_on_exit() -> Generator[None]: # numpydoc ignore=GL08
467473
return revert_on_exit()
468474

469475

470-
class CountingDaskScheduler(SchedulerGetCallable):
476+
class _CountingDaskScheduler(SchedulerGetCallable):
471477
"""
472478
Dask scheduler that counts how many times `dask.compute` is called.
473479
@@ -527,7 +533,7 @@ def _dask_wrap(
527533

528534
@wraps(func)
529535
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
530-
scheduler = CountingDaskScheduler(n, msg)
536+
scheduler = _CountingDaskScheduler(n, msg)
531537
with dask.config.set({"scheduler": scheduler}): # pyright: ignore[reportPrivateImportUsage]
532538
out = func(*args, **kwargs)
533539

@@ -541,7 +547,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
541547
return wrapper
542548

543549

544-
def require_numpy() -> ModuleType: # numpydoc ignore=RT01
550+
def _require_numpy() -> ModuleType: # numpydoc ignore=RT01
545551
"""
546552
Import and return `numpy` if it is available, otherwise raise informative error.
547553
"""
@@ -588,7 +594,7 @@ def _check_ns_shape_dtype(
588594
-------
589595
Actual array, desired array, their array namespace, the numpy module.
590596
"""
591-
np = require_numpy()
597+
np = _require_numpy()
592598

593599
actual_xp = array_namespace(actual) # Raises on Python scalars and lists
594600
desired_xp = array_namespace(desired)
@@ -631,7 +637,7 @@ def _check_ns_shape_dtype(
631637
desired_shape = cast(tuple[float, ...], desired.shape)
632638

633639
if check_shape:
634-
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
640+
msg = f"shapes do not match: {actual_shape} != {desired_shape}"
635641
assert actual_shape == desired_shape, msg
636642
elif desired.ndim > 0:
637643
# Ignore shape, but check flattened size. This is normally done by
@@ -640,7 +646,7 @@ def _check_ns_shape_dtype(
640646
# This check excludes 0d arrays as they are special-cased in NumPy.
641647
actual_size = math.prod(actual_shape)
642648
desired_size = math.prod(desired_shape)
643-
msg = f"sizes do not match: {actual_size} != f{desired_size}"
649+
msg = f"sizes do not match: {actual_size} != {desired_size}"
644650
assert actual_size == desired_size, msg
645651

646652
if check_dtype:
@@ -665,7 +671,7 @@ def _as_numpy_array( # numpydoc ignore=PR01,RT01
665671
"""
666672
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
667673
"""
668-
np = require_numpy()
674+
np = _require_numpy()
669675
if is_cupy_namespace(xp):
670676
return xp.asnumpy(array)
671677
if is_pydata_sparse_namespace(xp):
@@ -708,7 +714,7 @@ def assert_close(
708714
xp: ModuleType | None = None,
709715
) -> None:
710716
"""
711-
Check that two arrays are close, up to a tolerance.
717+
Check that two arrays are close, up to tolerance ``atol + rtol * abs(desired)``.
712718
713719
This is an interface to :func:`numpy.testing.assert_allclose` which accepts
714720
any standard-compatible array and performs additional array namespace,
@@ -755,9 +761,12 @@ def assert_close(
755761
-----
756762
The default `atol` and `rtol` differ from ``xp.all(xpx.isclose(a, b))``.
757763
For inexact dtypes, the default `rtol` is
758-
``xp.finfo(actual.dtype).eps ** 0.5 * 4``, which is roughly halfway between
759-
:math:`\\sqrt{\\epsilon}` and the default for :func:`numpy.testing.assert_allclose`,
760-
``1e-7``. For other dtypes, the default ``1e-7`` is used.
764+
``xp.finfo(actual.dtype).eps ** 0.5 * 4``, which for ``float64`` is roughly halfway
765+
between :math:`\\sqrt{\\epsilon}` and the default for
766+
:func:`numpy.testing.assert_allclose`, ``1e-7``.
767+
This gives a more reasonable default for lower precision dtypes,
768+
for example approximately ``1e-3`` for ``float32``.
769+
For exact dtypes, the default ``1e-7`` is used.
761770
762771
Array arguments to `atol` and `rtol` must be valid input to :class:`float`.
763772
"""
@@ -880,14 +889,14 @@ def assert_less(
880889
verbose : bool, default: True
881890
Whether to include the conflicting arrays in the error message on failure.
882891
check_dtype : bool, default: True
883-
Whether to check agreement between actual and desired dtypes.
892+
Whether to check agreement between the dtypes of `x` and `y`.
884893
check_shape : bool, default: True
885-
Whether to check agreement between actual and desired shapes.
894+
Whether to check agreement between the shapes of `x` and `y`.
886895
check_scalar : bool, default: False
887896
NumPy only: whether to check agreement between actual and desired types —
888897
0-D :class:`numpy.ndarray` vs scalar (e.g. :class:`numpy.double`).
889898
xp : array_namespace, optional
890-
A standard-compatible namespace which `actual` and `desired` must match.
899+
A standard-compatible namespace which `x` and `y` must match.
891900
892901
Raises
893902
------

0 commit comments

Comments
 (0)