@@ -833,3 +833,30 @@ def run(event):
833833 t_final , y_final = run (event )
834834 assert jnp .allclose (t_final , 10.0 )
835835 assert jnp .allclose (y_final , jnp .array ([10.0 , 11.0 ]))
836+
837+
838+ # https://github.com/patrick-kidger/diffrax/issues/720
839+ def test_boolean_with_root_find_terminating_on_first_step ():
840+ controller = diffrax .PIDController (rtol = 1e-6 , atol = 1e-6 )
841+ steady_state_event = diffrax .steady_state_event (rtol = 1e-6 , atol = 1e-6 )
842+ root_finder = optx .Newton (atol = 1e-4 , rtol = 1e-4 )
843+
844+ sol = diffrax .diffeqsolve (
845+ diffrax .ODETerm (lambda t , y , args : jnp .zeros_like (y )),
846+ diffrax .Kvaerno5 (),
847+ t0 = 0.0 ,
848+ t1 = 1.2 ,
849+ dt0 = None ,
850+ y0 = jnp .array ([10.0 ]),
851+ stepsize_controller = controller ,
852+ event = diffrax .Event (
853+ cond_fn = steady_state_event ,
854+ root_finder = root_finder ,
855+ ),
856+ saveat = diffrax .SaveAt (t1 = True ),
857+ max_steps = 100 ,
858+ )
859+ assert sol .ts is not None
860+ assert sol .ys is not None
861+ assert jnp .allclose (sol .ts , jnp .array ([0.0 ]))
862+ assert jnp .allclose (sol .ys , jnp .array ([[10.0 ]]))
0 commit comments