diff --git a/diffrax/__init__.py b/diffrax/__init__.py index d35a7fac..207f5fe8 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -1,5 +1,12 @@ import importlib.metadata +from equinox.internal import ( + AbstractProgressMeter as AbstractProgressMeter, + NoProgressMeter as NoProgressMeter, + TextProgressMeter as TextProgressMeter, + TqdmProgressMeter as TqdmProgressMeter, +) + from ._adjoint import ( AbstractAdjoint as AbstractAdjoint, BacksolveAdjoint as BacksolveAdjoint, @@ -49,12 +56,6 @@ ) from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm from ._path import AbstractPath as AbstractPath -from ._progress_meter import ( - AbstractProgressMeter as AbstractProgressMeter, - NoProgressMeter as NoProgressMeter, - TextProgressMeter as TextProgressMeter, - TqdmProgressMeter as TqdmProgressMeter, -) from ._root_finder import ( VeryChord as VeryChord, with_stepsize_controller_tols as with_stepsize_controller_tols, diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 8241fa9b..74fa9ce1 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -19,6 +19,10 @@ import numpy as np import optimistix as optx import wadler_lindig as wl +from equinox.internal import ( + AbstractProgressMeter, + NoProgressMeter, +) from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint @@ -38,10 +42,6 @@ from ._global_interpolation import DenseInterpolation from ._heuristics import is_sde, is_unsafe_sde from ._misc import linear_rescale, static_select -from ._progress_meter import ( - AbstractProgressMeter, - NoProgressMeter, -) from ._root_finder import use_stepsize_tol from ._saveat import save_y, SaveAt, SubSaveAt from ._solution import is_okay, is_successful, RESULTS, Solution diff --git a/diffrax/_progress_meter.py b/diffrax/_progress_meter.py deleted file mode 100644 index a7f9574a..00000000 --- a/diffrax/_progress_meter.py +++ /dev/null @@ -1,316 +0,0 @@ -import abc -import importlib.util -import threading -from collections.abc import Callable -from typing import Any, cast, Generic, TypeVar - -import equinox as eqx -import equinox.internal as eqxi -import jax -import jax.numpy as jnp -import numpy as np -from jax.experimental import io_callback -from jaxtyping import Array, PyTree - -from ._custom_types import FloatScalarLike, IntScalarLike, RealScalarLike - - -_State = TypeVar("_State", bound=PyTree[Array]) - - -class AbstractProgressMeter(eqx.Module, Generic[_State]): - """Progress meters used to indicate how far along a solve is. Typically these - perform some kind of printout as the solve progresses. - """ - - @abc.abstractmethod - def init(self) -> _State: - """Initialises the state for a new progress meter. - - **Arguments:** - - Nothing. - - **Returns:** - - The initial state for the progress meter. - """ - - @abc.abstractmethod - def step(self, state: _State, progress: FloatScalarLike) -> _State: - """Updates the progress meter. Called on every numerical step of a differential - equation solve. - - **Arguments:** - - - `state`: the state from the previous step. - - `progress`: how far along the solve is, as a number in `[0, 1]`. - - **Returns:** - - The updated state. In addition, the meter is expected to update as a - side-effect. - """ - - @abc.abstractmethod - def close(self, state: _State): - """Closes the progress meter. Called at the end of a differential equation - solve. - - **Arguments:** - - - `state`: the final state from the end of the solve. - - *Returns:** - - None. - """ - - -class NoProgressMeter(AbstractProgressMeter): - """Indicates that no progress meter should be displayed during the solve.""" - - def init(self) -> None: - return None - - def step(self, state, progress: FloatScalarLike) -> None: - del progress - return state - - def close(self, state): - del state - - -NoProgressMeter.__init__.__doc__ = """**Arguments:** - -Nothing. -""" - - -def _unvmap_min(x): # No `eqxi.unvmap_min` at the moment. - return -eqxi.unvmap_max(-x) - - -class _TextProgressMeterState(eqx.Module): - progress: FloatScalarLike - meter_idx: IntScalarLike - - -class TextProgressMeter(AbstractProgressMeter): - """A text progress meter, printing out e.g.: - ``` - 0.00% - 2.00% - 5.30% - ... - 100.00% - ``` - """ - - minimum_increase: RealScalarLike = 0.02 - - @staticmethod - def _init_bar() -> list[float]: - print("0.00%") - return [0.0] - - def init(self) -> _TextProgressMeterState: - meter_idx = _progress_meter_manager.init(self._init_bar) - return _TextProgressMeterState(meter_idx=meter_idx, progress=jnp.array(0.0)) - - @staticmethod - def _step_bar(bar: list[float], progress: FloatScalarLike) -> None: - if eqx.is_array(progress): - # May not be an array when called with `JAX_DISABLE_JIT=1` - progress = cast(Array | np.ndarray, progress) - progress = cast(float, progress.item()) - else: - progress = cast(float, progress) - bar[0] = progress - print(f"{100 * progress:.2f}%") - - def step( - self, state: _TextProgressMeterState, progress: FloatScalarLike - ) -> _TextProgressMeterState: - # When `diffeqsolve(..., t0=..., t1=...)` are batched, then both - # `state.progress` and `progress` will pick up a batch tracer. - # (For the former, because the condition for the while-loop-over-steps becomes - # batched, so necessarily everything in the body of the loop is as well.) - pred = eqxi.unvmap_all( - (progress - state.progress > self.minimum_increase) | (progress == 1) - ) - - # We only print if the progress has increased by at least `minimum_increase` to - # avoid flooding the user with too many updates. - next_progress, meter_idx = jax.lax.cond( - eqxi.nonbatchable(pred), - lambda _idx: ( - progress, - _progress_meter_manager.step(self._step_bar, progress, _idx), - ), - lambda _idx: (state.progress, _idx), - state.meter_idx, - ) - - return _TextProgressMeterState(progress=next_progress, meter_idx=meter_idx) - - @staticmethod - def _close_bar(bar: list[float]): - if bar[0] != 1: - print("100.00%") - - def close(self, state: _TextProgressMeterState): - _progress_meter_manager.close(self._close_bar, state.meter_idx) - - -TextProgressMeter.__init__.__doc__ = """**Arguments:** - -- `minimum_increase`: the minimum amount the progress has to have increased in order to - print out a new line. The progress starts at 0 at the beginning of the solve, and - increases to 1 at the end of the solve. Defaults to `0.02`, so that a new line is - printed each time the progress increases another 2%. -""" - - -class _TqdmProgressMeterState(eqx.Module): - meter_idx: IntScalarLike - step: IntScalarLike - - -class TqdmProgressMeter(AbstractProgressMeter): - """Uses tqdm to display a progress bar for the solve.""" - - refresh_steps: int = 20 - - def __check_init__(self): - if importlib.util.find_spec("tqdm") is None: - raise ValueError( - "Cannot use `diffrax.TqdmProgressMeter` without `tqdm` installed. " - "Install it via `pip install tqdm`." - ) - - @staticmethod - def _init_bar() -> "tqdm.tqdm": # pyright: ignore[reportUndefinedVariable] # noqa: F821 - import tqdm - - bar_format = ( - "{percentage:.2f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]" - ) - return tqdm.tqdm( - total=100, - unit="%", - bar_format=bar_format, - ) - - def init(self) -> _TqdmProgressMeterState: - meter_idx = _progress_meter_manager.init(self._init_bar) - return _TqdmProgressMeterState(meter_idx=meter_idx, step=jnp.array(0)) - - @staticmethod - def _step_bar(bar: "tqdm.tqdm", progress: FloatScalarLike) -> None: # pyright: ignore # noqa: F821 - bar.n = round(100 * float(progress), 2) - bar.update(n=0) - bar.refresh() - - def step( - self, - state: _TqdmProgressMeterState, - progress: FloatScalarLike, - ) -> _TqdmProgressMeterState: - # Here we update every `refresh_rate` steps in order to limit expensive - # callbacks. - # The `unvmap_max` is because batch values for `state.step` start off in sync, - # and then eventually will freeze their values as that batch element finishes - # its solve. So take a `max` to get the true number of overall solve steps for - # the batched system. - meter_idx = jax.lax.cond( - eqxi.nonbatchable(eqxi.unvmap_max(state.step) % self.refresh_steps == 0), - lambda _idx: _progress_meter_manager.step(self._step_bar, progress, _idx), - lambda _idx: _idx, - state.meter_idx, - ) - return _TqdmProgressMeterState(meter_idx=meter_idx, step=state.step + 1) - - @staticmethod - def _close_bar(bar: "tqdm.tqdm"): # pyright: ignore # noqa: F821 - bar.n = 100.0 - bar.update(n=0) - bar.close() - - def close(self, state: _TqdmProgressMeterState): - _progress_meter_manager.close(self._close_bar, state.meter_idx) - - -TqdmProgressMeter.__init__.__doc__ = """**Arguments:** - -- `refresh_steps`: the number of numerical steps between refreshing the bar. Used to - limit how frequently the (potentially computationally expensive) bar update is - performed. -""" - - -class _ProgressMeterManager: - """Host-side progress meter manager.""" - - def __init__(self): - self.idx = 0 - self.bars = {} - # Not sure how important a lock really is, but included just in case. - self.lock = threading.Lock() - - def init(self, init_bar: Callable[[], Any]) -> IntScalarLike: - def _init() -> IntScalarLike: - with self.lock: - bar = init_bar() - self.idx += 1 - self.bars[self.idx] = bar - return np.array(self.idx, dtype=jnp.int32) - - # Not `pure_callback` because it's not a deterministic function of its input - # arguments. - # Not `debug.callback` because it has a return value. - meter_idx = io_callback(_init, jax.ShapeDtypeStruct((), jnp.int32)) - return eqxi.nonbatchable(meter_idx) - - def step( - self, - step_bar: Callable[[Any, FloatScalarLike], None], - progress: FloatScalarLike, - idx: IntScalarLike, - ) -> IntScalarLike: - # Track the slowest batch element. - progress = _unvmap_min(progress) - - def _step(_progress, _idx): - with self.lock: - try: - # This may pick up a spurious batch tracer from a batched condition, - # so we need to handle that. We do this by using an `np.unique`. - # It should always be the case that `_idx` has precisely one value! - bar = self.bars[np.unique(_idx).item()] - except KeyError: - pass # E.g. the backward pass after a forward pass. - else: - # As above, `_idx` may have a spurious batch tracer. Correspondingly - # `_progress` may pick up spurious length-1 batch dimensions from - # `vmap_method="expand_dims"` below. Remove them now. - step_bar(bar, np.array(_progress).reshape(())) - # Return the idx to thread the callbacks in the correct order. - return _idx - - return jax.pure_callback(_step, idx, progress, idx, vmap_method="expand_dims") - - def close(self, close_bar: Callable[[Any], None], idx: IntScalarLike): - def _close(_idx): - with self.lock: - _idx = _idx.item() - bar = self.bars[_idx] - close_bar(bar) - del self.bars[_idx] - - # Unlike in `step`, we do the `unvmap_max` here. For mysterious reasons this - # callback does not trigger at all otherwise. - io_callback(_close, None, eqxi.unvmap_max(idx)) - - -_progress_meter_manager = _ProgressMeterManager() diff --git a/docs/api/progress_meter.md b/docs/api/progress_meter.md index 86a923ce..b43e6974 100644 --- a/docs/api/progress_meter.md +++ b/docs/api/progress_meter.md @@ -4,26 +4,24 @@ As the solve progresses, progress meters offer the ability to have some kind of ??? abstract "`diffrax.AbstractProgressMeter`" - ::: diffrax.AbstractProgressMeter - options: - members: - - init - - step - - close + An abstract base class for all progress meters. + + **Methods:** + + - `init()` + - `step()` + - `close()` --- -::: diffrax.NoProgressMeter - options: - members: - - __init__ +### `diffrax.NoProgressMeter` + +A progress meter that does nothing. + +### `diffrax.TextProgressMeter` + +A progress meter that prints text to the console. -::: diffrax.TextProgressMeter - options: - members: - - __init__ +### `diffrax.TqdmProgressMeter` -::: diffrax.TqdmProgressMeter - options: - members: - - __init__ +A progress meter that displays a tqdm progress bar. diff --git a/mkdocs.yml b/mkdocs.yml index 06338eb4..9ee4f73c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,10 @@ plugins: - lineax.AutoLinearSolver - optimistix.AbstractRootFinder - optimistix.Chord + - equinox.internal._progress_meter.AbstractProgressMeter + - equinox.internal._progress_meter.NoProgressMeter + - equinox.internal._progress_meter.TextProgressMeter + - equinox.internal._progress_meter.TqdmProgressMeter - mkdocstrings: handlers: python: