Skip to content

Commit 2d9253a

Browse files
committed
save code
1 parent 16978a9 commit 2d9253a

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,34 +531,41 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
531531
uint16_t high_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
532532
lut_id = (lut_id + 1) % LUT_NUM;
533533
uint16_t low_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
534+
534535
c++;
535536
uint16_t high_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
536537
lut_id = (lut_id + 1) % LUT_NUM;
537538
uint16_t low_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
539+
538540
c++;
539541
uint16_t high_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
540542
lut_id = (lut_id + 1) % LUT_NUM;
541543
uint16_t low_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
544+
542545
c++;
543546
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_1) << 16) | high_bits_1;
544547
uint16_t high_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
545548
lut_id = (lut_id + 1) % LUT_NUM;
546549
uint16_t low_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
550+
547551
c++;
548552
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_2) << 16) | high_bits_2;
549553
uint16_t high_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
550554
lut_id = (lut_id + 1) % LUT_NUM;
551555
uint16_t low_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
556+
552557
c++;
553558
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_3) << 16) | high_bits_3;
554559
uint16_t high_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
555560
lut_id = (lut_id + 1) % LUT_NUM;
556561
uint16_t low_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
562+
557563
c++;
558564
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_4) << 16) | high_bits_4;
559565
uint16_t high_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
560566
lut_id = (lut_id + 1) % LUT_NUM;
561567
uint16_t low_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
568+
562569
c++;
563570
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-3] = (static_cast<uint32_t>(low_bits_5) << 16) | high_bits_5;
564571
uint16_t high_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));

0 commit comments

Comments
 (0)