Skip to content

Commit 7206605

Browse files
committed
save code
1 parent 5bbb92f commit 7206605

1 file changed

Lines changed: 16 additions & 13 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ inline float dDequantizeNF4(unsigned char val) {
232232
? BlockIdxX() : BlockIdxY();
233233
const int l_coord = BlockIdxZ();
234234

235-
#if 0
235+
#if 1
236236
//float* quant_map;
237237
//static constexpr std::array<float, 16> quant_map{};
238238
// {
@@ -277,7 +277,7 @@ inline float dDequantizeNF4(unsigned char val) {
277277
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
278278
Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
279279

280-
#if 0 //SLM: 0, register: 1
280+
#if 1 //SLM: 0, register: 1
281281
#if 1 //fragement register
282282
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
283283
#else //common register
@@ -324,7 +324,7 @@ inline float dDequantizeNF4(unsigned char val) {
324324
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
325325
int prefetch_k = k_start_idx;
326326

327-
#if 1 //SLM
327+
#if 0 //SLM
328328
#if 1
329329
auto dequant = [&] (int k_tile) {
330330
constexpr int N = decltype(cute::size<1>(mma_B))::value;
@@ -386,11 +386,10 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
386386
constexpr int N = decltype(cute::size<1>(mma_B))::value;
387387
constexpr int K = decltype(cute::size(mma_B))::value / N;
388388

389-
390389
using src_compress_type = uint64_t;
391390
using dst_compress_type = uint64_t;
392391
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
393-
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //16
392+
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
394393
constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
395394
constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
396395
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
@@ -399,11 +398,11 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
399398
//if(cute::thread0()) printf("params.group_size = %d, k_reload_factor = %d, k_tile_count = %d, N = %d, K = %d, src_compress_size = %d, src_vec_size = %d, dst_compress_size = %d, dst_vec_size = %d\n",params.group_size, k_reload_factor, k_tile_count, N, K, src_compress_size, src_vec_size, dst_compress_size, dst_vec_size);
400399

401400
src_compress_type src[src_vec_size];
402-
ElementMMA dst[dst_compress_size * dst_vec_size];
401+
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
403402

404403
#pragma unroll
405404
for (int n = 0; n < N; n++) {
406-
float scale_value = fragment_scale(n);
405+
//float scale_value = fragment_scale(n);
407406
#pragma unroll
408407
for (int l = 0; l < src_loop_num; l++) {
409408
//src_compress_type src[src_vec_size];
@@ -412,18 +411,19 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
412411
#pragma unroll
413412
for (int v = 0; v < src_vec_size; v++) {
414413
src_compress_type src_value = src[v];
415-
int dst_idx = v * src_compress_size;
414+
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
416415
#pragma unroll
417416
for (int c = 0; c < src_compress_size; c++) {
418417
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
419-
dst[dst_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
418+
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
419+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
420420
}
421421
}
422422
}
423423

424424
#pragma unroll
425425
for (int l = 0; l < dst_loop_num; l++) {
426-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0];
426+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
427427
}
428428
}
429429
};
@@ -454,9 +454,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
454454
}
455455

456456
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
457-
#if 0 //SLM: 0, register: 1
457+
#if 1 //SLM: 0, register: 1
458458
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
459-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale);
459+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
460460
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
461461
dequant();
462462
#else
@@ -501,8 +501,11 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
501501

502502
using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS>;
503503

504-
//static constexpr int smem_size= (16) * sizeof(float);
504+
#if 1
505+
static constexpr int smem_size= (16) * sizeof(float);
506+
#else
505507
static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
508+
#endif
506509
size_t max_slm_size = q.get_device().get_info<sycl::info::device::local_mem_size>();
507510
assert(smem_size <= max_slm_size);
508511

0 commit comments

Comments
 (0)