@@ -61,13 +61,14 @@ static constexpr float quant_map_static[16] = {
6161};
6262#endif
6363
64- using TileShape = Shape<_32, _128, _64 >;
64+ using TileShape = Shape<_32, _128, _128 >;
6565using TiledMma =
6666 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6767 Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
6868using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
7070constexpr int PipelineStages = 2 ;
71+ static constexpr auto GROUP_SIZE =64 ; // Block Quant Size
7172
7273using MmaAtomShape = typename TiledMma::AtomShape_MNK;
7374using WorkgroupTileShape = TileShape;
@@ -285,9 +286,10 @@ inline float dDequantizeNF4(unsigned char val) {
285286 #endif
286287 Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
287288#endif
288- static constexpr auto scale_traits_size = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
289- static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
290- using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
289+ static constexpr auto scale_shape_t = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
290+ static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
291+ static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE ;
292+ using FragScaleLayout = Layout<Shape<Int<scale_shape_t >, Int<scale_shape_n>, Int<scale_shape_k>>>; // [1, dequant_N, block_num]
291293 Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
292294
293295 // static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
@@ -314,8 +316,8 @@ inline float dDequantizeNF4(unsigned char val) {
314316
315317 auto tSgS = [&](){
316318 return make_tensor (make_inttuple_iter (make_coord (n_coord * BLK_N + get<2 >(thr_mma.thr_vmnk_ )*SG_QNT_WIDTH , 0 , 0 )),
317- make_layout (make_shape (Int<scale_traits_size >{}, Int<scale_traits_num >{}, _1{} , k_tile_count/k_reload_factor ),
318- make_stride (E<0 >{}*_16 {}, E<0 >{}*_16 {}, _0 {}, E<1 >{}*_1{})));
319+ make_layout (make_shape (Int<scale_shape_t >{}, Int<scale_shape_n >{}, scale_shape_k , k_tile_count * BLK_K /params. group_size ),
320+ make_stride (E<0 >{}*_32 {}, E<0 >{}*_32 {}, E< 1 >{}*_1 {}, E<1 >{}*_1{})));
319321
320322 }();
321323
@@ -340,28 +342,34 @@ inline float dDequantizeNF4(unsigned char val) {
340342 alignas (8 ) ElementB* src = reinterpret_cast <ElementB*>(smem_buf) + thread_idx * K * 5 ; // for K=64, 4 is hardcode for 128B alignment.
341343 const uint8_t * gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2 ;
342344 ElementMMA* dst_slm = reinterpret_cast <ElementMMA*>(src + K);
343- // if(cute::thread0()) {
344- // //printf("src = %x, gB_ptr = %x, dst_slm = %x\n", src, gB_ptr, dst_slm);
345+ #if 0
346+ if(cute::thread0()) {
347+ //printf("src = %x, gB_ptr = %x, dst_slm = %x\n", src, gB_ptr, dst_slm);
345348// print("\n\n======================= SLM: \n");
346349// print(" src : "); print(src); print("\n");
347350// print(" gB_ptr : "); print(gB_ptr); print("\n");
348351// print(" dst_slm : "); print(dst_slm); print("\n");
352+ // print(" fragment_scale: "); print(fragment_scale); print("\n");
349353// print("\n\n=======================\n\n");
350- // }
354+ }
355+ #endif
351356 #pragma unroll
352357 for (int n = 0 ; n < N; n++) {
353- float scale_value = fragment_scale (n);
358+ // float scale_value = fragment_scale(n);
354359 #pragma unroll
355360 for (int l = 0 ; l < src_loop_num; l++) {
356361 reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr )[n*src_loop_num + l];
357362 #pragma unroll
358363 for (int v = 0 ; v < src_vec_size; ++v) {
359364 src_compress_type src_value = reinterpret_cast <src_compress_type*>(src)[v];
360365 int dst_idx = v * src_compress_size;
366+ // float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + dst_idx / GROUP_SIZE);
361367 #pragma unroll
362368 for (int c = 0 ; c < src_compress_size; ++c) {
363369 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
370+ float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_idx+c) / GROUP_SIZE );
364371 dst_slm[dst_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
372+ // if(cute::thread0()) printf("dst_idx+c = %d, n * (BLK_K / GROUP_SIZE) + (dst_idx+c)/GROUP_SIZE) = %d, scale_value = %f\n",dst_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_idx+c)/GROUP_SIZE, scale_value);
365373 }
366374 }
367375 }
@@ -453,7 +461,7 @@ inline float dDequantizeNF4(unsigned char val) {
453461 copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
454462 dequant();
455463#else
456- copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) / k_reload_factor ), frag_copy_Scale);
464+ copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) * BLK_K /params. group_size ), frag_copy_Scale);
457465 copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
458466 dequant (k_tile);
459467#endif
0 commit comments