Skip to content

Commit 8f776f5

Browse files
committed
save code
1 parent 1fd7c06 commit 8f776f5

1 file changed

Lines changed: 12 additions & 5 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ using TiledMma =
6868
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969
using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
7070
constexpr int PipelineStages = 2;
71-
static constexpr auto GROUP_SIZE=64; //Block Quant Size
71+
static constexpr auto GROUP_SIZE=32; //Block Quant Size
7272

7373
using MmaAtomShape = typename TiledMma::AtomShape_MNK;
7474
using WorkgroupTileShape = TileShape;
@@ -295,6 +295,7 @@ inline float dDequantizeNF4(unsigned char val) {
295295
//static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
296296
static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
297297
static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
298+
//static_assert(params.group_size, GROUP_SIZE);
298299

299300
Tensor frag_copy_A = thr_copy_A.retile_D(mma_A);
300301
//Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
@@ -316,10 +317,11 @@ inline float dDequantizeNF4(unsigned char val) {
316317

317318
auto tSgS = [&](){
318319
return make_tensor(make_inttuple_iter(make_coord(n_coord * BLK_N + get<2>(thr_mma.thr_vmnk_)*SG_QNT_WIDTH, 0, 0)),
319-
make_layout(make_shape(Int<scale_shape_t>{}, Int<scale_shape_n>{}, scale_shape_k, k_tile_count * BLK_K/params.group_size),
320-
make_stride(E<0>{}*(scale_shape_n * scale_shape_k * DispatchPolicy::SubgroupSize), E<0>{}*(scale_shape_k * DispatchPolicy::SubgroupSize), E<1>{}*_1{}, E<1>{}*_1{})));
320+
make_layout(make_shape(Int<scale_shape_t>{}, Int<scale_shape_n>{}, 1, k_tile_count * BLK_K/params.group_size),
321+
make_stride(E<0>{}*(scale_shape_n * k_tile_count * BLK_K/params.group_size), E<0>{}*(k_tile_count * BLK_K/params.group_size), E<0>{}*_0{}, E<1>{}*_1{})));
321322

322323
}();
324+
if(cute::thread0()) printf("scale_shape_t = %d, scale_shape_n = %d, scale_shape_k = %d, k_tile_count = %d, k_tile_count * BLK_K/params.group_size = %d, scale_shape_n * scale_shape_k * DispatchPolicy::SubgroupSize = %d, scale_shape_k * DispatchPolicy::SubgroupSize = %d\n",/*static_cast<int>(get<2>(thr_mma.thr_vmnk_)), static_cast<int>(SG_QNT_WIDTH),*/ scale_shape_t, scale_shape_n, scale_shape_k, k_tile_count, k_tile_count * BLK_K/params.group_size, scale_shape_n * scale_shape_k * DispatchPolicy::SubgroupSize, scale_shape_k * DispatchPolicy::SubgroupSize);
323325

324326
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
325327
int prefetch_k = k_start_idx;
@@ -385,6 +387,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
385387
auto dequant = [&] {
386388
constexpr int N = decltype(cute::size<1>(mma_B))::value;
387389
constexpr int K = decltype(cute::size(mma_B))::value / N;
390+
//if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
388391

389392
using src_compress_type = uint64_t;
390393
using dst_compress_type = uint64_t;
@@ -422,9 +425,12 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
422425
#pragma unroll
423426
for (int c = 0; c < src_compress_size; c++) {
424427
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
425-
float scale_value = 1.0f; //fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
428+
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
426429
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
427-
//if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) printf("n = %d, src_l = %d, dst_base_idx+c = %d, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE) = %d, scale_value = %f\n", n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
430+
if(1){ //thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0){
431+
printf("tid = %d, m_coord = %d, n_coord = %d, l_coord = %d, n = %d, src_l = %d, dst_dx = %d, scale_idx = %d, scale_value = %f\n", thread_idx, m_coord, n_coord, l_coord, n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
432+
//print(" scale_value : "); print(scale_value); print("\n");
433+
}
428434
}
429435
}
430436
}
@@ -536,6 +542,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
536542
sycl::queue q = *stream;
537543

538544
using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS>;
545+
std::cout<<"group_size = "<<blocksize<<std::endl;
539546

540547
#if 1
541548
static constexpr int smem_size= (16) * sizeof(float);

0 commit comments

Comments
 (0)