From c9345f7530c148e3576c7c03679881a1c16a10a3 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 30 Mar 2026 15:22:03 +0200 Subject: [PATCH 1/4] Add algebraic simplification pass to cancel addi/subi pairs. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an algebra_pass! with two rewrite rules that cancel inverse addi/subi pairs (x+c-c → x, x-c+c → x). This eliminates the redundant subi instructions generated by the 1-based to 0-based index conversion in load/store operations (e.g. bid() + One() - One()). Two supporting changes to the rewrite framework: - DefEntry no longer caches a stale operands copy; pattern matching now reads live operands from the instruction via resolve_call, so bindings reflect updates from prior rewrites within the same pass. - RBind consumed tracking only marks the root instruction, leaving shared intermediates matchable by subsequent rules (e.g. a single addi used by multiple subi sites). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/passes/pipeline.jl | 15 +++++++ src/compiler/codegen/passes/rewrite.jl | 52 +++++++++++++++++-------- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/src/compiler/codegen/passes/pipeline.jl b/src/compiler/codegen/passes/pipeline.jl index 450706f0..fc7631d6 100644 --- a/src/compiler/codegen/passes/pipeline.jl +++ b/src/compiler/codegen/passes/pipeline.jl @@ -111,6 +111,20 @@ const FMA_RULES = RewriteRule[ fma_fusion_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, FMA_RULES) +#============================================================================= + Algebraic Simplification (rewrite) +=============================================================================# + +# Cancel inverse addi/subi pairs: x+c-c → x, x-c+c → x. +# Repeated ~c binds enforce that both operands are the same value. + +const ALGEBRA_RULES = RewriteRule[ + @rewrite Intrinsics.subi(Intrinsics.addi(~x, ~c), ~c) => ~x + @rewrite Intrinsics.addi(Intrinsics.subi(~x, ~c), ~c) => ~x +] + +algebra_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, ALGEBRA_RULES) + #============================================================================= Pass Pipeline =============================================================================# @@ -124,6 +138,7 @@ and subprogram compilation. function run_passes!(sci::StructuredIRCode) # Rewrite passes (order matters: normalize before optimize, SVE before FMA) normalize_pass!(sci) + algebra_pass!(sci) scalar_view_elim_pass!(sci) fma_fusion_pass!(sci) diff --git a/src/compiler/codegen/passes/rewrite.jl b/src/compiler/codegen/passes/rewrite.jl index bfa15323..26fc580a 100644 --- a/src/compiler/codegen/passes/rewrite.jl +++ b/src/compiler/codegen/passes/rewrite.jl @@ -122,7 +122,14 @@ struct DefEntry block::Block inst::Instruction func::Any - operands::Vector{Any} +end + +"""Operands of a DefEntry, read from the live instruction.""" +function _def_operands(entry::DefEntry) + call = resolve_call(entry.inst) + call === nothing && return Any[] + _, ops = call + return ops end struct MatchContext @@ -137,8 +144,8 @@ function MatchContext(sci::StructuredIRCode) for inst in instructions(block) call = resolve_call(inst) call === nothing && continue - func, operands = call - defs[inst.ssa_idx] = DefEntry(block, inst, func, collect(Any, operands)) + func, _ = call + defs[inst.ssa_idx] = DefEntry(block, inst, func) end end MatchContext(sci.entry, defs, uses(sci.entry)) @@ -170,22 +177,27 @@ function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PCall, entry = get(ctx.defs, val.id, nothing) entry === nothing && return nothing - if entry.func === pat.func && length(entry.operands) == length(pat.operands) - result = MatchResult(Dict{Symbol,Any}(), Int[val.id]) - for (op, sub) in zip(entry.operands, pat.operands) - m = pattern_match(ctx, op, sub, entry.block) - m === nothing && return nothing - _merge_bindings!(result.bindings, m.bindings) || return nothing - append!(result.matched_ssas, m.matched_ssas) + if entry.func === pat.func + ops = _def_operands(entry) + if length(ops) == length(pat.operands) + result = MatchResult(Dict{Symbol,Any}(), Int[val.id]) + for (op, sub) in zip(ops, pat.operands) + m = pattern_match(ctx, op, sub, entry.block) + m === nothing && return nothing + _merge_bindings!(result.bindings, m.bindings) || return nothing + append!(result.matched_ssas, m.matched_ssas) + end + return result end - return result end # Trace through single-use transparent ops to find the underlying operation - if _is_transparent(entry.func) && !isempty(entry.operands) + if _is_transparent(entry.func) _use_count(ctx, val) == 1 || return nothing + ops = _def_operands(entry) + isempty(ops) && return nothing if entry.func === Intrinsics.broadcast - inner = entry.operands[1] + inner = ops[1] if inner isa SSAValue inner_entry = get(ctx.defs, inner.id, nothing) if inner_entry !== nothing @@ -195,7 +207,7 @@ function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PCall, end end end - result = pattern_match(ctx, entry.operands[1], pat, entry.block) + result = pattern_match(ctx, ops[1], pat, entry.block) result === nothing && return nothing push!(result.matched_ssas, val.id) return result @@ -243,7 +255,6 @@ function _apply_rewrite!(sci, block, inst, rule, match, ctx, consumed) _cleanup_dead_operands!(sci, inst.ssa_idx, ctx, consumed) else # Substitution: delete matched intermediates, replace root statement in-place. - # No operand cleanup here — the new replacement may reference these operands. for dead_ssa in match.matched_ssas dead_ssa == inst.ssa_idx && continue entry = ctx.defs[dead_ssa] @@ -261,7 +272,7 @@ end function _cleanup_dead_operands!(sci, ssa_idx, ctx, consumed) entry = get(ctx.defs, ssa_idx, nothing) entry === nothing && return - for op in entry.operands + for op in _def_operands(entry) op isa SSAValue || continue op_entry = get(ctx.defs, op.id, nothing) op_entry === nothing && continue @@ -306,7 +317,14 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}) rule.guard !== nothing && !rule.guard(m, ctx) && continue result = _apply_rewrite!(sci, block, inst, rule, m, ctx, consumed) result === false && continue # RFunc declined, try next rule - union!(consumed, m.matched_ssas) + # RBind only deletes the root; intermediates stay live and + # must remain matchable. Substitution deletes intermediates + # so they must be consumed to prevent the driver revisiting. + if rule.rhs isa RBind + push!(consumed, inst.ssa_idx) + else + union!(consumed, m.matched_ssas) + end break end end From 01146a5d6b3eda8ba234c3812dddae940bca35b3 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 30 Mar 2026 16:29:55 +0200 Subject: [PATCH 2/4] Worklist-based fixpoint rewrite driver. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the single-pass linear scan driver with a LIFO worklist that processes until fixpoint, inspired by MLIR's GreedyPatternRewriteDriver. Key changes: - No stale MatchContext: DefEntry reads live operands from the IR via resolve_call. Use counts computed on-demand via uses() — always fresh. - Worklist with notifications: when a rewrite fires, affected instructions (users, operand-producers) are re-added to the worklist, enabling cascading rewrites across rule sets. - Unified rule set: all rewrite rules (normalize, algebra, SVE, FMA) run in a single fixpoint invocation instead of separate passes. - Trivial dead-op elimination on worklist pop: keeps use counts accurate for one_use patterns (e.g. FMA fusion after SVE removes transparent op chains). Full DCE still runs after for complex dead code. - Safe intermediate deletion: substitution rewrites only delete matched intermediates that have no remaining uses, fixing a bug where transparent-op tracing could add multi-use intermediates to matched_ssas. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/passes/pipeline.jl | 18 +- src/compiler/codegen/passes/rewrite.jl | 343 ++++++++++++++++-------- 2 files changed, 241 insertions(+), 120 deletions(-) diff --git a/src/compiler/codegen/passes/pipeline.jl b/src/compiler/codegen/passes/pipeline.jl index fc7631d6..3a7abb7b 100644 --- a/src/compiler/codegen/passes/pipeline.jl +++ b/src/compiler/codegen/passes/pipeline.jl @@ -125,6 +125,17 @@ const ALGEBRA_RULES = RewriteRule[ algebra_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, ALGEBRA_RULES) +#============================================================================= + Combined Rule Set +=============================================================================# + +const ALL_REWRITE_RULES = RewriteRule[ + NORMALIZE_RULES..., + ALGEBRA_RULES..., + SVE_RULES..., + FMA_RULES..., +] + #============================================================================= Pass Pipeline =============================================================================# @@ -136,11 +147,8 @@ Run the full pass pipeline on a StructuredIRCode. Called for both kernel and subprogram compilation. """ function run_passes!(sci::StructuredIRCode) - # Rewrite passes (order matters: normalize before optimize, SVE before FMA) - normalize_pass!(sci) - algebra_pass!(sci) - scalar_view_elim_pass!(sci) - fma_fusion_pass!(sci) + # All rewrite rules in one fixpoint pass + rewrite_patterns!(sci, ALL_REWRITE_RULES) # Memory ordering alias_result = alias_analysis_pass!(sci) diff --git a/src/compiler/codegen/passes/rewrite.jl b/src/compiler/codegen/passes/rewrite.jl index 26fc580a..1d56a734 100644 --- a/src/compiler/codegen/passes/rewrite.jl +++ b/src/compiler/codegen/passes/rewrite.jl @@ -1,8 +1,10 @@ # Declarative IR Rewrite Pattern Framework # -# Inspired by MLIR's PDLL. Patterns compile into pattern/rewrite node trees. -# The framework handles matching (recursive SSA def-chain walking) and rewrite -# application. Cleanup of dead code is delegated to the pipeline's dce_pass!. +# Worklist-based fixpoint driver inspired by MLIR's GreedyPatternRewriteDriver. +# Patterns compile into pattern/rewrite node trees. The driver processes a LIFO +# worklist until fixpoint: when a rewrite fires, affected instructions are +# re-added to the worklist for further matching. Dead code cleanup is delegated +# to the pipeline's dce_pass!. # # Usage: # rules = RewriteRule[ @@ -33,17 +35,16 @@ struct RConst <: RewriteNode; val::Any; end """ RFunc(func) -Imperative rewrite node (MLIR-inspired). The function is called with -`(sci, block, inst, match, ctx)` and returns `true` if the rewrite was applied, -`false` to skip this rule and try the next one. `ctx` is the `MatchContext` -providing def-chain lookups via `ctx.defs`. +Imperative rewrite node. The function is called with +`(sci, block, inst, match, driver)` and returns `true` if the rewrite was +applied, `false` to skip this rule and try the next one. """ struct RFunc <: RewriteNode; func::Function; end struct RewriteRule lhs::PCall rhs::RewriteNode - guard::Union{Function, Nothing} # (match, ctx) -> Bool, or nothing + guard::Union{Function, Nothing} # (match, driver) -> Bool, or nothing end RewriteRule(lhs::PCall, rhs::RewriteNode) = RewriteRule(lhs, rhs, nothing) @@ -61,7 +62,7 @@ Compile a declarative rewrite rule. LHS: `func(args...)` matches calls, `~x` binds (repeated names require equality), `~x::T` binds with type constraint, `one_use(pat)` requires single use. RHS: `func(args...)` emits calls, `~x` references bindings, `\$(expr)` injects a literal constant. -Optional `guard` is a function `(match, ctx) -> Bool` checked after pattern match. +Optional `guard` is a function `(match, driver) -> Bool` checked after pattern match. """ macro rewrite(ex, guard=nothing) ex isa Expr && ex.head === :call && ex.args[1] === :(=>) || @@ -74,10 +75,9 @@ end @rewriter lhs => func Declarative pattern with imperative rewrite. LHS uses the same pattern syntax as -`@rewrite`. RHS is a function `(sci, block, inst, match, ctx) -> Bool` that +`@rewrite`. RHS is a function `(sci, block, inst, match, driver) -> Bool` that performs the rewrite and returns `true`, or returns `false` to skip and try the -next rule. Matched bindings are in `match.bindings`; def-chain lookups via -`ctx.defs`. +next rule. """ macro rewriter(ex) ex isa Expr && ex.head === :call && ex.args[1] === :(=>) || @@ -110,55 +110,132 @@ function _compile_rhs(ex) end #============================================================================= - Matching + Worklist =============================================================================# -struct MatchResult - bindings::Dict{Symbol, Any} - matched_ssas::Vector{Int} +mutable struct Worklist + list::Vector{Int} # SSA indices (-1 = removed sentinel) + member::Dict{Int, Int} # ssa_idx -> position in list +end + +Worklist() = Worklist(Int[], Dict{Int, Int}()) + +function Base.push!(wl::Worklist, ssa_idx::Int) + haskey(wl.member, ssa_idx) && return + push!(wl.list, ssa_idx) + wl.member[ssa_idx] = length(wl.list) +end + +function Base.pop!(wl::Worklist) + while !isempty(wl.list) + idx = pop!(wl.list) + idx == -1 && continue + delete!(wl.member, idx) + return idx + end + return nothing +end + +function remove!(wl::Worklist, ssa_idx::Int) + pos = get(wl.member, ssa_idx, 0) + pos == 0 && return + wl.list[pos] = -1 + delete!(wl.member, ssa_idx) end +Base.isempty(wl::Worklist) = isempty(wl.member) + +#============================================================================= + Driver State +=============================================================================# + struct DefEntry block::Block - inst::Instruction + ssa_idx::Int func::Any end -"""Operands of a DefEntry, read from the live instruction.""" +"""Operands of a DefEntry, read from the live IR.""" function _def_operands(entry::DefEntry) - call = resolve_call(entry.inst) + pos = findfirst(==(entry.ssa_idx), entry.block.body.ssa_idxes) + pos === nothing && return Any[] + call = resolve_call(entry.block.body.stmts[pos]) call === nothing && return Any[] _, ops = call return ops end -struct MatchContext - entry::Block # root block of the StructuredIRCode (for value_type lookups) +mutable struct RewriteDriver + sci::StructuredIRCode defs::Dict{Int, DefEntry} - use_index # UseIndex from IRStructurizer -end - -function MatchContext(sci::StructuredIRCode) - defs = Dict{Int, DefEntry}() - for block in eachblock(sci) - for inst in instructions(block) - call = resolve_call(inst) - call === nothing && continue - func, _ = call - defs[inst.ssa_idx] = DefEntry(block, inst, func) - end - end - MatchContext(sci.entry, defs, uses(sci.entry)) + dispatch::Dict{Any, Vector{RewriteRule}} + worklist::Worklist + max_rewrites::Int end -_use_count(ctx::MatchContext, val::SSAValue) = - haskey(ctx.use_index, val) ? length(ctx.use_index[val]) : 0 +"""Compute fresh use count for an SSA value.""" +_use_count(driver::RewriteDriver, val::SSAValue) = + length(uses(driver.sci.entry, val)) # Codegen no-ops that pattern matching traces through transparently. _is_transparent(func) = func === Intrinsics.to_scalar || func === Intrinsics.from_scalar || func === Intrinsics.broadcast +#============================================================================= + Notifications +=============================================================================# + +"""Add operand-producing instructions to the worklist (enables cascading).""" +function _add_operands_to_worklist!(driver::RewriteDriver, entry::DefEntry) + for op in _def_operands(entry) + op isa SSAValue || continue + haskey(driver.defs, op.id) && push!(driver.worklist, op.id) + end +end + +"""Add instructions that use `val` to the worklist (their operand changed).""" +function _add_users_to_worklist!(driver::RewriteDriver, ssa_idx::Int) + val = SSAValue(ssa_idx) + for (id, entry) in driver.defs + id == ssa_idx && continue + for op in _def_operands(entry) + if op isa SSAValue && op.id == ssa_idx + push!(driver.worklist, id) + break + end + end + end +end + +"""Erase an instruction and notify the worklist.""" +function _erase_op!(driver::RewriteDriver, entry::DefEntry) + _add_operands_to_worklist!(driver, entry) + pos = findfirst(==(entry.ssa_idx), entry.block.body.ssa_idxes) + if pos !== nothing + deleteat!(entry.block.body.ssa_idxes, pos) + deleteat!(entry.block.body.stmts, pos) + deleteat!(entry.block.body.types, pos) + end + delete!(driver.defs, entry.ssa_idx) + remove!(driver.worklist, entry.ssa_idx) +end + +"""Register a newly inserted instruction.""" +function _notify_insert!(driver::RewriteDriver, block::Block, ssa_idx::Int, func) + driver.defs[ssa_idx] = DefEntry(block, ssa_idx, func) + push!(driver.worklist, ssa_idx) +end + +#============================================================================= + Matching +=============================================================================# + +struct MatchResult + bindings::Dict{Symbol, Any} + matched_ssas::Vector{Int} +end + """Merge bindings, requiring repeated names to bind the same value (=== equality).""" function _merge_bindings!(dest::Dict{Symbol,Any}, src::Dict{Symbol,Any}) for (k, v) in src @@ -171,10 +248,10 @@ function _merge_bindings!(dest::Dict{Symbol,Any}, src::Dict{Symbol,Any}) return true end -function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PCall, - block::Block=ctx.entry) +function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, + block::Block=driver.sci.entry) val isa SSAValue || return nothing - entry = get(ctx.defs, val.id, nothing) + entry = get(driver.defs, val.id, nothing) entry === nothing && return nothing if entry.func === pat.func @@ -182,7 +259,7 @@ function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PCall, if length(ops) == length(pat.operands) result = MatchResult(Dict{Symbol,Any}(), Int[val.id]) for (op, sub) in zip(ops, pat.operands) - m = pattern_match(ctx, op, sub, entry.block) + m = pattern_match(driver, op, sub, entry.block) m === nothing && return nothing _merge_bindings!(result.bindings, m.bindings) || return nothing append!(result.matched_ssas, m.matched_ssas) @@ -193,21 +270,23 @@ function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PCall, # Trace through single-use transparent ops to find the underlying operation if _is_transparent(entry.func) - _use_count(ctx, val) == 1 || return nothing + _use_count(driver, val) == 1 || return nothing ops = _def_operands(entry) isempty(ops) && return nothing if entry.func === Intrinsics.broadcast inner = ops[1] if inner isa SSAValue - inner_entry = get(ctx.defs, inner.id, nothing) + inner_entry = get(driver.defs, inner.id, nothing) if inner_entry !== nothing - it, ot = value_type(inner_entry.inst), value_type(entry.inst) - it <: Tile && ot <: Tile || return nothing - size(it) == size(ot) || return nothing + it = value_type(entry.block, inner) + ot = value_type(entry.block, val) + it !== nothing && ot !== nothing || return nothing + CC.widenconst(it) <: Tile && CC.widenconst(ot) <: Tile || return nothing + size(CC.widenconst(it)) == size(CC.widenconst(ot)) || return nothing end end end - result = pattern_match(ctx, ops[1], pat, entry.block) + result = pattern_match(driver, ops[1], pat, entry.block) result === nothing && return nothing push!(result.matched_ssas, val.id) return result @@ -215,21 +294,21 @@ function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PCall, return nothing end -pattern_match(ctx::MatchContext, @nospecialize(val), pat::PBind, block::Block=ctx.entry) = +pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.sci.entry) = MatchResult(Dict{Symbol,Any}(pat.name => val), Int[]) -function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PTypedBind, - block::Block=ctx.entry) +function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PTypedBind, + block::Block=driver.sci.entry) T = value_type(block, val) T === nothing && return nothing CC.widenconst(T) <: pat.type || return nothing MatchResult(Dict{Symbol,Any}(pat.name => val), Int[]) end -function pattern_match(ctx::MatchContext, @nospecialize(val), pat::POneUse, - block::Block=ctx.entry) - val isa SSAValue && _use_count(ctx, val) == 1 || return nothing - pattern_match(ctx, val, pat.inner, block) +function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::POneUse, + block::Block=driver.sci.entry) + val isa SSAValue && _use_count(driver, val) == 1 || return nothing + pattern_match(driver, val, pat.inner, block) end #============================================================================= @@ -237,51 +316,51 @@ end =============================================================================# """Resolve an RHS operand, inserting sub-calls before `ref` as needed.""" -_resolve_rhs(block, ref, op::RBind, bindings, typ) = bindings[op.name] -_resolve_rhs(block, ref, op::RConst, bindings, typ) = op.val -function _resolve_rhs(block, ref, op::RCall, bindings, typ) - operands = Any[_resolve_rhs(block, ref, sub, bindings, typ) for sub in op.operands] - SSAValue(insert_before!(block, ref, Expr(:call, op.func, operands...), typ)) +_resolve_rhs(driver, block, ref, op::RBind, bindings, typ) = bindings[op.name] +_resolve_rhs(driver, block, ref, op::RConst, bindings, typ) = op.val +function _resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, typ) + operands = Any[_resolve_rhs(driver, block, ref, sub, bindings, typ) for sub in op.operands] + inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ) + _notify_insert!(driver, block, inst.ssa_idx, op.func) + SSAValue(inst.ssa_idx) end -function _apply_rewrite!(sci, block, inst, rule, match, ctx, consumed) +function _apply_rewrite!(driver::RewriteDriver, block, inst_ssa, rule, match) + entry = driver.defs[inst_ssa] if rule.rhs isa RFunc - rule.rhs.func(sci, block, inst, match, ctx) || return false + # Look up live instruction for RFunc interface + pos = findfirst(==(inst_ssa), block.body.ssa_idxes) + pos === nothing && return false + inst = Instruction(inst_ssa, block.body.stmts[pos], block.body.types[pos]) + rule.rhs.func(driver.sci, block, inst, match, driver) || return false return true elseif rule.rhs isa RBind - # Forwarding: replace all uses of root with the bound value, delete root - replace_uses!(sci.entry, SSAValue(inst), match.bindings[rule.rhs.name]) - delete!(block, inst) - _cleanup_dead_operands!(sci, inst.ssa_idx, ctx, consumed) + # Forwarding: replace all uses of root with the bound value, delete root. + # Collect users BEFORE replace_uses! updates their operands. + _add_users_to_worklist!(driver, inst_ssa) + replace_uses!(driver.sci.entry, SSAValue(inst_ssa), match.bindings[rule.rhs.name]) + _erase_op!(driver, entry) else - # Substitution: delete matched intermediates, replace root statement in-place. + # Substitution: delete matched intermediates, replace root in-place. + # Only delete intermediates that have no remaining uses. + # Transparent-op tracing may have added intermediates to matched_ssas + # that have uses outside the matched chain. for dead_ssa in match.matched_ssas - dead_ssa == inst.ssa_idx && continue - entry = ctx.defs[dead_ssa] - delete!(entry.block, entry.inst) + dead_ssa == inst_ssa && continue + dead_entry = get(driver.defs, dead_ssa, nothing) + dead_entry === nothing && continue + _use_count(driver, SSAValue(dead_ssa)) == 0 || continue + _erase_op!(driver, dead_entry) end - typ = value_type(inst) - operands = Any[_resolve_rhs(block, SSAValue(inst), op, match.bindings, typ) + pos = findfirst(==(inst_ssa), block.body.ssa_idxes) + typ = block.body.types[pos] + operands = Any[_resolve_rhs(driver, block, SSAValue(inst_ssa), op, match.bindings, typ) for op in rule.rhs.operands] - pos = findfirst(==(inst.ssa_idx), block.body.ssa_idxes) block.body.stmts[pos] = Expr(:call, rule.rhs.func, operands...) - end -end - -"""Recursively erase dead pure operands after an instruction is deleted (MLIR-style).""" -function _cleanup_dead_operands!(sci, ssa_idx, ctx, consumed) - entry = get(ctx.defs, ssa_idx, nothing) - entry === nothing && return - for op in _def_operands(entry) - op isa SSAValue || continue - op_entry = get(ctx.defs, op.id, nothing) - op_entry === nothing && continue - _is_transparent(op_entry.func) || continue - op.id in consumed && continue - isempty(uses(sci.entry, op)) || continue - push!(consumed, op.id) - delete!(op_entry.block, op_entry.inst) - _cleanup_dead_operands!(sci, op.id, ctx, consumed) + # Update defs, re-add self and users to worklist (statement changed) + driver.defs[inst_ssa] = DefEntry(block, inst_ssa, rule.rhs.func) + push!(driver.worklist, inst_ssa) + _add_users_to_worklist!(driver, inst_ssa) end end @@ -290,43 +369,77 @@ end =============================================================================# """ - rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}) + rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; max_rewrites=10_000) -Apply declarative rewrite rules to the structured IR. Dead code left behind -is cleaned up by the pipeline's `dce_pass!`. +Apply rewrite rules to the structured IR using a worklist-based fixpoint driver. +Rules are tried until no more matches fire or `max_rewrites` is reached. +Dead code left behind is cleaned up by the pipeline's `dce_pass!`. """ -function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}) - ctx = MatchContext(sci) +function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; + max_rewrites::Int=10_000) + # Build dispatch table dispatch = Dict{Any, Vector{RewriteRule}}() for rule in rules push!(get!(dispatch, root_func(rule), RewriteRule[]), rule) end - consumed = Set{Int}() + # Build defs index + defs = Dict{Int, DefEntry}() for block in eachblock(sci) - for inst in collect(instructions(block)) - inst.ssa_idx in consumed && continue + for inst in instructions(block) call = resolve_call(inst) call === nothing && continue - applicable = get(dispatch, call[1], nothing) - applicable === nothing && continue - for rule in applicable - m = pattern_match(ctx, SSAValue(inst), rule.lhs) - m === nothing && continue - any(s in consumed for s in m.matched_ssas) && continue - rule.guard !== nothing && !rule.guard(m, ctx) && continue - result = _apply_rewrite!(sci, block, inst, rule, m, ctx, consumed) - result === false && continue # RFunc declined, try next rule - # RBind only deletes the root; intermediates stay live and - # must remain matchable. Substitution deletes intermediates - # so they must be consumed to prevent the driver revisiting. - if rule.rhs isa RBind - push!(consumed, inst.ssa_idx) - else - union!(consumed, m.matched_ssas) - end - break + func, _ = call + defs[inst.ssa_idx] = DefEntry(block, inst.ssa_idx, func) + end + end + + # Seed worklist (forward order → reversed by LIFO → processes top-down) + wl = Worklist() + for block in eachblock(sci) + for inst in instructions(block) + haskey(defs, inst.ssa_idx) && push!(wl, inst.ssa_idx) + end + end + + driver = RewriteDriver(sci, defs, dispatch, wl, max_rewrites) + + num_rewrites = 0 + while !isempty(driver.worklist) && num_rewrites < driver.max_rewrites + ssa_idx = pop!(driver.worklist)::Int + entry = get(driver.defs, ssa_idx, nothing) + entry === nothing && continue + + # Verify instruction is still live in its block + pos = findfirst(==(ssa_idx), entry.block.body.ssa_idxes) + pos === nothing && begin + delete!(driver.defs, ssa_idx) + continue + end + + # Trivial dead-op elimination: if this op has no uses and is pure, + # erase it. This keeps use counts accurate for `one_use` patterns + # (e.g., FMA fusion needs mulf's dead transparent-op users removed + # so the mulf reads as single-use). Full DCE handles the rest. + if _use_count(driver, SSAValue(ssa_idx)) == 0 + stmt = entry.block.body.stmts[pos] + if !must_keep(stmt) + _erase_op!(driver, entry) + continue end end + + # Look up applicable rules by function + applicable = get(driver.dispatch, entry.func, nothing) + applicable === nothing && continue + + for rule in applicable + m = pattern_match(driver, SSAValue(ssa_idx), rule.lhs) + m === nothing && continue + rule.guard !== nothing && !rule.guard(m, driver) && continue + _apply_rewrite!(driver, entry.block, ssa_idx, rule, m) + num_rewrites += 1 + break + end end end From c4e0b0b0a251131907d946f87eee37e8dbc9c323 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 30 Mar 2026 16:55:06 +0200 Subject: [PATCH 3/4] Use IRStructurizer's users() API in rewrite driver. Replace the manual defs-scanning workaround in _add_users_to_worklist! with the new users(block, val) API from IRStructurizer. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/passes/rewrite.jl | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/compiler/codegen/passes/rewrite.jl b/src/compiler/codegen/passes/rewrite.jl index 1d56a734..7005c1d9 100644 --- a/src/compiler/codegen/passes/rewrite.jl +++ b/src/compiler/codegen/passes/rewrite.jl @@ -196,15 +196,8 @@ end """Add instructions that use `val` to the worklist (their operand changed).""" function _add_users_to_worklist!(driver::RewriteDriver, ssa_idx::Int) - val = SSAValue(ssa_idx) - for (id, entry) in driver.defs - id == ssa_idx && continue - for op in _def_operands(entry) - if op isa SSAValue && op.id == ssa_idx - push!(driver.worklist, id) - break - end - end + for inst in users(driver.sci.entry, SSAValue(ssa_idx)) + push!(driver.worklist, inst.ssa_idx) end end From c11c65d14ae2b0984d9efbe61c45778c91eb342f Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 30 Mar 2026 18:07:33 +0200 Subject: [PATCH 4/4] Use SSAValue consistently instead of bare Int for SSA indices. The Worklist, DefEntry, defs dict, matched_ssas, and all notification/apply functions now use SSAValue instead of Int, eliminating the same kind of type confusion between IR references and literal integers that was fixed in IRStructurizer's normalize_key. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/passes/rewrite.jl | 139 +++++++++++++------------ 1 file changed, 73 insertions(+), 66 deletions(-) diff --git a/src/compiler/codegen/passes/rewrite.jl b/src/compiler/codegen/passes/rewrite.jl index 7005c1d9..c09a4848 100644 --- a/src/compiler/codegen/passes/rewrite.jl +++ b/src/compiler/codegen/passes/rewrite.jl @@ -114,33 +114,35 @@ end =============================================================================# mutable struct Worklist - list::Vector{Int} # SSA indices (-1 = removed sentinel) - member::Dict{Int, Int} # ssa_idx -> position in list + list::Vector{SSAValue} # entries (SSAValue(-1) = removed sentinel) + member::Dict{SSAValue, Int} # val -> position in list end -Worklist() = Worklist(Int[], Dict{Int, Int}()) +const _SENTINEL = SSAValue(-1) -function Base.push!(wl::Worklist, ssa_idx::Int) - haskey(wl.member, ssa_idx) && return - push!(wl.list, ssa_idx) - wl.member[ssa_idx] = length(wl.list) +Worklist() = Worklist(SSAValue[], Dict{SSAValue, Int}()) + +function Base.push!(wl::Worklist, val::SSAValue) + haskey(wl.member, val) && return + push!(wl.list, val) + wl.member[val] = length(wl.list) end function Base.pop!(wl::Worklist) while !isempty(wl.list) - idx = pop!(wl.list) - idx == -1 && continue - delete!(wl.member, idx) - return idx + val = pop!(wl.list) + val == _SENTINEL && continue + delete!(wl.member, val) + return val end return nothing end -function remove!(wl::Worklist, ssa_idx::Int) - pos = get(wl.member, ssa_idx, 0) +function remove!(wl::Worklist, val::SSAValue) + pos = get(wl.member, val, 0) pos == 0 && return - wl.list[pos] = -1 - delete!(wl.member, ssa_idx) + wl.list[pos] = _SENTINEL + delete!(wl.member, val) end Base.isempty(wl::Worklist) = isempty(wl.member) @@ -151,13 +153,13 @@ Base.isempty(wl::Worklist) = isempty(wl.member) struct DefEntry block::Block - ssa_idx::Int + val::SSAValue func::Any end """Operands of a DefEntry, read from the live IR.""" function _def_operands(entry::DefEntry) - pos = findfirst(==(entry.ssa_idx), entry.block.body.ssa_idxes) + pos = findfirst(==(entry.val.id), entry.block.body.ssa_idxes) pos === nothing && return Any[] call = resolve_call(entry.block.body.stmts[pos]) call === nothing && return Any[] @@ -167,7 +169,7 @@ end mutable struct RewriteDriver sci::StructuredIRCode - defs::Dict{Int, DefEntry} + defs::Dict{SSAValue, DefEntry} dispatch::Dict{Any, Vector{RewriteRule}} worklist::Worklist max_rewrites::Int @@ -190,34 +192,38 @@ _is_transparent(func) = func === Intrinsics.to_scalar || function _add_operands_to_worklist!(driver::RewriteDriver, entry::DefEntry) for op in _def_operands(entry) op isa SSAValue || continue - haskey(driver.defs, op.id) && push!(driver.worklist, op.id) + haskey(driver.defs, op) && push!(driver.worklist, op) end end """Add instructions that use `val` to the worklist (their operand changed).""" -function _add_users_to_worklist!(driver::RewriteDriver, ssa_idx::Int) - for inst in users(driver.sci.entry, SSAValue(ssa_idx)) - push!(driver.worklist, inst.ssa_idx) +function _add_users_to_worklist!(driver::RewriteDriver, val::SSAValue) + for inst in users(driver.sci.entry, val) + push!(driver.worklist, SSAValue(inst)) end end """Erase an instruction and notify the worklist.""" function _erase_op!(driver::RewriteDriver, entry::DefEntry) _add_operands_to_worklist!(driver, entry) - pos = findfirst(==(entry.ssa_idx), entry.block.body.ssa_idxes) + pos = findfirst(==(entry.val.id), entry.block.body.ssa_idxes) if pos !== nothing deleteat!(entry.block.body.ssa_idxes, pos) deleteat!(entry.block.body.stmts, pos) deleteat!(entry.block.body.types, pos) end - delete!(driver.defs, entry.ssa_idx) - remove!(driver.worklist, entry.ssa_idx) + delete!(driver.defs, entry.val) + remove!(driver.worklist, entry.val) end """Register a newly inserted instruction.""" -function _notify_insert!(driver::RewriteDriver, block::Block, ssa_idx::Int, func) - driver.defs[ssa_idx] = DefEntry(block, ssa_idx, func) - push!(driver.worklist, ssa_idx) +function _notify_insert!(driver::RewriteDriver, block::Block, inst::Instruction) + val = SSAValue(inst) + call = resolve_call(inst) + call === nothing && return + func, _ = call + driver.defs[val] = DefEntry(block, val, func) + push!(driver.worklist, val) end #============================================================================= @@ -226,7 +232,7 @@ end struct MatchResult bindings::Dict{Symbol, Any} - matched_ssas::Vector{Int} + matched_ssas::Vector{SSAValue} end """Merge bindings, requiring repeated names to bind the same value (=== equality).""" @@ -244,13 +250,13 @@ end function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, block::Block=driver.sci.entry) val isa SSAValue || return nothing - entry = get(driver.defs, val.id, nothing) + entry = get(driver.defs, val, nothing) entry === nothing && return nothing if entry.func === pat.func ops = _def_operands(entry) if length(ops) == length(pat.operands) - result = MatchResult(Dict{Symbol,Any}(), Int[val.id]) + result = MatchResult(Dict{Symbol,Any}(), SSAValue[val]) for (op, sub) in zip(ops, pat.operands) m = pattern_match(driver, op, sub, entry.block) m === nothing && return nothing @@ -269,7 +275,7 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, if entry.func === Intrinsics.broadcast inner = ops[1] if inner isa SSAValue - inner_entry = get(driver.defs, inner.id, nothing) + inner_entry = get(driver.defs, inner, nothing) if inner_entry !== nothing it = value_type(entry.block, inner) ot = value_type(entry.block, val) @@ -281,21 +287,21 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, end result = pattern_match(driver, ops[1], pat, entry.block) result === nothing && return nothing - push!(result.matched_ssas, val.id) + push!(result.matched_ssas, val) return result end return nothing end pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.sci.entry) = - MatchResult(Dict{Symbol,Any}(pat.name => val), Int[]) + MatchResult(Dict{Symbol,Any}(pat.name => val), SSAValue[]) function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PTypedBind, block::Block=driver.sci.entry) T = value_type(block, val) T === nothing && return nothing CC.widenconst(T) <: pat.type || return nothing - MatchResult(Dict{Symbol,Any}(pat.name => val), Int[]) + MatchResult(Dict{Symbol,Any}(pat.name => val), SSAValue[]) end function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::POneUse, @@ -314,46 +320,45 @@ _resolve_rhs(driver, block, ref, op::RConst, bindings, typ) = op.val function _resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, typ) operands = Any[_resolve_rhs(driver, block, ref, sub, bindings, typ) for sub in op.operands] inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ) - _notify_insert!(driver, block, inst.ssa_idx, op.func) - SSAValue(inst.ssa_idx) + _notify_insert!(driver, block, inst) + SSAValue(inst) end -function _apply_rewrite!(driver::RewriteDriver, block, inst_ssa, rule, match) - entry = driver.defs[inst_ssa] +function _apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match) + entry = driver.defs[val] if rule.rhs isa RFunc # Look up live instruction for RFunc interface - pos = findfirst(==(inst_ssa), block.body.ssa_idxes) + pos = findfirst(==(val.id), block.body.ssa_idxes) pos === nothing && return false - inst = Instruction(inst_ssa, block.body.stmts[pos], block.body.types[pos]) + inst = Instruction(val.id, block.body.stmts[pos], block.body.types[pos]) rule.rhs.func(driver.sci, block, inst, match, driver) || return false return true elseif rule.rhs isa RBind # Forwarding: replace all uses of root with the bound value, delete root. # Collect users BEFORE replace_uses! updates their operands. - _add_users_to_worklist!(driver, inst_ssa) - replace_uses!(driver.sci.entry, SSAValue(inst_ssa), match.bindings[rule.rhs.name]) + _add_users_to_worklist!(driver, val) + replace_uses!(driver.sci.entry, val, match.bindings[rule.rhs.name]) _erase_op!(driver, entry) else - # Substitution: delete matched intermediates, replace root in-place. - # Only delete intermediates that have no remaining uses. - # Transparent-op tracing may have added intermediates to matched_ssas - # that have uses outside the matched chain. - for dead_ssa in match.matched_ssas - dead_ssa == inst_ssa && continue - dead_entry = get(driver.defs, dead_ssa, nothing) + # Substitution: replace root in-place, clean up dead intermediates. + # Only delete intermediates with no remaining uses — transparent-op + # tracing may have added multi-use intermediates to matched_ssas. + for dead_val in match.matched_ssas + dead_val == val && continue + dead_entry = get(driver.defs, dead_val, nothing) dead_entry === nothing && continue - _use_count(driver, SSAValue(dead_ssa)) == 0 || continue + _use_count(driver, dead_val) == 0 || continue _erase_op!(driver, dead_entry) end - pos = findfirst(==(inst_ssa), block.body.ssa_idxes) + pos = findfirst(==(val.id), block.body.ssa_idxes) typ = block.body.types[pos] - operands = Any[_resolve_rhs(driver, block, SSAValue(inst_ssa), op, match.bindings, typ) + operands = Any[_resolve_rhs(driver, block, val, op, match.bindings, typ) for op in rule.rhs.operands] block.body.stmts[pos] = Expr(:call, rule.rhs.func, operands...) # Update defs, re-add self and users to worklist (statement changed) - driver.defs[inst_ssa] = DefEntry(block, inst_ssa, rule.rhs.func) - push!(driver.worklist, inst_ssa) - _add_users_to_worklist!(driver, inst_ssa) + driver.defs[val] = DefEntry(block, val, rule.rhs.func) + push!(driver.worklist, val) + _add_users_to_worklist!(driver, val) end end @@ -377,13 +382,14 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; end # Build defs index - defs = Dict{Int, DefEntry}() + defs = Dict{SSAValue, DefEntry}() for block in eachblock(sci) for inst in instructions(block) call = resolve_call(inst) call === nothing && continue func, _ = call - defs[inst.ssa_idx] = DefEntry(block, inst.ssa_idx, func) + val = SSAValue(inst) + defs[val] = DefEntry(block, val, func) end end @@ -391,7 +397,8 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; wl = Worklist() for block in eachblock(sci) for inst in instructions(block) - haskey(defs, inst.ssa_idx) && push!(wl, inst.ssa_idx) + val = SSAValue(inst) + haskey(defs, val) && push!(wl, val) end end @@ -399,14 +406,14 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; num_rewrites = 0 while !isempty(driver.worklist) && num_rewrites < driver.max_rewrites - ssa_idx = pop!(driver.worklist)::Int - entry = get(driver.defs, ssa_idx, nothing) + val = pop!(driver.worklist)::SSAValue + entry = get(driver.defs, val, nothing) entry === nothing && continue # Verify instruction is still live in its block - pos = findfirst(==(ssa_idx), entry.block.body.ssa_idxes) + pos = findfirst(==(val.id), entry.block.body.ssa_idxes) pos === nothing && begin - delete!(driver.defs, ssa_idx) + delete!(driver.defs, val) continue end @@ -414,7 +421,7 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; # erase it. This keeps use counts accurate for `one_use` patterns # (e.g., FMA fusion needs mulf's dead transparent-op users removed # so the mulf reads as single-use). Full DCE handles the rest. - if _use_count(driver, SSAValue(ssa_idx)) == 0 + if _use_count(driver, val) == 0 stmt = entry.block.body.stmts[pos] if !must_keep(stmt) _erase_op!(driver, entry) @@ -427,10 +434,10 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; applicable === nothing && continue for rule in applicable - m = pattern_match(driver, SSAValue(ssa_idx), rule.lhs) + m = pattern_match(driver, val, rule.lhs) m === nothing && continue rule.guard !== nothing && !rule.guard(m, driver) && continue - _apply_rewrite!(driver, entry.block, ssa_idx, rule, m) + _apply_rewrite!(driver, entry.block, val, rule, m) num_rewrites += 1 break end