Skip to content

Commit 6e34acd

Browse files
committed
Merge branch 'dev'
2 parents 633afbd + 55d3c0f commit 6e34acd

50 files changed

Lines changed: 4007 additions & 389 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.1.7
3+
rev: v0.2.2
44
hooks:
55
- id: ruff # linter
66
types_or: [ python, pyi, jupyter ]
77
args: [ --fix ]
88
- id: ruff-format # formatter
99
types_or: [ python, pyi, jupyter ]
1010
- repo: https://github.com/RobertCraigie/pyright-python
11-
rev: v1.1.316
11+
rev: v1.1.350
1212
hooks:
1313
- id: pyright
1414
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]

benchmarks/small_neural_ode.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
class FuncTorch(torch.nn.Module):
2121
def __init__(self):
2222
super().__init__()
23-
self.func = torch.jit.script( # pyright: ignore
23+
self.func = torch.jit.script(
2424
torch.nn.Sequential(
2525
torch.nn.Linear(4, 32),
2626
torch.nn.Softplus(),
@@ -30,7 +30,7 @@ def __init__(self):
3030
)
3131

3232
def forward(self, t, y):
33-
return self.func(y) # pyright: ignore
33+
return self.func(y)
3434

3535

3636
class FuncJax(eqx.Module):
@@ -177,10 +177,10 @@ def run(multiple, grad, batch_size=64, t1=100):
177177
with torch.no_grad():
178178
func_jax = neural_ode_diffrax.func.func
179179
func_torch = neural_ode_torch.func.func
180-
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight))) # pyright: ignore
181-
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias))) # pyright: ignore
182-
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight))) # pyright: ignore
183-
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias))) # pyright: ignore
180+
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight)))
181+
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias)))
182+
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight)))
183+
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))
184184

185185
y0_jax = jr.normal(jr.PRNGKey(1), (batch_size, 4))
186186
y0_torch = torch.tensor(np.asarray(y0_jax))

diffrax/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313
UnsafeBrownianPath as UnsafeBrownianPath,
1414
VirtualBrownianTree as VirtualBrownianTree,
1515
)
16-
from ._custom_types import LevyVal as LevyVal
16+
from ._custom_types import (
17+
AbstractBrownianIncrement as AbstractBrownianIncrement,
18+
AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea,
19+
AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea,
20+
BrownianIncrement as BrownianIncrement,
21+
SpaceTimeLevyArea as SpaceTimeLevyArea,
22+
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
23+
)
1724
from ._event import (
1825
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
1926
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
@@ -37,6 +44,12 @@
3744
)
3845
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
3946
from ._path import AbstractPath as AbstractPath
47+
from ._progress_meter import (
48+
AbstractProgressMeter as AbstractProgressMeter,
49+
NoProgressMeter as NoProgressMeter,
50+
TextProgressMeter as TextProgressMeter,
51+
TqdmProgressMeter as TqdmProgressMeter,
52+
)
4053
from ._root_finder import (
4154
VeryChord as VeryChord,
4255
with_stepsize_controller_tols as with_stepsize_controller_tols,
@@ -59,6 +72,7 @@
5972
AbstractRungeKutta as AbstractRungeKutta,
6073
AbstractSDIRK as AbstractSDIRK,
6174
AbstractSolver as AbstractSolver,
75+
AbstractSRK as AbstractSRK,
6276
AbstractStratonovichSolver as AbstractStratonovichSolver,
6377
AbstractWrappedSolver as AbstractWrappedSolver,
6478
Bosh3 as Bosh3,
@@ -68,6 +82,7 @@
6882
Dopri8 as Dopri8,
6983
Euler as Euler,
7084
EulerHeun as EulerHeun,
85+
GeneralShARK as GeneralShARK,
7186
HalfSolver as HalfSolver,
7287
Heun as Heun,
7388
ImplicitEuler as ImplicitEuler,
@@ -84,8 +99,14 @@
8499
MultiButcherTableau as MultiButcherTableau,
85100
Ralston as Ralston,
86101
ReversibleHeun as ReversibleHeun,
102+
SEA as SEA,
87103
SemiImplicitEuler as SemiImplicitEuler,
104+
ShARK as ShARK,
88105
Sil3 as Sil3,
106+
SlowRK as SlowRK,
107+
SPaRK as SPaRK,
108+
SRA1 as SRA1,
109+
StochasticButcherTableau as StochasticButcherTableau,
89110
StratonovichMilstein as StratonovichMilstein,
90111
Tsit5 as Tsit5,
91112
)

diffrax/_adjoint.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import abc
22
import functools as ft
33
import warnings
4-
from collections.abc import Iterable
5-
from typing import Any, Optional, Union
4+
from collections.abc import Callable, Iterable
5+
from typing import Any, cast, Optional, Union
66

77
import equinox as eqx
88
import equinox.internal as eqxi
@@ -20,6 +20,9 @@
2020
from ._term import AbstractTerm, AdjointTerm
2121

2222

23+
ω = cast(Callable, ω)
24+
25+
2326
def _is_none(x):
2427
return x is None
2528

@@ -128,6 +131,7 @@ def loop(
128131
init_state,
129132
passed_solver_state,
130133
passed_controller_state,
134+
progress_meter,
131135
) -> Any:
132136
"""Runs the main solve loop. Subclasses can override this to provide custom
133137
backpropagation behaviour; see for example the implementation of
@@ -559,6 +563,7 @@ def _loop_backsolve_bwd(
559563
max_steps,
560564
throw,
561565
init_state,
566+
progress_meter,
562567
):
563568
assert discrete_terminating_event is None
564569

@@ -567,7 +572,7 @@ def _loop_backsolve_bwd(
567572
# using them later.
568573
#
569574

570-
del perturbed, init_state, t1
575+
del perturbed, init_state, t1, progress_meter
571576
ts, ys = residuals
572577
del residuals
573578
grad_final_state, _ = grad_final_state__aux_stats

diffrax/_autocitation.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,23 @@
1616
from ._saveat import SubSaveAt
1717
from ._solver import (
1818
AbstractImplicitSolver,
19+
AbstractItoSolver,
20+
AbstractSRK,
21+
AbstractStratonovichSolver,
1922
Dopri5,
2023
Dopri8,
24+
GeneralShARK,
2125
Kvaerno3,
2226
Kvaerno4,
2327
Kvaerno5,
2428
LeapfrogMidpoint,
2529
ReversibleHeun,
30+
SEA,
2631
SemiImplicitEuler,
32+
ShARK,
33+
SlowRK,
34+
SPaRK,
35+
SRA1,
2736
Tsit5,
2837
)
2938
from ._step_size_controller import PIDController
@@ -374,7 +383,15 @@ def _backsolve_rms_norm(adjoint):
374383

375384
@citation_rules.append
376385
def _explicit_solver(solver, terms=None):
377-
if not isinstance(solver, AbstractImplicitSolver) and not is_sde(terms):
386+
if not isinstance(
387+
solver,
388+
(
389+
AbstractImplicitSolver,
390+
AbstractSRK,
391+
AbstractItoSolver,
392+
AbstractStratonovichSolver,
393+
),
394+
) and not is_sde(terms):
378395
return r"""
379396
% You are using an explicit solver, and may wish to cite the standard textbook:
380397
@book{hairer2008solving-i,
@@ -467,6 +484,12 @@ def _solvers(solver, saveat=None):
467484
Kvaerno5,
468485
ReversibleHeun,
469486
LeapfrogMidpoint,
487+
ShARK,
488+
SRA1,
489+
SlowRK,
490+
GeneralShARK,
491+
SPaRK,
492+
SEA,
470493
):
471494
return (
472495
r"""

diffrax/_brownian/base.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
import abc
2-
from typing import Optional, Union
2+
from typing import Optional, TypeVar, Union
33

44
from equinox.internal import AbstractVar
55
from jaxtyping import Array, PyTree
66

7-
from .._custom_types import LevyArea, LevyVal, RealScalarLike
7+
from .._custom_types import (
8+
AbstractBrownianIncrement,
9+
BrownianIncrement,
10+
RealScalarLike,
11+
SpaceTimeLevyArea,
12+
)
813
from .._path import AbstractPath
914

1015

11-
class AbstractBrownianPath(AbstractPath):
16+
_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianIncrement])
17+
18+
19+
class AbstractBrownianPath(AbstractPath[_Control]):
1220
"""Abstract base class for all Brownian paths."""
1321

14-
levy_area: AbstractVar[LevyArea]
22+
levy_area: AbstractVar[type[Union[BrownianIncrement, SpaceTimeLevyArea]]]
1523

1624
@abc.abstractmethod
1725
def evaluate(
@@ -20,7 +28,7 @@ def evaluate(
2028
t1: Optional[RealScalarLike] = None,
2129
left: bool = True,
2230
use_levy: bool = False,
23-
) -> Union[PyTree[Array], LevyVal]:
31+
) -> _Control:
2432
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.
2533
2634
Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.

diffrax/_brownian/path.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
import lineax.internal as lxi
1111
from jaxtyping import Array, PRNGKeyArray, PyTree
1212

13-
from .._custom_types import levy_tree_transpose, LevyArea, LevyVal, RealScalarLike
13+
from .._custom_types import (
14+
AbstractBrownianIncrement,
15+
BrownianIncrement,
16+
levy_tree_transpose,
17+
RealScalarLike,
18+
SpaceTimeLevyArea,
19+
)
1420
from .._misc import (
1521
force_bitcast_convert_type,
1622
is_tuple_of_ints,
@@ -42,25 +48,23 @@ class UnsafeBrownianPath(AbstractBrownianPath):
4248
"""
4349

4450
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
45-
levy_area: LevyArea = eqx.field(static=True)
51+
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]] = eqx.field(
52+
static=True
53+
)
4654
key: PRNGKeyArray
4755

4856
def __init__(
4957
self,
5058
shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
5159
key: PRNGKeyArray,
52-
levy_area: LevyArea = "",
60+
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]],
5361
):
5462
self.shape = (
5563
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
5664
if is_tuple_of_ints(shape)
5765
else shape
5866
)
5967
self.key = key
60-
if levy_area not in ["", "space-time"]:
61-
raise ValueError(
62-
f"levy_area must be one of '', 'space-time', but got {levy_area}."
63-
)
6468
self.levy_area = levy_area
6569

6670
if any(
@@ -84,7 +88,7 @@ def evaluate(
8488
t1: Optional[RealScalarLike] = None,
8589
left: bool = True,
8690
use_levy: bool = False,
87-
) -> Union[PyTree[Array], LevyVal]:
91+
) -> Union[PyTree[Array], AbstractBrownianIncrement]:
8892
del left
8993
if t1 is None:
9094
dtype = jnp.result_type(t0)
@@ -111,8 +115,8 @@ def evaluate(
111115
self.shape,
112116
)
113117
if use_levy:
114-
out = levy_tree_transpose(self.shape, self.levy_area, out)
115-
assert isinstance(out, LevyVal)
118+
out = levy_tree_transpose(self.shape, out)
119+
assert isinstance(out, (BrownianIncrement, SpaceTimeLevyArea))
116120
return out
117121

118122
@staticmethod
@@ -121,25 +125,26 @@ def _evaluate_leaf(
121125
t1: RealScalarLike,
122126
key,
123127
shape: jax.ShapeDtypeStruct,
124-
levy_area: str,
128+
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]],
125129
use_levy: bool,
126130
):
127131
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
132+
w = jr.normal(key, shape.shape, shape.dtype) * w_std
133+
dt = jnp.asarray(t1 - t0, dtype=shape.dtype)
128134

129-
if levy_area == "space-time":
135+
if levy_area is SpaceTimeLevyArea:
130136
key, key_hh = jr.split(key, 2)
131137
hh_std = w_std / math.sqrt(12)
132138
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
133-
elif levy_area == "":
134-
hh = None
139+
levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh)
140+
elif levy_area is BrownianIncrement:
141+
levy_val = BrownianIncrement(dt=dt, W=w)
135142
else:
136143
assert False
137-
w = jr.normal(key, shape.shape, shape.dtype) * w_std
138144

139145
if use_levy:
140-
return LevyVal(dt=t1 - t0, W=w, H=hh, bar_H=None, K=None, bar_K=None)
141-
else:
142-
return w
146+
return levy_val
147+
return w
143148

144149

145150
UnsafeBrownianPath.__init__.__doc__ = """

0 commit comments

Comments
 (0)