44import diffrax
55import equinox as eqx
66import jax
7- import jax .interpreters .ad
7+ import jax ._src . interpreters .ad
88import jax .numpy as jnp
99import jax .random as jr
1010import jax .tree_util as jtu
@@ -21,6 +21,7 @@ class _VectorField(eqx.Module):
2121 diff_arg : float
2222
2323 def __call__ (self , t , y , args ):
24+ del t
2425 assert y .shape == (2 ,)
2526 diff_arg , nondiff_arg = args
2627 dya = diff_arg * y [0 ] + nondiff_arg * y [1 ]
@@ -29,7 +30,7 @@ def __call__(self, t, y, args):
2930
3031
3132@pytest .mark .slow
32- def test_against (getkey ):
33+ def test_against ():
3334 y0 = jnp .array ([0.9 , 5.4 ])
3435 args = (0.1 , - 1 )
3536 term = diffrax .ODETerm (_VectorField (nondiff_arg = 1 , diff_arg = - 0.1 ))
@@ -215,6 +216,7 @@ def test_closure_errors():
215216 @eqx .filter_value_and_grad
216217 def run (model ):
217218 def f (t , y , args ):
219+ del t , args
218220 return model (y )
219221
220222 sol = diffrax .diffeqsolve (
@@ -228,7 +230,7 @@ def f(t, y, args):
228230 )
229231 return jnp .sum (cast (Array , sol .ys ))
230232
231- with pytest .raises (jax .interpreters .ad .CustomVJPException ):
233+ with pytest .raises (jax ._src . interpreters .ad .CustomVJPException ):
232234 run (mlp )
233235
234236
@@ -239,6 +241,7 @@ class VectorField(eqx.Module):
239241 model : Callable
240242
241243 def __call__ (self , t , y , args ):
244+ del t , args
242245 return self .model (y )
243246
244247 @eqx .filter_jit
@@ -307,12 +310,12 @@ def make_step(model, opt_state, target_steady_state):
307310 model = eqx .apply_updates (model , updates )
308311 return model , opt_state
309312
310- for step in range (100 ):
313+ for _ in range (100 ):
311314 model , opt_state = make_step (model , opt_state , target_steady_state )
312315 assert tree_allclose (model .steady_state , target_steady_state , rtol = 1e-2 , atol = 1e-2 )
313316
314317
315- def test_backprop_ts (getkey ):
318+ def test_backprop_ts ():
316319 mlp = eqx .nn .MLP (1 , 1 , 8 , 2 , key = jr .PRNGKey (0 ))
317320
318321 @eqx .filter_jit
@@ -338,14 +341,17 @@ def run(model):
338341)
339342def test_sde_against (diffusion_fn , getkey ):
340343 def f (t , y , args ):
344+ del t
341345 k0 , _ = args
342346 return - k0 * y
343347
344348 def g (t , y , args ):
349+ del t
345350 _ , k1 = args
346351 return k1 * y
347352
348353 def g_lx (t , y , args ):
354+ del t
349355 _ , k1 = args
350356 return lx .DiagonalLinearOperator (k1 * y )
351357
0 commit comments