Skip to content

Commit 2709081

Browse files
authored
Use 100 ULP's to clip timesteps close to t1 (#660)
* Use 100 ULP's to clip timesteps close to t1 * test that t1-t0 > 100 ULP's * revert testing as t1 is traced * remove unnecessary pyright ignores
1 parent 514fc9e commit 2709081

1 file changed

Lines changed: 10 additions & 9 deletions

File tree

diffrax/_integrate.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,11 @@ def _save(
263263
)
264264

265265

266-
def _clip_to_end(tprev, tnext, t1, keep_step):
267-
# The tolerance means that we don't end up with too-small intervals for
268-
# dense output, which then gives numerically unstable answers due to floating
266+
def _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step):
267+
# The tolerance of ~100 ULP's means that we don't end up with too-small intervals
268+
# for dense output, which then gives numerically unstable answers due to floating
269269
# point errors.
270-
if tnext.dtype is jnp.dtype("float64"):
271-
tol = 1e-10
272-
else:
273-
tol = 1e-6
274-
clip = tnext > t1 - tol
270+
clip = tnext > t1_clip_floor
275271
tclip = jnp.where(keep_step, t1, tprev + 0.5 * (t1 - tprev))
276272
return jnp.where(clip, tclip, tnext)
277273

@@ -308,6 +304,11 @@ def loop(
308304
outer_while_loop,
309305
progress_meter,
310306
):
307+
# Calculate in advance t1 - 100 ULP's: the threshold at which to round tnext to t1
308+
t1_clip_floor = t1
309+
for _ in range(100):
310+
t1_clip_floor = eqxi.prevbefore(t1_clip_floor)
311+
311312
if saveat.dense:
312313
dense_ts = init_state.dense_ts
313314
dense_ts = dense_ts.at[0].set(t0)
@@ -397,7 +398,7 @@ def body_fun_aux(state):
397398
#
398399

399400
tprev = jnp.minimum(tprev, t1)
400-
tnext = _clip_to_end(tprev, tnext, t1, keep_step)
401+
tnext = _clip_to_end(tprev, tnext, t1, t1_clip_floor, keep_step)
401402

402403
progress_meter_state = progress_meter.step(
403404
state.progress_meter_state, linear_rescale(t0, tprev, t1)

0 commit comments

Comments
 (0)