Skip to content

Commit 0135652

Browse files
[Codegen][AMDGPU] Pattern match transfer_read+transpose to global_transpose_load for RDNA4
Add TransferReadTransposeToGlobalTransposeLoad pattern to ROCDLLoadToTransposeLoad that matches: vector.transfer_read %src[row, col] : memref<..., global>, vector<1x8xT> vector.transpose %read, [1, 0] : vector<1x8xT> to vector<8x1xT> vector.transfer_write %transposed, %dst[n, k] : ..., workgroup and replaces it with amdgpu.global_transpose_load on gfx1200+ (RDNA4). The hardware 8x8 wave-level transpose (global_load_tr_b128 for bf16/f16) means each lane's result is written at a different N position within the K group. The write indices are corrected to produce contiguous K writes: N_new = N_base + K_single % N K_new = (K_single floordiv N) * N The pass now dispatches separately for gfx950 (LDS transpose, existing) and gfx1200+ (global transpose, new), so neither path interferes with the other. Part of: #24454 Co-authored-by: Claude Sonnet 4 (1M context) <noreply@anthropic.com> Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent ca84a96 commit 0135652

4 files changed

Lines changed: 329 additions & 12 deletions

File tree

compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLoadToTransposeLoad.cpp

Lines changed: 229 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
99
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
1010
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
11+
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
1112
#include "llvm/ADT/SetVector.h"
1213
#include "llvm/Support/Debug.h"
1314
#include "llvm/Support/DebugLog.h"
@@ -18,6 +19,7 @@
1819
#include "mlir/Dialect/Arith/IR/Arith.h"
1920
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2021
#include "mlir/Dialect/MemRef/IR/MemRef.h"
22+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2123
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2224
#include "mlir/IR/PatternMatch.h"
2325
#include "mlir/Pass/Pass.h"
@@ -34,6 +36,7 @@ namespace {
3436

3537
constexpr int64_t kTransposeLoadLaneGroupSize = 16;
3638
constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
39+
constexpr amdgpu::Chipset kGfx1200 = amdgpu::Chipset(12, 0, 0);
3740
constexpr llvm::StringLiteral kPassLocalHintAttr = "__pass_local_hint";
3841

3942
//===----------------------------------------------------------------------===//
@@ -751,6 +754,203 @@ struct TransferReadToTransposeLoad final
751754
}
752755
};
753756

757+
//===----------------------------------------------------------------------===//
758+
// Global Transfer Read + Transpose to Global Transpose Load Pattern
759+
//===----------------------------------------------------------------------===//
760+
761+
/// Returns true if the memref memory space is a flat global (not workgroup,
762+
/// not fat_raw_buffer). Accepts no memory space, gpu::Global, integer 0/1,
763+
/// or #hal.descriptor_type<storage_buffer>.
764+
static bool isGlobalMemorySpace(Attribute memSpace) {
765+
if (!memSpace) {
766+
return true;
767+
}
768+
if (auto gpuAttr = dyn_cast<gpu::AddressSpaceAttr>(memSpace)) {
769+
return gpuAttr.getValue() == gpu::AddressSpace::Global;
770+
}
771+
if (auto intAttr = dyn_cast<IntegerAttr>(memSpace)) {
772+
return intAttr.getInt() == 0 || intAttr.getInt() == 1;
773+
}
774+
// Accept HAL descriptor_type (flat global binding in IREE).
775+
if (isa<IREE::HAL::DescriptorTypeAttr>(memSpace)) {
776+
return true;
777+
}
778+
return false;
779+
}
780+
781+
/// Returns the required vector size (number of elements in the transposed
782+
/// dimension) for global_transpose_load given an element type, or nullopt if
783+
/// unsupported.
784+
static std::optional<int64_t>
785+
getGlobalTransposeLoadVectorSize(Type elementType) {
786+
unsigned bits = elementType.getIntOrFloatBitWidth();
787+
switch (bits) {
788+
case 8:
789+
case 16:
790+
return 8;
791+
default:
792+
return std::nullopt;
793+
}
794+
}
795+
796+
/// Matches:
797+
/// %read = vector.transfer_read %src[%row, %col] : memref<..., global>,
798+
/// vector<Nx1xT>
799+
/// %result = vector.transpose %read, [1, 0] : vector<Nx1xT> to vector<1xNxT>
800+
///
801+
/// and replaces with:
802+
/// %cast = memref.memory_space_cast %src : ... to memref<..., global>
803+
/// %tr = amdgpu.global_transpose_load %cast[%row, %col]
804+
/// : memref<..., global> -> vector<NxT>
805+
/// %result = vector.shape_cast %tr : vector<NxT> to vector<1xNxT>
806+
///
807+
/// Only fires on gfx1250+ targets.
808+
struct TransferReadTransposeToGlobalTransposeLoad final
809+
: OpRewritePattern<vector::TransposeOp> {
810+
using Base::Base;
811+
812+
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
813+
PatternRewriter &rewriter) const override {
814+
// Must be a simple [1, 0] transpose (2D).
815+
ArrayRef<int64_t> perm = transposeOp.getPermutation();
816+
if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0) {
817+
return rewriter.notifyMatchFailure(transposeOp,
818+
"not a 2D [1,0] transpose");
819+
}
820+
821+
// Input to transpose must be a transfer_read.
822+
auto transferOp =
823+
transposeOp.getVector().getDefiningOp<vector::TransferReadOp>();
824+
if (!transferOp) {
825+
return rewriter.notifyMatchFailure(transposeOp,
826+
"not fed by transfer_read");
827+
}
828+
829+
// Source memref must be flat global (not workgroup, not fat_raw_buffer).
830+
auto memrefType = cast<MemRefType>(transferOp.getBase().getType());
831+
if (!isGlobalMemorySpace(memrefType.getMemorySpace())) {
832+
return rewriter.notifyMatchFailure(transposeOp,
833+
"source is not flat global memory");
834+
}
835+
836+
// With (K-outer, N-inner) iteration in the linalg.generic copy, N is the
837+
// vectorized inner dimension. The transfer_read reads 8 contiguous
838+
// N-elements per lane, giving vector<1xNxT> (1 K row, N contiguous cols).
839+
// vector<1x8xT> [K_dim=1, N_dim=8]
840+
// After the software transpose [1,0] this becomes vector<8x1xT> [N,K],
841+
// which is written to alloc_8[N_base, K_single] along N (stride-8 write).
842+
// global_load_tr replaces the N read: each of 8 consecutive lanes provides
843+
// its own K-row address, the hardware transposes (8×8 block transpose),
844+
// and the result is written with the corrected stride-8 subview.
845+
VectorType readType = transferOp.getVectorType();
846+
if (readType.getRank() != 2 || readType.getDimSize(0) != 1) {
847+
return rewriter.notifyMatchFailure(transposeOp,
848+
"expected vector<1xNxT> from read");
849+
}
850+
851+
// Check element type and expected N (number of contiguous N elements read).
852+
Type elemType = readType.getElementType();
853+
std::optional<int64_t> expectedN =
854+
getGlobalTransposeLoadVectorSize(elemType);
855+
if (!expectedN) {
856+
return rewriter.notifyMatchFailure(transposeOp,
857+
"unsupported element type");
858+
}
859+
if (readType.getDimSize(1) != *expectedN) {
860+
return rewriter.notifyMatchFailure(
861+
transposeOp,
862+
"vector inner dim does not match global_transpose_load size");
863+
}
864+
865+
// Must be in_bounds.
866+
ArrayAttr inBounds = transferOp.getInBounds();
867+
if (!inBounds || !llvm::all_of(inBounds.getAsRange<BoolAttr>(),
868+
[](BoolAttr b) { return b.getValue(); })) {
869+
return rewriter.notifyMatchFailure(transposeOp,
870+
"transfer_read not in_bounds");
871+
}
872+
873+
// Permutation map must be identity (no broadcast).
874+
if (!transferOp.getPermutationMap().isIdentity()) {
875+
return rewriter.notifyMatchFailure(transposeOp,
876+
"non-identity permutation map");
877+
}
878+
879+
Location loc = transposeOp.getLoc();
880+
881+
// Cast memref to gpu::AddressSpace::Global if needed so that
882+
// amdgpu.global_transpose_load verifier is satisfied.
883+
Value src = transferOp.getBase();
884+
if (memrefType.getMemorySpace()) {
885+
auto globalSpace = gpu::AddressSpaceAttr::get(rewriter.getContext(),
886+
gpu::AddressSpace::Global);
887+
auto globalMemrefType =
888+
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
889+
memrefType.getLayout(), globalSpace);
890+
src = memref::MemorySpaceCastOp::create(rewriter, loc, globalMemrefType,
891+
src);
892+
}
893+
894+
// The transpose result must have exactly one use: a transfer_write to
895+
// workgroup memory at [N_base, K_single] with vector<8x1>.
896+
// With the K-inner tiling (UseGlobalTransposeLoadAttr, (N-outer, K-inner)
897+
// linalg.generic), global_load_tr's 8-lane wave-level 8×8 transpose means
898+
// lane K_single's result[i] = B[K_group_base+i, N_base + K_single%N].
899+
// This should be written to alloc_8[N_base + K_single%N, K_group_base..N-1]
900+
// as vector<1x8> (contiguous K) — no subview needed.
901+
if (!transposeOp->hasOneUse()) {
902+
return rewriter.notifyMatchFailure(transposeOp,
903+
"transpose result has multiple uses");
904+
}
905+
auto writeOp =
906+
dyn_cast<vector::TransferWriteOp>(*transposeOp->user_begin());
907+
if (!writeOp) {
908+
return rewriter.notifyMatchFailure(
909+
transposeOp, "transpose not consumed by transfer_write");
910+
}
911+
auto writeDst = cast<MemRefType>(writeOp.getBase().getType());
912+
if (!hasSharedMemoryAddressSpace(writeDst)) {
913+
return rewriter.notifyMatchFailure(
914+
transposeOp, "write destination is not workgroup memory");
915+
}
916+
917+
// Emit amdgpu.global_transpose_load.
918+
int64_t N = *expectedN;
919+
auto resultVecType = VectorType::get({N}, elemType);
920+
auto trLoad = amdgpu::GlobalTransposeLoadOp::create(
921+
rewriter, loc, resultVecType, src, transferOp.getIndices());
922+
923+
// Compute corrected write indices for contiguous K writes:
924+
// N_new = N_base + K_single % N (lane's N position within N_base
925+
// group) K_new = (K_single // N) * N (K-group base, aligned to N)
926+
// Write vector<1xNxT> at [N_new, K_new] → alloc_8[N_new, K_new..K_new+N-1]
927+
// This is contiguous K in alloc_8[N, K] (K is inner) — no subview needed.
928+
ValueRange writeIndices = writeOp.getIndices();
929+
assert(writeIndices.size() == 2 && "expected 2D write");
930+
Value nBase = writeIndices[0]; // N_group * N
931+
Value kSingle = writeIndices[1]; // K lane value (0..K_total-1)
932+
933+
AffineExpr dn = rewriter.getAffineDimExpr(0);
934+
AffineExpr dk = rewriter.getAffineDimExpr(1);
935+
AffineMap nNewMap = AffineMap::get(2, 0, dn + dk % N);
936+
AffineMap kNewMap = AffineMap::get(2, 0, (dk.floorDiv(N)) * N);
937+
938+
Value nNew = affine::AffineApplyOp::create(rewriter, loc, nNewMap,
939+
ValueRange{nBase, kSingle});
940+
Value kNew = affine::AffineApplyOp::create(rewriter, loc, kNewMap,
941+
ValueRange{nBase, kSingle});
942+
943+
VectorType writeVecType = VectorType::get({1, N}, elemType);
944+
Value castResult = vector::ShapeCastOp::create(rewriter, loc, writeVecType,
945+
trLoad.getResult());
946+
vector::TransferWriteOp::create(rewriter, loc, castResult,
947+
writeOp.getBase(), ValueRange{nNew, kNew},
948+
SmallVector<bool>{true, true});
949+
rewriter.eraseOp(writeOp);
950+
return success();
951+
}
952+
};
953+
754954
//===----------------------------------------------------------------------===//
755955
// Pass
756956
//===----------------------------------------------------------------------===//
@@ -768,36 +968,53 @@ struct ROCDLLoadToTransposeLoadPass final
768968
void runOnOperation() override {
769969
FunctionOpInterface funcOp = getOperation();
770970

771-
// Check if target supports transpose_load (currently gfx950 only)
772971
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
773972
if (!target) {
774973
return;
775974
}
776975
FailureOr<amdgpu::Chipset> chipset =
777976
amdgpu::Chipset::parse(target.getArch());
778-
if (failed(chipset) || *chipset != kGfx950) {
977+
if (failed(chipset)) {
779978
return;
780979
}
781980

782-
IRRewriter rewriter(funcOp.getContext());
981+
bool isGfx950 = (*chipset == kGfx950);
982+
bool isRDNA4 = chipset->majorVersion == 12 && chipset->minorVersion <= 1;
783983

784-
// Phase 1: Seed hints on gpu.thread_id ops
785-
std::optional<SmallVector<int64_t>> workgroupSize =
786-
getWorkgroupSize(funcOp);
787-
if (workgroupSize) {
788-
seedThreadIdHints(funcOp, rewriter, *workgroupSize);
984+
if (!isGfx950 && !isRDNA4) {
985+
return;
789986
}
790987

791-
// Phase 2: Propagate hints and lower transfer_reads via greedy patterns
988+
IRRewriter rewriter(funcOp.getContext());
989+
792990
RewritePatternSet patterns(funcOp.getContext());
793-
patterns.add<PropagateHintThroughDelinearize, TransferReadToTransposeLoad>(
794-
funcOp.getContext());
991+
992+
if (isGfx950) {
993+
// Seed hints on gpu.thread_id ops for LDS transpose load.
994+
std::optional<SmallVector<int64_t>> workgroupSize =
995+
getWorkgroupSize(funcOp);
996+
if (workgroupSize) {
997+
seedThreadIdHints(funcOp, rewriter, *workgroupSize);
998+
}
999+
// Propagate hints and lower transfer_reads via greedy patterns.
1000+
patterns
1001+
.add<PropagateHintThroughDelinearize, TransferReadToTransposeLoad>(
1002+
funcOp.getContext());
1003+
}
1004+
1005+
if (isRDNA4) {
1006+
// Global memory transpose load: match vector<1x8> transfer_read +
1007+
// transpose [1,0] → vector<8x1> from flat global memory and replace
1008+
// with amdgpu.global_transpose_load.
1009+
patterns.add<TransferReadTransposeToGlobalTransposeLoad>(
1010+
funcOp.getContext());
1011+
}
7951012

7961013
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
7971014
return signalPassFailure();
7981015
}
7991016

800-
// Phase 3: Remove pass-local index_hint ops for IR cleanliness
1017+
// Remove pass-local index_hint ops for IR cleanliness (gfx950 path).
8011018
funcOp.walk([&](IREE::Codegen::IndexHintOp hintOp) {
8021019
if (hintOp->hasAttr(kPassLocalHintAttr)) {
8031020
rewriter.replaceAllUsesWith(hintOp.getResult(), hintOp.getOperand());

compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ iree_lit_test_suite(
6161
"reduction_pipeline_rocm.mlir",
6262
"reduction_pipeline_softmax_rocm.mlir",
6363
"reuse_shared_memory_allocs.mlir",
64+
"rocdl_global_transpose_load.mlir",
6465
"rocdl_load_to_transpose_load.mlir",
6566
"rocdl_pipeline_test.mlir",
6667
"sort_pipeline_test.mlir",

compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ iree_lit_test_suite(
5656
"reduction_pipeline_rocm.mlir"
5757
"reduction_pipeline_softmax_rocm.mlir"
5858
"reuse_shared_memory_allocs.mlir"
59+
"rocdl_global_transpose_load.mlir"
5960
"rocdl_load_to_transpose_load.mlir"
6061
"rocdl_pipeline_test.mlir"
6162
"sort_pipeline_test.mlir"

0 commit comments

Comments
 (0)