Skip to content

Commit 91c245a

Browse files
Adjust save-every-step logic.
1 parent 3c9e0db commit 91c245a

2 files changed

Lines changed: 59 additions & 46 deletions

File tree

diffrax/_integrate.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,8 @@ def _save(
242242
ys = save_state.ys
243243
save_index = save_state.save_index
244244

245-
ts = lax.dynamic_update_slice_in_dim(
246-
ts, jnp.broadcast_to(t, (repeat,)), save_index, axis=0
247-
)
245+
t_to_save = jnp.broadcast_to(static_select(pred, t, ts[save_index]), (repeat,))
246+
ts = lax.dynamic_update_slice_in_dim(ts, t_to_save, save_index, axis=0)
248247
y_to_save = lax.cond(
249248
pred,
250249
lambda: fn(t, y, args),
@@ -484,33 +483,29 @@ def _body_fun(_save_state):
484483
save_ts, saveat.subs, save_state, is_leaf=_is_subsaveat
485484
)
486485

487-
def maybe_inplace(i, u, x):
488-
return eqxi.buffer_at_set(x, i, u, pred=keep_step)
489-
490486
def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
491487
if subsaveat.steps != 0:
492-
save_step = (state.num_accepted_steps % subsaveat.steps) == 0
488+
save_step = (num_accepted_steps % subsaveat.steps) == 0
493489
should_save = keep_step & save_step
494490

495-
def save_fn(tprev, y, args):
496-
return subsaveat.fn(tprev, y, args)
497-
# TODO: Enable this, but I am not sure if possible? How do we know
498-
# the output shape of `.fn`? We should do a dummy call to it?
499-
if subsaveat.steps == 1:
500-
return subsaveat.fn(tprev, y, args)
501-
else:
502-
return lax.cond(
503-
should_save,
504-
lambda: subsaveat.fn(tprev, y, args),
505-
lambda: jtu.tree_map(
506-
lambda y: jnp.zeros(y.shape[1:], y.dtype), save_state.ys
507-
),
508-
)
491+
if subsaveat.steps == 1:
492+
y_to_save = subsaveat.fn(tprev, y, args)
493+
else:
494+
struct = eqx.filter_eval_shape(subsaveat.fn, tprev, y, args)
495+
y_to_save = lax.cond(
496+
eqxi.unvmap_any(should_save),
497+
lambda: subsaveat.fn(tprev, y, args),
498+
lambda: jtu.tree_map(jnp.zeros_like, struct),
499+
)
509500

510-
ts = maybe_inplace(save_state.save_index, tprev, save_state.ts)
501+
ts = eqxi.buffer_at_set(
502+
save_state.ts, save_state.save_index, tprev, pred=should_save
503+
)
511504
ys = jtu.tree_map(
512-
ft.partial(maybe_inplace, save_state.save_index),
513-
save_fn(tprev, y, args),
505+
lambda _y, _ys: eqxi.buffer_at_set(
506+
_ys, save_state.save_index, _y, pred=should_save
507+
),
508+
y_to_save,
514509
save_state.ys,
515510
)
516511
save_index = save_state.save_index + jnp.where(should_save, 1, 0)
@@ -525,9 +520,13 @@ def save_fn(tprev, y, args):
525520
save_steps, saveat.subs, save_state, is_leaf=_is_subsaveat
526521
)
527522
if saveat.dense:
528-
dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts)
523+
dense_ts = eqxi.buffer_at_set(
524+
dense_ts, dense_save_index + 1, tprev, pred=keep_step
525+
)
529526
dense_infos = jtu.tree_map(
530-
ft.partial(maybe_inplace, dense_save_index),
527+
lambda _i, _is: eqxi.buffer_at_set(
528+
_is, dense_save_index, _i, pred=keep_step
529+
),
531530
dense_info,
532531
dense_infos,
533532
)

test/test_saveat_solution.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,12 @@ def test_saveat_solution():
191191

192192

193193
def test_saveat_solution_skip_steps():
194-
def _step_integrate(saveat: diffrax.SaveAt):
194+
def _step_integrate(saveat: diffrax.SaveAt, with_7: bool):
195195
term = diffrax.ODETerm(lambda t, y, args: -0.5 * y)
196-
ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
196+
if with_7:
197+
ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])
198+
else:
199+
ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
197200
sol_ts = diffrax.diffeqsolve(
198201
term,
199202
t0=ts[0],
@@ -208,24 +211,35 @@ def _step_integrate(saveat: diffrax.SaveAt):
208211
assert sol_ts is not None
209212
return sol_ts[jnp.isfinite(sol_ts)]
210213

211-
saveat = diffrax.SaveAt(steps=2)
212-
ts = _step_integrate(saveat)
213-
assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0]))
214-
saveat = diffrax.SaveAt(steps=2, t1=True)
215-
ts = _step_integrate(saveat)
216-
assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0, 6.0]))
217-
saveat = diffrax.SaveAt(steps=2, t1=True, t0=True)
218-
ts = _step_integrate(saveat)
219-
assert jnp.allclose(ts, jnp.array([0.0, 1.0, 3.0, 5.0, 6.0]))
220-
saveat = diffrax.SaveAt(steps=3)
221-
ts = _step_integrate(saveat)
222-
assert jnp.allclose(ts, jnp.array([1.0, 4.0]))
223-
saveat = diffrax.SaveAt(steps=3, t1=True)
224-
ts = _step_integrate(saveat)
225-
assert jnp.allclose(ts, jnp.array([1.0, 4.0, 6.0]))
226-
saveat = diffrax.SaveAt(steps=3, t1=True, t0=True)
227-
ts = _step_integrate(saveat)
228-
assert jnp.allclose(ts, jnp.array([0.0, 1.0, 4.0, 6.0]))
214+
ts = _step_integrate(diffrax.SaveAt(steps=2), with_7=True)
215+
assert jnp.allclose(ts, jnp.array([2.0, 4.0, 6.0]))
216+
ts = _step_integrate(diffrax.SaveAt(steps=2), with_7=False)
217+
assert jnp.allclose(ts, jnp.array([2.0, 4.0, 6.0]))
218+
219+
ts = _step_integrate(diffrax.SaveAt(steps=2, t1=True), with_7=True)
220+
assert jnp.allclose(ts, jnp.array([2.0, 4.0, 6.0, 7.0]))
221+
ts = _step_integrate(diffrax.SaveAt(steps=2, t1=True), with_7=False)
222+
assert jnp.allclose(ts, jnp.array([2.0, 4.0, 6.0]))
223+
224+
ts = _step_integrate(diffrax.SaveAt(steps=2, t1=True, t0=True), with_7=True)
225+
assert jnp.allclose(ts, jnp.array([0.0, 2.0, 4.0, 6.0, 7.0]))
226+
ts = _step_integrate(diffrax.SaveAt(steps=2, t1=True, t0=True), with_7=False)
227+
assert jnp.allclose(ts, jnp.array([0.0, 2.0, 4.0, 6.0]))
228+
229+
ts = _step_integrate(diffrax.SaveAt(steps=3), with_7=True)
230+
assert jnp.allclose(ts, jnp.array([3.0, 6.0]))
231+
ts = _step_integrate(diffrax.SaveAt(steps=3), with_7=False)
232+
assert jnp.allclose(ts, jnp.array([3.0, 6.0]))
233+
234+
ts = _step_integrate(diffrax.SaveAt(steps=3, t1=True), with_7=True)
235+
assert jnp.allclose(ts, jnp.array([3.0, 6.0, 7.0]))
236+
ts = _step_integrate(diffrax.SaveAt(steps=3, t1=True), with_7=False)
237+
assert jnp.allclose(ts, jnp.array([3.0, 6.0]))
238+
239+
ts = _step_integrate(diffrax.SaveAt(steps=3, t1=True, t0=True), with_7=True)
240+
assert jnp.allclose(ts, jnp.array([0.0, 3.0, 6.0, 7.0]))
241+
ts = _step_integrate(diffrax.SaveAt(steps=3, t1=True, t0=True), with_7=False)
242+
assert jnp.allclose(ts, jnp.array([0.0, 3.0, 6.0]))
229243

230244

231245
def test_saveat_solution_skip_vs_saveat():

0 commit comments

Comments
 (0)