@@ -527,19 +527,46 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
527527 for (int v = 0 ; v < src_vec_size; v++) {
528528 src_compress_type src_value = src[v];
529529 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
530- #pragma unroll
531- for (int c = 0 ; c < src_compress_size/2 ; c++) {
532- uint8_t high = (src_value >> (4 * (c * 2 + 1 ))) & 0xf ;
533- uint8_t low = (src_value >> (4 * (c * 2 ))) & 0xf ;
534- float ts_high = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE );
535- float ts_low = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE );
536-
537- uint16_t high_bits = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][high] * ts_high));
538- uint16_t low_bits = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][low] * ts_low));
539- reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast <uint32_t >(low_bits) << 16 ) | high_bits;
540-
541- lut_id = (lut_id + 1 ) % LUT_NUM ;
542- }
530+ int c = 0 ;
531+ uint16_t high_bits_1 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
532+ lut_id = (lut_id + 1 ) % LUT_NUM ;
533+ uint16_t low_bits_1 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
534+ c++;
535+ uint16_t high_bits_2 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
536+ lut_id = (lut_id + 1 ) % LUT_NUM ;
537+ uint16_t low_bits_2 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
538+ c++;
539+ uint16_t high_bits_3 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
540+ lut_id = (lut_id + 1 ) % LUT_NUM ;
541+ uint16_t low_bits_3 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
542+ c++;
543+ reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_1) << 16 ) | high_bits_1;
544+ uint16_t high_bits_4 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
545+ lut_id = (lut_id + 1 ) % LUT_NUM ;
546+ uint16_t low_bits_4 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
547+ c++;
548+ reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_2) << 16 ) | high_bits_2;
549+ uint16_t high_bits_5 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
550+ lut_id = (lut_id + 1 ) % LUT_NUM ;
551+ uint16_t low_bits_5 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
552+ c++;
553+ reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_3) << 16 ) | high_bits_3;
554+ uint16_t high_bits_6 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
555+ lut_id = (lut_id + 1 ) % LUT_NUM ;
556+ uint16_t low_bits_6 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
557+ c++;
558+ reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_4) << 16 ) | high_bits_4;
559+ uint16_t high_bits_7 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
560+ lut_id = (lut_id + 1 ) % LUT_NUM ;
561+ uint16_t low_bits_7 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
562+ c++;
563+ reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_5) << 16 ) | high_bits_5;
564+ uint16_t high_bits_8 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
565+ lut_id = (lut_id + 1 ) % LUT_NUM ;
566+ uint16_t low_bits_8 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
567+ reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-2 ] = (static_cast <uint32_t >(low_bits_6) << 16 ) | high_bits_6;
568+ reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-1 ] = (static_cast <uint32_t >(low_bits_7) << 16 ) | high_bits_7;
569+ reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast <uint32_t >(low_bits_8) << 16 ) | high_bits_8;
543570 }
544571 }
545572 }
0 commit comments