Skip to content

Commit e4a26b3

Browse files
Merge pull request #4382 from SciML/as/initialize-save-discretes
feat: support `initialize_save_discretes` in symbolic callbacks
2 parents 930d703 + aff08c6 commit e4a26b3

3 files changed

Lines changed: 47 additions & 14 deletions

File tree

lib/ModelingToolkitBase/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ ReadOnlyDicts = "1.0.0"
151151
RecursiveArrayTools = "3.26"
152152
Reexport = "0.2, 1"
153153
RuntimeGeneratedFunctions = "0.5.12"
154-
SciMLBase = "2.144.0"
154+
SciMLBase = "2.149.0"
155155
SciMLPublic = "1.0.0"
156156
SciMLStructures = "1.7"
157157
Serialization = "1"

lib/ModelingToolkitBase/src/systems/callbacks.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,8 @@ const Affect = Union{AffectSystem, ImperativeAffect}
317317

318318
"""
319319
SymbolicContinuousCallback(eqs::Vector{Equation}, affect = nothing, iv = nothing;
320-
affect_neg = affect, initialize = nothing, finalize = nothing, rootfind = SciMLBase.LeftRootFind)
320+
affect_neg = affect, initialize = nothing, finalize = nothing,
321+
rootfind = SciMLBase.LeftRootFind, initialize_save_discretes = true)
321322
322323
A [`ContinuousCallback`](@ref SciMLBase.ContinuousCallback) specified symbolically. Takes a vector of equations `eq`
323324
as well as the positive-edge `affect` and negative-edge `affect_neg` that apply when *any* of `eq` are satisfied.
@@ -364,6 +365,9 @@ and combined with the remaining `Equation`s.
364365
- Symbolic affects have reinitialization built in. In this case the algorithm will default to SciMLBase.NoInit(), and should **not** be provided.
365366
- Functional and imperative affects will default to SciMLBase.CheckInit(), which will error if the system is not properly reinitialized after the callback. If your system is a DAE, pass in an algorithm like SciMLBase.BrownBasicFullInit() to properly re-initialize.
366367
368+
`initialize_save_discretes` is a flag indicating whether the discrete variables modified by this
369+
callback should be saved at the start of the integration (when the `initialize` runs).
370+
367371
Initial and final affects can also be specified identically to positive and negative edge affects. Initialization affects
368372
will run as soon as the solver starts, while finalization affects will be executed after termination.
369373
"""
@@ -376,6 +380,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
376380
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
377381
reinitializealg::SciMLBase.DAEInitializationAlgorithm
378382
zero_crossing_id::Symbol
383+
initialize_save_discretes::Bool
379384
end
380385

381386
function SymbolicContinuousCallback(
@@ -387,6 +392,7 @@ function SymbolicContinuousCallback(
387392
rootfind = SciMLBase.LeftRootFind,
388393
reinitializealg = nothing,
389394
zero_crossing_id = gensym(),
395+
initialize_save_discretes = true,
390396
kwargs...
391397
)
392398
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
@@ -408,7 +414,7 @@ function SymbolicContinuousCallback(
408414
SymbolicAffect(initialize; kwargs...), SymbolicAffect(
409415
finalize; kwargs...
410416
),
411-
rootfind, reinitializealg, zero_crossing_id
417+
rootfind, reinitializealg, zero_crossing_id, initialize_save_discretes
412418
)
413419
end # Default affect to nothing
414420

@@ -429,7 +435,8 @@ function complete(cb::SymbolicContinuousCallback; kwargs...)
429435
return SymbolicContinuousCallback(
430436
cb.conditions, make_affect(cb.affect; kwargs...),
431437
make_affect(cb.affect_neg; kwargs...), make_affect(cb.initialize; kwargs...),
432-
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg, cb.zero_crossing_id
438+
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg,
439+
cb.zero_crossing_id, cb.initialize_save_discretes
433440
)
434441
end
435442

@@ -540,12 +547,13 @@ struct SymbolicDiscreteCallback <: AbstractCallback
540547
initialize::Union{Affect, SymbolicAffect, Nothing}
541548
finalize::Union{Affect, SymbolicAffect, Nothing}
542549
reinitializealg::SciMLBase.DAEInitializationAlgorithm
550+
initialize_save_discretes::Bool
543551
end
544552

545553
function SymbolicDiscreteCallback(
546554
condition::Union{SymbolicT, Number, Vector{<:Number}}, affect = nothing;
547555
initialize = nothing, finalize = nothing,
548-
reinitializealg = nothing, kwargs...
556+
reinitializealg = nothing, initialize_save_discretes = true, kwargs...
549557
)
550558
# Manual error check (to prevent events like `[X < 5.0] => [X ~ Pre(X) + 10.0]` from being created).
551559
(condition isa Vector) && (eltype(condition) <: Num) &&
@@ -567,7 +575,8 @@ function SymbolicDiscreteCallback(
567575
return SymbolicDiscreteCallback(
568576
c, SymbolicAffect(affect; kwargs...),
569577
SymbolicAffect(initialize; kwargs...),
570-
SymbolicAffect(finalize; kwargs...), reinitializealg
578+
SymbolicAffect(finalize; kwargs...), reinitializealg,
579+
initialize_save_discretes
571580
)
572581
end # Default affect to nothing
573582

@@ -588,7 +597,8 @@ function complete(cb::SymbolicDiscreteCallback; kwargs...)
588597
return SymbolicDiscreteCallback(
589598
cb.conditions, make_affect(cb.affect; kwargs...),
590599
make_affect(cb.initialize; kwargs...),
591-
make_affect(cb.finalize; kwargs...), cb.reinitializealg
600+
make_affect(cb.finalize; kwargs...), cb.reinitializealg,
601+
cb.initialize_save_discretes
592602
)
593603
end
594604

@@ -809,15 +819,15 @@ function generate_continuous_callbacks(
809819
cbs = continuous_events(sys)
810820
isempty(cbs) && return nothing
811821
cb_classes = Dict{
812-
Tuple{SciMLBase.RootfindOpt, SciMLBase.DAEInitializationAlgorithm},
822+
Tuple{SciMLBase.RootfindOpt, SciMLBase.DAEInitializationAlgorithm, Bool},
813823
Vector{SymbolicContinuousCallback},
814824
}()
815825

816826
# Sort the callbacks by their rootfinding method
817827
for cb in cbs
818828
_cbs = get!(
819829
() -> SymbolicContinuousCallback[],
820-
cb_classes, (cb.rootfind, cb.reinitializealg)
830+
cb_classes, (cb.rootfind, cb.reinitializealg, cb.initialize_save_discretes)
821831
)
822832
push!(_cbs, cb)
823833
end
@@ -927,7 +937,7 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
927937
return VectorContinuousCallback(
928938
trigger, affect, affect_neg, length(eqs); initialize, finalize,
929939
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg,
930-
saved_clock_partitions
940+
saved_clock_partitions, initialize_save_discretes = cbs[1].initialize_save_discretes
931941
)
932942
end
933943

@@ -964,23 +974,26 @@ function generate_callback(cb, sys; kwargs...)
964974
if is_timed && conditions(cb) isa AbstractVector
965975
return PresetTimeCallback(
966976
trigger, affect; initialize,
967-
finalize, initializealg = cb.reinitializealg, saved_clock_partitions
977+
finalize, initializealg = cb.reinitializealg, saved_clock_partitions,
978+
initialize_save_discretes = cb.initialize_save_discretes
968979
)
969980
elseif is_timed
970981
return PeriodicCallback(
971982
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg,
972-
saved_clock_partitions
983+
saved_clock_partitions, initialize_save_discretes = cb.initialize_save_discretes
973984
)
974985
else
975986
return DiscreteCallback(
976987
trigger, affect; initialize,
977-
finalize, initializealg = cb.reinitializealg, saved_clock_partitions
988+
finalize, initializealg = cb.reinitializealg, saved_clock_partitions,
989+
initialize_save_discretes = cb.initialize_save_discretes
978990
)
979991
end
980992
else
981993
return ContinuousCallback(
982994
trigger, affect, affect_neg; initialize, finalize,
983-
rootfind = cb.rootfind, initializealg = cb.reinitializealg, saved_clock_partitions
995+
rootfind = cb.rootfind, initializealg = cb.reinitializealg, saved_clock_partitions,
996+
initialize_save_discretes = cb.initialize_save_discretes
984997
)
985998
end
986999
end

lib/ModelingToolkitBase/test/symbolic_events.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,3 +1891,23 @@ if @isdefined(ModelingToolkit)
18911891
@test_nowarn mtkcompile(sys)
18921892
end
18931893
end
1894+
1895+
if !@isdefined(ModelingToolkit)
1896+
@testset "`initialize_save_discretes` support" begin
1897+
@variables x(t)
1898+
@discretes d1(t) d2(t)
1899+
cevt = SymbolicContinuousCallback(
1900+
[x ~ 0.5], [d1 ~ Pre(d1) + 0.1]; discrete_parameters = [d1], initialize_save_discretes = false
1901+
)
1902+
devt = SymbolicDiscreteCallback(
1903+
0.1, [d2 ~ Pre(d2) + 0.1]; discrete_parameters = [d2], initialize_save_discretes = false
1904+
)
1905+
@mtkcompile sys = System(
1906+
[D(x) ~ sin(d1 * d2 * t)], t; continuous_events = [cevt], discrete_events = [devt]
1907+
)
1908+
prob = ODEProblem(sys, [x => 0.0, d1 => 1.0, d2 => 1.0], (0.0, 10.0))
1909+
sol = solve(prob, Tsit5())
1910+
@test sol.discretes[1].t[1] > 0.0
1911+
@test sol.discretes[2].t[1] > 0.0
1912+
end
1913+
end

0 commit comments

Comments
 (0)