|
14 | 14 | from .ad import implicit_jvp |
15 | 15 | from .heuristics import is_sde, is_unsafe_sde |
16 | 16 | from .saveat import save_y, SaveAt, SubSaveAt |
17 | | -from .solver import AbstractItoSolver, AbstractStratonovichSolver |
| 17 | +from .solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver |
18 | 18 | from .term import AbstractTerm, AdjointTerm |
19 | 19 |
|
20 | 20 |
|
@@ -332,6 +332,7 @@ class DirectAdjoint(AbstractAdjoint): |
332 | 332 | def loop( |
333 | 333 | self, |
334 | 334 | *, |
| 335 | + solver, |
335 | 336 | max_steps, |
336 | 337 | terms, |
337 | 338 | throw, |
@@ -362,10 +363,15 @@ def loop( |
362 | 363 | else: |
363 | 364 | kind = "bounded" |
364 | 365 | msg = None |
| 366 | + # Support forward-mode autodiff. |
| 367 | + # TODO: remove this hack once we can JVP through custom_vjps. |
| 368 | + if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: |
| 369 | + solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded") |
365 | 370 | inner_while_loop = ft.partial(_inner_loop, kind=kind) |
366 | 371 | outer_while_loop = ft.partial(_outer_loop, kind=kind) |
367 | 372 | final_state = self._loop( |
368 | 373 | **kwargs, |
| 374 | + solver=solver, |
369 | 375 | max_steps=max_steps, |
370 | 376 | terms=terms, |
371 | 377 | inner_while_loop=inner_while_loop, |
@@ -535,6 +541,8 @@ def _loop_backsolve_bwd( |
535 | 541 | zeros_like_diff_args = jtu.tree_map(jnp.zeros_like, diff_args) |
536 | 542 | zeros_like_diff_terms = jtu.tree_map(jnp.zeros_like, diff_terms) |
537 | 543 | del diff_args, diff_terms |
| 544 | + # TODO: have this look inside MultiTerms? Need to think about the math. i.e.: |
| 545 | + # is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm) |
538 | 546 | adjoint_terms = jtu.tree_map( |
539 | 547 | AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm) |
540 | 548 | ) |
@@ -762,6 +770,11 @@ def loop( |
762 | 770 | "`BacksolveAdjoint` will only produce the correct solution for " |
763 | 771 | "Stratonovich SDEs." |
764 | 772 | ) |
| 773 | + if jtu.tree_structure(solver.term_structure) != jtu.tree_structure(0): |
| 774 | + raise NotImplementedError( |
| 775 | + "`diffrax.BacksolveAdjoint` is only compatible with solvers that take " |
| 776 | + "a single term." |
| 777 | + ) |
765 | 778 |
|
766 | 779 | y = init_state.y |
767 | 780 | init_state = eqx.tree_at(lambda s: s.y, init_state, object()) |
|
0 commit comments