Skip to content

Commit 7ea3127

Browse files
committed
fix test
1 parent 6e34acd commit 7ea3127

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

diffrax/_integrate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,6 @@ def _promote(yi):
760760
_dtype = jnp.result_type(yi, time_dtype) # noqa: F821
761761
return jnp.asarray(yi, dtype=_dtype)
762762

763-
if isinstance(solver, KLSolver):
764-
y0 = (y0, 0.0)
765763
y0 = jtu.tree_map(_promote, y0)
766764
del timelikes
767765

@@ -810,6 +808,9 @@ def _promote(yi):
810808
"`UnsafeBrownianPath` cannot be used with adaptive step sizes."
811809
)
812810

811+
if isinstance(solver, KLSolver):
812+
y0 = (y0, 0.0)
813+
y0 = jtu.tree_map(_promote, y0)
813814
# Normalises time: if t0 > t1 then flip things around.
814815
direction = jnp.where(t0 < t1, 1, -1)
815816
t0 = t0 * direction

diffrax/_solver/kl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> Tuple[VF, RealScalarLike]:
104104
self.drift1, self.drift2, self.diffusion, t, y, args, self.linear_solver
105105
)
106106

107-
def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> Control:
107+
def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> Control:
108108
return t1 - t0
109109

110110
def prod(self, vf: VF, control: RealScalarLike) -> Y:
@@ -120,7 +120,7 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> Tuple[VF, RealScalarLike]:
120120
return vf, 0.0
121121

122122
def contr(
123-
self, t0: RealScalarLike, t1: RealScalarLike
123+
self, t0: RealScalarLike, t1: RealScalarLike, **kwargs
124124
) -> Tuple[Control, RealScalarLike]:
125125
return self.control_term.contr(t0, t1), 0.0
126126

0 commit comments

Comments
 (0)