@@ -196,41 +196,66 @@ end
196196 bottom_condition = get_condition (integrator, callback, integrator. tprev)
197197 @. bottom_sign = sign (bottom_condition)
198198
199+ prev_simultaneous_events = integrator. callback_cache. prev_simultaneous_events
199200 if integrator. event_last_time == callback_idx
200- nudged_idx = integrator. vector_event_last_time
201201 # If there was a previous event, nudge tprev on the right
202202 # side of the root (if necessary) to avoid repeat detection
203203
204204 if callback. interp_points == 0
205205 addsteps! (integrator)
206206 end
207207
208+ # Find the condition value closest to zero across all triggered events
209+ min_condition_val = zero (eltype (bottom_condition))
210+ min_abs_condition = typemax (eltype (bottom_condition))
211+ for idx in 1 : callback. len
212+ if prev_simultaneous_events[idx]
213+ cond_val = ArrayInterface. allowed_getindex (bottom_condition, idx)
214+ if abs (cond_val) < min_abs_condition
215+ min_abs_condition = abs (cond_val)
216+ min_condition_val = cond_val
217+ end
218+ end
219+ end
220+
208221 # Evaluate condition slightly in future
209- nudged_t = nudge_tprev (integrator, callback, ArrayInterface . allowed_getindex (bottom_condition, nudged_idx) )
222+ nudged_t = nudge_tprev (integrator, callback, min_condition_val )
210223 tmp_condition = get_condition (integrator, callback, nudged_t)
211224
212- ArrayInterface. allowed_setindex! (bottom_sign, sign (ArrayInterface. allowed_getindex (tmp_condition, nudged_idx)), nudged_idx)
225+ for idx in 1 : callback. len
226+ if prev_simultaneous_events[idx]
227+ ArrayInterface. allowed_setindex! (bottom_sign, sign (ArrayInterface. allowed_getindex (tmp_condition, idx)), idx)
228+ end
229+ end
213230 else
214- nudged_idx = - 1
215231 nudged_t = bottom_t
216232 end
217233
218234 # Check if an event occured
219235 event_occurred, event_idx, top_t, top_sign =
220236 check_event_occurence (integrator, callback, bottom_sign)
221237
238+ # Track simultaneous events
239+ (; simultaneous_events) = integrator. callback_cache
240+ if event_occurred
241+ prev_simultaneous_events .= simultaneous_events
242+ simultaneous_events .= false
243+ end
244+
222245 # Find callback time if occurence
223246 if ! event_occurred
224247 callback_t = integrator. t
225248 min_event_idx = 1
226249 residual = zero (eltype (bottom_condition))
227250 elseif isdiscrete (integrator. alg) || callback. rootfind == SciMLBase. NoRootFind
228251 callback_t = top_t
229- min_event_idx = 1
252+ min_event_idx = - 1
230253 for i in 1 : length (event_idx)
231254 if ArrayInterface. allowed_getindex (event_idx, i) == 1
232- min_event_idx = i
233- break
255+ if min_event_idx < 0
256+ min_event_idx = i
257+ end
258+ simultaneous_events[i] = true
234259 end
235260 end
236261 residual = zero (eltype (bottom_condition))
@@ -251,16 +276,20 @@ end
251276 if iszero (ArrayInterface. allowed_getindex (top_sign, idx))
252277 cbi_t = top_t
253278 else
254- if idx == nudged_idx
279+ if integrator . event_last_time == callback_idx && prev_simultaneous_events[idx]
255280 cbi_t = find_root (zero_func, (nudged_t, top_t), callback. rootfind)
256281 else
257282 cbi_t = find_root (zero_func, (bottom_t, top_t), callback. rootfind)
258283 end
259284 end
260285 if integrator. tdir * cbi_t < integrator. tdir * callback_t
286+ simultaneous_events .= false
287+ end
288+ if integrator. tdir * cbi_t <= integrator. tdir * callback_t
261289 min_event_idx = idx
262290 callback_t = cbi_t
263291 residual = zero_func (cbi_t)
292+ simultaneous_events[idx] = true
264293 end
265294 end
266295 end
270299 end
271300 end
272301
273- return callback_t, ArrayInterface. allowed_getindex (bottom_sign, min_event_idx),
302+ # We still pass around the min_event_idx for now because some stuff in OrdinaryDiffEqCore expects it to be an Int
303+ return callback_t, bottom_sign,
274304 event_occurred:: Bool , min_event_idx:: Int , residual
275305end
276306
446476
447477function apply_callback! (
448478 integrator,
449- callback:: Union{ ContinuousCallback, VectorContinuousCallback} ,
479+ callback:: ContinuousCallback ,
450480 cb_time, prev_sign, event_idx
451481 )
452482 if isadaptive (integrator)
@@ -478,15 +508,13 @@ function apply_callback!(
478508 if callback. affect! === nothing
479509 integrator. u_modified = false
480510 else
481- callback isa VectorContinuousCallback ?
482- callback. affect! (integrator, event_idx) : callback. affect! (integrator)
511+ callback. affect! (integrator)
483512 end
484513 elseif prev_sign > 0
485514 if callback. affect_neg! === nothing
486515 integrator. u_modified = false
487516 else
488- callback isa VectorContinuousCallback ?
489- callback. affect_neg! (integrator, event_idx) : callback. affect_neg! (integrator)
517+ callback. affect_neg! (integrator)
490518 end
491519 end
492520
@@ -498,10 +526,66 @@ function apply_callback!(
498526 @inbounds if callback. save_positions[2 ]
499527 savevalues! (integrator, true )
500528 if ! isdefined (integrator. opts, :save_discretes ) || integrator. opts. save_discretes
501- if callback isa VectorContinuousCallback
502- SciMLBase. save_discretes! (integrator, callback, event_idx)
503- else
504- SciMLBase. save_discretes! (integrator, callback)
529+ SciMLBase. save_discretes! (integrator, callback)
530+ end
531+ saved_in_cb = true
532+ end
533+ return true , saved_in_cb
534+ end
535+ return false , saved_in_cb
536+ end
537+
538+ function apply_callback! (
539+ integrator,
540+ callback:: VectorContinuousCallback ,
541+ cb_time, prev_sign, min_event_idx
542+ )
543+ if isadaptive (integrator)
544+ set_proposed_dt! (
545+ integrator,
546+ integrator. tdir * max (
547+ nextfloat (integrator. opts. dtmin),
548+ integrator. tdir * callback. dtrelax * integrator. dt
549+ )
550+ )
551+ end
552+
553+ change_t_via_interpolation! (
554+ integrator, cb_time, Val{:false }, callback. initializealg
555+ )
556+
557+ # handle saveat
558+ _, savedexactly = savevalues! (integrator)
559+ saved_in_cb = true
560+
561+ @inbounds if callback. save_positions[1 ]
562+ # if already saved then skip saving
563+ savedexactly || savevalues! (integrator, true )
564+ end
565+
566+ u_modified = false
567+ for (i, triggered) ∈ enumerate (integrator. callback_cache. simultaneous_events)
568+ if triggered
569+ if prev_sign[i] < 0 && callback. affect! != = nothing
570+ callback. affect! (integrator, i)
571+ u_modified = true
572+ elseif prev_sign[i] > 0 && callback. affect_neg! != = nothing
573+ callback. affect_neg! (integrator, i)
574+ u_modified = true
575+ end
576+ end
577+ end
578+ integrator. u_modified = u_modified
579+ if u_modified
580+ reeval_internals_due_to_modification! (
581+ integrator, callback_initializealg = callback. initializealg
582+ )
583+
584+ @inbounds if callback. save_positions[2 ]
585+ savevalues! (integrator, true )
586+ if ! isdefined (integrator. opts, :save_discretes ) || integrator. opts. save_discretes
587+ for i ∈ integrator. callback_cache. simultaneous_events
588+ SciMLBase. save_discretes! (integrator, callback, i)
505589 end
506590 end
507591 saved_in_cb = true
@@ -610,6 +694,8 @@ mutable struct CallbackCache{conditionType, signType}
610694 next_condition:: conditionType
611695 next_sign:: signType
612696 prev_sign:: signType
697+ simultaneous_events:: Vector{Bool}
698+ prev_simultaneous_events:: Vector{Bool}
613699end
614700
615701function CallbackCache (
@@ -620,7 +706,10 @@ function CallbackCache(
620706 next_condition = similar (u, conditionType, max_len)
621707 next_sign = similar (u, signType, max_len)
622708 prev_sign = similar (u, signType, max_len)
623- return CallbackCache (tmp_condition, next_condition, next_sign, prev_sign)
709+ simultaneous_events = zeros (Bool, max_len)
710+ prev_simultaneous_events = zeros (Bool, max_len)
711+ return CallbackCache (tmp_condition, next_condition, next_sign, prev_sign,
712+ simultaneous_events, prev_simultaneous_events)
624713end
625714
626715function CallbackCache (
@@ -631,5 +720,8 @@ function CallbackCache(
631720 next_condition = zeros (conditionType, max_len)
632721 next_sign = zeros (signType, max_len)
633722 prev_sign = zeros (signType, max_len)
634- return CallbackCache (tmp_condition, next_condition, next_sign, prev_sign)
723+ simultaneous_events = zeros (Bool, max_len)
724+ prev_simultaneous_events = zeros (Bool, max_len)
725+ return CallbackCache (tmp_condition, next_condition, next_sign, prev_sign,
726+ simultaneous_events, prev_simultaneous_events)
635727end
0 commit comments