@@ -61,7 +61,7 @@ static constexpr float quant_map_static[16] = {
6161};
6262#endif
6363
64- using TileShape = Shape<_32, _128, _32 >;
64+ using TileShape = Shape<_32, _128, _128 >;
6565using TiledMma =
6666 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6767 Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
@@ -333,7 +333,7 @@ inline float dDequantizeNF4(unsigned char val) {
333333 using src_compress_type = uint64_t ;
334334 using dst_compress_type = uint64_t ;
335335 constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
336- constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 16
336+ constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
337337 constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; // 4, 16 -> max vec_size of sycl::vec
338338 constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; // 16, 16 -> max vec_size of sycl::vec
339339 constexpr int src_loop_num = K / src_vec_size / src_compress_size;
@@ -344,7 +344,7 @@ inline float dDequantizeNF4(unsigned char val) {
344344 ElementMMA* dst_slm = reinterpret_cast <ElementMMA*>(src + K);
345345#if 0
346346if(cute::thread0()) {
347- // printf("src = %x, gB_ptr = %x, dst_slm = %x \n", src, gB_ptr, dst_slm );
347+ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d \n", src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num );
348348// print("\n\n======================= SLM: \n");
349349// print(" src : "); print(src); print("\n");
350350// print(" gB_ptr : "); print(gB_ptr); print("\n");
@@ -362,21 +362,20 @@ if(cute::thread0()) {
362362 #pragma unroll
363363 for (int v = 0 ; v < src_vec_size; ++v) {
364364 src_compress_type src_value = reinterpret_cast <src_compress_type*>(src)[v];
365- int dst_idx = v * src_compress_size;
366- // float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + dst_idx / GROUP_SIZE);
365+ int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
367366 #pragma unroll
368367 for (int c = 0 ; c < src_compress_size; ++c) {
369368 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
370- float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_idx+ c) / GROUP_SIZE );
371- dst_slm[dst_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
372- // if(thread_idx==1 && m_coord==0 && n_coord==0 && l_coord==0) printf("dst_idx +c = %d, n * (BLK_K / GROUP_SIZE) + (dst_idx +c)/GROUP_SIZE) = %d, scale_value = %f\n",dst_idx +c, n * (BLK_K / GROUP_SIZE) + (dst_idx +c)/GROUP_SIZE, scale_value);
369+ float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + c) / GROUP_SIZE );
370+ dst_slm[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
371+ // if(thread_idx==1 && m_coord==0 && n_coord==0 && l_coord==0) printf("dst_base_idx +c = %d, n * (BLK_K / GROUP_SIZE) + (dst_base_idx +c)/GROUP_SIZE) = %d, scale_value = %f\n",dst_base_idx +c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx +c)/GROUP_SIZE, scale_value);
373372 }
374373 }
375374 }
376375
377376 #pragma unroll
378377 for (int l = 0 ; l < dst_loop_num; l++) {
379- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <const sycl::vec<dst_compress_type, dst_vec_size>*>(dst_slm)[0 ];
378+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <const sycl::vec<dst_compress_type, dst_vec_size>*>(dst_slm)[l ];
380379 }
381380 }
382381 };
0 commit comments