@@ -252,7 +252,11 @@ class gemm_4bit_cutlass_kernel {
252252 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
253253 int prefetch_k = k_start_idx;
254254
255- auto dequant = [&] (int start_lut_id){
255+ auto dequant = [&] (int start_lut_id, int k_tile, int k_s){
256+ copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B);
257+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale);
258+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
259+
256260 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
257261 constexpr int K = decltype (cute::size (mma_B))::value / N;
258262
@@ -303,11 +307,11 @@ class gemm_4bit_cutlass_kernel {
303307 int start_lut_id = sg_idx % LUT_NUM ;
304308
305309 for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
306- copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B);
307- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale);
308- copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
310+ // copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
311+ // copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
312+ // copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
309313
310- dequant (start_lut_id);
314+ dequant (start_lut_id, k_tile, k_s );
311315
312316 if (prefetch_k < k_tile_count) {
313317 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
0 commit comments