Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions diffrax/_solver/implicit_euler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from collections.abc import Callable
from typing import ClassVar, TypeAlias

import jax.numpy as jnp
import jax.tree_util as jtu
import optimistix as optx
from equinox.internal import ω

from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
from .._heuristics import is_sde
from .._local_interpolation import LocalLinearInterpolation
from .._root_finder import with_stepsize_controller_tols
from .._solution import RESULTS
from .._solution import is_okay, RESULTS
from .._term import AbstractTerm
from .base import AbstractAdaptiveSolver, AbstractImplicitSolver

Expand Down Expand Up @@ -94,9 +96,13 @@ def step(
y1 = (y0**ω + k1**ω).ω
# Use the trapezoidal rule for adaptive step sizing.
y_error = (0.5 * (k1**ω - k0**ω)).ω
result = RESULTS.promote(nonlinear_sol.result)
y_error = jtu.tree_map(
lambda _y_error: jnp.where(is_okay(result), _y_error, jnp.inf),
y_error,
) # i.e. an implicit step failed to converge
dense_info = dict(y0=y0, y1=y1)
solver_state = None
result = RESULTS.promote(nonlinear_sol.result)
return y1, y_error, dense_info, solver_state, result

def func(
Expand Down
Loading