From 8638299c285f7090ade3de8150c74ee86d4207b9 Mon Sep 17 00:00:00 2001 From: Aleksei Nurmukhametov Date: Thu, 9 Apr 2026 10:17:13 -0500 Subject: [PATCH] [ROCm] Add TDM (Tensor Descriptor Memory) support for gfx1250 gfx1250 introduces device-side tensor descriptors (TDM), constructed on-device via tt.make_tensor_descriptor. Wire TDM through the XLA Triton pipeline so xtile.extract/insert lower to descriptor-based loads on TDM-capable hardware, with a pointer-based fallback for TDM-incompatible tile shapes. --- .../codegen/triton/compilation_pipeline.cc | 6 +- .../triton/compilation_pipeline_rocm.cc | 11 +- .../gpu/codegen/triton/transforms/passes.h | 2 + .../gpu/codegen/triton/transforms/passes.td | 2 + .../transforms/tests/triton_pipeline_tdm.mlir | 28 +++ .../triton_xla_extract_insert_to_triton.mlir | 85 ++++++++++ ...riton_xla_extract_insert_to_triton_pass.cc | 160 +++++++++++++++++- 7 files changed, 285 insertions(+), 9 deletions(-) create mode 100644 xla/backends/gpu/codegen/triton/transforms/tests/triton_pipeline_tdm.mlir diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline.cc index 327d7b617da47..4209a0cb46d9b 100644 --- a/xla/backends/gpu/codegen/triton/compilation_pipeline.cc +++ b/xla/backends/gpu/codegen/triton/compilation_pipeline.cc @@ -50,6 +50,9 @@ void CreateTritonXlaPipeline( auto* cuda_cc = gpu_cc.cuda_compute_capability(); bool is_at_least_hopper = cuda_cc != nullptr && cuda_cc->IsAtLeastHopper(); + auto* rocm_cc = gpu_cc.rocm_compute_capability(); + bool rocm_supports_tdm = rocm_cc != nullptr && rocm_cc->has_tdm_support(); + if (rewrite_int4) { pm->addPass(mlir::triton::xla::CreateInt4ToPackedInt4RewritePass( /*enable_bf16x2=*/is_at_least_hopper)); @@ -59,7 +62,8 @@ void CreateTritonXlaPipeline( pm->addPass(CreateInsertPDLPass()); } pm->addPass(mlir::triton::xla::CreateTritonXLAExtractInsertToTritonPass( - /*allow_tma=*/allow_tma && is_at_least_hopper, num_stages)); + /*allow_tma=*/allow_tma && is_at_least_hopper, + /*allow_tdm=*/rocm_supports_tdm, num_stages)); if (enable_pdl) { pm->addPass(emitters::CreateLowerPdlWaitPass()); } diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc index 4d2108a490c6d..2bbd610db2dac 100644 --- a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc +++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc @@ -42,10 +42,12 @@ namespace mt = ::mlir::triton; // Based on make_ttir() in // @triton//:third_party/amd/backend/compiler.py -static void MakeTTIR(mlir::OpPassManager* pm) { +static void MakeTTIR(mlir::OpPassManager* pm, + const stream_executor::RocmComputeCapability& rocm_cc) { pm->addPass(mlir::createInlinerPass()); - // if not amd.supports_tdm(arch) - // pm->addPass(mt::createTritonRewriteTensorDescriptorToPointer()); + if (!rocm_cc.has_tdm_support()) { + pm->addPass(mt::createTritonRewriteTensorDescriptorToPointer()); + } pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mt::createTritonCombineOps()); pm->addPass(mt::createTritonReorderBroadcast()); @@ -102,6 +104,7 @@ static void MakeTTGIR(mlir::OpPassManager* pm, bool use_block_pingpong = is_pingpong_schedule_enabled(rocm_cc, use_async_copy); + pm->addPass(mlir::createTritonAMDGPUOptimizeDescriptorEncoding()); pm->addPass(mlir::createTritonAMDGPUScheduleLoops({num_stages})); pm->addPass( mlir::createTritonAMDGPUPipeline({use_async_copy, use_block_pingpong})); @@ -193,7 +196,7 @@ void CreateTritonRocmPipeline( mlir::OpPassManager* pm, const stream_executor::RocmComputeCapability& rocm_cc, int num_warps, int num_ctas, int num_stages) { - MakeTTIR(pm); + MakeTTIR(pm, rocm_cc); MakeTTGIR(pm, rocm_cc, num_warps, num_ctas, num_stages); MakeLLIR(pm, rocm_cc, num_stages); } diff --git a/xla/backends/gpu/codegen/triton/transforms/passes.h b/xla/backends/gpu/codegen/triton/transforms/passes.h index 0ae024e0ba383..3b31f7ca6164e 100644 --- a/xla/backends/gpu/codegen/triton/transforms/passes.h +++ b/xla/backends/gpu/codegen/triton/transforms/passes.h @@ -34,6 +34,8 @@ namespace mlir::triton::xla { std::unique_ptr CreateTritonXLAExtractInsertToTritonPass(); std::unique_ptr CreateTritonXLAExtractInsertToTritonPass( bool allow_tma, int num_stages); +std::unique_ptr CreateTritonXLAExtractInsertToTritonPass( + bool allow_tma, bool allow_tdm, int num_stages); std::unique_ptr CreateTritonXLASqueezeDimsPass(); std::unique_ptr CreateTritonXLAFoldTransposePass(); std::unique_ptr CreateGeneralizeKernelSignaturePass(); diff --git a/xla/backends/gpu/codegen/triton/transforms/passes.td b/xla/backends/gpu/codegen/triton/transforms/passes.td index 9787dff0977cd..3d4badf5a93dd 100644 --- a/xla/backends/gpu/codegen/triton/transforms/passes.td +++ b/xla/backends/gpu/codegen/triton/transforms/passes.td @@ -28,6 +28,8 @@ def TritonXLAExtractInsertToTritonPass : Pass<"triton-xla-extract-insert-to-trit let options = [ Option<"allow_tma_", "allow_tma", "bool", "false", "Whether to permit lowering to TMA.">, + Option<"allow_tdm_", "allow_tdm", "bool", "false", + "Whether to permit lowering to TDM (device-side tensor descriptors).">, Option<"num_stages_", "num_stages", "int", "1", "Number of stages for pipelining.">, ]; diff --git a/xla/backends/gpu/codegen/triton/transforms/tests/triton_pipeline_tdm.mlir b/xla/backends/gpu/codegen/triton/transforms/tests/triton_pipeline_tdm.mlir new file mode 100644 index 0000000000000..87f37f6c1f025 --- /dev/null +++ b/xla/backends/gpu/codegen/triton/transforms/tests/triton_pipeline_tdm.mlir @@ -0,0 +1,28 @@ +// RUN: xla-opt %s --triton-xla-pipeline='target=gfx1250' \ +// RUN: | FileCheck %s --check-prefix=CHECK-TDM +// +// RUN: xla-opt %s --triton-xla-pipeline='target=gfx950' \ +// RUN: | FileCheck %s --check-prefix=CHECK-NOTDM + +// Verifies that the full Triton XLA + AMD lowering pipeline emits TDM +// intrinsics on gfx1250 and pointer-arithmetic buffer ops on non-TDM arches. + +func.func @lower_extract_insert(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %extracted_tensor = triton_xla.extract from %arg0 + as memref<256x256xbf16, #xtile.layout<[1, 0]>> + [0, 0] [16, 64] [1, 1] : tensor<16x64xbf16> + triton_xla.insert %extracted_tensor into %arg1 + as memref<256x256xbf16, #xtile.layout<[1, 0]>> + [0, 0] [16, 64] [1, 1] : tensor<16x64xbf16> + func.return +} + +// CHECK-TDM-LABEL: llvm.func @lower_extract_insert +// CHECK-TDM: tensor.load.to.lds +// CHECK-TDM: s.wait.tensorcnt +// CHECK-TDM: tensor.store.from.lds + +// CHECK-NOTDM-LABEL: llvm.func @lower_extract_insert +// CHECK-NOTDM-NOT: tensor.load.to.lds +// CHECK-NOTDM-NOT: tensor.store.from.lds +// CHECK-NOTDM: raw.ptr.buffer.load diff --git a/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir b/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir index 57f06299b843b..eb284424b0156 100644 --- a/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir +++ b/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir @@ -6,6 +6,10 @@ // RUN: -triton-xla-extract-insert-to-triton="allow_tma=1 num_stages=3" \ // RUN: | FileCheck %s --check-prefix=CHECK-TMA +// RUN: xla-opt %s -split-input-file \ +// RUN: -triton-xla-extract-insert-to-triton="allow_tdm=1" \ +// RUN: | FileCheck %s --check-prefix=CHECK-TDM + func.func @lower_extract_insert(%arg0: !tt.ptr, %arg1: !tt.ptr) { %extracted_tensor = triton_xla.extract from %arg0 as memref<512x8x128xbf16, #xtile.layout<[2, 1, 0]>> @@ -30,6 +34,17 @@ func.func @lower_extract_insert(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], // CHECK-TMA: tt.return +// Middle singleton dim is TDM-incompatible, so fall back to pointer loads. +// CHECK-TDM-LABEL: tt.func @lower_extract_insert( +// CHECK-TDM-SAME: %arg0: !tt.ptr {tt.divisibility = 16 : i32}, +// CHECK-TDM-SAME: %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK-TDM-NOT: tt.make_tensor_descriptor +// CHECK-TDM-NOT: tt.descriptor_load +// CHECK-TDM-NOT: tt.descriptor_store +// CHECK-TDM: %[[LOAD:.*]] = tt.load +// CHECK-TDM: tt.store {{.*}}, %[[LOAD]] +// CHECK-TDM: tt.return + // ----- func.func @non_perfect_tile_shape(%arg0: !tt.ptr, %arg1: !tt.ptr) { @@ -46,6 +61,12 @@ func.func @non_perfect_tile_shape(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK: %[[LOAD:.*]] = tt.load {{.*}}, %{{.*}}, %{{.*}} : // CHECK: tt.store {{.*}}, %[[LOAD]], %{{.*}} : +// CHECK-TDM-LABEL: tt.func @non_perfect_tile_shape +// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0 +// CHECK-TDM: tt.descriptor_load %[[DESC0]] +// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1 +// CHECK-TDM: tt.descriptor_store %[[DESC1]] + // ----- func.func @incompatible_tma_global_strides(%arg0: !tt.ptr, %arg1: !tt.ptr) { @@ -62,6 +83,11 @@ func.func @incompatible_tma_global_strides(%arg0: !tt.ptr, %arg1: !tt.ptr< // CHECK-TMA: tt.load // CHECK-TMA: tt.store +// CHECK-TDM-LABEL: tt.func @incompatible_tma_global_strides +// CHECK-TDM-NOT: tt.make_tensor_descriptor +// CHECK-TDM: tt.load +// CHECK-TDM: tt.store + // ----- #indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]"> @@ -91,6 +117,11 @@ module { // CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}} // CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}} +// CHECK-TDM-LABEL: tt.func @slice_with_tiling_that_needs_padding_has_boundary_checks +// CHECK-TDM: tt.descriptor_load +// CHECK-TDM: tt.descriptor_store +// CHECK-TDM: tt.descriptor_store + // ----- #indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]"> @@ -120,6 +151,11 @@ module { // CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}} // CHECK: tt.store {{.*}}, %{{.*}} : +// CHECK-TDM-LABEL: tt.func @slice_with_extra_output_that_can_reuse_tile_due_to_padding +// CHECK-TDM: tt.descriptor_load +// CHECK-TDM: tt.descriptor_store +// CHECK-TDM: tt.descriptor_store + // ----- func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr, @@ -137,6 +173,10 @@ func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr, // CHECK-TMA: tt.load // CHECK-TMA: tt.descriptor_store +// CHECK-TDM-LABEL: tt.func @extract_with_non_unit_minor_dim_stride +// CHECK-TDM: tt.load +// CHECK-TDM: tt.descriptor_store + // ----- func.func @lower_extract_insert_1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { @@ -163,6 +203,15 @@ func.func @lower_extract_insert_1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]] // CHECK-TMA: tt.return +// CHECK-TDM-LABEL: tt.func @lower_extract_insert_1d( +// CHECK-TDM-SAME: %arg0: !tt.ptr {tt.divisibility = 16 : i32}, +// CHECK-TDM-SAME: %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0 +// CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]] +// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1 +// CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]] +// CHECK-TDM: tt.return + // ----- func.func @lower_extract_insert_5d(%arg0: !tt.ptr, %arg1: !tt.ptr) { @@ -189,6 +238,15 @@ func.func @lower_extract_insert_5d(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]] // CHECK-TMA: tt.return +// CHECK-TDM-LABEL: tt.func @lower_extract_insert_5d( +// CHECK-TDM-SAME: %arg0: !tt.ptr {tt.divisibility = 16 : i32}, +// CHECK-TDM-SAME: %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0 +// CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]] +// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1 +// CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]] +// CHECK-TDM: tt.return + // ----- func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr, %arg1: !tt.ptr) { @@ -205,6 +263,11 @@ func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr, %arg1: !tt.ptr< // CHECK-TMA-SAME: %arg0: !tt.tensordesc<1x64xbf16> // CHECK-TMA-SAME: %arg1: !tt.tensordesc<1x64xbf16> +// CHECK-TDM-LABEL: tt.func @extract_insert_with_zero_stride +// CHECK-TDM-NOT: tt.make_tensor_descriptor +// CHECK-TDM: tt.load +// CHECK-TDM: tt.store + // ----- func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes( @@ -222,6 +285,11 @@ func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes( // CHECK-TMA: tt.load // CHECK-TMA: tt.descriptor_store +// CHECK-TDM-LABEL: tt.func @incompatible_tma_const_offset_not_divisible_by_16_bytes +// CHECK-TDM-NOT: tt.make_tensor_descriptor +// CHECK-TDM: tt.load +// CHECK-TDM: tt.store + // ----- #indexing_map = #xla.indexing_map<"(pid_0) -> ((pid_0 mod 9) * 16 + (pid_0 floordiv 9) * 130), domain: pid_0 in [0, 575]"> @@ -251,6 +319,10 @@ module { // CHECK-TMA: tt.load // CHECK-TMA: tt.descriptor_store +// CHECK-TDM-LABEL: tt.func @incompatible_tma_dynamic_offset_not_divisible_by_16_bytes +// CHECK-TDM: tt.descriptor_load +// CHECK-TDM: tt.store + // ----- func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma( @@ -276,6 +348,11 @@ func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma( // CHECK-TMA-NOT: tt.descriptor_load %arg0 // CHECK-TMA: tt.descriptor_load %arg1 +// CHECK-TDM-LABEL: tt.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma +// CHECK-TDM: tt.descriptor_load +// CHECK-TDM: tt.descriptor_load +// CHECK-TDM: tt.descriptor_store + // ----- #indexing_map_unaligned = #xla.indexing_map<"(d0) -> (d0 * 2816), domain: d0 in [0, 2047]"> @@ -301,6 +378,10 @@ module { // CHECK: %[[MASK:.*]] = arith.cmpi slt // CHECK: tt.load {{.*}}, %[[MASK]], {{.*}} +// CHECK-TDM-LABEL: tt.func @apply_mask_to_unaligned_offset_with_perfect_total_size +// CHECK-TDM: tt.descriptor_load +// CHECK-TDM: tt.descriptor_store + // ----- #indexing_map_aligned_with_oob_at_end = #xla.indexing_map<"(pid, d1) -> ((pid floordiv 64) * 384 + d1 * 32), domain: pid in [0, 1023], d1 in [0, 11]"> @@ -328,3 +409,7 @@ module { // CHECK-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end // CHECK: %[[MASK:.*]] = arith.cmpi slt // CHECK: tt.load {{.*}}, %[[MASK]], {{.*}} + +// CHECK-TDM-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end +// CHECK-TDM: tt.descriptor_load +// CHECK-TDM: tt.descriptor_store diff --git a/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc b/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc index 08249234ec3ea..f9ee283211e2e 100644 --- a/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc +++ b/xla/backends/gpu/codegen/triton/transforms/triton_xla_extract_insert_to_triton_pass.cc @@ -259,6 +259,98 @@ SmallVector GetMajorToMinorOrder(ValueRange values, return GetMajorToMinorOrder(ArrayRef(llvm::to_vector(values)), layout); } +// Whether the tile shape is compatible with AMD TDM lowering. Rejects: +// - dynamic tile sizes or strides +// - non-unit tile strides +// - non-trailing singleton dims (violates upstream Triton's TDM legalizer +// contract: shared order must be [rank-1, ..., 0] with stride-1 dims +// consecutive trailing). A smarter pass could canonicalize the singleton +// out and lower to TDM. We reject conservatively for now. +// +// Not checked here, deferred to upstream Triton / hardware legalization: +// - Hardware rank cap (TDM supports ranks 1 to 5). +// - Innermost dim byte alignment requirements. +// - Per-dim box size limits. +// - Padding mode compatibility (this pass hardcodes PAD_ZERO). +bool CanUseTdm(bool allow_tdm, const ArrayRef& tile_sizes, + const ArrayRef& tile_strides, + const ArrayRef& minor_to_major_layout) { + if (!allow_tdm) { + return false; + } + // Dynamic sizes or strides would feed sentinel values into the i32/i64 + // descriptor constants and silently produce a malformed descriptor. + if (mlir::ShapedType::isDynamicShape(tile_sizes) || + mlir::ShapedType::isDynamicShape(tile_strides)) { + VLOG(1) << "TDM is not compatible: dynamic tile sizes or strides."; + return false; + } + // TDM descriptors describe contiguous boxes; non-unit (or zero) tile strides + // cannot be expressed and would silently produce a contiguous load. + for (int64_t s : tile_strides) { + if (s != 1) { + VLOG(1) << "TDM is not compatible: non-unit tile stride."; + return false; + } + } + auto ordered_sizes = GetMajorToMinorOrder(tile_sizes, minor_to_major_layout); + bool seen_singleton = false; + for (int64_t s : ordered_sizes) { + if (s == 1) { + seen_singleton = true; + } else if (seen_singleton) { + // Non-singleton dim follows a singleton dim, TDM-incompatible. + VLOG(1) << "TDM is not compatible: non-trailing singleton dim in tile."; + return false; + } + } + return true; +} + +// Builds a tt.make_tensor_descriptor for a contiguous box load/store from +// `pointer`. Global shape and strides come from `shape` + `layout` (the source +// memref geometry); the descriptor's block dims come from `tile_sizes`. All +// arrays are reordered to major-to-minor as required by Triton's TDM +// legalizer. Padding mode is hardcoded to PAD_ZERO. +MakeTensorDescOp BuildTensorDescriptor(ImplicitLocOpBuilder& builder, + Value pointer, ArrayRef shape, + ArrayRef layout, + ArrayRef tile_sizes) { + // Global shape as i32 SSA values, in major-to-minor order. + auto ordered_shape = GetMajorToMinorOrder(shape, layout); + SmallVector shape_values; + for (int64_t dim : ordered_shape) { + shape_values.push_back( + arith::ConstantOp::create(builder, builder.getI32IntegerAttr(dim))); + } + + // Global strides as i64 SSA values, in major-to-minor order. + auto global_strides = xtriton::ComputeStrides(shape, layout); + auto ordered_strides = + GetMajorToMinorOrder(ArrayRef(global_strides), layout); + SmallVector stride_values; + for (int64_t s : ordered_strides) { + stride_values.push_back( + arith::ConstantOp::create(builder, builder.getI64IntegerAttr(s))); + } + + // Block shape in major-to-minor order as i32. + auto ordered_sizes = GetMajorToMinorOrder(tile_sizes, layout); + SmallVector block_shape; + for (int64_t s : ordered_sizes) { + CHECK_LE(s, INT32_MAX) << "tile dim " << s << " exceeds i32 range"; + block_shape.push_back(static_cast(s)); + } + + auto element_type = cast(pointer.getType()).getPointeeType(); + bool is_signed_integer = + mlir::isa(element_type) && !element_type.isUnsignedInteger(); + + return MakeTensorDescOp::create(builder, pointer, shape_values, stride_values, + block_shape, is_signed_integer, + PaddingOption::PAD_ZERO); +} + // Given the layout of a tensor, return the inverse permutation required to // transpose an already major-to-minor tensor to the original tensor. SmallVector GetInverseLayoutPermutation(ArrayRef layout) { @@ -357,9 +449,11 @@ class RewriteFuncOp : public mlir::OpRewritePattern { class RewriteExtract : public mlir::OpRewritePattern { public: - RewriteExtract(mlir::MLIRContext* context, bool allow_tma, int num_stages) + RewriteExtract(mlir::MLIRContext* context, bool allow_tma, bool allow_tdm, + int num_stages) : OpRewritePattern(context), allow_tma_(allow_tma), + allow_tdm_(allow_tdm), num_stages_(num_stages) {} using OpRewritePattern::OpRewritePattern; @@ -428,6 +522,31 @@ class RewriteExtract : public mlir::OpRewritePattern { return mlir::success(); } + if (CanUseTdm(allow_tdm_, sizes, strides, src_layout)) { + auto ordered_offsets = GetMajorToMinorOrder(offsets, src_layout); + auto ordered_type = + tile_type.clone(GetMajorToMinorOrder(sizes, src_layout)); + + auto desc = BuildTensorDescriptor(builder, op.getSrc(), src_shape, + src_layout, sizes); + + Value result = DescriptorLoadOp::create( + builder, ordered_type, desc.getResult(), + xtriton::IndexCast(builder, builder.getI32Type(), ordered_offsets)); + + if (!IsMajorToMinorLayout(src_layout)) { + result = TransOp::create(builder, result, + GetInverseLayoutPermutation(src_layout)); + } + if (sizes.size() != tile_shape.size()) { + result = ReshapeOp::create(builder, tile_shape, result, + /*allowReorder=*/false); + } + + rewriter.replaceOp(op, result); + return mlir::success(); + } + // Compute the set of reduced dimensions. auto reduction_mask = mlir::computeRankReductionMask(sizes, tile_shape); if (!reduction_mask) { @@ -453,14 +572,17 @@ class RewriteExtract : public mlir::OpRewritePattern { } const bool allow_tma_; + const bool allow_tdm_; const int num_stages_; }; class RewriteInsert : public mlir::OpRewritePattern { public: - RewriteInsert(mlir::MLIRContext* context, bool allow_tma, int num_stages) + RewriteInsert(mlir::MLIRContext* context, bool allow_tma, bool allow_tdm, + int num_stages) : OpRewritePattern(context), allow_tma_(allow_tma), + allow_tdm_(allow_tdm), num_stages_(num_stages) {} using OpRewritePattern::OpRewritePattern; @@ -532,6 +654,25 @@ class RewriteInsert : public mlir::OpRewritePattern { DescriptorStoreOp::create( builder, cast_to_tensor_desc.getResult(0), src, xtriton::IndexCast(builder, builder.getI32Type(), ordered_offsets)); + } else if (CanUseTdm(allow_tdm_, sizes, strides, dst_layout)) { + auto ordered_offsets = GetMajorToMinorOrder(offsets, dst_layout); + + auto desc = BuildTensorDescriptor(builder, op.getDst(), dst_shape, + dst_layout, sizes); + + Value src = op.getSrc(); + for (auto dim : reduced_dims) { + src = ExpandDimsOp::create(builder, src, dim); + } + if (!IsMajorToMinorLayout(dst_layout)) { + auto transpose_order = llvm::to_vector_of(dst_layout); + std::reverse(transpose_order.begin(), transpose_order.end()); + src = TransOp::create(builder, src, transpose_order); + } + + DescriptorStoreOp::create( + builder, desc.getResult(), src, + xtriton::IndexCast(builder, builder.getI32Type(), ordered_offsets)); } else { auto [ptr, mask] = xtriton::CreateTensorOfPointersAndMask( builder, op.getDst(), dst_shape, dst_layout, offsets, sizes, strides, @@ -544,6 +685,7 @@ class RewriteInsert : public mlir::OpRewritePattern { } const bool allow_tma_; + const bool allow_tdm_; const int num_stages_; }; @@ -606,7 +748,8 @@ class TritonXLAExtractInsertToTritonPass mlir::MLIRContext* mlir_context = &getContext(); mlir::RewritePatternSet patterns(mlir_context); patterns.add( - mlir_context, allow_tma_.getValue(), num_stages_.getValue()); + mlir_context, allow_tma_.getValue(), allow_tdm_.getValue(), + num_stages_.getValue()); patterns.add(mlir_context); if (mlir::failed( mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { @@ -631,7 +774,16 @@ std::unique_ptr CreateTritonXLAExtractInsertToTritonPass() { std::unique_ptr CreateTritonXLAExtractInsertToTritonPass( bool allow_tma, int num_stages) { return std::make_unique( - TritonXLAExtractInsertToTritonPassOptions{allow_tma, num_stages}); + TritonXLAExtractInsertToTritonPassOptions{allow_tma, + /*allow_tdm=*/false, + num_stages}); +} + +std::unique_ptr CreateTritonXLAExtractInsertToTritonPass( + bool allow_tma, bool allow_tdm, int num_stages) { + return std::make_unique( + TritonXLAExtractInsertToTritonPassOptions{allow_tma, allow_tdm, + num_stages}); } } // namespace mlir::triton::xla