Skip to content

Commit b9a9b91

Browse files
shreyas-omkarmaleadtclaude
authored
Add alias-aware token ordering pass. (#89)
Add alias-aware token ordering pass. - Alias analysis groups pointers into per-argument alias sets. - token_order_pass! inserts token IR nodes (MakeToken, JoinTokens, TokenResult) and threads per-alias-set token carries through loops, branches, and terminators. - Memory intrinsics receive input tokens as arguments, store result tokens via ctx.result_tokens. - Release/acquire memory ordering for atomics. - Codegen emits what the pass added to the IR. Enables independent memory operations to execute without unnecessary serialization. --------- Co-authored-by: Tim Besard <tim.besard@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 850a808 commit b9a9b91

13 files changed

Lines changed: 1414 additions & 307 deletions

File tree

src/compiler/codegen.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Codegen: Julia IR -> Tile IR bytecode
22

33
include("codegen/utils.jl")
4+
include("codegen/irutils.jl") # SSAMap/Block mutation helpers
5+
include("codegen/passes/token_keys.jl") # TokenKey, TokenRole, ACQUIRE_TOKEN_KEY
6+
include("codegen/passes/alias_analysis.jl") # alias_analysis_pass!
7+
include("codegen/passes/token_order.jl") # token_order_pass!
48
include("codegen/kernel.jl")
59
include("codegen/control_flow.jl")
610
include("codegen/statements.jl")

src/compiler/codegen/control_flow.jl

Lines changed: 125 additions & 236 deletions
Large diffs are not rendered by default.

src/compiler/codegen/irutils.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# StructuredIRCode / SSAMap mutation utilities
2+
#
3+
# Helpers for passes that modify the structured IR in place.
4+
# Inspired by Julia's IncrementalCompact (Compiler/src/ssair/ir.jl).
5+
6+
"""
7+
new_ssa_idx!(sci::StructuredIRCode) -> Int
8+
9+
Allocate a fresh SSA index from the StructuredIRCode.
10+
"""
11+
function new_ssa_idx!(sci::StructuredIRCode)
12+
sci.max_ssa_idx += 1
13+
return sci.max_ssa_idx
14+
end
15+
16+
"""
17+
new_block_arg!(block::Block, sci::StructuredIRCode, @nospecialize(typ)) -> BlockArg
18+
19+
Add a new BlockArg to a block, allocating a fresh ID.
20+
"""
21+
function new_block_arg!(block::Block, sci::StructuredIRCode, @nospecialize(typ))
22+
id = new_ssa_idx!(sci)
23+
arg = BlockArg(id, typ)
24+
push!(block.args, arg)
25+
return arg
26+
end
27+
28+
"""
29+
Base.pushfirst!(m::SSAMap, (idx, stmt, typ)::Tuple{Int,Any,Any})
30+
31+
Prepend a statement at the beginning of an SSAMap.
32+
"""
33+
function Base.pushfirst!(m::SSAMap, (idx, stmt, typ)::Tuple{Int,Any,Any})
34+
pushfirst!(m.ssa_idxes, idx)
35+
pushfirst!(m.stmts, stmt)
36+
pushfirst!(m.types, typ)
37+
return nothing
38+
end
39+
40+
"""
41+
insert_before!(m::SSAMap, before_idx::Int, new_idx::Int, stmt, typ)
42+
43+
Insert a new entry before the entry with SSA index `before_idx`.
44+
"""
45+
function insert_before!(m::SSAMap, before_idx::Int, new_idx::Int, stmt, typ)
46+
pos = findfirst(==(before_idx), m.ssa_idxes)
47+
pos === nothing && throw(KeyError(before_idx))
48+
insert!(m.ssa_idxes, pos, new_idx)
49+
insert!(m.stmts, pos, stmt)
50+
insert!(m.types, pos, typ)
51+
return nothing
52+
end
53+
54+
"""
55+
insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ)
56+
57+
Insert a new entry after the entry with SSA index `after_idx`.
58+
"""
59+
function insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ)
60+
pos = findfirst(==(after_idx), m.ssa_idxes)
61+
pos === nothing && throw(KeyError(after_idx))
62+
insert!(m.ssa_idxes, pos + 1, new_idx)
63+
insert!(m.stmts, pos + 1, stmt)
64+
insert!(m.types, pos + 1, typ)
65+
return nothing
66+
end
67+
68+
"""
69+
update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type))
70+
71+
Update the type annotation for an existing SSAMap entry.
72+
"""
73+
function update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type))
74+
pos = findfirst(==(ssa_idx), m.ssa_idxes)
75+
pos === nothing && throw(KeyError(ssa_idx))
76+
m.types[pos] = new_type
77+
return nothing
78+
end
79+
80+
"""
81+
resolve_call(stmt) -> (resolved_func, operands) or nothing
82+
83+
Extract the resolved function and operands from a `:call` or `:invoke` Expr.
84+
Shared by alias analysis and token ordering passes.
85+
"""
86+
function resolve_call(stmt)
87+
stmt isa Expr || return nothing
88+
if stmt.head === :call
89+
func_ref = stmt.args[1]
90+
operands = @view stmt.args[2:end]
91+
elseif stmt.head === :invoke
92+
func_ref = stmt.args[2]
93+
operands = @view stmt.args[3:end]
94+
else
95+
return nothing
96+
end
97+
resolved = if func_ref isa GlobalRef
98+
try; getfield(func_ref.mod, func_ref.name); catch; nothing; end
99+
else
100+
func_ref
101+
end
102+
resolved === nothing && return nothing
103+
return (resolved, operands)
104+
end

src/compiler/codegen/kernel.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,17 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
141141
create_tensor_views!(ctx, arg_idx, argtype, Int[])
142142
end
143143

144-
# Create memory ordering token
145-
token_type = Token(tt)
146-
ctx.token_type = token_type
147-
ctx.token = encode_MakeTokenOp!(cb, token_type)
148-
149-
# Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp)
144+
# Hoist early returns BEFORE token ordering — hoist_returns! rewrites
145+
# ReturnNode terminators to YieldOp, which the token pass then extends.
150146
hoist_returns!(ctx.sci.entry)
151147

148+
# Run alias analysis and token ordering pass on the structured IR.
149+
alias_result = alias_analysis_pass!(sci)
150+
token_order_pass!(sci, alias_result)
151+
152+
# Cache the token bytecode type for codegen
153+
ctx.token_type = Token(tt)
154+
152155
# Emit the structured IR (uses original Julia SSA indices everywhere)
153156
emit_block!(ctx, ctx.sci.entry)
154157

@@ -314,7 +317,7 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector,
314317

315318
# 3. Create sub-context
316319
sub_ctx = CGCtx(; ctx.cb, ctx.tt, sci,
317-
ctx.token, ctx.token_type,
320+
ctx.token_type,
318321
ctx.type_cache, ctx.sm_arch,
319322
ctx.cache)
320323

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Alias Analysis Pass
2+
#
3+
# Fixed-point alias analysis over StructuredIRCode. Determines which memory
4+
# operations may access the same underlying data (i.e., which SSA values
5+
# point into the same allocation).
6+
#
7+
# WHY: The token ordering pass needs alias information to decide which memory
8+
# operations require token dependencies between them. Without alias analysis,
9+
# all memory ops would be serialized through a single token chain — correct,
10+
# but overly conservative. With per-alias-set information, independent memory
11+
# regions (e.g., separate kernel arguments) get independent token chains,
12+
# enabling more parallelism in the generated Tile IR.
13+
#
14+
# HOW: Each pointer-containing kernel argument starts in its own alias set.
15+
# Alias sets propagate forward through:
16+
# - getfield (for TileArray.ptr field access)
17+
# - pointer arithmetic (+, -)
18+
# - view constructors (make_tensor_view, make_partition_view)
19+
# - pointer passthroughs (bitcast, assume_aligned, etc.)
20+
# Unknown operations conservatively produce ALIAS_UNIVERSE (may alias anything).
21+
# Fixed-point iteration handles back-edges from loops.
22+
#
23+
# OUTPUT: Dict{Any, AliasSet} mapping SSA values and Arguments to their alias
24+
# sets, consumed by token_order_pass!.
25+
26+
"""
27+
AliasTracker
28+
29+
Tracks alias sets for each SSA value during fixed-point analysis.
30+
"""
31+
mutable struct AliasTracker
32+
dirty::Bool
33+
aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet
34+
end
35+
36+
AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}())
37+
38+
function Base.getindex(tracker::AliasTracker, key)
39+
return get(tracker.aliases, key, ALIAS_UNIVERSE)
40+
end
41+
42+
function Base.setindex!(tracker::AliasTracker, value::AliasSet, key)
43+
current = get(tracker.aliases, key, nothing)
44+
if current !== value
45+
tracker.dirty = true
46+
tracker.aliases[key] = value
47+
end
48+
return
49+
end
50+
51+
"""
52+
alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet}
53+
54+
Perform fixed-point alias analysis on structured IR.
55+
Returns mapping from SSA values to alias sets.
56+
"""
57+
function alias_analysis_pass!(sci::StructuredIRCode)
58+
tracker = AliasTracker()
59+
60+
# Initialize: each argument gets its own alias set
61+
for (idx, argtype) in enumerate(sci.argtypes)
62+
argtype_unwrapped = CC.widenconst(argtype)
63+
if contains_pointers(argtype_unwrapped)
64+
arg_ref = Argument(idx)
65+
tracker[arg_ref] = Set{Any}([arg_ref])
66+
end
67+
end
68+
69+
# Fixed-point iteration
70+
iteration = 0
71+
max_iterations = 100
72+
73+
tracker.dirty = true
74+
while tracker.dirty && iteration < max_iterations
75+
tracker.dirty = false
76+
iteration += 1
77+
78+
analyze_block!(tracker, sci.entry)
79+
end
80+
81+
@debug "Alias analysis converged in $iteration iterations"
82+
83+
return tracker.aliases
84+
end
85+
86+
"""
87+
propagate!(tracker::AliasTracker, from, to)
88+
89+
Propagate alias set from `from` to `to`.
90+
Uses direct assignment when `to` is uninitialized, union otherwise.
91+
"""
92+
function propagate!(tracker::AliasTracker, from, to)
93+
from_aliases = tracker[from]
94+
95+
if from_aliases === ALIAS_UNIVERSE
96+
# Propagating UNIVERSE is always conservative
97+
tracker[to] = ALIAS_UNIVERSE
98+
return
99+
end
100+
101+
if haskey(tracker.aliases, to)
102+
# Target already has an alias set union with it
103+
to_aliases = tracker.aliases[to]
104+
new_aliases = union(from_aliases, to_aliases)
105+
if new_aliases != to_aliases
106+
tracker[to] = new_aliases
107+
end
108+
else
109+
# Target not yet analyzed assign directly
110+
tracker[to] = from_aliases
111+
end
112+
return
113+
end
114+
115+
"""
116+
analyze_block!(tracker::AliasTracker, block)
117+
118+
Analyze all statements in a block, recursing into nested control flow.
119+
"""
120+
function analyze_block!(tracker::AliasTracker, block)
121+
for (ssa_idx, entry) in block.body
122+
if entry.stmt isa ControlFlowOp
123+
analyze_control_flow!(tracker, entry.stmt)
124+
else
125+
analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt)
126+
end
127+
end
128+
return
129+
end
130+
131+
# Recurse into nested control flow regions
132+
function analyze_control_flow!(tracker::AliasTracker, op::IfOp)
133+
analyze_block!(tracker, op.then_region)
134+
return analyze_block!(tracker, op.else_region)
135+
end
136+
137+
function analyze_control_flow!(tracker::AliasTracker, op::ForOp)
138+
return analyze_block!(tracker, op.body)
139+
end
140+
141+
function analyze_control_flow!(tracker::AliasTracker, op::WhileOp)
142+
analyze_block!(tracker, op.before)
143+
return analyze_block!(tracker, op.after)
144+
end
145+
146+
function analyze_control_flow!(tracker::AliasTracker, op::LoopOp)
147+
return analyze_block!(tracker, op.body)
148+
end
149+
150+
# Fallback for unknown control flow ops
151+
function analyze_control_flow!(::AliasTracker, ::ControlFlowOp)
152+
return
153+
end
154+
155+
"""
156+
analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
157+
158+
Analyze a single statement and propagate aliases.
159+
Handles both `:call` and `:invoke` expression forms.
160+
"""
161+
function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
162+
call = resolve_call(stmt)
163+
if call !== nothing
164+
resolved_func, operands = call
165+
166+
# Also need the raw func ref for GlobalRef comparisons
167+
func = stmt.head === :call ? stmt.args[1] : stmt.args[2]
168+
169+
# getfield: propagate from parent
170+
if func === GlobalRef(Core, :getfield) && length(operands) >= 1
171+
field = length(operands) >= 2 ? operands[2] : nothing
172+
173+
# For TileArray.ptr field access, propagate pointer alias
174+
if field isa QuoteNode && field.value === :ptr
175+
propagate!(tracker, operands[1], ssa)
176+
else
177+
# Conservatively mark as UNIVERSE for non-pointer fields
178+
tracker[ssa] = ALIAS_UNIVERSE
179+
end
180+
181+
# Pointer arithmetic: propagate from pointer operand
182+
elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-)
183+
for arg in operands
184+
# Find the pointer argument and propagate
185+
arg_aliases = tracker[arg]
186+
if arg_aliases !== ALIAS_UNIVERSE && arg_aliases isa Set
187+
propagate!(tracker, arg, ssa)
188+
break
189+
end
190+
end
191+
192+
# View construction: propagate alias from first operand
193+
elseif is_view_constructor(resolved_func) || is_pointer_passthrough(resolved_func)
194+
if length(operands) >= 1
195+
propagate!(tracker, operands[1], ssa)
196+
end
197+
198+
# Default: unknown operation -> UNIVERSE
199+
else
200+
tracker[ssa] = ALIAS_UNIVERSE
201+
end
202+
203+
elseif stmt isa ReturnNode
204+
# No alias propagation needed
205+
206+
else
207+
# Unknown statement type -> conservative
208+
tracker[ssa] = ALIAS_UNIVERSE
209+
end
210+
return
211+
end
212+
213+
# Helper functions
214+
contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr)
215+
216+
"""
217+
is_view_constructor(func) -> Bool
218+
219+
Check if a resolved function is a tensor/partition view constructor.
220+
These propagate alias identity from their first operand.
221+
"""
222+
function is_view_constructor(func)
223+
return func === Intrinsics.make_tensor_view ||
224+
func === Intrinsics.make_partition_view
225+
end
226+
227+
function is_pointer_passthrough(func)
228+
func === GlobalRef(Core.Intrinsics, :bitcast) && return true
229+
230+
# Safely check by name to avoid UndefVarError if intrinsics aren't exposed
231+
if func isa Core.IntrinsicFunction || func isa Function
232+
n = nameof(func)
233+
return n === :bitcast || n === :assume_div_by || n === :assume_bounded || n === :assume_aligned
234+
end
235+
return false
236+
end

0 commit comments

Comments
 (0)