Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion xla/backends/gpu/codegen/triton/compilation_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ void CreateTritonXlaPipeline(
auto* cuda_cc = gpu_cc.cuda_compute_capability();
bool is_at_least_hopper = cuda_cc != nullptr && cuda_cc->IsAtLeastHopper();

auto* rocm_cc = gpu_cc.rocm_compute_capability();
bool rocm_supports_tdm = rocm_cc != nullptr && rocm_cc->has_tdm_support();

if (rewrite_int4) {
pm->addPass(mlir::triton::xla::CreateInt4ToPackedInt4RewritePass(
/*enable_bf16x2=*/is_at_least_hopper));
Expand All @@ -59,7 +62,8 @@ void CreateTritonXlaPipeline(
pm->addPass(CreateInsertPDLPass());
}
pm->addPass(mlir::triton::xla::CreateTritonXLAExtractInsertToTritonPass(
/*allow_tma=*/allow_tma && is_at_least_hopper, num_stages));
/*allow_tma=*/allow_tma && is_at_least_hopper,
/*allow_tdm=*/rocm_supports_tdm, num_stages));
if (enable_pdl) {
pm->addPass(emitters::CreateLowerPdlWaitPass());
}
Expand Down
11 changes: 7 additions & 4 deletions xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ namespace mt = ::mlir::triton;

// Based on make_ttir() in
// @triton//:third_party/amd/backend/compiler.py
static void MakeTTIR(mlir::OpPassManager* pm) {
static void MakeTTIR(mlir::OpPassManager* pm,
const stream_executor::RocmComputeCapability& rocm_cc) {
pm->addPass(mlir::createInlinerPass());
// if not amd.supports_tdm(arch)
// pm->addPass(mt::createTritonRewriteTensorDescriptorToPointer());
if (!rocm_cc.has_tdm_support()) {
pm->addPass(mt::createTritonRewriteTensorDescriptorToPointer());
}
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mt::createTritonCombineOps());
pm->addPass(mt::createTritonReorderBroadcast());
Expand Down Expand Up @@ -102,6 +104,7 @@ static void MakeTTGIR(mlir::OpPassManager* pm,
bool use_block_pingpong =
is_pingpong_schedule_enabled(rocm_cc, use_async_copy);

pm->addPass(mlir::createTritonAMDGPUOptimizeDescriptorEncoding());
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: descriptor optimization pass is added unconditionally for all ROCm architectures

createTritonAMDGPUOptimizeDescriptorEncoding() is added to the pipeline for all ROCm targets, not just TDM-capable ones. While the upstream pass is likely a no-op when there are no tensor descriptors in the IR, it would be more explicit to guard it with if (rocm_cc.supports_tdm()), consistent with the descriptor-to-pointer rewrite guard in MakeTTIR.

pm->addPass(mlir::createTritonAMDGPUScheduleLoops({num_stages}));
pm->addPass(
mlir::createTritonAMDGPUPipeline({use_async_copy, use_block_pingpong}));
Expand Down Expand Up @@ -193,7 +196,7 @@ void CreateTritonRocmPipeline(
mlir::OpPassManager* pm,
const stream_executor::RocmComputeCapability& rocm_cc, int num_warps,
int num_ctas, int num_stages) {
MakeTTIR(pm);
MakeTTIR(pm, rocm_cc);
MakeTTGIR(pm, rocm_cc, num_warps, num_ctas, num_stages);
MakeLLIR(pm, rocm_cc, num_stages);
}
Expand Down
2 changes: 2 additions & 0 deletions xla/backends/gpu/codegen/triton/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace mlir::triton::xla {
std::unique_ptr<mlir::Pass> CreateTritonXLAExtractInsertToTritonPass();
std::unique_ptr<mlir::Pass> CreateTritonXLAExtractInsertToTritonPass(
bool allow_tma, int num_stages);
std::unique_ptr<mlir::Pass> CreateTritonXLAExtractInsertToTritonPass(
bool allow_tma, bool allow_tdm, int num_stages);
std::unique_ptr<mlir::Pass> CreateTritonXLASqueezeDimsPass();
std::unique_ptr<mlir::Pass> CreateTritonXLAFoldTransposePass();
std::unique_ptr<mlir::Pass> CreateGeneralizeKernelSignaturePass();
Expand Down
2 changes: 2 additions & 0 deletions xla/backends/gpu/codegen/triton/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def TritonXLAExtractInsertToTritonPass : Pass<"triton-xla-extract-insert-to-trit
let options = [
Option<"allow_tma_", "allow_tma", "bool", "false",
"Whether to permit lowering to TMA.">,
Option<"allow_tdm_", "allow_tdm", "bool", "false",
"Whether to permit lowering to TDM (device-side tensor descriptors).">,
Option<"num_stages_", "num_stages", "int", "1",
"Number of stages for pipelining.">,
];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: xla-opt %s --triton-xla-pipeline='target=gfx1250' \
// RUN: | FileCheck %s --check-prefix=CHECK-TDM
//
// RUN: xla-opt %s --triton-xla-pipeline='target=gfx950' \
// RUN: | FileCheck %s --check-prefix=CHECK-NOTDM

// Verifies that the full Triton XLA + AMD lowering pipeline emits TDM
// intrinsics on gfx1250 and pointer-arithmetic buffer ops on non-TDM arches.

func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
%extracted_tensor = triton_xla.extract from %arg0
as memref<256x256xbf16, #xtile.layout<[1, 0]>>
[0, 0] [16, 64] [1, 1] : tensor<16x64xbf16>
triton_xla.insert %extracted_tensor into %arg1
as memref<256x256xbf16, #xtile.layout<[1, 0]>>
[0, 0] [16, 64] [1, 1] : tensor<16x64xbf16>
func.return
}

// CHECK-TDM-LABEL: llvm.func @lower_extract_insert
// CHECK-TDM: tensor.load.to.lds
// CHECK-TDM: s.wait.tensorcnt
// CHECK-TDM: tensor.store.from.lds

// CHECK-NOTDM-LABEL: llvm.func @lower_extract_insert
// CHECK-NOTDM-NOT: tensor.load.to.lds
// CHECK-NOTDM-NOT: tensor.store.from.lds
// CHECK-NOTDM: raw.ptr.buffer.load
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
// RUN: -triton-xla-extract-insert-to-triton="allow_tma=1 num_stages=3" \
// RUN: | FileCheck %s --check-prefix=CHECK-TMA

// RUN: xla-opt %s -split-input-file \
// RUN: -triton-xla-extract-insert-to-triton="allow_tdm=1" \
// RUN: | FileCheck %s --check-prefix=CHECK-TDM

func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
%extracted_tensor = triton_xla.extract from %arg0
as memref<512x8x128xbf16, #xtile.layout<[2, 1, 0]>>
Expand All @@ -30,6 +34,17 @@ func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}],
// CHECK-TMA: tt.return

// Middle singleton dim is TDM-incompatible, so fall back to pointer loads.
// CHECK-TDM-LABEL: tt.func @lower_extract_insert(
// CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
// CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
// CHECK-TDM-NOT: tt.make_tensor_descriptor
// CHECK-TDM-NOT: tt.descriptor_load
// CHECK-TDM-NOT: tt.descriptor_store
// CHECK-TDM: %[[LOAD:.*]] = tt.load
// CHECK-TDM: tt.store {{.*}}, %[[LOAD]]
// CHECK-TDM: tt.return

// -----
Comment on lines 36 to 48
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing gap: several test cases lack CHECK-TDM assertions

CHECK-TDM assertions are only present for 3 of the 10+ test cases (TDM-incompatible singleton, 1D, and 5D). Notably missing TDM coverage:

  • non_perfect_tile_shape — does TDM handle boundary checks for non-perfectly-tiling memrefs?
  • extract_insert_with_zero_stride — zero strides (broadcasting) behavior under TDM
  • extract_with_non_unit_minor_dim_stride — non-unit stride fallback
  • incompatible_tma_global_strides — stride compatibility differences between TMA and TDM

Even if TDM correctly falls back to pointer ops for some of these, the expected behavior should be tested to prevent future regressions.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have covered better here.


func.func @non_perfect_tile_shape(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
Expand All @@ -46,6 +61,12 @@ func.func @non_perfect_tile_shape(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
// CHECK: %[[LOAD:.*]] = tt.load {{.*}}, %{{.*}}, %{{.*}} :
// CHECK: tt.store {{.*}}, %[[LOAD]], %{{.*}} :

// CHECK-TDM-LABEL: tt.func @non_perfect_tile_shape
// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
// CHECK-TDM: tt.descriptor_load %[[DESC0]]
// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
// CHECK-TDM: tt.descriptor_store %[[DESC1]]

// -----

func.func @incompatible_tma_global_strides(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
Expand All @@ -62,6 +83,11 @@ func.func @incompatible_tma_global_strides(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<
// CHECK-TMA: tt.load
// CHECK-TMA: tt.store

// CHECK-TDM-LABEL: tt.func @incompatible_tma_global_strides
// CHECK-TDM-NOT: tt.make_tensor_descriptor
// CHECK-TDM: tt.load
// CHECK-TDM: tt.store

// -----

#indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]">
Expand Down Expand Up @@ -91,6 +117,11 @@ module {
// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}
// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}

// CHECK-TDM-LABEL: tt.func @slice_with_tiling_that_needs_padding_has_boundary_checks
// CHECK-TDM: tt.descriptor_load
// CHECK-TDM: tt.descriptor_store
// CHECK-TDM: tt.descriptor_store

// -----

#indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]">
Expand Down Expand Up @@ -120,6 +151,11 @@ module {
// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}
// CHECK: tt.store {{.*}}, %{{.*}} :

// CHECK-TDM-LABEL: tt.func @slice_with_extra_output_that_can_reuse_tile_due_to_padding
// CHECK-TDM: tt.descriptor_load
// CHECK-TDM: tt.descriptor_store
// CHECK-TDM: tt.descriptor_store

// -----

func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr<bf16>,
Expand All @@ -137,6 +173,10 @@ func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr<bf16>,
// CHECK-TMA: tt.load
// CHECK-TMA: tt.descriptor_store

// CHECK-TDM-LABEL: tt.func @extract_with_non_unit_minor_dim_stride
// CHECK-TDM: tt.load
// CHECK-TDM: tt.descriptor_store

// -----

func.func @lower_extract_insert_1d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
Expand All @@ -163,6 +203,15 @@ func.func @lower_extract_insert_1d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]]
// CHECK-TMA: tt.return

// CHECK-TDM-LABEL: tt.func @lower_extract_insert_1d(
// CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
// CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
// CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]]
// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
// CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]]
// CHECK-TDM: tt.return

// -----

func.func @lower_extract_insert_5d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
Expand All @@ -189,6 +238,15 @@ func.func @lower_extract_insert_5d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]]
// CHECK-TMA: tt.return

// CHECK-TDM-LABEL: tt.func @lower_extract_insert_5d(
// CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
// CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
// CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]]
// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
// CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]]
// CHECK-TDM: tt.return

// -----

func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
Expand All @@ -205,6 +263,11 @@ func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<
// CHECK-TMA-SAME: %arg0: !tt.tensordesc<1x64xbf16>
// CHECK-TMA-SAME: %arg1: !tt.tensordesc<1x64xbf16>

// CHECK-TDM-LABEL: tt.func @extract_insert_with_zero_stride
// CHECK-TDM-NOT: tt.make_tensor_descriptor
// CHECK-TDM: tt.load
// CHECK-TDM: tt.store

// -----

func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes(
Expand All @@ -222,6 +285,11 @@ func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes(
// CHECK-TMA: tt.load
// CHECK-TMA: tt.descriptor_store

// CHECK-TDM-LABEL: tt.func @incompatible_tma_const_offset_not_divisible_by_16_bytes
// CHECK-TDM-NOT: tt.make_tensor_descriptor
// CHECK-TDM: tt.load
// CHECK-TDM: tt.store

// -----

#indexing_map = #xla.indexing_map<"(pid_0) -> ((pid_0 mod 9) * 16 + (pid_0 floordiv 9) * 130), domain: pid_0 in [0, 575]">
Expand Down Expand Up @@ -251,6 +319,10 @@ module {
// CHECK-TMA: tt.load
// CHECK-TMA: tt.descriptor_store

// CHECK-TDM-LABEL: tt.func @incompatible_tma_dynamic_offset_not_divisible_by_16_bytes
// CHECK-TDM: tt.descriptor_load
// CHECK-TDM: tt.store

// -----

func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma(
Expand All @@ -276,6 +348,11 @@ func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma(
// CHECK-TMA-NOT: tt.descriptor_load %arg0
// CHECK-TMA: tt.descriptor_load %arg1

// CHECK-TDM-LABEL: tt.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma
// CHECK-TDM: tt.descriptor_load
// CHECK-TDM: tt.descriptor_load
// CHECK-TDM: tt.descriptor_store

// -----

#indexing_map_unaligned = #xla.indexing_map<"(d0) -> (d0 * 2816), domain: d0 in [0, 2047]">
Expand All @@ -301,6 +378,10 @@ module {
// CHECK: %[[MASK:.*]] = arith.cmpi slt
// CHECK: tt.load {{.*}}, %[[MASK]], {{.*}}

// CHECK-TDM-LABEL: tt.func @apply_mask_to_unaligned_offset_with_perfect_total_size
// CHECK-TDM: tt.descriptor_load
// CHECK-TDM: tt.descriptor_store

// -----

#indexing_map_aligned_with_oob_at_end = #xla.indexing_map<"(pid, d1) -> ((pid floordiv 64) * 384 + d1 * 32), domain: pid in [0, 1023], d1 in [0, 11]">
Expand Down Expand Up @@ -328,3 +409,7 @@ module {
// CHECK-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end
// CHECK: %[[MASK:.*]] = arith.cmpi slt
// CHECK: tt.load {{.*}}, %[[MASK]], {{.*}}

// CHECK-TDM-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end
// CHECK-TDM: tt.descriptor_load
// CHECK-TDM: tt.descriptor_store
Loading
Loading