Skip to content

Commit e8b8e5c

Browse files
authored
add direction for saveat + test (#427)
* add direction for saveat + test * remove direction, make changes before * fix `save_y` test * expand tests * fix inf initialization * fix test
1 parent c6cc85c commit e8b8e5c

2 files changed

Lines changed: 56 additions & 4 deletions

File tree

diffrax/_integrate.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
NoProgressMeter,
3838
)
3939
from ._root_finder import use_stepsize_tol
40-
from ._saveat import SaveAt, SubSaveAt
40+
from ._saveat import save_y, SaveAt, SubSaveAt
4141
from ._solution import is_okay, is_successful, RESULTS, Solution
4242
from ._solver import (
4343
AbstractImplicitSolver,
@@ -881,6 +881,18 @@ def _check_subsaveat_ts(ts):
881881

882882
saveat = eqx.tree_at(_get_subsaveat_ts, saveat, replace_fn=_check_subsaveat_ts)
883883

884+
def _subsaveat_direction_fn(x):
885+
if _is_subsaveat(x):
886+
if x.fn is not save_y:
887+
direction_fn = lambda t, y, args: x.fn(direction * t, y, args)
888+
return eqx.tree_at(lambda x: x.fn, x, direction_fn)
889+
else:
890+
return x
891+
else:
892+
return x
893+
894+
saveat = jtu.tree_map(_subsaveat_direction_fn, saveat, is_leaf=_is_subsaveat)
895+
884896
# Initialise states
885897
tprev = t0
886898
error_order = solver.error_order(terms)
@@ -924,7 +936,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
924936
out_size += 1
925937
saveat_ts_index = 0
926938
save_index = 0
927-
ts = jnp.full(out_size, jnp.inf, dtype=time_dtype)
939+
ts = jnp.full(out_size, direction * jnp.inf, dtype=time_dtype)
928940
struct = eqx.filter_eval_shape(subsaveat.fn, t0, y0, args)
929941
ys = jtu.tree_map(
930942
lambda y: jnp.full((out_size,) + y.shape, jnp.inf, dtype=y.dtype), struct
@@ -1013,7 +1025,6 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
10131025
#
10141026

10151027
progress_meter.close(final_state.progress_meter_state)
1016-
10171028
is_save_state = lambda x: isinstance(x, SaveState)
10181029
ts = jtu.tree_map(
10191030
lambda s: s.ts * direction, final_state.save_state, is_leaf=is_save_state

test/test_integrate.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,18 @@ def g(t, y, args):
389389
or saveat.subs.ts is not None
390390
or saveat.subs.steps
391391
):
392-
assert tree_allclose(sol1.ts, -cast(Array, sol2.ts), equal_nan=True)
392+
assert sol1.ts is not None
393+
assert sol2.ts is not None
394+
assert tree_allclose(
395+
sol1.ts[~jnp.isinf(sol1.ts)],
396+
-cast(Array, sol2.ts[~jnp.isinf(sol2.ts)]),
397+
equal_nan=True,
398+
)
399+
assert tree_allclose(
400+
sol1.ts[jnp.isinf(sol1.ts)],
401+
cast(Array, sol2.ts[jnp.isinf(sol2.ts)]),
402+
equal_nan=True,
403+
)
393404
assert tree_allclose(sol1.ys, sol2.ys, equal_nan=True)
394405
if saveat.dense:
395406
t = jnp.linspace(0.3, 4, 20)
@@ -398,6 +409,36 @@ def g(t, y, args):
398409
assert tree_allclose(sol1.derivative(ti), -sol2.derivative(-ti))
399410

400411

412+
@pytest.mark.parametrize(
413+
"saveat",
414+
(
415+
diffrax.SaveAt(t0=True, fn=lambda t, y, args: t),
416+
diffrax.SaveAt(t1=True, fn=lambda t, y, args: t),
417+
diffrax.SaveAt(dense=True, fn=lambda t, y, args: t),
418+
diffrax.SaveAt(steps=True, fn=lambda t, y, args: t),
419+
diffrax.SaveAt(ts=jnp.linspace(3.0, 1.0, 5), fn=lambda t, y, args: t),
420+
),
421+
)
422+
def test_reverse_time_saveat(saveat):
423+
def f(t, y, args):
424+
return -y
425+
426+
t0 = 4
427+
t1 = 0.3
428+
dt0 = -1 / 50
429+
y0 = 1.0
430+
sol1 = diffrax.diffeqsolve(
431+
diffrax.ODETerm(f),
432+
diffrax.Euler(),
433+
t0,
434+
t1,
435+
dt0,
436+
y0,
437+
saveat=saveat,
438+
)
439+
assert tree_allclose(sol1.ys, sol1.ts)
440+
441+
401442
def test_semi_implicit_euler():
402443
term1 = diffrax.ODETerm(lambda t, y, args: -y)
403444
term2 = diffrax.ODETerm(lambda t, y, args: y)

0 commit comments

Comments
 (0)