Skip to content

Commit 5bbb92f

Browse files
committed
save code, fix bug
1 parent a8f256c commit 5bbb92f

2 files changed

Lines changed: 10 additions & 11 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static constexpr float quant_map_static[16] = {
6161
};
6262
#endif
6363

64-
using TileShape = Shape<_32, _128, _32>;
64+
using TileShape = Shape<_32, _128, _128>;
6565
using TiledMma =
6666
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6767
Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
@@ -333,7 +333,7 @@ inline float dDequantizeNF4(unsigned char val) {
333333
using src_compress_type = uint64_t;
334334
using dst_compress_type = uint64_t;
335335
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
336-
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //16
336+
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
337337
constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
338338
constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
339339
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
@@ -344,7 +344,7 @@ inline float dDequantizeNF4(unsigned char val) {
344344
ElementMMA* dst_slm = reinterpret_cast<ElementMMA*>(src + K);
345345
#if 0
346346
if(cute::thread0()) {
347-
//printf("src = %x, gB_ptr = %x, dst_slm = %x\n", src, gB_ptr, dst_slm);
347+
printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n", src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num);
348348
// print("\n\n======================= SLM: \n");
349349
// print(" src : "); print(src); print("\n");
350350
// print(" gB_ptr : "); print(gB_ptr); print("\n");
@@ -362,21 +362,20 @@ if(cute::thread0()) {
362362
#pragma unroll
363363
for (int v = 0; v < src_vec_size; ++v) {
364364
src_compress_type src_value = reinterpret_cast<src_compress_type*>(src)[v];
365-
int dst_idx = v * src_compress_size;
366-
//float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + dst_idx / GROUP_SIZE);
365+
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
367366
#pragma unroll
368367
for (int c = 0; c < src_compress_size; ++c) {
369368
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
370-
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_idx+c) / GROUP_SIZE);
371-
dst_slm[dst_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
372-
//if(thread_idx==1 && m_coord==0 && n_coord==0 && l_coord==0) printf("dst_idx+c = %d, n * (BLK_K / GROUP_SIZE) + (dst_idx+c)/GROUP_SIZE) = %d, scale_value = %f\n",dst_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_idx+c)/GROUP_SIZE, scale_value);
369+
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
370+
dst_slm[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
371+
//if(thread_idx==1 && m_coord==0 && n_coord==0 && l_coord==0) printf("dst_base_idx+c = %d, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE) = %d, scale_value = %f\n",dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
373372
}
374373
}
375374
}
376375

377376
#pragma unroll
378377
for (int l = 0; l < dst_loop_num; l++) {
379-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<const sycl::vec<dst_compress_type, dst_vec_size>*>(dst_slm)[0];
378+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<const sycl::vec<dst_compress_type, dst_vec_size>*>(dst_slm)[l];
380379
}
381380
}
382381
};

run_case.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
#gdb -args python -m pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
3131
#pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
32-
pytest -vs tests/test_xpu.py::TestXPU::test_gemv_4bit
33-
#python tests/test_xpu_db.py
32+
#pytest -vs tests/test_xpu.py::TestXPU::test_gemv_4bit
33+
python tests/test_xpu_db.py
3434
#gdb -args python tests/test_xpu_db.py
3535
#pytest tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=256-uint8-bf16-fc1-nf4-DQ_True-xpu]

0 commit comments

Comments
 (0)