Skip to content

Commit 8c0d6f0

Browse files
Added test for grad-of-discontinuous-forcing
1 parent b71f8fa commit 8c0d6f0

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

test/test_adaptive_stepsize_controller.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import diffrax
22
import equinox as eqx
3+
import jax
34
import jax.numpy as jnp
45
import jax.tree_util as jtu
56

7+
from .helpers import shaped_allclose
8+
69

710
def test_step_ts():
811
term = diffrax.ODETerm(lambda t, y, args: -0.2 * y)
@@ -90,3 +93,45 @@ def run(ys, controller, state):
9093
ys = (y0, y1_candidate, y_error)
9194
grads = run(ys, stepsize_controller, state)
9295
assert not any(jnp.isnan(grad).any() for grad in grads)
96+
97+
98+
def test_grad_of_discontinuous_forcing():
99+
def vector_field(t, y, forcing):
100+
y, _ = y
101+
dy = -y + forcing(t)
102+
dsum = y
103+
return dy, dsum
104+
105+
def run(t):
106+
term = diffrax.ODETerm(vector_field)
107+
solver = diffrax.Tsit5()
108+
t0 = 0
109+
t1 = 1
110+
dt0 = None
111+
y0 = 1.0
112+
stepsize_controller = diffrax.PIDController(
113+
rtol=1e-8, atol=1e-8, step_ts=t[None]
114+
)
115+
116+
def forcing(s):
117+
return jnp.where(s < t, 0, 1)
118+
119+
sol = diffrax.diffeqsolve(
120+
term,
121+
solver,
122+
t0,
123+
t1,
124+
dt0,
125+
(y0, 0),
126+
args=forcing,
127+
stepsize_controller=stepsize_controller,
128+
)
129+
_, sum = sol.ys
130+
(sum,) = sum
131+
return sum
132+
133+
r = jax.jit(run)
134+
eps = 1e-5
135+
finite_diff = (r(0.5) - r(0.5 - eps)) / eps
136+
autodiff = jax.jit(jax.grad(run))(0.5)
137+
assert shaped_allclose(finite_diff, autodiff)

0 commit comments

Comments
 (0)