Skip to content

Commit 56d8c0b

Browse files
committed
In progress commit on branch delay.
1 parent 3020bb5 commit 56d8c0b

24 files changed

Lines changed: 12388 additions & 107 deletions

diffrax/__init__.py

Lines changed: 93 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,114 @@
11
import importlib.metadata
22

3-
from ._adjoint import (
4-
AbstractAdjoint as AbstractAdjoint,
5-
BacksolveAdjoint as BacksolveAdjoint,
6-
DirectAdjoint as DirectAdjoint,
7-
ImplicitAdjoint as ImplicitAdjoint,
8-
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
9-
)
10-
from ._autocitation import citation as citation, citation_rules as citation_rules
11-
from ._brownian import (
12-
AbstractBrownianPath as AbstractBrownianPath,
13-
UnsafeBrownianPath as UnsafeBrownianPath,
14-
VirtualBrownianTree as VirtualBrownianTree,
15-
)
3+
from ._adjoint import AbstractAdjoint as AbstractAdjoint
4+
from ._adjoint import BacksolveAdjoint as BacksolveAdjoint
5+
from ._adjoint import DirectAdjoint as DirectAdjoint
6+
from ._adjoint import ImplicitAdjoint as ImplicitAdjoint
7+
from ._adjoint import RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint
8+
from ._autocitation import citation as citation
9+
from ._autocitation import citation_rules as citation_rules
10+
from ._brownian import AbstractBrownianPath
11+
from ._brownian import AbstractBrownianPath as AbstractBrownianPath
12+
from ._brownian import UnsafeBrownianPath
13+
from ._brownian import UnsafeBrownianPath as UnsafeBrownianPath
14+
from ._brownian import VirtualBrownianTree
15+
from ._brownian import VirtualBrownianTree as VirtualBrownianTree
1616
from ._custom_types import LevyVal as LevyVal
17-
from ._event import (
18-
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
19-
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
20-
SteadyStateEvent as SteadyStateEvent,
21-
)
17+
from ._delays import Delays as Delays
18+
from ._delays import bind_history as bind_history
19+
from ._delays import history_extrapolation_implicit as history_extrapolation_implicit
20+
from ._delays import maybe_find_discontinuity as maybe_find_discontinuity
21+
from ._event import AbstractDiscreteTerminatingEvent
22+
from ._event import AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent
23+
from ._event import DiscreteTerminatingEvent
24+
from ._event import DiscreteTerminatingEvent as DiscreteTerminatingEvent
25+
from ._event import SteadyStateEvent
26+
from ._event import SteadyStateEvent as SteadyStateEvent
2227
from ._global_interpolation import (
2328
AbstractGlobalInterpolation as AbstractGlobalInterpolation,
29+
)
30+
from ._global_interpolation import CubicInterpolation as CubicInterpolation
31+
from ._global_interpolation import DenseInterpolation as DenseInterpolation
32+
from ._global_interpolation import LinearInterpolation as LinearInterpolation
33+
from ._global_interpolation import (
2434
backward_hermite_coefficients as backward_hermite_coefficients,
25-
CubicInterpolation as CubicInterpolation,
26-
DenseInterpolation as DenseInterpolation,
27-
linear_interpolation as linear_interpolation,
28-
LinearInterpolation as LinearInterpolation,
35+
)
36+
from ._global_interpolation import linear_interpolation as linear_interpolation
37+
from ._global_interpolation import (
2938
rectilinear_interpolation as rectilinear_interpolation,
3039
)
3140
from ._integrate import diffeqsolve as diffeqsolve
3241
from ._local_interpolation import (
3342
AbstractLocalInterpolation as AbstractLocalInterpolation,
43+
)
44+
from ._local_interpolation import (
3445
FourthOrderPolynomialInterpolation as FourthOrderPolynomialInterpolation,
35-
LocalLinearInterpolation as LocalLinearInterpolation,
36-
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation, # noqa: E501
3746
)
47+
from ._local_interpolation import LocalLinearInterpolation as LocalLinearInterpolation
48+
from ._local_interpolation import (
49+
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation,
50+
) # noqa: E501
3851
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
3952
from ._path import AbstractPath as AbstractPath
40-
from ._root_finder import (
41-
VeryChord as VeryChord,
42-
with_stepsize_controller_tols as with_stepsize_controller_tols,
43-
)
44-
from ._saveat import SaveAt as SaveAt, SubSaveAt as SubSaveAt
45-
from ._solution import (
46-
is_event as is_event,
47-
is_okay as is_okay,
48-
is_successful as is_successful,
49-
RESULTS as RESULTS,
50-
Solution as Solution,
51-
)
52-
from ._solver import (
53-
AbstractAdaptiveSolver as AbstractAdaptiveSolver,
54-
AbstractDIRK as AbstractDIRK,
55-
AbstractERK as AbstractERK,
56-
AbstractESDIRK as AbstractESDIRK,
57-
AbstractImplicitSolver as AbstractImplicitSolver,
58-
AbstractItoSolver as AbstractItoSolver,
59-
AbstractRungeKutta as AbstractRungeKutta,
60-
AbstractSDIRK as AbstractSDIRK,
61-
AbstractSolver as AbstractSolver,
62-
AbstractStratonovichSolver as AbstractStratonovichSolver,
63-
AbstractWrappedSolver as AbstractWrappedSolver,
64-
Bosh3 as Bosh3,
65-
ButcherTableau as ButcherTableau,
66-
CalculateJacobian as CalculateJacobian,
67-
Dopri5 as Dopri5,
68-
Dopri8 as Dopri8,
69-
Euler as Euler,
70-
EulerHeun as EulerHeun,
71-
HalfSolver as HalfSolver,
72-
Heun as Heun,
73-
ImplicitEuler as ImplicitEuler,
74-
ItoMilstein as ItoMilstein,
75-
KenCarp3 as KenCarp3,
76-
KenCarp4 as KenCarp4,
77-
KenCarp5 as KenCarp5,
78-
Kvaerno3 as Kvaerno3,
79-
Kvaerno4 as Kvaerno4,
80-
Kvaerno5 as Kvaerno5,
81-
LeapfrogMidpoint as LeapfrogMidpoint,
82-
Midpoint as Midpoint,
83-
MultiButcherTableau as MultiButcherTableau,
84-
Ralston as Ralston,
85-
ReversibleHeun as ReversibleHeun,
86-
SemiImplicitEuler as SemiImplicitEuler,
87-
Sil3 as Sil3,
88-
StratonovichMilstein as StratonovichMilstein,
89-
Tsit5 as Tsit5,
90-
)
53+
from ._root_finder import VeryChord as VeryChord
54+
from ._root_finder import with_stepsize_controller_tols as with_stepsize_controller_tols
55+
from ._saveat import SaveAt as SaveAt
56+
from ._saveat import SubSaveAt as SubSaveAt
57+
from ._solution import RESULTS as RESULTS
58+
from ._solution import Solution as Solution
59+
from ._solution import is_event as is_event
60+
from ._solution import is_okay as is_okay
61+
from ._solution import is_successful as is_successful
62+
from ._solver import AbstractAdaptiveSolver as AbstractAdaptiveSolver
63+
from ._solver import AbstractDIRK as AbstractDIRK
64+
from ._solver import AbstractERK as AbstractERK
65+
from ._solver import AbstractESDIRK as AbstractESDIRK
66+
from ._solver import AbstractImplicitSolver as AbstractImplicitSolver
67+
from ._solver import AbstractItoSolver as AbstractItoSolver
68+
from ._solver import AbstractRungeKutta as AbstractRungeKutta
69+
from ._solver import AbstractSDIRK as AbstractSDIRK
70+
from ._solver import AbstractSolver as AbstractSolver
71+
from ._solver import AbstractStratonovichSolver as AbstractStratonovichSolver
72+
from ._solver import AbstractWrappedSolver as AbstractWrappedSolver
73+
from ._solver import Bosh3 as Bosh3
74+
from ._solver import ButcherTableau as ButcherTableau
75+
from ._solver import CalculateJacobian as CalculateJacobian
76+
from ._solver import Dopri5 as Dopri5
77+
from ._solver import Dopri8 as Dopri8
78+
from ._solver import Euler as Euler
79+
from ._solver import EulerHeun as EulerHeun
80+
from ._solver import HalfSolver as HalfSolver
81+
from ._solver import Heun as Heun
82+
from ._solver import ImplicitEuler as ImplicitEuler
83+
from ._solver import ItoMilstein as ItoMilstein
84+
from ._solver import KenCarp3 as KenCarp3
85+
from ._solver import KenCarp4 as KenCarp4
86+
from ._solver import KenCarp5 as KenCarp5
87+
from ._solver import Kvaerno3 as Kvaerno3
88+
from ._solver import Kvaerno4 as Kvaerno4
89+
from ._solver import Kvaerno5 as Kvaerno5
90+
from ._solver import LeapfrogMidpoint as LeapfrogMidpoint
91+
from ._solver import Midpoint as Midpoint
92+
from ._solver import MultiButcherTableau as MultiButcherTableau
93+
from ._solver import Ralston as Ralston
94+
from ._solver import ReversibleHeun as ReversibleHeun
95+
from ._solver import SemiImplicitEuler as SemiImplicitEuler
96+
from ._solver import Sil3 as Sil3
97+
from ._solver import StratonovichMilstein as StratonovichMilstein
98+
from ._solver import Tsit5 as Tsit5
9199
from ._step_size_controller import (
92100
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
93-
AbstractStepSizeController as AbstractStepSizeController,
94-
ConstantStepSize as ConstantStepSize,
95-
PIDController as PIDController,
96-
StepTo as StepTo,
97101
)
98-
from ._term import (
99-
AbstractTerm as AbstractTerm,
100-
ControlTerm as ControlTerm,
101-
MultiTerm as MultiTerm,
102-
ODETerm as ODETerm,
103-
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
102+
from ._step_size_controller import (
103+
AbstractStepSizeController as AbstractStepSizeController,
104104
)
105-
105+
from ._step_size_controller import ConstantStepSize as ConstantStepSize
106+
from ._step_size_controller import PIDController as PIDController
107+
from ._step_size_controller import StepTo as StepTo
108+
from ._term import AbstractTerm as AbstractTerm
109+
from ._term import ControlTerm as ControlTerm
110+
from ._term import MultiTerm as MultiTerm
111+
from ._term import ODETerm as ODETerm
112+
from ._term import WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm
106113

107114
__version__ = importlib.metadata.version("diffrax")

diffrax/_adjoint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def loop(
119119
solver,
120120
stepsize_controller,
121121
discrete_terminating_event,
122+
delays,
122123
saveat,
123124
t0,
124125
t1,
@@ -128,6 +129,7 @@ def loop(
128129
init_state,
129130
passed_solver_state,
130131
passed_controller_state,
132+
y0_history,
131133
) -> Any:
132134
"""Runs the main solve loop. Subclasses can override this to provide custom
133135
backpropagation behaviour; see for example the implementation of
@@ -552,13 +554,15 @@ def _loop_backsolve_bwd(
552554
solver,
553555
stepsize_controller,
554556
discrete_terminating_event,
557+
delays,
555558
saveat,
556559
t0,
557560
t1,
558561
dt0,
559562
max_steps,
560563
throw,
561564
init_state,
565+
y0_history,
562566
):
563567
assert discrete_terminating_event is None
564568

@@ -596,6 +600,8 @@ def _loop_backsolve_bwd(
596600
adjoint=self,
597601
solver=solver,
598602
stepsize_controller=stepsize_controller,
603+
discrete_terminating_event=discrete_terminating_event,
604+
delays=delays,
599605
terms=adjoint_terms,
600606
dt0=None if dt0 is None else -dt0,
601607
max_steps=max_steps,
@@ -775,6 +781,7 @@ def loop(
775781
passed_solver_state,
776782
passed_controller_state,
777783
discrete_terminating_event,
784+
delays,
778785
**kwargs,
779786
):
780787
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
@@ -820,6 +827,10 @@ def loop(
820827
raise NotImplementedError(
821828
"`diffrax.BacksolveAdjoint` is not compatible with events."
822829
)
830+
if delays is not None:
831+
raise NotImplementedError(
832+
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
833+
)
823834

824835
y = init_state.y
825836
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
@@ -834,6 +845,7 @@ def loop(
834845
init_state=init_state,
835846
solver=solver,
836847
discrete_terminating_event=discrete_terminating_event,
848+
delays=delays,
837849
**kwargs,
838850
)
839851
final_state = _only_transpose_ys(final_state)

0 commit comments

Comments
 (0)