Skip to content

Commit edd1250

Browse files
Merge pull request #251 from patrick-kidger/implicit-euler-adaptive
Implicit Euler is now adaptive (basically necessary for it to be usef…
2 parents 0b93a3c + 7377755 commit edd1250

10 files changed

Lines changed: 148 additions & 29 deletions

diffrax/local_interpolation.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@ def __init__(
4141
*,
4242
y0: PyTree[Array["dims":...]], # noqa: F821
4343
y1: PyTree[Array["dims":...]], # noqa: F821
44-
f0: PyTree[Array["dims":...]], # noqa: F821
45-
f1: PyTree[Array["dims":...]], # noqa: F821
44+
k0: PyTree[Array["dims":...]], # noqa: F821
45+
k1: PyTree[Array["dims":...]], # noqa: F821
4646
**kwargs
4747
):
4848
super().__init__(**kwargs)
4949

50-
def _calculate(_y0, _y1, _f0, _f1):
51-
_a = _f0 + _f1 + 2 * _y0 - 2 * _y1
52-
_b = -2 * _f0 - _f1 - 3 * _y0 + 3 * _y1
53-
return jnp.stack([_a, _b, _f0, _y0])
50+
def _calculate(_y0, _y1, _k0, _k1):
51+
_a = _k0 + _k1 + 2 * _y0 - 2 * _y1
52+
_b = -2 * _k0 - _k1 - 3 * _y0 + 3 * _y1
53+
return jnp.stack([_a, _b, _k0, _y0])
5454

55-
self.coeffs = jtu.tree_map(_calculate, y0, y1, f0, f1)
55+
self.coeffs = jtu.tree_map(_calculate, y0, y1, k0, k1)
5656

5757
@classmethod
5858
def from_k(
@@ -63,7 +63,7 @@ def from_k(
6363
k: PyTree[Array["order", "dims":...]], # noqa: F821
6464
**kwargs
6565
):
66-
return cls(y0=y0, y1=y1, f0=ω(k)[0].ω, f1=ω(k)[-1].ω, **kwargs)
66+
return cls(y0=y0, y1=y1, k0=ω(k)[0].ω, k1=ω(k)[-1].ω, **kwargs)
6767

6868
def evaluate(
6969
self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True

diffrax/solver/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def error_order(self, terms: PyTree[AbstractTerm]) -> Optional[Scalar]:
8585
else:
8686
return self.order(terms)
8787

88+
@abc.abstractmethod
8889
def init(
8990
self,
9091
terms: PyTree[AbstractTerm],
@@ -101,7 +102,6 @@ def init(
101102
102103
The initial solver state, which should be used the first time `step` is called.
103104
"""
104-
return None
105105

106106
@abc.abstractmethod
107107
def step(

diffrax/solver/euler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def order(self, terms):
3030
def strong_order(self, terms):
3131
return 0.5
3232

33+
def init(
34+
self,
35+
terms: AbstractTerm,
36+
t0: Scalar,
37+
t1: Scalar,
38+
y0: PyTree,
39+
args: PyTree,
40+
) -> _SolverState:
41+
return None
42+
3343
def step(
3444
self,
3545
terms: AbstractTerm,

diffrax/solver/euler_heun.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ def order(self, terms):
2828
def strong_order(self, terms):
2929
return 0.5
3030

31+
def init(
32+
self,
33+
terms: Tuple[ODETerm, AbstractTerm],
34+
t0: Scalar,
35+
t1: Scalar,
36+
y0: PyTree,
37+
args: PyTree,
38+
) -> _SolverState:
39+
return None
40+
3141
def step(
3242
self,
3343
terms: Tuple[ODETerm, AbstractTerm],

diffrax/solver/implicit_euler.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from equinox.internal import ω
44

55
from ..custom_types import Bool, DenseInfo, PyTree, Scalar
6+
from ..heuristics import is_sde
67
from ..local_interpolation import LocalLinearInterpolation
78
from ..solution import RESULTS
89
from ..term import AbstractTerm
910
from .base import AbstractImplicitSolver
1011

1112

12-
_ErrorEstimate = None
1313
_SolverState = None
1414

1515

@@ -22,15 +22,36 @@ def _implicit_relation(z1, nonlinear_solve_args):
2222
class ImplicitEuler(AbstractImplicitSolver):
2323
r"""Implicit Euler method.
2424
25-
A-B-L stable 1st order SDIRK method. Does not support adaptive step sizing.
25+
A-B-L stable 1st order SDIRK method. Has an embedded 2nd order method for adaptive
26+
step sizing.
2627
"""
2728

2829
term_structure = AbstractTerm
30+
# We actually have enough information to use 3rd order Hermite interpolation.
31+
#
32+
# We don't use it as this seems to be quite a bad choice for low-order solvers: it
33+
# produces very oscillatory interpolations.
2934
interpolation_cls = LocalLinearInterpolation
3035

3136
def order(self, terms):
3237
return 1
3338

39+
def error_order(self, terms):
40+
if is_sde(terms):
41+
return None
42+
else:
43+
return 2
44+
45+
def init(
46+
self,
47+
terms: AbstractTerm,
48+
t0: Scalar,
49+
t1: Scalar,
50+
y0: PyTree,
51+
args: PyTree,
52+
) -> _SolverState:
53+
return None
54+
3455
def step(
3556
self,
3657
terms: AbstractTerm,
@@ -40,20 +61,28 @@ def step(
4061
args: PyTree,
4162
solver_state: _SolverState,
4263
made_jump: Bool,
43-
) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
44-
del solver_state, made_jump
64+
) -> Tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]:
65+
del made_jump
4566
control = terms.contr(t0, t1)
46-
pred = terms.vf_prod(t0, y0, args, control)
67+
# Could use FSAL here but that would mean we'd need to switch to working with
68+
# `f0 = terms.vf(t0, y0, args)`, and that gets quite hairy quite quickly.
69+
# (C.f. `AbstractRungeKutta.step`.)
70+
# If we wanted FSAL then really the correct thing to do would just be to
71+
# write out a `ButcherTableau` and use `AbstractSDIRK`.
72+
k0 = terms.vf_prod(t0, y0, args, control)
4773
jac = self.nonlinear_solver.jac(
48-
_implicit_relation, pred, (terms.vf_prod, t1, y0, args, control)
74+
_implicit_relation, k0, (terms.vf_prod, t1, y0, args, control)
4975
)
5076
nonlinear_sol = self.nonlinear_solver(
51-
_implicit_relation, pred, (terms.vf_prod, t1, y0, args, control), jac
77+
_implicit_relation, k0, (terms.vf_prod, t1, y0, args, control), jac
5278
)
53-
z1 = nonlinear_sol.root
54-
y1 = (y0**ω + z1**ω).ω
79+
k1 = nonlinear_sol.root
80+
y1 = (y0**ω + k1**ω).ω
81+
# Use the trapezoidal rule for adaptive step sizing.
82+
y_error = (0.5 * (k1**ω - k0**ω)).ω
5583
dense_info = dict(y0=y0, y1=y1)
56-
return y1, None, dense_info, None, nonlinear_sol.result
84+
solver_state = None
85+
return y1, y_error, dense_info, solver_state, nonlinear_sol.result
5786

5887
def func(
5988
self,

diffrax/solver/milstein.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,19 @@ def order(self, terms):
4545
def strong_order(self, terms):
4646
return 1 # assuming commutative noise
4747

48+
def init(
49+
self,
50+
terms: Tuple[ODETerm, AbstractTerm],
51+
t0: Scalar,
52+
t1: Scalar,
53+
y0: PyTree,
54+
args: PyTree,
55+
) -> _SolverState:
56+
return None
57+
4858
def step(
4959
self,
50-
terms: Tuple[AbstractTerm, AbstractTerm],
60+
terms: Tuple[ODETerm, AbstractTerm],
5161
t0: Scalar,
5262
t1: Scalar,
5363
y0: PyTree,
@@ -103,9 +113,19 @@ def order(self, terms):
103113
def strong_order(self, terms):
104114
return 1 # assuming commutative noise
105115

116+
def init(
117+
self,
118+
terms: Tuple[ODETerm, AbstractTerm],
119+
t0: Scalar,
120+
t1: Scalar,
121+
y0: PyTree,
122+
args: PyTree,
123+
) -> _SolverState:
124+
return None
125+
106126
def step(
107127
self,
108-
terms: Tuple[AbstractTerm, AbstractTerm],
128+
terms: Tuple[ODETerm, AbstractTerm],
109129
t0: Scalar,
110130
t1: Scalar,
111131
y0: PyTree,

diffrax/solver/semi_implicit_euler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ class SemiImplicitEuler(AbstractSolver):
2525
def order(self, terms):
2626
return 1
2727

28+
def init(
29+
self,
30+
terms: Tuple[AbstractTerm, AbstractTerm],
31+
t0: Scalar,
32+
t1: Scalar,
33+
y0: PyTree,
34+
args: PyTree,
35+
) -> _SolverState:
36+
return None
37+
2838
def step(
2939
self,
3040
terms: Tuple[AbstractTerm, AbstractTerm],
@@ -35,7 +45,7 @@ def step(
3545
solver_state: _SolverState,
3646
made_jump: Bool,
3747
) -> Tuple[Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
38-
del made_jump
48+
del solver_state, made_jump
3949

4050
term_1, term_2 = terms
4151
y0_1, y0_2 = y0

test/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def clear_caches():
3939
if hasattr(obj, "cache_clear"):
4040
try:
4141
print(f"Clearing {obj}")
42-
obj.cache_clear()
42+
if "Weakref" not in type(obj).__name__:
43+
obj.cache_clear()
4344
except Exception:
4445
pass
4546
gc.collect()

test/test_local_interpolation.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ def test_local_linear_interpolation():
1212
t1_ = 2.9
1313
for y0 in (2.1, jnp.array(2.1), jnp.array([2.1, 3.1])):
1414
for y1 in (2.2, jnp.array(2.2), jnp.array([2.2, 3.2])):
15-
interp = diffrax.local_interpolation.LocalLinearInterpolation(
16-
t0=t0, t1=t1, y0=y0, y1=y1
17-
)
15+
interp = diffrax.LocalLinearInterpolation(t0=t0, t1=t1, y0=y0, y1=y1)
1816

1917
# evaluate position
2018
pred = interp.evaluate(t0_)
@@ -36,12 +34,27 @@ def test_local_linear_interpolation():
3634
assert shaped_allclose(pred, jnp.zeros_like(pred))
3735

3836
# evaluate over zero-length interval. Note t1=t0.
39-
interp = diffrax.local_interpolation.LocalLinearInterpolation(
40-
t0=t0, t1=t0, y0=y0, y1=y1
41-
)
37+
interp = diffrax.LocalLinearInterpolation(t0=t0, t1=t0, y0=y0, y1=y1)
4238
pred = interp.evaluate(t0)
4339
true, _ = jnp.broadcast_arrays(y0, y1)
4440
assert shaped_allclose(pred, true)
4541

4642
_, pred = jax.jvp(interp.evaluate, (t0,), (jnp.ones_like(t0),))
4743
assert shaped_allclose(pred, jnp.zeros_like(pred))
44+
45+
46+
def test_third_order_hermite():
47+
t0 = 2.0
48+
t1 = 3.9
49+
50+
def y(t):
51+
return 0.4 + 0.7 * t - 1.1 * t**2 + 0.4 * t**3
52+
53+
y0, f0 = jax.jvp(y, (t0,), (1.0,))
54+
y1, f1 = jax.jvp(y, (t1,), (1.0,))
55+
k0 = f0 * (t1 - t0)
56+
k1 = f1 * (t1 - t0)
57+
interp = diffrax.ThirdOrderHermitePolynomialInterpolation(
58+
t0=t0, t1=t1, y0=y0, y1=y1, k0=k0, k1=k1
59+
)
60+
assert shaped_allclose(interp.evaluate(2.6), y(2.6))

test/test_solver.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,29 @@ def test_half_solver():
1717
def test_instance_check():
1818
assert isinstance(diffrax.HalfSolver(diffrax.Euler()), diffrax.Euler)
1919
assert not isinstance(diffrax.HalfSolver(diffrax.Euler()), diffrax.Heun)
20+
21+
22+
def test_implicit_euler_adaptive():
23+
term = diffrax.ODETerm(lambda t, y, args: -10 * y**3)
24+
solver1 = diffrax.ImplicitEuler(
25+
nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-5, atol=1e-5)
26+
)
27+
solver2 = diffrax.ImplicitEuler()
28+
t0 = 0
29+
t1 = 1
30+
dt0 = 1
31+
y0 = 1.0
32+
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
33+
out1 = diffrax.diffeqsolve(term, solver1, t0, t1, dt0, y0, throw=False)
34+
out2 = diffrax.diffeqsolve(
35+
term,
36+
solver2,
37+
t0,
38+
t1,
39+
dt0,
40+
y0,
41+
stepsize_controller=stepsize_controller,
42+
throw=False,
43+
)
44+
assert out1.result == diffrax.RESULTS.implicit_nonconvergence
45+
assert out2.result == diffrax.RESULTS.successful

0 commit comments

Comments
 (0)