Skip to content

Commit 0b15442

Browse files
committed
[ROCm] Add TDM (Tensor Descriptor Memory) support for gfx1250
gfx1250 introduces device-side tensor descriptors (TDM), constructed on-device via tt.make_tensor_descriptor. Wire TDM through the XLA Triton pipeline so xtile.extract/insert lower to descriptor-based loads on TDM-capable hardware, with a pointer-based fallback for TDM-incompatible tile shapes.
1 parent 32b0167 commit 0b15442

7 files changed

Lines changed: 288 additions & 11 deletions

File tree

xla/backends/gpu/codegen/triton/compilation_pipeline.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ void CreateTritonXlaPipeline(
5050
auto* cuda_cc = gpu_cc.cuda_compute_capability();
5151
bool is_at_least_hopper = cuda_cc != nullptr && cuda_cc->IsAtLeastHopper();
5252

53+
auto* rocm_cc = gpu_cc.rocm_compute_capability();
54+
bool rocm_supports_tdm = rocm_cc != nullptr && rocm_cc->has_tdm_support();
55+
5356
if (rewrite_int4) {
5457
pm->addPass(mlir::triton::xla::CreateInt4ToPackedInt4RewritePass(
5558
/*enable_bf16x2=*/is_at_least_hopper));
@@ -59,7 +62,8 @@ void CreateTritonXlaPipeline(
5962
pm->addPass(CreateInsertPDLPass());
6063
}
6164
pm->addPass(mlir::triton::xla::CreateTritonXLAExtractInsertToTritonPass(
62-
/*allow_tma=*/allow_tma && is_at_least_hopper, num_stages));
65+
/*allow_tma=*/allow_tma && is_at_least_hopper,
66+
/*allow_tdm=*/rocm_supports_tdm, num_stages));
6367
if (enable_pdl) {
6468
pm->addPass(emitters::CreateLowerPdlWaitPass());
6569
}

xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ limitations under the License.
1818
#include <string>
1919

2020
#include "absl/strings/str_cat.h"
21-
#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
22-
#include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h"
2321
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
2422
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
2523
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
2624
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
2725
#include "mlir/Pass/PassManager.h"
2826
#include "mlir/Transforms/Passes.h"
27+
#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
28+
#include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h"
2929
#include "xla/stream_executor/device_description.h"
30+
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
3031
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
3132
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
3233
#include "triton/Dialect/Triton/Transforms/Passes.h"
@@ -39,10 +40,12 @@ namespace mt = ::mlir::triton;
3940

4041
// Based on make_ttir() in
4142
// @triton//:third_party/amd/backend/compiler.py
42-
static void MakeTTIR(mlir::OpPassManager* pm) {
43+
static void MakeTTIR(mlir::OpPassManager* pm,
44+
const stream_executor::RocmComputeCapability& rocm_cc) {
4345
pm->addPass(mlir::createInlinerPass());
44-
// if not amd.supports_tdm(arch)
45-
// pm->addPass(mt::createTritonRewriteTensorDescriptorToPointer());
46+
if (!rocm_cc.has_tdm_support()) {
47+
pm->addPass(mt::createTritonRewriteTensorDescriptorToPointer());
48+
}
4649
pm->addPass(mlir::createCanonicalizerPass());
4750
pm->addPass(mt::createTritonCombineOps());
4851
pm->addPass(mt::createTritonReorderBroadcast());
@@ -99,6 +102,7 @@ static void MakeTTGIR(mlir::OpPassManager* pm,
99102
bool use_block_pingpong =
100103
is_pingpong_schedule_enabled(rocm_cc, use_async_copy);
101104

105+
pm->addPass(mlir::createTritonAMDGPUOptimizeDescriptorEncoding());
102106
pm->addPass(mlir::createTritonAMDGPUScheduleLoops({num_stages}));
103107
pm->addPass(
104108
mlir::createTritonAMDGPUPipeline({use_async_copy, use_block_pingpong}));
@@ -185,7 +189,7 @@ void CreateTritonRocmPipeline(
185189
mlir::OpPassManager* pm,
186190
const stream_executor::RocmComputeCapability& rocm_cc, int num_warps,
187191
int num_ctas, int num_stages) {
188-
MakeTTIR(pm);
192+
MakeTTIR(pm, rocm_cc);
189193
MakeTTGIR(pm, rocm_cc, num_warps, num_ctas, num_stages);
190194
MakeLLIR(pm, rocm_cc, num_stages);
191195
}

xla/backends/gpu/codegen/triton/transforms/passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ namespace mlir::triton::xla {
3333
std::unique_ptr<mlir::Pass> CreateTritonXLAExtractInsertToTritonPass();
3434
std::unique_ptr<mlir::Pass> CreateTritonXLAExtractInsertToTritonPass(
3535
bool allow_tma, int num_stages);
36+
std::unique_ptr<mlir::Pass> CreateTritonXLAExtractInsertToTritonPass(
37+
bool allow_tma, bool allow_tdm, int num_stages);
3638
std::unique_ptr<mlir::Pass> CreateTritonXLASqueezeDimsPass();
3739
std::unique_ptr<mlir::Pass> CreateTritonXLAFoldTransposePass();
3840
std::unique_ptr<mlir::Pass> CreateGeneralizeKernelSignaturePass();

xla/backends/gpu/codegen/triton/transforms/passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def TritonXLAExtractInsertToTritonPass : Pass<"triton-xla-extract-insert-to-trit
2828
let options = [
2929
Option<"allow_tma_", "allow_tma", "bool", "false",
3030
"Whether to permit lowering to TMA.">,
31+
Option<"allow_tdm_", "allow_tdm", "bool", "false",
32+
"Whether to permit lowering to TDM (device-side tensor descriptors).">,
3133
Option<"num_stages_", "num_stages", "int", "1",
3234
"Number of stages for pipelining.">,
3335
];
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: xla-opt %s --triton-xla-pipeline='target=gfx1250' \
2+
// RUN: | FileCheck %s --check-prefix=CHECK-TDM
3+
//
4+
// RUN: xla-opt %s --triton-xla-pipeline='target=gfx950' \
5+
// RUN: | FileCheck %s --check-prefix=CHECK-NOTDM
6+
7+
// Verifies that the full Triton XLA + AMD lowering pipeline emits TDM
8+
// intrinsics on gfx1250 and pointer-arithmetic buffer ops on non-TDM arches.
9+
10+
func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
11+
%extracted_tensor = triton_xla.extract from %arg0
12+
as memref<256x256xbf16, #xtile.layout<[1, 0]>>
13+
[0, 0] [16, 64] [1, 1] : tensor<16x64xbf16>
14+
triton_xla.insert %extracted_tensor into %arg1
15+
as memref<256x256xbf16, #xtile.layout<[1, 0]>>
16+
[0, 0] [16, 64] [1, 1] : tensor<16x64xbf16>
17+
func.return
18+
}
19+
20+
// CHECK-TDM-LABEL: llvm.func @lower_extract_insert
21+
// CHECK-TDM: tensor.load.to.lds
22+
// CHECK-TDM: s.wait.tensorcnt
23+
// CHECK-TDM: tensor.store.from.lds
24+
25+
// CHECK-NOTDM-LABEL: llvm.func @lower_extract_insert
26+
// CHECK-NOTDM-NOT: tensor.load.to.lds
27+
// CHECK-NOTDM-NOT: tensor.store.from.lds
28+
// CHECK-NOTDM: raw.ptr.buffer.load

xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_extract_insert_to_triton.mlir

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
// RUN: -triton-xla-extract-insert-to-triton="allow_tma=1 num_stages=3" \
77
// RUN: | FileCheck %s --check-prefix=CHECK-TMA
88

9+
// RUN: xla-opt %s -split-input-file \
10+
// RUN: -triton-xla-extract-insert-to-triton="allow_tdm=1" \
11+
// RUN: | FileCheck %s --check-prefix=CHECK-TDM
12+
913
func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
1014
%extracted_tensor = triton_xla.extract from %arg0
1115
as memref<512x8x128xbf16, #xtile.layout<[2, 1, 0]>>
@@ -30,6 +34,17 @@ func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
3034
// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}],
3135
// CHECK-TMA: tt.return
3236

37+
// Middle singleton dim is TDM-incompatible, so fall back to pointer loads.
38+
// CHECK-TDM-LABEL: tt.func @lower_extract_insert(
39+
// CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
40+
// CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
41+
// CHECK-TDM-NOT: tt.make_tensor_descriptor
42+
// CHECK-TDM-NOT: tt.descriptor_load
43+
// CHECK-TDM-NOT: tt.descriptor_store
44+
// CHECK-TDM: %[[LOAD:.*]] = tt.load
45+
// CHECK-TDM: tt.store {{.*}}, %[[LOAD]]
46+
// CHECK-TDM: tt.return
47+
3348
// -----
3449

3550
func.func @non_perfect_tile_shape(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
@@ -46,6 +61,12 @@ func.func @non_perfect_tile_shape(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
4661
// CHECK: %[[LOAD:.*]] = tt.load {{.*}}, %{{.*}}, %{{.*}} :
4762
// CHECK: tt.store {{.*}}, %[[LOAD]], %{{.*}} :
4863

64+
// CHECK-TDM-LABEL: tt.func @non_perfect_tile_shape
65+
// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
66+
// CHECK-TDM: tt.descriptor_load %[[DESC0]]
67+
// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
68+
// CHECK-TDM: tt.descriptor_store %[[DESC1]]
69+
4970
// -----
5071

5172
func.func @incompatible_tma_global_strides(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
@@ -62,6 +83,11 @@ func.func @incompatible_tma_global_strides(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<
6283
// CHECK-TMA: tt.load
6384
// CHECK-TMA: tt.store
6485

86+
// CHECK-TDM-LABEL: tt.func @incompatible_tma_global_strides
87+
// CHECK-TDM-NOT: tt.make_tensor_descriptor
88+
// CHECK-TDM: tt.load
89+
// CHECK-TDM: tt.store
90+
6591
// -----
6692

6793
#indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]">
@@ -91,6 +117,11 @@ module {
91117
// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}
92118
// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}
93119

120+
// CHECK-TDM-LABEL: tt.func @slice_with_tiling_that_needs_padding_has_boundary_checks
121+
// CHECK-TDM: tt.descriptor_load
122+
// CHECK-TDM: tt.descriptor_store
123+
// CHECK-TDM: tt.descriptor_store
124+
94125
// -----
95126

96127
#indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]">
@@ -120,6 +151,11 @@ module {
120151
// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}
121152
// CHECK: tt.store {{.*}}, %{{.*}} :
122153

154+
// CHECK-TDM-LABEL: tt.func @slice_with_extra_output_that_can_reuse_tile_due_to_padding
155+
// CHECK-TDM: tt.descriptor_load
156+
// CHECK-TDM: tt.descriptor_store
157+
// CHECK-TDM: tt.descriptor_store
158+
123159
// -----
124160

125161
func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr<bf16>,
@@ -137,6 +173,10 @@ func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr<bf16>,
137173
// CHECK-TMA: tt.load
138174
// CHECK-TMA: tt.descriptor_store
139175

176+
// CHECK-TDM-LABEL: tt.func @extract_with_non_unit_minor_dim_stride
177+
// CHECK-TDM: tt.load
178+
// CHECK-TDM: tt.descriptor_store
179+
140180
// -----
141181

142182
func.func @lower_extract_insert_1d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
@@ -163,6 +203,15 @@ func.func @lower_extract_insert_1d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
163203
// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]]
164204
// CHECK-TMA: tt.return
165205

206+
// CHECK-TDM-LABEL: tt.func @lower_extract_insert_1d(
207+
// CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
208+
// CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
209+
// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
210+
// CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]]
211+
// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
212+
// CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]]
213+
// CHECK-TDM: tt.return
214+
166215
// -----
167216

168217
func.func @lower_extract_insert_5d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
@@ -189,6 +238,15 @@ func.func @lower_extract_insert_5d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
189238
// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]]
190239
// CHECK-TMA: tt.return
191240

241+
// CHECK-TDM-LABEL: tt.func @lower_extract_insert_5d(
242+
// CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
243+
// CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
244+
// CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
245+
// CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]]
246+
// CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
247+
// CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]]
248+
// CHECK-TDM: tt.return
249+
192250
// -----
193251

194252
func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
@@ -205,6 +263,11 @@ func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<
205263
// CHECK-TMA-SAME: %arg0: !tt.tensordesc<1x64xbf16>
206264
// CHECK-TMA-SAME: %arg1: !tt.tensordesc<1x64xbf16>
207265

266+
// CHECK-TDM-LABEL: tt.func @extract_insert_with_zero_stride
267+
// CHECK-TDM-NOT: tt.make_tensor_descriptor
268+
// CHECK-TDM: tt.load
269+
// CHECK-TDM: tt.store
270+
208271
// -----
209272

210273
func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes(
@@ -222,6 +285,11 @@ func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes(
222285
// CHECK-TMA: tt.load
223286
// CHECK-TMA: tt.descriptor_store
224287

288+
// CHECK-TDM-LABEL: tt.func @incompatible_tma_const_offset_not_divisible_by_16_bytes
289+
// CHECK-TDM-NOT: tt.make_tensor_descriptor
290+
// CHECK-TDM: tt.load
291+
// CHECK-TDM: tt.store
292+
225293
// -----
226294

227295
#indexing_map = #xla.indexing_map<"(pid_0) -> ((pid_0 mod 9) * 16 + (pid_0 floordiv 9) * 130), domain: pid_0 in [0, 575]">
@@ -251,6 +319,10 @@ module {
251319
// CHECK-TMA: tt.load
252320
// CHECK-TMA: tt.descriptor_store
253321

322+
// CHECK-TDM-LABEL: tt.func @incompatible_tma_dynamic_offset_not_divisible_by_16_bytes
323+
// CHECK-TDM: tt.descriptor_load
324+
// CHECK-TDM: tt.store
325+
254326
// -----
255327

256328
func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma(
@@ -276,6 +348,11 @@ func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma(
276348
// CHECK-TMA-NOT: tt.descriptor_load %arg0
277349
// CHECK-TMA: tt.descriptor_load %arg1
278350

351+
// CHECK-TDM-LABEL: tt.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma
352+
// CHECK-TDM: tt.descriptor_load
353+
// CHECK-TDM: tt.descriptor_load
354+
// CHECK-TDM: tt.descriptor_store
355+
279356
// -----
280357

281358
#indexing_map_unaligned = #xla.indexing_map<"(d0) -> (d0 * 2816), domain: d0 in [0, 2047]">
@@ -301,6 +378,10 @@ module {
301378
// CHECK: %[[MASK:.*]] = arith.cmpi slt
302379
// CHECK: tt.load {{.*}}, %[[MASK]], {{.*}}
303380

381+
// CHECK-TDM-LABEL: tt.func @apply_mask_to_unaligned_offset_with_perfect_total_size
382+
// CHECK-TDM: tt.descriptor_load
383+
// CHECK-TDM: tt.descriptor_store
384+
304385
// -----
305386

306387
#indexing_map_aligned_with_oob_at_end = #xla.indexing_map<"(pid, d1) -> ((pid floordiv 64) * 384 + d1 * 32), domain: pid in [0, 1023], d1 in [0, 11]">
@@ -328,3 +409,7 @@ module {
328409
// CHECK-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end
329410
// CHECK: %[[MASK:.*]] = arith.cmpi slt
330411
// CHECK: tt.load {{.*}}, %[[MASK]], {{.*}}
412+
413+
// CHECK-TDM-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end
414+
// CHECK-TDM: tt.descriptor_load
415+
// CHECK-TDM: tt.descriptor_store

0 commit comments

Comments
 (0)