@@ -155,49 +155,49 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
155155}
156156
157157// CHECK-LABEL: func.func @gather_memref_2d
158- // CHECK-SAME: (%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
158+ // CHECK-SAME: (%[[BASE:.*]]: memref<?x?xf32>, %[[IDX:.*]]: vector<2x3xindex>, %[[MASK:.*]]: vector<2x3xi1>, %[[PASS:.*]]: vector<2x3xf32>) -> vector<2x3xf32>
159159
160- // CHECK: %0 = ub.poison : vector<6xf32>
161- // CHECK: %c1 = arith.constant 1 : index
162- // CHECK: %c0 = arith.constant 0 : index
163- // CHECK: %1 = vector.shape_cast %arg3 : vector<2x3xf32> to vector<6xf32>
160+ // CHECK: %[[POISON:.*]] = ub.poison : vector<6xf32>
161+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
162+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
163+ // CHECK: %[[PASS_CAST:.*]] = vector.shape_cast %[[PASS]] : vector<2x3xf32> to vector<6xf32>
164164
165165// First shuffle + if ladder for row 0
166- // CHECK: %2 = vector.shuffle %1 , %1 [0, 1, 2]
167- // CHECK: %3 = vector.extract %arg2 [0, 0]
168- // CHECK: %4 = vector.extract %arg1 [0, 0]
169- // CHECK: %5 = arith.addi %4 , %c1
170- // CHECK: %6 = scf.if %3 -> (vector<3xf32>) {
171- // CHECK: %{{.*}} = vector.load %arg0[%c0 , %5 ] : memref<?x?xf32>, vector<1xf32>
172- // CHECK: %{{.*}} = vector.extract {{.*}} [0] : f32
173- // CHECK: %{{.*}} = vector.insert {{.*}} , %2 [0] : f32 into vector<3xf32>
174- // CHECK: scf.yield {{.*}} : vector<3xf32>
166+ // CHECK: %[[ROW0_INIT:.*]] = vector.shuffle %[[PASS_CAST]] , %[[PASS_CAST]] [0, 1, 2]
167+ // CHECK: %[[MASK_0_0:.*]] = vector.extract %[[MASK]] [0, 0]
168+ // CHECK: %[[IDX_0_0:.*]] = vector.extract %[[IDX]] [0, 0]
169+ // CHECK: %[[OFF_0_0:.*]] = arith.addi %[[IDX_0_0]] , %[[C1]]
170+ // CHECK: %[[IF_0_0:.*]] = scf.if %[[MASK_0_0]] -> (vector<3xf32>) {
171+ // CHECK: %[[LOAD_0_0:.*]] = vector.load %[[BASE]][%[[C0]] , %[[OFF_0_0]] ] : memref<?x?xf32>, vector<1xf32>
172+ // CHECK: %[[ELEM_0_0:.*]] = vector.extract %[[LOAD_0_0]] [0] : f32
173+ // CHECK: %[[INS_0_0:.*]] = vector.insert %[[ELEM_0_0]] , %[[ROW0_INIT]] [0] : f32 into vector<3xf32>
174+ // CHECK: scf.yield %[[INS_0_0]] : vector<3xf32>
175175// CHECK: } else {
176- // CHECK: scf.yield %2 : vector<3xf32>
176+ // CHECK: scf.yield %[[ROW0_INIT]] : vector<3xf32>
177177// CHECK: }
178178
179- // CHECK: %7 = vector.extract %arg2 [0, 1]
180- // CHECK: %8 = vector.extract %arg1 [0, 1]
181- // CHECK: %9 = arith.addi %8 , %c1
182- // CHECK: %10 = scf.if %7 -> (vector<3xf32>)
179+ // CHECK: %[[MASK_0_1:.*]] = vector.extract %[[MASK]] [0, 1]
180+ // CHECK: %[[IDX_0_1:.*]] = vector.extract %[[IDX]] [0, 1]
181+ // CHECK: %[[OFF_0_1:.*]] = arith.addi %[[IDX_0_1]] , %[[C1]]
182+ // CHECK: %[[IF_0_1:.*]] = scf.if %[[MASK_0_1]] -> (vector<3xf32>)
183183
184184// … (similar checks for the rest of row 0, then row 1)
185185
186- // CHECK: %15 = vector.shuffle %0, % {{.*}} [6, 7, 8, 3, 4, 5]
187- // CHECK: %16 = vector.shuffle %1 , %1 [3, 4, 5]
186+ // CHECK: %[[ROW_SHUFFLE:.*]] = vector.shuffle %[[POISON]], {{.*}} [6, 7, 8, 3, 4, 5]
187+ // CHECK: %[[ROW1_INIT:.*]] = vector.shuffle %[[PASS_CAST]] , %[[PASS_CAST]] [3, 4, 5]
188188
189189// Row 1 if ladder checks
190- // CHECK: %17 = vector.extract %arg2 [1, 0]
191- // CHECK: %18 = vector.extract %arg1 [1, 0]
192- // CHECK: %19 = arith.addi %18 , %c1
193- // CHECK: %20 = scf.if %17 -> (vector<3xf32>)
190+ // CHECK: %[[MASK_1_0:.*]] = vector.extract %[[MASK]] [1, 0]
191+ // CHECK: %[[IDX_1_0:.*]] = vector.extract %[[IDX]] [1, 0]
192+ // CHECK: %[[OFF_1_0:.*]] = arith.addi %[[IDX_1_0]] , %[[C1]]
193+ // CHECK: %[[IF_1_0:.*]] = scf.if %[[MASK_1_0]] -> (vector<3xf32>)
194194
195195// … (similar checks for remaining row 1 inserts)
196196
197197// Final reshuffle and cast
198- // CHECK: %29 = vector.shuffle %15, % {{.*}} [0, 1, 2, 6, 7, 8]
199- // CHECK: %30 = vector.shape_cast %29 : vector<6xf32> to vector<2x3xf32>
200- // CHECK: return %30 : vector<2x3xf32>
198+ // CHECK: %[[FINAL_SHUFFLE:.*]] = vector.shuffle %[[ROW_SHUFFLE]], {{.*}} [0, 1, 2, 6, 7, 8]
199+ // CHECK: %[[RESULT:.*]] = vector.shape_cast %[[FINAL_SHUFFLE]] : vector<6xf32> to vector<2x3xf32>
200+ // CHECK: return %[[RESULT]] : vector<2x3xf32>
201201func.func @gather_memref_2d (%base: memref <?x?xf32 >, %v: vector <2 x3 xindex >, %mask: vector <2 x3 xi1 >, %pass_thru: vector <2 x3 xf32 >) -> vector <2 x3 xf32 > {
202202 %c0 = arith.constant 0 : index
203203 %c1 = arith.constant 1 : index
@@ -209,32 +209,32 @@ func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask
209209// Check for vector linearization interoperability with XeGPU dialect ops.
210210// The `xegpu-vector-linearize` pass does not itself affect the XeGPU ops.
211211
212- // CHECK: gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2 : memref<8x16xf32>) kernel {
213- // CHECK: %c0 = arith.constant 0 : index
214- // CHECK: %cst = arith.constant dense<0.000000e+00> : vector<64xf16>
215- // CHECK: %cst_0 = arith.constant dense<5.000000e+00> : vector<64xf32>
212+ // CHECK: gpu.func @test_kernel(%[[A:.*]]: memref<8x16xf16>, %[[B:.*]]: memref<16x16xf16>, %[[C:.*]] : memref<8x16xf32>) kernel {
213+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
214+ // CHECK: %[[CST_A:.*]] = arith.constant dense<0.000000e+00> : vector<64xf16>
215+ // CHECK: %[[CST_C:.*]] = arith.constant dense<5.000000e+00> : vector<64xf32>
216216
217- // CHECK: %0 = xegpu.create_nd_tdesc %arg0[%c0 , %c0 ]
218- // CHECK: %1 = xegpu.load_nd %0
219- // CHECK: %2 = vector.shape_cast %1 : vector<8x16xf16> to vector<128xf16>
220- // CHECK: %3 = vector.shuffle %2 , %cst {{.*}} : vector<128xf16>, vector<64xf16>
221- // CHECK: %4 = vector.shape_cast %3 : vector<128xf16> to vector<8x16xf16>
217+ // CHECK: %[[A_TDESC:.*]] = xegpu.create_nd_tdesc %[[A]][%[[C0]] , %[[C0]] ]
218+ // CHECK: %[[A_VAL:.*]] = xegpu.load_nd %[[A_TDESC]]
219+ // CHECK: %[[A_CAST:.*]] = vector.shape_cast %[[A_VAL]] : vector<8x16xf16> to vector<128xf16>
220+ // CHECK: %[[A_SHUFFLE:.*]] = vector.shuffle %[[A_CAST]] , %[[CST_A]] {{.*}} : vector<128xf16>, vector<64xf16>
221+ // CHECK: %[[A_RESULT:.*]] = vector.shape_cast %[[A_SHUFFLE]] : vector<128xf16> to vector<8x16xf16>
222222
223- // CHECK: %5 = xegpu.create_nd_tdesc %arg1[%c0 , %c0 ]
224- // CHECK: %6 = xegpu.load_nd %5
225- // CHECK: %7 = vector.shape_cast %6 : vector<16x16xf16> to vector<256xf16>
226- // CHECK: %8 = vector.shuffle %7 , %cst {{.*}} : vector<256xf16>, vector<64xf16>
227- // CHECK: %9 = vector.shape_cast %8 : vector<256xf16> to vector<16x16xf16>
223+ // CHECK: %[[B_TDESC:.*]] = xegpu.create_nd_tdesc %[[B]][%[[C0]] , %[[C0]] ]
224+ // CHECK: %[[B_VAL:.*]] = xegpu.load_nd %[[B_TDESC]]
225+ // CHECK: %[[B_CAST:.*]] = vector.shape_cast %[[B_VAL]] : vector<16x16xf16> to vector<256xf16>
226+ // CHECK: %[[B_SHUFFLE:.*]] = vector.shuffle %[[B_CAST]] , %[[CST_A]] {{.*}} : vector<256xf16>, vector<64xf16>
227+ // CHECK: %[[B_RESULT:.*]] = vector.shape_cast %[[B_SHUFFLE]] : vector<256xf16> to vector<16x16xf16>
228228
229- // CHECK: %10 = xegpu.dpas %4 , %9 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
230- // CHECK: %11 = vector.shape_cast %10 : vector<8x16xf32> to vector<128xf32>
231- // CHECK: %12 = vector.shuffle %11 , %11 {{.*}} : vector<128xf32>, vector<128xf32>
232- // CHECK: %13 = arith.addf %12 , %cst_0 : vector<64xf32>
233- // CHECK: %14 = vector.shuffle %11 , %13 {{.*}} : vector<128xf32>, vector<64xf32>
234- // CHECK: %15 = vector.shape_cast %14 : vector<128xf32> to vector<8x16xf32>
229+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[A_RESULT]] , %[[B_RESULT]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
230+ // CHECK: %[[DPAS_CAST:.*]] = vector.shape_cast %[[DPAS]] : vector<8x16xf32> to vector<128xf32>
231+ // CHECK: %[[EXTRACT_SHUFFLE:.*]] = vector.shuffle %[[DPAS_CAST]] , %[[DPAS_CAST]] {{.*}} : vector<128xf32>, vector<128xf32>
232+ // CHECK: %[[ADDF:.*]] = arith.addf %[[EXTRACT_SHUFFLE]] , %[[CST_C]] : vector<64xf32>
233+ // CHECK: %[[INSERT_SHUFFLE:.*]] = vector.shuffle %[[DPAS_CAST]] , %[[ADDF]] {{.*}} : vector<128xf32>, vector<64xf32>
234+ // CHECK: %[[C_RESULT:.*]] = vector.shape_cast %[[INSERT_SHUFFLE]] : vector<128xf32> to vector<8x16xf32>
235235
236- // CHECK: %16 = xegpu.create_nd_tdesc %arg2[%c0 , %c0 ]
237- // CHECK: xegpu.store_nd %15 , %16
236+ // CHECK: %[[C_TDESC:.*]] = xegpu.create_nd_tdesc %[[C]][%[[C0]] , %[[C0]] ]
237+ // CHECK: xegpu.store_nd %[[C_RESULT]] , %[[C_TDESC]]
238238// CHECK: gpu.return
239239
240240gpu.module @test_kernel {
0 commit comments