Skip to content

Commit 68ce746

Browse files
Merge pull request #4495 from SciML/as/new-callback-syntax
[WIP] feat: support new callback syntax in OrdinaryDiffEq@7
2 parents 4df69f0 + 6ffeb61 commit 68ce746

26 files changed

Lines changed: 127 additions & 56 deletions

Project.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,11 @@ OrdinaryDiffEq = "6.82.0, 7"
9797
OrdinaryDiffEqBDF = "1, 2"
9898
OrdinaryDiffEqCore = "1.34.0, 2, 3, 4"
9999
OrdinaryDiffEqDefault = "1.2, 2"
100+
OrdinaryDiffEqFIRK = "1, 2"
101+
OrdinaryDiffEqFunctionMap = "1, 2"
100102
OrdinaryDiffEqNonlinearSolve = "1.5.0, 2"
101103
OrdinaryDiffEqRosenbrock = "1, 2"
104+
OrdinaryDiffEqSDIRK = "1, 2"
102105
PreallocationTools = "0.4.27, 1"
103106
PrecompileTools = "1.2.1"
104107
REPL = "1"
@@ -147,9 +150,14 @@ OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
147150
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
148151
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
149152
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
153+
OrdinaryDiffEqBDF = "6ad6398a-0878-4a85-9266-38940aa047c8"
150154
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
151155
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
156+
OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125"
157+
OrdinaryDiffEqFunctionMap = "d3585ca7-f5d3-4ba6-8057-292ed1abd90f"
152158
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
159+
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
160+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
153161
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
154162
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
155163
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -169,4 +177,4 @@ TestEnv = "1e6cf692-eddd-4d53-88a5-2d735e33781b"
169177
URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
170178

171179
[targets]
172-
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "Pkg", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve", "Latexify", "Distributed", "DiffEqNoiseProcess", "DynamicQuantities", "DiffEqCallbacks", "URIs", "JumpProcesses", "RecursiveArrayTools", "SciMLStructures", "SpecialFunctions", "SciCompDSL"]
180+
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "Pkg", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve", "Latexify", "Distributed", "DiffEqNoiseProcess", "DynamicQuantities", "DiffEqCallbacks", "URIs", "JumpProcesses", "RecursiveArrayTools", "SciMLStructures", "SpecialFunctions", "SciCompDSL", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqBDF", "OrdinaryDiffEqFunctionMap", "OrdinaryDiffEqFIRK"]

lib/ModelingToolkitBase/Project.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ ConstructionBase = "1"
106106
DataInterpolations = "8.7"
107107
DataStructures = "0.18, 0.19"
108108
DelayDiffEq = "5.61, 6"
109-
DiffEqBase = "6.210.0, 7"
109+
DiffEqBase = "6.210.0, 7.2"
110110
DiffEqCallbacks = "2.16, 3, 4"
111111
DiffEqNoiseProcess = "5"
112112
DiffRules = "0.1, 1.0"
@@ -141,8 +141,12 @@ NonlinearSolve = "4.3"
141141
OffsetArrays = "1"
142142
OrderedCollections = "1"
143143
OrdinaryDiffEq = "6.82.0, 7"
144+
OrdinaryDiffEqBDF = "1, 2"
144145
OrdinaryDiffEqDefault = "1.2, 2"
146+
OrdinaryDiffEqFunctionMap = "1, 2"
145147
OrdinaryDiffEqNonlinearSolve = "1.5.0, 2"
148+
OrdinaryDiffEqRosenbrock = "1, 2"
149+
OrdinaryDiffEqSDIRK = "1, 2"
146150
PreallocationTools = "0.4.27, 1"
147151
PrecompileTools = "1.2.1"
148152
Pyomo = "0.1.0"
@@ -194,8 +198,12 @@ OptimizationIpopt = "43fad042-7963-4b32-ab19-e2a4f9a67124"
194198
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
195199
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
196200
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
201+
OrdinaryDiffEqBDF = "6ad6398a-0878-4a85-9266-38940aa047c8"
197202
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
203+
OrdinaryDiffEqFunctionMap = "d3585ca7-f5d3-4ba6-8057-292ed1abd90f"
198204
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
205+
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
206+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
199207
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
200208
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
201209
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -209,4 +217,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
209217
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
210218

211219
[targets]
212-
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationIpopt", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "Pkg", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve", "Latexify", "Distributed", "DiffEqNoiseProcess", "DynamicQuantities"]
220+
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationIpopt", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "Pkg", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve", "Latexify", "Distributed", "DiffEqNoiseProcess", "DynamicQuantities", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqBDF", "OrdinaryDiffEqFunctionMap"]

lib/ModelingToolkitBase/src/systems/callbacks.jl

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -882,43 +882,63 @@ function (ia::ImplicitAffect)(integ)
882882
return ia.reset_jumps && reset_aggregated_jumps!(integ)
883883
end
884884

885-
"""
886-
VectorAffect{E2A, AFFS}
887-
888-
Callable struct for the positive-edge arm of a `VectorContinuousCallback`. Routes an
889-
integrator call to the appropriate per-equation affect based on the equation index `idx`.
890-
Created inside [`generate_callback`](@ref) for vectors of `SymbolicContinuousCallback`s.
891-
892-
# Fields
893-
- `eq2affect`: maps condition equation index → affect index
894-
- `affects`: vector of compiled affect callables, one per callback in the group
895-
"""
896-
struct VectorAffect{E2A, AFFS}
897-
eq2affect::E2A
898-
affects::AFFS
899-
end
900-
901-
(va::VectorAffect)(integ, idx) = va.affects[va.eq2affect[idx]](integ)
902-
903-
"""
904-
VectorAffectNeg{E2A, AFFS}
885+
@static if pkgversion(SciMLBase) < v"3"
886+
"""
887+
VectorAffect{E2A, AFFS}
888+
889+
Callable struct for a `VectorContinuousCallback`. Routes an
890+
integrator call to the appropriate per-equation affect based on the equation index `idx`.
891+
Created inside [`generate_callback`](@ref) for vectors of `SymbolicContinuousCallback`s.
892+
Skips `nothing` affects.
893+
894+
# Fields
895+
- `eq2affect`: maps condition equation index → affect index
896+
- `affects`: vector of compiled affect callables (entries may be `nothing`)
897+
"""
898+
struct VectorAffect{E2A, AFFS}
899+
eq2affect::E2A
900+
affects::AFFS
901+
end
905902

906-
Callable struct for the negative-edge arm of a `VectorContinuousCallback`. Like
907-
[`VectorAffect`](@ref) but skips `nothing` entries (callbacks with no negative-edge affect).
903+
function (va::VectorAffect)(integ, idx)
904+
f = va.affects[va.eq2affect[idx]]
905+
f === nothing && return
906+
return f(integ)
907+
end
908908

909-
# Fields
910-
- `eq2affect`: maps condition equation index → affect index
911-
- `affect_negs`: vector of compiled negative-edge affect callables (entries may be `nothing`)
912-
"""
913-
struct VectorAffectNeg{E2A, AFFS}
914-
eq2affect::E2A
915-
affect_negs::AFFS
916-
end
909+
else
910+
"""
911+
$TYPEDEF
912+
913+
Callable struct for a `VectorContinuousCallback`. Routes an
914+
integrator call to the appropriate per-equation affect based on the equation index `idx`.
915+
Created inside [`generate_callback`](@ref) for vectors of `SymbolicContinuousCallback`s.
916+
Skips `nothing` affects.
917+
918+
# Fields
919+
- `eq2affect`: maps condition equation index → affect index
920+
- `affects`: vector of compiled positive-edge affect callables (entries may be `nothing`)
921+
- `affect_negs`: vector of compiled negative-edge affect callables (entries may be `nothing`)
922+
"""
923+
struct VectorAffect{E2A, AFFS, NAFFS}
924+
eq2affect::E2A
925+
affects::AFFS
926+
affect_negs::NAFFS
927+
end
917928

918-
function (va::VectorAffectNeg)(integ, idx)
919-
f = va.affect_negs[va.eq2affect[idx]]
920-
f === nothing && return
921-
return f(integ)
929+
function (va::VectorAffect)(integ, evts)
930+
for (i, evt) in enumerate(evts)
931+
if evt == 1
932+
f = va.affects[va.eq2affect[i]]
933+
f === nothing && continue
934+
f(integ)
935+
elseif evt == -1
936+
f = va.affect_negs[va.eq2affect[i]]
937+
f === nothing && continue
938+
f(integ)
939+
end
940+
end
941+
end
922942
end
923943

924944
"""
@@ -1073,7 +1093,7 @@ one equation) from a homogeneous group of `SymbolicContinuousCallback`s that sha
10731093
rootfinding class. Delegates to the single-callback overload when `sum(num_eqs) == 1`.
10741094
10751095
Affect routing (from condition equation index to per-callback affect) is encoded in
1076-
[`VectorAffect`](@ref) and [`VectorAffectNeg`](@ref) callable structs.
1096+
the [`VectorAffect`](@ref) callable structs.
10771097
Initialize/finalize are wrapped in [`VectorOptionalAffect`](@ref) via
10781098
[`wrap_vector_optional_affect`](@ref).
10791099
"""
@@ -1101,17 +1121,30 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
11011121
eq2affect = reduce(vcat, [fill(i, num_eqs[i]) for i in eachindex(compiled.affects)])
11021122
eqs = reduce(vcat, eqs)
11031123

1104-
affect = VectorAffect(eq2affect, compiled.affects)
1105-
affect_neg = VectorAffectNeg(eq2affect, compiled.affect_negs)
1106-
initialize = wrap_vector_optional_affect(compiled.inits, SciMLBase.INITIALIZE_DEFAULT)
1107-
finalize = wrap_vector_optional_affect(compiled.finals, SciMLBase.FINALIZE_DEFAULT)
1108-
1109-
return VectorContinuousCallback(
1110-
trigger, affect, affect_neg, length(eqs); initialize, finalize,
1111-
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg,
1112-
saved_clock_partitions = compiled.saved_clock_partitions,
1113-
initialize_save_discretes = cbs[1].initialize_save_discretes
1114-
)
1124+
@static if pkgversion(SciMLBase) < v"3"
1125+
affect = VectorAffect(eq2affect, compiled.affects)
1126+
affect_neg = VectorAffect(eq2affect, compiled.affect_negs)
1127+
initialize = wrap_vector_optional_affect(compiled.inits, SciMLBase.INITIALIZE_DEFAULT)
1128+
finalize = wrap_vector_optional_affect(compiled.finals, SciMLBase.FINALIZE_DEFAULT)
1129+
1130+
return VectorContinuousCallback(
1131+
trigger, affect, affect_neg, length(eqs); initialize, finalize,
1132+
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg,
1133+
saved_clock_partitions = compiled.saved_clock_partitions,
1134+
initialize_save_discretes = cbs[1].initialize_save_discretes
1135+
)
1136+
else
1137+
affect = VectorAffect(eq2affect, compiled.affects, compiled.affect_negs)
1138+
initialize = wrap_vector_optional_affect(compiled.inits, SciMLBase.INITIALIZE_DEFAULT)
1139+
finalize = wrap_vector_optional_affect(compiled.finals, SciMLBase.FINALIZE_DEFAULT)
1140+
1141+
return VectorContinuousCallback(
1142+
trigger, affect, length(eqs); initialize, finalize,
1143+
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg,
1144+
saved_clock_partitions = compiled.saved_clock_partitions,
1145+
initialize_save_discretes = cbs[1].initialize_save_discretes
1146+
)
1147+
end
11151148
end
11161149

11171150
"""

lib/ModelingToolkitBase/test/analysis_points.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using ModelingToolkitStandardLibrary.Mechanical.Rotational
33
using ModelingToolkitStandardLibrary.Blocks
44
using SymbolicIndexingInterface
55
using OrdinaryDiffEq, LinearAlgebra
6+
using OrdinaryDiffEqRosenbrock
7+
using SciMLBase
68
using Test
79
using ModelingToolkitBase: t_nounits as t, D_nounits as D, AnalysisPoint, AbstractSystem
810
import ModelingToolkitBase as MTK

lib/ModelingToolkitBase/test/code_generation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkitBase, OrdinaryDiffEq, SymbolicIndexingInterface
2+
using SciMLBase
23
using SymbolicUtils: unwrap
34
using ModelingToolkitBase: t_nounits as t, D_nounits as D
45
using Test

lib/ModelingToolkitBase/test/components.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
22
using ModelingToolkitBase, OrdinaryDiffEq
3+
using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF
34
using ModelingToolkitBase: get_component_type, complete
45
using ModelingToolkitBase: t_nounits as t, D_nounits as D, value
56
using ModelingToolkitStandardLibrary.Electrical

lib/ModelingToolkitBase/test/constants.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
MT = ModelingToolkitBase
44

55
@constants a = 1
6-
@test isconstant(a)
6+
@test MT.isconstant(a)
77
@test !istunable(a)
88

99
@independent_variables t

lib/ModelingToolkitBase/test/discrete_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ prob_map = DiscreteProblem(
6868
@test prob_map.f.sys === syss
6969

7070
# Solution
71-
using OrdinaryDiffEq
71+
using OrdinaryDiffEqFunctionMap
7272
sol_map = solve(prob_map, FunctionMap());
7373
@test sol_map[S] isa Vector
7474
@test sol_map[S(k - 1)] isa Vector

lib/ModelingToolkitBase/test/index_cache.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using ModelingToolkitBase, SymbolicIndexingInterface, SciMLStructures
22
using ModelingToolkitBase: t_nounits as t, D_nounits as D, SymbolicDiscreteCallback
33
using OrdinaryDiffEq
4+
using OrdinaryDiffEqRosenbrock
5+
using SciMLBase
46
import SymbolicUtils as SU
57
using PreallocationTools: DiffCache
68
using ForwardDiff

lib/ModelingToolkitBase/test/initial_values.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ModelingToolkitBase
22
using ModelingToolkitBase: t_nounits as t, D_nounits as D, get_u0
3+
using SciMLBase
34
using OrdinaryDiffEq
45
using DataInterpolations
56
using StaticArrays

0 commit comments

Comments
 (0)