Skip to content

Commit 7fc24f6

Browse files
committed
fixing MTK interface
1 parent 92f312a commit 7fc24f6

2 files changed

Lines changed: 105 additions & 21 deletions

File tree

ext/ModelingToolkitSIExt.jl

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,17 @@ function StructuralIdentifiability.eval_at_nemo(
8484
end
8585

8686
function get_measured_quantities(ode::ModelingToolkitBase.System)
87+
# filterings is to discard vectorial entities (with tearing and alike)
88+
scalar_observed = filter(e -> length(Symbolics.shape(e.rhs)) == 0, ModelingToolkitBase.observed(ode))
8789
outputs = filter(
8890
eq -> ModelingToolkitBase.isoutput(eq.lhs),
89-
vcat(ModelingToolkitBase.equations(ode), ModelingToolkitBase.observed(ode)),
91+
vcat(ModelingToolkitBase.equations(ode), scalar_observed),
9092
)
9193
if !isempty(outputs)
9294
return outputs
9395
elseif !isempty(ModelingToolkitBase.observed(ode))
9496
@warn "All `observed` variables from the MTK model are taken as outputs, make sure this is what you wanted"
95-
return ModelingToolkitBase.observed(ode)
97+
return scalar_observed
9698
else
9799
throw(
98100
error(
@@ -188,6 +190,45 @@ end
188190

189191
#------------------------------------------------------------------------------
190192

193+
function extract_shifts_and_derivatives(exs_list)
194+
result = Dict(
195+
:vars => [],
196+
:shifts => [],
197+
:derivatives => []
198+
)
199+
function _walk(ex)
200+
!Symbolics.iscall(ex) && return
201+
op = Symbolics.operation(ex)
202+
args = Symbolics.arguments(ex)
203+
204+
if op isa Shift
205+
push!(result[:shifts], first(args))
206+
return
207+
end
208+
209+
if op isa Differential
210+
push!(result[:derivatives], first(args))
211+
return
212+
end
213+
214+
if length(args) == 1
215+
push!(result[:vars], ex)
216+
return
217+
end
218+
219+
for arg in args
220+
_walk(arg)
221+
end
222+
return
223+
end
224+
for ex in exs_list
225+
_walk(ex)
226+
end
227+
return result
228+
end
229+
230+
#------------------------------------------------------------------------------
231+
191232
function scalarize(arr)
192233
result = []
193234
for a in arr
@@ -218,29 +259,22 @@ function __mtk_to_si(
218259
measured_quantities::Array{<:Tuple{String, <:SymbolicUtils.BasicSymbolic}},
219260
)
220261

221-
# checking if all the functions in the lhs are either states of outputs
222-
ufo_functions = filter(
223-
f -> !(ModelingToolkitBase.isoutput(f)) && isfunction(f),
224-
map(e -> e.lhs, ModelingToolkitBase.equations(de))
225-
)
226-
if !isempty(ufo_functions)
227-
throw(DomainError("The following functions on the lhs of the equations are neither states not outputs: $ufo_functions. Did you mean to compile the model?"))
228-
end
229-
230262
polytype = StructuralIdentifiability.Nemo.QQMPolyRingElem
231263
fractype = StructuralIdentifiability.Nemo.Generic.FracFieldElem{polytype}
232264
diff_eqs = filter(
233265
eq -> !(ModelingToolkitBase.isoutput(eq.lhs)),
234266
ModelingToolkitBase.equations(de),
235267
)
236-
output_eqs = [e[2] for e in measured_quantities]
268+
output_funcs = [e[2] for e in measured_quantities]
237269

238270
# performing full structural simplification
239-
if length(observed(de)) > 0 || length(bindings(de) > 0)
240-
rules = merge(
241-
Dict(s.lhs => s.rhs for s in observed(de)),
242-
Dict(k => v for (k, v) in bindings(de) if ModelingToolkitBase.isparameter(k))
243-
)
271+
if (length(observed(de)) > 0) || (length(bindings(de)) > 0)
272+
rules = Dict(s.lhs => s.rhs for s in observed(de))
273+
for (k, v) in bindings(de)
274+
if ModelingToolkitBase.isparameter(k)
275+
rules = merge(rules, Dict(scalarize(k) .=> scalarize(v)))
276+
end
277+
end
244278
while any(
245279
[
246280
length(intersect(get_variables(r), keys(rules))) > 0 for r in values(rules)
@@ -249,15 +283,29 @@ function __mtk_to_si(
249283
rules = Dict(k => SymbolicUtils.substitute(v, rules) for (k, v) in rules)
250284
end
251285
diff_eqs = [SymbolicUtils.substitute(eq, rules) for eq in diff_eqs]
252-
output_eqs = [SymbolicUtils.substitute(eq, rules) for eq in output_eqs]
286+
output_funcs = [SymbolicUtils.substitute(f, rules) for f in output_funcs]
287+
end
288+
289+
de_lhs = extract_shifts_and_derivatives([e.lhs for e in diff_eqs])
290+
de_rhs = extract_shifts_and_derivatives([e.rhs for e in diff_eqs])
291+
out_rhs = extract_shifts_and_derivatives(output_funcs)
292+
293+
(isempty(out_rhs[:shifts]) && isempty(out_rhs[:derivatives])) || throw(DomainError("Output expressions cannot contain neither shifts nor derivatives"))
294+
isempty(de_lhs[:shifts]) || throw(DomainError("Shifts are not allowed on the left-hand side"))
295+
296+
if !isempty(de_rhs[:shifts])
297+
(is_empty(de_rhs[:derivatives]) && is_empty(de_lhs[:derivatives])) || throw(DomainError("Derivatives and shifts cannot appear at the same time"))
298+
isempty(intersect(de_lhs[:vars], de_rhs[:vars])) || throw(DomainError("States in the right-hand side of the dynamics equations can appear only as shifts"))
299+
else
300+
isempty(de_lhs[:vars]) || throw(DomainError("States on the left-hand side must appear with derivations (and this is not the case for $(de_lhs[:vars])). Did you mean to mtkcompile the model?"))
253301
end
254302

255303
y_functions = [each[2] for each in measured_quantities]
256304
state_vars = filter(
257305
s -> !ModelingToolkitBase.isoutput(s),
258306
clean_calls(map(e -> e.lhs, diff_eqs)),
259307
)
260-
all_funcs = collect(Set(clean_calls(ModelingToolkitBase.unknowns(de))))
308+
all_funcs = collect(scalarize(ModelingToolkitBase.unknowns(de)))
261309
inputs = filter(s -> !ModelingToolkitBase.isoutput(s), setdiff(all_funcs, state_vars))
262310
params = reduce(vcat, SymbolicUtils.scalarize(ModelingToolkitBase.parameters(de)))
263311
t = ModelingToolkitBase.arguments(clean_calls([diff_eqs[1].lhs])[1])[1]
@@ -299,7 +347,7 @@ function __mtk_to_si(
299347
end
300348
end
301349
for i in 1:length(measured_quantities)
302-
out_eqn_dict[y_vars[i]] = eval_at_nemo(output_eqs[i], symb2gens)
350+
out_eqn_dict[y_vars[i]] = eval_at_nemo(output_funcs[i], symb2gens)
303351
end
304352

305353
inputs_ = [symb2gens[each] for each in inputs]

test/extensions/modelingtoolkit.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,10 +838,46 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
838838
@parameters k₁ k₂ μ₁max μ₂max
839839
eqs = [μ₁ ~ k₁ + μ₁max, D(X) ~ (μ₁ + μ₂max) * X, yX ~ X]
840840

841-
ode = System(eqs, t, name = :output_definition_case)
841+
@mtkcompile ode = System(eqs, t)
842842

843843
id_res = assess_identifiability(ode)
844844
@test 1 == count(v -> v == :globally, values(id_res))
845+
846+
# Examples from https://github.com/SciML/StructuralIdentifiability.jl/issues/496
847+
@variables x(t)[1:2] y(t)[1:2] [output = true]
848+
@parameters p[1:2] q[1:2]
849+
eqs = [
850+
D(x) ~ p .* x + y,
851+
y ~ q .* x,
852+
]
853+
854+
@mtkcompile sys = System(eqs, t; bindings = [y => x, q => 2 .* p])
855+
856+
res = mtk_to_si(sys, Equation[])
857+
858+
@test length(res[1].parameters) == 2
859+
@test length(res[1].x_vars) == 2
860+
@test length(res[1].y_vars) == 2
861+
@test length(res[1].u_vars) == 0
862+
863+
@variables x(t) y(t) z(t) [output = true]
864+
@parameters p q
865+
eqs = [
866+
D(x) ~ p * x + y
867+
y ~ q * x
868+
z ~ x + y
869+
]
870+
871+
@named sys_model = System(eqs, t)
872+
873+
@test_throws DomainError mtk_to_si(sys_model, Equation[])
874+
875+
@mtkcompile sys = System(eqs, t)
876+
res = mtk_to_si(sys, Equation[])
877+
@test length(res[1].parameters) == 2
878+
@test length(res[1].x_vars) == 1
879+
@test length(res[1].y_vars) == 1
880+
@test length(res[1].u_vars) == 0
845881
end
846882

847883
@testset "Identifiability of MTK models with known generic initial conditions" begin

0 commit comments

Comments
 (0)