diff --git a/src/compiler/codegen/passes/pipeline.jl b/src/compiler/codegen/passes/pipeline.jl index 450706f0..3a7abb7b 100644 --- a/src/compiler/codegen/passes/pipeline.jl +++ b/src/compiler/codegen/passes/pipeline.jl @@ -111,6 +111,31 @@ const FMA_RULES = RewriteRule[ fma_fusion_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, FMA_RULES) +#============================================================================= + Algebraic Simplification (rewrite) +=============================================================================# + +# Cancel inverse addi/subi pairs: x+c-c → x, x-c+c → x. +# Repeated ~c binds enforce that both operands are the same value. + +const ALGEBRA_RULES = RewriteRule[ + @rewrite Intrinsics.subi(Intrinsics.addi(~x, ~c), ~c) => ~x + @rewrite Intrinsics.addi(Intrinsics.subi(~x, ~c), ~c) => ~x +] + +algebra_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, ALGEBRA_RULES) + +#============================================================================= + Combined Rule Set +=============================================================================# + +const ALL_REWRITE_RULES = RewriteRule[ + NORMALIZE_RULES..., + ALGEBRA_RULES..., + SVE_RULES..., + FMA_RULES..., +] + #============================================================================= Pass Pipeline =============================================================================# @@ -122,10 +147,8 @@ Run the full pass pipeline on a StructuredIRCode. Called for both kernel and subprogram compilation. """ function run_passes!(sci::StructuredIRCode) - # Rewrite passes (order matters: normalize before optimize, SVE before FMA) - normalize_pass!(sci) - scalar_view_elim_pass!(sci) - fma_fusion_pass!(sci) + # All rewrite rules in one fixpoint pass + rewrite_patterns!(sci, ALL_REWRITE_RULES) # Memory ordering alias_result = alias_analysis_pass!(sci) diff --git a/src/compiler/codegen/passes/rewrite.jl b/src/compiler/codegen/passes/rewrite.jl index bfa15323..c09a4848 100644 --- a/src/compiler/codegen/passes/rewrite.jl +++ b/src/compiler/codegen/passes/rewrite.jl @@ -1,8 +1,10 @@ # Declarative IR Rewrite Pattern Framework # -# Inspired by MLIR's PDLL. Patterns compile into pattern/rewrite node trees. -# The framework handles matching (recursive SSA def-chain walking) and rewrite -# application. Cleanup of dead code is delegated to the pipeline's dce_pass!. +# Worklist-based fixpoint driver inspired by MLIR's GreedyPatternRewriteDriver. +# Patterns compile into pattern/rewrite node trees. The driver processes a LIFO +# worklist until fixpoint: when a rewrite fires, affected instructions are +# re-added to the worklist for further matching. Dead code cleanup is delegated +# to the pipeline's dce_pass!. # # Usage: # rules = RewriteRule[ @@ -33,17 +35,16 @@ struct RConst <: RewriteNode; val::Any; end """ RFunc(func) -Imperative rewrite node (MLIR-inspired). The function is called with -`(sci, block, inst, match, ctx)` and returns `true` if the rewrite was applied, -`false` to skip this rule and try the next one. `ctx` is the `MatchContext` -providing def-chain lookups via `ctx.defs`. +Imperative rewrite node. The function is called with +`(sci, block, inst, match, driver)` and returns `true` if the rewrite was +applied, `false` to skip this rule and try the next one. """ struct RFunc <: RewriteNode; func::Function; end struct RewriteRule lhs::PCall rhs::RewriteNode - guard::Union{Function, Nothing} # (match, ctx) -> Bool, or nothing + guard::Union{Function, Nothing} # (match, driver) -> Bool, or nothing end RewriteRule(lhs::PCall, rhs::RewriteNode) = RewriteRule(lhs, rhs, nothing) @@ -61,7 +62,7 @@ Compile a declarative rewrite rule. LHS: `func(args...)` matches calls, `~x` binds (repeated names require equality), `~x::T` binds with type constraint, `one_use(pat)` requires single use. RHS: `func(args...)` emits calls, `~x` references bindings, `\$(expr)` injects a literal constant. -Optional `guard` is a function `(match, ctx) -> Bool` checked after pattern match. +Optional `guard` is a function `(match, driver) -> Bool` checked after pattern match. """ macro rewrite(ex, guard=nothing) ex isa Expr && ex.head === :call && ex.args[1] === :(=>) || @@ -74,10 +75,9 @@ end @rewriter lhs => func Declarative pattern with imperative rewrite. LHS uses the same pattern syntax as -`@rewrite`. RHS is a function `(sci, block, inst, match, ctx) -> Bool` that +`@rewrite`. RHS is a function `(sci, block, inst, match, driver) -> Bool` that performs the rewrite and returns `true`, or returns `false` to skip and try the -next rule. Matched bindings are in `match.bindings`; def-chain lookups via -`ctx.defs`. +next rule. """ macro rewriter(ex) ex isa Expr && ex.head === :call && ex.args[1] === :(=>) || @@ -110,48 +110,131 @@ function _compile_rhs(ex) end #============================================================================= - Matching + Worklist =============================================================================# -struct MatchResult - bindings::Dict{Symbol, Any} - matched_ssas::Vector{Int} +mutable struct Worklist + list::Vector{SSAValue} # entries (SSAValue(-1) = removed sentinel) + member::Dict{SSAValue, Int} # val -> position in list end +const _SENTINEL = SSAValue(-1) + +Worklist() = Worklist(SSAValue[], Dict{SSAValue, Int}()) + +function Base.push!(wl::Worklist, val::SSAValue) + haskey(wl.member, val) && return + push!(wl.list, val) + wl.member[val] = length(wl.list) +end + +function Base.pop!(wl::Worklist) + while !isempty(wl.list) + val = pop!(wl.list) + val == _SENTINEL && continue + delete!(wl.member, val) + return val + end + return nothing +end + +function remove!(wl::Worklist, val::SSAValue) + pos = get(wl.member, val, 0) + pos == 0 && return + wl.list[pos] = _SENTINEL + delete!(wl.member, val) +end + +Base.isempty(wl::Worklist) = isempty(wl.member) + +#============================================================================= + Driver State +=============================================================================# + struct DefEntry block::Block - inst::Instruction + val::SSAValue func::Any - operands::Vector{Any} end -struct MatchContext - entry::Block # root block of the StructuredIRCode (for value_type lookups) - defs::Dict{Int, DefEntry} - use_index # UseIndex from IRStructurizer +"""Operands of a DefEntry, read from the live IR.""" +function _def_operands(entry::DefEntry) + pos = findfirst(==(entry.val.id), entry.block.body.ssa_idxes) + pos === nothing && return Any[] + call = resolve_call(entry.block.body.stmts[pos]) + call === nothing && return Any[] + _, ops = call + return ops end -function MatchContext(sci::StructuredIRCode) - defs = Dict{Int, DefEntry}() - for block in eachblock(sci) - for inst in instructions(block) - call = resolve_call(inst) - call === nothing && continue - func, operands = call - defs[inst.ssa_idx] = DefEntry(block, inst, func, collect(Any, operands)) - end - end - MatchContext(sci.entry, defs, uses(sci.entry)) +mutable struct RewriteDriver + sci::StructuredIRCode + defs::Dict{SSAValue, DefEntry} + dispatch::Dict{Any, Vector{RewriteRule}} + worklist::Worklist + max_rewrites::Int end -_use_count(ctx::MatchContext, val::SSAValue) = - haskey(ctx.use_index, val) ? length(ctx.use_index[val]) : 0 +"""Compute fresh use count for an SSA value.""" +_use_count(driver::RewriteDriver, val::SSAValue) = + length(uses(driver.sci.entry, val)) # Codegen no-ops that pattern matching traces through transparently. _is_transparent(func) = func === Intrinsics.to_scalar || func === Intrinsics.from_scalar || func === Intrinsics.broadcast +#============================================================================= + Notifications +=============================================================================# + +"""Add operand-producing instructions to the worklist (enables cascading).""" +function _add_operands_to_worklist!(driver::RewriteDriver, entry::DefEntry) + for op in _def_operands(entry) + op isa SSAValue || continue + haskey(driver.defs, op) && push!(driver.worklist, op) + end +end + +"""Add instructions that use `val` to the worklist (their operand changed).""" +function _add_users_to_worklist!(driver::RewriteDriver, val::SSAValue) + for inst in users(driver.sci.entry, val) + push!(driver.worklist, SSAValue(inst)) + end +end + +"""Erase an instruction and notify the worklist.""" +function _erase_op!(driver::RewriteDriver, entry::DefEntry) + _add_operands_to_worklist!(driver, entry) + pos = findfirst(==(entry.val.id), entry.block.body.ssa_idxes) + if pos !== nothing + deleteat!(entry.block.body.ssa_idxes, pos) + deleteat!(entry.block.body.stmts, pos) + deleteat!(entry.block.body.types, pos) + end + delete!(driver.defs, entry.val) + remove!(driver.worklist, entry.val) +end + +"""Register a newly inserted instruction.""" +function _notify_insert!(driver::RewriteDriver, block::Block, inst::Instruction) + val = SSAValue(inst) + call = resolve_call(inst) + call === nothing && return + func, _ = call + driver.defs[val] = DefEntry(block, val, func) + push!(driver.worklist, val) +end + +#============================================================================= + Matching +=============================================================================# + +struct MatchResult + bindings::Dict{Symbol, Any} + matched_ssas::Vector{SSAValue} +end + """Merge bindings, requiring repeated names to bind the same value (=== equality).""" function _merge_bindings!(dest::Dict{Symbol,Any}, src::Dict{Symbol,Any}) for (k, v) in src @@ -164,60 +247,67 @@ function _merge_bindings!(dest::Dict{Symbol,Any}, src::Dict{Symbol,Any}) return true end -function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PCall, - block::Block=ctx.entry) +function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall, + block::Block=driver.sci.entry) val isa SSAValue || return nothing - entry = get(ctx.defs, val.id, nothing) + entry = get(driver.defs, val, nothing) entry === nothing && return nothing - if entry.func === pat.func && length(entry.operands) == length(pat.operands) - result = MatchResult(Dict{Symbol,Any}(), Int[val.id]) - for (op, sub) in zip(entry.operands, pat.operands) - m = pattern_match(ctx, op, sub, entry.block) - m === nothing && return nothing - _merge_bindings!(result.bindings, m.bindings) || return nothing - append!(result.matched_ssas, m.matched_ssas) + if entry.func === pat.func + ops = _def_operands(entry) + if length(ops) == length(pat.operands) + result = MatchResult(Dict{Symbol,Any}(), SSAValue[val]) + for (op, sub) in zip(ops, pat.operands) + m = pattern_match(driver, op, sub, entry.block) + m === nothing && return nothing + _merge_bindings!(result.bindings, m.bindings) || return nothing + append!(result.matched_ssas, m.matched_ssas) + end + return result end - return result end # Trace through single-use transparent ops to find the underlying operation - if _is_transparent(entry.func) && !isempty(entry.operands) - _use_count(ctx, val) == 1 || return nothing + if _is_transparent(entry.func) + _use_count(driver, val) == 1 || return nothing + ops = _def_operands(entry) + isempty(ops) && return nothing if entry.func === Intrinsics.broadcast - inner = entry.operands[1] + inner = ops[1] if inner isa SSAValue - inner_entry = get(ctx.defs, inner.id, nothing) + inner_entry = get(driver.defs, inner, nothing) if inner_entry !== nothing - it, ot = value_type(inner_entry.inst), value_type(entry.inst) - it <: Tile && ot <: Tile || return nothing - size(it) == size(ot) || return nothing + it = value_type(entry.block, inner) + ot = value_type(entry.block, val) + it !== nothing && ot !== nothing || return nothing + CC.widenconst(it) <: Tile && CC.widenconst(ot) <: Tile || return nothing + size(CC.widenconst(it)) == size(CC.widenconst(ot)) || return nothing end end end - result = pattern_match(ctx, entry.operands[1], pat, entry.block) + result = pattern_match(driver, ops[1], pat, entry.block) result === nothing && return nothing - push!(result.matched_ssas, val.id) + push!(result.matched_ssas, val) return result end return nothing end -pattern_match(ctx::MatchContext, @nospecialize(val), pat::PBind, block::Block=ctx.entry) = - MatchResult(Dict{Symbol,Any}(pat.name => val), Int[]) +pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.sci.entry) = + MatchResult(Dict{Symbol,Any}(pat.name => val), SSAValue[]) -function pattern_match(ctx::MatchContext, @nospecialize(val), pat::PTypedBind, - block::Block=ctx.entry) +function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PTypedBind, + block::Block=driver.sci.entry) T = value_type(block, val) T === nothing && return nothing CC.widenconst(T) <: pat.type || return nothing - MatchResult(Dict{Symbol,Any}(pat.name => val), Int[]) + MatchResult(Dict{Symbol,Any}(pat.name => val), SSAValue[]) end -function pattern_match(ctx::MatchContext, @nospecialize(val), pat::POneUse, - block::Block=ctx.entry) - val isa SSAValue && _use_count(ctx, val) == 1 || return nothing - pattern_match(ctx, val, pat.inner, block) +function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::POneUse, + block::Block=driver.sci.entry) + val isa SSAValue && _use_count(driver, val) == 1 || return nothing + pattern_match(driver, val, pat.inner, block) end #============================================================================= @@ -225,52 +315,50 @@ end =============================================================================# """Resolve an RHS operand, inserting sub-calls before `ref` as needed.""" -_resolve_rhs(block, ref, op::RBind, bindings, typ) = bindings[op.name] -_resolve_rhs(block, ref, op::RConst, bindings, typ) = op.val -function _resolve_rhs(block, ref, op::RCall, bindings, typ) - operands = Any[_resolve_rhs(block, ref, sub, bindings, typ) for sub in op.operands] - SSAValue(insert_before!(block, ref, Expr(:call, op.func, operands...), typ)) +_resolve_rhs(driver, block, ref, op::RBind, bindings, typ) = bindings[op.name] +_resolve_rhs(driver, block, ref, op::RConst, bindings, typ) = op.val +function _resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, typ) + operands = Any[_resolve_rhs(driver, block, ref, sub, bindings, typ) for sub in op.operands] + inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ) + _notify_insert!(driver, block, inst) + SSAValue(inst) end -function _apply_rewrite!(sci, block, inst, rule, match, ctx, consumed) +function _apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match) + entry = driver.defs[val] if rule.rhs isa RFunc - rule.rhs.func(sci, block, inst, match, ctx) || return false + # Look up live instruction for RFunc interface + pos = findfirst(==(val.id), block.body.ssa_idxes) + pos === nothing && return false + inst = Instruction(val.id, block.body.stmts[pos], block.body.types[pos]) + rule.rhs.func(driver.sci, block, inst, match, driver) || return false return true elseif rule.rhs isa RBind - # Forwarding: replace all uses of root with the bound value, delete root - replace_uses!(sci.entry, SSAValue(inst), match.bindings[rule.rhs.name]) - delete!(block, inst) - _cleanup_dead_operands!(sci, inst.ssa_idx, ctx, consumed) + # Forwarding: replace all uses of root with the bound value, delete root. + # Collect users BEFORE replace_uses! updates their operands. + _add_users_to_worklist!(driver, val) + replace_uses!(driver.sci.entry, val, match.bindings[rule.rhs.name]) + _erase_op!(driver, entry) else - # Substitution: delete matched intermediates, replace root statement in-place. - # No operand cleanup here — the new replacement may reference these operands. - for dead_ssa in match.matched_ssas - dead_ssa == inst.ssa_idx && continue - entry = ctx.defs[dead_ssa] - delete!(entry.block, entry.inst) + # Substitution: replace root in-place, clean up dead intermediates. + # Only delete intermediates with no remaining uses — transparent-op + # tracing may have added multi-use intermediates to matched_ssas. + for dead_val in match.matched_ssas + dead_val == val && continue + dead_entry = get(driver.defs, dead_val, nothing) + dead_entry === nothing && continue + _use_count(driver, dead_val) == 0 || continue + _erase_op!(driver, dead_entry) end - typ = value_type(inst) - operands = Any[_resolve_rhs(block, SSAValue(inst), op, match.bindings, typ) + pos = findfirst(==(val.id), block.body.ssa_idxes) + typ = block.body.types[pos] + operands = Any[_resolve_rhs(driver, block, val, op, match.bindings, typ) for op in rule.rhs.operands] - pos = findfirst(==(inst.ssa_idx), block.body.ssa_idxes) block.body.stmts[pos] = Expr(:call, rule.rhs.func, operands...) - end -end - -"""Recursively erase dead pure operands after an instruction is deleted (MLIR-style).""" -function _cleanup_dead_operands!(sci, ssa_idx, ctx, consumed) - entry = get(ctx.defs, ssa_idx, nothing) - entry === nothing && return - for op in entry.operands - op isa SSAValue || continue - op_entry = get(ctx.defs, op.id, nothing) - op_entry === nothing && continue - _is_transparent(op_entry.func) || continue - op.id in consumed && continue - isempty(uses(sci.entry, op)) || continue - push!(consumed, op.id) - delete!(op_entry.block, op_entry.inst) - _cleanup_dead_operands!(sci, op.id, ctx, consumed) + # Update defs, re-add self and users to worklist (statement changed) + driver.defs[val] = DefEntry(block, val, rule.rhs.func) + push!(driver.worklist, val) + _add_users_to_worklist!(driver, val) end end @@ -279,36 +367,79 @@ end =============================================================================# """ - rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}) + rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; max_rewrites=10_000) -Apply declarative rewrite rules to the structured IR. Dead code left behind -is cleaned up by the pipeline's `dce_pass!`. +Apply rewrite rules to the structured IR using a worklist-based fixpoint driver. +Rules are tried until no more matches fire or `max_rewrites` is reached. +Dead code left behind is cleaned up by the pipeline's `dce_pass!`. """ -function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}) - ctx = MatchContext(sci) +function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule}; + max_rewrites::Int=10_000) + # Build dispatch table dispatch = Dict{Any, Vector{RewriteRule}}() for rule in rules push!(get!(dispatch, root_func(rule), RewriteRule[]), rule) end - consumed = Set{Int}() + # Build defs index + defs = Dict{SSAValue, DefEntry}() for block in eachblock(sci) - for inst in collect(instructions(block)) - inst.ssa_idx in consumed && continue + for inst in instructions(block) call = resolve_call(inst) call === nothing && continue - applicable = get(dispatch, call[1], nothing) - applicable === nothing && continue - for rule in applicable - m = pattern_match(ctx, SSAValue(inst), rule.lhs) - m === nothing && continue - any(s in consumed for s in m.matched_ssas) && continue - rule.guard !== nothing && !rule.guard(m, ctx) && continue - result = _apply_rewrite!(sci, block, inst, rule, m, ctx, consumed) - result === false && continue # RFunc declined, try next rule - union!(consumed, m.matched_ssas) - break + func, _ = call + val = SSAValue(inst) + defs[val] = DefEntry(block, val, func) + end + end + + # Seed worklist (forward order → reversed by LIFO → processes top-down) + wl = Worklist() + for block in eachblock(sci) + for inst in instructions(block) + val = SSAValue(inst) + haskey(defs, val) && push!(wl, val) + end + end + + driver = RewriteDriver(sci, defs, dispatch, wl, max_rewrites) + + num_rewrites = 0 + while !isempty(driver.worklist) && num_rewrites < driver.max_rewrites + val = pop!(driver.worklist)::SSAValue + entry = get(driver.defs, val, nothing) + entry === nothing && continue + + # Verify instruction is still live in its block + pos = findfirst(==(val.id), entry.block.body.ssa_idxes) + pos === nothing && begin + delete!(driver.defs, val) + continue + end + + # Trivial dead-op elimination: if this op has no uses and is pure, + # erase it. This keeps use counts accurate for `one_use` patterns + # (e.g., FMA fusion needs mulf's dead transparent-op users removed + # so the mulf reads as single-use). Full DCE handles the rest. + if _use_count(driver, val) == 0 + stmt = entry.block.body.stmts[pos] + if !must_keep(stmt) + _erase_op!(driver, entry) + continue end end + + # Look up applicable rules by function + applicable = get(driver.dispatch, entry.func, nothing) + applicable === nothing && continue + + for rule in applicable + m = pattern_match(driver, val, rule.lhs) + m === nothing && continue + rule.guard !== nothing && !rule.guard(m, driver) && continue + _apply_rewrite!(driver, entry.block, val, rule, m) + num_rewrites += 1 + break + end end end