88#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
99#include " iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
1010#include " iree/compiler/Codegen/Utils/GPUUtils.h"
11+ #include " iree/compiler/Dialect/HAL/IR/HALTypes.h"
1112#include " llvm/ADT/SetVector.h"
1213#include " llvm/Support/Debug.h"
1314#include " llvm/Support/DebugLog.h"
1819#include " mlir/Dialect/Arith/IR/Arith.h"
1920#include " mlir/Dialect/GPU/IR/GPUDialect.h"
2021#include " mlir/Dialect/MemRef/IR/MemRef.h"
22+ #include " mlir/Dialect/MemRef/Transforms/Transforms.h"
2123#include " mlir/Dialect/Vector/IR/VectorOps.h"
2224#include " mlir/IR/PatternMatch.h"
2325#include " mlir/Pass/Pass.h"
@@ -34,6 +36,7 @@ namespace {
3436
3537constexpr int64_t kTransposeLoadLaneGroupSize = 16 ;
3638constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9 , 5 , 0 );
39+ constexpr amdgpu::Chipset kGfx1200 = amdgpu::Chipset(12 , 0 , 0 );
3740constexpr llvm::StringLiteral kPassLocalHintAttr = " __pass_local_hint" ;
3841
3942// ===----------------------------------------------------------------------===//
@@ -751,6 +754,203 @@ struct TransferReadToTransposeLoad final
751754 }
752755};
753756
757+ // ===----------------------------------------------------------------------===//
758+ // Global Transfer Read + Transpose to Global Transpose Load Pattern
759+ // ===----------------------------------------------------------------------===//
760+
761+ // / Returns true if the memref memory space is a flat global (not workgroup,
762+ // / not fat_raw_buffer). Accepts no memory space, gpu::Global, integer 0/1,
763+ // / or #hal.descriptor_type<storage_buffer>.
764+ static bool isGlobalMemorySpace (Attribute memSpace) {
765+ if (!memSpace) {
766+ return true ;
767+ }
768+ if (auto gpuAttr = dyn_cast<gpu::AddressSpaceAttr>(memSpace)) {
769+ return gpuAttr.getValue () == gpu::AddressSpace::Global;
770+ }
771+ if (auto intAttr = dyn_cast<IntegerAttr>(memSpace)) {
772+ return intAttr.getInt () == 0 || intAttr.getInt () == 1 ;
773+ }
774+ // Accept HAL descriptor_type (flat global binding in IREE).
775+ if (isa<IREE::HAL::DescriptorTypeAttr>(memSpace)) {
776+ return true ;
777+ }
778+ return false ;
779+ }
780+
781+ // / Returns the required vector size (number of elements in the transposed
782+ // / dimension) for global_transpose_load given an element type, or nullopt if
783+ // / unsupported.
784+ static std::optional<int64_t >
785+ getGlobalTransposeLoadVectorSize (Type elementType) {
786+ unsigned bits = elementType.getIntOrFloatBitWidth ();
787+ switch (bits) {
788+ case 8 :
789+ case 16 :
790+ return 8 ;
791+ default :
792+ return std::nullopt ;
793+ }
794+ }
795+
796+ // / Matches:
797+ // / %read = vector.transfer_read %src[%row, %col] : memref<..., global>,
798+ // / vector<Nx1xT>
799+ // / %result = vector.transpose %read, [1, 0] : vector<Nx1xT> to vector<1xNxT>
800+ // /
801+ // / and replaces with:
802+ // / %cast = memref.memory_space_cast %src : ... to memref<..., global>
803+ // / %tr = amdgpu.global_transpose_load %cast[%row, %col]
804+ // / : memref<..., global> -> vector<NxT>
805+ // / %result = vector.shape_cast %tr : vector<NxT> to vector<1xNxT>
806+ // /
807+ // / Only fires on gfx1250+ targets.
808+ struct TransferReadTransposeToGlobalTransposeLoad final
809+ : OpRewritePattern<vector::TransposeOp> {
810+ using Base::Base;
811+
812+ LogicalResult matchAndRewrite (vector::TransposeOp transposeOp,
813+ PatternRewriter &rewriter) const override {
814+ // Must be a simple [1, 0] transpose (2D).
815+ ArrayRef<int64_t > perm = transposeOp.getPermutation ();
816+ if (perm.size () != 2 || perm[0 ] != 1 || perm[1 ] != 0 ) {
817+ return rewriter.notifyMatchFailure (transposeOp,
818+ " not a 2D [1,0] transpose" );
819+ }
820+
821+ // Input to transpose must be a transfer_read.
822+ auto transferOp =
823+ transposeOp.getVector ().getDefiningOp <vector::TransferReadOp>();
824+ if (!transferOp) {
825+ return rewriter.notifyMatchFailure (transposeOp,
826+ " not fed by transfer_read" );
827+ }
828+
829+ // Source memref must be flat global (not workgroup, not fat_raw_buffer).
830+ auto memrefType = cast<MemRefType>(transferOp.getBase ().getType ());
831+ if (!isGlobalMemorySpace (memrefType.getMemorySpace ())) {
832+ return rewriter.notifyMatchFailure (transposeOp,
833+ " source is not flat global memory" );
834+ }
835+
836+ // With (K-outer, N-inner) iteration in the linalg.generic copy, N is the
837+ // vectorized inner dimension. The transfer_read reads 8 contiguous
838+ // N-elements per lane, giving vector<1xNxT> (1 K row, N contiguous cols).
839+ // vector<1x8xT> [K_dim=1, N_dim=8]
840+ // After the software transpose [1,0] this becomes vector<8x1xT> [N,K],
841+ // which is written to alloc_8[N_base, K_single] along N (stride-8 write).
842+ // global_load_tr replaces the N read: each of 8 consecutive lanes provides
843+ // its own K-row address, the hardware transposes (8×8 block transpose),
844+ // and the result is written with the corrected stride-8 subview.
845+ VectorType readType = transferOp.getVectorType ();
846+ if (readType.getRank () != 2 || readType.getDimSize (0 ) != 1 ) {
847+ return rewriter.notifyMatchFailure (transposeOp,
848+ " expected vector<1xNxT> from read" );
849+ }
850+
851+ // Check element type and expected N (number of contiguous N elements read).
852+ Type elemType = readType.getElementType ();
853+ std::optional<int64_t > expectedN =
854+ getGlobalTransposeLoadVectorSize (elemType);
855+ if (!expectedN) {
856+ return rewriter.notifyMatchFailure (transposeOp,
857+ " unsupported element type" );
858+ }
859+ if (readType.getDimSize (1 ) != *expectedN) {
860+ return rewriter.notifyMatchFailure (
861+ transposeOp,
862+ " vector inner dim does not match global_transpose_load size" );
863+ }
864+
865+ // Must be in_bounds.
866+ ArrayAttr inBounds = transferOp.getInBounds ();
867+ if (!inBounds || !llvm::all_of (inBounds.getAsRange <BoolAttr>(),
868+ [](BoolAttr b) { return b.getValue (); })) {
869+ return rewriter.notifyMatchFailure (transposeOp,
870+ " transfer_read not in_bounds" );
871+ }
872+
873+ // Permutation map must be identity (no broadcast).
874+ if (!transferOp.getPermutationMap ().isIdentity ()) {
875+ return rewriter.notifyMatchFailure (transposeOp,
876+ " non-identity permutation map" );
877+ }
878+
879+ Location loc = transposeOp.getLoc ();
880+
881+ // Cast memref to gpu::AddressSpace::Global if needed so that
882+ // amdgpu.global_transpose_load verifier is satisfied.
883+ Value src = transferOp.getBase ();
884+ if (memrefType.getMemorySpace ()) {
885+ auto globalSpace = gpu::AddressSpaceAttr::get (rewriter.getContext (),
886+ gpu::AddressSpace::Global);
887+ auto globalMemrefType =
888+ MemRefType::get (memrefType.getShape (), memrefType.getElementType (),
889+ memrefType.getLayout (), globalSpace);
890+ src = memref::MemorySpaceCastOp::create (rewriter, loc, globalMemrefType,
891+ src);
892+ }
893+
894+ // The transpose result must have exactly one use: a transfer_write to
895+ // workgroup memory at [N_base, K_single] with vector<8x1>.
896+ // With the K-inner tiling (UseGlobalTransposeLoadAttr, (N-outer, K-inner)
897+ // linalg.generic), global_load_tr's 8-lane wave-level 8×8 transpose means
898+ // lane K_single's result[i] = B[K_group_base+i, N_base + K_single%N].
899+ // This should be written to alloc_8[N_base + K_single%N, K_group_base..N-1]
900+ // as vector<1x8> (contiguous K) — no subview needed.
901+ if (!transposeOp->hasOneUse ()) {
902+ return rewriter.notifyMatchFailure (transposeOp,
903+ " transpose result has multiple uses" );
904+ }
905+ auto writeOp =
906+ dyn_cast<vector::TransferWriteOp>(*transposeOp->user_begin ());
907+ if (!writeOp) {
908+ return rewriter.notifyMatchFailure (
909+ transposeOp, " transpose not consumed by transfer_write" );
910+ }
911+ auto writeDst = cast<MemRefType>(writeOp.getBase ().getType ());
912+ if (!hasSharedMemoryAddressSpace (writeDst)) {
913+ return rewriter.notifyMatchFailure (
914+ transposeOp, " write destination is not workgroup memory" );
915+ }
916+
917+ // Emit amdgpu.global_transpose_load.
918+ int64_t N = *expectedN;
919+ auto resultVecType = VectorType::get ({N}, elemType);
920+ auto trLoad = amdgpu::GlobalTransposeLoadOp::create (
921+ rewriter, loc, resultVecType, src, transferOp.getIndices ());
922+
923+ // Compute corrected write indices for contiguous K writes:
924+ // N_new = N_base + K_single % N (lane's N position within N_base
925+ // group) K_new = (K_single // N) * N (K-group base, aligned to N)
926+ // Write vector<1xNxT> at [N_new, K_new] → alloc_8[N_new, K_new..K_new+N-1]
927+ // This is contiguous K in alloc_8[N, K] (K is inner) — no subview needed.
928+ ValueRange writeIndices = writeOp.getIndices ();
929+ assert (writeIndices.size () == 2 && " expected 2D write" );
930+ Value nBase = writeIndices[0 ]; // N_group * N
931+ Value kSingle = writeIndices[1 ]; // K lane value (0..K_total-1)
932+
933+ AffineExpr dn = rewriter.getAffineDimExpr (0 );
934+ AffineExpr dk = rewriter.getAffineDimExpr (1 );
935+ AffineMap nNewMap = AffineMap::get (2 , 0 , dn + dk % N);
936+ AffineMap kNewMap = AffineMap::get (2 , 0 , (dk.floorDiv (N)) * N);
937+
938+ Value nNew = affine::AffineApplyOp::create (rewriter, loc, nNewMap,
939+ ValueRange{nBase, kSingle });
940+ Value kNew = affine::AffineApplyOp::create (rewriter, loc, kNewMap ,
941+ ValueRange{nBase, kSingle });
942+
943+ VectorType writeVecType = VectorType::get ({1 , N}, elemType);
944+ Value castResult = vector::ShapeCastOp::create (rewriter, loc, writeVecType,
945+ trLoad.getResult ());
946+ vector::TransferWriteOp::create (rewriter, loc, castResult,
947+ writeOp.getBase (), ValueRange{nNew, kNew },
948+ SmallVector<bool >{true , true });
949+ rewriter.eraseOp (writeOp);
950+ return success ();
951+ }
952+ };
953+
754954// ===----------------------------------------------------------------------===//
755955// Pass
756956// ===----------------------------------------------------------------------===//
@@ -768,36 +968,53 @@ struct ROCDLLoadToTransposeLoadPass final
768968 void runOnOperation () override {
769969 FunctionOpInterface funcOp = getOperation ();
770970
771- // Check if target supports transpose_load (currently gfx950 only)
772971 IREE::GPU::TargetAttr target = getGPUTargetAttr (funcOp);
773972 if (!target) {
774973 return ;
775974 }
776975 FailureOr<amdgpu::Chipset> chipset =
777976 amdgpu::Chipset::parse (target.getArch ());
778- if (failed (chipset) || *chipset != kGfx950 ) {
977+ if (failed (chipset)) {
779978 return ;
780979 }
781980
782- IRRewriter rewriter (funcOp.getContext ());
981+ bool isGfx950 = (*chipset == kGfx950 );
982+ bool isRDNA4 = chipset->majorVersion == 12 && chipset->minorVersion <= 1 ;
783983
784- // Phase 1: Seed hints on gpu.thread_id ops
785- std::optional<SmallVector<int64_t >> workgroupSize =
786- getWorkgroupSize (funcOp);
787- if (workgroupSize) {
788- seedThreadIdHints (funcOp, rewriter, *workgroupSize);
984+ if (!isGfx950 && !isRDNA4) {
985+ return ;
789986 }
790987
791- // Phase 2: Propagate hints and lower transfer_reads via greedy patterns
988+ IRRewriter rewriter (funcOp.getContext ());
989+
792990 RewritePatternSet patterns (funcOp.getContext ());
793- patterns.add <PropagateHintThroughDelinearize, TransferReadToTransposeLoad>(
794- funcOp.getContext ());
991+
992+ if (isGfx950) {
993+ // Seed hints on gpu.thread_id ops for LDS transpose load.
994+ std::optional<SmallVector<int64_t >> workgroupSize =
995+ getWorkgroupSize (funcOp);
996+ if (workgroupSize) {
997+ seedThreadIdHints (funcOp, rewriter, *workgroupSize);
998+ }
999+ // Propagate hints and lower transfer_reads via greedy patterns.
1000+ patterns
1001+ .add <PropagateHintThroughDelinearize, TransferReadToTransposeLoad>(
1002+ funcOp.getContext ());
1003+ }
1004+
1005+ if (isRDNA4) {
1006+ // Global memory transpose load: match vector<1x8> transfer_read +
1007+ // transpose [1,0] → vector<8x1> from flat global memory and replace
1008+ // with amdgpu.global_transpose_load.
1009+ patterns.add <TransferReadTransposeToGlobalTransposeLoad>(
1010+ funcOp.getContext ());
1011+ }
7951012
7961013 if (failed (applyPatternsGreedily (funcOp, std::move (patterns)))) {
7971014 return signalPassFailure ();
7981015 }
7991016
800- // Phase 3: Remove pass-local index_hint ops for IR cleanliness
1017+ // Remove pass-local index_hint ops for IR cleanliness (gfx950 path).
8011018 funcOp.walk ([&](IREE::Codegen::IndexHintOp hintOp) {
8021019 if (hintOp->hasAttr (kPassLocalHintAttr )) {
8031020 rewriter.replaceAllUsesWith (hintOp.getResult (), hintOp.getOperand ());
0 commit comments