Skip to content

Commit 27702dd

Browse files
Merge pull request #257 from patrick-kidger/better-rk
Better rk
2 parents f101e75 + 77b1a60 commit 27702dd

51 files changed

Lines changed: 2186 additions & 849 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[flake8]
22
max-line-length = 88
3-
ignore = W291,W503,W504,E121,E123,E126,E203,E402,E701,E702,E731
3+
ignore = W291,W503,W504,E121,E123,E126,E203,E402,E701,E702,E731,F722
44
per-file-ignores = __init__.py: F401

.github/workflows/build_docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
build:
1010
strategy:
1111
matrix:
12-
python-version: [ 3.8 ]
12+
python-version: [ 3.11 ]
1313
os: [ ubuntu-latest ]
1414
runs-on: ${{ matrix.os }}
1515
steps:

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
- name: Release
1313
uses: patrick-kidger/action_update_python_project@v1
1414
with:
15-
python-version: "3.8"
15+
python-version: "3.11"
1616
test-script: |
1717
python -m pip install pytest psutil jax jaxlib equinox scipy optax
1818
cp -r ${{ github.workspace }}/test ./test

.github/workflows/run_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
run-tests:
88
strategy:
99
matrix:
10-
python-version: [ 3.8, 3.9 ]
10+
python-version: [ 3.9, 3.11 ]
1111
os: [ ubuntu-latest ]
1212
fail-fast: false
1313
runs-on: ${{ matrix.os }}

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty
2121
pip install diffrax
2222
```
2323

24-
Requires Python 3.8+, JAX 0.4.3+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.0+.
24+
Requires Python 3.9+, JAX 0.4.4+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.4+.
2525

2626
## Documentation
2727

benchmarks/scan_stages.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

benchmarks/scan_stages_cnf.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

diffrax/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .misc import adjoint_rms_seminorm
3232
from .nonlinear_solver import (
3333
AbstractNonlinearSolver,
34+
AffineNonlinearSolver,
3435
NewtonNonlinearSolver,
3536
NonlinearSolution,
3637
)
@@ -60,14 +61,19 @@
6061
Heun,
6162
ImplicitEuler,
6263
ItoMilstein,
64+
KenCarp3,
65+
KenCarp4,
66+
KenCarp5,
6367
Kvaerno3,
6468
Kvaerno4,
6569
Kvaerno5,
6670
LeapfrogMidpoint,
6771
Midpoint,
72+
MultiButcherTableau,
6873
Ralston,
6974
ReversibleHeun,
7075
SemiImplicitEuler,
76+
Sil3,
7177
StratonovichMilstein,
7278
Tsit5,
7379
)

diffrax/adjoint.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .ad import implicit_jvp
1515
from .heuristics import is_sde, is_unsafe_sde
1616
from .saveat import save_y, SaveAt, SubSaveAt
17-
from .solver import AbstractItoSolver, AbstractStratonovichSolver
17+
from .solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
1818
from .term import AbstractTerm, AdjointTerm
1919

2020

@@ -332,6 +332,7 @@ class DirectAdjoint(AbstractAdjoint):
332332
def loop(
333333
self,
334334
*,
335+
solver,
335336
max_steps,
336337
terms,
337338
throw,
@@ -362,10 +363,15 @@ def loop(
362363
else:
363364
kind = "bounded"
364365
msg = None
366+
# Support forward-mode autodiff.
367+
# TODO: remove this hack once we can JVP through custom_vjps.
368+
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
369+
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded")
365370
inner_while_loop = ft.partial(_inner_loop, kind=kind)
366371
outer_while_loop = ft.partial(_outer_loop, kind=kind)
367372
final_state = self._loop(
368373
**kwargs,
374+
solver=solver,
369375
max_steps=max_steps,
370376
terms=terms,
371377
inner_while_loop=inner_while_loop,
@@ -535,6 +541,8 @@ def _loop_backsolve_bwd(
535541
zeros_like_diff_args = jtu.tree_map(jnp.zeros_like, diff_args)
536542
zeros_like_diff_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
537543
del diff_args, diff_terms
544+
# TODO: have this look inside MultiTerms? Need to think about the math. i.e.:
545+
# is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm)
538546
adjoint_terms = jtu.tree_map(
539547
AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
540548
)
@@ -762,6 +770,11 @@ def loop(
762770
"`BacksolveAdjoint` will only produce the correct solution for "
763771
"Stratonovich SDEs."
764772
)
773+
if jtu.tree_structure(solver.term_structure) != jtu.tree_structure(0):
774+
raise NotImplementedError(
775+
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
776+
"a single term."
777+
)
765778

766779
y = init_state.y
767780
init_state = eqx.tree_at(lambda s: s.y, init_state, object())

diffrax/custom_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import inspect
22
import typing
3-
from typing import Dict, Generic, Tuple, TypeVar, Union
3+
from typing import Any, Dict, Generic, Tuple, TypeVar, Union
44

5+
import equinox.internal as eqxi
56
import jax.tree_util as jtu
67

78

@@ -129,3 +130,4 @@ def __class_getitem__(cls, item):
129130

130131
DenseInfo = Dict[str, PyTree[Array]]
131132
DenseInfos = Dict[str, PyTree[Array["times", ...]]] # noqa: F821
133+
sentinel: Any = eqxi.doc_repr(object(), "sentinel")

0 commit comments

Comments
 (0)