@@ -208,24 +208,32 @@ class gemm_4bit_cutlass_kernel {
208208 Tensor tCgA = thr_mma.partition_A (gA );
209209 Tensor tCgB = thr_mma.partition_B (gB ); // values for each_thread (FrgV,(RestN,RestK),*)
210210
211- Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
212- Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
211+ Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
212+ Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
213+ Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout ());
213214
214- Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout ());
215+ Tensor mma_A_b = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
216+ Tensor mma_B_b = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
217+ Tensor dequant_frag_b = make_tensor<ElementB>(mma_B_b.layout ());
215218
216219 static constexpr auto scale_shape_t = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
217220 static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
218221 static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE ;
219222 using FragScaleLayout = Layout<Shape<Int<scale_shape_t >, Int<scale_shape_n>, Int<scale_shape_k>>>; // [1, dequant_N, block_num]
220- Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
223+ Tensor fragment_scale_a = make_tensor<ElementScale>(FragScaleLayout{});
224+ Tensor fragment_scale_b = make_tensor<ElementScale>(FragScaleLayout{});
221225
222226// static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
223227// static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
224228// static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
225229
226- Tensor frag_copy_A = thr_copy_A.retile_D (mma_A);
227- Tensor frag_copy_B = thr_copy_B.retile_D (dequant_frag);
228- Tensor frag_copy_Scale = thr_copy_scale.retile_D (fragment_scale);
230+ Tensor frag_copy_A_a = thr_copy_A.retile_D (mma_A_a);
231+ Tensor frag_copy_B_a = thr_copy_B.retile_D (dequant_frag_a);
232+ Tensor frag_copy_Scale_a = thr_copy_scale.retile_D (fragment_scale_a);
233+
234+ Tensor frag_copy_A_b = thr_copy_A.retile_D (mma_A_b);
235+ Tensor frag_copy_B_b = thr_copy_B.retile_D (dequant_frag_b);
236+ Tensor frag_copy_Scale_b = thr_copy_scale.retile_D (fragment_scale_b);
229237
230238 Tensor tAgA = thr_copy_A.retile_S (tCgA);
231239 Tensor tBgB = thr_copy_B.retile_S (tCgB);
@@ -252,13 +260,9 @@ class gemm_4bit_cutlass_kernel {
252260 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
253261 int prefetch_k = k_start_idx;
254262
255- auto copy_and_dequant = [&] (int start_lut_id, int k_tile, int k_s){
256- copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B);
257- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale);
258- copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
259-
260- constexpr int N = decltype (cute::size<1 >(mma_B))::value;
261- constexpr int K = decltype (cute::size (mma_B))::value / N;
263+ auto dequant = [&] (int start_lut_id, int k_tile){
264+ constexpr int N = decltype (cute::size<1 >(mma_B_a))::value;
265+ constexpr int K = decltype (cute::size (mma_B_a))::value / N;
262266
263267 using src_compress_type = uint32_t ;
264268 using dst_compress_type = uint32_t ;
@@ -279,12 +283,12 @@ class gemm_4bit_cutlass_kernel {
279283
280284 #pragma unroll
281285 for (int v = 0 ; v < src_vec_size; v++) {
282- src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag .data ()))[n*src_loop_num + l][v];
286+ src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(k_tile % 2 != 0 ? cute::raw_pointer_cast (dequant_frag_a. data ()): cute::raw_pointer_cast (dequant_frag_b .data ()))[n*src_loop_num + l][v];
283287 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
284288 #pragma unroll
285289 for (int c = 0 ; c < src_compress_size; c++) {
286290 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
287- float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
291+ float scale_value = k_tile % 2 != 0 ? fragment_scale_a ((n * BLK_K + dst_base_idx + c) >> ( 31 - std::countl_zero< unsigned int >( GROUP_SIZE ))) : fragment_scale_b ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
288292 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
289293 lut_id = (lut_id + 1 ) % LUT_NUM ;
290294 }
@@ -293,14 +297,9 @@ class gemm_4bit_cutlass_kernel {
293297
294298 #pragma unroll
295299 for (int l = 0 ; l < dst_loop_num; l++) {
296- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B .data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
300+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (k_tile % 2 != 0 ? mma_B_a. data () : mma_B_b .data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
297301 }
298302 }
299-
300- if (prefetch_k < k_tile_count) {
301- prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
302- prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
303- }
304303 };
305304
306305 CUTLASS_PRAGMA_UNROLL
@@ -311,21 +310,34 @@ class gemm_4bit_cutlass_kernel {
311310
312311 int start_lut_id = sg_idx % LUT_NUM ;
313312
314- for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
315- // copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
316- // copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
317- // copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
313+ copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), frag_copy_B_a);
314+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + 0 ) * BLK_K /params.group_size ), frag_copy_Scale_a);
315+ copy (params.tiled_copy_a , tAgA (_,_,_,0 ), frag_copy_A_a);
316+
317+ for (int k_tile = k_start_idx + 1 , k_s = 0 + 1 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
318+ if (k_tile % 2 != 0 ){
319+ copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B_b);
320+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale_b);
321+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A_b);
322+ } else {
323+ copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), frag_copy_B_a);
324+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale_a);
325+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A_a);
326+ }
318327
319- copy_and_dequant (start_lut_id, k_tile, k_s);
320328
321- // if (prefetch_k < k_tile_count) {
322- // prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
323- // prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
324- // }
329+ dequant (start_lut_id, k_tile);
325330
326- cute::gemm (tiled_mma, mma_A, mma_B, accumulators);
331+ if (prefetch_k < k_tile_count) {
332+ prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
333+ prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
334+ }
335+
336+ k_tile % 2 != 0 ? cute::gemm (tiled_mma, mma_A_a, mma_B_a, accumulators) : cute::gemm (tiled_mma, mma_A_b, mma_B_b, accumulators);
327337 barrier_wait (3 );
328338 }
339+ cute::gemm (tiled_mma, mma_A_b, mma_B_b, accumulators);
340+ barrier_wait (3 );
329341
330342 static constexpr int FragsM = get<0 >(SubgroupTileShape{}) / get<0 >(MmaAtomShape ()); // atom numbers per thread; A frags per sub_group
331343 static constexpr int FragsN = get<1 >(SubgroupTileShape{}) / get<1 >(MmaAtomShape ()); // atom numbers per thread; B frags per sub_group
0 commit comments