Skip to content

Commit 43704d9

Browse files
committed
slm dst, draft
1 parent b2e2125 commit 43704d9

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ inline float dDequantizeNF4(unsigned char val) {
243243
// }
244244
// barrier_arrive(3);
245245
//PVC SLM 64 banks -> 4 LUTs
246-
alignas(128) float (*quant_map_)[16] = reinterpret_cast<float(*)[16]>(smem_buf);
246+
alignas(64) float (*quant_map_)[16] = reinterpret_cast<float(*)[16]>(smem_buf);
247247
if (thread_idx < 16 * LUT_NUM) {
248248
quant_map_[thread_idx / 16][thread_idx % 16] = params.datatype[thread_idx % 16];
249249
}
@@ -509,7 +509,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
509509
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
510510
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
511511
src_compress_type src[src_vec_size];
512-
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
512+
//ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
513+
alignas(64) ElementMMA* dst = reinterpret_cast<ElementMMA*>(smem_buf + 16 * sizeof(float) * LUT_NUM + thread_idx * decltype(cute::size(mma_B))::value * sizeof(ElementMMA));
513514

514515
int lut_id = start_lut_id;
515516

@@ -633,7 +634,8 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
633634
//std::cout<<"group_size = "<<blocksize<<std::endl;
634635

635636
#if 1
636-
static constexpr int smem_size= (16) * sizeof(float) * LUT_NUM;// * 2;
637+
//static constexpr int smem_size= (16) * sizeof(float) * LUT_NUM;
638+
static constexpr int smem_size= (16) * sizeof(float) * LUT_NUM + BLK_N * BLK_K * sizeof(ElementMMA)*2;
637639
#else
638640
static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
639641
#endif

0 commit comments

Comments
 (0)