@@ -232,7 +232,7 @@ inline float dDequantizeNF4(unsigned char val) {
232232 ? BlockIdxX () : BlockIdxY ();
233233 const int l_coord = BlockIdxZ ();
234234
235- #if 0
235+ #if 1
236236 // float* quant_map;
237237 // static constexpr std::array<float, 16> quant_map{};
238238 // {
@@ -277,7 +277,7 @@ inline float dDequantizeNF4(unsigned char val) {
277277 Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
278278 Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
279279
280- #if 0 //SLM: 0, register: 1
280+ #if 1 // SLM: 0, register: 1
281281 #if 1 // fragement register
282282 Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout ());
283283 #else //common register
@@ -324,7 +324,7 @@ inline float dDequantizeNF4(unsigned char val) {
324324 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
325325 int prefetch_k = k_start_idx;
326326
327- #if 1 // SLM
327+ #if 0 //SLM
328328 #if 1
329329 auto dequant = [&] (int k_tile) {
330330 constexpr int N = decltype(cute::size<1>(mma_B))::value;
@@ -386,11 +386,10 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
386386 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
387387 constexpr int K = decltype (cute::size (mma_B))::value / N;
388388
389-
390389 using src_compress_type = uint64_t ;
391390 using dst_compress_type = uint64_t ;
392391 constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
393- constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //16
392+ constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
394393 constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; // 4, 16 -> max vec_size of sycl::vec
395394 constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; // 16, 16 -> max vec_size of sycl::vec
396395 constexpr int src_loop_num = K / src_vec_size / src_compress_size;
@@ -399,11 +398,11 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
399398 // if(cute::thread0()) printf("params.group_size = %d, k_reload_factor = %d, k_tile_count = %d, N = %d, K = %d, src_compress_size = %d, src_vec_size = %d, dst_compress_size = %d, dst_vec_size = %d\n",params.group_size, k_reload_factor, k_tile_count, N, K, src_compress_size, src_vec_size, dst_compress_size, dst_vec_size);
400399
401400 src_compress_type src[src_vec_size];
402- ElementMMA dst[dst_compress_size * dst_vec_size];
401+ ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
403402
404403 #pragma unroll
405404 for (int n = 0 ; n < N; n++) {
406- float scale_value = fragment_scale(n);
405+ // float scale_value = fragment_scale(n);
407406 #pragma unroll
408407 for (int l = 0 ; l < src_loop_num; l++) {
409408 // src_compress_type src[src_vec_size];
@@ -412,18 +411,19 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
412411 #pragma unroll
413412 for (int v = 0 ; v < src_vec_size; v++) {
414413 src_compress_type src_value = src[v];
415- int dst_idx = v * src_compress_size;
414+ int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
416415 #pragma unroll
417416 for (int c = 0 ; c < src_compress_size; c++) {
418417 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
419- dst[dst_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
418+ float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + c) / GROUP_SIZE );
419+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
420420 }
421421 }
422422 }
423423
424424 #pragma unroll
425425 for (int l = 0 ; l < dst_loop_num; l++) {
426- reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0 ];
426+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l ];
427427 }
428428 }
429429 };
@@ -454,9 +454,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
454454 }
455455
456456 for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
457- #if 0 //SLM: 0, register: 1
457+ #if 1 // SLM: 0, register: 1
458458 copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B);
459- copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) / k_reload_factor ), frag_copy_Scale);
459+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params. group_size ), frag_copy_Scale);
460460 copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
461461 dequant ();
462462#else
@@ -501,8 +501,11 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
501501
502502 using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS >;
503503
504- // static constexpr int smem_size= (16) * sizeof(float);
504+ #if 1
505+ static constexpr int smem_size= (16 ) * sizeof (float );
506+ #else
505507 static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
508+ #endif
506509 size_t max_slm_size = q.get_device ().get_info <sycl::info::device::local_mem_size>();
507510 assert (smem_size <= max_slm_size);
508511
0 commit comments