33from equinox .internal import ω
44
55from ..custom_types import Bool , DenseInfo , PyTree , Scalar
6+ from ..heuristics import is_sde
67from ..local_interpolation import LocalLinearInterpolation
78from ..solution import RESULTS
89from ..term import AbstractTerm
910from .base import AbstractImplicitSolver
1011
1112
12- _ErrorEstimate = None
1313_SolverState = None
1414
1515
@@ -22,15 +22,36 @@ def _implicit_relation(z1, nonlinear_solve_args):
2222class ImplicitEuler (AbstractImplicitSolver ):
2323 r"""Implicit Euler method.
2424
25- A-B-L stable 1st order SDIRK method. Does not support adaptive step sizing.
25+ A-B-L stable 1st order SDIRK method. Has an embedded 2nd order method for adaptive
26+ step sizing.
2627 """
2728
2829 term_structure = AbstractTerm
30+ # We actually have enough information to use 3rd order Hermite interpolation.
31+ #
32+ # We don't use it as this seems to be quite a bad choice for low-order solvers: it
33+ # produces very oscillatory interpolations.
2934 interpolation_cls = LocalLinearInterpolation
3035
3136 def order (self , terms ):
3237 return 1
3338
39+ def error_order (self , terms ):
40+ if is_sde (terms ):
41+ return None
42+ else :
43+ return 2
44+
45+ def init (
46+ self ,
47+ terms : AbstractTerm ,
48+ t0 : Scalar ,
49+ t1 : Scalar ,
50+ y0 : PyTree ,
51+ args : PyTree ,
52+ ) -> _SolverState :
53+ return None
54+
3455 def step (
3556 self ,
3657 terms : AbstractTerm ,
@@ -40,20 +61,28 @@ def step(
4061 args : PyTree ,
4162 solver_state : _SolverState ,
4263 made_jump : Bool ,
43- ) -> Tuple [PyTree , _ErrorEstimate , DenseInfo , _SolverState , RESULTS ]:
44- del solver_state , made_jump
64+ ) -> Tuple [PyTree , PyTree , DenseInfo , _SolverState , RESULTS ]:
65+ del made_jump
4566 control = terms .contr (t0 , t1 )
46- pred = terms .vf_prod (t0 , y0 , args , control )
67+ # Could use FSAL here but that would mean we'd need to switch to working with
68+ # `f0 = terms.vf(t0, y0, args)`, and that gets quite hairy quite quickly.
69+ # (C.f. `AbstractRungeKutta.step`.)
70+ # If we wanted FSAL then really the correct thing to do would just be to
71+ # write out a `ButcherTableau` and use `AbstractSDIRK`.
72+ k0 = terms .vf_prod (t0 , y0 , args , control )
4773 jac = self .nonlinear_solver .jac (
48- _implicit_relation , pred , (terms .vf_prod , t1 , y0 , args , control )
74+ _implicit_relation , k0 , (terms .vf_prod , t1 , y0 , args , control )
4975 )
5076 nonlinear_sol = self .nonlinear_solver (
51- _implicit_relation , pred , (terms .vf_prod , t1 , y0 , args , control ), jac
77+ _implicit_relation , k0 , (terms .vf_prod , t1 , y0 , args , control ), jac
5278 )
53- z1 = nonlinear_sol .root
54- y1 = (y0 ** ω + z1 ** ω ).ω
79+ k1 = nonlinear_sol .root
80+ y1 = (y0 ** ω + k1 ** ω ).ω
81+ # Use the trapezoidal rule for adaptive step sizing.
82+ y_error = (0.5 * (k1 ** ω - k0 ** ω )).ω
5583 dense_info = dict (y0 = y0 , y1 = y1 )
56- return y1 , None , dense_info , None , nonlinear_sol .result
84+ solver_state = None
85+ return y1 , y_error , dense_info , solver_state , nonlinear_sol .result
5786
5887 def func (
5988 self ,
0 commit comments