@@ -582,12 +582,7 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
582582 jtu .tree_structure ((0 , 0 )),
583583 event_values__mask ,
584584 )
585- had_event = False
586- event_mask_leaves = []
587- for event_mask_i in jtu .tree_leaves (event_mask ):
588- event_mask_leaves .append (event_mask_i & jnp .invert (had_event ))
589- had_event = event_mask_i | had_event
590- event_mask = jtu .tree_unflatten (event_structure , event_mask_leaves )
585+ had_event = jnp .any (jnp .stack (jtu .tree_leaves (event_mask ), axis = 0 ))
591586 result = RESULTS .where (
592587 had_event ,
593588 RESULTS .event_occurred ,
@@ -653,6 +648,7 @@ def body_fun(state):
653648 if event is None or event .root_finder is None :
654649 tfinal = final_state .tprev
655650 yfinal = final_state .y
651+ first_event_mask = final_state .event_mask
656652 else :
657653 # If we're on this branch, it means that an event may have triggered, and now we
658654 # may need to do a root find, in order to locate the event time.
@@ -663,19 +659,19 @@ def body_fun(state):
663659 event_happened = jnp .max (float_mask ) > 0.0
664660
665661 def _root_find ():
666- _interpolator = solver .interpolation_cls (
662+ interp = solver .interpolation_cls (
667663 t0 = final_state .event_tprev ,
668664 t1 = final_state .event_tnext ,
669665 ** final_state .event_dense_info ,
670666 )
671667
672- def _to_root_find ( _t , _ ):
673- _distance_from_t_end = final_state . event_tnext - _t
668+ flat_fns , fn_tree = jtu . tree_flatten ( event . cond_fn )
669+ flat_masks , _ = jtu . tree_flatten ( event_mask )
674670
675- def _call_real (_event_mask_i , _cond_fn_i ):
676- def _call_real_impl ():
677- # First evaluate the triggered event.
678- _y = _interpolator .evaluate (_t )
671+ def _call_real (_event_mask_i , _cond_fn_i ):
672+ def _find ():
673+ def f ( _t , _ ):
674+ _y = interp .evaluate (_t )
679675 _value = _cond_fn_i (
680676 t = _t ,
681677 y = _y ,
@@ -689,67 +685,56 @@ def _call_real_impl():
689685 stepsize_controller = stepsize_controller ,
690686 max_steps = max_steps ,
691687 )
692- # Second: if this is a boolean event, then normalise to a
693- # floating point number by having the root occur at the end of
694- # the last step, i.e. `event_tnext`.
695- _value_dtype = jnp .result_type (_value )
696- if jnp .issubdtype (_value_dtype , jnp .bool_ ):
697- _value = _distance_from_t_end
698- else :
699- assert jnp .issubdtype (_value_dtype , jnp .floating )
700- return _value
701-
702- # Only the triggered event actually gets to the decide what time the
703- # event occurs; everything else is zeroed out to automatically give
704- # a root.
705- #
706- # We allow this `lax.cond` to be inefficiently transformed into a
707- # `lax.select` when `_event_mask_i` is batched. There isn't any way
708- # to avoid this, I think.
709- _value = lax .cond (_event_mask_i , _call_real_impl , lambda : 0.0 )
710-
711- # Third: if no events triggered at all, then have the root occur at
712- # the end of the last step (which will be the `t1` of the overall
713- # solve).
714- _value = jnp .where (event_happened , _value , _distance_from_t_end )
715- return _value
716-
717- return jtu .tree_map (
718- _call_real ,
719- event_mask ,
720- event .cond_fn ,
721- )
688+ return (
689+ (final_state .event_tnext - _t )
690+ if jnp .issubdtype (_value .dtype , jnp .bool_ )
691+ else _value
692+ )
693+
694+ opts = {
695+ "lower" : final_state .event_tprev ,
696+ "upper" : final_state .event_tnext ,
697+ }
698+ res = optx .root_find (
699+ f ,
700+ event .root_finder ,
701+ y0 = final_state .event_tnext ,
702+ options = opts ,
703+ throw = False ,
704+ )
705+ return res .value
722706
723- _options = {
724- "lower" : final_state .event_tprev ,
725- "upper" : final_state .event_tnext ,
726- }
727- _event_root_find = optx .root_find (
728- _to_root_find ,
729- event .root_finder ,
730- y0 = final_state .event_tprev ,
731- options = _options ,
732- throw = False ,
707+ return lax .cond (_event_mask_i , _find , lambda : jnp .inf )
708+
709+ candidates = jnp .stack (
710+ [_call_real (m , fn ) for m , fn in zip (flat_masks , flat_fns )]
733711 )
734- _tfinal = _event_root_find .value
735- # TODO: we might need to change the way we evaluate `_yfinal` in order to
736- # get more accurate derivatives?
737- _yfinal = _interpolator .evaluate (_tfinal )
738- _result = RESULTS .where (
739- _event_root_find .result == optx .RESULTS .successful ,
712+
713+ t_event = jnp .min (candidates )
714+ t_event = jnp .where (jnp .isfinite (t_event ), t_event , final_state .event_tnext )
715+
716+ y_event = interp .evaluate (t_event )
717+
718+ first_idx = jnp .argmin (candidates )
719+ first_mask_arr = jnp .arange (candidates .shape [0 ]) == first_idx
720+ first_event_mask = jtu .tree_unflatten (fn_tree , list (first_mask_arr ))
721+
722+ new_result = RESULTS .where (
723+ jnp .any (jnp .stack (flat_masks )),
724+ RESULTS .event_occurred ,
740725 result ,
741- RESULTS .promote (_event_root_find .result ),
742726 )
743- return _tfinal , _yfinal , _result
727+
728+ return t_event , y_event , new_result , first_event_mask
744729
745730 # Fastpath: if no event happened anywhere at all, then skip the root-find
746731 # altogether.
747732 # Note that `_root_find` might still be called on batch elements which did not
748733 # have an event, so we still need to access `event_happened` inside of it.
749- tfinal , yfinal , result = lax .cond (
734+ tfinal , yfinal , result , first_event_mask = lax .cond (
750735 eqxi .unvmap_any (event_happened ),
751736 _root_find ,
752- lambda : (final_state .tprev , final_state .y , result ),
737+ lambda : (final_state .tprev , final_state .y , result , final_state . event_mask ),
753738 )
754739
755740 # We delete all the saved values after the event time.
@@ -824,9 +809,13 @@ def _save_t1(subsaveat, save_state):
824809 final_state = eqx .tree_at (
825810 lambda s : s .save_state , final_state , save_state , is_leaf = _is_none
826811 )
812+
827813 final_state = _handle_static (final_state )
828814 result = RESULTS .where (cond_fun (final_state ), RESULTS .max_steps_reached , result )
829815 aux_stats = dict () # TODO: put something in here?
816+
817+ # override event mask with first found event
818+ final_state = eqx .tree_at (lambda s : s .event_mask , final_state , first_event_mask )
830819 return eqx .tree_at (lambda s : s .result , final_state , result ), aux_stats
831820
832821
@@ -1339,18 +1328,13 @@ def _outer_cond_fn(cond_fn_i):
13391328 jtu .tree_structure ((0 , 0 )),
13401329 event_values__mask ,
13411330 )
1342- had_event = False
1343- event_mask_leaves = []
1344- for event_mask_i in jtu .tree_leaves (event_mask ):
1345- event_mask_leaves .append (event_mask_i & jnp .invert (had_event ))
1346- had_event = event_mask_i | had_event
1347- event_mask = jtu .tree_unflatten (event_structure , event_mask_leaves )
1331+ had_event = jnp .any (jnp .stack (jtu .tree_leaves (event_mask ), axis = 0 ))
13481332 result = RESULTS .where (
13491333 had_event ,
13501334 RESULTS .event_occurred ,
13511335 result ,
13521336 )
1353- del had_event , event_structure , event_mask_leaves , event_values__mask
1337+ del had_event , event_structure , event_values__mask
13541338
13551339 # Initialise state
13561340 init_state = State (
0 commit comments