Skip to content

Commit 8e9d99f

Browse files
Fixes 681
1 parent b7dc392 commit 8e9d99f

1 file changed

Lines changed: 2 additions & 8 deletions

File tree

diffrax/_integrate.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
161161
pass
162162
elif n_term_args == 2:
163163
vf_type_expected, control_type_expected = term_args
164-
try:
165-
vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
166-
except Exception as e:
167-
raise ValueError(f"Error while tracing {term}.vf: " + str(e))
164+
vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
168165
vf_type_compatible = eqx.filter_eval_shape(
169166
better_isinstance, vf_type, vf_type_expected
170167
)
@@ -173,10 +170,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
173170

174171
contr = ft.partial(term.contr, **term_contr_kwargs)
175172
# Work around https://github.com/google/jax/issues/21825
176-
try:
177-
control_type = eqx.filter_eval_shape(contr, t, t)
178-
except Exception as e:
179-
raise ValueError(f"Error while tracing {term}.contr: " + str(e))
173+
control_type = eqx.filter_eval_shape(contr, t, t)
180174
control_type_compatible = eqx.filter_eval_shape(
181175
better_isinstance, control_type, control_type_expected
182176
)

0 commit comments

Comments
 (0)