Skip to content

Commit e8efca1

Browse files
committed
Use 100 ULP's to clip timesteps close to t1
1 parent d3c1430 commit e8efca1

1 file changed

Lines changed: 10 additions & 9 deletions

File tree

diffrax/_integrate.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,11 @@ def _save(
258258
)
259259

260260

261-
def _clip_to_end(tprev, tnext, t1, keep_step):
262-
# The tolerance means that we don't end up with too-small intervals for
263-
# dense output, which then gives numerically unstable answers due to floating
261+
def _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step):
262+
# The tolerance of ~100 ULP's means that we don't end up with too-small intervals
263+
# for dense output, which then gives numerically unstable answers due to floating
264264
# point errors.
265-
if tnext.dtype is jnp.dtype("float64"):
266-
tol = 1e-10
267-
else:
268-
tol = 1e-6
269-
clip = tnext > t1 - tol
265+
clip = tnext > t1_clip_floor
270266
tclip = jnp.where(keep_step, t1, tprev + 0.5 * (t1 - tprev))
271267
return jnp.where(clip, tclip, tnext)
272268

@@ -303,6 +299,11 @@ def loop(
303299
outer_while_loop,
304300
progress_meter,
305301
):
302+
# Calculate in advance t1 - 100 ULP's: the threshold at which to round tnext to t1
303+
t1_clip_floor = t1
304+
for _ in range(100):
305+
t1_clip_floor = eqxi.prevbefore(t1_clip_floor)
306+
306307
if saveat.dense:
307308
dense_ts = init_state.dense_ts
308309
dense_ts = dense_ts.at[0].set(t0)
@@ -392,7 +393,7 @@ def body_fun_aux(state):
392393
#
393394

394395
tprev = jnp.minimum(tprev, t1)
395-
tnext = _clip_to_end(tprev, tnext, t1, keep_step)
396+
tnext = _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step)
396397

397398
progress_meter_state = progress_meter.step(
398399
state.progress_meter_state, linear_rescale(t0, tprev, t1)

0 commit comments

Comments
 (0)