@@ -61,7 +61,7 @@ static constexpr float quant_map_static[16] = {
6161};
6262#endif
6363
64- using TileShape = Shape<_64, _128, _128 >;
64+ using TileShape = Shape<_64, _128, _64 >;
6565using TiledMma =
6666 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6767 Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
@@ -414,8 +414,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
414414 constexpr int dst_compress_size = 4 ; // cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
415415 constexpr int src_vec_size = 8 ; // (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
416416 constexpr int dst_vec_size = 16 ; // (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
417- constexpr int src_loop_num = 2 ; // K / src_vec_size / src_compress_size;
418- constexpr int dst_loop_num = 2 ; // K / dst_vec_size / dst_compress_size;
417+ constexpr int src_loop_num = 1 ; // K / src_vec_size / src_compress_size;
418+ constexpr int dst_loop_num = 1 ; // K / dst_vec_size / dst_compress_size;
419419
420420 src_compress_type src[src_loop_num * src_vec_size];
421421 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
@@ -441,29 +441,29 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
441441 dst[dst_base_idx + c] = static_cast <ElementMMA>(converted_value_1 * scale_value);
442442 }
443443
444- reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[1 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[1 ];
445- scale_value = fragment_scale (1 );
444+ // reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[1] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[1];
445+ // scale_value = fragment_scale(1);
446446 reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[0 ] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0 ];
447447
448- #pragma unroll
449- for (int v = src_vec_size; v < src_loop_num * src_vec_size; v++) {
450- int dst_base_idx = v * src_compress_size;
451- int c = 0 ;
452- uint8_t bit_value = (src[v] >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
453- float converted_value_1 = quant_map[bit_value];
454- float converted_value_2 = 0 .f ;
455- #pragma unroll
456- for (; c < src_compress_size-1 ;) {
457- converted_value_2 = converted_value_1;
458- c++;
459- bit_value = (src[v] >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
460- converted_value_1 = quant_map[bit_value];
461- dst[dst_base_idx + c-1 ] = static_cast <ElementMMA>(converted_value_2 * scale_value);
462- }
463- dst[dst_base_idx + c] = static_cast <ElementMMA>(converted_value_1 * scale_value);
464- }
465-
466- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[1 ] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[1 ];
448+ // #pragma unroll
449+ // for (int v = src_vec_size; v < src_loop_num * src_vec_size; v++) {
450+ // int dst_base_idx = v * src_compress_size;
451+ // int c = 0;
452+ // uint8_t bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
453+ // float converted_value_1 = quant_map[bit_value];
454+ // float converted_value_2 = 0.f;
455+ // #pragma unroll
456+ // for (; c < src_compress_size-1;) {
457+ // converted_value_2 = converted_value_1;
458+ // c++;
459+ // bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
460+ // converted_value_1 = quant_map[bit_value];
461+ // dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
462+ // }
463+ // dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
464+ // }
465+ //
466+ // reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[1] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[1];
467467
468468 };
469469#endif
0 commit comments