|
1 | 1 | import diffrax |
2 | 2 | import equinox as eqx |
| 3 | +import jax |
3 | 4 | import jax.numpy as jnp |
4 | 5 | import jax.tree_util as jtu |
5 | 6 |
|
| 7 | +from .helpers import shaped_allclose |
| 8 | + |
6 | 9 |
|
7 | 10 | def test_step_ts(): |
8 | 11 | term = diffrax.ODETerm(lambda t, y, args: -0.2 * y) |
@@ -90,3 +93,45 @@ def run(ys, controller, state): |
90 | 93 | ys = (y0, y1_candidate, y_error) |
91 | 94 | grads = run(ys, stepsize_controller, state) |
92 | 95 | assert not any(jnp.isnan(grad).any() for grad in grads) |
| 96 | + |
| 97 | + |
| 98 | +def test_grad_of_discontinuous_forcing(): |
| 99 | + def vector_field(t, y, forcing): |
| 100 | + y, _ = y |
| 101 | + dy = -y + forcing(t) |
| 102 | + dsum = y |
| 103 | + return dy, dsum |
| 104 | + |
| 105 | + def run(t): |
| 106 | + term = diffrax.ODETerm(vector_field) |
| 107 | + solver = diffrax.Tsit5() |
| 108 | + t0 = 0 |
| 109 | + t1 = 1 |
| 110 | + dt0 = None |
| 111 | + y0 = 1.0 |
| 112 | + stepsize_controller = diffrax.PIDController( |
| 113 | + rtol=1e-8, atol=1e-8, step_ts=t[None] |
| 114 | + ) |
| 115 | + |
| 116 | + def forcing(s): |
| 117 | + return jnp.where(s < t, 0, 1) |
| 118 | + |
| 119 | + sol = diffrax.diffeqsolve( |
| 120 | + term, |
| 121 | + solver, |
| 122 | + t0, |
| 123 | + t1, |
| 124 | + dt0, |
| 125 | + (y0, 0), |
| 126 | + args=forcing, |
| 127 | + stepsize_controller=stepsize_controller, |
| 128 | + ) |
| 129 | + _, sum = sol.ys |
| 130 | + (sum,) = sum |
| 131 | + return sum |
| 132 | + |
| 133 | + r = jax.jit(run) |
| 134 | + eps = 1e-5 |
| 135 | + finite_diff = (r(0.5) - r(0.5 - eps)) / eps |
| 136 | + autodiff = jax.jit(jax.grad(run))(0.5) |
| 137 | + assert shaped_allclose(finite_diff, autodiff) |
0 commit comments