Skip to content

Commit 1039524

Browse files
committed
run root finders over all potential first events
1 parent 506c5be commit 1039524

1 file changed

Lines changed: 54 additions & 70 deletions

File tree

diffrax/_integrate.py

Lines changed: 54 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,7 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
582582
jtu.tree_structure((0, 0)),
583583
event_values__mask,
584584
)
585-
had_event = False
586-
event_mask_leaves = []
587-
for event_mask_i in jtu.tree_leaves(event_mask):
588-
event_mask_leaves.append(event_mask_i & jnp.invert(had_event))
589-
had_event = event_mask_i | had_event
590-
event_mask = jtu.tree_unflatten(event_structure, event_mask_leaves)
585+
had_event = jnp.any(jnp.stack(jtu.tree_leaves(event_mask), axis=0))
591586
result = RESULTS.where(
592587
had_event,
593588
RESULTS.event_occurred,
@@ -653,6 +648,7 @@ def body_fun(state):
653648
if event is None or event.root_finder is None:
654649
tfinal = final_state.tprev
655650
yfinal = final_state.y
651+
first_event_mask = final_state.event_mask
656652
else:
657653
# If we're on this branch, it means that an event may have triggered, and now we
658654
# may need to do a root find, in order to locate the event time.
@@ -663,19 +659,19 @@ def body_fun(state):
663659
event_happened = jnp.max(float_mask) > 0.0
664660

665661
def _root_find():
666-
_interpolator = solver.interpolation_cls(
662+
interp = solver.interpolation_cls(
667663
t0=final_state.event_tprev,
668664
t1=final_state.event_tnext,
669665
**final_state.event_dense_info,
670666
)
671667

672-
def _to_root_find(_t, _):
673-
_distance_from_t_end = final_state.event_tnext - _t
668+
flat_fns, fn_tree = jtu.tree_flatten(event.cond_fn)
669+
flat_masks, _ = jtu.tree_flatten(event_mask)
674670

675-
def _call_real(_event_mask_i, _cond_fn_i):
676-
def _call_real_impl():
677-
# First evaluate the triggered event.
678-
_y = _interpolator.evaluate(_t)
671+
def _call_real(_event_mask_i, _cond_fn_i):
672+
def _find():
673+
def f(_t, _):
674+
_y = interp.evaluate(_t)
679675
_value = _cond_fn_i(
680676
t=_t,
681677
y=_y,
@@ -689,67 +685,56 @@ def _call_real_impl():
689685
stepsize_controller=stepsize_controller,
690686
max_steps=max_steps,
691687
)
692-
# Second: if this is a boolean event, then normalise to a
693-
# floating point number by having the root occur at the end of
694-
# the last step, i.e. `event_tnext`.
695-
_value_dtype = jnp.result_type(_value)
696-
if jnp.issubdtype(_value_dtype, jnp.bool_):
697-
_value = _distance_from_t_end
698-
else:
699-
assert jnp.issubdtype(_value_dtype, jnp.floating)
700-
return _value
701-
702-
# Only the triggered event actually gets to the decide what time the
703-
# event occurs; everything else is zeroed out to automatically give
704-
# a root.
705-
#
706-
# We allow this `lax.cond` to be inefficiently transformed into a
707-
# `lax.select` when `_event_mask_i` is batched. There isn't any way
708-
# to avoid this, I think.
709-
_value = lax.cond(_event_mask_i, _call_real_impl, lambda: 0.0)
710-
711-
# Third: if no events triggered at all, then have the root occur at
712-
# the end of the last step (which will be the `t1` of the overall
713-
# solve).
714-
_value = jnp.where(event_happened, _value, _distance_from_t_end)
715-
return _value
716-
717-
return jtu.tree_map(
718-
_call_real,
719-
event_mask,
720-
event.cond_fn,
721-
)
688+
return (
689+
(final_state.event_tnext - _t)
690+
if jnp.issubdtype(_value.dtype, jnp.bool_)
691+
else _value
692+
)
693+
694+
opts = {
695+
"lower": final_state.event_tprev,
696+
"upper": final_state.event_tnext,
697+
}
698+
res = optx.root_find(
699+
f,
700+
event.root_finder,
701+
y0=final_state.event_tnext,
702+
options=opts,
703+
throw=False,
704+
)
705+
return res.value
722706

723-
_options = {
724-
"lower": final_state.event_tprev,
725-
"upper": final_state.event_tnext,
726-
}
727-
_event_root_find = optx.root_find(
728-
_to_root_find,
729-
event.root_finder,
730-
y0=final_state.event_tprev,
731-
options=_options,
732-
throw=False,
707+
return lax.cond(_event_mask_i, _find, lambda: jnp.inf)
708+
709+
candidates = jnp.stack(
710+
[_call_real(m, fn) for m, fn in zip(flat_masks, flat_fns)]
733711
)
734-
_tfinal = _event_root_find.value
735-
# TODO: we might need to change the way we evaluate `_yfinal` in order to
736-
# get more accurate derivatives?
737-
_yfinal = _interpolator.evaluate(_tfinal)
738-
_result = RESULTS.where(
739-
_event_root_find.result == optx.RESULTS.successful,
712+
713+
t_event = jnp.min(candidates)
714+
t_event = jnp.where(jnp.isfinite(t_event), t_event, final_state.event_tnext)
715+
716+
y_event = interp.evaluate(t_event)
717+
718+
first_idx = jnp.argmin(candidates)
719+
first_mask_arr = jnp.arange(candidates.shape[0]) == first_idx
720+
first_event_mask = jtu.tree_unflatten(fn_tree, list(first_mask_arr))
721+
722+
new_result = RESULTS.where(
723+
jnp.any(jnp.stack(flat_masks)),
724+
RESULTS.event_occurred,
740725
result,
741-
RESULTS.promote(_event_root_find.result),
742726
)
743-
return _tfinal, _yfinal, _result
727+
728+
return t_event, y_event, new_result, first_event_mask
744729

745730
# Fastpath: if no event happened anywhere at all, then skip the root-find
746731
# altogether.
747732
# Note that `_root_find` might still be called on batch elements which did not
748733
# have an event, so we still need to access `event_happened` inside of it.
749-
tfinal, yfinal, result = lax.cond(
734+
tfinal, yfinal, result, first_event_mask = lax.cond(
750735
eqxi.unvmap_any(event_happened),
751736
_root_find,
752-
lambda: (final_state.tprev, final_state.y, result),
737+
lambda: (final_state.tprev, final_state.y, result, final_state.event_mask),
753738
)
754739

755740
# We delete all the saved values after the event time.
@@ -824,9 +809,13 @@ def _save_t1(subsaveat, save_state):
824809
final_state = eqx.tree_at(
825810
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
826811
)
812+
827813
final_state = _handle_static(final_state)
828814
result = RESULTS.where(cond_fun(final_state), RESULTS.max_steps_reached, result)
829815
aux_stats = dict() # TODO: put something in here?
816+
817+
# override event mask with first found event
818+
final_state = eqx.tree_at(lambda s: s.event_mask, final_state, first_event_mask)
830819
return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats
831820

832821

@@ -1339,18 +1328,13 @@ def _outer_cond_fn(cond_fn_i):
13391328
jtu.tree_structure((0, 0)),
13401329
event_values__mask,
13411330
)
1342-
had_event = False
1343-
event_mask_leaves = []
1344-
for event_mask_i in jtu.tree_leaves(event_mask):
1345-
event_mask_leaves.append(event_mask_i & jnp.invert(had_event))
1346-
had_event = event_mask_i | had_event
1347-
event_mask = jtu.tree_unflatten(event_structure, event_mask_leaves)
1331+
had_event = jnp.any(jnp.stack(jtu.tree_leaves(event_mask), axis=0))
13481332
result = RESULTS.where(
13491333
had_event,
13501334
RESULTS.event_occurred,
13511335
result,
13521336
)
1353-
del had_event, event_structure, event_mask_leaves, event_values__mask
1337+
del had_event, event_structure, event_values__mask
13541338

13551339
# Initialise state
13561340
init_state = State(

0 commit comments

Comments
 (0)