@@ -84,15 +84,17 @@ function StructuralIdentifiability.eval_at_nemo(
8484end
8585
8686function get_measured_quantities (ode:: ModelingToolkitBase.System )
87+ # filterings is to discard vectorial entities (with tearing and alike)
88+ scalar_observed = filter (e -> length (Symbolics. shape (e. rhs)) == 0 , ModelingToolkitBase. observed (ode))
8789 outputs = filter (
8890 eq -> ModelingToolkitBase. isoutput (eq. lhs),
89- vcat (ModelingToolkitBase. equations (ode), ModelingToolkitBase . observed (ode) ),
91+ vcat (ModelingToolkitBase. equations (ode), scalar_observed ),
9092 )
9193 if ! isempty (outputs)
9294 return outputs
9395 elseif ! isempty (ModelingToolkitBase. observed (ode))
9496 @warn " All `observed` variables from the MTK model are taken as outputs, make sure this is what you wanted"
95- return ModelingToolkitBase . observed (ode)
97+ return scalar_observed
9698 else
9799 throw (
98100 error (
188190
189191# ------------------------------------------------------------------------------
190192
193+ function extract_shifts_and_derivatives (exs_list)
194+ result = Dict (
195+ :vars => [],
196+ :shifts => [],
197+ :derivatives => []
198+ )
199+ function _walk (ex)
200+ ! Symbolics. iscall (ex) && return
201+ op = Symbolics. operation (ex)
202+ args = Symbolics. arguments (ex)
203+
204+ if op isa Shift
205+ push! (result[:shifts ], first (args))
206+ return
207+ end
208+
209+ if op isa Differential
210+ push! (result[:derivatives ], first (args))
211+ return
212+ end
213+
214+ if length (args) == 1
215+ push! (result[:vars ], ex)
216+ return
217+ end
218+
219+ for arg in args
220+ _walk (arg)
221+ end
222+ return
223+ end
224+ for ex in exs_list
225+ _walk (ex)
226+ end
227+ return result
228+ end
229+
230+ # ------------------------------------------------------------------------------
231+
191232function scalarize (arr)
192233 result = []
193234 for a in arr
@@ -218,29 +259,22 @@ function __mtk_to_si(
218259 measured_quantities:: Array{<:Tuple{String, <:SymbolicUtils.BasicSymbolic}} ,
219260 )
220261
221- # checking if all the functions in the lhs are either states of outputs
222- ufo_functions = filter (
223- f -> ! (ModelingToolkitBase. isoutput (f)) && isfunction (f),
224- map (e -> e. lhs, ModelingToolkitBase. equations (de))
225- )
226- if ! isempty (ufo_functions)
227- throw (DomainError (" The following functions on the lhs of the equations are neither states not outputs: $ufo_functions . Did you mean to compile the model?" ))
228- end
229-
230262 polytype = StructuralIdentifiability. Nemo. QQMPolyRingElem
231263 fractype = StructuralIdentifiability. Nemo. Generic. FracFieldElem{polytype}
232264 diff_eqs = filter (
233265 eq -> ! (ModelingToolkitBase. isoutput (eq. lhs)),
234266 ModelingToolkitBase. equations (de),
235267 )
236- output_eqs = [e[2 ] for e in measured_quantities]
268+ output_funcs = [e[2 ] for e in measured_quantities]
237269
238270 # performing full structural simplification
239- if length (observed (de)) > 0 || length (bindings (de) > 0 )
240- rules = merge (
241- Dict (s. lhs => s. rhs for s in observed (de)),
242- Dict (k => v for (k, v) in bindings (de) if ModelingToolkitBase. isparameter (k))
243- )
271+ if (length (observed (de)) > 0 ) || (length (bindings (de)) > 0 )
272+ rules = Dict (s. lhs => s. rhs for s in observed (de))
273+ for (k, v) in bindings (de)
274+ if ModelingToolkitBase. isparameter (k)
275+ rules = merge (rules, Dict (scalarize (k) .=> scalarize (v)))
276+ end
277+ end
244278 while any (
245279 [
246280 length (intersect (get_variables (r), keys (rules))) > 0 for r in values (rules)
@@ -249,15 +283,29 @@ function __mtk_to_si(
249283 rules = Dict (k => SymbolicUtils. substitute (v, rules) for (k, v) in rules)
250284 end
251285 diff_eqs = [SymbolicUtils. substitute (eq, rules) for eq in diff_eqs]
252- output_eqs = [SymbolicUtils. substitute (eq, rules) for eq in output_eqs]
286+ output_funcs = [SymbolicUtils. substitute (f, rules) for f in output_funcs]
287+ end
288+
289+ de_lhs = extract_shifts_and_derivatives ([e. lhs for e in diff_eqs])
290+ de_rhs = extract_shifts_and_derivatives ([e. rhs for e in diff_eqs])
291+ out_rhs = extract_shifts_and_derivatives (output_funcs)
292+
293+ (isempty (out_rhs[:shifts ]) && isempty (out_rhs[:derivatives ])) || throw (DomainError (" Output expressions cannot contain neither shifts nor derivatives" ))
294+ isempty (de_lhs[:shifts ]) || throw (DomainError (" Shifts are not allowed on the left-hand side" ))
295+
296+ if ! isempty (de_rhs[:shifts ])
297+ (is_empty (de_rhs[:derivatives ]) && is_empty (de_lhs[:derivatives ])) || throw (DomainError (" Derivatives and shifts cannot appear at the same time" ))
298+ isempty (intersect (de_lhs[:vars ], de_rhs[:vars ])) || throw (DomainError (" States in the right-hand side of the dynamics equations can appear only as shifts" ))
299+ else
300+ isempty (de_lhs[:vars ]) || throw (DomainError (" States on the left-hand side must appear with derivations (and this is not the case for $(de_lhs[:vars ]) ). Did you mean to mtkcompile the model?" ))
253301 end
254302
255303 y_functions = [each[2 ] for each in measured_quantities]
256304 state_vars = filter (
257305 s -> ! ModelingToolkitBase. isoutput (s),
258306 clean_calls (map (e -> e. lhs, diff_eqs)),
259307 )
260- all_funcs = collect (Set ( clean_calls ( ModelingToolkitBase. unknowns (de) )))
308+ all_funcs = collect (scalarize ( ModelingToolkitBase. unknowns (de)))
261309 inputs = filter (s -> ! ModelingToolkitBase. isoutput (s), setdiff (all_funcs, state_vars))
262310 params = reduce (vcat, SymbolicUtils. scalarize (ModelingToolkitBase. parameters (de)))
263311 t = ModelingToolkitBase. arguments (clean_calls ([diff_eqs[1 ]. lhs])[1 ])[1 ]
@@ -299,7 +347,7 @@ function __mtk_to_si(
299347 end
300348 end
301349 for i in 1 : length (measured_quantities)
302- out_eqn_dict[y_vars[i]] = eval_at_nemo (output_eqs [i], symb2gens)
350+ out_eqn_dict[y_vars[i]] = eval_at_nemo (output_funcs [i], symb2gens)
303351 end
304352
305353 inputs_ = [symb2gens[each] for each in inputs]
0 commit comments