@@ -208,7 +208,7 @@ class gemm_4bit_cutlass_kernel {
208208 Tensor tCgA = thr_mma.partition_A (gA );
209209 Tensor tCgB = thr_mma.partition_B (gB ); // values for each_thread (FrgV,(RestN,RestK),*)
210210
211- #if 1
211+ #if 0
212212 Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
213213 Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
214214 Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout());
@@ -239,26 +239,39 @@ class gemm_4bit_cutlass_kernel {
239239 auto layout_A = make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape());
240240 Tensor mma_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_A.shape(), Int<2>{}), cute::append(layout_A.stride(), Int<0>{})));
241241#else
242- auto layout_A = make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape());
243- auto layout_B = make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape());
244-
245- Tensor mma_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_A.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_A.stride(), 0)));
246- Tensor mma_B = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_B.stride(), 0)));
247- Tensor dequant_frag = make_tensor<ElementB>(cute::make_layout(cute::append(layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_B.stride(), 0)));
248-
242+ Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
243+ Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
244+ Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout ());
245+
246+ Tensor mma_A_b = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
247+ Tensor mma_B_b = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
248+ Tensor dequant_frag_b = make_tensor<ElementB>(mma_B_b.layout ());
249+
249250 static constexpr auto scale_shape_t = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
250251 static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
251252 static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE ;
252- using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>;
253- Tensor fragment_scale = make_tensor<ElementScale>(cute::make_layout(cute::append(FragScaleLayout{}.shape(), cute::make_shape(Int<2>{})), cute::make_stride(FragScaleLayout{}.stride(), 0)));
254-
255- auto single_layout_A = thr_copy_A.retile_D(cute::make_tensor(mma_A.data(), layout_A)).layout();
256- auto single_layout_B = thr_copy_B.retile_D(cute::make_tensor(dequant_frag.data(), layout_B)).layout();
257- auto single_layout_Scale = thr_copy_scale.retile_D(cute::make_tensor(fragment_scale.data(), FragScaleLayout{})).layout();
258-
259- Tensor frag_copy_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(single_layout_A.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_A.stride(), 0)));
260- Tensor frag_copy_B = make_tensor<ElementB>(cute::make_layout(cute::append(single_layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_B.stride(), 0)));
261- Tensor frag_copy_Scale = make_tensor<float>(cute::make_layout(cute::append(single_layout_Scale.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_Scale.stride(), 0)));
253+ using FragScaleLayout = Layout<Shape<Int<scale_shape_t >, Int<scale_shape_n>, Int<scale_shape_k>>>; // [1, dequant_N, block_num]
254+ Tensor fragment_scale_a = make_tensor<ElementScale>(FragScaleLayout{});
255+ Tensor fragment_scale_b = make_tensor<ElementScale>(FragScaleLayout{});
256+
257+ Tensor frag_copy_A_a = thr_copy_A.retile_D (mma_A_a);
258+ Tensor frag_copy_B_a = thr_copy_B.retile_D (dequant_frag_a);
259+ Tensor frag_copy_Scale_a = thr_copy_scale.retile_D (fragment_scale_a);
260+
261+ Tensor frag_copy_A_b = thr_copy_A.retile_D (mma_A_b);
262+ Tensor frag_copy_B_b = thr_copy_B.retile_D (dequant_frag_b);
263+ Tensor frag_copy_Scale_b = thr_copy_scale.retile_D (fragment_scale_b);
264+
265+ cute::tuple<decltype (mma_A_a), decltype (mma_A_b)> mma_A (mma_A_a, mma_A_b);
266+ cute::tuple<decltype (mma_B_a), decltype (mma_B_b)> mma_B (mma_B_a, mma_B_b);
267+ cute::tuple<decltype (dequant_frag_a), decltype (dequant_frag_b)> dequant_frag (dequant_frag_a, dequant_frag_b);
268+ cute::tuple<decltype (fragment_scale_a), decltype (fragment_scale_b)> fragment_scale (fragment_scale_a, fragment_scale_b);
269+ cute::tuple<decltype (frag_copy_A_a), decltype (frag_copy_A_b)> frag_copy_A (frag_copy_A_a, frag_copy_A_b);
270+ cute::tuple<decltype (frag_copy_B_a), decltype (frag_copy_B_b)> frag_copy_B (frag_copy_B_a, frag_copy_B_b);
271+ cute::tuple<decltype (frag_copy_Scale_a), decltype (frag_copy_Scale_b)> frag_copy_Scale (frag_copy_Scale_a, frag_copy_Scale_b);
272+ // auto& mma_A_0 = cute::get<0>(mma_A_tuple); // 引用 mma_A_a
273+ // auto& mma_A_1 = cute::get<1>(mma_A_tuple); // 引用 mma_A_b
274+
262275#endif
263276 Tensor tAgA = thr_copy_A.retile_S (tCgA);
264277 Tensor tBgB = thr_copy_B.retile_S (tCgB);
@@ -293,8 +306,8 @@ class gemm_4bit_cutlass_kernel {
293306
294307 int start_lut_id = sg_idx % LUT_NUM ;
295308
296- #if 0
297- auto dequant = [&](int start_lut_id, int buffer_idx) {
309+ #if 1
310+ auto dequant = [&](int start_lut_id, const int buffer_idx) {
298311 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
299312 constexpr int K = decltype (cute::size (mma_B))::value / N;
300313
@@ -309,14 +322,6 @@ class gemm_4bit_cutlass_kernel {
309322 constexpr int dst_vec_size = 4 ;
310323 constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
311324
312- size_t dequant_offset = buffer_idx * dequant_frag.size() / 2;
313- size_t scale_offset = buffer_idx * fragment_scale.size() / 2;
314- size_t mma_offset = buffer_idx * mma_B.size() / 2;
315-
316- auto* dequant_ptr = cute::raw_pointer_cast(dequant_frag.data()) + dequant_offset;
317- auto* scale_ptr = cute::raw_pointer_cast(fragment_scale.data()) + scale_offset;
318- auto* mma_ptr = cute::raw_pointer_cast(mma_B.data()) + mma_offset;
319-
320325 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
321326
322327 int lut_id = start_lut_id;
@@ -328,13 +333,13 @@ class gemm_4bit_cutlass_kernel {
328333
329334 #pragma unroll
330335 for (int v = 0 ; v < src_vec_size; v++) {
331- src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(dequant_ptr )[n*src_loop_num + l][v];
336+ src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (cute::get<buffer_idx>(dequant_frag). data ()) )[n*src_loop_num + l][v];
332337 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
333338
334339 #pragma unroll
335340 for (int c = 0 ; c < src_compress_size; c++) {
336341 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
337- float scale_value = *reinterpret_cast<float*>(scale_ptr + ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE))));
342+ float scale_value = cute::get<buffer_idx>(fragment_scale_a) ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
338343
339344 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
340345 lut_id = (lut_id + 1 ) % LUT_NUM ;
@@ -344,14 +349,14 @@ class gemm_4bit_cutlass_kernel {
344349
345350 #pragma unroll
346351 for (int l = 0 ; l < dst_loop_num; l++) {
347- reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(mma_ptr )[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
352+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (cute::get<buffer_idx>(mma_B_a). data ()) )[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
348353 }
349354 }
350355 };
351356
352- copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B(_,_,_,0 ));
353- copy(params.tiled_copy_scale, tSgS(_,_,_,k_start_idx * BLK_K/params.group_size), frag_copy_Scale(_,_,_,0 ));
354- copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), frag_copy_A(_,_,_,0 ));
357+ copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), cute::get< 0 >(frag_copy_B ));
358+ copy (params.tiled_copy_scale , tSgS (_,_,_,k_start_idx * BLK_K /params.group_size ), cute::get< 0 >(frag_copy_Scale ));
359+ copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx), cute::get< 0 >(frag_copy_A ));
355360
356361 if (prefetch_k < k_tile_count) {
357362 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
@@ -360,24 +365,27 @@ class gemm_4bit_cutlass_kernel {
360365 prefetch_k++;
361366
362367 for (int k_tile = k_start_idx + 1 , k_s = 1 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
363- const int buf_idx = k_tile % 2;
368+ constexpr int buf_idx = k_tile % 2 ;
364369
365370 dequant (start_lut_id, buf_idx);
366371
367- copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B(_,_,_, buf_idx));
368- copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), frag_copy_Scale(_,_,_, buf_idx));
369- copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A(_,_,_, buf_idx));
372+ copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), cute::get< buf_idx>(frag_copy_B ));
373+ copy (params.tiled_copy_scale , tSgS (_,_,_,(k_start_idx+k_s)*BLK_K /params.group_size ), cute::get< buf_idx>(frag_copy_Scale ));
374+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), cute::get< buf_idx>(frag_copy_A ));
370375
371376 if (prefetch_k < k_tile_count) {
372377 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
373378 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
374379 }
375-
376- cute::gemm(tiled_mma, frag_copy_A(_,_,_,1-buf_idx), frag_copy_B(_,_,_,1-buf_idx), accumulators);
380+
381+ constexpr int idx = 1 - buf_idx;
382+ cute::gemm (tiled_mma, cute::get<idx>(frag_copy_A), cute::get<idx>(frag_copy_B), accumulators);
377383 barrier_wait (3 );
378384 }
379- cute::gemm(tiled_mma, frag_copy_A(_,_,_,1), frag_copy_B(_,_,_,1), accumulators);
385+ cute::gemm (tiled_mma, cute::get<1 >(frag_copy_A), cute::get<1 >(frag_copy_B), accumulators);
386+
380387#else
388+
381389 auto dequant_a = [&] (int start_lut_id){
382390 constexpr int N = decltype(cute::size<1>(mma_B_a))::value;
383391 constexpr int K = decltype(cute::size(mma_B_a))::value / N;
0 commit comments