|
3 | 3 |
|
4 | 4 | import optimistix as optx |
5 | 5 | from equinox.internal import ω |
| 6 | +import jax.numpy as jnp |
| 7 | +import jax.tree_util as jtu |
6 | 8 |
|
7 | 9 | from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y |
8 | 10 | from .._heuristics import is_sde |
9 | 11 | from .._local_interpolation import LocalLinearInterpolation |
10 | 12 | from .._root_finder import with_stepsize_controller_tols |
11 | | -from .._solution import RESULTS |
| 13 | +from .._solution import is_okay, RESULTS |
12 | 14 | from .._term import AbstractTerm |
13 | 15 | from .base import AbstractAdaptiveSolver, AbstractImplicitSolver |
14 | 16 |
|
@@ -94,9 +96,13 @@ def step( |
94 | 96 | y1 = (y0**ω + k1**ω).ω |
95 | 97 | # Use the trapezoidal rule for adaptive step sizing. |
96 | 98 | y_error = (0.5 * (k1**ω - k0**ω)).ω |
| 99 | + result = RESULTS.promote(nonlinear_sol.result) |
| 100 | + y_error = jtu.tree_map( |
| 101 | + lambda _y_error: jnp.where(is_okay(result), _y_error, jnp.inf), |
| 102 | + y_error, |
| 103 | + ) # i.e. an implicit step failed to converge |
97 | 104 | dense_info = dict(y0=y0, y1=y1) |
98 | 105 | solver_state = None |
99 | | - result = RESULTS.promote(nonlinear_sol.result) |
100 | 106 | return y1, y_error, dense_info, solver_state, result |
101 | 107 |
|
102 | 108 | def func( |
|
0 commit comments