Skip to content

Commit 8e105b3

Browse files
authored
[mlir][XeGPU][Transform] Update the xegpu-vector-linearize test with named captures in FileCheck. (llvm#186314)
These tests used to contain specific numbered SSA names (%0, %1, %2, etc.), this may cause unnecessary issue if a test is updated with new ops. Update the tests to use named captures instead for future adaptability.
1 parent 3ecede5 commit 8e105b3

1 file changed

Lines changed: 50 additions & 50 deletions

File tree

mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -155,49 +155,49 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
155155
}
156156

157157
// CHECK-LABEL: func.func @gather_memref_2d
158-
// CHECK-SAME: (%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
158+
// CHECK-SAME: (%[[BASE:.*]]: memref<?x?xf32>, %[[IDX:.*]]: vector<2x3xindex>, %[[MASK:.*]]: vector<2x3xi1>, %[[PASS:.*]]: vector<2x3xf32>) -> vector<2x3xf32>
159159

160-
// CHECK: %0 = ub.poison : vector<6xf32>
161-
// CHECK: %c1 = arith.constant 1 : index
162-
// CHECK: %c0 = arith.constant 0 : index
163-
// CHECK: %1 = vector.shape_cast %arg3 : vector<2x3xf32> to vector<6xf32>
160+
// CHECK: %[[POISON:.*]] = ub.poison : vector<6xf32>
161+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
162+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
163+
// CHECK: %[[PASS_CAST:.*]] = vector.shape_cast %[[PASS]] : vector<2x3xf32> to vector<6xf32>
164164

165165
// First shuffle + if ladder for row 0
166-
// CHECK: %2 = vector.shuffle %1, %1 [0, 1, 2]
167-
// CHECK: %3 = vector.extract %arg2[0, 0]
168-
// CHECK: %4 = vector.extract %arg1[0, 0]
169-
// CHECK: %5 = arith.addi %4, %c1
170-
// CHECK: %6 = scf.if %3 -> (vector<3xf32>) {
171-
// CHECK: %{{.*}} = vector.load %arg0[%c0, %5] : memref<?x?xf32>, vector<1xf32>
172-
// CHECK: %{{.*}} = vector.extract {{.*}}[0] : f32
173-
// CHECK: %{{.*}} = vector.insert {{.*}}, %2 [0] : f32 into vector<3xf32>
174-
// CHECK: scf.yield {{.*}} : vector<3xf32>
166+
// CHECK: %[[ROW0_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[PASS_CAST]] [0, 1, 2]
167+
// CHECK: %[[MASK_0_0:.*]] = vector.extract %[[MASK]][0, 0]
168+
// CHECK: %[[IDX_0_0:.*]] = vector.extract %[[IDX]][0, 0]
169+
// CHECK: %[[OFF_0_0:.*]] = arith.addi %[[IDX_0_0]], %[[C1]]
170+
// CHECK: %[[IF_0_0:.*]] = scf.if %[[MASK_0_0]] -> (vector<3xf32>) {
171+
// CHECK: %[[LOAD_0_0:.*]] = vector.load %[[BASE]][%[[C0]], %[[OFF_0_0]]] : memref<?x?xf32>, vector<1xf32>
172+
// CHECK: %[[ELEM_0_0:.*]] = vector.extract %[[LOAD_0_0]][0] : f32
173+
// CHECK: %[[INS_0_0:.*]] = vector.insert %[[ELEM_0_0]], %[[ROW0_INIT]] [0] : f32 into vector<3xf32>
174+
// CHECK: scf.yield %[[INS_0_0]] : vector<3xf32>
175175
// CHECK: } else {
176-
// CHECK: scf.yield %2 : vector<3xf32>
176+
// CHECK: scf.yield %[[ROW0_INIT]] : vector<3xf32>
177177
// CHECK: }
178178

179-
// CHECK: %7 = vector.extract %arg2[0, 1]
180-
// CHECK: %8 = vector.extract %arg1[0, 1]
181-
// CHECK: %9 = arith.addi %8, %c1
182-
// CHECK: %10 = scf.if %7 -> (vector<3xf32>)
179+
// CHECK: %[[MASK_0_1:.*]] = vector.extract %[[MASK]][0, 1]
180+
// CHECK: %[[IDX_0_1:.*]] = vector.extract %[[IDX]][0, 1]
181+
// CHECK: %[[OFF_0_1:.*]] = arith.addi %[[IDX_0_1]], %[[C1]]
182+
// CHECK: %[[IF_0_1:.*]] = scf.if %[[MASK_0_1]] -> (vector<3xf32>)
183183

184184
// … (similar checks for the rest of row 0, then row 1)
185185

186-
// CHECK: %15 = vector.shuffle %0, %{{.*}} [6, 7, 8, 3, 4, 5]
187-
// CHECK: %16 = vector.shuffle %1, %1 [3, 4, 5]
186+
// CHECK: %[[ROW_SHUFFLE:.*]] = vector.shuffle %[[POISON]], {{.*}} [6, 7, 8, 3, 4, 5]
187+
// CHECK: %[[ROW1_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[PASS_CAST]] [3, 4, 5]
188188

189189
// Row 1 if ladder checks
190-
// CHECK: %17 = vector.extract %arg2[1, 0]
191-
// CHECK: %18 = vector.extract %arg1[1, 0]
192-
// CHECK: %19 = arith.addi %18, %c1
193-
// CHECK: %20 = scf.if %17 -> (vector<3xf32>)
190+
// CHECK: %[[MASK_1_0:.*]] = vector.extract %[[MASK]][1, 0]
191+
// CHECK: %[[IDX_1_0:.*]] = vector.extract %[[IDX]][1, 0]
192+
// CHECK: %[[OFF_1_0:.*]] = arith.addi %[[IDX_1_0]], %[[C1]]
193+
// CHECK: %[[IF_1_0:.*]] = scf.if %[[MASK_1_0]] -> (vector<3xf32>)
194194

195195
// … (similar checks for remaining row 1 inserts)
196196

197197
// Final reshuffle and cast
198-
// CHECK: %29 = vector.shuffle %15, %{{.*}} [0, 1, 2, 6, 7, 8]
199-
// CHECK: %30 = vector.shape_cast %29 : vector<6xf32> to vector<2x3xf32>
200-
// CHECK: return %30 : vector<2x3xf32>
198+
// CHECK: %[[FINAL_SHUFFLE:.*]] = vector.shuffle %[[ROW_SHUFFLE]], {{.*}} [0, 1, 2, 6, 7, 8]
199+
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[FINAL_SHUFFLE]] : vector<6xf32> to vector<2x3xf32>
200+
// CHECK: return %[[RESULT]] : vector<2x3xf32>
201201
func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
202202
%c0 = arith.constant 0 : index
203203
%c1 = arith.constant 1 : index
@@ -209,32 +209,32 @@ func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask
209209
// Check for vector linearization interoperability with XeGPU dialect ops.
210210
// The `xegpu-vector-linearize` pass does not itself affect the XeGPU ops.
211211

212-
// CHECK: gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel {
213-
// CHECK: %c0 = arith.constant 0 : index
214-
// CHECK: %cst = arith.constant dense<0.000000e+00> : vector<64xf16>
215-
// CHECK: %cst_0 = arith.constant dense<5.000000e+00> : vector<64xf32>
212+
// CHECK: gpu.func @test_kernel(%[[A:.*]]: memref<8x16xf16>, %[[B:.*]]: memref<16x16xf16>, %[[C:.*]]: memref<8x16xf32>) kernel {
213+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
214+
// CHECK: %[[CST_A:.*]] = arith.constant dense<0.000000e+00> : vector<64xf16>
215+
// CHECK: %[[CST_C:.*]] = arith.constant dense<5.000000e+00> : vector<64xf32>
216216

217-
// CHECK: %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0]
218-
// CHECK: %1 = xegpu.load_nd %0
219-
// CHECK: %2 = vector.shape_cast %1 : vector<8x16xf16> to vector<128xf16>
220-
// CHECK: %3 = vector.shuffle %2, %cst {{.*}} : vector<128xf16>, vector<64xf16>
221-
// CHECK: %4 = vector.shape_cast %3 : vector<128xf16> to vector<8x16xf16>
217+
// CHECK: %[[A_TDESC:.*]] = xegpu.create_nd_tdesc %[[A]][%[[C0]], %[[C0]]]
218+
// CHECK: %[[A_VAL:.*]] = xegpu.load_nd %[[A_TDESC]]
219+
// CHECK: %[[A_CAST:.*]] = vector.shape_cast %[[A_VAL]] : vector<8x16xf16> to vector<128xf16>
220+
// CHECK: %[[A_SHUFFLE:.*]] = vector.shuffle %[[A_CAST]], %[[CST_A]] {{.*}} : vector<128xf16>, vector<64xf16>
221+
// CHECK: %[[A_RESULT:.*]] = vector.shape_cast %[[A_SHUFFLE]] : vector<128xf16> to vector<8x16xf16>
222222

223-
// CHECK: %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0]
224-
// CHECK: %6 = xegpu.load_nd %5
225-
// CHECK: %7 = vector.shape_cast %6 : vector<16x16xf16> to vector<256xf16>
226-
// CHECK: %8 = vector.shuffle %7, %cst {{.*}} : vector<256xf16>, vector<64xf16>
227-
// CHECK: %9 = vector.shape_cast %8 : vector<256xf16> to vector<16x16xf16>
223+
// CHECK: %[[B_TDESC:.*]] = xegpu.create_nd_tdesc %[[B]][%[[C0]], %[[C0]]]
224+
// CHECK: %[[B_VAL:.*]] = xegpu.load_nd %[[B_TDESC]]
225+
// CHECK: %[[B_CAST:.*]] = vector.shape_cast %[[B_VAL]] : vector<16x16xf16> to vector<256xf16>
226+
// CHECK: %[[B_SHUFFLE:.*]] = vector.shuffle %[[B_CAST]], %[[CST_A]] {{.*}} : vector<256xf16>, vector<64xf16>
227+
// CHECK: %[[B_RESULT:.*]] = vector.shape_cast %[[B_SHUFFLE]] : vector<256xf16> to vector<16x16xf16>
228228

229-
// CHECK: %10 = xegpu.dpas %4, %9 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
230-
// CHECK: %11 = vector.shape_cast %10 : vector<8x16xf32> to vector<128xf32>
231-
// CHECK: %12 = vector.shuffle %11, %11 {{.*}} : vector<128xf32>, vector<128xf32>
232-
// CHECK: %13 = arith.addf %12, %cst_0 : vector<64xf32>
233-
// CHECK: %14 = vector.shuffle %11, %13 {{.*}} : vector<128xf32>, vector<64xf32>
234-
// CHECK: %15 = vector.shape_cast %14 : vector<128xf32> to vector<8x16xf32>
229+
// CHECK: %[[DPAS:.*]] = xegpu.dpas %[[A_RESULT]], %[[B_RESULT]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
230+
// CHECK: %[[DPAS_CAST:.*]] = vector.shape_cast %[[DPAS]] : vector<8x16xf32> to vector<128xf32>
231+
// CHECK: %[[EXTRACT_SHUFFLE:.*]] = vector.shuffle %[[DPAS_CAST]], %[[DPAS_CAST]] {{.*}} : vector<128xf32>, vector<128xf32>
232+
// CHECK: %[[ADDF:.*]] = arith.addf %[[EXTRACT_SHUFFLE]], %[[CST_C]] : vector<64xf32>
233+
// CHECK: %[[INSERT_SHUFFLE:.*]] = vector.shuffle %[[DPAS_CAST]], %[[ADDF]] {{.*}} : vector<128xf32>, vector<64xf32>
234+
// CHECK: %[[C_RESULT:.*]] = vector.shape_cast %[[INSERT_SHUFFLE]] : vector<128xf32> to vector<8x16xf32>
235235

236-
// CHECK: %16 = xegpu.create_nd_tdesc %arg2[%c0, %c0]
237-
// CHECK: xegpu.store_nd %15, %16
236+
// CHECK: %[[C_TDESC:.*]] = xegpu.create_nd_tdesc %[[C]][%[[C0]], %[[C0]]]
237+
// CHECK: xegpu.store_nd %[[C_RESULT]], %[[C_TDESC]]
238238
// CHECK: gpu.return
239239

240240
gpu.module @test_kernel {

0 commit comments

Comments
 (0)