@@ -14,15 +14,19 @@ using ModelingToolkitBase
1414
1515# ------------------------------------------------------------------------------
1616
17+ # checking if it is a function of the form x(t), a bit dirty
18+ function isfunction (e:: SymbolicUtils.BasicSymbolic )
19+ return length (Symbolics. arguments (e)) == 1 && " $(first (Symbolics. arguments (e))) " == " t"
20+ end
21+
1722function StructuralIdentifiability. eval_at_nemo (e:: Num , vals:: Dict )
1823 e = Symbolics. value (e)
1924 return eval_at_nemo (e, vals)
2025end
2126
2227function StructuralIdentifiability. eval_at_nemo (e:: SymbolicUtils.BasicSymbolic , vals:: Dict )
2328 if Symbolics. iscall (e)
24- # checking if it is a function of the form x(t), a bit dirty
25- if length (Symbolics. arguments (e)) == 1 && " $(first (Symbolics. arguments (e))) " == " t"
29+ if isfunction (e)
2630 return vals[e]
2731 end
2832 # checking if this is a vector entry like x(t)[1]
@@ -80,15 +84,17 @@ function StructuralIdentifiability.eval_at_nemo(
8084end
8185
8286function get_measured_quantities (ode:: ModelingToolkitBase.System )
87+ # filterings is to discard vectorial entities (with tearing and alike)
88+ scalar_observed = filter (e -> length (Symbolics. shape (e. rhs)) == 0 , ModelingToolkitBase. observed (ode))
8389 outputs = filter (
8490 eq -> ModelingToolkitBase. isoutput (eq. lhs),
85- vcat (ModelingToolkitBase. equations (ode), ModelingToolkitBase . observed (ode) ),
91+ vcat (ModelingToolkitBase. equations (ode), scalar_observed ),
8692 )
8793 if ! isempty (outputs)
8894 return outputs
8995 elseif ! isempty (ModelingToolkitBase. observed (ode))
9096 @warn " All `observed` variables from the MTK model are taken as outputs, make sure this is what you wanted"
91- return ModelingToolkitBase . observed (ode)
97+ return scalar_observed
9298 else
9399 throw (
94100 error (
184190
185191# ------------------------------------------------------------------------------
186192
193+ function extract_shifts_and_derivatives (exs_list)
194+ result = Dict (
195+ :vars => [],
196+ :shifts => [],
197+ :derivatives => []
198+ )
199+ function _walk (ex)
200+ ! Symbolics. iscall (ex) && return
201+ op = Symbolics. operation (ex)
202+ args = Symbolics. arguments (ex)
203+
204+ if op isa Shift
205+ push! (result[:shifts ], first (args))
206+ return
207+ end
208+
209+ if op isa Differential
210+ push! (result[:derivatives ], first (args))
211+ return
212+ end
213+
214+ if length (args) == 1
215+ push! (result[:vars ], ex)
216+ return
217+ end
218+
219+ for arg in args
220+ _walk (arg)
221+ end
222+ return
223+ end
224+ for ex in exs_list
225+ _walk (ex)
226+ end
227+ return result
228+ end
229+
230+ # ------------------------------------------------------------------------------
231+
187232function scalarize (arr)
188233 result = []
189234 for a in arr
@@ -213,15 +258,23 @@ function __mtk_to_si(
213258 de:: ModelingToolkitBase.System ,
214259 measured_quantities:: Array{<:Tuple{String, <:SymbolicUtils.BasicSymbolic}} ,
215260 )
261+
216262 polytype = StructuralIdentifiability. Nemo. QQMPolyRingElem
217263 fractype = StructuralIdentifiability. Nemo. Generic. FracFieldElem{polytype}
218264 diff_eqs = filter (
219265 eq -> ! (ModelingToolkitBase. isoutput (eq. lhs)),
220266 ModelingToolkitBase. equations (de),
221267 )
268+ output_funcs = [e[2 ] for e in measured_quantities]
269+
222270 # performing full structural simplification
223- if length (observed (de)) > 0
271+ if ( length (observed (de)) > 0 ) || ( length ( bindings (de)) > 0 )
224272 rules = Dict (s. lhs => s. rhs for s in observed (de))
273+ for (k, v) in bindings (de)
274+ if ModelingToolkitBase. isparameter (k)
275+ rules = merge (rules, Dict (scalarize (k) .=> scalarize (v)))
276+ end
277+ end
225278 while any (
226279 [
227280 length (intersect (get_variables (r), keys (rules))) > 0 for r in values (rules)
@@ -230,14 +283,29 @@ function __mtk_to_si(
230283 rules = Dict (k => SymbolicUtils. substitute (v, rules) for (k, v) in rules)
231284 end
232285 diff_eqs = [SymbolicUtils. substitute (eq, rules) for eq in diff_eqs]
286+ output_funcs = [SymbolicUtils. substitute (f, rules) for f in output_funcs]
287+ end
288+
289+ de_lhs = extract_shifts_and_derivatives ([e. lhs for e in diff_eqs])
290+ de_rhs = extract_shifts_and_derivatives ([e. rhs for e in diff_eqs])
291+ out_rhs = extract_shifts_and_derivatives (output_funcs)
292+
293+ (isempty (out_rhs[:shifts ]) && isempty (out_rhs[:derivatives ])) || throw (DomainError (" Output expressions cannot contain neither shifts nor derivatives" ))
294+ isempty (de_lhs[:shifts ]) || throw (DomainError (" Shifts are not allowed on the left-hand side" ))
295+
296+ if ! isempty (de_rhs[:shifts ])
297+ (is_empty (de_rhs[:derivatives ]) && is_empty (de_lhs[:derivatives ])) || throw (DomainError (" Derivatives and shifts cannot appear at the same time" ))
298+ isempty (intersect (de_lhs[:vars ], de_rhs[:vars ])) || throw (DomainError (" States in the right-hand side of the dynamics equations can appear only as shifts" ))
299+ else
300+ isempty (de_lhs[:vars ]) || throw (DomainError (" States on the left-hand side must appear with derivations (and this is not the case for $(de_lhs[:vars ]) ). Did you mean to mtkcompile the model?" ))
233301 end
234302
235303 y_functions = [each[2 ] for each in measured_quantities]
236304 state_vars = filter (
237305 s -> ! ModelingToolkitBase. isoutput (s),
238306 clean_calls (map (e -> e. lhs, diff_eqs)),
239307 )
240- all_funcs = collect (Set ( clean_calls ( ModelingToolkitBase. unknowns (de) )))
308+ all_funcs = collect (scalarize ( ModelingToolkitBase. unknowns (de)))
241309 inputs = filter (s -> ! ModelingToolkitBase. isoutput (s), setdiff (all_funcs, state_vars))
242310 params = reduce (vcat, SymbolicUtils. scalarize (ModelingToolkitBase. parameters (de)))
243311 t = ModelingToolkitBase. arguments (clean_calls ([diff_eqs[1 ]. lhs])[1 ])[1 ]
@@ -279,7 +347,7 @@ function __mtk_to_si(
279347 end
280348 end
281349 for i in 1 : length (measured_quantities)
282- out_eqn_dict[y_vars[i]] = eval_at_nemo (measured_quantities[i][ 2 ], symb2gens)
350+ out_eqn_dict[y_vars[i]] = eval_at_nemo (output_funcs[i ], symb2gens)
283351 end
284352
285353 inputs_ = [symb2gens[each] for each in inputs]
0 commit comments