Skip to content

Commit f561120

Browse files
Fix for Optimistix 0.1.0
1 parent 2aa6345 commit f561120

2 files changed

Lines changed: 4 additions & 1 deletion

File tree

diffrax/_root_finder/_verychord.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def terminate(
162162
converged = _converged(factor, self.kappa)
163163
terminate = at_least_two & (small | diverged | converged)
164164
terminate_result = optx.RESULTS.where(
165-
jnp.invert(small) & (diverged | jnp.invert(converged)),
165+
at_least_two & jnp.invert(small) & (diverged | jnp.invert(converged)),
166166
optx.RESULTS.nonlinear_divergence,
167167
optx.RESULTS.successful,
168168
)

test/test_very_chord.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def _fn2(x, args):
2424
@jax.jit
2525
def _fn3(x, args):
2626
mlp = eqx.nn.MLP(4, 4, 256, 2, key=jr.PRNGKey(678))
27+
dynamic, static = eqx.partition(mlp, eqx.is_array)
28+
dynamic = jtu.tree_map(lambda x: x * 0.1, dynamic)
29+
mlp = eqx.combine(dynamic, static)
2730
return mlp(x) - x
2831

2932

0 commit comments

Comments
 (0)