Skip to content

Commit 2a65cd3

Browse files
Merge pull request #52 from JuliaComputing/as/backshift-additional-observed
fix: backshift `additional_observed` when generating equations
2 parents 8a412df + 6474cdd commit 2a65cd3

3 files changed

Lines changed: 32 additions & 1 deletion

File tree

lib/ModelingToolkitTearing/src/reassemble.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,9 +939,13 @@ function update_simplified_system!(
939939
MTKBase.isdiffeq(eq) || continue
940940
obs_sub[eq.lhs] = eq.rhs
941941
end
942+
(; additional_observed) = state
943+
if StateSelection.is_only_discrete(structure)
944+
additional_observed = map(Base.Fix2(backshift_expr, iv), additional_observed)
945+
end
942946
# TODO: compute the dependency correctly so that we don't have to do this
943947
obs = [substitute(observed(sys), obs_sub); solved_eqs;
944-
substitute(state.additional_observed, obs_sub)]
948+
substitute(additional_observed, obs_sub)]
945949

946950
filterer = let diff_to_var = diff_to_var, ispresent = ispresent, fullvars = fullvars,
947951
solved_vars = solved_vars

lib/ModelingToolkitTearing/src/tearingstate.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,15 @@ function shift_discrete_system(ts::TearingState)
591591
eqs[i] = MTKBase.simplify_shifts(substitute(
592592
eqs[i], discmap; filterer = Symbolics.FPSubFilterer{Union{Sample, Hold, Pre}}()))
593593
end
594+
595+
original_eqs = copy(ts.original_eqs)
596+
for i in eachindex(original_eqs)
597+
original_eqs[i] = MTKBase.simplify_shifts(substitute(
598+
original_eqs[i], discmap; filterer = Symbolics.FPSubFilterer{Union{Sample, Hold, Pre}}()))
599+
end
600+
594601
@set! ts.sys.eqs = eqs
602+
@set! ts.original_eqs = original_eqs
595603
@set! ts.fullvars = fullvars
596604
return ts
597605
end

lib/ModelingToolkitTearing/test/runtests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,22 @@ end
110110
@test isempty(unknowns(sys))
111111
@test length(observed(sys)) == 2
112112
end
113+
114+
@testset "`additional_observed` works correctly for discrete systems" begin
115+
@variables x(t) y(t)
116+
k = ShiftIndex(t)
117+
@named sys = System([x(k) ~ x(k-1) + 1, y(k) ~ x(k)], t)
118+
ts = TearingState(sys)
119+
# Original equations should mirror those of the system
120+
@test issetequal(ts.original_eqs, equations(sys))
121+
# `mark_discrete` should shift `original_eqs` forward too
122+
ts = MTKTearing.mark_discrete(ts)
123+
@test issetequal(ts.original_eqs, [x(k+1) ~ x(k) + 1, y(k+1) ~ x(k + 1)])
124+
# `trivial_tearing!` should then move them to `additional_observed`
125+
StateSelection.trivial_tearing!(ts)
126+
@test issetequal(ts.additional_observed, [y(k+1) ~ x(k+1)])
127+
# MTK calls `trivial_tearing!`, and the reassemble process should backshift
128+
# `additional_observed`
129+
ss = mtkcompile(sys)
130+
@test any(isequal(y(k) ~ x(k)), observed(ss))
131+
end

0 commit comments

Comments
 (0)