Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include("codegen/passes/rewrite.jl") # @rewrite, rewrite_patterns! fr
include("codegen/passes/alias_analysis.jl") # alias_analysis_pass!
include("codegen/passes/token_order.jl") # token_order_pass!
include("codegen/passes/dce.jl") # dce_pass!
include("codegen/passes/index_normalize.jl") # index_lower_pass!
include("codegen/passes/pipeline.jl") # run_passes!
include("codegen/kernel.jl")
include("codegen/control_flow.jl")
Expand Down
76 changes: 76 additions & 0 deletions src/compiler/codegen/passes/index_normalize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Index Lowering Pass
#
# Converts 1-based Julia indices to 0-based Tile IR indices for
# load_partition_view and store_partition_view.
#
# Uses the RFunc (callable RHS) extension to the @rewrite framework:
# for each index element in the indices tuple, inserts subi(elem, 1).

using Core: SSAValue

#=============================================================================
Implementation
=============================================================================#

function _lower_indices_rhs(sci, block, inst, match)
idx = match.bindings[:idx]
ref = SSAValue(inst)

# Get tuple type to determine element count and types
idx_type = CC.widenconst(value_type(block, idx))
idx_type <: Tuple || return false
n = fieldcount(idx_type)
n == 0 && return true

# Extract each element via getfield, subtract 1, build new tuple
lowered = Any[]
for i in 1:n
ft = fieldtype(idx_type, i)
elem = SSAValue(insert_before!(block, ref,
Expr(:call, GlobalRef(Core, :getfield), idx, i), ft))
sub = SSAValue(insert_before!(block, ref,
Expr(:call, GlobalRef(Intrinsics, :subi), elem, one(ft)), ft))
push!(lowered, sub)
end

new_tuple = SSAValue(insert_before!(block, ref,
Expr(:call, GlobalRef(Core, :tuple), lowered...), idx_type))

s = stmt(inst)
for i in eachindex(s.args)
s.args[i] === idx && (s.args[i] = new_tuple)
end

return true
end

#=============================================================================
Rewrite rules
=============================================================================#

# Lower 1-based indices → 0-based for load_partition_view.
# args: (pv, latency, allow_tma, indices_tuple)
const LOAD_INDEX_LOWER = RewriteRule(
PCall(:load_partition_view, [PBind(:pv), PBind(:lat), PBind(:tma), PBind(:idx)]),
RFunc(_lower_indices_rhs)
)

# Lower 1-based indices → 0-based for store_partition_view.
# args: (pv, tile, latency, allow_tma, indices_tuple)
const STORE_INDEX_LOWER = RewriteRule(
PCall(:store_partition_view, [PBind(:pv), PBind(:tile), PBind(:lat), PBind(:tma), PBind(:idx)]),
RFunc(_lower_indices_rhs)
)

const INDEX_LOWER_RULES = RewriteRule[LOAD_INDEX_LOWER, STORE_INDEX_LOWER]

#=============================================================================
Driver
=============================================================================#

"""
index_lower_pass!(sci::StructuredIRCode)

Lower 1-based Julia indices to 0-based Tile IR indices for load/store ops.
"""
index_lower_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, INDEX_LOWER_RULES)
1 change: 1 addition & 0 deletions src/compiler/codegen/passes/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ and subprogram compilation.
function run_passes!(sci::StructuredIRCode)
# Rewrite passes (order matters: normalize before optimize, SVE before FMA)
normalize_pass!(sci)
index_lower_pass!(sci)
scalar_view_elim_pass!(sci)
fma_fusion_pass!(sci)

Expand Down
6 changes: 3 additions & 3 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ tile = ct.load(arr, (bidx, bidy), (TM, TN); order=(2, 1))
matched = _match_shape(Val(shape), Val(ndims(arr)))
tv = Intrinsics.make_tensor_view(arr)
pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order)
tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One())
tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...))
reshape(tile, shape)
end

Expand All @@ -148,7 +148,7 @@ end
tv = Intrinsics.make_tensor_view(arr)
shape = ntuple(_ -> 1, Val(N))
pv = Intrinsics.make_partition_view(tv, shape, PaddingMode.Undetermined, nothing)
tile = Intrinsics.load_partition_view(pv, nothing, nothing, promote(indices...) .- One())
tile = Intrinsics.load_partition_view(pv, nothing, nothing, promote(indices...))
Intrinsics.to_scalar(reshape(tile, ()))
end

Expand Down Expand Up @@ -203,7 +203,7 @@ Returns the stored tile (enables chaining and helps constant folding).
latency::Union{Int, Nothing}=nothing,
allow_tma::Union{Bool, Nothing}=nothing) where {T}
reshaped = _reshape_to_rank(tile, Val(ndims(arr)))
_store_reshaped(arr, reshaped, order, latency, allow_tma, promote(index...) .- One())
_store_reshaped(arr, reshaped, order, latency, allow_tma, promote(index...))
return tile # XXX: enables constant folding; remove when possible (see "constant folding" test)
end

Expand Down
Loading