diff --git a/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl b/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl index 5a802b9..2398e58 100644 --- a/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl +++ b/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl @@ -152,18 +152,21 @@ function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_e end outdomain = output_timedomain(op) - @match outdomain begin - x::SciMLBase.AbstractClock => begin - push!(hyperedge, ClockVertex.Clock(x)) - end - InferredClock.Inferred() => nothing - InferredClock.InferredDiscrete(i) => begin - buffer = get(relative_hyperedges, i, nothing) - if buffer !== nothing - union!(hyperedge, buffer) - delete!(relative_hyperedges, i) + if outdomain isa SciMLBase.AbstractClock + push!(hyperedge, ClockVertex.Clock(outdomain)) + elseif outdomain isa InferredTimeDomain + @match outdomain begin + InferredClock.Inferred() => nothing + InferredClock.InferredDiscrete(i) => begin + buffer = get(relative_hyperedges, i, nothing) + if buffer !== nothing + union!(hyperedge, buffer) + delete!(relative_hyperedges, i) + end end end + else + error("Unreachable reached!") end for (_, relative_edge) in relative_hyperedges diff --git a/lib/ModelingToolkitTearing/src/clock_inference/interface.jl b/lib/ModelingToolkitTearing/src/clock_inference/interface.jl index bd92c9f..2024ff1 100644 --- a/lib/ModelingToolkitTearing/src/clock_inference/interface.jl +++ b/lib/ModelingToolkitTearing/src/clock_inference/interface.jl @@ -75,7 +75,7 @@ function system_subset(ts::TearingState, ieqs::Vector{Int}, iieqs::Vector{Int}, @set! ts.structure = system_subset(ts.structure, ieqs, ivars) if all(eq -> eq.rhs isa StateMachineOperator, MTKBase.get_eqs(ts.sys)) names = Symbol[] - for eq in get_eqs(ts.sys) + for eq in MTKBase.get_eqs(ts.sys) if eq.lhs isa Transition push!(names, first(MTKBase.namespace_hierarchy(nameof(eq.rhs.from)))) push!(names, first(MTKBase.namespace_hierarchy(nameof(eq.rhs.to)))) diff --git a/lib/ModelingToolkitTearing/src/clock_inference/state_machines.jl b/lib/ModelingToolkitTearing/src/clock_inference/state_machines.jl index ca01ea9..cd87ad1 100644 --- a/lib/ModelingToolkitTearing/src/clock_inference/state_machines.jl +++ b/lib/ModelingToolkitTearing/src/clock_inference/state_machines.jl @@ -76,14 +76,6 @@ for (s, T) in [(:timeInState, :Real), @eval begin $s(x) = wrap(term($s, x)) SymbolicUtils.promote_symtype(::typeof($s), ::Type{S}) where {S} = $T - function SymbolicUtils.show_call(io, ::typeof($s), args) - if isempty(args) - print(io, $s, "()") - else - arg = only(args) - print(io, $s, "(", arg isa Number ? arg : nameof(arg), ")") - end - end end if s != :activeState @eval $s() = wrap(term($s)) diff --git a/lib/ModelingToolkitTearing/src/reassemble.jl b/lib/ModelingToolkitTearing/src/reassemble.jl index b316ce6..375c41f 100644 --- a/lib/ModelingToolkitTearing/src/reassemble.jl +++ b/lib/ModelingToolkitTearing/src/reassemble.jl @@ -782,7 +782,7 @@ function codegen_equation!(eg::EquationGenerator, # current time. This works because we added one additional history element # in `add_additional_history!`. if isdisc - neweq = backshift_expr(neweq, idep) + neweq = backshift_expr(neweq, idep::SymbolicT)::Equation end push!(solved_eqs, neweq) push!(solved_vars, iv) @@ -794,7 +794,7 @@ function codegen_equation!(eg::EquationGenerator, neweq = make_algebraic_equation(eq, total_sub) # For the same reason as solved equations (they are effectively the same) if isdisc - neweq = backshift_expr(neweq, idep) + neweq = backshift_expr(neweq, idep::SymbolicT) end push!(neweqs′, neweq) push!(eq_ordering, ieq) @@ -982,7 +982,7 @@ function update_simplified_system!( end @set! sys.unknowns = unknowns - obs = tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack) + obs = (@invokelatest tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack))::Vector{Equation} @set! sys.eqs = neweqs @set! sys.observed = obs @@ -1240,7 +1240,7 @@ Backshift the given expression `ex`. function backshift_expr(ex, iv) ex isa SymbolicT || return ex return descend_lower_shift_varname_with_unit( - MTKBase.simplify_shifts(MTKBase.distribute_shift(Shift(iv, -1)(ex))), iv) + MTKBase.simplify_shifts(MTKBase.distribute_shift(Shift(iv, -1)(ex))), iv)::SymbolicT end function backshift_expr(ex::Equation, iv) @@ -1298,26 +1298,30 @@ function tearing_hacks(sys, obs, unknowns, neweqs; array = true) rhs = eq.rhs array || continue - iscall(lhs) || continue - operation(lhs) === getindex || continue - SU.shape(lhs) isa SU.Unknown && continue - arg1 = arguments(lhs)[1] - cnt = get(arr_obs_occurrences, arg1, 0) - arr_obs_occurrences[arg1] = cnt + 1 - continue + Moshi.Match.@match lhs begin + BSImpl.Term(; f, args) && if f === getindex end => begin + arg1 = args[1] + cnt = get(arr_obs_occurrences, arg1, 0) + arr_obs_occurrences[arg1] = cnt + 1 + end + _ => nothing + end end # count variables in unknowns if they are scalarized forms of variables # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)` # is an observed equation. for sym in unknowns - iscall(sym) || continue - operation(sym) === getindex || continue - SU.shape(sym) isa SU.Unknown && continue - arg1 = arguments(sym)[1] - cnt = get(arr_obs_occurrences, arg1, 0) - cnt == 0 && continue - arr_obs_occurrences[arg1] = cnt + 1 + Moshi.Match.@match sym begin + BSImpl.Term(; f, args, shape) && if f === getindex end => begin + shape isa SU.Unknown && continue + arg1 = args[1] + cnt = get(arr_obs_occurrences, arg1, 0) + cnt == 0 && continue + arr_obs_occurrences[arg1] = cnt + 1 + end + _ => nothing + end end obs_arr_eqs = Equation[] diff --git a/lib/ModelingToolkitTearing/src/tearingstate.jl b/lib/ModelingToolkitTearing/src/tearingstate.jl index f677a71..066c622 100644 --- a/lib/ModelingToolkitTearing/src/tearingstate.jl +++ b/lib/ModelingToolkitTearing/src/tearingstate.jl @@ -207,7 +207,7 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true) addvar!(vi, VARIABLE) end else - vv = collect(v)::Array{SymbolicT} + vv = vec(collect(v)::Array{SymbolicT})::Vector{SymbolicT} union!(incidence, vv) for vi in vv addvar!(vi, VARIABLE) @@ -456,7 +456,7 @@ end function lower_order_var(dervar::SymbolicT, t::SymbolicT) @match dervar begin BSImpl.Term(; f, args) && if f isa Differential end => begin - order = f.order::Int + order::Int = f.order::Int isone(order) && return args[1] return Differential(f.x, order - 1)(args[1]) end diff --git a/lib/ModelingToolkitTearing/src/utils.jl b/lib/ModelingToolkitTearing/src/utils.jl index 8c86347..c5f12a0 100644 --- a/lib/ModelingToolkitTearing/src/utils.jl +++ b/lib/ModelingToolkitTearing/src/utils.jl @@ -20,14 +20,18 @@ function descend_lower_shift_varname_with_unit(var, iv) MTKBase._with_unit(descend_lower_shift_varname, var, iv, iv) end function descend_lower_shift_varname(var, iv) - iscall(var) || return var - op = operation(var) - if op isa Shift - return MTKBase.shift2term(var) - else - args = arguments(var) - args = map(Base.Fix2(descend_lower_shift_varname, iv), args) - return SU.maketerm(SymbolicT, op, args, SU.metadata(var)) + @match var begin + BSImpl.Term(; f, args) && if f isa Shift end => MTKBase.shift2term(var) + if iscall(var) end => begin + args = arguments(var) + _args = SU.ArgsT{VartypeT}() + sizehint!(_args, length(args)) + for arg in args + push!(_args, descend_lower_shift_varname(arg, iv)) + end + return SU.maketerm(SymbolicT, operation(var), _args, SU.metadata(var)) + end + _ => return var end end