Skip to content

Commit 2589876

Browse files
Silence warning from pure_callback(vectorized=True)
1 parent 10bff5a commit 2589876

2 files changed

Lines changed: 6 additions & 3 deletions

File tree

diffrax/_progress_meter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,14 @@ def _step(_progress, _idx):
291291
except KeyError:
292292
pass # E.g. the backward pass after a forward pass.
293293
else:
294-
step_bar(bar, _progress)
294+
# As above, `_idx` may have a spurious batch tracer. Correspondingly
295+
# `_progress` may pick up spurious length-1 batch dimensions from
296+
# `vmap_method="expand_dims"` below. Remove them now.
297+
step_bar(bar, np.array(_progress).reshape(()))
295298
# Return the idx to thread the callbacks in the correct order.
296299
return _idx
297300

298-
return jax.pure_callback(_step, idx, progress, idx, vectorized=True)
301+
return jax.pure_callback(_step, idx, progress, idx, vmap_method="expand_dims")
299302

300303
def close(self, close_bar: Callable[[Any], None], idx: IntScalarLike):
301304
def _close(_idx):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Topic :: Scientific/Engineering :: Mathematics",
2424
]
2525
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
26-
dependencies = ["jax>=0.4.28", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.7"]
26+
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.7"]
2727

2828
[build-system]
2929
requires = ["hatchling"]

0 commit comments

Comments
 (0)