Skip to content

Commit 825e4e0

Browse files
Fixed where a nonbatchable check was being called.
1 parent 1ae1d58 commit 825e4e0

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

diffrax/_solver/runge_kutta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def eval_k_jac():
964964
assert implicit_tableau.a_diagonal[0] == 0 # pyright: ignore
965965
assert len(set(implicit_tableau.a_diagonal[1:])) == 1 # pyright: ignore
966966
jac_stage_index = 1
967-
stage_index = eqxi.nonbatchable(stage_index)
967+
stage_index = eqxi.nonbatchable(stage_index)
968968
# These `stop_gradients` are needed to work around the lack of
969969
# symbolic zeros in `custom_vjp`s.
970970
if eval_fs:

test/test_global_interpolation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def _test_dense_interpolation(solver, key, t1):
340340

341341

342342
@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers)
343-
def test_dense_interpolation(solver, getkey):
343+
def test_dense_interpolation(solver):
344344
solver = implicit_tol(solver)
345345
key = jr.PRNGKey(5678)
346346
vals, true_vals, derivs, true_derivs = _test_dense_interpolation(solver, key, 1)

0 commit comments

Comments
 (0)