We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2b56424 commit 633afbdCopy full SHA for 633afbd
1 file changed
diffrax/_solver/kl.py
@@ -1,5 +1,5 @@
1
import operator
2
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
3
4
import equinox as eqx
5
import jax
@@ -20,8 +20,8 @@
20
from .._heuristics import is_sde
21
from .._solution import RESULTS
22
from .._term import (
23
- _ControlTerm,
24
AbstractTerm,
+ ControlTerm,
25
MultiTerm,
26
ODETerm,
27
WeaklyDiagonalControlTerm,
@@ -33,6 +33,9 @@
33
)
34
35
36
+_ControlTerm = Union[ControlTerm, WeaklyDiagonalControlTerm]
37
+
38
39
def _compute_kl_integral(
40
drift_term1: ODETerm,
41
drift_term2: ODETerm,
0 commit comments