@@ -588,7 +588,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
588588 uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
589589 reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
590590 }
591- #elif 1
591+ #elif 0
592592 #pragma unroll
593593 for (int c = 0; c < src_compress_size; c++) {
594594 uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
@@ -607,24 +607,21 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
607607 //reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
608608
609609 }
610- //#else
611- // #pragma unroll
612- // for (int c = 0; c < src_compress_size; c++) {
613- // uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
614- // float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
615- // dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
616- // lut_id = (lut_id + 1) % LUT_NUM;
617- // }
618- // }
619- // }
620- //
621- // #pragma unroll
622- // for (int l = 0; l < dst_loop_num; l++) {
623- // reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
624- // //reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
625- // //reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
626- //
627- // }
610+ #else
611+ #pragma unroll
612+ for (int c = 0 ; c < src_compress_size; c++) {
613+ uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
614+ float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
615+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
616+ lut_id = (lut_id + 1 ) % LUT_NUM ;
617+ }
618+ }
619+ }
620+
621+ #pragma unroll
622+ for (int l = 0 ; l < dst_loop_num; l++) {
623+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
624+ }
628625#endif
629626 }
630627 };
0 commit comments