Skip to content

Commit 5280c97

Browse files
andyElkingpatrick-kidger
authored andcommitted
Split out jump/step clipping in stepsize controllers.
1 parent 4228819 commit 5280c97

11 files changed

Lines changed: 1267 additions & 224 deletions

File tree

benchmarks/jump_step_timing.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from warnings import simplefilter
2+
3+
4+
simplefilter(action="ignore", category=FutureWarning)
5+
6+
import timeit
7+
from functools import partial
8+
9+
import diffrax
10+
import equinox as eqx
11+
import jax
12+
import jax.numpy as jnp
13+
import jax.random as jr
14+
from old_pid_controller import OldPIDController
15+
16+
17+
t0 = 0
18+
t1 = 5
19+
dt0 = 0.5
20+
y0 = 1.0
21+
drift = diffrax.ODETerm(lambda t, y, args: -0.2 * y)
22+
23+
24+
def diffusion_vf(t, y, args):
25+
return jnp.ones((), dtype=y.dtype)
26+
27+
28+
def get_terms(key):
29+
bm = diffrax.VirtualBrownianTree(t0, t1, 2**-5, (), key)
30+
diffusion = diffrax.ControlTerm(diffusion_vf, bm)
31+
return diffrax.MultiTerm(drift, diffusion)
32+
33+
34+
solver = diffrax.Heun()
35+
step_ts = jnp.linspace(t0, t1, 129, endpoint=True)
36+
pid_controller = diffrax.PIDController(
37+
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7
38+
)
39+
new_controller = diffrax.JumpStepWrapper(
40+
pid_controller,
41+
step_ts=step_ts,
42+
rejected_step_buffer_len=None,
43+
)
44+
old_controller = OldPIDController(
45+
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7, step_ts=step_ts
46+
)
47+
48+
49+
@eqx.filter_jit
50+
@partial(jax.vmap, in_axes=(0, None))
51+
def solve(key, controller):
52+
term = get_terms(key)
53+
return diffrax.diffeqsolve(
54+
term,
55+
solver,
56+
t0,
57+
t1,
58+
dt0,
59+
y0,
60+
stepsize_controller=controller,
61+
saveat=diffrax.SaveAt(ts=step_ts),
62+
)
63+
64+
65+
num_samples = 100
66+
keys = jr.split(jr.PRNGKey(0), num_samples)
67+
68+
69+
def do_timing(controller):
70+
@jax.jit
71+
@eqx.debug.assert_max_traces(max_traces=1)
72+
def time_controller_fun():
73+
sols = solve(keys, controller)
74+
assert sols.ys is not None
75+
assert sols.ys.shape == (num_samples, len(step_ts))
76+
return sols.ys
77+
78+
def time_controller():
79+
jax.block_until_ready(time_controller_fun())
80+
81+
return min(timeit.repeat(time_controller, number=3, repeat=20))
82+
83+
84+
time_new = do_timing(new_controller)
85+
86+
time_old = do_timing(old_controller)
87+
88+
print(f"New controller: {time_new:.5} s, Old controller: {time_old:.5} s")
89+
90+
# How expensive is revisiting rejected steps?
91+
revisiting_controller_short = diffrax.JumpStepWrapper(
92+
pid_controller,
93+
step_ts=step_ts,
94+
rejected_step_buffer_len=10,
95+
)
96+
97+
revisiting_controller_long = diffrax.JumpStepWrapper(
98+
pid_controller,
99+
step_ts=step_ts,
100+
rejected_step_buffer_len=4096,
101+
)
102+
103+
time_revisiting_short = do_timing(revisiting_controller_short)
104+
time_revisiting_long = do_timing(revisiting_controller_long)
105+
106+
print(
107+
f"Revisiting controller\n"
108+
f"with buffer len 10: {time_revisiting_short:.5} s\n"
109+
f"with buffer len 4096: {time_revisiting_long:.5} s"
110+
)
111+
112+
# ======= RESULTS =======
113+
# New controller: 0.23506 s, Old controller: 0.30735 s
114+
# Revisiting controller
115+
# with buffer len 10: 0.23636 s
116+
# with buffer len 4096: 0.23965 s

0 commit comments

Comments
 (0)