66// RUN: -triton-xla-extract-insert-to-triton="allow_tma=1 num_stages=3" \
77// RUN: | FileCheck %s --check-prefix=CHECK-TMA
88
9+ // RUN: xla-opt %s -split-input-file \
10+ // RUN: -triton-xla-extract-insert-to-triton="allow_tdm=1" \
11+ // RUN: | FileCheck %s --check-prefix=CHECK-TDM
12+
913func.func @lower_extract_insert (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
1014 %extracted_tensor = triton_xla.extract from %arg0
1115 as memref <512 x8 x128 xbf16 , #xtile.layout <[2 , 1 , 0 ]>>
@@ -30,6 +34,17 @@ func.func @lower_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
3034// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}],
3135// CHECK-TMA: tt.return
3236
37+ // Middle singleton dim is TDM-incompatible, so fall back to pointer loads.
38+ // CHECK-TDM-LABEL: tt.func @lower_extract_insert(
39+ // CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
40+ // CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
41+ // CHECK-TDM-NOT: tt.make_tensor_descriptor
42+ // CHECK-TDM-NOT: tt.descriptor_load
43+ // CHECK-TDM-NOT: tt.descriptor_store
44+ // CHECK-TDM: %[[LOAD:.*]] = tt.load
45+ // CHECK-TDM: tt.store {{.*}}, %[[LOAD]]
46+ // CHECK-TDM: tt.return
47+
3348// -----
3449
3550func.func @non_perfect_tile_shape (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
@@ -46,6 +61,12 @@ func.func @non_perfect_tile_shape(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
4661// CHECK: %[[LOAD:.*]] = tt.load {{.*}}, %{{.*}}, %{{.*}} :
4762// CHECK: tt.store {{.*}}, %[[LOAD]], %{{.*}} :
4863
64+ // CHECK-TDM-LABEL: tt.func @non_perfect_tile_shape
65+ // CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
66+ // CHECK-TDM: tt.descriptor_load %[[DESC0]]
67+ // CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
68+ // CHECK-TDM: tt.descriptor_store %[[DESC1]]
69+
4970// -----
5071
5172func.func @incompatible_tma_global_strides (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
@@ -62,6 +83,11 @@ func.func @incompatible_tma_global_strides(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<
6283// CHECK-TMA: tt.load
6384// CHECK-TMA: tt.store
6485
86+ // CHECK-TDM-LABEL: tt.func @incompatible_tma_global_strides
87+ // CHECK-TDM-NOT: tt.make_tensor_descriptor
88+ // CHECK-TDM: tt.load
89+ // CHECK-TDM: tt.store
90+
6591// -----
6692
6793#indexing_map = #xla.indexing_map <" (pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]" >
@@ -91,6 +117,11 @@ module {
91117// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}
92118// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}
93119
120+ // CHECK-TDM-LABEL: tt.func @slice_with_tiling_that_needs_padding_has_boundary_checks
121+ // CHECK-TDM: tt.descriptor_load
122+ // CHECK-TDM: tt.descriptor_store
123+ // CHECK-TDM: tt.descriptor_store
124+
94125// -----
95126
96127#indexing_map = #xla.indexing_map <" (pid_0) -> (pid_0 * 32), domain: pid_0 in [0, 1]" >
@@ -120,6 +151,11 @@ module {
120151// CHECK: tt.store {{.*}}, %{{.*}}, %{{.*}}
121152// CHECK: tt.store {{.*}}, %{{.*}} :
122153
154+ // CHECK-TDM-LABEL: tt.func @slice_with_extra_output_that_can_reuse_tile_due_to_padding
155+ // CHECK-TDM: tt.descriptor_load
156+ // CHECK-TDM: tt.descriptor_store
157+ // CHECK-TDM: tt.descriptor_store
158+
123159// -----
124160
125161func.func @extract_with_non_unit_minor_dim_stride (%arg0: !tt.ptr <bf16 >,
@@ -137,6 +173,10 @@ func.func @extract_with_non_unit_minor_dim_stride(%arg0: !tt.ptr<bf16>,
137173// CHECK-TMA: tt.load
138174// CHECK-TMA: tt.descriptor_store
139175
176+ // CHECK-TDM-LABEL: tt.func @extract_with_non_unit_minor_dim_stride
177+ // CHECK-TDM: tt.load
178+ // CHECK-TDM: tt.descriptor_store
179+
140180// -----
141181
142182func.func @lower_extract_insert_1d (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
@@ -163,6 +203,15 @@ func.func @lower_extract_insert_1d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
163203// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]]
164204// CHECK-TMA: tt.return
165205
206+ // CHECK-TDM-LABEL: tt.func @lower_extract_insert_1d(
207+ // CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
208+ // CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
209+ // CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
210+ // CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]]
211+ // CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
212+ // CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]]
213+ // CHECK-TDM: tt.return
214+
166215// -----
167216
168217func.func @lower_extract_insert_5d (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
@@ -189,6 +238,15 @@ func.func @lower_extract_insert_5d(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
189238// CHECK-TMA: tt.descriptor_store %arg1[{{.*}}], %[[LOAD]]
190239// CHECK-TMA: tt.return
191240
241+ // CHECK-TDM-LABEL: tt.func @lower_extract_insert_5d(
242+ // CHECK-TDM-SAME: %arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
243+ // CHECK-TDM-SAME: %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
244+ // CHECK-TDM: %[[DESC0:.*]] = tt.make_tensor_descriptor %arg0
245+ // CHECK-TDM: %[[LOAD:.*]] = tt.descriptor_load %[[DESC0]]
246+ // CHECK-TDM: %[[DESC1:.*]] = tt.make_tensor_descriptor %arg1
247+ // CHECK-TDM: tt.descriptor_store %[[DESC1]][{{.*}}], %[[LOAD]]
248+ // CHECK-TDM: tt.return
249+
192250// -----
193251
194252func.func @extract_insert_with_zero_stride (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
@@ -205,6 +263,11 @@ func.func @extract_insert_with_zero_stride(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<
205263// CHECK-TMA-SAME: %arg0: !tt.tensordesc<1x64xbf16>
206264// CHECK-TMA-SAME: %arg1: !tt.tensordesc<1x64xbf16>
207265
266+ // CHECK-TDM-LABEL: tt.func @extract_insert_with_zero_stride
267+ // CHECK-TDM-NOT: tt.make_tensor_descriptor
268+ // CHECK-TDM: tt.load
269+ // CHECK-TDM: tt.store
270+
208271// -----
209272
210273func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes (
@@ -222,6 +285,11 @@ func.func @incompatible_tma_const_offset_not_divisible_by_16_bytes(
222285// CHECK-TMA: tt.load
223286// CHECK-TMA: tt.descriptor_store
224287
288+ // CHECK-TDM-LABEL: tt.func @incompatible_tma_const_offset_not_divisible_by_16_bytes
289+ // CHECK-TDM-NOT: tt.make_tensor_descriptor
290+ // CHECK-TDM: tt.load
291+ // CHECK-TDM: tt.store
292+
225293// -----
226294
227295#indexing_map = #xla.indexing_map <" (pid_0) -> ((pid_0 mod 9) * 16 + (pid_0 floordiv 9) * 130), domain: pid_0 in [0, 575]" >
@@ -251,6 +319,10 @@ module {
251319// CHECK-TMA: tt.load
252320// CHECK-TMA: tt.descriptor_store
253321
322+ // CHECK-TDM-LABEL: tt.func @incompatible_tma_dynamic_offset_not_divisible_by_16_bytes
323+ // CHECK-TDM: tt.descriptor_load
324+ // CHECK-TDM: tt.store
325+
254326// -----
255327
256328func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma (
@@ -276,6 +348,11 @@ func.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma(
276348// CHECK-TMA-NOT: tt.descriptor_load %arg0
277349// CHECK-TMA: tt.descriptor_load %arg1
278350
351+ // CHECK-TDM-LABEL: tt.func @parameter_into_broadcast_with_3_or_more_stages_does_not_use_tma
352+ // CHECK-TDM: tt.descriptor_load
353+ // CHECK-TDM: tt.descriptor_load
354+ // CHECK-TDM: tt.descriptor_store
355+
279356// -----
280357
281358#indexing_map_unaligned = #xla.indexing_map <" (d0) -> (d0 * 2816), domain: d0 in [0, 2047]" >
@@ -301,6 +378,10 @@ module {
301378// CHECK: %[[MASK:.*]] = arith.cmpi slt
302379// CHECK: tt.load {{.*}}, %[[MASK]], {{.*}}
303380
381+ // CHECK-TDM-LABEL: tt.func @apply_mask_to_unaligned_offset_with_perfect_total_size
382+ // CHECK-TDM: tt.descriptor_load
383+ // CHECK-TDM: tt.descriptor_store
384+
304385// -----
305386
306387#indexing_map_aligned_with_oob_at_end = #xla.indexing_map <" (pid, d1) -> ((pid floordiv 64) * 384 + d1 * 32), domain: pid in [0, 1023], d1 in [0, 11]" >
@@ -328,3 +409,7 @@ module {
328409// CHECK-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end
329410// CHECK: %[[MASK:.*]] = arith.cmpi slt
330411// CHECK: tt.load {{.*}}, %[[MASK]], {{.*}}
412+
413+ // CHECK-TDM-LABEL: tt.func @apply_mask_to_aligned_offset_with_out_of_bounds_reads_at_end
414+ // CHECK-TDM: tt.descriptor_load
415+ // CHECK-TDM: tt.descriptor_store
0 commit comments