@@ -207,7 +207,8 @@ class gemm_4bit_cutlass_kernel {
207207
208208 Tensor tCgA = thr_mma.partition_A (gA );
209209 Tensor tCgB = thr_mma.partition_B (gB ); // values for each_thread (FrgV,(RestN,RestK),*)
210-
210+
211+ #if 1
211212 Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
212213 Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
213214 Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout ());
@@ -235,6 +236,30 @@ class gemm_4bit_cutlass_kernel {
235236 Tensor frag_copy_B_b = thr_copy_B.retile_D (dequant_frag_b);
236237 Tensor frag_copy_Scale_b = thr_copy_scale.retile_D (fragment_scale_b);
237238
239+ auto layout_A = make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ());
240+ Tensor mma_A = make_tensor<ElementMMA>(cute::make_layout (cute::append (layout_A.shape (), Int<2 >{}), cute::append (layout_A.stride (), Int<0 >{})));
241+ #else
242+ auto layout_A = make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape());
243+ auto layout_B = make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape());
244+
245+ Tensor mma_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_A.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_A.stride(), 0)));
246+ Tensor mma_B = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_B.stride(), 0)));
247+ Tensor dequant_frag = make_tensor<ElementB>(cute::make_layout(cute::append(layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_B.stride(), 0)));
248+
249+ static constexpr auto scale_shape_t = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
250+ static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
251+ static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE;
252+ using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>;
253+ Tensor fragment_scale = make_tensor<ElementScale>(cute::make_layout(cute::append(FragScaleLayout{}.shape(), cute::make_shape(Int<2>{})), cute::make_stride(FragScaleLayout{}.stride(), 0)));
254+
255+ auto single_layout_A = thr_copy_A.retile_D(cute::make_tensor(mma_A.data(), layout_A)).layout();
256+ auto single_layout_B = thr_copy_B.retile_D(cute::make_tensor(dequant_frag.data(), layout_B)).layout();
257+ auto single_layout_Scale = thr_copy_scale.retile_D(cute::make_tensor(fragment_scale.data(), FragScaleLayout{})).layout();
258+
259+ Tensor frag_copy_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(single_layout_A.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_A.stride(), 0)));
260+ Tensor frag_copy_B = make_tensor<ElementB>(cute::make_layout(cute::append(single_layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_B.stride(), 0)));
261+ Tensor frag_copy_Scale = make_tensor<float>(cute::make_layout(cute::append(single_layout_Scale.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_Scale.stride(), 0)));
262+ #endif
238263 Tensor tAgA = thr_copy_A.retile_S (tCgA);
239264 Tensor tBgB = thr_copy_B.retile_S (tCgB);
240265
@@ -260,10 +285,103 @@ class gemm_4bit_cutlass_kernel {
260285 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
261286 int prefetch_k = k_start_idx;
262287
288+ CUTLASS_PRAGMA_UNROLL
289+ for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
290+ prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
291+ prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
292+ }
293+
294+ int start_lut_id = sg_idx % LUT_NUM ;
295+
296+ #if 0
297+ auto dequant = [&](int start_lut_id, int buffer_idx) {
298+ constexpr int N = decltype(cute::size<1>(mma_B))::value;
299+ constexpr int K = decltype(cute::size(mma_B))::value / N;
300+
301+ using src_compress_type = uint32_t;
302+ using dst_compress_type = uint32_t;
303+
304+ constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
305+ constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
306+ constexpr int src_vec_size = 4;
307+
308+ constexpr int src_loop_num = K / src_vec_size / src_compress_size;
309+ constexpr int dst_vec_size = 4;
310+ constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
311+
312+ size_t dequant_offset = buffer_idx * dequant_frag.size() / 2;
313+ size_t scale_offset = buffer_idx * fragment_scale.size() / 2;
314+ size_t mma_offset = buffer_idx * mma_B.size() / 2;
315+
316+ auto* dequant_ptr = cute::raw_pointer_cast(dequant_frag.data()) + dequant_offset;
317+ auto* scale_ptr = cute::raw_pointer_cast(fragment_scale.data()) + scale_offset;
318+ auto* mma_ptr = cute::raw_pointer_cast(mma_B.data()) + mma_offset;
319+
320+ ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
321+
322+ int lut_id = start_lut_id;
323+ #pragma unroll
324+ for (int n = 0; n < N; n++) {
325+
326+ #pragma unroll
327+ for (int l = 0; l < src_loop_num; l++) {
328+
329+ #pragma unroll
330+ for (int v = 0; v < src_vec_size; v++) {
331+ src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(dequant_ptr)[n*src_loop_num + l][v];
332+ int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
333+
334+ #pragma unroll
335+ for (int c = 0; c < src_compress_size; c++) {
336+ uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
337+ float scale_value = *reinterpret_cast<float*>(scale_ptr + ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE))));
338+
339+ dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
340+ lut_id = (lut_id + 1) % LUT_NUM;
341+ }
342+ }
343+ }
344+
345+ #pragma unroll
346+ for (int l = 0; l < dst_loop_num; l++) {
347+ reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(mma_ptr)[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
348+ }
349+ }
350+ };
351+
352+ copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B(_,_,_,0));
353+ copy(params.tiled_copy_scale, tSgS(_,_,_,k_start_idx * BLK_K/params.group_size), frag_copy_Scale(_,_,_,0));
354+ copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), frag_copy_A(_,_,_,0));
355+
356+ if (prefetch_k < k_tile_count) {
357+ prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
358+ prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
359+ }
360+ prefetch_k++;
361+
362+ for (int k_tile = k_start_idx + 1, k_s = 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
363+ const int buf_idx = k_tile % 2;
364+
365+ dequant(start_lut_id, buf_idx);
366+
367+ copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B(_,_,_,buf_idx));
368+ copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), frag_copy_Scale(_,_,_,buf_idx));
369+ copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A(_,_,_,buf_idx));
370+
371+ if (prefetch_k < k_tile_count) {
372+ prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
373+ prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
374+ }
375+
376+ cute::gemm(tiled_mma, frag_copy_A(_,_,_,1-buf_idx), frag_copy_B(_,_,_,1-buf_idx), accumulators);
377+ barrier_wait(3);
378+ }
379+ cute::gemm(tiled_mma, frag_copy_A(_,_,_,1), frag_copy_B(_,_,_,1), accumulators);
380+ #else
263381 auto dequant_a = [&] (int start_lut_id){
264382 constexpr int N = decltype (cute::size<1 >(mma_B_a))::value;
265383 constexpr int K = decltype (cute::size (mma_B_a))::value / N;
266-
384+
267385 using src_compress_type = uint32_t ;
268386 using dst_compress_type = uint32_t ;
269387 constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
@@ -344,14 +462,6 @@ class gemm_4bit_cutlass_kernel {
344462 }
345463 };
346464
347- CUTLASS_PRAGMA_UNROLL
348- for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
349- prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
350- prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
351- }
352-
353- int start_lut_id = sg_idx % LUT_NUM ;
354-
355465 copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), frag_copy_B_a);
356466 copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + 0 ) * BLK_K /params.group_size ), frag_copy_Scale_a);
357467 copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx), frag_copy_A_a);
@@ -422,6 +532,7 @@ class gemm_4bit_cutlass_kernel {
422532 }
423533 cute::gemm (tiled_mma, mma_A_a, mma_B_b, accumulators);
424534 // barrier_wait(3);
535+ #endif
425536
426537 static constexpr int FragsM = get<0 >(SubgroupTileShape{}) / get<0 >(MmaAtomShape ()); // atom numbers per thread; A frags per sub_group
427538 static constexpr int FragsN = get<1 >(SubgroupTileShape{}) / get<1 >(MmaAtomShape ()); // atom numbers per thread; B frags per sub_group
0 commit comments