3333)
3434
3535
36- _ControlTerm = Union [ControlTerm , WeaklyDiagonalControlTerm ]
36+ DiffusionTerm = Union [ControlTerm , WeaklyDiagonalControlTerm ]
3737
3838
3939def _compute_kl_integral (
4040 drift_term1 : ODETerm ,
4141 drift_term2 : ODETerm ,
42- diffusion_term : _ControlTerm ,
42+ diffusion_term : DiffusionTerm ,
4343 t0 : RealScalarLike ,
4444 y0 : Y ,
4545 args : Args ,
@@ -95,7 +95,7 @@ def _compute_kl_integral(
9595class _KLDrift (AbstractTerm ):
9696 drift1 : ODETerm
9797 drift2 : ODETerm
98- diffusion : _ControlTerm
98+ diffusion : DiffusionTerm
9999 linear_solver : lx .AbstractLinearSolver
100100
101101 def vf (self , t : RealScalarLike , y : Y , args : Args ) -> Tuple [VF , RealScalarLike ]:
@@ -112,7 +112,7 @@ def prod(self, vf: VF, control: RealScalarLike) -> Y:
112112
113113
114114class _KLControlTerm (AbstractTerm ):
115- control_term : _ControlTerm
115+ control_term : DiffusionTerm
116116
117117 def vf (self , t : RealScalarLike , y : Y , args : Args ) -> Tuple [VF , RealScalarLike ]:
118118 y , _ = y
@@ -160,7 +160,7 @@ class KLSolver(AbstractWrappedSolver[_SolverState]):
160160 The input must be a `MultiTerm` composed of the first SDE with drift `f`
161161 and diffusion `g` and the second either a SDE or just the drift term
162162 (since the diffusion is assumed to be the same). For example, a type
163- of: `MuliTerm(MultiTerm(ODETerm, _ControlTerm ), ODETerm)`.
163+ of: `MuliTerm(MultiTerm(ODETerm, DiffusionTerm ), ODETerm)`.
164164
165165 ??? cite "References"
166166
@@ -260,12 +260,12 @@ def step(
260260 drift_term1 , drift_term2 = drift_term1 [0 ], drift_term2 [0 ]
261261
262262 diffusion_term = jtu .tree_map (
263- lambda x : x if isinstance (x , _ControlTerm ) else None ,
263+ lambda x : x if isinstance (x , DiffusionTerm ) else None ,
264264 terms1 ,
265- is_leaf = lambda x : isinstance (x , _ControlTerm ),
265+ is_leaf = lambda x : isinstance (x , DiffusionTerm ),
266266 )
267267 diffusion_term = jtu .tree_leaves (
268- diffusion_term , is_leaf = lambda x : isinstance (x , _ControlTerm )
268+ diffusion_term , is_leaf = lambda x : isinstance (x , DiffusionTerm )
269269 )
270270
271271 diffusion_term = eqx .error_if (
0 commit comments