Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ CUDA_Tile_jll = "13.1"
CompilerCaching = "0.2"
EnumX = "1.0"
GPUArrays = "11"
IRStructurizer = "0.5.1"
IRStructurizer = "0.5.3"
julia = "1.11"
74 changes: 74 additions & 0 deletions src/compiler/passes/licm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Loop-Invariant Code Motion (LICM)
#
# Hoists loop-invariant operations out of loops. Runs AFTER token_order_pass!
# so that token dependencies correctly prevent unsafe hoisting of aliasing loads.
#
# Operations classified as stores (store_partition_view, store_ptr_tko, atomics,
# print_tko) and control flow exits (return) are never hoisted. All other
# operations — including loads, arithmetic, partition views, token nodes — are
# hoisted when all their data dependencies are defined outside the loop.
#
# Uses IRStructurizer's `is_defined_outside`, `move_before!`, and `operands`
# primitives. Processes innermost loops first and repeats until fixpoint.
#
# This mirrors cuTile Python's code_motion.py:hoist_loop_invariants.

"""
licm_pass!(sci::StructuredIRCode)

Hoist loop-invariant operations out of loops. Must run after token_order_pass!.
"""
function licm_pass!(sci::StructuredIRCode)
for (loop_inst, loop_op) in collect_loops(sci.entry)
hoist_from_loop!(loop_inst, loop_op)
end
end

# Collect (instruction, loop_op) pairs in post-order (innermost first).
function collect_loops(root::Block)
result = Tuple{Instruction, Union{ForOp, LoopOp, WhileOp}}[]
collect_loops!(result, root)
return result
end

function collect_loops!(result, block::Block)
for inst in instructions(block)
s = stmt(inst)
if s isa ForOp || s isa LoopOp
collect_loops!(result, s.body)
push!(result, (inst, s))
elseif s isa WhileOp
collect_loops!(result, s.before)
collect_loops!(result, s.after)
push!(result, (inst, s))
elseif s isa ControlFlowOp
for b in blocks(s)
collect_loops!(result, b)
end
end
end
end

function hoist_from_loop!(loop_inst::Instruction, loop_op)
changed = true
while changed
changed = false
for body in blocks(loop_op)
for inst in collect(instructions(body))
stmt(inst) isa ControlFlowOp && continue
is_store(body, stmt(inst)) && continue
all(v -> is_defined_outside(v, loop_op), operands(body, inst)) || continue
move_before!(inst, loop_inst)
changed = true
end
end
end
end

# Check if a statement is a store/atomic (side-effecting memory write).
function is_store(block::Block, @nospecialize(s))
call = resolve_call(block, s)
call === nothing && return false
resolved_func, _ = call
return classify_memory_op(resolved_func) == MEM_STORE
end
3 changes: 3 additions & 0 deletions src/compiler/passes/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ function run_passes!(sci::StructuredIRCode)
rewrite_patterns!(sci, OPTIMIZATION_RULES; constants)

alias_result = alias_analysis_pass!(sci)

token_order_pass!(sci, alias_result)

licm_pass!(sci)

dce_pass!(sci)
end
5 changes: 5 additions & 0 deletions src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ IRStructurizer.walk_uses!(f, node::JoinTokensNode) =
IRStructurizer.walk_uses!(f, ::TokenResultNode) = nothing
IRStructurizer.walk_uses!(f, ::MakeTokenNode) = nothing

# operands extensions for cuTile-specific IR nodes.
operands(::Block, s::JoinTokensNode) = s.tokens
operands(::Block, s::TokenResultNode) = Any[SSAValue(s.mem_op_ssa)]
operands(::Block, ::MakeTokenNode) = Any[]


"""
is_token_type(typ) -> Bool
Expand Down
4 changes: 3 additions & 1 deletion src/cuTile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using IRStructurizer
using IRStructurizer: Block, ControlFlowOp, BlockArgument,
YieldOp, ContinueOp, BreakOp, ConditionOp,
IfOp, ForOp, WhileOp, LoopOp, Undef,
SourceLocation, source_location
SourceLocation
import IRStructurizer: operands

using Base: compilerbarrier, donotdelete
using Core: MethodInstance, CodeInfo, SSAValue, Argument, SlotNumber,
Expand Down Expand Up @@ -44,6 +45,7 @@ include("compiler/passes/canonicalize.jl")
include("compiler/passes/alias_analysis.jl")
include("compiler/passes/token_keys.jl")
include("compiler/passes/token_order.jl")
include("compiler/passes/licm.jl")
include("compiler/passes/dce.jl")
include("compiler/passes/pipeline.jl")
include("compiler/codegen/debug.jl")
Expand Down
32 changes: 32 additions & 0 deletions test/codegen/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,38 @@ end
end
end

@testset "loop-invariant load (manually hoisted)" begin
# Test that a manually-hoisted loop-invariant load appears before the loop.
# Pattern: Y[n, m] = X[n, m] * W[m], iterating over N-tiles.
# W[bid_m] doesn't depend on the loop variable, so the user hoists it.
spec2d = ct.ArraySpec{2}(16, true)
spec1d = ct.ArraySpec{1}(16, true)
@test @filecheck begin
@check_label "entry"
# W load must appear BEFORE the for loop
@check "load_view_tko"
@check "for %loopIdx in"
# Inside the loop: only the X load
@check "load_view_tko"
@check "mulf"
@check "store_view_tko"
code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d},
ct.TileArray{Float32,2,spec2d}, ct.Constant{Int,1024}}) do X, W, Y, TILE_N
bid_m = ct.bid(1)
num_tiles = ct.num_tiles(X, 1, (TILE_N, 1))
# Hoisted: W load before loop
w = ct.load(W; index=bid_m, shape=(1,))
for j in Int32(1):num_tiles
x = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1),
padding_mode=ct.PaddingMode.Zero)
y = x .* w
ct.store(Y; index=(j, bid_m), tile=y)
end
return
end
end
end

#=========================================================================
Gather/Scatter Operations
=========================================================================#
Expand Down
Loading