11# Loop-Invariant Code Motion (LICM)
22#
3- # Single-pass depth-tracking algorithm that hoists loop-invariant operations
4- # out of loops. Port of cuTile Python's `hoist_loop_invariants` (code_motion.py).
3+ # Hoists loop-invariant loads (and their dependency chains) out of loops.
54#
6- # The algorithm walks the IR recursively while tracking the definition depth
7- # of each value. An operation whose data dependencies all resolve to depths
8- # *less than* its containing loop can be hoisted above that loop. A stack of
9- # instruction lists collects operations at their target depth; at the end of
10- # each block, the original body is rebuilt from the filtered list.
11-
12- # Whether a block can be moved, based on the operations it contains.
13- @enum BlockMobility begin
14- IMMOVABLE # contains stores, returns, or nested IMMOVABLE blocks
15- CAN_MOVE_WITH_LOOP # contains continue/break
16- CAN_MOVE # pure operations only
17- end
18-
19- struct BlockResult
20- mobility:: BlockMobility
21- min_depth:: Int # minimum depth any op in this block needs
22- end
23-
24- mutable struct DependencyInfo
25- must_stay:: Bool
26- max_outside_depth:: Int
27- end
28-
29- function update! (di:: DependencyInfo , dep_depth:: Int , cur_depth:: Int )
30- if dep_depth >= cur_depth
31- di. must_stay = true
32- else
33- di. max_outside_depth = max (di. max_outside_depth, dep_depth)
34- end
35- end
36-
37- struct StackItem
38- entries:: Vector{Tuple{Int,Any,Any}} # (ssa_idx, stmt, typ) triples
39- is_loop_body:: Bool
40- end
5+ # Pure operations (arithmetic, broadcasts, view constructors) are NOT hoisted
6+ # here — MLIR's built-in LICM handles those at optLevel >= 2.
7+ #
8+ # This pass targets what MLIR cannot hoist: memory loads. After token ordering,
9+ # loads have token dependencies that anchor them inside loops. By hoisting
10+ # before token insertion, we avoid creating unnecessary token carries.
11+ #
12+ # Safety: a load is hoistable only when (1) all its operands are loop-invariant,
13+ # and (2) no store in the loop body aliases with the load's memory region.
14+ # Alias information comes from alias_analysis_pass!, which must run first.
4115
4216"""
43- licm_pass!(sci::StructuredIRCode)
17+ licm_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet})
18+
19+ Hoist loop-invariant loads out of loops. Must run after alias_analysis_pass!
20+ and before token_order_pass!.
4421
45- Hoist loop-invariant operations out of loops. Must run after rewrite_patterns!
46- and before token_order_pass! (which inserts tokens that should not be moved).
22+ A load is hoistable when:
23+ - All operands are defined outside the loop
24+ - No store in the loop body writes to an aliasing memory region
4725"""
48- function licm_pass! (sci:: StructuredIRCode )
26+ function licm_pass! (sci:: StructuredIRCode , alias_result :: Dict{Any, AliasSet} )
4927 def_depth = Dict {Any, Int} ()
5028 for i in 1 : length (sci. argtypes)
5129 def_depth[Argument (i)] = 0
5230 end
53- _hoist! (sci. entry, StackItem[], def_depth, false )
31+ _hoist_loads! (sci. entry, Vector {Vector{Tuple{Int,Any,Any}}} (), def_depth,
32+ alias_result, false )
5433 return
5534end
5635
57- function _hoist! (block:: Block , stack:: Vector{StackItem} , def_depth:: Dict{Any,Int} ,
58- is_loop_body:: Bool )
59- depth = length (stack)
60- push! (stack, StackItem (Tuple{Int,Any,Any}[], is_loop_body))
36+ # Collect alias sets of all stores in a block (recursively through nested CFs).
37+ function _collect_store_aliases (block:: Block , alias_result:: Dict{Any, AliasSet} )
38+ store_aliases = AliasSet[]
39+ for inst in instructions (block)
40+ s = stmt (inst)
41+ if s isa ControlFlowOp
42+ for b in blocks (s)
43+ append! (store_aliases, _collect_store_aliases (b, alias_result))
44+ end
45+ else
46+ call = resolve_call (block, s)
47+ call === nothing && continue
48+ resolved_func, operands = call
49+ if classify_memory_op (resolved_func) == MEM_STORE
50+ aset = get_alias_set_for_operand (alias_result, first (operands))
51+ push! (store_aliases, aset)
52+ end
53+ end
54+ end
55+ return store_aliases
56+ end
57+
58+ # Check if a load's alias set conflicts with any store alias set in the loop.
59+ function _aliases_with_store (load_alias:: AliasSet , store_aliases:: Vector{AliasSet} )
60+ for sa in store_aliases
61+ if load_alias isa AliasUniverse || sa isa AliasUniverse
62+ return true
63+ end
64+ if ! isempty (intersect (load_alias, sa))
65+ return true
66+ end
67+ end
68+ return false
69+ end
6170
62- mobility = CAN_MOVE
63- min_depth = 0
71+ function _hoist_loads! (block:: Block , stack:: Vector{Vector{Tuple{Int,Any,Any}}} ,
72+ def_depth:: Dict{Any,Int} , alias_result:: Dict{Any, AliasSet} ,
73+ is_loop_body:: Bool )
74+ depth = length (stack)
75+ push! (stack, Tuple{Int,Any,Any}[])
6476
6577 # Register block args at current depth
6678 for ba in block. args
6779 def_depth[ba] = depth
6880 end
6981
82+ # If this is a loop body, collect store alias sets for the load safety check
83+ store_aliases = is_loop_body ? _collect_store_aliases (block, alias_result) : AliasSet[]
84+
7085 for inst in instructions (block)
7186 s = stmt (inst)
72- depinfo = DependencyInfo ( ! is_loop_body, 0 )
87+ hoisted = false
7388
7489 if s isa ForOp || s isa LoopOp
7590 body = s. body
76- # ForOp's iv_arg is separate from body.args (which holds only carries)
7791 if s isa ForOp
7892 def_depth[s. iv_arg] = depth + 1
7993 end
8094 for ba in body. args
8195 def_depth[ba] = depth + 1
8296 end
83- body_result = _hoist! (body, stack, def_depth, true )
84- if body_result. mobility == IMMOVABLE
85- mobility = IMMOVABLE
86- depinfo. must_stay = true
87- end
88- for v in s. init_values
89- _update_from_value! (depinfo, def_depth, v, depth)
90- end
91- if s isa ForOp
92- for v in (s. lower, s. upper, s. step)
93- _update_from_value! (depinfo, def_depth, v, depth)
94- end
95- end
96- update! (depinfo, body_result. min_depth, depth)
97+ _hoist_loads! (body, stack, def_depth, alias_result, true )
9798
9899 elseif s isa WhileOp
99100 for ba in s. before. args
@@ -102,98 +103,75 @@ function _hoist!(block::Block, stack::Vector{StackItem}, def_depth::Dict{Any,Int
102103 for ba in s. after. args
103104 def_depth[ba] = depth + 1
104105 end
105- before_result = _hoist! (s. before, stack, def_depth, true )
106- after_result = _hoist! (s. after, stack, def_depth, true )
107- worst = min (before_result. mobility, after_result. mobility)
108- if worst == IMMOVABLE
109- mobility = IMMOVABLE
110- depinfo. must_stay = true
111- end
112- for v in s. init_values
113- _update_from_value! (depinfo, def_depth, v, depth)
114- end
115- update! (depinfo, before_result. min_depth, depth)
116- update! (depinfo, after_result. min_depth, depth)
106+ _hoist_loads! (s. before, stack, def_depth, alias_result, true )
107+ _hoist_loads! (s. after, stack, def_depth, alias_result, true )
117108
118109 elseif s isa IfOp
119- _update_from_value! (depinfo, def_depth, s. condition, depth)
120- then_result = _hoist! (s. then_region, stack, def_depth, false )
121- else_result = _hoist! (s. else_region, stack, def_depth, false )
122- update! (depinfo, then_result. min_depth, depth)
123- update! (depinfo, else_result. min_depth, depth)
124- for r in (then_result, else_result)
125- if r. mobility != CAN_MOVE
126- mobility = min (mobility, r. mobility)
127- depinfo. must_stay = true
128- end
129- end
130-
131- elseif _is_memory_store (block, s)
132- mobility = IMMOVABLE
133- depinfo. must_stay = true
134- else
135- # Pure operation: check operand depths
136- _update_operand_depths! (depinfo, def_depth, s, depth)
137- end
138-
139- # Determine target depth
140- target_depth = depth
141- if depinfo. must_stay
142- min_depth = max (min_depth, depinfo. max_outside_depth)
143- else
144- while target_depth > depinfo. max_outside_depth && stack[target_depth + 1 ]. is_loop_body
110+ _hoist_loads! (s. then_region, stack, def_depth, alias_result, false )
111+ _hoist_loads! (s. else_region, stack, def_depth, alias_result, false )
112+
113+ elseif is_loop_body && _is_hoistable_load (block, s, def_depth, depth,
114+ alias_result, store_aliases)
115+ # Hoist this load to the enclosing scope
116+ target_depth = depth - 1
117+ while target_depth > 0 && _can_hoist_to (stack, target_depth)
145118 target_depth -= 1
146119 end
120+ push! (stack[target_depth + 1 ], (inst. ssa_idx, s, inst. typ))
121+ def_depth[SSAValue (inst. ssa_idx)] = target_depth
122+ hoisted = true
147123 end
148124
149- # Place at target depth
150- push! (stack[target_depth + 1 ]. entries, (inst. ssa_idx, s, inst. typ))
151-
152- # Record definition depth AFTER hoisting (enables cascading hoists)
153- def_depth[SSAValue (inst. ssa_idx)] = target_depth
154- end
155-
156- # Handle terminator for mobility
157- term = block. terminator
158- if term isa ContinueOp || term isa BreakOp
159- mobility = min (mobility, CAN_MOVE_WITH_LOOP)
125+ if ! hoisted
126+ # Keep at current depth
127+ push! (stack[depth + 1 ], (inst. ssa_idx, s, inst. typ))
128+ def_depth[SSAValue (inst. ssa_idx)] = depth
129+ end
160130 end
161131
162132 # Rebuild block body from collected entries
163- entries = pop! (stack). entries
133+ entries = pop! (stack)
164134 empty! (block)
165135 for (idx, s, typ) in entries
166136 push! (block, idx, s, typ)
167137 end
138+ end
168139
169- return BlockResult (mobility, min_depth)
140+ # Check if a stack entry is a loop body (for multi-level hoisting)
141+ function _can_hoist_to (stack:: Vector{Vector{Tuple{Int,Any,Any}}} , target_depth:: Int )
142+ # We'd need to track is_loop_body per stack entry to do multi-level hoisting.
143+ # For now, only hoist one level out.
144+ return false
170145end
171146
172- # Check if a statement is a memory store (IMMOVABLE for LICM purposes).
173- # Loads are hoistable (they're pure if operands are invariant).
174- function _is_memory_store (block:: Block , @nospecialize (s))
147+ # Check if a statement is a load that can be safely hoisted.
148+ function _is_hoistable_load (block:: Block , @nospecialize (s), def_depth:: Dict{Any,Int} ,
149+ cur_depth:: Int , alias_result:: Dict{Any, AliasSet} ,
150+ store_aliases:: Vector{AliasSet} )
175151 s isa Expr || return false
176152 call = resolve_call (block, s)
177153 call === nothing && return false
178- resolved_func, _ = call
179- effect = classify_memory_op (resolved_func)
180- return effect == MEM_STORE
181- end
154+ resolved_func, operands = call
182155
183- # Update DependencyInfo from a single IR value
184- function _update_from_value! (di:: DependencyInfo , def_depth:: Dict{Any,Int} , @nospecialize (val), cur_depth:: Int )
185- d = get (def_depth, val, nothing )
186- d != = nothing && update! (di, d, cur_depth)
156+ # Must be a load operation
157+ classify_memory_op (resolved_func) == MEM_LOAD || return false
158+
159+ # All operands must be defined outside this loop
160+ _all_operands_outside (s, def_depth, cur_depth) || return false
161+
162+ # Load must not alias with any store in the loop
163+ load_alias = get_alias_set_for_operand (alias_result, first (operands))
164+ return ! _aliases_with_store (load_alias, store_aliases)
187165end
188166
189- # Update DependencyInfo from all operands of a statement
190- function _update_operand_depths! (di:: DependencyInfo , def_depth:: Dict{Any,Int} , @nospecialize (s), cur_depth:: Int )
191- if s isa Expr
192- start = s. head === :invoke ? 3 : 2
193- for i in start: length (s. args)
194- _update_from_value! (di, def_depth, s. args[i], cur_depth)
195- end
196- elseif s isa Core. PiNode
197- _update_from_value! (di, def_depth, s. val, cur_depth)
167+ # Check that all operands of a statement are defined at depth < cur_depth.
168+ function _all_operands_outside (@nospecialize (s), def_depth:: Dict{Any,Int} , cur_depth:: Int )
169+ s isa Expr || return true
170+ start = s. head === :invoke ? 3 : 2
171+ for i in start: length (s. args)
172+ d = get (def_depth, s. args[i], nothing )
173+ d === nothing && continue # constants/literals are always available
174+ d >= cur_depth && return false
198175 end
176+ return true
199177end
0 commit comments