Skip to content

Commit b941379

Browse files
committed
save code
1 parent 4248edf commit b941379

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
101101
//static constexpr auto FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape());
102102
//static constexpr auto FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape());
103103
//static constexpr auto FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
104-
104+
static constexpr int LUT_NUM = 4;
105+
105106
// Design Scheduler
106107
using TileScheduler_ = PersistentScheduler;
107108
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>, "Intel PVC does not support specializing the tile scheduler.");
@@ -243,7 +244,7 @@ inline float dDequantizeNF4(unsigned char val) {
243244
// barrier_arrive(3);
244245
//PVC SLM 64 banks -> 4 LUTs
245246
alignas(128) float (*quant_map_)[16] = reinterpret_cast<float(*)[16]>(smem_buf);
246-
if (thread_idx < 64) {
247+
if (thread_idx < 16 * LUT_NUM) {
247248
quant_map_[thread_idx / 16][thread_idx % 16] = params.datatype[thread_idx % 16];
248249
}
249250
barrier_arrive(3);
@@ -532,7 +533,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
532533
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
533534
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value + (dst_base_idx + c) % 4 * 16] * scale_value);
534535
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
535-
//lut_id = (lut_id + 1) % 4;
536+
lut_id = (lut_id + 1) % LUT_NUM;
536537
//dst[dst_base_idx + c] = static_cast<ElementMMA>(params.quant_map_const[bit_value] * scale_value);
537538

538539
// uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
@@ -542,7 +543,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
542543
// dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
543544
// dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
544545
}
545-
lut_id = (lut_id + 1) % 4;
546+
//lut_id = (lut_id + 1) % LUT_NUM;
546547
}
547548
}
548549

@@ -565,7 +566,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
565566
//int map_offset = 16 * (sg_idx % 4);
566567
//int map_offset = 16 * ((sg_idx ^ (sg_idx >> 2)) % 4);
567568
//int lut_id = sg_idx % 4;
568-
int start_lut_id = sg_idx % 4;
569+
int start_lut_id = sg_idx % LUT_NUM;
569570

570571
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
571572
#if 1 //SLM: 0, register: 1
@@ -638,7 +639,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
638639
//std::cout<<"group_size = "<<blocksize<<std::endl;
639640

640641
#if 1
641-
static constexpr int smem_size= (32) * sizeof(float) * 2;// * 2;
642+
static constexpr int smem_size= (16) * sizeof(float) * LUT_NUM;// * 2;
642643
#else
643644
static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
644645
#endif

0 commit comments

Comments
 (0)