Skip to content

Commit 69661fc

Browse files
Fix #720; bool event + root find + terminate on first step
1 parent 62bf876 commit 69661fc

2 files changed

Lines changed: 35 additions & 3 deletions

File tree

diffrax/_integrate.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,11 @@ def _call_real_impl():
761761
_tfinal = _event_root_find.value
762762
# TODO: we might need to change the way we evaluate `_yfinal` in order to
763763
# get more accurate derivatives?
764-
_yfinal = _interpolator.evaluate(_tfinal)
764+
_yfinal = lax.cond(
765+
final_state.num_steps == 0,
766+
lambda: final_state.y,
767+
lambda: _interpolator.evaluate(_tfinal),
768+
)
765769
_result = RESULTS.where(
766770
_event_root_find.result == optx.RESULTS.successful,
767771
result,
@@ -1323,7 +1327,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
13231327
event_mask = None
13241328
else:
13251329
event_tprev = tprev
1326-
event_tnext = tnext
1330+
event_tnext = tprev
13271331
# Fill the dense-info with dummy values on the first step, when we haven't yet
13281332
# made any steps.
13291333
# Note that we're threading a needle here! What if we terminate on the very
@@ -1334,8 +1338,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
13341338
# to the end of the interval).
13351339
# - A floating event can't terminate on the first step (it requires a sign
13361340
# change).
1341+
# c.f. https://github.com/patrick-kidger/diffrax/issues/720
13371342
event_dense_info = jtu.tree_map(
1338-
lambda x: jnp.empty(x.shape, x.dtype),
1343+
lambda x: jnp.zeros(x.shape, x.dtype),
13391344
dense_info_struct, # pyright: ignore[reportPossiblyUnboundVariable]
13401345
)
13411346

test/test_event.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,3 +833,30 @@ def run(event):
833833
t_final, y_final = run(event)
834834
assert jnp.allclose(t_final, 10.0)
835835
assert jnp.allclose(y_final, jnp.array([10.0, 11.0]))
836+
837+
838+
# https://github.com/patrick-kidger/diffrax/issues/720
839+
def test_boolean_with_root_find_terminating_on_first_step():
840+
controller = diffrax.PIDController(rtol=1e-6, atol=1e-6)
841+
steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6)
842+
root_finder = optx.Newton(atol=1e-4, rtol=1e-4)
843+
844+
sol = diffrax.diffeqsolve(
845+
diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)),
846+
diffrax.Kvaerno5(),
847+
t0=0.0,
848+
t1=1.2,
849+
dt0=None,
850+
y0=jnp.array([10.0]),
851+
stepsize_controller=controller,
852+
event=diffrax.Event(
853+
cond_fn=steady_state_event,
854+
root_finder=root_finder,
855+
),
856+
saveat=diffrax.SaveAt(t1=True),
857+
max_steps=100,
858+
)
859+
assert sol.ts is not None
860+
assert sol.ys is not None
861+
assert jnp.allclose(sol.ts, jnp.array([0.0]))
862+
assert jnp.allclose(sol.ys, jnp.array([[10.0]]))

0 commit comments

Comments
 (0)