-
-
Notifications
You must be signed in to change notification settings - Fork 179
Expand file tree
/
Copy pathtest_adaptive_stepsize_controller.py
More file actions
390 lines (342 loc) · 12 KB
/
test_adaptive_stepsize_controller.py
File metadata and controls
390 lines (342 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
from typing import cast
import diffrax
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import optimistix as optx
import pytest
from diffrax._step_size_controller.clip import _find_idx_with_hint
from jaxtyping import Array
from .helpers import tree_allclose
@pytest.mark.parametrize("backwards", [False, True])
def test_step_ts(backwards):
term = diffrax.ODETerm(lambda t, y, args: -0.2 * y)
solver = diffrax.Dopri5()
t0 = 0
t1 = 5
if backwards:
t0, t1 = t1, t0
dt0 = None
y0 = 1.0
pid_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6)
stepsize_controller = diffrax.ClipStepSizeController(pid_controller, step_ts=[3, 4])
saveat = diffrax.SaveAt(steps=True)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
stepsize_controller=stepsize_controller,
saveat=saveat,
)
assert 3 in cast(Array, sol.ts)
assert 4 in cast(Array, sol.ts)
@pytest.mark.parametrize("backwards", [False, True])
def test_jump_ts(backwards):
# Tests no regression of https://github.com/patrick-kidger/diffrax/issues/58
def vector_field(t, y, args):
x, v = y
force = jnp.where(t < 7.5, 10, -10)
return v, -4 * jnp.pi**2 * x - 4 * jnp.pi * 0.05 * v + force
term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
t0 = 0
t1 = 15
if backwards:
t0, t1 = t1, t0
dt0 = None
y0 = 1.5, 0
saveat = diffrax.SaveAt(steps=True)
def run(**kwargs):
pid_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6)
stepsize_controller = diffrax.ClipStepSizeController(pid_controller, **kwargs)
return diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
stepsize_controller=stepsize_controller,
saveat=saveat,
)
sol_no_jump_ts = run()
sol_with_jump_ts = run(jump_ts=[7.5])
assert sol_no_jump_ts.stats["num_steps"] > sol_with_jump_ts.stats["num_steps"]
assert sol_with_jump_ts.result == diffrax.RESULTS.successful
sol = run(jump_ts=[7.5], step_ts=[7.5])
assert sol.result == diffrax.RESULTS.successful
sol = run(jump_ts=[7.5], step_ts=[3.5, 8])
assert sol.result == diffrax.RESULTS.successful
assert 3.5 in cast(Array, sol.ts)
assert 8 in cast(Array, sol.ts)
@pytest.mark.parametrize("backwards", [False, True])
def test_revisit_steps(backwards):
t0 = 0.0
t1 = 5.0
dt0 = 0.5
if backwards:
t0, t1 = t1, t0
dt0 = -dt0
y0 = 1.0
drift = diffrax.ODETerm(lambda t, y, args: -0.2 * y)
def diffusion_vf(t, y, args):
return jnp.ones((), dtype=y.dtype)
bm = diffrax.VirtualBrownianTree(min(t0, t1), max(t0, t1), 2**-8, (), jr.key(0))
diffusion = diffrax.ControlTerm(diffusion_vf, bm)
term = diffrax.MultiTerm(drift, diffusion)
solver = diffrax.Heun()
pid_controller = diffrax.PIDController(
rtol=0, atol=1e-3, dtmin=2**-7, pcoeff=0.5, icoeff=0.8
)
rejected_ts_list = []
def callback_fun(keep_step, t1):
if not keep_step:
rejected_ts_list.append(t1.item())
return None
store_rejected_steps = 10
stepsize_controller = diffrax.ClipStepSizeController(
pid_controller,
step_ts=[3, 4],
store_rejected_steps=store_rejected_steps,
_callback_on_reject=callback_fun,
)
saveat = diffrax.SaveAt(steps=True, controller_state=True)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
stepsize_controller=stepsize_controller,
saveat=saveat,
)
assert sol.ts is not None
rejected_ts = jnp.array(rejected_ts_list)
if backwards:
rejected_ts = -rejected_ts
# there should be many rejected steps, otherwise something went wrong
assert len(rejected_ts) > 10
# check if all rejected ts are in the array sol.ts
ts = sol.ts[sol.ts != jnp.inf]
if backwards:
ts = ts[::-1]
for t in rejected_ts:
i = jnp.searchsorted(ts, t)
assert ts[i] == t
assert 3 in cast(Array, sol.ts)
assert 4 in cast(Array, sol.ts)
# Check that at the end of the run, the rejected stack is empty,
# i.e. rejected_index == store_rejected_steps
assert sol.controller_state is not None
reject_index, _ = sol.controller_state.reject_info
assert reject_index == store_rejected_steps
@pytest.mark.parametrize("use_clip", [True, False])
def test_backprop(use_clip):
t0 = jnp.asarray(0, dtype=jnp.float64)
t1 = jnp.asarray(1, dtype=jnp.float64)
@eqx.filter_jit
@eqx.filter_grad
def run(ys, controller, state):
y0, y1_candidate, y_error = ys
_, tprev, tnext, _, state, _ = controller.adapt_step_size(
t0, t1, y0, y1_candidate, None, y_error, 5, state
)
with jax.numpy_dtype_promotion("standard"):
return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state))
y0 = jnp.array(1.0)
y1_candidate = jnp.array(2.0)
term = diffrax.ODETerm(lambda t, y, args: -y)
solver = diffrax.Tsit5()
controller = diffrax.PIDController(rtol=1e-4, atol=1e-4)
if use_clip:
controller = diffrax.ClipStepSizeController(
controller, step_ts=[0.5], store_rejected_steps=20
)
_, state = controller.init(term, t0, t1, y0, 0.1, None, solver.func, 5)
for y_error in (jnp.array(0.0), jnp.array(3.0), jnp.array(jnp.inf)):
ys = (y0, y1_candidate, y_error)
grads = run(ys, controller, state)
assert not any(jnp.isnan(grad).any() for grad in grads)
def test_grad_of_discontinuous_forcing():
def vector_field(t, y, forcing):
y, _ = y
dy = -y + forcing(t)
dsum = y
return dy, dsum
def run(t):
term = diffrax.ODETerm(vector_field)
solver = diffrax.Tsit5()
t0 = 0
t1 = 1
dt0 = None
y0 = 1.0
pid_controller = diffrax.PIDController(
rtol=1e-8,
atol=1e-8,
)
stepsize_controller = diffrax.ClipStepSizeController(
pid_controller, step_ts=t[None]
)
def forcing(s):
return jnp.where(s < t, 0, 1)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
(y0, 0),
args=forcing,
stepsize_controller=stepsize_controller,
)
_, sum = cast(Array, sol.ys)
(sum,) = sum
return sum
r = jax.jit(run)
eps = 1e-5
finite_diff = (r(0.5) - r(0.5 - eps)) / eps
autodiff = jax.jit(jax.grad(run))(0.5)
assert tree_allclose(finite_diff, autodiff)
def test_pid_meta():
ts = jnp.array([3, 4], dtype=jnp.float64)
pid1 = diffrax.PIDController(rtol=1e-4, atol=1e-6)
pid2 = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=ts) # pyright: ignore
pid3 = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=ts, jump_ts=ts) # pyright: ignore
assert not isinstance(pid1, diffrax.ClipStepSizeController)
assert isinstance(pid1, diffrax.PIDController)
assert isinstance(pid2, diffrax.ClipStepSizeController)
assert isinstance(pid3, diffrax.ClipStepSizeController)
assert all(pid2.step_ts == ts)
assert all(pid3.step_ts == ts)
assert all(pid3.jump_ts == ts)
def test_nested_clip_wrappers():
pid = diffrax.PIDController(rtol=0, atol=1.0)
wrap1 = diffrax.ClipStepSizeController(pid, jump_ts=[3.0, 13.0], step_ts=[23.0])
wrap2 = diffrax.ClipStepSizeController(wrap1, step_ts=[2.0, 13.0], jump_ts=[23.0])
func = lambda terms, t, y, args: -y
terms = diffrax.ODETerm(lambda t, y, args: -y)
_, state = wrap2.init(terms, -1.0, 0.0, 0.0, 4.0, None, func, 5)
# test 1
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
0.0, 1.0, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t0 == 1
assert next_t1 == 2
assert not made_jump
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
next_t0, next_t1, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t0 == 2
assert next_t1 == eqxi.prevbefore(jnp.asarray(3.0))
assert not made_jump
# test 2
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
10.0, 11.0, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t0 == 11
assert next_t1 == eqxi.prevbefore(jnp.asarray(13.0))
assert not made_jump
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
next_t0, next_t1, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t0 == eqxi.nextafter(jnp.asarray(13.0))
assert next_t1 == eqxi.prevbefore(jnp.asarray(23.0))
assert made_jump
# test 3
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
20.0, 21.0, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t0 == 21
assert next_t1 == eqxi.prevbefore(jnp.asarray(23.0))
assert not made_jump
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
next_t0, next_t1, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t0 == eqxi.nextafter(jnp.asarray(23.0))
assert next_t1 > next_t0
assert made_jump
def test_find_idx_with_hint():
ts = jnp.arange(5.0)
for hint in (0, 2, 3, 5):
idx = _find_idx_with_hint(2.5, ts, hint)
assert idx == 3
idx = _find_idx_with_hint(2, ts, hint)
assert idx == 3 # not 2; we want the first value *strictly* greater.
idx = _find_idx_with_hint(1.9, ts, hint)
assert idx == 2
# https://github.com/patrick-kidger/diffrax/issues/607
@pytest.mark.parametrize("new", (False, True))
def test_implicit_solver_with_clip_controller(new: bool):
term = diffrax.ODETerm(lambda t, y, args: -y)
solver = diffrax.Kvaerno3()
if new:
ssc = diffrax.PIDController(rtol=1e-3, atol=1e-3)
ssc = diffrax.ClipStepSizeController(ssc, jump_ts=[0.5])
else:
ssc = diffrax.PIDController(jump_ts=[0.5], rtol=1e-3, atol=1e-3) # pyright: ignore[reportCallIssue]
diffrax.diffeqsolve(
term,
solver,
t0=0,
t1=1,
dt0=0.01,
args=None,
y0=1.0,
stepsize_controller=ssc,
max_steps=16384,
saveat=diffrax.SaveAt(t1=True),
)
# https://github.com/patrick-kidger/diffrax/issues/663
# `jump_ts` sets the time we step to as `prevbefore` the time provided.
# Clipping at t1 saves us! We need to clip at at least 1 ULP.
def test_jump_at_t1_with_large_t1_in_float32():
t0 = jnp.array(0.0, dtype=jnp.float32)
t1 = jnp.array(1e3, dtype=jnp.float32)
dt0 = jnp.array(0.01, dtype=jnp.float32)
y0 = jnp.array(1, dtype=jnp.float32)
saveat = diffrax.SaveAt(ts=t1[None])
ssc = diffrax.ClipStepSizeController(
diffrax.PIDController(atol=1e-6, rtol=1e-6), jump_ts=t1[None]
)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: -y),
diffrax.Heun(),
t0=t0,
t1=t1,
dt0=dt0,
y0=y0,
stepsize_controller=ssc,
saveat=saveat,
)
assert sol.ts == jnp.array([t1])
# https://github.com/patrick-kidger/diffrax/issues/713
def test_t0_at_jump_time():
jump_time = 0.98
controller = diffrax.PIDController(rtol=1e-6, atol=1e-6)
controller = diffrax.ClipStepSizeController(controller, jump_ts=[jump_time])
sol = diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)),
diffrax.Heun(),
t0=eqxi.prevbefore(jnp.asarray(jump_time)),
t1=1.2,
dt0=None,
y0=jnp.array([0, 0, 0, 0.0]),
stepsize_controller=controller,
event=diffrax.Event(
cond_fn=lambda t, y, args, **kw: jump_time - t,
root_finder=optx.Newton(atol=1e-4, rtol=1e-4),
direction=True,
),
max_steps=100,
)
# And in particular not an event.
# What used to happen was something very weird where we'd oscillate across the
# jump time.
assert sol.result == diffrax.RESULTS.successful