From 37f140e445c6093a34a064bce8265ef1af42f10c Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 28 May 2026 15:18:12 +0200 Subject: [PATCH 1/3] Cache user lookups. --- src/compiler/transform/pipeline.jl | 5 +- src/compiler/transform/rewrite.jl | 259 ++++++++++++++++++++++++----- 2 files changed, 220 insertions(+), 44 deletions(-) diff --git a/src/compiler/transform/pipeline.jl b/src/compiler/transform/pipeline.jl index d0ec4ada..0ba89a98 100644 --- a/src/compiler/transform/pipeline.jl +++ b/src/compiler/transform/pipeline.jl @@ -111,8 +111,9 @@ function commute_arith_transparent(sci, block, inst, match, driver) # Replace root with transparent_op(op_result, s). Func changes # (subi/addi → reshape/broadcast), so recompute the flag from the new # func's declared effects — the inferred bits describe the OLD op. - block[val.id] = (stmt=Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]), - flag=inferred_flags(transparent_func)) + update_stmt!(driver, block, val, + Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]); + flag=inferred_flags(transparent_func)) driver.defs[val] = DefEntry(block, val, transparent_func) push!(driver.worklist, val) add_users_to_worklist!(driver, val) diff --git a/src/compiler/transform/rewrite.jl b/src/compiler/transform/rewrite.jl index fbcd0d56..3008d003 100644 --- a/src/compiler/transform/rewrite.jl +++ b/src/compiler/transform/rewrite.jl @@ -223,12 +223,173 @@ mutable struct RewriteDriver worklist::Worklist constants::Union{Nothing, ConstantInfo} modified::Set{SSAValue} # instructions whose operands were modified by forwarding + users::Dict{SSAValue, Vector{SSAValue}} # val -> SSAs whose stmts reference it + extra_uses::Dict{SSAValue, Int} # val -> count of uses inside block terminators max_rewrites::Int end -"""Compute fresh use count for an SSA value.""" +"""Users (by SSA) of `val`, or an empty list if `val` has none recorded.""" +users_of(driver::RewriteDriver, val::SSAValue) = + get(driver.users, val, SSAValue[]) +users_of(::RewriteDriver, ::Any) = SSAValue[] + +"""Total use count for `val`, including terminator uses.""" use_count(driver::RewriteDriver, val::SSAValue) = - length(uses(driver.sci.entry, val)) + length(users_of(driver, val)) + get(driver.extra_uses, val, 0) + +# Iterate the SSA-positional operands of an Expr (skipping head/callee slot). +function for_expr_operands(f, expr::Expr) + start = expr.head === :invoke ? 3 : 2 + for i in start:length(expr.args) + f(expr.args[i]) + end +end + +function add_use!(driver::RewriteDriver, @nospecialize(operand), user::SSAValue) + operand isa SSAValue || return + push!(get!(Vector{SSAValue}, driver.users, operand), user) +end + +function remove_use!(driver::RewriteDriver, @nospecialize(operand), user::SSAValue) + operand isa SSAValue || return + users = get(driver.users, operand, nothing) + users === nothing && return + i = findfirst(==(user), users) + i === nothing && return + deleteat!(users, i) + isempty(users) && delete!(driver.users, operand) +end + +# Register/unregister operand uses for a stmt owned by `val`. The rewriter +# only ever creates or replaces `Expr` stmts (CF ops and terminators are +# untouched), so we only need the Expr path here. Initial-build populates the +# other stmt kinds in `build_use_index!`. +function register_stmt_uses!(driver::RewriteDriver, val::SSAValue, @nospecialize(stmt)) + stmt isa Expr || return + for_expr_operands(stmt) do op + add_use!(driver, op, val) + end +end + +function unregister_stmt_uses!(driver::RewriteDriver, val::SSAValue, @nospecialize(stmt)) + stmt isa Expr || return + for_expr_operands(stmt) do op + remove_use!(driver, op, val) + end +end + +""" +Replace the stmt at `val` with `new_stmt`, updating the use-index. If `flag` +is given, also update the IR_FLAG_* bitmask; otherwise the existing flag and +type are preserved (matches the existing partial-NamedTuple `block[val.id] =` +idiom). Called by every driver-side mutation that swaps an Expr in place. +""" +function update_stmt!(driver::RewriteDriver, block::Block, val::SSAValue, + new_stmt; flag::Union{UInt32, Nothing}=nothing) + haskey(block, val.id) || return + old_stmt = block[val.id][:stmt] + unregister_stmt_uses!(driver, val, old_stmt) + if flag === nothing + block[val.id] = (stmt=new_stmt,) + else + block[val.id] = (stmt=new_stmt, flag=flag) + end + register_stmt_uses!(driver, val, new_stmt) +end + +# Build the initial use-index by walking the entire SCI once. Mirrors the +# dispatch in `IRStructurizer.walk_uses!(::Block)` so every operand kind that +# the IR could reference is captured: Expr args, control-flow op fields, +# ReturnNode/PiNode/value-like stmt operands, and block terminators. +function build_use_index!(driver::RewriteDriver) + for block in eachblock(driver.sci) + for i in 1:length(block.body.ssa_idxes) + owner = SSAValue(block.body.ssa_idxes[i]) + register_owned_stmt_uses!(driver, owner, block.body.stmts[i]) + end + register_terminator_uses!(driver, block.terminator) + end +end + +function register_owned_stmt_uses!(driver::RewriteDriver, owner::SSAValue, @nospecialize(stmt)) + if stmt isa Expr + for_expr_operands(stmt) do op + add_use!(driver, op, owner) + end + elseif stmt isa Core.ReturnNode + isdefined(stmt, :val) && add_use!(driver, stmt.val, owner) + elseif stmt isa Core.PiNode + add_use!(driver, stmt.val, owner) + elseif stmt isa SSAValue + # Alias/forwarding stmt: the stmt slot itself IS the operand. + add_use!(driver, stmt, owner) + elseif stmt isa ControlFlowOp + register_cf_op_uses!(driver, owner, stmt) + end +end + +function register_cf_op_uses!(driver::RewriteDriver, owner::SSAValue, op::IfOp) + add_use!(driver, op.condition, owner) +end + +function register_cf_op_uses!(driver::RewriteDriver, owner::SSAValue, op::ForOp) + add_use!(driver, op.lower, owner) + add_use!(driver, op.upper, owner) + add_use!(driver, op.step, owner) + for iv in op.init_values + add_use!(driver, iv, owner) + end +end + +function register_cf_op_uses!(driver::RewriteDriver, owner::SSAValue, + op::Union{WhileOp, LoopOp}) + for iv in op.init_values + add_use!(driver, iv, owner) + end +end + +# Terminator operands have no SSA owner; record them as anonymous counts so +# `use_count` reflects them (needed for the dead-op elim shortcut). +function register_terminator_uses!(driver::RewriteDriver, term) + if term isa Core.ReturnNode + isdefined(term, :val) && bump_extra_use!(driver, term.val) + elseif term isa ConditionOp + bump_extra_use!(driver, term.condition) + for arg in term.args + bump_extra_use!(driver, arg) + end + elseif term isa Union{ContinueOp, BreakOp, YieldOp} + for v in term.values + bump_extra_use!(driver, v) + end + end +end + +function bump_extra_use!(driver::RewriteDriver, @nospecialize(val)) + val isa SSAValue || return + driver.extra_uses[val] = get(driver.extra_uses, val, 0) + 1 +end + +# Transfer index entries from `old` to `new_val` after a RAUW. The IR-level +# `replace_uses!` has already rewritten the user stmts; we just move the +# bookkeeping so subsequent queries find users under `new_val`. +function transfer_uses!(driver::RewriteDriver, old::SSAValue, @nospecialize(new_val)) + old_users = get(driver.users, old, nothing) + if old_users !== nothing + if new_val isa SSAValue + dest = get!(Vector{SSAValue}, driver.users, new_val) + append!(dest, old_users) + end + delete!(driver.users, old) + end + old_extra = get(driver.extra_uses, old, 0) + if old_extra > 0 + if new_val isa SSAValue + driver.extra_uses[new_val] = get(driver.extra_uses, new_val, 0) + old_extra + end + delete!(driver.extra_uses, old) + end +end #============================================================================= Notifications @@ -244,25 +405,40 @@ end """Add instructions that use `val` to the worklist (their operand changed).""" function add_users_to_worklist!(driver::RewriteDriver, val::SSAValue) - for inst in users(driver.sci.entry, val) - push!(driver.worklist, SSAValue(inst)) + for u in users_of(driver, val) + push!(driver.worklist, u) end end """Erase an instruction and notify the worklist.""" function erase_op!(driver::RewriteDriver, entry::DefEntry) - add_operands_to_worklist!(driver, entry) - if haskey(entry.block, entry.val.id) - delete!(entry.block, entry.val.id) + val = entry.val + if haskey(entry.block, val.id) + old_stmt = entry.block[val.id][:stmt] + # Deregister this op's operand uses AND enqueue operand-defs (their + # use count just dropped — may enable further dead-op elim). + if old_stmt isa Expr + for_expr_operands(old_stmt) do op + op isa SSAValue || return + remove_use!(driver, op, val) + haskey(driver.defs, op) && push!(driver.worklist, op) + end + end + delete!(entry.block, val.id) end - delete!(driver.defs, entry.val) - remove!(driver.worklist, entry.val) + # Drop val's own entries; any leftover claims would now be stale. + delete!(driver.users, val) + delete!(driver.extra_uses, val) + delete!(driver.defs, val) + remove!(driver.worklist, val) end """Register a newly inserted instruction.""" function notify_insert!(driver::RewriteDriver, block::Block, inst::Instruction) val = SSAValue(inst) - call = resolve_call(block, inst) + stmt = inst[:stmt] + register_stmt_uses!(driver, val, stmt) + call = resolve_call(block, stmt) call === nothing && return func, _ = call driver.defs[val] = DefEntry(block, val, func) @@ -406,11 +582,8 @@ function apply_inplace_rewrite!(driver::RewriteDriver, block, val::SSAValue, rul # the flag from the new func's declared effects (`efunc` overrides), # mirroring inference's `flags_for_effects` — the inferred bits on the # old call describe the OLD op and don't carry over. - if rule.rhs.func === driver.defs[val].func - block[val.id] = (stmt=new_stmt,) - else - block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func)) - end + flag = rule.rhs.func === driver.defs[val].func ? nothing : inferred_flags(rule.rhs.func) + update_stmt!(driver, block, val, new_stmt; flag) driver.defs[val] = DefEntry(block, val, rule.rhs.func) push!(driver.worklist, val) add_users_to_worklist!(driver, val) @@ -438,8 +611,9 @@ function resolve_inplace_rhs(driver, bindings, op::RCall, lhs_op::PCall) for (sub_rhs, sub_lhs) in zip(op.operands, lhs_op.operands)] # Sub-call func is enforced equal to the matched lhs func above # (`op.func === lhs_op.func`), so the flag is still valid — only operands - # changed. Partial-NamedTuple setindex preserves type and flag. - entry.block[matched_ssa.id] = (stmt=Expr(:call, op.func, new_ops...),) + # changed. `update_stmt!` preserves type and flag via partial-NamedTuple + # setindex and keeps the use-index in sync. + update_stmt!(driver, entry.block, matched_ssa, Expr(:call, op.func, new_ops...)) push!(driver.worklist, matched_ssa) return matched_ssa end @@ -454,25 +628,22 @@ The matched_ssas in MatchResult are ordered root-first, but we need to find the specific SSA for a sub-pattern. We do this by looking up the first operand's binding and finding the op that defines it.""" function find_matched_ssa(driver, pat::PCall, bindings) - entry = driver.sci.entry for sub in pat.operands if sub isa PBind bound = get(bindings, sub.name, nothing) bound isa SSAValue || continue - for inst in users(entry, bound) - call = resolve_call(entry, inst) - call === nothing && continue - func, _ = call - func === pat.func && return SSAValue(inst) + for u in users_of(driver, bound) + def_entry = get(driver.defs, u, nothing) + def_entry === nothing && continue + def_entry.func === pat.func && return u end elseif sub isa PCall inner_ssa = find_matched_ssa(driver, sub, bindings) if inner_ssa !== nothing - for inst in users(entry, inner_ssa) - call = resolve_call(entry, inst) - call === nothing && continue - func, _ = call - func === pat.func && return SSAValue(inst) + for u in users_of(driver, inner_ssa) + def_entry = get(driver.defs, u, nothing) + def_entry === nothing && continue + def_entry.func === pat.func && return u end end end @@ -499,11 +670,16 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match # When these are later popped from the worklist without a match, the # driver propagates to THEIR users (see modified check in main loop). # This gives MLIR-style notifyOperationModified cascading. - for inst in users(driver.sci.entry, val) - push!(driver.modified, SSAValue(inst)) + new_val = match.bindings[rule.rhs.name] + for u in users_of(driver, val) + push!(driver.modified, u) + push!(driver.worklist, u) end - add_users_to_worklist!(driver, val) - replace_uses!(driver.sci.entry, val, match.bindings[rule.rhs.name]) + # IR-level RAUW walks the SCI and rewrites every use-site; the + # use-index is then updated to reflect that those user stmts now + # reference `new_val` instead. + replace_uses!(driver.sci.entry, val, new_val) + transfer_uses!(driver, val, new_val) erase_op!(driver, entry) else # Substitution: replace root in-place, clean up dead intermediates. @@ -531,11 +707,8 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match # describe the OLD op; recompute from the new func's `efunc` effects # so downstream gates (CSE, LICM) see fresh, correct information. new_stmt = Expr(:call, rule.rhs.func, operands...) - if rule.rhs.func === driver.defs[val].func - block[val.id] = (stmt=new_stmt,) - else - block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func)) - end + flag = rule.rhs.func === driver.defs[val].func ? nothing : inferred_flags(rule.rhs.func) + update_stmt!(driver, block, val, new_stmt; flag) # Update defs, re-add self and users to worklist (statement changed) driver.defs[val] = DefEntry(block, val, rule.rhs.func) push!(driver.worklist, val) @@ -584,7 +757,10 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; end end - driver = RewriteDriver(sci, defs, dispatch, wl, constants, Set{SSAValue}(), max_rewrites) + driver = RewriteDriver(sci, defs, dispatch, wl, constants, Set{SSAValue}(), + Dict{SSAValue, Vector{SSAValue}}(), Dict{SSAValue, Int}(), + max_rewrites) + build_use_index!(driver) num_rewrites = 0 while !isempty(driver.worklist) && num_rewrites < driver.max_rewrites @@ -634,10 +810,9 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; # cascade continues through the fixpoint. if !matched && val in driver.modified delete!(driver.modified, val) - for inst in users(driver.sci.entry, val) - uv = SSAValue(inst) - push!(driver.modified, uv) - haskey(driver.defs, uv) && push!(driver.worklist, uv) + for u in users_of(driver, val) + push!(driver.modified, u) + haskey(driver.defs, u) && push!(driver.worklist, u) end end end From ebce35955291e051f7386afe31cb357635ab521b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 28 May 2026 15:18:21 +0200 Subject: [PATCH 2/3] Introduce Rewriter abstraction. --- src/compiler/transform/pipeline.jl | 15 +- src/compiler/transform/rewrite.jl | 299 +++++++---------------------- src/compiler/transform/rewriter.jl | 291 ++++++++++++++++++++++++++++ src/cuTile.jl | 3 +- 4 files changed, 369 insertions(+), 239 deletions(-) create mode 100644 src/compiler/transform/rewriter.jl diff --git a/src/compiler/transform/pipeline.jl b/src/compiler/transform/pipeline.jl index 0ba89a98..9fd59ac0 100644 --- a/src/compiler/transform/pipeline.jl +++ b/src/compiler/transform/pipeline.jl @@ -95,26 +95,25 @@ function commute_arith_transparent(sci, block, inst, match, driver) # Insert broadcast of the scalar to x's shape and register as constant x_shape = size(xT) bc_type = Tile{eltype(xT), Tuple{x_shape...}} - bc = insert_before!(block, val, Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type; + bc = insert_before!(driver.rewriter, block, val, + Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type; flag=inferred_flags(Intrinsics.broadcast)) - notify_insert!(driver, block, bc) # Side-inject the freshly synthesized constant into the dataflow result so # downstream pattern matches see it. Bypasses tmerge (this is a brand-new # SSA value, not a merge). driver.constants[SSAValue(bc)] = convert(eltype(xT), scalar) # Insert op(x, broadcast) with x's type - op = insert_before!(block, val, Expr(:call, root_func, x, SSAValue(bc)), xT; + op = insert_before!(driver.rewriter, block, val, + Expr(:call, root_func, x, SSAValue(bc)), xT; flag=inferred_flags(root_func)) - notify_insert!(driver, block, op) # Replace root with transparent_op(op_result, s). Func changes # (subi/addi → reshape/broadcast), so recompute the flag from the new # func's declared effects — the inferred bits describe the OLD op. - update_stmt!(driver, block, val, - Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]); - flag=inferred_flags(transparent_func)) - driver.defs[val] = DefEntry(block, val, transparent_func) + replace_stmt!(driver.rewriter, block, val, + Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]); + flag=inferred_flags(transparent_func)) push!(driver.worklist, val) add_users_to_worklist!(driver, val) return true diff --git a/src/compiler/transform/rewrite.jl b/src/compiler/transform/rewrite.jl index 3008d003..eba6d65c 100644 --- a/src/compiler/transform/rewrite.jl +++ b/src/compiler/transform/rewrite.jl @@ -217,191 +217,71 @@ function def_operands(entry::DefEntry) end mutable struct RewriteDriver - sci::StructuredIRCode + rewriter::Rewriter defs::Dict{SSAValue, DefEntry} dispatch::Dict{Any, Vector{RewriteRule}} worklist::Worklist constants::Union{Nothing, ConstantInfo} modified::Set{SSAValue} # instructions whose operands were modified by forwarding - users::Dict{SSAValue, Vector{SSAValue}} # val -> SSAs whose stmts reference it - extra_uses::Dict{SSAValue, Int} # val -> count of uses inside block terminators max_rewrites::Int end -"""Users (by SSA) of `val`, or an empty list if `val` has none recorded.""" -users_of(driver::RewriteDriver, val::SSAValue) = - get(driver.users, val, SSAValue[]) -users_of(::RewriteDriver, ::Any) = SSAValue[] - -"""Total use count for `val`, including terminator uses.""" -use_count(driver::RewriteDriver, val::SSAValue) = - length(users_of(driver, val)) + get(driver.extra_uses, val, 0) - -# Iterate the SSA-positional operands of an Expr (skipping head/callee slot). -function for_expr_operands(f, expr::Expr) - start = expr.head === :invoke ? 3 : 2 - for i in start:length(expr.args) - f(expr.args[i]) - end -end - -function add_use!(driver::RewriteDriver, @nospecialize(operand), user::SSAValue) - operand isa SSAValue || return - push!(get!(Vector{SSAValue}, driver.users, operand), user) -end - -function remove_use!(driver::RewriteDriver, @nospecialize(operand), user::SSAValue) - operand isa SSAValue || return - users = get(driver.users, operand, nothing) - users === nothing && return - i = findfirst(==(user), users) - i === nothing && return - deleteat!(users, i) - isempty(users) && delete!(driver.users, operand) -end - -# Register/unregister operand uses for a stmt owned by `val`. The rewriter -# only ever creates or replaces `Expr` stmts (CF ops and terminators are -# untouched), so we only need the Expr path here. Initial-build populates the -# other stmt kinds in `build_use_index!`. -function register_stmt_uses!(driver::RewriteDriver, val::SSAValue, @nospecialize(stmt)) - stmt isa Expr || return - for_expr_operands(stmt) do op - add_use!(driver, op, val) - end -end - -function unregister_stmt_uses!(driver::RewriteDriver, val::SSAValue, @nospecialize(stmt)) - stmt isa Expr || return - for_expr_operands(stmt) do op - remove_use!(driver, op, val) - end -end - -""" -Replace the stmt at `val` with `new_stmt`, updating the use-index. If `flag` -is given, also update the IR_FLAG_* bitmask; otherwise the existing flag and -type are preserved (matches the existing partial-NamedTuple `block[val.id] =` -idiom). Called by every driver-side mutation that swaps an Expr in place. -""" -function update_stmt!(driver::RewriteDriver, block::Block, val::SSAValue, - new_stmt; flag::Union{UInt32, Nothing}=nothing) - haskey(block, val.id) || return - old_stmt = block[val.id][:stmt] - unregister_stmt_uses!(driver, val, old_stmt) - if flag === nothing - block[val.id] = (stmt=new_stmt,) - else - block[val.id] = (stmt=new_stmt, flag=flag) - end - register_stmt_uses!(driver, val, new_stmt) -end - -# Build the initial use-index by walking the entire SCI once. Mirrors the -# dispatch in `IRStructurizer.walk_uses!(::Block)` so every operand kind that -# the IR could reference is captured: Expr args, control-flow op fields, -# ReturnNode/PiNode/value-like stmt operands, and block terminators. -function build_use_index!(driver::RewriteDriver) - for block in eachblock(driver.sci) - for i in 1:length(block.body.ssa_idxes) - owner = SSAValue(block.body.ssa_idxes[i]) - register_owned_stmt_uses!(driver, owner, block.body.stmts[i]) - end - register_terminator_uses!(driver, block.terminator) - end -end - -function register_owned_stmt_uses!(driver::RewriteDriver, owner::SSAValue, @nospecialize(stmt)) - if stmt isa Expr - for_expr_operands(stmt) do op - add_use!(driver, op, owner) - end - elseif stmt isa Core.ReturnNode - isdefined(stmt, :val) && add_use!(driver, stmt.val, owner) - elseif stmt isa Core.PiNode - add_use!(driver, stmt.val, owner) - elseif stmt isa SSAValue - # Alias/forwarding stmt: the stmt slot itself IS the operand. - add_use!(driver, stmt, owner) - elseif stmt isa ControlFlowOp - register_cf_op_uses!(driver, owner, stmt) - end -end - -function register_cf_op_uses!(driver::RewriteDriver, owner::SSAValue, op::IfOp) - add_use!(driver, op.condition, owner) -end - -function register_cf_op_uses!(driver::RewriteDriver, owner::SSAValue, op::ForOp) - add_use!(driver, op.lower, owner) - add_use!(driver, op.upper, owner) - add_use!(driver, op.step, owner) - for iv in op.init_values - add_use!(driver, iv, owner) - end -end - -function register_cf_op_uses!(driver::RewriteDriver, owner::SSAValue, - op::Union{WhileOp, LoopOp}) - for iv in op.init_values - add_use!(driver, iv, owner) - end -end +#============================================================================= + Rewriter listener — driver-side bookkeeping reacts to IR mutations +=============================================================================# -# Terminator operands have no SSA owner; record them as anonymous counts so -# `use_count` reflects them (needed for the dead-op elim shortcut). -function register_terminator_uses!(driver::RewriteDriver, term) - if term isa Core.ReturnNode - isdefined(term, :val) && bump_extra_use!(driver, term.val) - elseif term isa ConditionOp - bump_extra_use!(driver, term.condition) - for arg in term.args - bump_extra_use!(driver, arg) - end - elseif term isa Union{ContinueOp, BreakOp, YieldOp} - for v in term.values - bump_extra_use!(driver, v) - end - end -end +# The driver maintains its own auxiliary state (a `defs` map keyed by func, +# a worklist, and a `modified` set for cascading propagation) on top of the +# Rewriter's IR-level state. The Rewriter calls these `notify_*` hooks at +# every mutation; the driver uses them to keep its state in sync without +# having to wrap each mutation site itself. Mirrors MLIR's Listener pattern +# (`RewriterBase::Listener::notifyOperation{Inserted,Modified,Erased}`). -function bump_extra_use!(driver::RewriteDriver, @nospecialize(val)) - val isa SSAValue || return - driver.extra_uses[val] = get(driver.extra_uses, val, 0) + 1 -end - -# Transfer index entries from `old` to `new_val` after a RAUW. The IR-level -# `replace_uses!` has already rewritten the user stmts; we just move the -# bookkeeping so subsequent queries find users under `new_val`. -function transfer_uses!(driver::RewriteDriver, old::SSAValue, @nospecialize(new_val)) - old_users = get(driver.users, old, nothing) - if old_users !== nothing - if new_val isa SSAValue - dest = get!(Vector{SSAValue}, driver.users, new_val) - append!(dest, old_users) - end - delete!(driver.users, old) - end - old_extra = get(driver.extra_uses, old, 0) - if old_extra > 0 - if new_val isa SSAValue - driver.extra_uses[new_val] = get(driver.extra_uses, new_val, 0) + old_extra +function notify_inserted!(d::RewriteDriver, block::Block, inst::Instruction) + stmt = inst[:stmt] + call = resolve_call(block, stmt) + call === nothing && return + func, _ = call + val = SSAValue(inst) + d.defs[val] = DefEntry(block, val, func) + push!(d.worklist, val) +end + +function notify_modified!(d::RewriteDriver, block::Block, val::SSAValue, + @nospecialize(old_stmt), @nospecialize(new_stmt)) + # Func may have changed (substitution rewrites); refresh the def entry + # so the next worklist visit sees the new func for rule dispatch. + call = resolve_call(block, new_stmt) + if call !== nothing + func, _ = call + d.defs[val] = DefEntry(block, val, func) + end + # No worklist cascade here — callers (`apply_rewrite!`, + # `apply_inplace_rewrite!`, `commute_arith_transparent`) handle re-seeding + # of `val` + its users explicitly after the mutation returns. +end + +function notify_erased!(d::RewriteDriver, ::Block, val::SSAValue, + @nospecialize(old_stmt)) + # Operand-defs of the erased op may now be dead — cascade to worklist. + if old_stmt isa Expr + for_expr_operands(old_stmt) do op + op isa SSAValue || return + haskey(d.defs, op) && push!(d.worklist, op) end - delete!(driver.extra_uses, old) end + delete!(d.defs, val) + remove!(d.worklist, val) + delete!(d.modified, val) end #============================================================================= - Notifications + Driver-side query helpers (thin proxies to the Rewriter) =============================================================================# -"""Add operand-producing instructions to the worklist (enables cascading).""" -function add_operands_to_worklist!(driver::RewriteDriver, entry::DefEntry) - for op in def_operands(entry) - op isa SSAValue || continue - haskey(driver.defs, op) && push!(driver.worklist, op) - end -end +users_of(driver::RewriteDriver, val::SSAValue) = users(driver.rewriter, val) + +use_count(driver::RewriteDriver, val::SSAValue) = use_count(driver.rewriter, val) """Add instructions that use `val` to the worklist (their operand changed).""" function add_users_to_worklist!(driver::RewriteDriver, val::SSAValue) @@ -410,41 +290,6 @@ function add_users_to_worklist!(driver::RewriteDriver, val::SSAValue) end end -"""Erase an instruction and notify the worklist.""" -function erase_op!(driver::RewriteDriver, entry::DefEntry) - val = entry.val - if haskey(entry.block, val.id) - old_stmt = entry.block[val.id][:stmt] - # Deregister this op's operand uses AND enqueue operand-defs (their - # use count just dropped — may enable further dead-op elim). - if old_stmt isa Expr - for_expr_operands(old_stmt) do op - op isa SSAValue || return - remove_use!(driver, op, val) - haskey(driver.defs, op) && push!(driver.worklist, op) - end - end - delete!(entry.block, val.id) - end - # Drop val's own entries; any leftover claims would now be stale. - delete!(driver.users, val) - delete!(driver.extra_uses, val) - delete!(driver.defs, val) - remove!(driver.worklist, val) -end - -"""Register a newly inserted instruction.""" -function notify_insert!(driver::RewriteDriver, block::Block, inst::Instruction) - val = SSAValue(inst) - stmt = inst[:stmt] - register_stmt_uses!(driver, val, stmt) - call = resolve_call(block, stmt) - call === nothing && return - func, _ = call - driver.defs[val] = DefEntry(block, val, func) - push!(driver.worklist, val) -end - #============================================================================= Matching =============================================================================# @@ -467,7 +312,7 @@ function merge_bindings!(dest::Dict{Symbol,Any}, src::Dict{Symbol,Any}) end function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, - block::Block=driver.sci.entry) + block::Block=driver.rewriter.sci.entry) val isa SSAValue || return nothing entry = get(driver.defs, val, nothing) entry === nothing && return nothing @@ -498,11 +343,11 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, return nothing end -pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.sci.entry) = +pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.rewriter.sci.entry) = MatchResult(Dict{Symbol,Any}(pat.name => val), SSAValue[]) function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PTypedBind, - block::Block=driver.sci.entry) + block::Block=driver.rewriter.sci.entry) T = value_type(block, val) T === nothing && return nothing CC.widenconst(T) <: pat.type || return nothing @@ -510,7 +355,7 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PTypedBin end function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::POneUse, - block::Block=driver.sci.entry) + block::Block=driver.rewriter.sci.entry) val isa SSAValue && use_count(driver, val) == 1 || return nothing pattern_match(driver, val, pat.inner, block) end @@ -519,7 +364,7 @@ end # For non-SSA operands (enum constants, predicates): checks ===. # For SSA operands: routed through const_value on the ConstantAnalysis result. function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PLiteral, - block::Block=driver.sci.entry) + block::Block=driver.rewriter.sci.entry) val === pat.val && return MatchResult(Dict{Symbol,Any}(), SSAValue[]) if val isa SSAValue c = const_value(driver.constants, val) @@ -560,9 +405,8 @@ function resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, roo typ = CC.widenconst(t) break end - inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ; + inst = insert_before!(driver.rewriter, block, ref, Expr(:call, op.func, operands...), typ; flag=inferred_flags(op.func)) - notify_insert!(driver, block, inst) SSAValue(inst) end @@ -583,8 +427,7 @@ function apply_inplace_rewrite!(driver::RewriteDriver, block, val::SSAValue, rul # mirroring inference's `flags_for_effects` — the inferred bits on the # old call describe the OLD op and don't carry over. flag = rule.rhs.func === driver.defs[val].func ? nothing : inferred_flags(rule.rhs.func) - update_stmt!(driver, block, val, new_stmt; flag) - driver.defs[val] = DefEntry(block, val, rule.rhs.func) + replace_stmt!(driver.rewriter, block, val, new_stmt; flag) push!(driver.worklist, val) add_users_to_worklist!(driver, val) return true @@ -611,9 +454,9 @@ function resolve_inplace_rhs(driver, bindings, op::RCall, lhs_op::PCall) for (sub_rhs, sub_lhs) in zip(op.operands, lhs_op.operands)] # Sub-call func is enforced equal to the matched lhs func above # (`op.func === lhs_op.func`), so the flag is still valid — only operands - # changed. `update_stmt!` preserves type and flag via partial-NamedTuple - # setindex and keeps the use-index in sync. - update_stmt!(driver, entry.block, matched_ssa, Expr(:call, op.func, new_ops...)) + # changed. The rewriter's `replace_stmt!` preserves type and flag via + # partial-NamedTuple setindex. + replace_stmt!(driver.rewriter, entry.block, matched_ssa, Expr(:call, op.func, new_ops...)) push!(driver.worklist, matched_ssa) return matched_ssa end @@ -662,7 +505,7 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match # Look up live instruction for RFunc interface haskey(block, val.id) || return false inst = block[val.id] - rule.rhs.func(driver.sci, block, inst, match, driver) || return false + rule.rhs.func(driver.rewriter.sci, block, inst, match, driver) || return false return true elseif rule.rhs isa RBind # Forwarding: replace all uses of root with the bound value, delete root. @@ -675,12 +518,8 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match push!(driver.modified, u) push!(driver.worklist, u) end - # IR-level RAUW walks the SCI and rewrites every use-site; the - # use-index is then updated to reflect that those user stmts now - # reference `new_val` instead. - replace_uses!(driver.sci.entry, val, new_val) - transfer_uses!(driver, val, new_val) - erase_op!(driver, entry) + replace_uses!(driver.rewriter, val, new_val) + erase!(driver.rewriter, entry.block, val) else # Substitution: replace root in-place, clean up dead intermediates. # Only delete intermediates with no remaining uses — transparent-op @@ -690,7 +529,7 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match dead_entry = get(driver.defs, dead_val, nothing) dead_entry === nothing && continue use_count(driver, dead_val) == 0 || continue - erase_op!(driver, dead_entry) + erase!(driver.rewriter, dead_entry.block, dead_val) end typ = block[val.id][:type] # Build operands, flattening RSplat nodes into multiple operands @@ -708,9 +547,7 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match # so downstream gates (CSE, LICM) see fresh, correct information. new_stmt = Expr(:call, rule.rhs.func, operands...) flag = rule.rhs.func === driver.defs[val].func ? nothing : inferred_flags(rule.rhs.func) - update_stmt!(driver, block, val, new_stmt; flag) - # Update defs, re-add self and users to worklist (statement changed) - driver.defs[val] = DefEntry(block, val, rule.rhs.func) + replace_stmt!(driver.rewriter, block, val, new_stmt; flag) push!(driver.worklist, val) add_users_to_worklist!(driver, val) end @@ -757,10 +594,12 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; end end - driver = RewriteDriver(sci, defs, dispatch, wl, constants, Set{SSAValue}(), - Dict{SSAValue, Vector{SSAValue}}(), Dict{SSAValue, Int}(), + rewriter = Rewriter(sci) + driver = RewriteDriver(rewriter, defs, dispatch, wl, constants, Set{SSAValue}(), max_rewrites) - build_use_index!(driver) + # Wire the driver as the rewriter's listener so its `notify_*` methods + # fire on every IR mutation. + rewriter.listener = driver num_rewrites = 0 while !isempty(driver.worklist) && num_rewrites < driver.max_rewrites @@ -781,7 +620,7 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; if use_count(driver, val) == 0 stmt = entry.block[val.id][:stmt] if !must_keep(entry.block, stmt) - erase_op!(driver, entry) + erase!(driver.rewriter, entry.block, val) continue end end diff --git a/src/compiler/transform/rewriter.jl b/src/compiler/transform/rewriter.jl new file mode 100644 index 00000000..f4ff9670 --- /dev/null +++ b/src/compiler/transform/rewriter.jl @@ -0,0 +1,291 @@ +# IR Rewriter — mirrors MLIR's RewriterBase/IRRewriter. +# +# `Rewriter` wraps a `StructuredIRCode` and is the single channel through +# which all IR mutations flow during a rewrite session. It maintains a +# use-def index incrementally as a side effect of the mutation API; clients +# don't touch the index. Each mutation also fires `notify_*` listener hooks, +# letting higher-level drivers (e.g. `RewriteDriver` for greedy pattern +# rewriting) maintain their own bookkeeping (defs, worklist, …) without +# having to reach into the IR themselves. +# +# Mirrors: +# MLIR RewriterBase → Rewriter +# MLIR IROperand intrinsic use-list → (users, extra_uses) maps +# MLIR Listener / RewriterBase::Listener → notify_inserted!/_modified!/_erased! + +using Core: SSAValue + +#============================================================================= + Listener protocol +=============================================================================# + +# Default no-op listener callbacks. Drivers override via dispatch on their +# listener type — e.g. `notify_inserted!(d::RewriteDriver, block, inst)`. +"""Called after an instruction is inserted by the Rewriter.""" +notify_inserted!(::Any, ::Block, ::Instruction) = nothing + +"""Called after an instruction's stmt is replaced by the Rewriter. +`old_stmt` is the stmt that was in place before; `new_stmt` is the one now +written. Use this to react to func changes, mark users dirty, etc.""" +notify_modified!(::Any, ::Block, ::SSAValue, @nospecialize(old_stmt), @nospecialize(new_stmt)) = nothing + +"""Called after an instruction is erased by the Rewriter. `old_stmt` is the +stmt that was removed; the SSA value is no longer live.""" +notify_erased!(::Any, ::Block, ::SSAValue, @nospecialize(old_stmt)) = nothing + +#============================================================================= + Rewriter +=============================================================================# + +""" + Rewriter(sci::StructuredIRCode; listener=nothing) + +IR mutation handle for `sci`. Maintains an incremental use-def index — query +with `users(rewriter, val)` / `use_count(rewriter, val)`. All IR mutations +should go through `insert_before!` / `replace_stmt!` / `erase!` / +`replace_uses!` on the rewriter so the index stays consistent and the +listener gets fired. + +`listener` is an opaque object whose type drives the `notify_*` callbacks +via Julia multi-dispatch. Defaults to `nothing` (no-op). +""" +mutable struct Rewriter + sci::StructuredIRCode + # SSA value → SSAs of stmts that reference it. Captures every operand + # site reached by the build-time SCI walk (Expr args, CF op fields, + # ReturnNode/PiNode/alias stmts). Terminator uses live in `extra_uses` + # because terminators have no SSA owner. + users::Dict{SSAValue, Vector{SSAValue}} + extra_uses::Dict{SSAValue, Int} + listener::Any +end + +function Rewriter(sci::StructuredIRCode; listener=nothing) + r = Rewriter(sci, Dict{SSAValue, Vector{SSAValue}}(), + Dict{SSAValue, Int}(), listener) + build_use_index!(r) + return r +end + +#============================================================================= + Queries +=============================================================================# + +"""Users (by SSA) of `val`, or an empty list if `val` has none recorded.""" +users(r::Rewriter, val::SSAValue) = get(r.users, val, SSAValue[]) +users(::Rewriter, ::Any) = SSAValue[] + +"""Total use count for `val`, including terminator uses.""" +use_count(r::Rewriter, val::SSAValue) = + length(users(r, val)) + get(r.extra_uses, val, 0) + +#============================================================================= + Mutation API +=============================================================================# + +""" + insert_before!(r::Rewriter, block::Block, ref, stmt, type; flag=UInt32(0)) + +Insert `stmt` at `ref` in `block`. Registers operand uses and fires +`notify_inserted!` on the listener. Returns the newly created `Instruction`. +""" +function insert_before!(r::Rewriter, block::Block, ref, @nospecialize(stmt), + @nospecialize(type); flag::UInt32=UInt32(0)) + inst = IRStructurizer.insert_before!(block, ref, stmt, type; flag=flag) + register_stmt_uses!(r, SSAValue(inst), stmt) + notify_inserted!(r.listener, block, inst) + return inst +end + +""" + replace_stmt!(r::Rewriter, block::Block, val::SSAValue, new_stmt; flag=nothing) + +Replace the stmt at `val` with `new_stmt`. Updates the use-index (deregister +old operands, register new) and fires `notify_modified!`. If `flag` is given, +the IR_FLAG_* bitmask is updated too; otherwise only the stmt slot changes +(type and flag preserved). +""" +function replace_stmt!(r::Rewriter, block::Block, val::SSAValue, + @nospecialize(new_stmt); flag::Union{UInt32, Nothing}=nothing) + haskey(block, val.id) || return + old_stmt = block[val.id][:stmt] + unregister_stmt_uses!(r, val, old_stmt) + if flag === nothing + block[val.id] = (stmt=new_stmt,) + else + block[val.id] = (stmt=new_stmt, flag=flag) + end + register_stmt_uses!(r, val, new_stmt) + notify_modified!(r.listener, block, val, old_stmt, new_stmt) +end + +""" + erase!(r::Rewriter, block::Block, val::SSAValue) + +Erase the stmt at `val` from `block`. Deregisters this op's operand uses and +fires `notify_erased!`. The listener may inspect `old_stmt` to enqueue +operand-defs for dead-op elim cascading. +""" +function erase!(r::Rewriter, block::Block, val::SSAValue) + haskey(block, val.id) || return + old_stmt = block[val.id][:stmt] + unregister_stmt_uses!(r, val, old_stmt) + delete!(block, val.id) + # Any leftover index entry for `val` would now be stale (nothing produces + # val anymore); drop it so subsequent queries report empty. + delete!(r.users, val) + delete!(r.extra_uses, val) + notify_erased!(r.listener, block, val, old_stmt) +end + +""" + replace_uses!(r::Rewriter, val::SSAValue, new_val) + +Replace every use of `val` in the SCI with `new_val`. The IR walk handles +every operand kind (Expr args, CF op fields, terminators) and the use-index +entries are transferred to `new_val`. No listener event is fired — only the +users' operand slots changed; their stmts as a whole did not. +""" +function replace_uses!(r::Rewriter, val::SSAValue, @nospecialize(new_val)) + IRStructurizer.replace_uses!(r.sci.entry, val, new_val) + transfer_uses!(r, val, new_val) +end + +#============================================================================= + Use-index maintenance (internal) +=============================================================================# + +# Iterate the SSA-positional operands of an Expr (skipping head/callee slot). +function for_expr_operands(f, expr::Expr) + start = expr.head === :invoke ? 3 : 2 + for i in start:length(expr.args) + f(expr.args[i]) + end +end + +function add_use!(r::Rewriter, @nospecialize(operand), user::SSAValue) + operand isa SSAValue || return + push!(get!(Vector{SSAValue}, r.users, operand), user) +end + +function remove_use!(r::Rewriter, @nospecialize(operand), user::SSAValue) + operand isa SSAValue || return + list = get(r.users, operand, nothing) + list === nothing && return + i = findfirst(==(user), list) + i === nothing && return + deleteat!(list, i) + isempty(list) && delete!(r.users, operand) +end + +# Register/unregister operand uses for a stmt owned by `val`. The mutation +# API only ever creates or replaces `Expr` stmts (CF ops and terminators are +# not produced by drivers), so we only need the Expr path here. Initial-build +# populates the other stmt kinds via `register_owned_stmt_uses!`. +function register_stmt_uses!(r::Rewriter, val::SSAValue, @nospecialize(stmt)) + stmt isa Expr || return + for_expr_operands(stmt) do op + add_use!(r, op, val) + end +end + +function unregister_stmt_uses!(r::Rewriter, val::SSAValue, @nospecialize(stmt)) + stmt isa Expr || return + for_expr_operands(stmt) do op + remove_use!(r, op, val) + end +end + +# Move all index entries from `old` to `new_val` after a RAUW. The IR-level +# `replace_uses!` has already rewritten the user stmts; this just shifts the +# bookkeeping so subsequent queries find users under `new_val`. +function transfer_uses!(r::Rewriter, old::SSAValue, @nospecialize(new_val)) + old_users = get(r.users, old, nothing) + if old_users !== nothing + if new_val isa SSAValue + dest = get!(Vector{SSAValue}, r.users, new_val) + append!(dest, old_users) + end + delete!(r.users, old) + end + old_extra = get(r.extra_uses, old, 0) + if old_extra > 0 + if new_val isa SSAValue + r.extra_uses[new_val] = get(r.extra_uses, new_val, 0) + old_extra + end + delete!(r.extra_uses, old) + end +end + +# Build the initial use-index by walking the entire SCI once. Mirrors the +# dispatch in `IRStructurizer.walk_uses!(::Block)` so every operand kind that +# the IR could reference is captured: Expr args, CF op fields, ReturnNode / +# PiNode / alias stmt operands, and block terminators. +function build_use_index!(r::Rewriter) + for block in eachblock(r.sci) + for i in 1:length(block.body.ssa_idxes) + owner = SSAValue(block.body.ssa_idxes[i]) + register_owned_stmt_uses!(r, owner, block.body.stmts[i]) + end + register_terminator_uses!(r, block.terminator) + end +end + +function register_owned_stmt_uses!(r::Rewriter, owner::SSAValue, @nospecialize(stmt)) + if stmt isa Expr + for_expr_operands(stmt) do op + add_use!(r, op, owner) + end + elseif stmt isa Core.ReturnNode + isdefined(stmt, :val) && add_use!(r, stmt.val, owner) + elseif stmt isa Core.PiNode + add_use!(r, stmt.val, owner) + elseif stmt isa SSAValue + # Alias/forwarding stmt: the stmt slot itself IS the operand. + add_use!(r, stmt, owner) + elseif stmt isa ControlFlowOp + register_cf_op_uses!(r, owner, stmt) + end +end + +function register_cf_op_uses!(r::Rewriter, owner::SSAValue, op::IfOp) + add_use!(r, op.condition, owner) +end + +function register_cf_op_uses!(r::Rewriter, owner::SSAValue, op::ForOp) + add_use!(r, op.lower, owner) + add_use!(r, op.upper, owner) + add_use!(r, op.step, owner) + for iv in op.init_values + add_use!(r, iv, owner) + end +end + +function register_cf_op_uses!(r::Rewriter, owner::SSAValue, + op::Union{WhileOp, LoopOp}) + for iv in op.init_values + add_use!(r, iv, owner) + end +end + +# Terminator operands have no SSA owner; record them as anonymous counts so +# `use_count` reflects them (needed for the dead-op elim shortcut). +function register_terminator_uses!(r::Rewriter, term) + if term isa Core.ReturnNode + isdefined(term, :val) && bump_extra_use!(r, term.val) + elseif term isa ConditionOp + bump_extra_use!(r, term.condition) + for arg in term.args + bump_extra_use!(r, arg) + end + elseif term isa Union{ContinueOp, BreakOp, YieldOp} + for v in term.values + bump_extra_use!(r, v) + end + end +end + +function bump_extra_use!(r::Rewriter, @nospecialize(val)) + val isa SSAValue || return + r.extra_uses[val] = get(r.extra_uses, val, 0) + 1 +end diff --git a/src/cuTile.jl b/src/cuTile.jl index c1eed906..784e1918 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -5,7 +5,7 @@ using IRStructurizer: Block, ControlFlowOp, BlockArgument, YieldOp, ContinueOp, BreakOp, ConditionOp, IfOp, ForOp, WhileOp, LoopOp, Undef, SourceLocation -import IRStructurizer: operands +import IRStructurizer: operands, replace_uses!, insert_before! using Base: compilerbarrier, donotdelete using Core: MethodInstance, CodeInfo, SSAValue, Argument, SlotNumber, @@ -51,6 +51,7 @@ include("compiler/analysis/effects.jl") include("compiler/analysis/divisibility.jl") include("compiler/analysis/bounds.jl") include("compiler/analysis/assume.jl") +include("compiler/transform/rewriter.jl") include("compiler/transform/rewrite.jl") include("compiler/transform/canonicalize.jl") include("compiler/transform/control_flow.jl") From cf5b4864abee9035e634abc6fdb11fda972e02b0 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 28 May 2026 15:32:25 +0200 Subject: [PATCH 3/3] Rewrite comments. --- src/compiler/transform/rewrite.jl | 29 ++++---- src/compiler/transform/rewriter.jl | 107 +++++++++++++---------------- 2 files changed, 59 insertions(+), 77 deletions(-) diff --git a/src/compiler/transform/rewrite.jl b/src/compiler/transform/rewrite.jl index eba6d65c..7ab25b82 100644 --- a/src/compiler/transform/rewrite.jl +++ b/src/compiler/transform/rewrite.jl @@ -227,15 +227,14 @@ mutable struct RewriteDriver end #============================================================================= - Rewriter listener — driver-side bookkeeping reacts to IR mutations + Rewriter listener: driver-side bookkeeping reacts to IR mutations =============================================================================# -# The driver maintains its own auxiliary state (a `defs` map keyed by func, -# a worklist, and a `modified` set for cascading propagation) on top of the -# Rewriter's IR-level state. The Rewriter calls these `notify_*` hooks at -# every mutation; the driver uses them to keep its state in sync without -# having to wrap each mutation site itself. Mirrors MLIR's Listener pattern -# (`RewriterBase::Listener::notifyOperation{Inserted,Modified,Erased}`). +# The driver has its own state on top of the Rewriter's index: a `defs` map +# keyed by func, the `worklist`, and a `modified` set for cascading. These +# `notify_*` hooks let the Rewriter keep that state in sync, so the driver +# doesn't need to wrap every mutation callsite. Same pattern as MLIR's +# `RewriterBase::Listener::notifyOperation{Inserted,Modified,Erased}`. function notify_inserted!(d::RewriteDriver, block::Block, inst::Instruction) stmt = inst[:stmt] @@ -249,21 +248,20 @@ end function notify_modified!(d::RewriteDriver, block::Block, val::SSAValue, @nospecialize(old_stmt), @nospecialize(new_stmt)) - # Func may have changed (substitution rewrites); refresh the def entry - # so the next worklist visit sees the new func for rule dispatch. + # Refresh the def entry so worklist dispatch picks the new func. call = resolve_call(block, new_stmt) if call !== nothing func, _ = call d.defs[val] = DefEntry(block, val, func) end - # No worklist cascade here — callers (`apply_rewrite!`, - # `apply_inplace_rewrite!`, `commute_arith_transparent`) handle re-seeding - # of `val` + its users explicitly after the mutation returns. + # Worklist cascading is done by the callers (`apply_rewrite!`, + # `apply_inplace_rewrite!`, `commute_arith_transparent`), which re-seed + # `val` and its users themselves. end function notify_erased!(d::RewriteDriver, ::Block, val::SSAValue, @nospecialize(old_stmt)) - # Operand-defs of the erased op may now be dead — cascade to worklist. + # Operand-defs may now be dead; cascade them to the worklist. if old_stmt isa Expr for_expr_operands(old_stmt) do op op isa SSAValue || return @@ -454,8 +452,7 @@ function resolve_inplace_rhs(driver, bindings, op::RCall, lhs_op::PCall) for (sub_rhs, sub_lhs) in zip(op.operands, lhs_op.operands)] # Sub-call func is enforced equal to the matched lhs func above # (`op.func === lhs_op.func`), so the flag is still valid — only operands - # changed. The rewriter's `replace_stmt!` preserves type and flag via - # partial-NamedTuple setindex. + # changed. replace_stmt!(driver.rewriter, entry.block, matched_ssa, Expr(:call, op.func, new_ops...)) push!(driver.worklist, matched_ssa) return matched_ssa @@ -597,8 +594,6 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; rewriter = Rewriter(sci) driver = RewriteDriver(rewriter, defs, dispatch, wl, constants, Set{SSAValue}(), max_rewrites) - # Wire the driver as the rewriter's listener so its `notify_*` methods - # fire on every IR mutation. rewriter.listener = driver num_rewrites = 0 diff --git a/src/compiler/transform/rewriter.jl b/src/compiler/transform/rewriter.jl index f4ff9670..a3fbdab8 100644 --- a/src/compiler/transform/rewriter.jl +++ b/src/compiler/transform/rewriter.jl @@ -1,17 +1,12 @@ -# IR Rewriter — mirrors MLIR's RewriterBase/IRRewriter. +# IR Rewriter, modeled on MLIR's RewriterBase/IRRewriter. # -# `Rewriter` wraps a `StructuredIRCode` and is the single channel through -# which all IR mutations flow during a rewrite session. It maintains a -# use-def index incrementally as a side effect of the mutation API; clients -# don't touch the index. Each mutation also fires `notify_*` listener hooks, -# letting higher-level drivers (e.g. `RewriteDriver` for greedy pattern -# rewriting) maintain their own bookkeeping (defs, worklist, …) without -# having to reach into the IR themselves. +# All IR mutations during a rewrite session go through a `Rewriter`, which +# in return keeps an incremental use-def index up to date and fires +# `notify_*` hooks so higher-level drivers can react. Correspondences: # -# Mirrors: -# MLIR RewriterBase → Rewriter -# MLIR IROperand intrinsic use-list → (users, extra_uses) maps -# MLIR Listener / RewriterBase::Listener → notify_inserted!/_modified!/_erased! +# RewriterBase -> Rewriter +# IROperand intrinsic use-list -> (users, extra_uses) maps +# RewriterBase::Listener -> notify_inserted!/_modified!/_erased! using Core: SSAValue @@ -19,18 +14,16 @@ using Core: SSAValue Listener protocol =============================================================================# -# Default no-op listener callbacks. Drivers override via dispatch on their -# listener type — e.g. `notify_inserted!(d::RewriteDriver, block, inst)`. -"""Called after an instruction is inserted by the Rewriter.""" +# Default no-op listener callbacks. Drivers override by adding a method on +# their listener type (e.g. `notify_inserted!(d::RewriteDriver, block, inst)`). +"""Fires after the Rewriter inserts an instruction.""" notify_inserted!(::Any, ::Block, ::Instruction) = nothing -"""Called after an instruction's stmt is replaced by the Rewriter. -`old_stmt` is the stmt that was in place before; `new_stmt` is the one now -written. Use this to react to func changes, mark users dirty, etc.""" +"""Fires when the Rewriter replaces a stmt. Drivers use this to react to +func changes or mark users dirty.""" notify_modified!(::Any, ::Block, ::SSAValue, @nospecialize(old_stmt), @nospecialize(new_stmt)) = nothing -"""Called after an instruction is erased by the Rewriter. `old_stmt` is the -stmt that was removed; the SSA value is no longer live.""" +"""Fires once an instruction has been erased; the SSA value is no longer live.""" notify_erased!(::Any, ::Block, ::SSAValue, @nospecialize(old_stmt)) = nothing #============================================================================= @@ -40,21 +33,19 @@ notify_erased!(::Any, ::Block, ::SSAValue, @nospecialize(old_stmt)) = nothing """ Rewriter(sci::StructuredIRCode; listener=nothing) -IR mutation handle for `sci`. Maintains an incremental use-def index — query -with `users(rewriter, val)` / `use_count(rewriter, val)`. All IR mutations -should go through `insert_before!` / `replace_stmt!` / `erase!` / -`replace_uses!` on the rewriter so the index stays consistent and the -listener gets fired. +IR mutation handle for `sci`. Holds an incremental use-def index; query it +with `users(rewriter, val)` and `use_count(rewriter, val)`. Mutate the IR +only through `insert_before!`, `replace_stmt!`, `erase!`, or +`replace_uses!` on the rewriter, which keeps the index consistent and fires +the listener. -`listener` is an opaque object whose type drives the `notify_*` callbacks -via Julia multi-dispatch. Defaults to `nothing` (no-op). +`listener` is dispatched on its type via the `notify_*` generic functions. +Defaults to `nothing` (no-op). """ mutable struct Rewriter sci::StructuredIRCode - # SSA value → SSAs of stmts that reference it. Captures every operand - # site reached by the build-time SCI walk (Expr args, CF op fields, - # ReturnNode/PiNode/alias stmts). Terminator uses live in `extra_uses` - # because terminators have no SSA owner. + # `users`: SSA value -> SSAs of stmts that reference it. `extra_uses` + # holds the count of terminator uses, which have no SSA owner. users::Dict{SSAValue, Vector{SSAValue}} extra_uses::Dict{SSAValue, Int} listener::Any @@ -86,8 +77,8 @@ use_count(r::Rewriter, val::SSAValue) = """ insert_before!(r::Rewriter, block::Block, ref, stmt, type; flag=UInt32(0)) -Insert `stmt` at `ref` in `block`. Registers operand uses and fires -`notify_inserted!` on the listener. Returns the newly created `Instruction`. +Insert `stmt` at `ref` in `block`. Updates the use-index for the new stmt's +operands and fires `notify_inserted!`. Returns the new `Instruction`. """ function insert_before!(r::Rewriter, block::Block, ref, @nospecialize(stmt), @nospecialize(type); flag::UInt32=UInt32(0)) @@ -100,10 +91,10 @@ end """ replace_stmt!(r::Rewriter, block::Block, val::SSAValue, new_stmt; flag=nothing) -Replace the stmt at `val` with `new_stmt`. Updates the use-index (deregister -old operands, register new) and fires `notify_modified!`. If `flag` is given, -the IR_FLAG_* bitmask is updated too; otherwise only the stmt slot changes -(type and flag preserved). +Replace the stmt at `val` with `new_stmt`. Updates the use-index for both +sets of operands and fires `notify_modified!`. Pass `flag` to also update +the IR_FLAG_* bitmask; otherwise only the stmt slot changes (type and flag +are preserved). """ function replace_stmt!(r::Rewriter, block::Block, val::SSAValue, @nospecialize(new_stmt); flag::Union{UInt32, Nothing}=nothing) @@ -122,17 +113,15 @@ end """ erase!(r::Rewriter, block::Block, val::SSAValue) -Erase the stmt at `val` from `block`. Deregisters this op's operand uses and -fires `notify_erased!`. The listener may inspect `old_stmt` to enqueue -operand-defs for dead-op elim cascading. +Erase the stmt at `val` from `block`. Drops the op's operand uses from the +use-index, then fires `notify_erased!`. Listeners can inspect `old_stmt` to +cascade operand-defs through their own worklist for dead-op elim. """ function erase!(r::Rewriter, block::Block, val::SSAValue) haskey(block, val.id) || return old_stmt = block[val.id][:stmt] unregister_stmt_uses!(r, val, old_stmt) delete!(block, val.id) - # Any leftover index entry for `val` would now be stale (nothing produces - # val anymore); drop it so subsequent queries report empty. delete!(r.users, val) delete!(r.extra_uses, val) notify_erased!(r.listener, block, val, old_stmt) @@ -141,10 +130,10 @@ end """ replace_uses!(r::Rewriter, val::SSAValue, new_val) -Replace every use of `val` in the SCI with `new_val`. The IR walk handles -every operand kind (Expr args, CF op fields, terminators) and the use-index -entries are transferred to `new_val`. No listener event is fired — only the -users' operand slots changed; their stmts as a whole did not. +RAUW: replace every use of `val` in the SCI with `new_val`, and move the +use-index entries to `new_val`. The walk covers every operand kind (Expr +args, CF op fields, terminators). No listener event fires; only operand +slots change, not the user stmts themselves. """ function replace_uses!(r::Rewriter, val::SSAValue, @nospecialize(new_val)) IRStructurizer.replace_uses!(r.sci.entry, val, new_val) @@ -155,7 +144,6 @@ end Use-index maintenance (internal) =============================================================================# -# Iterate the SSA-positional operands of an Expr (skipping head/callee slot). function for_expr_operands(f, expr::Expr) start = expr.head === :invoke ? 3 : 2 for i in start:length(expr.args) @@ -178,10 +166,9 @@ function remove_use!(r::Rewriter, @nospecialize(operand), user::SSAValue) isempty(list) && delete!(r.users, operand) end -# Register/unregister operand uses for a stmt owned by `val`. The mutation -# API only ever creates or replaces `Expr` stmts (CF ops and terminators are -# not produced by drivers), so we only need the Expr path here. Initial-build -# populates the other stmt kinds via `register_owned_stmt_uses!`. +# Incremental maintenance only needs the Expr path: drivers never produce +# CF ops or terminators. `register_owned_stmt_uses!` handles those once at +# initial build time. function register_stmt_uses!(r::Rewriter, val::SSAValue, @nospecialize(stmt)) stmt isa Expr || return for_expr_operands(stmt) do op @@ -196,9 +183,9 @@ function unregister_stmt_uses!(r::Rewriter, val::SSAValue, @nospecialize(stmt)) end end -# Move all index entries from `old` to `new_val` after a RAUW. The IR-level -# `replace_uses!` has already rewritten the user stmts; this just shifts the -# bookkeeping so subsequent queries find users under `new_val`. +# Reassign every index entry from `old` to `new_val` so subsequent queries +# find the users under `new_val`. The IR-level mutation lives in +# `replace_uses!`; this only touches the bookkeeping. function transfer_uses!(r::Rewriter, old::SSAValue, @nospecialize(new_val)) old_users = get(r.users, old, nothing) if old_users !== nothing @@ -217,10 +204,10 @@ function transfer_uses!(r::Rewriter, old::SSAValue, @nospecialize(new_val)) end end -# Build the initial use-index by walking the entire SCI once. Mirrors the -# dispatch in `IRStructurizer.walk_uses!(::Block)` so every operand kind that -# the IR could reference is captured: Expr args, CF op fields, ReturnNode / -# PiNode / alias stmt operands, and block terminators. +# One-shot SCI walk that seeds the use-index. Dispatch mirrors +# `IRStructurizer.walk_uses!(::Block)` so every operand kind is covered: +# Expr args, CF op fields, ReturnNode/PiNode/alias stmt operands, and +# block terminators. function build_use_index!(r::Rewriter) for block in eachblock(r.sci) for i in 1:length(block.body.ssa_idxes) @@ -268,8 +255,8 @@ function register_cf_op_uses!(r::Rewriter, owner::SSAValue, end end -# Terminator operands have no SSA owner; record them as anonymous counts so -# `use_count` reflects them (needed for the dead-op elim shortcut). +# Terminator operands have no SSA owner, so they're recorded as anonymous +# counts. The dead-op-elim shortcut in `use_count` needs them. function register_terminator_uses!(r::Rewriter, term) if term isa Core.ReturnNode isdefined(term, :val) && bump_extra_use!(r, term.val)