Skip to content

Commit 1420287

Browse files
committed
Fix y_error returned from implicit Euler step when non-linear solve fails
1 parent 7f0001c commit 1420287

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

diffrax/_solver/implicit_euler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
import optimistix as optx
55
from equinox.internal import ω
6+
import jax.numpy as jnp
7+
import jax.tree_util as jtu
68

79
from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
810
from .._heuristics import is_sde
911
from .._local_interpolation import LocalLinearInterpolation
1012
from .._root_finder import with_stepsize_controller_tols
11-
from .._solution import RESULTS
13+
from .._solution import is_okay, RESULTS
1214
from .._term import AbstractTerm
1315
from .base import AbstractAdaptiveSolver, AbstractImplicitSolver
1416

@@ -94,9 +96,13 @@ def step(
9496
y1 = (y0**ω + k1**ω).ω
9597
# Use the trapezoidal rule for adaptive step sizing.
9698
y_error = (0.5 * (k1**ω - k0**ω)).ω
99+
result = RESULTS.promote(nonlinear_sol.result)
100+
y_error = jtu.tree_map(
101+
lambda _y_error: jnp.where(is_okay(result), _y_error, jnp.inf),
102+
y_error,
103+
) # i.e. an implicit step failed to converge
97104
dense_info = dict(y0=y0, y1=y1)
98105
solver_state = None
99-
result = RESULTS.promote(nonlinear_sol.result)
100106
return y1, y_error, dense_info, solver_state, result
101107

102108
def func(

0 commit comments

Comments
 (0)