@@ -258,15 +258,11 @@ def _save(
258258 )
259259
260260
261- def _clip_to_end (tprev , tnext , t1 , keep_step ):
262- # The tolerance means that we don't end up with too-small intervals for
263- # dense output, which then gives numerically unstable answers due to floating
261+ def _clip_to_end (tprev , tnext , t1 , t1_clip_floor , keep_step ):
262+ # The tolerance of ~100 ULP's means that we don't end up with too-small intervals
263+ # for dense output, which then gives numerically unstable answers due to floating
264264 # point errors.
265- if tnext .dtype is jnp .dtype ("float64" ):
266- tol = 1e-10
267- else :
268- tol = 1e-6
269- clip = tnext > t1 - tol
265+ clip = tnext > t1_clip_floor
270266 tclip = jnp .where (keep_step , t1 , tprev + 0.5 * (t1 - tprev ))
271267 return jnp .where (clip , tclip , tnext )
272268
@@ -303,6 +299,11 @@ def loop(
303299 outer_while_loop ,
304300 progress_meter ,
305301):
302+ # Calculate in advance t1 - 100 ULP's: the threshold at which to round tnext to t1
303+ t1_clip_floor = t1
304+ for _ in range (100 ):
305+ t1_clip_floor = eqxi .prevbefore (t1_clip_floor )
306+
306307 if saveat .dense :
307308 dense_ts = init_state .dense_ts
308309 dense_ts = dense_ts .at [0 ].set (t0 )
@@ -392,7 +393,7 @@ def body_fun_aux(state):
392393 #
393394
394395 tprev = jnp .minimum (tprev , t1 )
395- tnext = _clip_to_end (tprev , tnext , t1 , keep_step )
396+ tnext = _clip_to_end (tprev , tnext , t1 , t1_clip_floor , keep_step )
396397
397398 progress_meter_state = progress_meter .step (
398399 state .progress_meter_state , linear_rescale (t0 , tprev , t1 )
0 commit comments