@@ -208,28 +208,32 @@ class gemm_4bit_cutlass_kernel {
208208 Tensor tCgA = thr_mma.partition_A (gA );
209209 Tensor tCgB = thr_mma.partition_B (gB ); // values for each_thread (FrgV,(RestN,RestK),*)
210210
211- Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
211+ Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
212212 Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
213213 Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout ());
214214
215+ Tensor mma_A_b = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
215216 Tensor mma_B_b = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
216217 Tensor dequant_frag_b = make_tensor<ElementB>(mma_B_b.layout ());
217218
218219 static constexpr auto scale_shape_t = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
219220 static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
220221 static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE ;
221222 using FragScaleLayout = Layout<Shape<Int<scale_shape_t >, Int<scale_shape_n>, Int<scale_shape_k>>>; // [1, dequant_N, block_num]
222- Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
223+ Tensor fragment_scale_a = make_tensor<ElementScale>(FragScaleLayout{});
224+ Tensor fragment_scale_b = make_tensor<ElementScale>(FragScaleLayout{});
223225
224226// static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
225227// static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
226228// static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
227229
228- Tensor frag_copy_A = thr_copy_A.retile_D (mma_A );
230+ Tensor frag_copy_A_a = thr_copy_A.retile_D (mma_A_a );
229231 Tensor frag_copy_B_a = thr_copy_B.retile_D (dequant_frag_a);
230- Tensor frag_copy_Scale = thr_copy_scale.retile_D (fragment_scale );
232+ Tensor frag_copy_Scale_a = thr_copy_scale.retile_D (fragment_scale_a );
231233
234+ Tensor frag_copy_A_b = thr_copy_A.retile_D (mma_A_b);
232235 Tensor frag_copy_B_b = thr_copy_B.retile_D (dequant_frag_b);
236+ Tensor frag_copy_Scale_b = thr_copy_scale.retile_D (fragment_scale_b);
233237
234238 Tensor tAgA = thr_copy_A.retile_S (tCgA);
235239 Tensor tBgB = thr_copy_B.retile_S (tCgB);
@@ -284,7 +288,7 @@ class gemm_4bit_cutlass_kernel {
284288 #pragma unroll
285289 for (int c = 0 ; c < src_compress_size; c++) {
286290 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
287- float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
291+ float scale_value = fragment_scale_a ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
288292 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
289293 lut_id = (lut_id + 1 ) % LUT_NUM ;
290294 }
@@ -326,7 +330,7 @@ class gemm_4bit_cutlass_kernel {
326330 #pragma unroll
327331 for (int c = 0 ; c < src_compress_size; c++) {
328332 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
329- float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
333+ float scale_value = fragment_scale_b ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
330334 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
331335 lut_id = (lut_id + 1 ) % LUT_NUM ;
332336 }
@@ -349,33 +353,75 @@ class gemm_4bit_cutlass_kernel {
349353 int start_lut_id = sg_idx % LUT_NUM ;
350354
351355 copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), frag_copy_B_a);
352- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + 0 ) * BLK_K /params.group_size ), frag_copy_Scale );
353- copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx), frag_copy_A );
356+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + 0 ) * BLK_K /params.group_size ), frag_copy_Scale_a );
357+ copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx), frag_copy_A_a );
354358
355- for (int k_tile = k_start_idx + 1 , k_s = 0 + 1 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
356- bool is_odd_tile = k_tile % 2 != 0 ;
359+ if (prefetch_k < k_tile_count) {
360+ prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
361+ prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
362+ }
363+
364+ prefetch_k++;
357365
358- if (is_odd_tile){
366+ for (int k_tile = k_start_idx + 1 , k_s = 0 + 1 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
367+ if (k_tile % 2 != 0 ){
368+ dequant_a (start_lut_id);
359369 copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B_b);
370+
371+ // dequant_a(start_lut_id);
372+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale_b);
373+
374+ // dequant_a(start_lut_id);
375+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A_b);
376+
377+ // dequant_a(start_lut_id);
378+ if (prefetch_k < k_tile_count) {
379+ prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
380+ prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
381+ }
382+
383+ // dequant_a(start_lut_id);
384+ cute::gemm (tiled_mma, mma_A_a, mma_B_a, accumulators);
385+ barrier_wait (3 );
386+
387+ // copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_a);
388+ // copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_a);
389+
390+ // if (prefetch_k < k_tile_count) {
391+ // prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
392+ // //prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
393+ // }
360394 } else {
395+ dequant_b (start_lut_id);
361396 copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B_a);
362- }
363397
364- is_odd_tile ? dequant_a (start_lut_id) : dequant_b (start_lut_id);
398+ // dequant_b(start_lut_id);
399+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale_a);
365400
366- if (prefetch_k < k_tile_count) {
367- prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
368- prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
369- }
370-
371- cute::gemm (tiled_mma, mma_A, is_odd_tile ? mma_B_a : mma_B_b, accumulators);
372- barrier_wait (3 );
401+ // dequant_b(start_lut_id);
402+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A_a);
373403
374- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale);
375- copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
404+ // dequant_b(start_lut_id);
405+ if (prefetch_k < k_tile_count) {
406+ prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
407+ prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
408+ }
409+
410+ // dequant_b(start_lut_id);
411+ cute::gemm (tiled_mma, mma_A_b, mma_B_b, accumulators);
412+ barrier_wait (3 );
413+
414+ // copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_a);
415+ // copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_a);
416+
417+ // if (prefetch_k < k_tile_count) {
418+ // prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
419+ // //prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
420+ // }
421+ }
376422 }
377- cute::gemm (tiled_mma, mma_A , mma_B_b, accumulators);
378- barrier_wait (3 );
423+ cute::gemm (tiled_mma, mma_A_a , mma_B_b, accumulators);
424+ // barrier_wait(3);
379425
380426 static constexpr int FragsM = get<0 >(SubgroupTileShape{}) / get<0 >(MmaAtomShape ()); // atom numbers per thread; A frags per sub_group
381427 static constexpr int FragsN = get<1 >(SubgroupTileShape{}) / get<1 >(MmaAtomShape ()); // atom numbers per thread; B frags per sub_group
0 commit comments