3131from ._lib ._utils ._helpers import jax_autojit , pickle_flatten , pickle_unflatten
3232from ._lib ._utils ._typing import Array , Device
3333
34- __all__ = ["lazy_xp_function" , "patch_lazy_xp_functions" ]
34+ __all__ = [
35+ "assert_close" ,
36+ "assert_equal" ,
37+ "assert_less" ,
38+ "lazy_xp_function" ,
39+ "patch_lazy_xp_functions" ,
40+ ]
3541
3642if TYPE_CHECKING : # pragma: no cover
3743 # TODO import override from typing (requires Python >=3.12)
@@ -467,7 +473,7 @@ def revert_on_exit() -> Generator[None]: # numpydoc ignore=GL08
467473 return revert_on_exit ()
468474
469475
470- class CountingDaskScheduler (SchedulerGetCallable ):
476+ class _CountingDaskScheduler (SchedulerGetCallable ):
471477 """
472478 Dask scheduler that counts how many times `dask.compute` is called.
473479
@@ -527,7 +533,7 @@ def _dask_wrap(
527533
528534 @wraps (func )
529535 def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
530- scheduler = CountingDaskScheduler (n , msg )
536+ scheduler = _CountingDaskScheduler (n , msg )
531537 with dask .config .set ({"scheduler" : scheduler }): # pyright: ignore[reportPrivateImportUsage]
532538 out = func (* args , ** kwargs )
533539
@@ -541,7 +547,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
541547 return wrapper
542548
543549
544- def require_numpy () -> ModuleType : # numpydoc ignore=RT01
550+ def _require_numpy () -> ModuleType : # numpydoc ignore=RT01
545551 """
546552 Import and return `numpy` if it is available, otherwise raise informative error.
547553 """
@@ -588,7 +594,7 @@ def _check_ns_shape_dtype(
588594 -------
589595 Actual array, desired array, their array namespace, the numpy module.
590596 """
591- np = require_numpy ()
597+ np = _require_numpy ()
592598
593599 actual_xp = array_namespace (actual ) # Raises on Python scalars and lists
594600 desired_xp = array_namespace (desired )
@@ -631,7 +637,7 @@ def _check_ns_shape_dtype(
631637 desired_shape = cast (tuple [float , ...], desired .shape )
632638
633639 if check_shape :
634- msg = f"shapes do not match: { actual_shape } != f { desired_shape } "
640+ msg = f"shapes do not match: { actual_shape } != { desired_shape } "
635641 assert actual_shape == desired_shape , msg
636642 elif desired .ndim > 0 :
637643 # Ignore shape, but check flattened size. This is normally done by
@@ -640,7 +646,7 @@ def _check_ns_shape_dtype(
640646 # This check excludes 0d arrays as they are special-cased in NumPy.
641647 actual_size = math .prod (actual_shape )
642648 desired_size = math .prod (desired_shape )
643- msg = f"sizes do not match: { actual_size } != f { desired_size } "
649+ msg = f"sizes do not match: { actual_size } != { desired_size } "
644650 assert actual_size == desired_size , msg
645651
646652 if check_dtype :
@@ -665,7 +671,7 @@ def _as_numpy_array( # numpydoc ignore=PR01,RT01
665671 """
666672 Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
667673 """
668- np = require_numpy ()
674+ np = _require_numpy ()
669675 if is_cupy_namespace (xp ):
670676 return xp .asnumpy (array )
671677 if is_pydata_sparse_namespace (xp ):
@@ -708,7 +714,7 @@ def assert_close(
708714 xp : ModuleType | None = None ,
709715) -> None :
710716 """
711- Check that two arrays are close, up to a tolerance.
717+ Check that two arrays are close, up to tolerance ``atol + rtol * abs(desired)`` .
712718
713719 This is an interface to :func:`numpy.testing.assert_allclose` which accepts
714720 any standard-compatible array and performs additional array namespace,
@@ -755,9 +761,12 @@ def assert_close(
755761 -----
756762 The default `atol` and `rtol` differ from ``xp.all(xpx.isclose(a, b))``.
757763 For inexact dtypes, the default `rtol` is
758- ``xp.finfo(actual.dtype).eps ** 0.5 * 4``, which is roughly halfway between
759- :math:`\\ sqrt{\\ epsilon}` and the default for :func:`numpy.testing.assert_allclose`,
760- ``1e-7``. For other dtypes, the default ``1e-7`` is used.
764+ ``xp.finfo(actual.dtype).eps ** 0.5 * 4``, which for ``float64`` is roughly halfway
765+ between :math:`\\ sqrt{\\ epsilon}` and the default for
766+ :func:`numpy.testing.assert_allclose`, ``1e-7``.
767+ This gives a more reasonable default for lower precision dtypes,
768+ for example approximately ``1e-3`` for ``float32``.
769+ For exact dtypes, the default ``1e-7`` is used.
761770
762771 Array arguments to `atol` and `rtol` must be valid input to :class:`float`.
763772 """
@@ -880,14 +889,14 @@ def assert_less(
880889 verbose : bool, default: True
881890 Whether to include the conflicting arrays in the error message on failure.
882891 check_dtype : bool, default: True
883- Whether to check agreement between actual and desired dtypes .
892+ Whether to check agreement between the dtypes of `x` and `y` .
884893 check_shape : bool, default: True
885- Whether to check agreement between actual and desired shapes .
894+ Whether to check agreement between the shapes of `x` and `y` .
886895 check_scalar : bool, default: False
887896 NumPy only: whether to check agreement between actual and desired types —
888897 0-D :class:`numpy.ndarray` vs scalar (e.g. :class:`numpy.double`).
889898 xp : array_namespace, optional
890- A standard-compatible namespace which `actual ` and `desired ` must match.
899+ A standard-compatible namespace which `x ` and `y ` must match.
891900
892901 Raises
893902 ------
0 commit comments