Skip to content

Commit b0658df

Browse files
maleadtclaude
andcommitted
Rewrite LICM to hoist all loop-invariant ops after token ordering.
The previous LICM only targeted loads and ran before token ordering, but failed to hoist anything because load dependencies (make_partition_view, Core.tuple) were always generated inline inside the loop body. The new approach mirrors cuTile Python's code_motion.py: run after token_order_pass! and hoist ALL loop-invariant operations based on data dependencies. Token dependencies naturally prevent unsafe hoisting of loads that alias with stores — no separate alias analysis needed for LICM. This correctly hoists loop-invariant loads and their entire dependency chain (tensor_view → partition_view → load → reshape → broadcast). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bfe76a3 commit b0658df

3 files changed

Lines changed: 199 additions & 122 deletions

File tree

src/compiler/passes/licm.jl

Lines changed: 165 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,134 @@
11
# Loop-Invariant Code Motion (LICM)
22
#
3-
# Hoists loop-invariant loads (and their dependency chains) out of loops.
3+
# Hoists loop-invariant operations out of loops. Runs AFTER token_order_pass!
4+
# so that token dependencies correctly prevent unsafe hoisting of aliasing loads.
45
#
5-
# Pure operations (arithmetic, broadcasts, view constructors) are NOT hoisted
6-
# here — MLIR's built-in LICM handles those at optLevel >= 2.
6+
# Operations classified as stores (store_partition_view, store_ptr_tko, atomics,
7+
# print_tko) and control flow exits (return) are never hoisted. All other
8+
# operations — including loads, arithmetic, partition views, token nodes — are
9+
# hoisted when all their data dependencies are defined outside the loop.
710
#
8-
# This pass targets what MLIR cannot hoist: memory loads. After token ordering,
9-
# loads have token dependencies that anchor them inside loops. By hoisting
10-
# before token insertion, we avoid creating unnecessary token carries.
11-
#
12-
# Safety: a load is hoistable only when (1) all its operands are loop-invariant,
13-
# and (2) no store in the loop body aliases with the load's memory region.
14-
# Alias information comes from alias_analysis_pass!, which must run first.
11+
# This mirrors cuTile Python's code_motion.py:hoist_loop_invariants.
12+
13+
# Indicates whether a block could in theory be moved, based on the operations
14+
# it contains (side effects, jumps). Does not consider data dependencies.
15+
@enum _BlockMobility::Int8 begin
16+
# The block (or any ancestor) cannot be moved due to side effects.
17+
_IMMOVABLE = 0
18+
# The block itself can't be hoisted alone, but its containing loop can.
19+
# Happens when the block contains Continue or Break.
20+
_CAN_MOVE_WITH_LOOP = 1
21+
# The block can move (subject to data dependencies).
22+
_CAN_MOVE = 2
23+
end
1524

16-
"""
17-
licm_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet})
25+
struct _BlockResult
26+
mobility::_BlockMobility
27+
min_depth::Int # deepest outside dependency of any hoisted-out op
28+
end
1829

19-
Hoist loop-invariant loads out of loops. Must run after alias_analysis_pass!
20-
and before token_order_pass!.
30+
# Helper for accumulating data dependency information per operation.
31+
mutable struct _DepInfo
32+
must_stay::Bool
33+
max_outside_depth::Int
34+
end
2135

22-
A load is hoistable when:
23-
- All operands are defined outside the loop
24-
- No store in the loop body writes to an aliasing memory region
25-
"""
26-
function licm_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet})
27-
def_depth = Dict{Any, Int}()
28-
for i in 1:length(sci.argtypes)
29-
def_depth[Argument(i)] = 0
36+
function _update_dep!(di::_DepInfo, dep_depth::Int, cur_depth::Int)
37+
if dep_depth >= cur_depth
38+
di.must_stay = true
39+
else
40+
di.max_outside_depth = max(di.max_outside_depth, dep_depth)
3041
end
31-
_hoist_loads!(sci.entry, Vector{Vector{Tuple{Int,Any,Any}}}(), def_depth,
32-
alias_result, false)
33-
return
3442
end
3543

36-
# Collect alias sets of all stores in a block (recursively through nested CFs).
37-
function _collect_store_aliases(block::Block, alias_result::Dict{Any, AliasSet})
38-
store_aliases = AliasSet[]
39-
for inst in instructions(block)
40-
s = stmt(inst)
41-
if s isa ControlFlowOp
42-
for b in blocks(s)
43-
append!(store_aliases, _collect_store_aliases(b, alias_result))
44-
end
45-
else
46-
call = resolve_call(block, s)
47-
call === nothing && continue
48-
resolved_func, operands = call
49-
if classify_memory_op(resolved_func) == MEM_STORE
50-
aset = get_alias_set_for_operand(alias_result, first(operands))
51-
push!(store_aliases, aset)
52-
end
53-
end
54-
end
55-
return store_aliases
44+
# Update dependency info from an SSA value or literal.
45+
function _check_val!(di::_DepInfo, val, def_depth::Dict{Any,Int}, cur_depth::Int)
46+
d = get(def_depth, val, nothing)
47+
d === nothing && return # constants/literals always available
48+
_update_dep!(di, d, cur_depth)
5649
end
5750

58-
# Check if a load's alias set conflicts with any store alias set in the loop.
59-
function _aliases_with_store(load_alias::AliasSet, store_aliases::Vector{AliasSet})
60-
for sa in store_aliases
61-
if load_alias isa AliasUniverse || sa isa AliasUniverse
62-
return true
51+
# Extract all SSA dependencies from a statement.
52+
function _check_stmt_deps!(di::_DepInfo, @nospecialize(s), def_depth::Dict{Any,Int},
53+
cur_depth::Int)
54+
if s isa Expr
55+
start = s.head === :invoke ? 3 : 2
56+
for i in start:length(s.args)
57+
_check_val!(di, s.args[i], def_depth, cur_depth)
6358
end
64-
if !isempty(intersect(load_alias, sa))
65-
return true
59+
elseif s isa JoinTokensNode
60+
for tok in s.tokens
61+
_check_val!(di, tok, def_depth, cur_depth)
6662
end
63+
elseif s isa TokenResultNode
64+
_check_val!(di, SSAValue(s.mem_op_ssa), def_depth, cur_depth)
65+
end
66+
# MakeTokenNode, PiNode, GlobalRef, literals: no SSA deps
67+
end
68+
69+
struct _StackItem
70+
entries::Vector{Tuple{Int,Any,Any}} # (ssa_idx, stmt, type)
71+
is_loop_body::Bool
72+
end
73+
74+
"""
75+
licm_pass!(sci::StructuredIRCode)
76+
77+
Hoist loop-invariant operations out of loops. Must run after token_order_pass!.
78+
"""
79+
function licm_pass!(sci::StructuredIRCode)
80+
def_depth = Dict{Any,Int}()
81+
for i in 1:length(sci.argtypes)
82+
def_depth[Argument(i)] = 0
6783
end
68-
return false
84+
_hoist!(sci.entry, _StackItem[], def_depth, false)
85+
return
6986
end
7087

71-
function _hoist_loads!(block::Block, stack::Vector{Vector{Tuple{Int,Any,Any}}},
72-
def_depth::Dict{Any,Int}, alias_result::Dict{Any, AliasSet},
73-
is_loop_body::Bool)
88+
function _hoist!(block::Block, stack::Vector{_StackItem}, def_depth::Dict{Any,Int},
89+
is_loop_body::Bool)
7490
depth = length(stack)
75-
push!(stack, Tuple{Int,Any,Any}[])
91+
push!(stack, _StackItem(Tuple{Int,Any,Any}[], is_loop_body))
7692

77-
# Register block args at current depth
7893
for ba in block.args
7994
def_depth[ba] = depth
8095
end
8196

82-
# If this is a loop body, collect store alias sets for the load safety check
83-
store_aliases = is_loop_body ? _collect_store_aliases(block, alias_result) : AliasSet[]
97+
mobility = _CAN_MOVE
98+
min_depth = 0
8499

85100
for inst in instructions(block)
86101
s = stmt(inst)
87-
hoisted = false
102+
di = _DepInfo(!is_loop_body, 0)
88103

89-
if s isa ForOp || s isa LoopOp
90-
body = s.body
91-
if s isa ForOp
92-
def_depth[s.iv_arg] = depth + 1
104+
if s isa ForOp
105+
def_depth[s.iv_arg] = depth + 1
106+
for ba in s.body.args
107+
def_depth[ba] = depth + 1
108+
end
109+
body_res = _hoist!(s.body, stack, def_depth, true)
110+
if body_res.mobility == _IMMOVABLE
111+
mobility = _IMMOVABLE
112+
di.must_stay = true
113+
end
114+
for v in s.init_values
115+
_check_val!(di, v, def_depth, depth)
93116
end
94-
for ba in body.args
117+
_update_dep!(di, body_res.min_depth, depth)
118+
119+
elseif s isa LoopOp
120+
for ba in s.body.args
95121
def_depth[ba] = depth + 1
96122
end
97-
_hoist_loads!(body, stack, def_depth, alias_result, true)
123+
body_res = _hoist!(s.body, stack, def_depth, true)
124+
if body_res.mobility == _IMMOVABLE
125+
mobility = _IMMOVABLE
126+
di.must_stay = true
127+
end
128+
for v in s.init_values
129+
_check_val!(di, v, def_depth, depth)
130+
end
131+
_update_dep!(di, body_res.min_depth, depth)
98132

99133
elseif s isa WhileOp
100134
for ba in s.before.args
@@ -103,75 +137,86 @@ function _hoist_loads!(block::Block, stack::Vector{Vector{Tuple{Int,Any,Any}}},
103137
for ba in s.after.args
104138
def_depth[ba] = depth + 1
105139
end
106-
_hoist_loads!(s.before, stack, def_depth, alias_result, true)
107-
_hoist_loads!(s.after, stack, def_depth, alias_result, true)
140+
before_res = _hoist!(s.before, stack, def_depth, true)
141+
after_res = _hoist!(s.after, stack, def_depth, true)
142+
if min(before_res.mobility, after_res.mobility) == _IMMOVABLE
143+
mobility = _IMMOVABLE
144+
di.must_stay = true
145+
end
146+
for v in s.init_values
147+
_check_val!(di, v, def_depth, depth)
148+
end
149+
_update_dep!(di, before_res.min_depth, depth)
150+
_update_dep!(di, after_res.min_depth, depth)
108151

109152
elseif s isa IfOp
110-
_hoist_loads!(s.then_region, stack, def_depth, alias_result, false)
111-
_hoist_loads!(s.else_region, stack, def_depth, alias_result, false)
112-
113-
elseif is_loop_body && _is_hoistable_load(block, s, def_depth, depth,
114-
alias_result, store_aliases)
115-
# Hoist this load to the enclosing scope
116-
target_depth = depth - 1
117-
while target_depth > 0 && _can_hoist_to(stack, target_depth)
118-
target_depth -= 1
153+
_check_val!(di, s.condition, def_depth, depth)
154+
for region in (s.then_region, s.else_region)
155+
branch_res = _hoist!(region, stack, def_depth, false)
156+
_update_dep!(di, branch_res.min_depth, depth)
157+
if branch_res.mobility != _CAN_MOVE
158+
mobility = min(mobility, branch_res.mobility)
159+
di.must_stay = true
160+
end
119161
end
120-
push!(stack[target_depth + 1], (inst.ssa_idx, s, inst.typ))
121-
def_depth[SSAValue(inst.ssa_idx)] = target_depth
122-
hoisted = true
162+
163+
elseif _is_store(block, s)
164+
mobility = _IMMOVABLE
165+
di.must_stay = true
166+
167+
elseif s isa ContinueOp || s isa BreakOp
168+
mobility = min(mobility, _CAN_MOVE_WITH_LOOP)
169+
di.must_stay = true
170+
171+
elseif s isa YieldOp || s isa ConditionOp || s isa ReturnNode
172+
di.must_stay = true
173+
# Track deps for YieldOp/ConditionOp so min_depth is correct
174+
if s isa YieldOp
175+
for v in s.values
176+
_check_val!(di, v, def_depth, depth)
177+
end
178+
elseif s isa ConditionOp
179+
_check_val!(di, s.condition, def_depth, depth)
180+
for v in s.args
181+
_check_val!(di, v, def_depth, depth)
182+
end
183+
end
184+
185+
else
186+
# Movable operation: loads, arithmetic, make_partition_view, etc.
187+
_check_stmt_deps!(di, s, def_depth, depth)
123188
end
124189

125-
if !hoisted
126-
# Keep at current depth
127-
push!(stack[depth + 1], (inst.ssa_idx, s, inst.typ))
128-
def_depth[SSAValue(inst.ssa_idx)] = depth
190+
# Determine target depth
191+
target_depth = depth
192+
if di.must_stay
193+
min_depth = max(min_depth, di.max_outside_depth)
194+
else
195+
while target_depth > di.max_outside_depth && stack[target_depth].is_loop_body
196+
target_depth -= 1
197+
end
129198
end
199+
200+
push!(stack[target_depth + 1].entries, (inst.ssa_idx, s, inst.typ))
201+
202+
# Record definition depth AFTER hoisting so subsequent ops see the new depth
203+
def_depth[SSAValue(inst.ssa_idx)] = target_depth
130204
end
131205

132206
# Rebuild block body from collected entries
133-
entries = pop!(stack)
207+
entries = pop!(stack).entries
134208
empty!(block)
135209
for (idx, s, typ) in entries
136210
push!(block, idx, s, typ)
137211
end
138-
end
139212

140-
# Check if a stack entry is a loop body (for multi-level hoisting)
141-
function _can_hoist_to(stack::Vector{Vector{Tuple{Int,Any,Any}}}, target_depth::Int)
142-
# We'd need to track is_loop_body per stack entry to do multi-level hoisting.
143-
# For now, only hoist one level out.
144-
return false
213+
return _BlockResult(mobility, min_depth)
145214
end
146215

147-
# Check if a statement is a load that can be safely hoisted.
148-
function _is_hoistable_load(block::Block, @nospecialize(s), def_depth::Dict{Any,Int},
149-
cur_depth::Int, alias_result::Dict{Any, AliasSet},
150-
store_aliases::Vector{AliasSet})
151-
s isa Expr || return false
216+
# Check if a statement is a store/atomic (side-effecting memory write).
217+
function _is_store(block::Block, @nospecialize(s))
152218
call = resolve_call(block, s)
153219
call === nothing && return false
154-
resolved_func, operands = call
155-
156-
# Must be a load operation
157-
classify_memory_op(resolved_func) == MEM_LOAD || return false
158-
159-
# All operands must be defined outside this loop
160-
_all_operands_outside(s, def_depth, cur_depth) || return false
161-
162-
# Load must not alias with any store in the loop
163-
load_alias = get_alias_set_for_operand(alias_result, first(operands))
164-
return !_aliases_with_store(load_alias, store_aliases)
165-
end
166-
167-
# Check that all operands of a statement are defined at depth < cur_depth.
168-
function _all_operands_outside(@nospecialize(s), def_depth::Dict{Any,Int}, cur_depth::Int)
169-
s isa Expr || return true
170-
start = s.head === :invoke ? 3 : 2
171-
for i in start:length(s.args)
172-
d = get(def_depth, s.args[i], nothing)
173-
d === nothing && continue # constants/literals are always available
174-
d >= cur_depth && return false
175-
end
176-
return true
220+
resolved_func, _ = call
221+
return classify_memory_op(resolved_func) == MEM_STORE
177222
end

src/compiler/passes/pipeline.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,9 @@ function run_passes!(sci::StructuredIRCode)
329329

330330
alias_result = alias_analysis_pass!(sci)
331331

332-
licm_pass!(sci, alias_result)
333-
334332
token_order_pass!(sci, alias_result)
335333

334+
licm_pass!(sci)
335+
336336
dce_pass!(sci)
337337
end

test/codegen/integration.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,38 @@ end
420420
end
421421
end
422422

423+
@testset "loop-invariant load (manually hoisted)" begin
424+
# Test that a manually-hoisted loop-invariant load appears before the loop.
425+
# Pattern: Y[n, m] = X[n, m] * W[m], iterating over N-tiles.
426+
# W[bid_m] doesn't depend on the loop variable, so the user hoists it.
427+
spec2d = ct.ArraySpec{2}(16, true)
428+
spec1d = ct.ArraySpec{1}(16, true)
429+
@test @filecheck begin
430+
@check_label "entry"
431+
# W load must appear BEFORE the for loop
432+
@check "load_view_tko"
433+
@check "for %loopIdx in"
434+
# Inside the loop: only the X load
435+
@check "load_view_tko"
436+
@check "mulf"
437+
@check "store_view_tko"
438+
code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d},
439+
ct.TileArray{Float32,2,spec2d}, ct.Constant{Int,1024}}) do X, W, Y, TILE_N
440+
bid_m = ct.bid(1)
441+
num_tiles = ct.num_tiles(X, 1, (TILE_N, 1))
442+
# Hoisted: W load before loop
443+
w = ct.load(W; index=bid_m, shape=(1,))
444+
for j in Int32(1):num_tiles
445+
x = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1),
446+
padding_mode=ct.PaddingMode.Zero)
447+
y = x .* w
448+
ct.store(Y; index=(j, bid_m), tile=y)
449+
end
450+
return
451+
end
452+
end
453+
end
454+
423455
#=========================================================================
424456
Gather/Scatter Operations
425457
=========================================================================#

0 commit comments

Comments
 (0)