diff --git a/src/compiler/codegen.jl b/src/compiler/codegen.jl index fc461698..57e333b1 100644 --- a/src/compiler/codegen.jl +++ b/src/compiler/codegen.jl @@ -5,6 +5,7 @@ include("codegen/passes/rewrite.jl") # @rewrite, rewrite_patterns! fr include("codegen/passes/alias_analysis.jl") # alias_analysis_pass! include("codegen/passes/token_order.jl") # token_order_pass! include("codegen/passes/dce.jl") # dce_pass! +include("codegen/passes/index_normalize.jl") # index_lower_pass! include("codegen/passes/pipeline.jl") # run_passes! include("codegen/kernel.jl") include("codegen/control_flow.jl") diff --git a/src/compiler/codegen/passes/index_normalize.jl b/src/compiler/codegen/passes/index_normalize.jl new file mode 100644 index 00000000..b0549508 --- /dev/null +++ b/src/compiler/codegen/passes/index_normalize.jl @@ -0,0 +1,76 @@ +# Index Lowering Pass +# +# Converts 1-based Julia indices to 0-based Tile IR indices for +# load_partition_view and store_partition_view. +# +# Uses the RFunc (callable RHS) extension to the @rewrite framework: +# for each index element in the indices tuple, inserts subi(elem, 1). + +using Core: SSAValue + +#============================================================================= + Implementation +=============================================================================# + +function _lower_indices_rhs(sci, block, inst, match) + idx = match.bindings[:idx] + ref = SSAValue(inst) + + # Get tuple type to determine element count and types + idx_type = CC.widenconst(value_type(block, idx)) + idx_type <: Tuple || return false + n = fieldcount(idx_type) + n == 0 && return true + + # Extract each element via getfield, subtract 1, build new tuple + lowered = Any[] + for i in 1:n + ft = fieldtype(idx_type, i) + elem = SSAValue(insert_before!(block, ref, + Expr(:call, GlobalRef(Core, :getfield), idx, i), ft)) + sub = SSAValue(insert_before!(block, ref, + Expr(:call, GlobalRef(Intrinsics, :subi), elem, one(ft)), ft)) + push!(lowered, sub) + end + + new_tuple = SSAValue(insert_before!(block, ref, + Expr(:call, GlobalRef(Core, :tuple), lowered...), idx_type)) + + s = stmt(inst) + for i in eachindex(s.args) + s.args[i] === idx && (s.args[i] = new_tuple) + end + + return true +end + +#============================================================================= + Rewrite rules +=============================================================================# + +# Lower 1-based indices → 0-based for load_partition_view. +# args: (pv, latency, allow_tma, indices_tuple) +const LOAD_INDEX_LOWER = RewriteRule( + PCall(:load_partition_view, [PBind(:pv), PBind(:lat), PBind(:tma), PBind(:idx)]), + RFunc(_lower_indices_rhs) +) + +# Lower 1-based indices → 0-based for store_partition_view. +# args: (pv, tile, latency, allow_tma, indices_tuple) +const STORE_INDEX_LOWER = RewriteRule( + PCall(:store_partition_view, [PBind(:pv), PBind(:tile), PBind(:lat), PBind(:tma), PBind(:idx)]), + RFunc(_lower_indices_rhs) +) + +const INDEX_LOWER_RULES = RewriteRule[LOAD_INDEX_LOWER, STORE_INDEX_LOWER] + +#============================================================================= + Driver +=============================================================================# + +""" + index_lower_pass!(sci::StructuredIRCode) + +Lower 1-based Julia indices to 0-based Tile IR indices for load/store ops. +""" +index_lower_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, INDEX_LOWER_RULES) diff --git a/src/compiler/codegen/passes/pipeline.jl b/src/compiler/codegen/passes/pipeline.jl index 450706f0..7665056d 100644 --- a/src/compiler/codegen/passes/pipeline.jl +++ b/src/compiler/codegen/passes/pipeline.jl @@ -124,6 +124,7 @@ and subprogram compilation. function run_passes!(sci::StructuredIRCode) # Rewrite passes (order matters: normalize before optimize, SVE before FMA) normalize_pass!(sci) + index_lower_pass!(sci) scalar_view_elim_pass!(sci) fma_fusion_pass!(sci) diff --git a/src/language/operations.jl b/src/language/operations.jl index 1d8edce3..405f7032 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -134,7 +134,7 @@ tile = ct.load(arr, (bidx, bidy), (TM, TN); order=(2, 1)) matched = _match_shape(Val(shape), Val(ndims(arr))) tv = Intrinsics.make_tensor_view(arr) pv = Intrinsics.make_partition_view(tv, matched, padding_mode, order) - tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...) .- One()) + tile = Intrinsics.load_partition_view(pv, latency, allow_tma, promote(index...)) reshape(tile, shape) end @@ -148,7 +148,7 @@ end tv = Intrinsics.make_tensor_view(arr) shape = ntuple(_ -> 1, Val(N)) pv = Intrinsics.make_partition_view(tv, shape, PaddingMode.Undetermined, nothing) - tile = Intrinsics.load_partition_view(pv, nothing, nothing, promote(indices...) .- One()) + tile = Intrinsics.load_partition_view(pv, nothing, nothing, promote(indices...)) Intrinsics.to_scalar(reshape(tile, ())) end @@ -203,7 +203,7 @@ Returns the stored tile (enables chaining and helps constant folding). latency::Union{Int, Nothing}=nothing, allow_tma::Union{Bool, Nothing}=nothing) where {T} reshaped = _reshape_to_rank(tile, Val(ndims(arr))) - _store_reshaped(arr, reshaped, order, latency, allow_tma, promote(index...) .- One()) + _store_reshaped(arr, reshaped, order, latency, allow_tma, promote(index...)) return tile # XXX: enables constant folding; remove when possible (see "constant folding" test) end