Skip to content

Commit 3244cd1

Browse files
ChrisRackauckasMasonProtterclaude
committed
WIP: support multiple simultaneous events in a single VectorContinuousCallback
Port of SciML/DiffEqBase.jl#1229 to the DiffEqBase sublibrary in OrdinaryDiffEq. Changes: - Add `simultaneous_events` and `prev_simultaneous_events` fields to `CallbackCache` - Track which events fire at the same time in `find_callback_time` - Return full `bottom_sign` array instead of single sign for VectorContinuousCallback - Split `apply_callback!` into separate ContinuousCallback and VectorContinuousCallback methods, where the VectorContinuousCallback version iterates over all simultaneous events and calls affect!/affect_neg! for each - Use `prev_simultaneous_events` instead of single `nudged_idx`/`vector_event_last_time` for nudging and root-finding bracket selection Co-Authored-By: Mason Protter <mason.protter@icloud.com> Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 249cda3 commit 3244cd1

1 file changed

Lines changed: 112 additions & 20 deletions

File tree

lib/DiffEqBase/src/callbacks.jl

Lines changed: 112 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -196,41 +196,66 @@ end
196196
bottom_condition = get_condition(integrator, callback, integrator.tprev)
197197
@. bottom_sign = sign(bottom_condition)
198198

199+
prev_simultaneous_events = integrator.callback_cache.prev_simultaneous_events
199200
if integrator.event_last_time == callback_idx
200-
nudged_idx = integrator.vector_event_last_time
201201
# If there was a previous event, nudge tprev on the right
202202
# side of the root (if necessary) to avoid repeat detection
203203

204204
if callback.interp_points == 0
205205
addsteps!(integrator)
206206
end
207207

208+
# Find the condition value closest to zero across all triggered events
209+
min_condition_val = zero(eltype(bottom_condition))
210+
min_abs_condition = typemax(eltype(bottom_condition))
211+
for idx in 1:callback.len
212+
if prev_simultaneous_events[idx]
213+
cond_val = ArrayInterface.allowed_getindex(bottom_condition, idx)
214+
if abs(cond_val) < min_abs_condition
215+
min_abs_condition = abs(cond_val)
216+
min_condition_val = cond_val
217+
end
218+
end
219+
end
220+
208221
# Evaluate condition slightly in future
209-
nudged_t = nudge_tprev(integrator, callback, ArrayInterface.allowed_getindex(bottom_condition, nudged_idx))
222+
nudged_t = nudge_tprev(integrator, callback, min_condition_val)
210223
tmp_condition = get_condition(integrator, callback, nudged_t)
211224

212-
ArrayInterface.allowed_setindex!(bottom_sign, sign(ArrayInterface.allowed_getindex(tmp_condition, nudged_idx)), nudged_idx)
225+
for idx in 1:callback.len
226+
if prev_simultaneous_events[idx]
227+
ArrayInterface.allowed_setindex!(bottom_sign, sign(ArrayInterface.allowed_getindex(tmp_condition, idx)), idx)
228+
end
229+
end
213230
else
214-
nudged_idx = -1
215231
nudged_t = bottom_t
216232
end
217233

218234
# Check if an event occured
219235
event_occurred, event_idx, top_t, top_sign =
220236
check_event_occurence(integrator, callback, bottom_sign)
221237

238+
# Track simultaneous events
239+
(; simultaneous_events) = integrator.callback_cache
240+
if event_occurred
241+
prev_simultaneous_events .= simultaneous_events
242+
simultaneous_events .= false
243+
end
244+
222245
# Find callback time if occurence
223246
if !event_occurred
224247
callback_t = integrator.t
225248
min_event_idx = 1
226249
residual = zero(eltype(bottom_condition))
227250
elseif isdiscrete(integrator.alg) || callback.rootfind == SciMLBase.NoRootFind
228251
callback_t = top_t
229-
min_event_idx = 1
252+
min_event_idx = -1
230253
for i in 1:length(event_idx)
231254
if ArrayInterface.allowed_getindex(event_idx, i) == 1
232-
min_event_idx = i
233-
break
255+
if min_event_idx < 0
256+
min_event_idx = i
257+
end
258+
simultaneous_events[i] = true
234259
end
235260
end
236261
residual = zero(eltype(bottom_condition))
@@ -251,16 +276,20 @@ end
251276
if iszero(ArrayInterface.allowed_getindex(top_sign, idx))
252277
cbi_t = top_t
253278
else
254-
if idx == nudged_idx
279+
if integrator.event_last_time == callback_idx && prev_simultaneous_events[idx]
255280
cbi_t = find_root(zero_func, (nudged_t, top_t), callback.rootfind)
256281
else
257282
cbi_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind)
258283
end
259284
end
260285
if integrator.tdir * cbi_t < integrator.tdir * callback_t
286+
simultaneous_events .= false
287+
end
288+
if integrator.tdir * cbi_t <= integrator.tdir * callback_t
261289
min_event_idx = idx
262290
callback_t = cbi_t
263291
residual = zero_func(cbi_t)
292+
simultaneous_events[idx] = true
264293
end
265294
end
266295
end
@@ -270,7 +299,8 @@ end
270299
end
271300
end
272301

273-
return callback_t, ArrayInterface.allowed_getindex(bottom_sign, min_event_idx),
302+
# We still pass around the min_event_idx for now because some stuff in OrdinaryDiffEqCore expects it to be an Int
303+
return callback_t, bottom_sign,
274304
event_occurred::Bool, min_event_idx::Int, residual
275305
end
276306

@@ -446,7 +476,7 @@ end
446476

447477
function apply_callback!(
448478
integrator,
449-
callback::Union{ContinuousCallback, VectorContinuousCallback},
479+
callback::ContinuousCallback,
450480
cb_time, prev_sign, event_idx
451481
)
452482
if isadaptive(integrator)
@@ -478,15 +508,13 @@ function apply_callback!(
478508
if callback.affect! === nothing
479509
integrator.u_modified = false
480510
else
481-
callback isa VectorContinuousCallback ?
482-
callback.affect!(integrator, event_idx) : callback.affect!(integrator)
511+
callback.affect!(integrator)
483512
end
484513
elseif prev_sign > 0
485514
if callback.affect_neg! === nothing
486515
integrator.u_modified = false
487516
else
488-
callback isa VectorContinuousCallback ?
489-
callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator)
517+
callback.affect_neg!(integrator)
490518
end
491519
end
492520

@@ -498,10 +526,66 @@ function apply_callback!(
498526
@inbounds if callback.save_positions[2]
499527
savevalues!(integrator, true)
500528
if !isdefined(integrator.opts, :save_discretes) || integrator.opts.save_discretes
501-
if callback isa VectorContinuousCallback
502-
SciMLBase.save_discretes!(integrator, callback, event_idx)
503-
else
504-
SciMLBase.save_discretes!(integrator, callback)
529+
SciMLBase.save_discretes!(integrator, callback)
530+
end
531+
saved_in_cb = true
532+
end
533+
return true, saved_in_cb
534+
end
535+
return false, saved_in_cb
536+
end
537+
538+
function apply_callback!(
539+
integrator,
540+
callback::VectorContinuousCallback,
541+
cb_time, prev_sign, min_event_idx
542+
)
543+
if isadaptive(integrator)
544+
set_proposed_dt!(
545+
integrator,
546+
integrator.tdir * max(
547+
nextfloat(integrator.opts.dtmin),
548+
integrator.tdir * callback.dtrelax * integrator.dt
549+
)
550+
)
551+
end
552+
553+
change_t_via_interpolation!(
554+
integrator, cb_time, Val{:false}, callback.initializealg
555+
)
556+
557+
# handle saveat
558+
_, savedexactly = savevalues!(integrator)
559+
saved_in_cb = true
560+
561+
@inbounds if callback.save_positions[1]
562+
# if already saved then skip saving
563+
savedexactly || savevalues!(integrator, true)
564+
end
565+
566+
u_modified = false
567+
for (i, triggered) enumerate(integrator.callback_cache.simultaneous_events)
568+
if triggered
569+
if prev_sign[i] < 0 && callback.affect! !== nothing
570+
callback.affect!(integrator, i)
571+
u_modified = true
572+
elseif prev_sign[i] > 0 && callback.affect_neg! !== nothing
573+
callback.affect_neg!(integrator, i)
574+
u_modified = true
575+
end
576+
end
577+
end
578+
integrator.u_modified = u_modified
579+
if u_modified
580+
reeval_internals_due_to_modification!(
581+
integrator, callback_initializealg = callback.initializealg
582+
)
583+
584+
@inbounds if callback.save_positions[2]
585+
savevalues!(integrator, true)
586+
if !isdefined(integrator.opts, :save_discretes) || integrator.opts.save_discretes
587+
for i integrator.callback_cache.simultaneous_events
588+
SciMLBase.save_discretes!(integrator, callback, i)
505589
end
506590
end
507591
saved_in_cb = true
@@ -610,6 +694,8 @@ mutable struct CallbackCache{conditionType, signType}
610694
next_condition::conditionType
611695
next_sign::signType
612696
prev_sign::signType
697+
simultaneous_events::Vector{Bool}
698+
prev_simultaneous_events::Vector{Bool}
613699
end
614700

615701
function CallbackCache(
@@ -620,7 +706,10 @@ function CallbackCache(
620706
next_condition = similar(u, conditionType, max_len)
621707
next_sign = similar(u, signType, max_len)
622708
prev_sign = similar(u, signType, max_len)
623-
return CallbackCache(tmp_condition, next_condition, next_sign, prev_sign)
709+
simultaneous_events = zeros(Bool, max_len)
710+
prev_simultaneous_events = zeros(Bool, max_len)
711+
return CallbackCache(tmp_condition, next_condition, next_sign, prev_sign,
712+
simultaneous_events, prev_simultaneous_events)
624713
end
625714

626715
function CallbackCache(
@@ -631,5 +720,8 @@ function CallbackCache(
631720
next_condition = zeros(conditionType, max_len)
632721
next_sign = zeros(signType, max_len)
633722
prev_sign = zeros(signType, max_len)
634-
return CallbackCache(tmp_condition, next_condition, next_sign, prev_sign)
723+
simultaneous_events = zeros(Bool, max_len)
724+
prev_simultaneous_events = zeros(Bool, max_len)
725+
return CallbackCache(tmp_condition, next_condition, next_sign, prev_sign,
726+
simultaneous_events, prev_simultaneous_events)
635727
end

0 commit comments

Comments
 (0)