Skip to content

Commit 3edc9fd

Browse files
authored
Merge pull request #500 from SciML/mtk_interface_fix
Fixing observables and bindings (addressing #496)
2 parents 3cd707f + 7fc24f6 commit 3edc9fd

2 files changed

Lines changed: 112 additions & 8 deletions

File tree

ext/ModelingToolkitSIExt.jl

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@ using ModelingToolkitBase
1414

1515
# ------------------------------------------------------------------------------
1616

17+
# checking if it is a function of the form x(t), a bit dirty
18+
function isfunction(e::SymbolicUtils.BasicSymbolic)
19+
return length(Symbolics.arguments(e)) == 1 && "$(first(Symbolics.arguments(e)))" == "t"
20+
end
21+
1722
function StructuralIdentifiability.eval_at_nemo(e::Num, vals::Dict)
1823
e = Symbolics.value(e)
1924
return eval_at_nemo(e, vals)
2025
end
2126

2227
function StructuralIdentifiability.eval_at_nemo(e::SymbolicUtils.BasicSymbolic, vals::Dict)
2328
if Symbolics.iscall(e)
24-
# checking if it is a function of the form x(t), a bit dirty
25-
if length(Symbolics.arguments(e)) == 1 && "$(first(Symbolics.arguments(e)))" == "t"
29+
if isfunction(e)
2630
return vals[e]
2731
end
2832
# checking if this is a vector entry like x(t)[1]
@@ -80,15 +84,17 @@ function StructuralIdentifiability.eval_at_nemo(
8084
end
8185

8286
function get_measured_quantities(ode::ModelingToolkitBase.System)
87+
# filterings is to discard vectorial entities (with tearing and alike)
88+
scalar_observed = filter(e -> length(Symbolics.shape(e.rhs)) == 0, ModelingToolkitBase.observed(ode))
8389
outputs = filter(
8490
eq -> ModelingToolkitBase.isoutput(eq.lhs),
85-
vcat(ModelingToolkitBase.equations(ode), ModelingToolkitBase.observed(ode)),
91+
vcat(ModelingToolkitBase.equations(ode), scalar_observed),
8692
)
8793
if !isempty(outputs)
8894
return outputs
8995
elseif !isempty(ModelingToolkitBase.observed(ode))
9096
@warn "All `observed` variables from the MTK model are taken as outputs, make sure this is what you wanted"
91-
return ModelingToolkitBase.observed(ode)
97+
return scalar_observed
9298
else
9399
throw(
94100
error(
@@ -184,6 +190,45 @@ end
184190

185191
#------------------------------------------------------------------------------
186192

193+
function extract_shifts_and_derivatives(exs_list)
194+
result = Dict(
195+
:vars => [],
196+
:shifts => [],
197+
:derivatives => []
198+
)
199+
function _walk(ex)
200+
!Symbolics.iscall(ex) && return
201+
op = Symbolics.operation(ex)
202+
args = Symbolics.arguments(ex)
203+
204+
if op isa Shift
205+
push!(result[:shifts], first(args))
206+
return
207+
end
208+
209+
if op isa Differential
210+
push!(result[:derivatives], first(args))
211+
return
212+
end
213+
214+
if length(args) == 1
215+
push!(result[:vars], ex)
216+
return
217+
end
218+
219+
for arg in args
220+
_walk(arg)
221+
end
222+
return
223+
end
224+
for ex in exs_list
225+
_walk(ex)
226+
end
227+
return result
228+
end
229+
230+
#------------------------------------------------------------------------------
231+
187232
function scalarize(arr)
188233
result = []
189234
for a in arr
@@ -213,15 +258,23 @@ function __mtk_to_si(
213258
de::ModelingToolkitBase.System,
214259
measured_quantities::Array{<:Tuple{String, <:SymbolicUtils.BasicSymbolic}},
215260
)
261+
216262
polytype = StructuralIdentifiability.Nemo.QQMPolyRingElem
217263
fractype = StructuralIdentifiability.Nemo.Generic.FracFieldElem{polytype}
218264
diff_eqs = filter(
219265
eq -> !(ModelingToolkitBase.isoutput(eq.lhs)),
220266
ModelingToolkitBase.equations(de),
221267
)
268+
output_funcs = [e[2] for e in measured_quantities]
269+
222270
# performing full structural simplification
223-
if length(observed(de)) > 0
271+
if (length(observed(de)) > 0) || (length(bindings(de)) > 0)
224272
rules = Dict(s.lhs => s.rhs for s in observed(de))
273+
for (k, v) in bindings(de)
274+
if ModelingToolkitBase.isparameter(k)
275+
rules = merge(rules, Dict(scalarize(k) .=> scalarize(v)))
276+
end
277+
end
225278
while any(
226279
[
227280
length(intersect(get_variables(r), keys(rules))) > 0 for r in values(rules)
@@ -230,14 +283,29 @@ function __mtk_to_si(
230283
rules = Dict(k => SymbolicUtils.substitute(v, rules) for (k, v) in rules)
231284
end
232285
diff_eqs = [SymbolicUtils.substitute(eq, rules) for eq in diff_eqs]
286+
output_funcs = [SymbolicUtils.substitute(f, rules) for f in output_funcs]
287+
end
288+
289+
de_lhs = extract_shifts_and_derivatives([e.lhs for e in diff_eqs])
290+
de_rhs = extract_shifts_and_derivatives([e.rhs for e in diff_eqs])
291+
out_rhs = extract_shifts_and_derivatives(output_funcs)
292+
293+
(isempty(out_rhs[:shifts]) && isempty(out_rhs[:derivatives])) || throw(DomainError("Output expressions cannot contain neither shifts nor derivatives"))
294+
isempty(de_lhs[:shifts]) || throw(DomainError("Shifts are not allowed on the left-hand side"))
295+
296+
if !isempty(de_rhs[:shifts])
297+
(is_empty(de_rhs[:derivatives]) && is_empty(de_lhs[:derivatives])) || throw(DomainError("Derivatives and shifts cannot appear at the same time"))
298+
isempty(intersect(de_lhs[:vars], de_rhs[:vars])) || throw(DomainError("States in the right-hand side of the dynamics equations can appear only as shifts"))
299+
else
300+
isempty(de_lhs[:vars]) || throw(DomainError("States on the left-hand side must appear with derivations (and this is not the case for $(de_lhs[:vars])). Did you mean to mtkcompile the model?"))
233301
end
234302

235303
y_functions = [each[2] for each in measured_quantities]
236304
state_vars = filter(
237305
s -> !ModelingToolkitBase.isoutput(s),
238306
clean_calls(map(e -> e.lhs, diff_eqs)),
239307
)
240-
all_funcs = collect(Set(clean_calls(ModelingToolkitBase.unknowns(de))))
308+
all_funcs = collect(scalarize(ModelingToolkitBase.unknowns(de)))
241309
inputs = filter(s -> !ModelingToolkitBase.isoutput(s), setdiff(all_funcs, state_vars))
242310
params = reduce(vcat, SymbolicUtils.scalarize(ModelingToolkitBase.parameters(de)))
243311
t = ModelingToolkitBase.arguments(clean_calls([diff_eqs[1].lhs])[1])[1]
@@ -279,7 +347,7 @@ function __mtk_to_si(
279347
end
280348
end
281349
for i in 1:length(measured_quantities)
282-
out_eqn_dict[y_vars[i]] = eval_at_nemo(measured_quantities[i][2], symb2gens)
350+
out_eqn_dict[y_vars[i]] = eval_at_nemo(output_funcs[i], symb2gens)
283351
end
284352

285353
inputs_ = [symb2gens[each] for each in inputs]

test/extensions/modelingtoolkit.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,10 +838,46 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
838838
@parameters k₁ k₂ μ₁max μ₂max
839839
eqs = [μ₁ ~ k₁ + μ₁max, D(X) ~ (μ₁ + μ₂max) * X, yX ~ X]
840840

841-
ode = System(eqs, t, name = :output_definition_case)
841+
@mtkcompile ode = System(eqs, t)
842842

843843
id_res = assess_identifiability(ode)
844844
@test 1 == count(v -> v == :globally, values(id_res))
845+
846+
# Examples from https://github.com/SciML/StructuralIdentifiability.jl/issues/496
847+
@variables x(t)[1:2] y(t)[1:2] [output = true]
848+
@parameters p[1:2] q[1:2]
849+
eqs = [
850+
D(x) ~ p .* x + y,
851+
y ~ q .* x,
852+
]
853+
854+
@mtkcompile sys = System(eqs, t; bindings = [y => x, q => 2 .* p])
855+
856+
res = mtk_to_si(sys, Equation[])
857+
858+
@test length(res[1].parameters) == 2
859+
@test length(res[1].x_vars) == 2
860+
@test length(res[1].y_vars) == 2
861+
@test length(res[1].u_vars) == 0
862+
863+
@variables x(t) y(t) z(t) [output = true]
864+
@parameters p q
865+
eqs = [
866+
D(x) ~ p * x + y
867+
y ~ q * x
868+
z ~ x + y
869+
]
870+
871+
@named sys_model = System(eqs, t)
872+
873+
@test_throws DomainError mtk_to_si(sys_model, Equation[])
874+
875+
@mtkcompile sys = System(eqs, t)
876+
res = mtk_to_si(sys, Equation[])
877+
@test length(res[1].parameters) == 2
878+
@test length(res[1].x_vars) == 1
879+
@test length(res[1].y_vars) == 1
880+
@test length(res[1].u_vars) == 0
845881
end
846882

847883
@testset "Identifiability of MTK models with known generic initial conditions" begin

0 commit comments

Comments
 (0)