Skip to content

Commit def8e1d

Browse files
Reworked JumpStepWrapper.
1 parent 5280c97 commit def8e1d

11 files changed

Lines changed: 511 additions & 536 deletions

File tree

benchmarks/jump_step_timing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def get_terms(key):
3636
pid_controller = diffrax.PIDController(
3737
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7
3838
)
39-
new_controller = diffrax.JumpStepWrapper(
39+
new_controller = diffrax.ClipStepSizeController(
4040
pid_controller,
4141
step_ts=step_ts,
42-
rejected_step_buffer_len=None,
42+
store_rejected_steps=None,
4343
)
4444
old_controller = OldPIDController(
4545
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7, step_ts=step_ts
@@ -88,16 +88,16 @@ def time_controller():
8888
print(f"New controller: {time_new:.5} s, Old controller: {time_old:.5} s")
8989

9090
# How expensive is revisiting rejected steps?
91-
revisiting_controller_short = diffrax.JumpStepWrapper(
91+
revisiting_controller_short = diffrax.ClipStepSizeController(
9292
pid_controller,
9393
step_ts=step_ts,
94-
rejected_step_buffer_len=10,
94+
store_rejected_steps=10,
9595
)
9696

97-
revisiting_controller_long = diffrax.JumpStepWrapper(
97+
revisiting_controller_long = diffrax.ClipStepSizeController(
9898
pid_controller,
9999
step_ts=step_ts,
100-
rejected_step_buffer_len=4096,
100+
store_rejected_steps=4096,
101101
)
102102

103103
time_revisiting_short = do_timing(revisiting_controller_short)

diffrax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@
121121
from ._step_size_controller import (
122122
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
123123
AbstractStepSizeController as AbstractStepSizeController,
124+
ClipStepSizeController as ClipStepSizeController,
124125
ConstantStepSize as ConstantStepSize,
125-
JumpStepWrapper as JumpStepWrapper,
126126
PIDController as PIDController,
127127
StepTo as StepTo,
128128
)

diffrax/_autocitation.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
SRA1,
3737
Tsit5,
3838
)
39-
from ._step_size_controller import PIDController
39+
from ._step_size_controller import ClipStepSizeController, PIDController
4040

4141

4242
def citation(*args, **kwargs):
@@ -134,7 +134,7 @@ def citation(*args, **kwargs):
134134

135135

136136
_thesis_cite = r"""
137-
phdthesis{kidger2021on,
137+
@phdthesis{kidger2021on,
138138
title={{O}n {N}eural {D}ifferential {E}quations},
139139
author={Patrick Kidger},
140140
year={2021},
@@ -352,10 +352,10 @@ def _virtual_brownian_tree(terms):
352352
return (
353353
r"""
354354
% You are simulating Brownian motion using a virtual Brownian tree, which was introduced
355-
% in:
355+
% in the following two papers:
356356
"""
357357
+ vbt_ref
358-
+ "\n\n"
358+
+ "\n"
359359
+ single_seed_ref
360360
)
361361

@@ -570,6 +570,17 @@ def _auto_dt0(dt0):
570570
"""
571571

572572

573+
@citation_rules.append
574+
def _clip_controller(terms, stepsize_controller):
575+
if type(stepsize_controller) is ClipStepSizeController:
576+
if stepsize_controller.store_rejected_steps is not None and is_sde(terms):
577+
return r"""
578+
% You are adaptively solving an SDE whilst revisiting rejected time points. This is a
579+
% subtle point required for the correctness of adaptive noncommutative SDE solves, as
580+
% found in:
581+
""" + _parse_reference(ClipStepSizeController)
582+
583+
573584
@citation_rules.append
574585
def _pid_controller(stepsize_controller, terms=None):
575586
if type(stepsize_controller) is PIDController:

diffrax/_solution.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ class RESULTS(optx.RESULTS): # pyright: ignore
2121
event_occurred = (
2222
"Terminating differential equation solve because an event occurred."
2323
)
24+
max_steps_rejected = (
25+
"Maximum number of rejected steps was reached. Consider increasing "
26+
"`diffrax.ClipStepSizeController(store_rejected_steps==...)`."
27+
)
28+
internal_error = (
29+
"An internal error occurred in Diffrax. This is a bug! Please open a GitHub "
30+
"issue with a minimum working example. (<50 lines of code is ideal)"
31+
)
2432

2533

2634
# Backward compatibility

diffrax/_step_size_controller/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
33
AbstractStepSizeController as AbstractStepSizeController,
44
)
5+
from .clip import ClipStepSizeController as ClipStepSizeController
56
from .constant import ConstantStepSize as ConstantStepSize, StepTo as StepTo
6-
from .jump_step_wrapper import JumpStepWrapper as JumpStepWrapper
77
from .pid import (
88
PIDController as PIDController,
99
)

0 commit comments

Comments
 (0)