@@ -68,7 +68,7 @@ using TiledMma =
6868using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
7070constexpr int PipelineStages = 2 ;
71- static constexpr auto GROUP_SIZE =64 ; // Block Quant Size
71+ static constexpr auto GROUP_SIZE =32 ; // Block Quant Size
7272
7373using MmaAtomShape = typename TiledMma::AtomShape_MNK;
7474using WorkgroupTileShape = TileShape;
@@ -295,6 +295,7 @@ inline float dDequantizeNF4(unsigned char val) {
295295 // static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
296296 static_assert (std::is_same_v<typename decltype (mma_A)::value_type, ElementMMA>);
297297 static_assert (std::is_same_v<typename decltype (mma_B)::value_type, ElementMMA>);
298+ // static_assert(params.group_size, GROUP_SIZE);
298299
299300 Tensor frag_copy_A = thr_copy_A.retile_D (mma_A);
300301 // Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
@@ -316,10 +317,11 @@ inline float dDequantizeNF4(unsigned char val) {
316317
317318 auto tSgS = [&](){
318319 return make_tensor (make_inttuple_iter (make_coord (n_coord * BLK_N + get<2 >(thr_mma.thr_vmnk_ )*SG_QNT_WIDTH , 0 , 0 )),
319- make_layout (make_shape (Int<scale_shape_t >{}, Int<scale_shape_n>{}, scale_shape_k , k_tile_count * BLK_K /params.group_size ),
320- make_stride (E<0 >{}*(scale_shape_n * scale_shape_k * DispatchPolicy::SubgroupSize ), E<0 >{}*(scale_shape_k * DispatchPolicy::SubgroupSize ), E<1 >{}*_1 {}, E<1 >{}*_1{})));
320+ make_layout (make_shape (Int<scale_shape_t >{}, Int<scale_shape_n>{}, 1 , k_tile_count * BLK_K /params.group_size ),
321+ make_stride (E<0 >{}*(scale_shape_n * k_tile_count * BLK_K /params. group_size ), E<0 >{}*(k_tile_count * BLK_K /params. group_size ), E<0 >{}*_0 {}, E<1 >{}*_1{})));
321322
322323 }();
324+ if (cute::thread0 ()) printf (" scale_shape_t = %d, scale_shape_n = %d, scale_shape_k = %d, k_tile_count = %d, k_tile_count * BLK_K/params.group_size = %d, scale_shape_n * scale_shape_k * DispatchPolicy::SubgroupSize = %d, scale_shape_k * DispatchPolicy::SubgroupSize = %d\n " ,/* static_cast<int>(get<2>(thr_mma.thr_vmnk_)), static_cast<int>(SG_QNT_WIDTH),*/ scale_shape_t , scale_shape_n, scale_shape_k, k_tile_count, k_tile_count * BLK_K /params.group_size , scale_shape_n * scale_shape_k * DispatchPolicy::SubgroupSize, scale_shape_k * DispatchPolicy::SubgroupSize);
323325
324326 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
325327 int prefetch_k = k_start_idx;
@@ -385,6 +387,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
385387 auto dequant = [&] {
386388 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
387389 constexpr int K = decltype (cute::size (mma_B))::value / N;
390+ // if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
388391
389392 using src_compress_type = uint64_t ;
390393 using dst_compress_type = uint64_t ;
@@ -422,9 +425,12 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
422425 #pragma unroll
423426 for (int c = 0 ; c < src_compress_size; c++) {
424427 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
425- float scale_value = 1 . 0f ; // fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
428+ float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + c) / GROUP_SIZE );
426429 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
427- // if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) printf("n = %d, src_l = %d, dst_base_idx+c = %d, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE) = %d, scale_value = %f\n", n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
430+ if (1 ){ // thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0){
431+ printf (" tid = %d, m_coord = %d, n_coord = %d, l_coord = %d, n = %d, src_l = %d, dst_dx = %d, scale_idx = %d, scale_value = %f\n " , thread_idx, m_coord, n_coord, l_coord, n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE ) + (dst_base_idx+c)/GROUP_SIZE , scale_value);
432+ // print(" scale_value : "); print(scale_value); print("\n");
433+ }
428434 }
429435 }
430436 }
@@ -536,6 +542,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
536542 sycl::queue q = *stream;
537543
538544 using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS >;
545+ std::cout<<" group_size = " <<blocksize<<std::endl;
539546
540547#if 1
541548 static constexpr int smem_size= (16 ) * sizeof (float );
0 commit comments