Skip to content

Commit d3490e6

Browse files
Fixes for JAX 0.4.36 which changes the name of an error.
1 parent 3c21d15 commit d3490e6

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

test/test_adjoint.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import diffrax
55
import equinox as eqx
66
import jax
7-
import jax.interpreters.ad
7+
import jax._src.interpreters.ad
88
import jax.numpy as jnp
99
import jax.random as jr
1010
import jax.tree_util as jtu
@@ -21,6 +21,7 @@ class _VectorField(eqx.Module):
2121
diff_arg: float
2222

2323
def __call__(self, t, y, args):
24+
del t
2425
assert y.shape == (2,)
2526
diff_arg, nondiff_arg = args
2627
dya = diff_arg * y[0] + nondiff_arg * y[1]
@@ -29,7 +30,7 @@ def __call__(self, t, y, args):
2930

3031

3132
@pytest.mark.slow
32-
def test_against(getkey):
33+
def test_against():
3334
y0 = jnp.array([0.9, 5.4])
3435
args = (0.1, -1)
3536
term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1))
@@ -215,6 +216,7 @@ def test_closure_errors():
215216
@eqx.filter_value_and_grad
216217
def run(model):
217218
def f(t, y, args):
219+
del t, args
218220
return model(y)
219221

220222
sol = diffrax.diffeqsolve(
@@ -228,7 +230,7 @@ def f(t, y, args):
228230
)
229231
return jnp.sum(cast(Array, sol.ys))
230232

231-
with pytest.raises(jax.interpreters.ad.CustomVJPException):
233+
with pytest.raises(jax._src.interpreters.ad.CustomVJPException):
232234
run(mlp)
233235

234236

@@ -239,6 +241,7 @@ class VectorField(eqx.Module):
239241
model: Callable
240242

241243
def __call__(self, t, y, args):
244+
del t, args
242245
return self.model(y)
243246

244247
@eqx.filter_jit
@@ -307,12 +310,12 @@ def make_step(model, opt_state, target_steady_state):
307310
model = eqx.apply_updates(model, updates)
308311
return model, opt_state
309312

310-
for step in range(100):
313+
for _ in range(100):
311314
model, opt_state = make_step(model, opt_state, target_steady_state)
312315
assert tree_allclose(model.steady_state, target_steady_state, rtol=1e-2, atol=1e-2)
313316

314317

315-
def test_backprop_ts(getkey):
318+
def test_backprop_ts():
316319
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))
317320

318321
@eqx.filter_jit
@@ -338,14 +341,17 @@ def run(model):
338341
)
339342
def test_sde_against(diffusion_fn, getkey):
340343
def f(t, y, args):
344+
del t
341345
k0, _ = args
342346
return -k0 * y
343347

344348
def g(t, y, args):
349+
del t
345350
_, k1 = args
346351
return k1 * y
347352

348353
def g_lx(t, y, args):
354+
del t
349355
_, k1 = args
350356
return lx.DiagonalLinearOperator(k1 * y)
351357

0 commit comments

Comments
 (0)