Skip to content

Commit ee07e61

Browse files
committed
fixing ufo function issue and observables in the output equations
1 parent 3cd707f commit ee07e61

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

ext/ModelingToolkitSIExt.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@ using ModelingToolkitBase
1414

1515
# ------------------------------------------------------------------------------
1616

17+
# checking if it is a function of the form x(t), a bit dirty
18+
function isfunction(e::SymbolicUtils.BasicSymbolic)
19+
return length(Symbolics.arguments(e)) == 1 && "$(first(Symbolics.arguments(e)))" == "t"
20+
end
21+
1722
function StructuralIdentifiability.eval_at_nemo(e::Num, vals::Dict)
1823
e = Symbolics.value(e)
1924
return eval_at_nemo(e, vals)
2025
end
2126

2227
function StructuralIdentifiability.eval_at_nemo(e::SymbolicUtils.BasicSymbolic, vals::Dict)
2328
if Symbolics.iscall(e)
24-
# checking if it is a function of the form x(t), a bit dirty
25-
if length(Symbolics.arguments(e)) == 1 && "$(first(Symbolics.arguments(e)))" == "t"
29+
if isfunction(e)
2630
return vals[e]
2731
end
2832
# checking if this is a vector entry like x(t)[1]
@@ -213,12 +217,24 @@ function __mtk_to_si(
213217
de::ModelingToolkitBase.System,
214218
measured_quantities::Array{<:Tuple{String, <:SymbolicUtils.BasicSymbolic}},
215219
)
220+
221+
# checking if all the functions in the lhs are either states of outputs
222+
ufo_functions = filter(
223+
f -> !(ModelingToolkitBase.isoutput(f)) && isfunction(f),
224+
map(e -> e.lhs, ModelingToolkitBase.equations(de))
225+
)
226+
if !isempty(ufo_functions)
227+
throw(DomainError("The following functions on the lhs of the equations are neither states not outputs: $ufo_functions. Did you mean to compile the model?"))
228+
end
229+
216230
polytype = StructuralIdentifiability.Nemo.QQMPolyRingElem
217231
fractype = StructuralIdentifiability.Nemo.Generic.FracFieldElem{polytype}
218232
diff_eqs = filter(
219233
eq -> !(ModelingToolkitBase.isoutput(eq.lhs)),
220234
ModelingToolkitBase.equations(de),
221235
)
236+
output_eqs = [e[2] for e in measured_quantities]
237+
222238
# performing full structural simplification
223239
if length(observed(de)) > 0
224240
rules = Dict(s.lhs => s.rhs for s in observed(de))
@@ -230,6 +246,7 @@ function __mtk_to_si(
230246
rules = Dict(k => SymbolicUtils.substitute(v, rules) for (k, v) in rules)
231247
end
232248
diff_eqs = [SymbolicUtils.substitute(eq, rules) for eq in diff_eqs]
249+
output_eqs = [SymbolicUtils.substitute(eq, rules) for eq in output_eqs]
233250
end
234251

235252
y_functions = [each[2] for each in measured_quantities]
@@ -279,7 +296,7 @@ function __mtk_to_si(
279296
end
280297
end
281298
for i in 1:length(measured_quantities)
282-
out_eqn_dict[y_vars[i]] = eval_at_nemo(measured_quantities[i][2], symb2gens)
299+
out_eqn_dict[y_vars[i]] = eval_at_nemo(output_eqs[i], symb2gens)
283300
end
284301

285302
inputs_ = [symb2gens[each] for each in inputs]

0 commit comments

Comments
 (0)