@@ -252,56 +252,56 @@ class gemm_4bit_cutlass_kernel {
252252 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
253253 int prefetch_k = k_start_idx;
254254
255- auto copy_and_dequant = [&] (int start_lut_id, int k_tile, int k_s){
256- copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B);
257- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale);
258- copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
255+ auto copy_and_dequant = [&] (int start_lut_id, int k_tile, int k_s){
256+ copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B);
257+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params.group_size ), frag_copy_Scale);
258+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
259259
260- constexpr int N = decltype (cute::size<1 >(mma_B))::value;
261- constexpr int K = decltype (cute::size (mma_B))::value / N;
260+ constexpr int N = decltype (cute::size<1 >(mma_B))::value;
261+ constexpr int K = decltype (cute::size (mma_B))::value / N;
262262
263- using src_compress_type = uint32_t ;
264- using dst_compress_type = uint32_t ;
265- constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
266- constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
267- constexpr int src_vec_size = 4 ;
268- constexpr int src_loop_num = K / src_vec_size / src_compress_size;
269-
270- constexpr int dst_vec_size = 4 ; // src_vec_size;
271- constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size; // src_loop_num;
272- ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
273-
274- int lut_id = start_lut_id;
263+ using src_compress_type = uint32_t ;
264+ using dst_compress_type = uint32_t ;
265+ constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
266+ constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
267+ constexpr int src_vec_size = 4 ;
268+ constexpr int src_loop_num = K / src_vec_size / src_compress_size;
269+
270+ constexpr int dst_vec_size = 4 ; // src_vec_size;
271+ constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size; // src_loop_num;
272+ ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
273+
274+ int lut_id = start_lut_id;
275+ #pragma unroll
276+ for (int n = 0 ; n < N; n++) {
275277 #pragma unroll
276- for (int n = 0 ; n < N; n++) {
277- #pragma unroll
278- for (int l = 0 ; l < src_loop_num; l++) {
278+ for (int l = 0 ; l < src_loop_num; l++) {
279279
280+ #pragma unroll
281+ for (int v = 0 ; v < src_vec_size; v++) {
282+ src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l][v];
283+ int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
280284 #pragma unroll
281- for (int v = 0 ; v < src_vec_size; v++) {
282- src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l][v];
283- int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
284- #pragma unroll
285- for (int c = 0 ; c < src_compress_size; c++) {
286- uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
287- float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
288- dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
289- lut_id = (lut_id + 1 ) % LUT_NUM ;
290- }
285+ for (int c = 0 ; c < src_compress_size; c++) {
286+ uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
287+ float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
288+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
289+ lut_id = (lut_id + 1 ) % LUT_NUM ;
291290 }
292291 }
293-
294- #pragma unroll
295- for (int l = 0 ; l < dst_loop_num; l++) {
296- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
297- }
298292 }
299293
300- if (prefetch_k < k_tile_count) {
301- prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
302- prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k)) ;
294+ # pragma unroll
295+ for ( int l = 0 ; l < dst_loop_num; l++) {
296+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>( cute::raw_pointer_cast (mma_B. data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l] ;
303297 }
304- };
298+ }
299+
300+ if (prefetch_k < k_tile_count) {
301+ prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
302+ prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
303+ }
304+ };
305305
306306 CUTLASS_PRAGMA_UNROLL
307307 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
0 commit comments