@@ -578,26 +578,42 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
578578 lut_id = (lut_id + 1) % LUT_NUM;
579579 uint16_t low_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
580580 reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits_8) << 16) | high_bits_8;
581- #else
582- #if 0
581+ #elif 0
583582 for (int c = 0; c < src_compress_size/2; c++) {
584583 uint16_t high_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
585584 lut_id = (lut_id + 1) % LUT_NUM;
586585 uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
587586 reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
588587 }
589- # else
588+ #elif 0
590589 for (int c = 0; c < src_compress_size; c++) {
591590 uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
592591 float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
593592 dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
594593 lut_id = (lut_id + 1) % LUT_NUM;
595594 }
596595 reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(dst)[v];
597- #endif
598- #endif
599596 }
600597 }
598+ #else
599+ #pragma unroll
600+ for (int c = 0 ; c < src_compress_size; c++) {
601+ uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
602+ float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
603+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
604+ lut_id = (lut_id + 1 ) % LUT_NUM ;
605+ }
606+ }
607+ }
608+
609+ #pragma unroll
610+ for (int l = 0 ; l < dst_loop_num / 4 ; l++) {
611+ // reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
612+ // reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
613+ reinterpret_cast <sycl::vec<dst_compress_type, 4 >*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, 4 >*>(dst)[l];
614+
615+ }
616+ #endif
601617 }
602618 };
603619#endif
0 commit comments