11# Loop-Invariant Code Motion (LICM)
22#
3- # Hoists loop-invariant loads (and their dependency chains) out of loops.
3+ # Hoists loop-invariant operations out of loops. Runs AFTER token_order_pass!
4+ # so that token dependencies correctly prevent unsafe hoisting of aliasing loads.
45#
5- # Pure operations (arithmetic, broadcasts, view constructors) are NOT hoisted
6- # here — MLIR's built-in LICM handles those at optLevel >= 2.
6+ # Operations classified as stores (store_partition_view, store_ptr_tko, atomics,
7+ # print_tko) and control flow exits (return) are never hoisted. All other
8+ # operations — including loads, arithmetic, partition views, token nodes — are
9+ # hoisted when all their data dependencies are defined outside the loop.
710#
8- # This pass targets what MLIR cannot hoist: memory loads. After token ordering,
9- # loads have token dependencies that anchor them inside loops. By hoisting
10- # before token insertion, we avoid creating unnecessary token carries.
11- #
12- # Safety: a load is hoistable only when (1) all its operands are loop-invariant,
13- # and (2) no store in the loop body aliases with the load's memory region.
14- # Alias information comes from alias_analysis_pass!, which must run first.
11+ # This mirrors cuTile Python's code_motion.py:hoist_loop_invariants.
12+
13+ # Indicates whether a block could in theory be moved, based on the operations
14+ # it contains (side effects, jumps). Does not consider data dependencies.
15+ @enum _BlockMobility:: Int8 begin
16+ # The block (or any ancestor) cannot be moved due to side effects.
17+ _IMMOVABLE = 0
18+ # The block itself can't be hoisted alone, but its containing loop can.
19+ # Happens when the block contains Continue or Break.
20+ _CAN_MOVE_WITH_LOOP = 1
21+ # The block can move (subject to data dependencies).
22+ _CAN_MOVE = 2
23+ end
1524
16- """
17- licm_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet})
25+ struct _BlockResult
26+ mobility:: _BlockMobility
27+ min_depth:: Int # deepest outside dependency of any hoisted-out op
28+ end
1829
19- Hoist loop-invariant loads out of loops. Must run after alias_analysis_pass!
20- and before token_order_pass!.
30+ # Helper for accumulating data dependency information per operation.
31+ mutable struct _DepInfo
32+ must_stay:: Bool
33+ max_outside_depth:: Int
34+ end
2135
22- A load is hoistable when:
23- - All operands are defined outside the loop
24- - No store in the loop body writes to an aliasing memory region
25- """
26- function licm_pass! (sci:: StructuredIRCode , alias_result:: Dict{Any, AliasSet} )
27- def_depth = Dict {Any, Int} ()
28- for i in 1 : length (sci. argtypes)
29- def_depth[Argument (i)] = 0
36+ function _update_dep! (di:: _DepInfo , dep_depth:: Int , cur_depth:: Int )
37+ if dep_depth >= cur_depth
38+ di. must_stay = true
39+ else
40+ di. max_outside_depth = max (di. max_outside_depth, dep_depth)
3041 end
31- _hoist_loads! (sci. entry, Vector {Vector{Tuple{Int,Any,Any}}} (), def_depth,
32- alias_result, false )
33- return
3442end
3543
36- # Collect alias sets of all stores in a block (recursively through nested CFs).
37- function _collect_store_aliases (block:: Block , alias_result:: Dict{Any, AliasSet} )
38- store_aliases = AliasSet[]
39- for inst in instructions (block)
40- s = stmt (inst)
41- if s isa ControlFlowOp
42- for b in blocks (s)
43- append! (store_aliases, _collect_store_aliases (b, alias_result))
44- end
45- else
46- call = resolve_call (block, s)
47- call === nothing && continue
48- resolved_func, operands = call
49- if classify_memory_op (resolved_func) == MEM_STORE
50- aset = get_alias_set_for_operand (alias_result, first (operands))
51- push! (store_aliases, aset)
52- end
53- end
54- end
55- return store_aliases
44+ # Update dependency info from an SSA value or literal.
45+ function _check_val! (di:: _DepInfo , val, def_depth:: Dict{Any,Int} , cur_depth:: Int )
46+ d = get (def_depth, val, nothing )
47+ d === nothing && return # constants/literals always available
48+ _update_dep! (di, d, cur_depth)
5649end
5750
58- # Check if a load's alias set conflicts with any store alias set in the loop.
59- function _aliases_with_store (load_alias:: AliasSet , store_aliases:: Vector{AliasSet} )
60- for sa in store_aliases
61- if load_alias isa AliasUniverse || sa isa AliasUniverse
62- return true
51+ # Extract all SSA dependencies from a statement.
52+ function _check_stmt_deps! (di:: _DepInfo , @nospecialize (s), def_depth:: Dict{Any,Int} ,
53+ cur_depth:: Int )
54+ if s isa Expr
55+ start = s. head === :invoke ? 3 : 2
56+ for i in start: length (s. args)
57+ _check_val! (di, s. args[i], def_depth, cur_depth)
6358 end
64- if ! isempty (intersect (load_alias, sa))
65- return true
59+ elseif s isa JoinTokensNode
60+ for tok in s. tokens
61+ _check_val! (di, tok, def_depth, cur_depth)
6662 end
63+ elseif s isa TokenResultNode
64+ _check_val! (di, SSAValue (s. mem_op_ssa), def_depth, cur_depth)
65+ end
66+ # MakeTokenNode, PiNode, GlobalRef, literals: no SSA deps
67+ end
68+
69+ struct _StackItem
70+ entries:: Vector{Tuple{Int,Any,Any}} # (ssa_idx, stmt, type)
71+ is_loop_body:: Bool
72+ end
73+
74+ """
75+ licm_pass!(sci::StructuredIRCode)
76+
77+ Hoist loop-invariant operations out of loops. Must run after token_order_pass!.
78+ """
79+ function licm_pass! (sci:: StructuredIRCode )
80+ def_depth = Dict {Any,Int} ()
81+ for i in 1 : length (sci. argtypes)
82+ def_depth[Argument (i)] = 0
6783 end
68- return false
84+ _hoist! (sci. entry, _StackItem[], def_depth, false )
85+ return
6986end
7087
71- function _hoist_loads! (block:: Block , stack:: Vector{Vector{Tuple{Int,Any,Any}}} ,
72- def_depth:: Dict{Any,Int} , alias_result:: Dict{Any, AliasSet} ,
73- is_loop_body:: Bool )
88+ function _hoist! (block:: Block , stack:: Vector{_StackItem} , def_depth:: Dict{Any,Int} ,
89+ is_loop_body:: Bool )
7490 depth = length (stack)
75- push! (stack, Tuple{Int,Any,Any}[])
91+ push! (stack, _StackItem ( Tuple{Int,Any,Any}[], is_loop_body) )
7692
77- # Register block args at current depth
7893 for ba in block. args
7994 def_depth[ba] = depth
8095 end
8196
82- # If this is a loop body, collect store alias sets for the load safety check
83- store_aliases = is_loop_body ? _collect_store_aliases (block, alias_result) : AliasSet[]
97+ mobility = _CAN_MOVE
98+ min_depth = 0
8499
85100 for inst in instructions (block)
86101 s = stmt (inst)
87- hoisted = false
102+ di = _DepInfo ( ! is_loop_body, 0 )
88103
89- if s isa ForOp || s isa LoopOp
90- body = s. body
91- if s isa ForOp
92- def_depth[s. iv_arg] = depth + 1
104+ if s isa ForOp
105+ def_depth[s. iv_arg] = depth + 1
106+ for ba in s. body. args
107+ def_depth[ba] = depth + 1
108+ end
109+ body_res = _hoist! (s. body, stack, def_depth, true )
110+ if body_res. mobility == _IMMOVABLE
111+ mobility = _IMMOVABLE
112+ di. must_stay = true
113+ end
114+ for v in s. init_values
115+ _check_val! (di, v, def_depth, depth)
93116 end
94- for ba in body. args
117+ _update_dep! (di, body_res. min_depth, depth)
118+
119+ elseif s isa LoopOp
120+ for ba in s. body. args
95121 def_depth[ba] = depth + 1
96122 end
97- _hoist_loads! (body, stack, def_depth, alias_result, true )
123+ body_res = _hoist! (s. body, stack, def_depth, true )
124+ if body_res. mobility == _IMMOVABLE
125+ mobility = _IMMOVABLE
126+ di. must_stay = true
127+ end
128+ for v in s. init_values
129+ _check_val! (di, v, def_depth, depth)
130+ end
131+ _update_dep! (di, body_res. min_depth, depth)
98132
99133 elseif s isa WhileOp
100134 for ba in s. before. args
@@ -103,75 +137,86 @@ function _hoist_loads!(block::Block, stack::Vector{Vector{Tuple{Int,Any,Any}}},
103137 for ba in s. after. args
104138 def_depth[ba] = depth + 1
105139 end
106- _hoist_loads! (s. before, stack, def_depth, alias_result, true )
107- _hoist_loads! (s. after, stack, def_depth, alias_result, true )
140+ before_res = _hoist! (s. before, stack, def_depth, true )
141+ after_res = _hoist! (s. after, stack, def_depth, true )
142+ if min (before_res. mobility, after_res. mobility) == _IMMOVABLE
143+ mobility = _IMMOVABLE
144+ di. must_stay = true
145+ end
146+ for v in s. init_values
147+ _check_val! (di, v, def_depth, depth)
148+ end
149+ _update_dep! (di, before_res. min_depth, depth)
150+ _update_dep! (di, after_res. min_depth, depth)
108151
109152 elseif s isa IfOp
110- _hoist_loads! (s. then_region, stack, def_depth, alias_result, false )
111- _hoist_loads! (s. else_region, stack, def_depth, alias_result, false )
112-
113- elseif is_loop_body && _is_hoistable_load (block, s, def_depth, depth,
114- alias_result, store_aliases)
115- # Hoist this load to the enclosing scope
116- target_depth = depth - 1
117- while target_depth > 0 && _can_hoist_to (stack, target_depth)
118- target_depth -= 1
153+ _check_val! (di, s. condition, def_depth, depth)
154+ for region in (s. then_region, s. else_region)
155+ branch_res = _hoist! (region, stack, def_depth, false )
156+ _update_dep! (di, branch_res. min_depth, depth)
157+ if branch_res. mobility != _CAN_MOVE
158+ mobility = min (mobility, branch_res. mobility)
159+ di. must_stay = true
160+ end
119161 end
120- push! (stack[target_depth + 1 ], (inst. ssa_idx, s, inst. typ))
121- def_depth[SSAValue (inst. ssa_idx)] = target_depth
122- hoisted = true
162+
163+ elseif _is_store (block, s)
164+ mobility = _IMMOVABLE
165+ di. must_stay = true
166+
167+ elseif s isa ContinueOp || s isa BreakOp
168+ mobility = min (mobility, _CAN_MOVE_WITH_LOOP)
169+ di. must_stay = true
170+
171+ elseif s isa YieldOp || s isa ConditionOp || s isa ReturnNode
172+ di. must_stay = true
173+ # Track deps for YieldOp/ConditionOp so min_depth is correct
174+ if s isa YieldOp
175+ for v in s. values
176+ _check_val! (di, v, def_depth, depth)
177+ end
178+ elseif s isa ConditionOp
179+ _check_val! (di, s. condition, def_depth, depth)
180+ for v in s. args
181+ _check_val! (di, v, def_depth, depth)
182+ end
183+ end
184+
185+ else
186+ # Movable operation: loads, arithmetic, make_partition_view, etc.
187+ _check_stmt_deps! (di, s, def_depth, depth)
123188 end
124189
125- if ! hoisted
126- # Keep at current depth
127- push! (stack[depth + 1 ], (inst. ssa_idx, s, inst. typ))
128- def_depth[SSAValue (inst. ssa_idx)] = depth
190+ # Determine target depth
191+ target_depth = depth
192+ if di. must_stay
193+ min_depth = max (min_depth, di. max_outside_depth)
194+ else
195+ while target_depth > di. max_outside_depth && stack[target_depth]. is_loop_body
196+ target_depth -= 1
197+ end
129198 end
199+
200+ push! (stack[target_depth + 1 ]. entries, (inst. ssa_idx, s, inst. typ))
201+
202+ # Record definition depth AFTER hoisting so subsequent ops see the new depth
203+ def_depth[SSAValue (inst. ssa_idx)] = target_depth
130204 end
131205
132206 # Rebuild block body from collected entries
133- entries = pop! (stack)
207+ entries = pop! (stack). entries
134208 empty! (block)
135209 for (idx, s, typ) in entries
136210 push! (block, idx, s, typ)
137211 end
138- end
139212
140- # Check if a stack entry is a loop body (for multi-level hoisting)
141- function _can_hoist_to (stack:: Vector{Vector{Tuple{Int,Any,Any}}} , target_depth:: Int )
142- # We'd need to track is_loop_body per stack entry to do multi-level hoisting.
143- # For now, only hoist one level out.
144- return false
213+ return _BlockResult (mobility, min_depth)
145214end
146215
147- # Check if a statement is a load that can be safely hoisted.
148- function _is_hoistable_load (block:: Block , @nospecialize (s), def_depth:: Dict{Any,Int} ,
149- cur_depth:: Int , alias_result:: Dict{Any, AliasSet} ,
150- store_aliases:: Vector{AliasSet} )
151- s isa Expr || return false
216+ # Check if a statement is a store/atomic (side-effecting memory write).
217+ function _is_store (block:: Block , @nospecialize (s))
152218 call = resolve_call (block, s)
153219 call === nothing && return false
154- resolved_func, operands = call
155-
156- # Must be a load operation
157- classify_memory_op (resolved_func) == MEM_LOAD || return false
158-
159- # All operands must be defined outside this loop
160- _all_operands_outside (s, def_depth, cur_depth) || return false
161-
162- # Load must not alias with any store in the loop
163- load_alias = get_alias_set_for_operand (alias_result, first (operands))
164- return ! _aliases_with_store (load_alias, store_aliases)
165- end
166-
167- # Check that all operands of a statement are defined at depth < cur_depth.
168- function _all_operands_outside (@nospecialize (s), def_depth:: Dict{Any,Int} , cur_depth:: Int )
169- s isa Expr || return true
170- start = s. head === :invoke ? 3 : 2
171- for i in start: length (s. args)
172- d = get (def_depth, s. args[i], nothing )
173- d === nothing && continue # constants/literals are always available
174- d >= cur_depth && return false
175- end
176- return true
220+ resolved_func, _ = call
221+ return classify_memory_op (resolved_func) == MEM_STORE
177222end
0 commit comments