Skip to content

Commit 61e5a56

Browse files
committed
small updates
1 parent 863342a commit 61e5a56

2 files changed

Lines changed: 58 additions & 60 deletions

File tree

diffrax/_delays.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class Delays(eqx.Module):
102102
max_discontinuities: IntScalarLike = 100
103103
recurrent_checking: bool = False
104104
sub_intervals: IntScalarLike = 10
105-
max_steps: IntScalarLike = 20
105+
max_steps: IntScalarLike = 10
106106
rtol: RealScalarLike = 10e-3
107107
atol: RealScalarLike = 10e-3
108108

@@ -136,7 +136,7 @@ def __call__(self, t, y, args):
136136
delays, treedef = jtu.tree_flatten(self.delays)
137137
if self.dense_interp is None:
138138
assert self.dense_info is None
139-
for delay in self.delays:
139+
for delay in delays:
140140
delay_val = delay(t, y, args)
141141
alpha_val = t - delay_val
142142
y0_val = self.y0_history(alpha_val)
@@ -265,7 +265,7 @@ def fn(dense_info, args):
265265
atol=delays.atol,
266266
norm=rms_norm,
267267
implicit_step=implicit_step,
268-
max_steps=100,
268+
max_steps=delays.max_steps,
269269
)
270270

271271
nonlinear_args = (

diffrax/_integrate.py

Lines changed: 55 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def body_fun_aux(state):
364364
state.solver_state,
365365
state.made_jump,
366366
)
367-
num_dde_explicit_step = num_dde_implicit_step = 0
367+
implicit_step = False
368368
else:
369369
min_delay = []
370370
flat_delays = jtu.tree_leaves(delays.delays)
@@ -439,64 +439,62 @@ def get_struct_dense_info(init_state):
439439
)
440440
assert jnp.result_type(keep_step) is jnp.dtype(bool)
441441
# Finding all of the potential discontinuity roots
442-
# if delays is not None:
443-
# _part_maybe_find_discontinuity = ft.partial(
444-
# maybe_find_discontinuity,
445-
# tprev,
446-
# tnext,
447-
# dense_info,
448-
# state,
449-
# delays,
450-
# solver,
451-
# args,
452-
# )
453-
454-
# tsearch = jnp.linspace(tprev, tnext, delays.sub_intervals)
455-
# batch_tprev, batch_tnext = tsearch[:-1], tsearch[1:]
456-
# vmap_maybe_find_discontinuity_wrapper = jax.vmap(
457-
# _part_maybe_find_discontinuity, (None, 0, 0)
458-
# )
459-
# if delays.recurrent_checking:
460-
# (
461-
# tnext_candidate,
462-
# batch_discont_update,
463-
# ) = vmap_maybe_find_discontinuity_wrapper(
464-
# False, batch_tprev, batch_tnext
465-
# )
466-
# else:
467-
# (
468-
# tnext_candidate,
469-
# batch_discont_update,
470-
# ) = vmap_maybe_find_discontinuity_wrapper(
471-
# keep_step, batch_tprev, batch_tnext
472-
# )
473-
474-
# proxy_tnext = jnp.where(batch_discont_update, tnext_candidate, jnp.inf)
475-
# proxy_tnext = jnp.min(proxy_tnext)
476-
477-
# tnext, discont_update = jax.lax.cond(
478-
# jnp.isinf(proxy_tnext),
479-
# lambda: (tnext, False),
480-
# lambda: (proxy_tnext, True),
481-
# )
482-
483-
# # Count the number of steps in DDEs, just for statistical purposes
484-
# num_dde_implicit_step = state.num_dde_implicit_step + (
485-
# keep_step & implicit_step
486-
# )
487-
# num_dde_explicit_step = state.num_dde_explicit_step + (
488-
# keep_step & jnp.invert(implicit_step)
489-
# )
490-
491-
# assert jnp.result_type(discont_update) is jnp.dtype(bool)
492-
493-
# assert jnp.result_type(keep_step) is jnp.dtype(bool)
442+
discont_update = False
443+
if delays is not None:
444+
# _part_maybe_find_discontinuity = ft.partial(
445+
# maybe_find_discontinuity,
446+
# tprev,
447+
# tnext,
448+
# dense_info,
449+
# state,
450+
# delays,
451+
# solver,
452+
# args,
453+
# )
454+
455+
# tsearch = jnp.linspace(tprev, tnext, delays.sub_intervals)
456+
# batch_tprev, batch_tnext = tsearch[:-1], tsearch[1:]
457+
# vmap_maybe_find_discontinuity_wrapper = jax.vmap(
458+
# _part_maybe_find_discontinuity, (None, 0, 0)
459+
# )
460+
# if delays.recurrent_checking:
461+
# (
462+
# tnext_candidate,
463+
# batch_discont_update,
464+
# ) = vmap_maybe_find_discontinuity_wrapper(
465+
# False, batch_tprev, batch_tnext
466+
# )
467+
# else:
468+
# (
469+
# tnext_candidate,
470+
# batch_discont_update,
471+
# ) = vmap_maybe_find_discontinuity_wrapper(
472+
# keep_step, batch_tprev, batch_tnext
473+
# )
474+
475+
# prox_tnext = jnp.where(batch_discont_update, tnext_candidate, jnp.inf)
476+
# prox_tnext = jnp.min(prox_tnext)
477+
478+
# tnext, discont_update = jax.lax.cond(
479+
# jnp.isinf(prox_tnext),
480+
# lambda: (tnext, False),
481+
# lambda: (prox_tnext, True),
482+
# )
483+
# assert jnp.result_type(discont_update) is jnp.dtype(bool)
484+
485+
# Count the number of steps in DDEs, just for statistical purposes
486+
num_dde_implicit_step = state.num_dde_implicit_step + (
487+
keep_step & implicit_step
488+
)
489+
num_dde_explicit_step = state.num_dde_explicit_step + (
490+
keep_step & jnp.invert(implicit_step)
491+
)
492+
493+
assert jnp.result_type(keep_step) is jnp.dtype(bool)
494494

495495
#
496496
# Do some book-keeping.
497497
#
498-
discont_update = False
499-
num_dde_explicit_step = num_dde_implicit_step = 0
500498
tprev = jnp.minimum(tprev, t1)
501499
tnext = _clip_to_end(tprev, tnext, t1, keep_step)
502500

@@ -736,8 +734,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
736734
event_dense_info=event_dense_info,
737735
event_values=event_values,
738736
event_mask=event_mask,
739-
num_dde_explicit_step=num_dde_explicit_step,
740-
num_dde_implicit_step=num_dde_implicit_step,
737+
num_dde_explicit_step=num_dde_explicit_step, # type: ignore
738+
num_dde_implicit_step=num_dde_implicit_step, # type: ignore
741739
discontinuities=discontinuities, # type: ignore
742740
discontinuities_save_index=discontinuities_save_index,
743741
)

0 commit comments

Comments
 (0)