@@ -509,6 +509,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
509509 constexpr int src_loop_num = K / src_vec_size / src_compress_size;
510510 constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
511511 src_compress_type src[src_vec_size];
512+ ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
512513
513514 int lut_id = start_lut_id;
514515// if(sg_idx == 0){
@@ -517,7 +518,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
517518// }
518519// }
519520
520- int pre_num = 2 ;
521+ int pre_num = 1 ;
521522 #pragma unroll
522523 for (int n = 0 ; n < N; n++) {
523524 #pragma unroll
@@ -526,55 +527,75 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
526527
527528 #pragma unroll
528529 for (int v = 0 ; v < src_vec_size; v++) {
530+ // src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l][v]; //src[v];
529531 src_compress_type src_value = src[v];
530532 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
533+ #if 0
531534 int c = 0;
532535 uint16_t high_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
533536 lut_id = (lut_id + 1) % LUT_NUM;
534537 uint16_t low_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
535538
536539 c++;
540+ reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_1) << 16) | high_bits_1;
537541 uint16_t high_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
538542 lut_id = (lut_id + 1) % LUT_NUM;
539543 uint16_t low_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
540544
541545 c++;
542- reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast <uint32_t >(low_bits_1 ) << 16 ) | high_bits_1 ;
546+ reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_2 ) << 16) | high_bits_2 ;
543547 uint16_t high_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
544548 lut_id = (lut_id + 1) % LUT_NUM;
545549 uint16_t low_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
546550
547551 c++;
548- reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast <uint32_t >(low_bits_2 ) << 16 ) | high_bits_2 ;
552+ reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_3 ) << 16) | high_bits_3 ;
549553 uint16_t high_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
550554 lut_id = (lut_id + 1) % LUT_NUM;
551555 uint16_t low_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
552556
553557 c++;
554- reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast <uint32_t >(low_bits_3 ) << 16 ) | high_bits_3 ;
558+ reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_4 ) << 16) | high_bits_4 ;
555559 uint16_t high_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
556560 lut_id = (lut_id + 1) % LUT_NUM;
557561 uint16_t low_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
558562
559563 c++;
560- reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast <uint32_t >(low_bits_4 ) << 16 ) | high_bits_4 ;
564+ reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_5 ) << 16) | high_bits_5 ;
561565 uint16_t high_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
562566 lut_id = (lut_id + 1) % LUT_NUM;
563567 uint16_t low_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
564568
565569 c++;
566- reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast <uint32_t >(low_bits_5 ) << 16 ) | high_bits_5 ;
570+ reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_6 ) << 16) | high_bits_6 ;
567571 uint16_t high_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
568572 lut_id = (lut_id + 1) % LUT_NUM;
569573 uint16_t low_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
570574
571575 c++;
572- reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num ] = (static_cast <uint32_t >(low_bits_6 ) << 16 ) | high_bits_6 ;
576+ reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-1 ] = (static_cast<uint32_t>(low_bits_7 ) << 16) | high_bits_7 ;
573577 uint16_t high_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
574578 lut_id = (lut_id + 1) % LUT_NUM;
575579 uint16_t low_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
576- reinterpret_cast <uint32_t *>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-1 ] = (static_cast <uint32_t >(low_bits_7) << 16 ) | high_bits_7;
577580 reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits_8) << 16) | high_bits_8;
581+ #else
582+ #if 0
583+ for (int c = 0; c < src_compress_size/2; c++) {
584+ uint16_t high_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
585+ lut_id = (lut_id + 1) % LUT_NUM;
586+ uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
587+ reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
588+ }
589+ #else
590+ for (int c = 0 ; c < src_compress_size; c++) {
591+ uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
592+ float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
593+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
594+ lut_id = (lut_id + 1 ) % LUT_NUM ;
595+ }
596+ reinterpret_cast <sycl::vec<dst_compress_type, 4 >*>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast <sycl::vec<dst_compress_type, 4 >*>(dst)[v];
597+ #endif
598+ #endif
578599 }
579600 }
580601 }
0 commit comments