@@ -263,15 +263,11 @@ def _save(
263263 )
264264
265265
266- def _clip_to_end (tprev , tnext , t1 , keep_step ):
267- # The tolerance means that we don't end up with too-small intervals for
268- # dense output, which then gives numerically unstable answers due to floating
266+ def _clip_to_end (tprev , tnext , t1 , t1_clip_floor , keep_step ):
267+ # The tolerance of ~100 ULP's means that we don't end up with too-small intervals
268+ # for dense output, which then gives numerically unstable answers due to floating
269269 # point errors.
270- if tnext .dtype is jnp .dtype ("float64" ):
271- tol = 1e-10
272- else :
273- tol = 1e-6
274- clip = tnext > t1 - tol
270+ clip = tnext > t1_clip_floor
275271 tclip = jnp .where (keep_step , t1 , tprev + 0.5 * (t1 - tprev ))
276272 return jnp .where (clip , tclip , tnext )
277273
@@ -308,6 +304,11 @@ def loop(
308304 outer_while_loop ,
309305 progress_meter ,
310306):
307+ # Calculate in advance t1 - 100 ULP's: the threshold at which to round tnext to t1
308+ t1_clip_floor = t1
309+ for _ in range (100 ):
310+ t1_clip_floor = eqxi .prevbefore (t1_clip_floor )
311+
311312 if saveat .dense :
312313 dense_ts = init_state .dense_ts
313314 dense_ts = dense_ts .at [0 ].set (t0 )
@@ -397,7 +398,7 @@ def body_fun_aux(state):
397398 #
398399
399400 tprev = jnp .minimum (tprev , t1 )
400- tnext = _clip_to_end (tprev , tnext , t1 , keep_step )
401+ tnext = _clip_to_end (tprev , tnext , t1 , t1_clip_floor , keep_step )
401402
402403 progress_meter_state = progress_meter .step (
403404 state .progress_meter_state , linear_rescale (t0 , tprev , t1 )
0 commit comments