Skip to content

Commit 16978a9

Browse files
committed
save code
1 parent a9e6d94 commit 16978a9

1 file changed

Lines changed: 40 additions & 13 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -527,19 +527,46 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
527527
for (int v = 0; v < src_vec_size; v++) {
528528
src_compress_type src_value = src[v];
529529
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
530-
#pragma unroll
531-
for (int c = 0; c < src_compress_size/2; c++) {
532-
uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
533-
uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
534-
float ts_high = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE);
535-
float ts_low = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE);
536-
537-
uint16_t high_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][high] * ts_high));
538-
uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][low] * ts_low));
539-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
540-
541-
lut_id = (lut_id + 1) % LUT_NUM;
542-
}
530+
int c = 0;
531+
uint16_t high_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
532+
lut_id = (lut_id + 1) % LUT_NUM;
533+
uint16_t low_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
534+
c++;
535+
uint16_t high_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
536+
lut_id = (lut_id + 1) % LUT_NUM;
537+
uint16_t low_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
538+
c++;
539+
uint16_t high_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
540+
lut_id = (lut_id + 1) % LUT_NUM;
541+
uint16_t low_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
542+
c++;
543+
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_1) << 16) | high_bits_1;
544+
uint16_t high_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
545+
lut_id = (lut_id + 1) % LUT_NUM;
546+
uint16_t low_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
547+
c++;
548+
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_2) << 16) | high_bits_2;
549+
uint16_t high_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
550+
lut_id = (lut_id + 1) % LUT_NUM;
551+
uint16_t low_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
552+
c++;
553+
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_3) << 16) | high_bits_3;
554+
uint16_t high_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
555+
lut_id = (lut_id + 1) % LUT_NUM;
556+
uint16_t low_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
557+
c++;
558+
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_4) << 16) | high_bits_4;
559+
uint16_t high_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
560+
lut_id = (lut_id + 1) % LUT_NUM;
561+
uint16_t low_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
562+
c++;
563+
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_5) << 16) | high_bits_5;
564+
uint16_t high_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
565+
lut_id = (lut_id + 1) % LUT_NUM;
566+
uint16_t low_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
567+
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-2] = (static_cast<uint32_t>(low_bits_6) << 16) | high_bits_6;
568+
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-1] = (static_cast<uint32_t>(low_bits_7) << 16) | high_bits_7;
569+
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits_8) << 16) | high_bits_8;
543570
}
544571
}
545572
}

0 commit comments

Comments
 (0)