@@ -531,34 +531,41 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
531531 uint16_t high_bits_1 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
532532 lut_id = (lut_id + 1 ) % LUT_NUM ;
533533 uint16_t low_bits_1 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
534+
534535 c++;
535536 uint16_t high_bits_2 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
536537 lut_id = (lut_id + 1 ) % LUT_NUM ;
537538 uint16_t low_bits_2 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
539+
538540 c++;
539541 uint16_t high_bits_3 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
540542 lut_id = (lut_id + 1 ) % LUT_NUM ;
541543 uint16_t low_bits_3 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
544+
542545 c++;
543546 reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_1) << 16 ) | high_bits_1;
544547 uint16_t high_bits_4 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
545548 lut_id = (lut_id + 1 ) % LUT_NUM ;
546549 uint16_t low_bits_4 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
550+
547551 c++;
548552 reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_2) << 16 ) | high_bits_2;
549553 uint16_t high_bits_5 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
550554 lut_id = (lut_id + 1 ) % LUT_NUM ;
551555 uint16_t low_bits_5 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
556+
552557 c++;
553558 reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_3) << 16 ) | high_bits_3;
554559 uint16_t high_bits_6 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
555560 lut_id = (lut_id + 1 ) % LUT_NUM ;
556561 uint16_t low_bits_6 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
562+
557563 c++;
558564 reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_4) << 16 ) | high_bits_4;
559565 uint16_t high_bits_7 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
560566 lut_id = (lut_id + 1 ) % LUT_NUM ;
561567 uint16_t low_bits_7 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE )));
568+
562569 c++;
563570 reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3 ] = (static_cast <uint32_t >(low_bits_5) << 16 ) | high_bits_5;
564571 uint16_t high_bits_8 = sycl::bit_cast<uint16_t >(static_cast <ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1 ))) & 0xf ] * fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE )));
0 commit comments