Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,21 @@ function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_e
end

outdomain = output_timedomain(op)
@match outdomain begin
x::SciMLBase.AbstractClock => begin
push!(hyperedge, ClockVertex.Clock(x))
end
InferredClock.Inferred() => nothing
InferredClock.InferredDiscrete(i) => begin
buffer = get(relative_hyperedges, i, nothing)
if buffer !== nothing
union!(hyperedge, buffer)
delete!(relative_hyperedges, i)
if outdomain isa SciMLBase.AbstractClock
push!(hyperedge, ClockVertex.Clock(outdomain))
elseif outdomain isa InferredTimeDomain
@match outdomain begin
InferredClock.Inferred() => nothing
InferredClock.InferredDiscrete(i) => begin
buffer = get(relative_hyperedges, i, nothing)
if buffer !== nothing
union!(hyperedge, buffer)
delete!(relative_hyperedges, i)
end
end
end
else
error("Unreachable reached!")
end

for (_, relative_edge) in relative_hyperedges
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function system_subset(ts::TearingState, ieqs::Vector{Int}, iieqs::Vector{Int},
@set! ts.structure = system_subset(ts.structure, ieqs, ivars)
if all(eq -> eq.rhs isa StateMachineOperator, MTKBase.get_eqs(ts.sys))
names = Symbol[]
for eq in get_eqs(ts.sys)
for eq in MTKBase.get_eqs(ts.sys)
if eq.lhs isa Transition
push!(names, first(MTKBase.namespace_hierarchy(nameof(eq.rhs.from))))
push!(names, first(MTKBase.namespace_hierarchy(nameof(eq.rhs.to))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ for (s, T) in [(:timeInState, :Real),
@eval begin
$s(x) = wrap(term($s, x))
SymbolicUtils.promote_symtype(::typeof($s), ::Type{S}) where {S} = $T
function SymbolicUtils.show_call(io, ::typeof($s), args)
if isempty(args)
print(io, $s, "()")
else
arg = only(args)
print(io, $s, "(", arg isa Number ? arg : nameof(arg), ")")
end
end
end
if s != :activeState
@eval $s() = wrap(term($s))
Expand Down
40 changes: 22 additions & 18 deletions lib/ModelingToolkitTearing/src/reassemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ function codegen_equation!(eg::EquationGenerator,
# current time. This works because we added one additional history element
# in `add_additional_history!`.
if isdisc
neweq = backshift_expr(neweq, idep)
neweq = backshift_expr(neweq, idep::SymbolicT)::Equation
end
push!(solved_eqs, neweq)
push!(solved_vars, iv)
Expand All @@ -794,7 +794,7 @@ function codegen_equation!(eg::EquationGenerator,
neweq = make_algebraic_equation(eq, total_sub)
# For the same reason as solved equations (they are effectively the same)
if isdisc
neweq = backshift_expr(neweq, idep)
neweq = backshift_expr(neweq, idep::SymbolicT)
end
push!(neweqs′, neweq)
push!(eq_ordering, ieq)
Expand Down Expand Up @@ -982,7 +982,7 @@ function update_simplified_system!(
end
@set! sys.unknowns = unknowns

obs = tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack)
obs = (@invokelatest tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack))::Vector{Equation}

@set! sys.eqs = neweqs
@set! sys.observed = obs
Expand Down Expand Up @@ -1240,7 +1240,7 @@ Backshift the given expression `ex`.
function backshift_expr(ex, iv)
ex isa SymbolicT || return ex
return descend_lower_shift_varname_with_unit(
MTKBase.simplify_shifts(MTKBase.distribute_shift(Shift(iv, -1)(ex))), iv)
MTKBase.simplify_shifts(MTKBase.distribute_shift(Shift(iv, -1)(ex))), iv)::SymbolicT
end

function backshift_expr(ex::Equation, iv)
Expand Down Expand Up @@ -1298,26 +1298,30 @@ function tearing_hacks(sys, obs, unknowns, neweqs; array = true)
rhs = eq.rhs

array || continue
iscall(lhs) || continue
operation(lhs) === getindex || continue
SU.shape(lhs) isa SU.Unknown && continue
arg1 = arguments(lhs)[1]
cnt = get(arr_obs_occurrences, arg1, 0)
arr_obs_occurrences[arg1] = cnt + 1
continue
Moshi.Match.@match lhs begin
BSImpl.Term(; f, args) && if f === getindex end => begin
arg1 = args[1]
cnt = get(arr_obs_occurrences, arg1, 0)
arr_obs_occurrences[arg1] = cnt + 1
end
_ => nothing
end
end

# count variables in unknowns if they are scalarized forms of variables
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
# is an observed equation.
for sym in unknowns
iscall(sym) || continue
operation(sym) === getindex || continue
SU.shape(sym) isa SU.Unknown && continue
arg1 = arguments(sym)[1]
cnt = get(arr_obs_occurrences, arg1, 0)
cnt == 0 && continue
arr_obs_occurrences[arg1] = cnt + 1
Moshi.Match.@match sym begin
BSImpl.Term(; f, args, shape) && if f === getindex end => begin
shape isa SU.Unknown && continue
arg1 = args[1]
cnt = get(arr_obs_occurrences, arg1, 0)
cnt == 0 && continue
arr_obs_occurrences[arg1] = cnt + 1
end
_ => nothing
end
end

obs_arr_eqs = Equation[]
Expand Down
4 changes: 2 additions & 2 deletions lib/ModelingToolkitTearing/src/tearingstate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true)
addvar!(vi, VARIABLE)
end
else
vv = collect(v)::Array{SymbolicT}
vv = vec(collect(v)::Array{SymbolicT})::Vector{SymbolicT}
union!(incidence, vv)
for vi in vv
addvar!(vi, VARIABLE)
Expand Down Expand Up @@ -456,7 +456,7 @@ end
function lower_order_var(dervar::SymbolicT, t::SymbolicT)
@match dervar begin
BSImpl.Term(; f, args) && if f isa Differential end => begin
order = f.order::Int
order::Int = f.order::Int
isone(order) && return args[1]
return Differential(f.x, order - 1)(args[1])
end
Expand Down
20 changes: 12 additions & 8 deletions lib/ModelingToolkitTearing/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ function descend_lower_shift_varname_with_unit(var, iv)
MTKBase._with_unit(descend_lower_shift_varname, var, iv, iv)
end
function descend_lower_shift_varname(var, iv)
iscall(var) || return var
op = operation(var)
if op isa Shift
return MTKBase.shift2term(var)
else
args = arguments(var)
args = map(Base.Fix2(descend_lower_shift_varname, iv), args)
return SU.maketerm(SymbolicT, op, args, SU.metadata(var))
@match var begin
BSImpl.Term(; f, args) && if f isa Shift end => MTKBase.shift2term(var)
if iscall(var) end => begin
args = arguments(var)
_args = SU.ArgsT{VartypeT}()
sizehint!(_args, length(args))
for arg in args
push!(_args, descend_lower_shift_varname(arg, iv))
end
return SU.maketerm(SymbolicT, operation(var), _args, SU.metadata(var))
end
_ => return var
end
end

Expand Down
Loading