Skip to content

Commit cde62c4

Browse files
committed
Add LICM pass.
1 parent a873277 commit cde62c4

3 files changed

Lines changed: 199 additions & 1 deletion

File tree

src/compiler/passes/licm.jl

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Loop-Invariant Code Motion (LICM)
2+
#
3+
# Single-pass depth-tracking algorithm that hoists loop-invariant operations
4+
# out of loops. Port of cuTile Python's `hoist_loop_invariants` (code_motion.py).
5+
#
6+
# The algorithm walks the IR recursively while tracking the definition depth
7+
# of each value. An operation whose data dependencies all resolve to depths
8+
# *less than* its containing loop can be hoisted above that loop. A stack of
9+
# SSAMaps collects operations at their target depth; at the end of each block,
10+
# the original body is replaced with the (filtered) rebuilt map.
11+
12+
# Whether a block can be moved, based on the operations it contains.
13+
@enum BlockMobility begin
14+
IMMOVABLE # contains stores, returns, or nested IMMOVABLE blocks
15+
CAN_MOVE_WITH_LOOP # contains continue/break
16+
CAN_MOVE # pure operations only
17+
end
18+
19+
struct BlockResult
20+
mobility::BlockMobility
21+
min_depth::Int # minimum depth any op in this block needs
22+
end
23+
24+
mutable struct DependencyInfo
25+
must_stay::Bool
26+
max_outside_depth::Int
27+
end
28+
29+
function update!(di::DependencyInfo, dep_depth::Int, cur_depth::Int)
30+
if dep_depth >= cur_depth
31+
di.must_stay = true
32+
else
33+
di.max_outside_depth = max(di.max_outside_depth, dep_depth)
34+
end
35+
end
36+
37+
struct StackItem
38+
new_body::SSAMap
39+
is_loop_body::Bool
40+
end
41+
42+
"""
43+
licm_pass!(sci::StructuredIRCode)
44+
45+
Hoist loop-invariant operations out of loops. Must run after rewrite_patterns!
46+
and before token_order_pass! (which inserts tokens that should not be moved).
47+
"""
48+
function licm_pass!(sci::StructuredIRCode)
49+
def_depth = Dict{Any, Int}()
50+
for i in 1:length(sci.argtypes)
51+
def_depth[Argument(i)] = 0
52+
end
53+
_hoist!(sci.entry, StackItem[], def_depth, false)
54+
return
55+
end
56+
57+
function _hoist!(block::Block, stack::Vector{StackItem}, def_depth::Dict{Any,Int},
58+
is_loop_body::Bool)
59+
depth = length(stack)
60+
new_body = SSAMap()
61+
push!(stack, StackItem(new_body, is_loop_body))
62+
63+
mobility = CAN_MOVE
64+
min_depth = 0
65+
66+
# Register block args at current depth
67+
for ba in block.args
68+
def_depth[ba] = depth
69+
end
70+
71+
for inst in instructions(block)
72+
s = stmt(inst)
73+
depinfo = DependencyInfo(!is_loop_body, 0)
74+
75+
if s isa ForOp || s isa LoopOp
76+
body = s.body
77+
# ForOp's iv_arg is separate from body.args (which holds only carries)
78+
if s isa ForOp
79+
def_depth[s.iv_arg] = depth + 1
80+
end
81+
for ba in body.args
82+
def_depth[ba] = depth + 1
83+
end
84+
body_result = _hoist!(body, stack, def_depth, true)
85+
if body_result.mobility == IMMOVABLE
86+
mobility = IMMOVABLE
87+
depinfo.must_stay = true
88+
end
89+
for v in s.init_values
90+
_update_from_value!(depinfo, def_depth, v, depth)
91+
end
92+
if s isa ForOp
93+
for v in (s.lower, s.upper, s.step)
94+
_update_from_value!(depinfo, def_depth, v, depth)
95+
end
96+
end
97+
update!(depinfo, body_result.min_depth, depth)
98+
99+
elseif s isa WhileOp
100+
for ba in s.before.args
101+
def_depth[ba] = depth + 1
102+
end
103+
for ba in s.after.args
104+
def_depth[ba] = depth + 1
105+
end
106+
before_result = _hoist!(s.before, stack, def_depth, true)
107+
after_result = _hoist!(s.after, stack, def_depth, true)
108+
worst = min(before_result.mobility, after_result.mobility)
109+
if worst == IMMOVABLE
110+
mobility = IMMOVABLE
111+
depinfo.must_stay = true
112+
end
113+
for v in s.init_values
114+
_update_from_value!(depinfo, def_depth, v, depth)
115+
end
116+
update!(depinfo, before_result.min_depth, depth)
117+
update!(depinfo, after_result.min_depth, depth)
118+
119+
elseif s isa IfOp
120+
_update_from_value!(depinfo, def_depth, s.condition, depth)
121+
then_result = _hoist!(s.then_region, stack, def_depth, false)
122+
else_result = _hoist!(s.else_region, stack, def_depth, false)
123+
update!(depinfo, then_result.min_depth, depth)
124+
update!(depinfo, else_result.min_depth, depth)
125+
for r in (then_result, else_result)
126+
if r.mobility != CAN_MOVE
127+
mobility = min(mobility, r.mobility)
128+
depinfo.must_stay = true
129+
end
130+
end
131+
132+
elseif _is_memory_store(block, s)
133+
mobility = IMMOVABLE
134+
depinfo.must_stay = true
135+
else
136+
# Pure operation: check operand depths
137+
_update_operand_depths!(depinfo, def_depth, s, depth)
138+
end
139+
140+
# Determine target depth
141+
target_depth = depth
142+
if depinfo.must_stay
143+
min_depth = max(min_depth, depinfo.max_outside_depth)
144+
else
145+
while target_depth > depinfo.max_outside_depth && stack[target_depth + 1].is_loop_body
146+
target_depth -= 1
147+
end
148+
end
149+
150+
# Place at target depth
151+
push!(stack[target_depth + 1].new_body, (inst.ssa_idx, s, inst.typ))
152+
153+
# Record definition depth AFTER hoisting (enables cascading hoists)
154+
def_depth[SSAValue(inst.ssa_idx)] = target_depth
155+
end
156+
157+
# Handle terminator operands for min_depth computation
158+
term = block.terminator
159+
if term isa ContinueOp || term isa BreakOp
160+
mobility = min(mobility, CAN_MOVE_WITH_LOOP)
161+
end
162+
163+
pop!(stack)
164+
block.body = new_body
165+
return BlockResult(mobility, min_depth)
166+
end
167+
168+
# Check if a statement is a memory store (IMMOVABLE for LICM purposes).
169+
# Loads are hoistable (they're pure if operands are invariant).
170+
function _is_memory_store(block::Block, @nospecialize(s))
171+
s isa Expr || return false
172+
call = resolve_call(block, s)
173+
call === nothing && return false
174+
resolved_func, _ = call
175+
effect = classify_memory_op(resolved_func)
176+
return effect == MEM_STORE
177+
end
178+
179+
# Update DependencyInfo from a single IR value
180+
function _update_from_value!(di::DependencyInfo, def_depth::Dict{Any,Int}, @nospecialize(val), cur_depth::Int)
181+
d = get(def_depth, val, nothing)
182+
d !== nothing && update!(di, d, cur_depth)
183+
end
184+
185+
# Update DependencyInfo from all operands of a statement
186+
function _update_operand_depths!(di::DependencyInfo, def_depth::Dict{Any,Int}, @nospecialize(s), cur_depth::Int)
187+
if s isa Expr
188+
start = s.head === :invoke ? 3 : 2
189+
for i in start:length(s.args)
190+
_update_from_value!(di, def_depth, s.args[i], cur_depth)
191+
end
192+
elseif s isa Core.PiNode
193+
_update_from_value!(di, def_depth, s.val, cur_depth)
194+
end
195+
end

src/compiler/passes/pipeline.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,8 @@ function run_passes!(sci::StructuredIRCode)
327327
constants = propagate_constants(sci)
328328
rewrite_patterns!(sci, OPTIMIZATION_RULES; constants)
329329

330+
licm_pass!(sci)
331+
330332
alias_result = alias_analysis_pass!(sci)
331333
token_order_pass!(sci, alias_result)
332334

src/cuTile.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module cuTile
22

33
using IRStructurizer
4-
using IRStructurizer: Block, ControlFlowOp, BlockArgument,
4+
using IRStructurizer: Block, ControlFlowOp, BlockArgument, SSAMap,
55
YieldOp, ContinueOp, BreakOp, ConditionOp,
66
IfOp, ForOp, WhileOp, LoopOp, Undef,
77
SourceLocation, source_location
@@ -44,6 +44,7 @@ include("compiler/passes/canonicalize.jl")
4444
include("compiler/passes/alias_analysis.jl")
4545
include("compiler/passes/token_keys.jl")
4646
include("compiler/passes/token_order.jl")
47+
include("compiler/passes/licm.jl")
4748
include("compiler/passes/dce.jl")
4849
include("compiler/passes/pipeline.jl")
4950
include("compiler/codegen/debug.jl")

0 commit comments

Comments
 (0)