-
Notifications
You must be signed in to change notification settings - Fork 8
[WIP][ROCm] Add TDM (Tensor Descriptor Memory) support for gfx1250 #826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| // RUN: xla-opt %s --triton-xla-pipeline='target=gfx1250' \ | ||
| // RUN: | FileCheck %s --check-prefix=CHECK-TDM | ||
| // | ||
| // RUN: xla-opt %s --triton-xla-pipeline='target=gfx950' \ | ||
| // RUN: | FileCheck %s --check-prefix=CHECK-NOTDM | ||
|
|
||
| // Verifies that the full Triton XLA + AMD lowering pipeline emits TDM | ||
| // intrinsics on gfx1250 and pointer-arithmetic buffer ops on non-TDM arches. | ||
|
|
||
| func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
| %extracted_tensor = triton_xla.extract from %arg0 | ||
| as memref<256x256xbf16, #xtile.layout<[1, 0]>> | ||
| [0, 0] [16, 64] [1, 1] : tensor<16x64xbf16> | ||
| triton_xla.insert %extracted_tensor into %arg1 | ||
| as memref<256x256xbf16, #xtile.layout<[1, 0]>> | ||
| [0, 0] [16, 64] [1, 1] : tensor<16x64xbf16> | ||
| func.return | ||
| } | ||
|
|
||
| // CHECK-TDM-LABEL: llvm.func @lower_extract_insert | ||
| // CHECK-TDM: tensor.load.to.lds | ||
| // CHECK-TDM: s.wait.tensorcnt | ||
| // CHECK-TDM: tensor.store.from.lds | ||
|
|
||
| // CHECK-NOTDM-LABEL: llvm.func @lower_extract_insert | ||
| // CHECK-NOTDM-NOT: tensor.load.to.lds | ||
| // CHECK-NOTDM-NOT: tensor.store.from.lds | ||
| // CHECK-NOTDM: raw.ptr.buffer.load |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,10 @@ | |
| // RUN: -triton-xla-extract-insert-to-triton="allow_tma=1 num_stages=3" \ | ||
| // RUN: | FileCheck %s --check-prefix=CHECK-TMA | ||
|
|
||
| // RUN: xla-opt %s -split-input-file \ | ||
| // RUN: -triton-xla-extract-insert-to-triton="allow_tdm=1" \ | ||
| // RUN: | FileCheck %s --check-prefix=CHECK-TDM | ||
|
|
||
| func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
| %extracted_tensor = triton_xla.extract from %arg0 | ||
| as memref<512x8x128xbf16, #xtile.layout<[2, 1, 0]>> | ||
|
|
@@ -30,6 +34,17 @@ func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | |
| // CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], | ||
| // CHECK-TMA: tt.return | ||
|
|
||
| // Middle singleton dim is TDM-incompatible, so fall back to pointer loads. | ||
| // CHECK-TDM-LABEL: tt.func @lower_extract_insert( | ||
| // CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, | ||
| // CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) { | ||
| // CHECK-TDM-NOT: tt.make_tensor_descriptor | ||
| // CHECK-TDM-NOT: tt.descriptor_load | ||
| // CHECK-TDM-NOT: tt.descriptor_store | ||
| // CHECK-TDM: %[[LOAD:.*]] = tt.load | ||
| // CHECK-TDM: tt.store {{.*}}, %[[LOAD]] | ||
| // CHECK-TDM: tt.return | ||
|
|
||
| // ----- | ||
|
Comment on lines
36
to
48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Testing gap: several test cases lack
Even if TDM correctly falls back to pointer ops for some of these, the expected behavior should be tested to prevent future regressions.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have covered better here. |
||
|
|
||
| func.func @non_perfect_tile_shape(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
|
|
@@ -46,6 +61,12 @@ func.func @non_perfect_tile_shape(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | |
| // CHECK: %[[LOAD:.*]] = tt.load {{.*}}, %{{.*}}, %{{.*}} : | ||
| // CHECK: tt.store {{.*}}, %[[LOAD]], %{{.*}} : | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @non_perfect_tile_shape | ||
| // CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0 | ||
| // CHECK-TDM: tt.descriptor_load %[[DESC0]] | ||
| // CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1 | ||
| // CHECK-TDM: tt.descriptor_store %[[DESC1]] | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @incompatible_tma_global_strides(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
|
|
@@ -62,6 +83,11 @@ func.func @incompatible_tma_global_strides(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr< | |
| // CHECK-TMA: tt.load | ||
| // CHECK-TMA: tt.store | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @incompatible_tma_global_strides | ||
| // CHECK-TDM-NOT: tt.make_tensor_descriptor | ||
| // CHECK-TDM: tt.load | ||
| // CHECK-TDM: tt.store | ||
|
|
||
| // ----- | ||
|
|
||
| #indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]"> | ||
|
|
@@ -91,6 +117,11 @@ module { | |
| // CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}} | ||
| // CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}} | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @slice_with_tiling_that_needs_padding_has_boundary_checks | ||
| // CHECK-TDM: tt.descriptor_load | ||
| // CHECK-TDM: tt.descriptor_store | ||
| // CHECK-TDM: tt.descriptor_store | ||
|
|
||
| // ----- | ||
|
|
||
| #indexing_map = #xla.indexing_map<"(pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]"> | ||
|
|
@@ -120,6 +151,11 @@ module { | |
| // CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}} | ||
| // CHECK: tt.store {{.*}}, %{{.*}} : | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @slice_with_extra_output_that_can_reuse_tile_due_to_padding | ||
| // CHECK-TDM: tt.descriptor_load | ||
| // CHECK-TDM: tt.descriptor_store | ||
| // CHECK-TDM: tt.descriptor_store | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr<bf16>, | ||
|
|
@@ -137,6 +173,10 @@ func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr<bf16>, | |
| // CHECK-TMA: tt.load | ||
| // CHECK-TMA: tt.descriptor_store | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @extract_with_non_unit_minor_dim_stride | ||
| // CHECK-TDM: tt.load | ||
| // CHECK-TDM: tt.descriptor_store | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @lower_extract_insert_1d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
|
|
@@ -163,6 +203,15 @@ func.func @lower_extract_insert_1d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | |
| // CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]] | ||
| // CHECK-TMA: tt.return | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @lower_extract_insert_1d( | ||
| // CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, | ||
| // CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) { | ||
| // CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0 | ||
| // CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]] | ||
| // CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1 | ||
| // CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]] | ||
| // CHECK-TDM: tt.return | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @lower_extract_insert_5d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
|
|
@@ -189,6 +238,15 @@ func.func @lower_extract_insert_5d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | |
| // CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]] | ||
| // CHECK-TMA: tt.return | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @lower_extract_insert_5d( | ||
| // CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, | ||
| // CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) { | ||
| // CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0 | ||
| // CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]] | ||
| // CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1 | ||
| // CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]] | ||
| // CHECK-TDM: tt.return | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
|
|
@@ -205,6 +263,11 @@ func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr< | |
| // CHECK-TMA-SAME: %arg0: !tt.tensordesc<1x64xbf16> | ||
| // CHECK-TMA-SAME: %arg1: !tt.tensordesc<1x64xbf16> | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @extract_insert_with_zero_stride | ||
| // CHECK-TDM-NOT: tt.make_tensor_descriptor | ||
| // CHECK-TDM: tt.load | ||
| // CHECK-TDM: tt.store | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes( | ||
|
|
@@ -222,6 +285,11 @@ func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes( | |
| // CHECK-TMA: tt.load | ||
| // CHECK-TMA: tt.descriptor_store | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @incompatible_tma_const_offset_not_divisible_by_16_bytes | ||
| // CHECK-TDM-NOT: tt.make_tensor_descriptor | ||
| // CHECK-TDM: tt.load | ||
| // CHECK-TDM: tt.store | ||
|
|
||
| // ----- | ||
|
|
||
| #indexing_map = #xla.indexing_map<"(pid_0) -> ((pid_0 mod 9) * 16 + (pid_0 floordiv 9) * 130), domain: pid_0 in [0, 575]"> | ||
|
|
@@ -251,6 +319,10 @@ module { | |
| // CHECK-TMA: tt.load | ||
| // CHECK-TMA: tt.descriptor_store | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @incompatible_tma_dynamic_offset_not_divisible_by_16_bytes | ||
| // CHECK-TDM: tt.descriptor_load | ||
| // CHECK-TDM: tt.store | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma( | ||
|
|
@@ -276,6 +348,11 @@ func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma( | |
| // CHECK-TMA-NOT: tt.descriptor_load %arg0 | ||
| // CHECK-TMA: tt.descriptor_load %arg1 | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma | ||
| // CHECK-TDM: tt.descriptor_load | ||
| // CHECK-TDM: tt.descriptor_load | ||
| // CHECK-TDM: tt.descriptor_store | ||
|
|
||
| // ----- | ||
|
|
||
| #indexing_map_unaligned = #xla.indexing_map<"(d0) -> (d0 * 2816), domain: d0 in [0, 2047]"> | ||
|
|
@@ -301,6 +378,10 @@ module { | |
| // CHECK: %[[MASK:.*]] = arith.cmpi slt | ||
| // CHECK: tt.load {{.*}}, %[[MASK]], {{.*}} | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @apply_mask_to_unaligned_offset_with_perfect_total_size | ||
| // CHECK-TDM: tt.descriptor_load | ||
| // CHECK-TDM: tt.descriptor_store | ||
|
|
||
| // ----- | ||
|
|
||
| #indexing_map_aligned_with_oob_at_end = #xla.indexing_map<"(pid, d1) -> ((pid floordiv 64) * 384 + d1 * 32), domain: pid in [0, 1023], d1 in [0, 11]"> | ||
|
|
@@ -328,3 +409,7 @@ module { | |
| // CHECK-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end | ||
| // CHECK: %[[MASK:.*]] = arith.cmpi slt | ||
| // CHECK: tt.load {{.*}}, %[[MASK]], {{.*}} | ||
|
|
||
| // CHECK-TDM-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end | ||
| // CHECK-TDM: tt.descriptor_load | ||
| // CHECK-TDM: tt.descriptor_store | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: descriptor optimization pass is added unconditionally for all ROCm architectures
createTritonAMDGPUOptimizeDescriptorEncoding()is added to the pipeline for all ROCm targets, not just TDM-capable ones. While the upstream pass is likely a no-op when there are no tensor descriptors in the IR, it would be more explicit to guard it withif (rocm_cc.supports_tdm()), consistent with the descriptor-to-pointer rewrite guard in MakeTTIR.