Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/compiler/transform/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,25 @@ function commute_arith_transparent(sci, block, inst, match, driver)
# Insert broadcast of the scalar to x's shape and register as constant
x_shape = size(xT)
bc_type = Tile{eltype(xT), Tuple{x_shape...}}
bc = insert_before!(block, val, Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type;
bc = insert_before!(driver.rewriter, block, val,
Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type;
flag=inferred_flags(Intrinsics.broadcast))
notify_insert!(driver, block, bc)
# Side-inject the freshly synthesized constant into the dataflow result so
# downstream pattern matches see it. Bypasses tmerge (this is a brand-new
# SSA value, not a merge).
driver.constants[SSAValue(bc)] = convert(eltype(xT), scalar)

# Insert op(x, broadcast) with x's type
op = insert_before!(block, val, Expr(:call, root_func, x, SSAValue(bc)), xT;
op = insert_before!(driver.rewriter, block, val,
Expr(:call, root_func, x, SSAValue(bc)), xT;
flag=inferred_flags(root_func))
notify_insert!(driver, block, op)

# Replace root with transparent_op(op_result, s). Func changes
# (subi/addi → reshape/broadcast), so recompute the flag from the new
# func's declared effects — the inferred bits describe the OLD op.
block[val.id] = (stmt=Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]),
flag=inferred_flags(transparent_func))
driver.defs[val] = DefEntry(block, val, transparent_func)
replace_stmt!(driver.rewriter, block, val,
Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]);
flag=inferred_flags(transparent_func))
push!(driver.worklist, val)
add_users_to_worklist!(driver, val)
return true
Expand Down
163 changes: 86 additions & 77 deletions src/compiler/transform/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ function def_operands(entry::DefEntry)
end

mutable struct RewriteDriver
sci::StructuredIRCode
rewriter::Rewriter
defs::Dict{SSAValue, DefEntry}
dispatch::Dict{Any, Vector{RewriteRule}}
worklist::Worklist
Expand All @@ -226,47 +226,66 @@ mutable struct RewriteDriver
max_rewrites::Int
end

"""Compute fresh use count for an SSA value."""
use_count(driver::RewriteDriver, val::SSAValue) =
length(uses(driver.sci.entry, val))

#=============================================================================
Notifications
Rewriter listener: driver-side bookkeeping reacts to IR mutations
=============================================================================#

"""Add operand-producing instructions to the worklist (enables cascading)."""
function add_operands_to_worklist!(driver::RewriteDriver, entry::DefEntry)
for op in def_operands(entry)
op isa SSAValue || continue
haskey(driver.defs, op) && push!(driver.worklist, op)
end
# The driver has its own state on top of the Rewriter's index: a `defs` map
# keyed by func, the `worklist`, and a `modified` set for cascading. These
# `notify_*` hooks let the Rewriter keep that state in sync, so the driver
# doesn't need to wrap every mutation callsite. Same pattern as MLIR's
# `RewriterBase::Listener::notifyOperation{Inserted,Modified,Erased}`.

function notify_inserted!(d::RewriteDriver, block::Block, inst::Instruction)
stmt = inst[:stmt]
call = resolve_call(block, stmt)
call === nothing && return
func, _ = call
val = SSAValue(inst)
d.defs[val] = DefEntry(block, val, func)
push!(d.worklist, val)
end

"""Add instructions that use `val` to the worklist (their operand changed)."""
function add_users_to_worklist!(driver::RewriteDriver, val::SSAValue)
for inst in users(driver.sci.entry, val)
push!(driver.worklist, SSAValue(inst))
function notify_modified!(d::RewriteDriver, block::Block, val::SSAValue,
@nospecialize(old_stmt), @nospecialize(new_stmt))
# Refresh the def entry so worklist dispatch picks the new func.
call = resolve_call(block, new_stmt)
if call !== nothing
func, _ = call
d.defs[val] = DefEntry(block, val, func)
end
# Worklist cascading is done by the callers (`apply_rewrite!`,
# `apply_inplace_rewrite!`, `commute_arith_transparent`), which re-seed
# `val` and its users themselves.
end

"""Erase an instruction and notify the worklist."""
function erase_op!(driver::RewriteDriver, entry::DefEntry)
add_operands_to_worklist!(driver, entry)
if haskey(entry.block, entry.val.id)
delete!(entry.block, entry.val.id)
function notify_erased!(d::RewriteDriver, ::Block, val::SSAValue,
@nospecialize(old_stmt))
# Operand-defs may now be dead; cascade them to the worklist.
if old_stmt isa Expr
for_expr_operands(old_stmt) do op
op isa SSAValue || return
haskey(d.defs, op) && push!(d.worklist, op)
end
end
delete!(driver.defs, entry.val)
remove!(driver.worklist, entry.val)
delete!(d.defs, val)
remove!(d.worklist, val)
delete!(d.modified, val)
end

"""Register a newly inserted instruction."""
function notify_insert!(driver::RewriteDriver, block::Block, inst::Instruction)
val = SSAValue(inst)
call = resolve_call(block, inst)
call === nothing && return
func, _ = call
driver.defs[val] = DefEntry(block, val, func)
push!(driver.worklist, val)
#=============================================================================
Driver-side query helpers (thin proxies to the Rewriter)
=============================================================================#

users_of(driver::RewriteDriver, val::SSAValue) = users(driver.rewriter, val)

use_count(driver::RewriteDriver, val::SSAValue) = use_count(driver.rewriter, val)

"""Add instructions that use `val` to the worklist (their operand changed)."""
function add_users_to_worklist!(driver::RewriteDriver, val::SSAValue)
for u in users_of(driver, val)
push!(driver.worklist, u)
end
end

#=============================================================================
Expand All @@ -291,7 +310,7 @@ function merge_bindings!(dest::Dict{Symbol,Any}, src::Dict{Symbol,Any})
end

function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall,
block::Block=driver.sci.entry)
block::Block=driver.rewriter.sci.entry)
val isa SSAValue || return nothing
entry = get(driver.defs, val, nothing)
entry === nothing && return nothing
Expand Down Expand Up @@ -322,19 +341,19 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall,
return nothing
end

pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.sci.entry) =
pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PBind, block::Block=driver.rewriter.sci.entry) =
MatchResult(Dict{Symbol,Any}(pat.name => val), SSAValue[])

function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PTypedBind,
block::Block=driver.sci.entry)
block::Block=driver.rewriter.sci.entry)
T = value_type(block, val)
T === nothing && return nothing
CC.widenconst(T) <: pat.type || return nothing
MatchResult(Dict{Symbol,Any}(pat.name => val), SSAValue[])
end

function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::POneUse,
block::Block=driver.sci.entry)
block::Block=driver.rewriter.sci.entry)
val isa SSAValue && use_count(driver, val) == 1 || return nothing
pattern_match(driver, val, pat.inner, block)
end
Expand All @@ -343,7 +362,7 @@ end
# For non-SSA operands (enum constants, predicates): checks ===.
# For SSA operands: routed through const_value on the ConstantAnalysis result.
function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PLiteral,
block::Block=driver.sci.entry)
block::Block=driver.rewriter.sci.entry)
val === pat.val && return MatchResult(Dict{Symbol,Any}(), SSAValue[])
if val isa SSAValue
c = const_value(driver.constants, val)
Expand Down Expand Up @@ -384,9 +403,8 @@ function resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, roo
typ = CC.widenconst(t)
break
end
inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ;
inst = insert_before!(driver.rewriter, block, ref, Expr(:call, op.func, operands...), typ;
flag=inferred_flags(op.func))
notify_insert!(driver, block, inst)
SSAValue(inst)
end

Expand All @@ -406,12 +424,8 @@ function apply_inplace_rewrite!(driver::RewriteDriver, block, val::SSAValue, rul
# the flag from the new func's declared effects (`efunc` overrides),
# mirroring inference's `flags_for_effects` — the inferred bits on the
# old call describe the OLD op and don't carry over.
if rule.rhs.func === driver.defs[val].func
block[val.id] = (stmt=new_stmt,)
else
block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func))
end
driver.defs[val] = DefEntry(block, val, rule.rhs.func)
flag = rule.rhs.func === driver.defs[val].func ? nothing : inferred_flags(rule.rhs.func)
replace_stmt!(driver.rewriter, block, val, new_stmt; flag)
push!(driver.worklist, val)
add_users_to_worklist!(driver, val)
return true
Expand All @@ -438,8 +452,8 @@ function resolve_inplace_rhs(driver, bindings, op::RCall, lhs_op::PCall)
for (sub_rhs, sub_lhs) in zip(op.operands, lhs_op.operands)]
# Sub-call func is enforced equal to the matched lhs func above
# (`op.func === lhs_op.func`), so the flag is still valid — only operands
# changed. Partial-NamedTuple setindex preserves type and flag.
entry.block[matched_ssa.id] = (stmt=Expr(:call, op.func, new_ops...),)
# changed.
replace_stmt!(driver.rewriter, entry.block, matched_ssa, Expr(:call, op.func, new_ops...))
push!(driver.worklist, matched_ssa)
return matched_ssa
end
Expand All @@ -454,25 +468,22 @@ The matched_ssas in MatchResult are ordered root-first, but we need to find
the specific SSA for a sub-pattern. We do this by looking up the first operand's
binding and finding the op that defines it."""
function find_matched_ssa(driver, pat::PCall, bindings)
entry = driver.sci.entry
for sub in pat.operands
if sub isa PBind
bound = get(bindings, sub.name, nothing)
bound isa SSAValue || continue
for inst in users(entry, bound)
call = resolve_call(entry, inst)
call === nothing && continue
func, _ = call
func === pat.func && return SSAValue(inst)
for u in users_of(driver, bound)
def_entry = get(driver.defs, u, nothing)
def_entry === nothing && continue
def_entry.func === pat.func && return u
end
elseif sub isa PCall
inner_ssa = find_matched_ssa(driver, sub, bindings)
if inner_ssa !== nothing
for inst in users(entry, inner_ssa)
call = resolve_call(entry, inst)
call === nothing && continue
func, _ = call
func === pat.func && return SSAValue(inst)
for u in users_of(driver, inner_ssa)
def_entry = get(driver.defs, u, nothing)
def_entry === nothing && continue
def_entry.func === pat.func && return u
end
end
end
Expand All @@ -491,20 +502,21 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match
# Look up live instruction for RFunc interface
haskey(block, val.id) || return false
inst = block[val.id]
rule.rhs.func(driver.sci, block, inst, match, driver) || return false
rule.rhs.func(driver.rewriter.sci, block, inst, match, driver) || return false
return true
elseif rule.rhs isa RBind
# Forwarding: replace all uses of root with the bound value, delete root.
# Mark immediate users as modified — their operands are about to change.
# When these are later popped from the worklist without a match, the
# driver propagates to THEIR users (see modified check in main loop).
# This gives MLIR-style notifyOperationModified cascading.
for inst in users(driver.sci.entry, val)
push!(driver.modified, SSAValue(inst))
new_val = match.bindings[rule.rhs.name]
for u in users_of(driver, val)
push!(driver.modified, u)
push!(driver.worklist, u)
end
add_users_to_worklist!(driver, val)
replace_uses!(driver.sci.entry, val, match.bindings[rule.rhs.name])
erase_op!(driver, entry)
replace_uses!(driver.rewriter, val, new_val)
erase!(driver.rewriter, entry.block, val)
else
# Substitution: replace root in-place, clean up dead intermediates.
# Only delete intermediates with no remaining uses — transparent-op
Expand All @@ -514,7 +526,7 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match
dead_entry = get(driver.defs, dead_val, nothing)
dead_entry === nothing && continue
use_count(driver, dead_val) == 0 || continue
erase_op!(driver, dead_entry)
erase!(driver.rewriter, dead_entry.block, dead_val)
end
typ = block[val.id][:type]
# Build operands, flattening RSplat nodes into multiple operands
Expand All @@ -531,13 +543,8 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match
# describe the OLD op; recompute from the new func's `efunc` effects
# so downstream gates (CSE, LICM) see fresh, correct information.
new_stmt = Expr(:call, rule.rhs.func, operands...)
if rule.rhs.func === driver.defs[val].func
block[val.id] = (stmt=new_stmt,)
else
block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func))
end
# Update defs, re-add self and users to worklist (statement changed)
driver.defs[val] = DefEntry(block, val, rule.rhs.func)
flag = rule.rhs.func === driver.defs[val].func ? nothing : inferred_flags(rule.rhs.func)
replace_stmt!(driver.rewriter, block, val, new_stmt; flag)
push!(driver.worklist, val)
add_users_to_worklist!(driver, val)
end
Expand Down Expand Up @@ -584,7 +591,10 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule};
end
end

driver = RewriteDriver(sci, defs, dispatch, wl, constants, Set{SSAValue}(), max_rewrites)
rewriter = Rewriter(sci)
driver = RewriteDriver(rewriter, defs, dispatch, wl, constants, Set{SSAValue}(),
max_rewrites)
rewriter.listener = driver

num_rewrites = 0
while !isempty(driver.worklist) && num_rewrites < driver.max_rewrites
Expand All @@ -605,7 +615,7 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule};
if use_count(driver, val) == 0
stmt = entry.block[val.id][:stmt]
if !must_keep(entry.block, stmt)
erase_op!(driver, entry)
erase!(driver.rewriter, entry.block, val)
continue
end
end
Expand Down Expand Up @@ -634,10 +644,9 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule};
# cascade continues through the fixpoint.
if !matched && val in driver.modified
delete!(driver.modified, val)
for inst in users(driver.sci.entry, val)
uv = SSAValue(inst)
push!(driver.modified, uv)
haskey(driver.defs, uv) && push!(driver.worklist, uv)
for u in users_of(driver, val)
push!(driver.modified, u)
haskey(driver.defs, u) && push!(driver.worklist, u)
end
end
end
Expand Down
Loading