11import abc
22import functools as ft
33import warnings
4- from collections .abc import Iterable
5- from typing import Any , Optional , Union
4+ from collections .abc import Callable , Iterable
5+ from typing import Any , cast , Optional , Union
66
77import equinox as eqx
88import equinox .internal as eqxi
2020from ._term import AbstractTerm , AdjointTerm
2121
2222
23+ ω = cast (Callable , ω )
24+
25+
2326def _is_none (x ):
2427 return x is None
2528
@@ -118,7 +121,7 @@ def loop(
118121 terms ,
119122 solver ,
120123 stepsize_controller ,
121- discrete_terminating_event ,
124+ event ,
122125 saveat ,
123126 t0 ,
124127 t1 ,
@@ -128,6 +131,7 @@ def loop(
128131 init_state ,
129132 passed_solver_state ,
130133 passed_controller_state ,
134+ progress_meter ,
131135 ) -> Any :
132136 """Runs the main solve loop. Subclasses can override this to provide custom
133137 backpropagation behaviour; see for example the implementation of
@@ -425,6 +429,14 @@ def _solve(inputs):
425429 )
426430
427431
432+ # Unwrap jaxtyping decorator during tests, so that these are global functions.
433+ # This is needed to ensure `optx.implicit_jvp` is happy.
434+ if _vf .__globals__ ["__name__" ].startswith ("jaxtyping" ):
435+ _vf = _vf .__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
436+ if _solve .__globals__ ["__name__" ].startswith ("jaxtyping" ):
437+ _solve = _solve .__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
438+
439+
428440def _frozenset (x : Union [object , Iterable [object ]]) -> frozenset [object ]:
429441 try :
430442 iter_x = iter (x ) # pyright: ignore
@@ -438,7 +450,8 @@ class ImplicitAdjoint(AbstractAdjoint):
438450 r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
439451
440452 This is used when solving towards a steady state, typically using
441- [`diffrax.SteadyStateEvent`][]. In this case, the output of the solver is $y(θ)$
453+ [`diffrax.Event`][] where the condition function is obtained by calling
454+ [`diffrax.steady_state_event`][]. In this case, the output of the solver is $y(θ)$
442455 for which $f(t, y(θ), θ) = 0$. (Where $θ$ corresponds to all parameters found
443456 through `terms` and `args`, but not `y0`.) Then we can skip backpropagating through
444457 the solver and instead directly compute
@@ -551,23 +564,24 @@ def _loop_backsolve_bwd(
551564 self ,
552565 solver ,
553566 stepsize_controller ,
554- discrete_terminating_event ,
567+ event ,
555568 saveat ,
556569 t0 ,
557570 t1 ,
558571 dt0 ,
559572 max_steps ,
560573 throw ,
561574 init_state ,
575+ progress_meter ,
562576):
563- assert discrete_terminating_event is None
577+ assert event is None
564578
565579 #
566580 # Unpack our various arguments. Delete a lot of things just to make sure we're not
567581 # using them later.
568582 #
569583
570- del perturbed , init_state , t1
584+ del perturbed , init_state , t1 , progress_meter
571585 ts , ys = residuals
572586 del residuals
573587 grad_final_state , _ = grad_final_state__aux_stats
@@ -774,7 +788,7 @@ def loop(
774788 init_state ,
775789 passed_solver_state ,
776790 passed_controller_state ,
777- discrete_terminating_event ,
791+ event ,
778792 ** kwargs ,
779793 ):
780794 if jtu .tree_structure (saveat .subs , is_leaf = _is_subsaveat ) != jtu .tree_structure (
@@ -816,7 +830,7 @@ def loop(
816830 "`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
817831 "a single term."
818832 )
819- if discrete_terminating_event is not None :
833+ if event is not None :
820834 raise NotImplementedError (
821835 "`diffrax.BacksolveAdjoint` is not compatible with events."
822836 )
@@ -833,7 +847,7 @@ def loop(
833847 saveat = saveat ,
834848 init_state = init_state ,
835849 solver = solver ,
836- discrete_terminating_event = discrete_terminating_event ,
850+ event = event ,
837851 ** kwargs ,
838852 )
839853 final_state = _only_transpose_ys (final_state )
0 commit comments