Skip to content

Commit c9bf530

Browse files
Merge pull request #30 from JuliaComputing/as/validate-inferred-discrete
fix: validate that `InferredDiscrete` ends up in a discrete partition
2 parents 526628f + b9883c5 commit c9bf530

1 file changed

Lines changed: 56 additions & 0 deletions

File tree

lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Equation(Int)
44
InitEquation(Int)
55
Clock(SciMLBase.AbstractClock)
6+
AssertDiscrete
67
end
78

89
struct ClockInference{S <: StateSelection.TransformationState}
@@ -147,6 +148,10 @@ function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_e
147148
InferredClock.InferredDiscrete(i) => begin
148149
relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i)
149150
union!(relative_edge, arg_hyperedge)
151+
# Ensure that this clock partition will be discrete. This is a separate
152+
# variant because I don't want to give `InferredDiscrete` too many meanings.
153+
push!(arg_hyperedge, ClockVertex.AssertDiscrete())
154+
add_edge!(inference_graph, arg_hyperedge)
150155
end
151156
end
152157
end
@@ -163,6 +168,7 @@ function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_e
163168
union!(hyperedge, buffer)
164169
delete!(relative_hyperedges, i)
165170
end
171+
push!(hyperedge, ClockVertex.AssertDiscrete())
166172
end
167173
end
168174
else
@@ -220,6 +226,9 @@ function infer_clocks!(ci::ClockInference)
220226
for partition in clock_partitions
221227
clockidxs = findall(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.Clock), partition)
222228
if isempty(clockidxs)
229+
if any(isequal(ClockVertex.AssertDiscrete()), partition)
230+
throw(ExpectedDiscreteClockPartitionError(ts, partition, true))
231+
end
223232
push!(partition, ClockVertex.Clock(SciMLBase.ContinuousClock()))
224233
push!(clockidxs, length(partition))
225234
end
@@ -237,12 +246,16 @@ function infer_clocks!(ci::ClockInference)
237246
clock = Moshi.Match.@match partition[only(clockidxs)] begin
238247
ClockVertex.Clock(clk) => clk
239248
end
249+
if clock == SciMLBase.ContinuousClock() && any(isequal(ClockVertex.AssertDiscrete()), partition)
250+
throw(ExpectedDiscreteClockPartitionError(ts, partition, false))
251+
end
240252
for vert in partition
241253
Moshi.Match.@match vert begin
242254
ClockVertex.Variable(i) => (var_domain[i] = clock)
243255
ClockVertex.Equation(i) => (eq_domain[i] = clock)
244256
ClockVertex.InitEquation(i) => (init_eq_domain[i] = clock)
245257
ClockVertex.Clock(_) => nothing
258+
ClockVertex.AssertDiscrete() => nothing
246259
end
247260
end
248261
end
@@ -251,6 +264,49 @@ function infer_clocks!(ci::ClockInference)
251264
return ci
252265
end
253266

267+
struct ExpectedDiscreteClockPartitionError <: Exception
268+
state::TearingState
269+
partition::Vector{ClockVertex.Type}
270+
has_no_clock::Bool
271+
end
272+
273+
function Base.showerror(io::IO, err::ExpectedDiscreteClockPartitionError)
274+
if err.has_no_clock
275+
println(io, """
276+
Found a clock partition that must be discrete (due to the presence of an \
277+
`InferredDiscrete`) but does not have any associated clock (and would otherwise \
278+
then default to being on the continuous clock). This likely means that the \
279+
partition was not assigned a valid discrete clock and the model is incorrect.
280+
""")
281+
else
282+
println(io, """
283+
Found a clock partition that must be discrete (due to the presence of an \
284+
`InferredDiscrete`) but is associated with a continuous clock. This is likely \
285+
a modeling error.
286+
""")
287+
end
288+
289+
vars = filter(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.Variable), err.partition)
290+
println(io, "Variables in the partition:")
291+
for var in vars
292+
println(io, " ", err.state.fullvars[var.:1])
293+
end
294+
println(io)
295+
296+
eqs = filter(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.Equation), err.partition)
297+
println(io, "Equations in the partition:")
298+
for eq in eqs
299+
println(io, " ", equations(err.state)[eq.:1])
300+
end
301+
println(io)
302+
303+
ieqs = filter(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.InitEquation), err.partition)
304+
println(io, "Initialization equations in the partition:")
305+
for ieq in ieqs
306+
println(io, " ", initialization_equations(err.state.sys)[ieq.:1])
307+
end
308+
end
309+
254310
function resize_or_push!(v, val, idx)
255311
n = length(v)
256312
if idx > n

0 commit comments

Comments
 (0)