Skip to content

Commit aca5a16

Browse files
Work around JAX issue in 0.4.29
1 parent 5e57351 commit aca5a16

2 files changed

Lines changed: 3 additions & 2 deletions

File tree

diffrax/_integrate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def _check(term_cls, term, term_contr_kwargs, yi):
144144
raise ValueError
145145

146146
contr = ft.partial(term.contr, **term_contr_kwargs)
147-
control_type = jax.eval_shape(contr, 0.0, 0.0)
147+
# Work around https://github.com/google/jax/issues/21825
148+
control_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
148149
control_type_compatible = eqx.filter_eval_shape(
149150
better_isinstance, control_type, control_type_expected
150151
)

diffrax/_term.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def is_vf_expensive(
493493
],
494494
args: Args,
495495
) -> bool:
496-
control_struct = jax.eval_shape(self.contr, t0, t1)
496+
control_struct = eqx.filter_eval_shape(self.contr, t0, t1)
497497
if sum(c.size for c in jtu.tree_leaves(control_struct)) in (0, 1):
498498
return False
499499
else:

0 commit comments

Comments
 (0)