Skip to content

Commit ba4f9a7

Browse files
committed
Add alias-aware token threading for memory operations.
Introduce alias analysis–based token threading: - Group pointers into alias sets. - Maintain per-alias-set token chains. - Thread tokens only between potentially aliasing operations. - Conservatively fall back to the global set for unknown pointers. - Preserve existing control-flow token merging semantics. Enables independent memory operations to execute without unnecessary serialization.
1 parent 473b7d5 commit ba4f9a7

9 files changed

Lines changed: 562 additions & 17 deletions

File tree

src/compiler/codegen.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Codegen: Julia IR -> Tile IR bytecode
22

33
include("codegen/utils.jl")
4+
include("codegen/token_keys.jl") # Defines TokenKey, TokenRole, ACQUIRE_TOKEN_KEY
5+
include("codegen/alias_analysis.jl") # Defines alias_analysis_pass!
6+
include("codegen/token_order.jl") # Defines get_alias_set, get_input_token!
47
include("codegen/kernel.jl")
58
include("codegen/control_flow.jl")
69
include("codegen/statements.jl")
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
AliasTracker
3+
4+
Tracks alias sets for each SSA value during fixed-point analysis.
5+
"""
6+
mutable struct AliasTracker
7+
dirty::Bool
8+
aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet
9+
end
10+
11+
AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}())
12+
13+
function Base.getindex(tracker::AliasTracker, key)
14+
return get(tracker.aliases, key, ALIAS_UNIVERSE)
15+
end
16+
17+
function Base.setindex!(tracker::AliasTracker, value::AliasSet, key)
18+
current = get(tracker.aliases, key, nothing)
19+
if current !== value
20+
tracker.dirty = true
21+
tracker.aliases[key] = value
22+
end
23+
return
24+
end
25+
26+
"""
27+
alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet}
28+
29+
Perform fixed-point alias analysis on structured IR.
30+
Returns mapping from SSA values to alias sets.
31+
"""
32+
function alias_analysis_pass!(sci::StructuredIRCode)
33+
tracker = AliasTracker()
34+
35+
# Initialize: each argument gets its own alias set
36+
for (idx, argtype) in enumerate(sci.argtypes)
37+
argtype_unwrapped = CC.widenconst(argtype)
38+
if contains_pointers(argtype_unwrapped)
39+
arg_ref = Argument(idx)
40+
tracker[arg_ref] = Set{Any}([arg_ref])
41+
end
42+
end
43+
44+
# Fixed-point iteration
45+
iteration = 0
46+
max_iterations = 100
47+
48+
tracker.dirty = true
49+
while tracker.dirty && iteration < max_iterations
50+
tracker.dirty = false
51+
iteration += 1
52+
53+
analyze_block!(tracker, sci.entry)
54+
end
55+
56+
@debug "Alias analysis converged in $iteration iterations"
57+
58+
return tracker.aliases
59+
end
60+
61+
"""
62+
propagate!(tracker::AliasTracker, from, to)
63+
64+
Propagate alias set from `from` to `to`.
65+
Uses direct assignment when `to` is uninitialized, union otherwise.
66+
"""
67+
function propagate!(tracker::AliasTracker, from, to)
68+
from_aliases = tracker[from]
69+
70+
if from_aliases === ALIAS_UNIVERSE
71+
# Propagating UNIVERSE is always conservative
72+
tracker[to] = ALIAS_UNIVERSE
73+
return
74+
end
75+
76+
if haskey(tracker.aliases, to)
77+
# Target already has an alias set union with it
78+
to_aliases = tracker.aliases[to]
79+
new_aliases = union(from_aliases, to_aliases)
80+
if new_aliases != to_aliases
81+
tracker[to] = new_aliases
82+
end
83+
else
84+
# Target not yet analyzed assign directly
85+
tracker[to] = from_aliases
86+
end
87+
return
88+
end
89+
90+
"""
91+
analyze_block!(tracker::AliasTracker, block)
92+
93+
Analyze all statements in a block, recursing into nested control flow.
94+
"""
95+
function analyze_block!(tracker::AliasTracker, block)
96+
for (ssa_idx, entry) in block.body
97+
if entry.stmt isa ControlFlowOp
98+
analyze_control_flow!(tracker, entry.stmt)
99+
else
100+
analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt)
101+
end
102+
end
103+
return
104+
end
105+
106+
# Recurse into nested control flow regions
107+
function analyze_control_flow!(tracker::AliasTracker, op::IfOp)
108+
analyze_block!(tracker, op.then_region)
109+
return analyze_block!(tracker, op.else_region)
110+
end
111+
112+
function analyze_control_flow!(tracker::AliasTracker, op::ForOp)
113+
return analyze_block!(tracker, op.body)
114+
end
115+
116+
function analyze_control_flow!(tracker::AliasTracker, op::WhileOp)
117+
analyze_block!(tracker, op.before)
118+
return analyze_block!(tracker, op.after)
119+
end
120+
121+
function analyze_control_flow!(tracker::AliasTracker, op::LoopOp)
122+
return analyze_block!(tracker, op.body)
123+
end
124+
125+
# Fallback for unknown control flow ops
126+
function analyze_control_flow!(::AliasTracker, ::ControlFlowOp)
127+
return
128+
end
129+
130+
"""
131+
analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
132+
133+
Analyze a single statement and propagate aliases.
134+
Handles both `:call` and `:invoke` expression forms.
135+
"""
136+
function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
137+
if stmt isa Expr && (stmt.head === :call || stmt.head === :invoke)
138+
# Normalize :call and :invoke into (func, operands)
139+
# :call -> args = [func, operands...]
140+
# :invoke -> args = [MethodInstance, func, operands...]
141+
if stmt.head === :call
142+
func = stmt.args[1]
143+
operands = @view stmt.args[2:end]
144+
else # :invoke
145+
func = stmt.args[2]
146+
operands = @view stmt.args[3:end]
147+
end
148+
149+
# Resolve func to its runtime value for intrinsic matching.
150+
# In :invoke, func may already be the function object (not a GlobalRef).
151+
resolved_func = if func isa GlobalRef
152+
try
153+
getfield(func.mod, func.name)
154+
catch
155+
nothing
156+
end
157+
else
158+
func # Direct function value (common in :invoke)
159+
end
160+
161+
# getfield: propagate from parent
162+
if func === GlobalRef(Core, :getfield) && length(operands) >= 1
163+
field = length(operands) >= 2 ? operands[2] : nothing
164+
165+
# For TileArray.ptr field access, propagate pointer alias
166+
if field isa QuoteNode && field.value === :ptr
167+
propagate!(tracker, operands[1], ssa)
168+
else
169+
# Conservatively mark as UNIVERSE for non-pointer fields
170+
tracker[ssa] = ALIAS_UNIVERSE
171+
end
172+
173+
# Pointer arithmetic: propagate from pointer operand
174+
elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-)
175+
for arg in operands
176+
# Find the pointer argument and propagate
177+
arg_aliases = tracker[arg]
178+
if arg_aliases !== ALIAS_UNIVERSE && arg_aliases isa Set
179+
propagate!(tracker, arg, ssa)
180+
break
181+
end
182+
end
183+
184+
# View construction: propagate alias from first operand
185+
elseif is_view_constructor(resolved_func)
186+
if length(operands) >= 1
187+
propagate!(tracker, operands[1], ssa)
188+
end
189+
190+
# Default: unknown operation -> UNIVERSE
191+
else
192+
tracker[ssa] = ALIAS_UNIVERSE
193+
end
194+
195+
elseif stmt isa ReturnNode
196+
# No alias propagation needed
197+
198+
else
199+
# Unknown statement type -> conservative
200+
tracker[ssa] = ALIAS_UNIVERSE
201+
end
202+
return
203+
end
204+
205+
# Helper functions
206+
contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr)
207+
208+
"""
209+
is_view_constructor(func) -> Bool
210+
211+
Check if a resolved function is a tensor/partition view constructor.
212+
These propagate alias identity from their first operand.
213+
"""
214+
function is_view_constructor(func)
215+
return func === Intrinsics.make_tensor_view ||
216+
func === Intrinsics.make_partition_view
217+
end

src/compiler/codegen/control_flow.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,14 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
8888
# Save token before branches
8989
token_before = ctx.token
9090

91+
# Save token_map before branches
92+
token_map_before = copy(ctx.token_map)
93+
9194
# Emit IfOp with callback-based region building
9295
then_body = function(_)
9396
saved_block_args = copy(ctx.block_args)
9497
ctx.token = token_before # Reset to pre-branch token
98+
ctx.token_map = copy(token_map_before) # Reset token_map too
9599
emit_block!(ctx, then_blk)
96100
if then_blk.terminator === nothing
97101
encode_YieldOp!(ctx.cb, [ctx.token])
@@ -102,6 +106,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
102106
else_body = function(_)
103107
saved_block_args = copy(ctx.block_args)
104108
ctx.token = token_before # Reset to pre-branch token
109+
ctx.token_map = copy(token_map_before) # Reset token_map too
105110
emit_block!(ctx, else_blk)
106111
if else_blk.terminator === nothing
107112
encode_YieldOp!(ctx.cb, [ctx.token])
@@ -114,6 +119,12 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
114119
# Last result is the merged token from both branches
115120
ctx.token = results[end]
116121

122+
# Merge token_map from both branches
123+
# Conservatively reset to token_before for all keys
124+
for key in keys(ctx.token_map)
125+
ctx.token_map[key] = results[end]
126+
end
127+
117128
# Store results at IfOp's SSA index (may be empty for void-returning ifs)
118129
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
119130
end
@@ -164,6 +175,9 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
164175
# Number of user result types (excluding token)
165176
n_user_results = n_carries
166177

178+
# Save token_map before loop
179+
token_map_before = copy(ctx.token_map)
180+
167181
# Emit ForOp with callback-based region building
168182
body_builder = function(block_args)
169183
saved_block_args = copy(ctx.block_args)
@@ -196,6 +210,12 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
196210
# Last result is the token
197211
ctx.token = results[end]
198212

213+
# Update token_map after loop
214+
# Conservatively update all keys to the merged token
215+
for key in keys(token_map_before)
216+
ctx.token_map[key] = results[end]
217+
end
218+
199219
# Store results at the loop's SSA index (may be empty for void-returning loops)
200220
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
201221
end
@@ -230,6 +250,9 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
230250
# Number of user result types (excluding token)
231251
n_user_results = n_carries
232252

253+
# Save token_map before loop
254+
token_map_before = copy(ctx.token_map)
255+
233256
# Emit LoopOp with callback-based region building
234257
body_builder = function(block_args)
235258
saved_block_args = copy(ctx.block_args)
@@ -266,6 +289,12 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
266289
# Last result is the token
267290
ctx.token = results[end]
268291

292+
# Update token_map after loop
293+
# Conservatively update all keys to the merged token
294+
for key in keys(token_map_before)
295+
ctx.token_map[key] = results[end]
296+
end
297+
269298
# Store results at the loop's SSA index (may be empty for void-returning loops)
270299
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
271300
end
@@ -301,6 +330,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
301330
# Number of user result types (excluding token)
302331
n_user_results = n_carries
303332

333+
# Save token_map before loop
334+
token_map_before = copy(ctx.token_map)
335+
304336
# Emit WhileOp as cuda_tile.loop with conditional break pattern
305337
# MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals }
306338
# Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue }
@@ -396,6 +428,12 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
396428
# Last result is the token
397429
ctx.token = results[end]
398430

431+
# Update token_map after loop
432+
# Conservatively update all keys to the merged token
433+
for key in keys(token_map_before)
434+
ctx.token_map[key] = results[end]
435+
end
436+
399437
# Store results at the loop's SSA index (may be empty for void-returning loops)
400438
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
401439
end

src/compiler/codegen/kernel.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,30 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
152152
cache_tensor_view!(ctx, arg_idx)
153153
end
154154

155+
# Run alias analysis FIRST
156+
alias_result = alias_analysis_pass!(sci)
157+
ctx.alias_result = alias_result
158+
155159
# Create memory ordering token
156160
token_type = Token(tt)
157161
ctx.token_type = token_type
158-
ctx.token = encode_MakeTokenOp!(cb, token_type)
162+
root_token = encode_MakeTokenOp!(cb, token_type)
163+
164+
ctx.global_token = root_token
165+
ctx.token = root_token
166+
167+
# Initialize token map with root token for all alias sets
168+
# Default: all tokens start at root
169+
ctx.token_map = Dict{TokenKey, Value}()
170+
171+
unique_alias_sets = Set(values(alias_result))
172+
for alias_set in unique_alias_sets
173+
ctx.token_map[last_op_key(alias_set)] = root_token
174+
ctx.token_map[last_store_key(alias_set)] = root_token
175+
end
176+
177+
# ACQUIRE token also starts at root
178+
ctx.token_map[ACQUIRE_TOKEN_KEY] = root_token
159179

160180
# Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp)
161181
hoist_returns!(ctx.sci.entry)

0 commit comments

Comments
 (0)