@@ -208,32 +208,28 @@ class gemm_4bit_cutlass_kernel {
208208 Tensor tCgA = thr_mma.partition_A (gA );
209209 Tensor tCgB = thr_mma.partition_B (gB ); // values for each_thread (FrgV,(RestN,RestK),*)
210210
211- Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
211+ Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
212212 Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
213213 Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout ());
214214
215- Tensor mma_A_b = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
216215 Tensor mma_B_b = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
217216 Tensor dequant_frag_b = make_tensor<ElementB>(mma_B_b.layout ());
218217
219218 static constexpr auto scale_shape_t = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
220219 static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
221220 static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE ;
222221 using FragScaleLayout = Layout<Shape<Int<scale_shape_t >, Int<scale_shape_n>, Int<scale_shape_k>>>; // [1, dequant_N, block_num]
223- Tensor fragment_scale_a = make_tensor<ElementScale>(FragScaleLayout{});
224- Tensor fragment_scale_b = make_tensor<ElementScale>(FragScaleLayout{});
222+ Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
225223
226224// static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
227225// static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
228226// static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
229227
230- Tensor frag_copy_A_a = thr_copy_A.retile_D (mma_A_a );
228+ Tensor frag_copy_A = thr_copy_A.retile_D (mma_A );
231229 Tensor frag_copy_B_a = thr_copy_B.retile_D (dequant_frag_a);
232- Tensor frag_copy_Scale_a = thr_copy_scale.retile_D (fragment_scale_a );
230+ Tensor frag_copy_Scale = thr_copy_scale.retile_D (fragment_scale );
233231
234- Tensor frag_copy_A_b = thr_copy_A.retile_D (mma_A_b);
235232 Tensor frag_copy_B_b = thr_copy_B.retile_D (dequant_frag_b);
236- Tensor frag_copy_Scale_b = thr_copy_scale.retile_D (fragment_scale_b);
237233
238234 Tensor tAgA = thr_copy_A.retile_S (tCgA);
239235 Tensor tBgB = thr_copy_B.retile_S (tCgB);
@@ -283,12 +279,12 @@ class gemm_4bit_cutlass_kernel {
283279
284280 #pragma unroll
285281 for (int v = 0 ; v < src_vec_size; v++) {
286- src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(k_tile % 2 != 0 ? cute::raw_pointer_cast ( dequant_frag_a.data ()): cute::raw_pointer_cast ( dequant_frag_b.data ()))[n*src_loop_num + l][v];
282+ src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast ( k_tile % 2 != 0 ? dequant_frag_a.data () : dequant_frag_b.data ()))[n*src_loop_num + l][v];
287283 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
288284 #pragma unroll
289285 for (int c = 0 ; c < src_compress_size; c++) {
290286 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
291- float scale_value = k_tile % 2 != 0 ? fragment_scale_a ((n * BLK_K + dst_base_idx + c) >> ( 31 - std::countl_zero< unsigned int >( GROUP_SIZE ))) : fragment_scale_b ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
287+ float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
292288 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
293289 lut_id = (lut_id + 1 ) % LUT_NUM ;
294290 }
@@ -311,32 +307,30 @@ class gemm_4bit_cutlass_kernel {
311307 int start_lut_id = sg_idx % LUT_NUM ;
312308
313309 copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), frag_copy_B_a);
314- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + 0 ) * BLK_K /params.group_size ), frag_copy_Scale_a );
315- copy (params.tiled_copy_a , tAgA (_,_,_,0 ), frag_copy_A_a );
310+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + 0 ) * BLK_K /params.group_size ), frag_copy_Scale );
311+ copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx ), frag_copy_A );
316312
317313 for (int k_tile = k_start_idx + 1 , k_s = 0 + 1 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
318314 if (k_tile % 2 != 0 ){
319315 copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B_b);
320- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale_b);
321- copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A_b);
322316 } else {
323- copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), frag_copy_B_a);
324- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale_a);
325- copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A_a);
317+ copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B_a);
326318 }
327319
328-
329320 dequant (start_lut_id, k_tile);
330321
331322 if (prefetch_k < k_tile_count) {
332323 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
333324 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
334325 }
335326
336- k_tile % 2 != 0 ? cute::gemm (tiled_mma, mma_A_a, mma_B_a, accumulators) : cute::gemm (tiled_mma, mma_A_b, mma_B_b, accumulators);
327+ cute::gemm (tiled_mma, mma_A, k_tile % 2 != 0 ? mma_B_a : mma_B_b, accumulators);
337328 barrier_wait (3 );
329+
330+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale);
331+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
338332 }
339- cute::gemm (tiled_mma, mma_A_b , mma_B_b, accumulators);
333+ cute::gemm (tiled_mma, mma_A , mma_B_b, accumulators);
340334 barrier_wait (3 );
341335
342336 static constexpr int FragsM = get<0 >(SubgroupTileShape{}) / get<0 >(MmaAtomShape ()); // atom numbers per thread; A frags per sub_group
0 commit comments