diff --git a/src/compiler/transform/pipeline.jl b/src/compiler/transform/pipeline.jl index d0ec4ada..9fd59ac0 100644 --- a/src/compiler/transform/pipeline.jl +++ b/src/compiler/transform/pipeline.jl @@ -95,25 +95,25 @@ function commute_arith_transparent(sci, block, inst, match, driver) # Insert broadcast of the scalar to x's shape and register as constant x_shape = size(xT) bc_type = Tile{eltype(xT), Tuple{x_shape...}} - bc = insert_before!(block, val, Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type; + bc = insert_before!(driver.rewriter, block, val, + Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type; flag=inferred_flags(Intrinsics.broadcast)) - notify_insert!(driver, block, bc) # Side-inject the freshly synthesized constant into the dataflow result so # downstream pattern matches see it. Bypasses tmerge (this is a brand-new # SSA value, not a merge). driver.constants[SSAValue(bc)] = convert(eltype(xT), scalar) # Insert op(x, broadcast) with x's type - op = insert_before!(block, val, Expr(:call, root_func, x, SSAValue(bc)), xT; + op = insert_before!(driver.rewriter, block, val, + Expr(:call, root_func, x, SSAValue(bc)), xT; flag=inferred_flags(root_func)) - notify_insert!(driver, block, op) # Replace root with transparent_op(op_result, s). Func changes # (subi/addi → reshape/broadcast), so recompute the flag from the new # func's declared effects — the inferred bits describe the OLD op. - block[val.id] = (stmt=Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]), - flag=inferred_flags(transparent_func)) - driver.defs[val] = DefEntry(block, val, transparent_func) + replace_stmt!(driver.rewriter, block, val, + Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]); + flag=inferred_flags(transparent_func)) push!(driver.worklist, val) add_users_to_worklist!(driver, val) return true diff --git a/src/compiler/transform/rewrite.jl b/src/compiler/transform/rewrite.jl index fbcd0d56..7ab25b82 100644 --- a/src/compiler/transform/rewrite.jl +++ b/src/compiler/transform/rewrite.jl @@ -217,7 +217,7 @@ function def_operands(entry::DefEntry) end mutable struct RewriteDriver - sci::StructuredIRCode + rewriter::Rewriter defs::Dict{SSAValue, DefEntry} dispatch::Dict{Any, Vector{RewriteRule}} worklist::Worklist @@ -226,47 +226,66 @@ mutable struct RewriteDriver max_rewrites::Int end -"""Compute fresh use count for an SSA value.""" -use_count(driver::RewriteDriver, val::SSAValue) = - length(uses(driver.sci.entry, val)) - #============================================================================= - Notifications + Rewriter listener: driver-side bookkeeping reacts to IR mutations =============================================================================# -"""Add operand-producing instructions to the worklist (enables cascading).""" -function add_operands_to_worklist!(driver::RewriteDriver, entry::DefEntry) - for op in def_operands(entry) - op isa SSAValue || continue - haskey(driver.defs, op) && push!(driver.worklist, op) - end +# The driver has its own state on top of the Rewriter's index: a `defs` map +# keyed by func, the `worklist`, and a `modified` set for cascading. These +# `notify_*` hooks let the Rewriter keep that state in sync, so the driver +# doesn't need to wrap every mutation callsite. Same pattern as MLIR's +# `RewriterBase::Listener::notifyOperation{Inserted,Modified,Erased}`. + +function notify_inserted!(d::RewriteDriver, block::Block, inst::Instruction) + stmt = inst[:stmt] + call = resolve_call(block, stmt) + call === nothing && return + func, _ = call + val = SSAValue(inst) + d.defs[val] = DefEntry(block, val, func) + push!(d.worklist, val) end -"""Add instructions that use `val` to the worklist (their operand changed).""" -function add_users_to_worklist!(driver::RewriteDriver, val::SSAValue) - for inst in users(driver.sci.entry, val) - push!(driver.worklist, SSAValue(inst)) +function notify_modified!(d::RewriteDriver, block::Block, val::SSAValue, + @nospecialize(old_stmt), @nospecialize(new_stmt)) + # Refresh the def entry so worklist dispatch picks the new func. + call = resolve_call(block, new_stmt) + if call !== nothing + func, _ = call + d.defs[val] = DefEntry(block, val, func) end + # Worklist cascading is done by the callers (`apply_rewrite!`, + # `apply_inplace_rewrite!`, `commute_arith_transparent`), which re-seed + # `val` and its users themselves. end -"""Erase an instruction and notify the worklist.""" -function erase_op!(driver::RewriteDriver, entry::DefEntry) - add_operands_to_worklist!(driver, entry) - if haskey(entry.block, entry.val.id) - delete!(entry.block, entry.val.id) +function notify_erased!(d::RewriteDriver, ::Block, val::SSAValue, + @nospecialize(old_stmt)) + # Operand-defs may now be dead; cascade them to the worklist. + if old_stmt isa Expr + for_expr_operands(old_stmt) do op + op isa SSAValue || return + haskey(d.defs, op) && push!(d.worklist, op) + end end - delete!(driver.defs, entry.val) - remove!(driver.worklist, entry.val) + delete!(d.defs, val) + remove!(d.worklist, val) + delete!(d.modified, val) end -"""Register a newly inserted instruction.""" -function notify_insert!(driver::RewriteDriver, block::Block, inst::Instruction) - val = SSAValue(inst) - call = resolve_call(block, inst) - call === nothing && return - func, _ = call - driver.defs[val] = DefEntry(block, val, func) - push!(driver.worklist, val) +#============================================================================= + Driver-side query helpers (thin proxies to the Rewriter) +=============================================================================# + +users_of(driver::RewriteDriver, val::SSAValue) = users(driver.rewriter, val) + +use_count(driver::RewriteDriver, val::SSAValue) = use_count(driver.rewriter, val) + +"""Add instructions that use `val` to the worklist (their operand changed).""" +function add_users_to_worklist!(driver::RewriteDriver, val::SSAValue) + for u in users_of(driver, val) + push!(driver.worklist, u) + end end #============================================================================= @@ -291,7 +310,7 @@ function merge_bindings!(dest::Dict{Symbol,Any}, src::Dict{Symbol,Any}) end function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, - block::Block=driver.sci.entry) + block::Block=driver.rewriter.sci.entry) val isa SSAValue || return nothing entry = get(driver.defs, val, nothing) entry === nothing && return nothing @@ -322,11 +341,11 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, return nothing end -pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.sci.entry) = +pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.rewriter.sci.entry) = MatchResult(Dict{Symbol,Any}(pat.name => val), SSAValue[]) function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PTypedBind, - block::Block=driver.sci.entry) + block::Block=driver.rewriter.sci.entry) T = value_type(block, val) T === nothing && return nothing CC.widenconst(T) <: pat.type || return nothing @@ -334,7 +353,7 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PTypedBin end function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::POneUse, - block::Block=driver.sci.entry) + block::Block=driver.rewriter.sci.entry) val isa SSAValue && use_count(driver, val) == 1 || return nothing pattern_match(driver, val, pat.inner, block) end @@ -343,7 +362,7 @@ end # For non-SSA operands (enum constants, predicates): checks ===. # For SSA operands: routed through const_value on the ConstantAnalysis result. function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PLiteral, - block::Block=driver.sci.entry) + block::Block=driver.rewriter.sci.entry) val === pat.val && return MatchResult(Dict{Symbol,Any}(), SSAValue[]) if val isa SSAValue c = const_value(driver.constants, val) @@ -384,9 +403,8 @@ function resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, roo typ = CC.widenconst(t) break end - inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ; + inst = insert_before!(driver.rewriter, block, ref, Expr(:call, op.func, operands...), typ; flag=inferred_flags(op.func)) - notify_insert!(driver, block, inst) SSAValue(inst) end @@ -406,12 +424,8 @@ function apply_inplace_rewrite!(driver::RewriteDriver, block, val::SSAValue, rul # the flag from the new func's declared effects (`efunc` overrides), # mirroring inference's `flags_for_effects` — the inferred bits on the # old call describe the OLD op and don't carry over. - if rule.rhs.func === driver.defs[val].func - block[val.id] = (stmt=new_stmt,) - else - block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func)) - end - driver.defs[val] = DefEntry(block, val, rule.rhs.func) + flag = rule.rhs.func === driver.defs[val].func ? nothing : inferred_flags(rule.rhs.func) + replace_stmt!(driver.rewriter, block, val, new_stmt; flag) push!(driver.worklist, val) add_users_to_worklist!(driver, val) return true @@ -438,8 +452,8 @@ function resolve_inplace_rhs(driver, bindings, op::RCall, lhs_op::PCall) for (sub_rhs, sub_lhs) in zip(op.operands, lhs_op.operands)] # Sub-call func is enforced equal to the matched lhs func above # (`op.func === lhs_op.func`), so the flag is still valid — only operands - # changed. Partial-NamedTuple setindex preserves type and flag. - entry.block[matched_ssa.id] = (stmt=Expr(:call, op.func, new_ops...),) + # changed. + replace_stmt!(driver.rewriter, entry.block, matched_ssa, Expr(:call, op.func, new_ops...)) push!(driver.worklist, matched_ssa) return matched_ssa end @@ -454,25 +468,22 @@ The matched_ssas in MatchResult are ordered root-first, but we need to find the specific SSA for a sub-pattern. We do this by looking up the first operand's binding and finding the op that defines it.""" function find_matched_ssa(driver, pat::PCall, bindings) - entry = driver.sci.entry for sub in pat.operands if sub isa PBind bound = get(bindings, sub.name, nothing) bound isa SSAValue || continue - for inst in users(entry, bound) - call = resolve_call(entry, inst) - call === nothing && continue - func, _ = call - func === pat.func && return SSAValue(inst) + for u in users_of(driver, bound) + def_entry = get(driver.defs, u, nothing) + def_entry === nothing && continue + def_entry.func === pat.func && return u end elseif sub isa PCall inner_ssa = find_matched_ssa(driver, sub, bindings) if inner_ssa !== nothing - for inst in users(entry, inner_ssa) - call = resolve_call(entry, inst) - call === nothing && continue - func, _ = call - func === pat.func && return SSAValue(inst) + for u in users_of(driver, inner_ssa) + def_entry = get(driver.defs, u, nothing) + def_entry === nothing && continue + def_entry.func === pat.func && return u end end end @@ -491,7 +502,7 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match # Look up live instruction for RFunc interface haskey(block, val.id) || return false inst = block[val.id] - rule.rhs.func(driver.sci, block, inst, match, driver) || return false + rule.rhs.func(driver.rewriter.sci, block, inst, match, driver) || return false return true elseif rule.rhs isa RBind # Forwarding: replace all uses of root with the bound value, delete root. @@ -499,12 +510,13 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match # When these are later popped from the worklist without a match, the # driver propagates to THEIR users (see modified check in main loop). # This gives MLIR-style notifyOperationModified cascading. - for inst in users(driver.sci.entry, val) - push!(driver.modified, SSAValue(inst)) + new_val = match.bindings[rule.rhs.name] + for u in users_of(driver, val) + push!(driver.modified, u) + push!(driver.worklist, u) end - add_users_to_worklist!(driver, val) - replace_uses!(driver.sci.entry, val, match.bindings[rule.rhs.name]) - erase_op!(driver, entry) + replace_uses!(driver.rewriter, val, new_val) + erase!(driver.rewriter, entry.block, val) else # Substitution: replace root in-place, clean up dead intermediates. # Only delete intermediates with no remaining uses — transparent-op @@ -514,7 +526,7 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match dead_entry = get(driver.defs, dead_val, nothing) dead_entry === nothing && continue use_count(driver, dead_val) == 0 || continue - erase_op!(driver, dead_entry) + erase!(driver.rewriter, dead_entry.block, dead_val) end typ = block[val.id][:type] # Build operands, flattening RSplat nodes into multiple operands @@ -531,13 +543,8 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match # describe the OLD op; recompute from the new func's `efunc` effects # so downstream gates (CSE, LICM) see fresh, correct information. new_stmt = Expr(:call, rule.rhs.func, operands...) - if rule.rhs.func === driver.defs[val].func - block[val.id] = (stmt=new_stmt,) - else - block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func)) - end - # Update defs, re-add self and users to worklist (statement changed) - driver.defs[val] = DefEntry(block, val, rule.rhs.func) + flag = rule.rhs.func === driver.defs[val].func ? nothing : inferred_flags(rule.rhs.func) + replace_stmt!(driver.rewriter, block, val, new_stmt; flag) push!(driver.worklist, val) add_users_to_worklist!(driver, val) end @@ -584,7 +591,10 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; end end - driver = RewriteDriver(sci, defs, dispatch, wl, constants, Set{SSAValue}(), max_rewrites) + rewriter = Rewriter(sci) + driver = RewriteDriver(rewriter, defs, dispatch, wl, constants, Set{SSAValue}(), + max_rewrites) + rewriter.listener = driver num_rewrites = 0 while !isempty(driver.worklist) && num_rewrites < driver.max_rewrites @@ -605,7 +615,7 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; if use_count(driver, val) == 0 stmt = entry.block[val.id][:stmt] if !must_keep(entry.block, stmt) - erase_op!(driver, entry) + erase!(driver.rewriter, entry.block, val) continue end end @@ -634,10 +644,9 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; # cascade continues through the fixpoint. if !matched && val in driver.modified delete!(driver.modified, val) - for inst in users(driver.sci.entry, val) - uv = SSAValue(inst) - push!(driver.modified, uv) - haskey(driver.defs, uv) && push!(driver.worklist, uv) + for u in users_of(driver, val) + push!(driver.modified, u) + haskey(driver.defs, u) && push!(driver.worklist, u) end end end diff --git a/src/compiler/transform/rewriter.jl b/src/compiler/transform/rewriter.jl new file mode 100644 index 00000000..a3fbdab8 --- /dev/null +++ b/src/compiler/transform/rewriter.jl @@ -0,0 +1,278 @@ +# IR Rewriter, modeled on MLIR's RewriterBase/IRRewriter. +# +# All IR mutations during a rewrite session go through a `Rewriter`, which +# in return keeps an incremental use-def index up to date and fires +# `notify_*` hooks so higher-level drivers can react. Correspondences: +# +# RewriterBase -> Rewriter +# IROperand intrinsic use-list -> (users, extra_uses) maps +# RewriterBase::Listener -> notify_inserted!/_modified!/_erased! + +using Core: SSAValue + +#============================================================================= + Listener protocol +=============================================================================# + +# Default no-op listener callbacks. Drivers override by adding a method on +# their listener type (e.g. `notify_inserted!(d::RewriteDriver, block, inst)`). +"""Fires after the Rewriter inserts an instruction.""" +notify_inserted!(::Any, ::Block, ::Instruction) = nothing + +"""Fires when the Rewriter replaces a stmt. Drivers use this to react to +func changes or mark users dirty.""" +notify_modified!(::Any, ::Block, ::SSAValue, @nospecialize(old_stmt), @nospecialize(new_stmt)) = nothing + +"""Fires once an instruction has been erased; the SSA value is no longer live.""" +notify_erased!(::Any, ::Block, ::SSAValue, @nospecialize(old_stmt)) = nothing + +#============================================================================= + Rewriter +=============================================================================# + +""" + Rewriter(sci::StructuredIRCode; listener=nothing) + +IR mutation handle for `sci`. Holds an incremental use-def index; query it +with `users(rewriter, val)` and `use_count(rewriter, val)`. Mutate the IR +only through `insert_before!`, `replace_stmt!`, `erase!`, or +`replace_uses!` on the rewriter, which keeps the index consistent and fires +the listener. + +`listener` is dispatched on its type via the `notify_*` generic functions. +Defaults to `nothing` (no-op). +""" +mutable struct Rewriter + sci::StructuredIRCode + # `users`: SSA value -> SSAs of stmts that reference it. `extra_uses` + # holds the count of terminator uses, which have no SSA owner. + users::Dict{SSAValue, Vector{SSAValue}} + extra_uses::Dict{SSAValue, Int} + listener::Any +end + +function Rewriter(sci::StructuredIRCode; listener=nothing) + r = Rewriter(sci, Dict{SSAValue, Vector{SSAValue}}(), + Dict{SSAValue, Int}(), listener) + build_use_index!(r) + return r +end + +#============================================================================= + Queries +=============================================================================# + +"""Users (by SSA) of `val`, or an empty list if `val` has none recorded.""" +users(r::Rewriter, val::SSAValue) = get(r.users, val, SSAValue[]) +users(::Rewriter, ::Any) = SSAValue[] + +"""Total use count for `val`, including terminator uses.""" +use_count(r::Rewriter, val::SSAValue) = + length(users(r, val)) + get(r.extra_uses, val, 0) + +#============================================================================= + Mutation API +=============================================================================# + +""" + insert_before!(r::Rewriter, block::Block, ref, stmt, type; flag=UInt32(0)) + +Insert `stmt` at `ref` in `block`. Updates the use-index for the new stmt's +operands and fires `notify_inserted!`. Returns the new `Instruction`. +""" +function insert_before!(r::Rewriter, block::Block, ref, @nospecialize(stmt), + @nospecialize(type); flag::UInt32=UInt32(0)) + inst = IRStructurizer.insert_before!(block, ref, stmt, type; flag=flag) + register_stmt_uses!(r, SSAValue(inst), stmt) + notify_inserted!(r.listener, block, inst) + return inst +end + +""" + replace_stmt!(r::Rewriter, block::Block, val::SSAValue, new_stmt; flag=nothing) + +Replace the stmt at `val` with `new_stmt`. Updates the use-index for both +sets of operands and fires `notify_modified!`. Pass `flag` to also update +the IR_FLAG_* bitmask; otherwise only the stmt slot changes (type and flag +are preserved). +""" +function replace_stmt!(r::Rewriter, block::Block, val::SSAValue, + @nospecialize(new_stmt); flag::Union{UInt32, Nothing}=nothing) + haskey(block, val.id) || return + old_stmt = block[val.id][:stmt] + unregister_stmt_uses!(r, val, old_stmt) + if flag === nothing + block[val.id] = (stmt=new_stmt,) + else + block[val.id] = (stmt=new_stmt, flag=flag) + end + register_stmt_uses!(r, val, new_stmt) + notify_modified!(r.listener, block, val, old_stmt, new_stmt) +end + +""" + erase!(r::Rewriter, block::Block, val::SSAValue) + +Erase the stmt at `val` from `block`. Drops the op's operand uses from the +use-index, then fires `notify_erased!`. Listeners can inspect `old_stmt` to +cascade operand-defs through their own worklist for dead-op elim. +""" +function erase!(r::Rewriter, block::Block, val::SSAValue) + haskey(block, val.id) || return + old_stmt = block[val.id][:stmt] + unregister_stmt_uses!(r, val, old_stmt) + delete!(block, val.id) + delete!(r.users, val) + delete!(r.extra_uses, val) + notify_erased!(r.listener, block, val, old_stmt) +end + +""" + replace_uses!(r::Rewriter, val::SSAValue, new_val) + +RAUW: replace every use of `val` in the SCI with `new_val`, and move the +use-index entries to `new_val`. The walk covers every operand kind (Expr +args, CF op fields, terminators). No listener event fires; only operand +slots change, not the user stmts themselves. +""" +function replace_uses!(r::Rewriter, val::SSAValue, @nospecialize(new_val)) + IRStructurizer.replace_uses!(r.sci.entry, val, new_val) + transfer_uses!(r, val, new_val) +end + +#============================================================================= + Use-index maintenance (internal) +=============================================================================# + +function for_expr_operands(f, expr::Expr) + start = expr.head === :invoke ? 3 : 2 + for i in start:length(expr.args) + f(expr.args[i]) + end +end + +function add_use!(r::Rewriter, @nospecialize(operand), user::SSAValue) + operand isa SSAValue || return + push!(get!(Vector{SSAValue}, r.users, operand), user) +end + +function remove_use!(r::Rewriter, @nospecialize(operand), user::SSAValue) + operand isa SSAValue || return + list = get(r.users, operand, nothing) + list === nothing && return + i = findfirst(==(user), list) + i === nothing && return + deleteat!(list, i) + isempty(list) && delete!(r.users, operand) +end + +# Incremental maintenance only needs the Expr path: drivers never produce +# CF ops or terminators. `register_owned_stmt_uses!` handles those once at +# initial build time. +function register_stmt_uses!(r::Rewriter, val::SSAValue, @nospecialize(stmt)) + stmt isa Expr || return + for_expr_operands(stmt) do op + add_use!(r, op, val) + end +end + +function unregister_stmt_uses!(r::Rewriter, val::SSAValue, @nospecialize(stmt)) + stmt isa Expr || return + for_expr_operands(stmt) do op + remove_use!(r, op, val) + end +end + +# Reassign every index entry from `old` to `new_val` so subsequent queries +# find the users under `new_val`. The IR-level mutation lives in +# `replace_uses!`; this only touches the bookkeeping. +function transfer_uses!(r::Rewriter, old::SSAValue, @nospecialize(new_val)) + old_users = get(r.users, old, nothing) + if old_users !== nothing + if new_val isa SSAValue + dest = get!(Vector{SSAValue}, r.users, new_val) + append!(dest, old_users) + end + delete!(r.users, old) + end + old_extra = get(r.extra_uses, old, 0) + if old_extra > 0 + if new_val isa SSAValue + r.extra_uses[new_val] = get(r.extra_uses, new_val, 0) + old_extra + end + delete!(r.extra_uses, old) + end +end + +# One-shot SCI walk that seeds the use-index. Dispatch mirrors +# `IRStructurizer.walk_uses!(::Block)` so every operand kind is covered: +# Expr args, CF op fields, ReturnNode/PiNode/alias stmt operands, and +# block terminators. +function build_use_index!(r::Rewriter) + for block in eachblock(r.sci) + for i in 1:length(block.body.ssa_idxes) + owner = SSAValue(block.body.ssa_idxes[i]) + register_owned_stmt_uses!(r, owner, block.body.stmts[i]) + end + register_terminator_uses!(r, block.terminator) + end +end + +function register_owned_stmt_uses!(r::Rewriter, owner::SSAValue, @nospecialize(stmt)) + if stmt isa Expr + for_expr_operands(stmt) do op + add_use!(r, op, owner) + end + elseif stmt isa Core.ReturnNode + isdefined(stmt, :val) && add_use!(r, stmt.val, owner) + elseif stmt isa Core.PiNode + add_use!(r, stmt.val, owner) + elseif stmt isa SSAValue + # Alias/forwarding stmt: the stmt slot itself IS the operand. + add_use!(r, stmt, owner) + elseif stmt isa ControlFlowOp + register_cf_op_uses!(r, owner, stmt) + end +end + +function register_cf_op_uses!(r::Rewriter, owner::SSAValue, op::IfOp) + add_use!(r, op.condition, owner) +end + +function register_cf_op_uses!(r::Rewriter, owner::SSAValue, op::ForOp) + add_use!(r, op.lower, owner) + add_use!(r, op.upper, owner) + add_use!(r, op.step, owner) + for iv in op.init_values + add_use!(r, iv, owner) + end +end + +function register_cf_op_uses!(r::Rewriter, owner::SSAValue, + op::Union{WhileOp, LoopOp}) + for iv in op.init_values + add_use!(r, iv, owner) + end +end + +# Terminator operands have no SSA owner, so they're recorded as anonymous +# counts. The dead-op-elim shortcut in `use_count` needs them. +function register_terminator_uses!(r::Rewriter, term) + if term isa Core.ReturnNode + isdefined(term, :val) && bump_extra_use!(r, term.val) + elseif term isa ConditionOp + bump_extra_use!(r, term.condition) + for arg in term.args + bump_extra_use!(r, arg) + end + elseif term isa Union{ContinueOp, BreakOp, YieldOp} + for v in term.values + bump_extra_use!(r, v) + end + end +end + +function bump_extra_use!(r::Rewriter, @nospecialize(val)) + val isa SSAValue || return + r.extra_uses[val] = get(r.extra_uses, val, 0) + 1 +end diff --git a/src/cuTile.jl b/src/cuTile.jl index c1eed906..784e1918 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -5,7 +5,7 @@ using IRStructurizer: Block, ControlFlowOp, BlockArgument, YieldOp, ContinueOp, BreakOp, ConditionOp, IfOp, ForOp, WhileOp, LoopOp, Undef, SourceLocation -import IRStructurizer: operands +import IRStructurizer: operands, replace_uses!, insert_before! using Base: compilerbarrier, donotdelete using Core: MethodInstance, CodeInfo, SSAValue, Argument, SlotNumber, @@ -51,6 +51,7 @@ include("compiler/analysis/effects.jl") include("compiler/analysis/divisibility.jl") include("compiler/analysis/bounds.jl") include("compiler/analysis/assume.jl") +include("compiler/transform/rewriter.jl") include("compiler/transform/rewrite.jl") include("compiler/transform/canonicalize.jl") include("compiler/transform/control_flow.jl")