@@ -364,7 +364,7 @@ def body_fun_aux(state):
364364 state .solver_state ,
365365 state .made_jump ,
366366 )
367- num_dde_explicit_step = num_dde_implicit_step = 0
367+ implicit_step = False
368368 else :
369369 min_delay = []
370370 flat_delays = jtu .tree_leaves (delays .delays )
@@ -439,64 +439,62 @@ def get_struct_dense_info(init_state):
439439 )
440440 assert jnp .result_type (keep_step ) is jnp .dtype (bool )
441441 # Finding all of the potential discontinuity roots
442- # if delays is not None:
443- # _part_maybe_find_discontinuity = ft.partial(
444- # maybe_find_discontinuity,
445- # tprev ,
446- # tnext ,
447- # dense_info ,
448- # state ,
449- # delays ,
450- # solver ,
451- # args ,
452- # )
453-
454- # tsearch = jnp.linspace(tprev, tnext, delays.sub_intervals)
455- # batch_tprev, batch_tnext = tsearch[:-1], tsearch[1:]
456- # vmap_maybe_find_discontinuity_wrapper = jax.vmap(
457- # _part_maybe_find_discontinuity, (None, 0, 0)
458- # )
459- # if delays.recurrent_checking:
460- # (
461- # tnext_candidate,
462- # batch_discont_update ,
463- # ) = vmap_maybe_find_discontinuity_wrapper(
464- # False, batch_tprev, batch_tnext
465- # )
466- # else:
467- # (
468- # tnext_candidate,
469- # batch_discont_update ,
470- # ) = vmap_maybe_find_discontinuity_wrapper(
471- # keep_step, batch_tprev, batch_tnext
472- # )
473-
474- # proxy_tnext = jnp.where(batch_discont_update, tnext_candidate, jnp.inf)
475- # proxy_tnext = jnp.min(proxy_tnext )
476-
477- # tnext, discont_update = jax.lax.cond(
478- # jnp.isinf(proxy_tnext),
479- # lambda: (tnext, False ),
480- # lambda: (proxy_tnext, True ),
481- # )
482-
483- # # Count the number of steps in DDEs, just for statistical purposes
484- # num_dde_implicit_step = state.num_dde_implicit_step + (
485- # keep_step & implicit_step
486- # )
487- # num_dde_explicit_step = state.num_dde_explicit_step + (
488- # keep_step & jnp.invert(implicit_step )
489- # )
490-
491- # assert jnp.result_type(discont_update) is jnp.dtype(bool )
492-
493- # assert jnp.result_type(keep_step) is jnp.dtype(bool)
442+ discont_update = False
443+ if delays is not None :
444+ # _part_maybe_find_discontinuity = ft.partial(
445+ # maybe_find_discontinuity ,
446+ # tprev ,
447+ # tnext ,
448+ # dense_info ,
449+ # state ,
450+ # delays ,
451+ # solver ,
452+ # args,
453+ # )
454+
455+ # tsearch = jnp.linspace(tprev, tnext, delays.sub_intervals)
456+ # batch_tprev, batch_tnext = tsearch[:-1], tsearch[1:]
457+ # vmap_maybe_find_discontinuity_wrapper = jax.vmap(
458+ # _part_maybe_find_discontinuity, (None, 0, 0 )
459+ # )
460+ # if delays.recurrent_checking:
461+ # (
462+ # tnext_candidate ,
463+ # batch_discont_update,
464+ # ) = vmap_maybe_find_discontinuity_wrapper(
465+ # False, batch_tprev, batch_tnext
466+ # )
467+ # else:
468+ # (
469+ # tnext_candidate ,
470+ # batch_discont_update,
471+ # ) = vmap_maybe_find_discontinuity_wrapper(
472+ # keep_step, batch_tprev, batch_tnext
473+ # )
474+
475+ # prox_tnext = jnp.where(batch_discont_update, tnext_candidate, jnp.inf )
476+ # prox_tnext = jnp.min(prox_tnext)
477+
478+ # tnext, discont_update = jax.lax.cond(
479+ # jnp.isinf(prox_tnext ),
480+ # lambda: (tnext, False ),
481+ # lambda: (prox_tnext, True),
482+ # )
483+ # assert jnp.result_type(discont_update) is jnp.dtype(bool)
484+
485+ # Count the number of steps in DDEs, just for statistical purposes
486+ num_dde_implicit_step = state . num_dde_implicit_step + (
487+ keep_step & implicit_step
488+ )
489+ num_dde_explicit_step = state . num_dde_explicit_step + (
490+ keep_step & jnp . invert ( implicit_step )
491+ )
492+
493+ assert jnp .result_type (keep_step ) is jnp .dtype (bool )
494494
495495 #
496496 # Do some book-keeping.
497497 #
498- discont_update = False
499- num_dde_explicit_step = num_dde_implicit_step = 0
500498 tprev = jnp .minimum (tprev , t1 )
501499 tnext = _clip_to_end (tprev , tnext , t1 , keep_step )
502500
@@ -736,8 +734,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
736734 event_dense_info = event_dense_info ,
737735 event_values = event_values ,
738736 event_mask = event_mask ,
739- num_dde_explicit_step = num_dde_explicit_step ,
740- num_dde_implicit_step = num_dde_implicit_step ,
737+ num_dde_explicit_step = num_dde_explicit_step , # type: ignore
738+ num_dde_implicit_step = num_dde_implicit_step , # type: ignore
741739 discontinuities = discontinuities , # type: ignore
742740 discontinuities_save_index = discontinuities_save_index ,
743741 )
0 commit comments