Skip to content

Commit d7cd99a

Browse files
committed
save code
1 parent 84c9cce commit d7cd99a

2 files changed

Lines changed: 21 additions & 13 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,14 @@ static constexpr float quant_map_static[16] = {
6161
};
6262
#endif
6363

64-
using TileShape = Shape<_32, _128, _64>;
64+
using TileShape = Shape<_32, _128, _128>;
6565
using TiledMma =
6666
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6767
Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
6868
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969
using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
7070
constexpr int PipelineStages = 2;
71+
static constexpr auto GROUP_SIZE=64; //Block Quant Size
7172

7273
using MmaAtomShape = typename TiledMma::AtomShape_MNK;
7374
using WorkgroupTileShape = TileShape;
@@ -285,9 +286,10 @@ inline float dDequantizeNF4(unsigned char val) {
285286
#endif
286287
Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
287288
#endif
288-
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
289-
static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
290-
using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
289+
static constexpr auto scale_shape_t = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
290+
static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
291+
static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE;
292+
using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>; //[1, dequant_N, block_num]
291293
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
292294

293295
//static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
@@ -314,8 +316,8 @@ inline float dDequantizeNF4(unsigned char val) {
314316

315317
auto tSgS = [&](){
316318
return make_tensor(make_inttuple_iter(make_coord(n_coord * BLK_N + get<2>(thr_mma.thr_vmnk_)*SG_QNT_WIDTH, 0, 0)),
317-
make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count/k_reload_factor),
318-
make_stride(E<0>{}*_16{}, E<0>{}*_16{}, _0{}, E<1>{}*_1{})));
319+
make_layout(make_shape(Int<scale_shape_t>{}, Int<scale_shape_n>{}, scale_shape_k, k_tile_count * BLK_K/params.group_size),
320+
make_stride(E<0>{}*_32{}, E<0>{}*_32{}, E<1>{}*_1{}, E<1>{}*_1{})));
319321

320322
}();
321323

@@ -340,28 +342,34 @@ inline float dDequantizeNF4(unsigned char val) {
340342
alignas(8) ElementB* src = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * K * 5; //for K=64, 4 is hardcode for 128B alignment.
341343
const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2;
342344
ElementMMA* dst_slm = reinterpret_cast<ElementMMA*>(src + K);
343-
//if(cute::thread0()) {
344-
// //printf("src = %x, gB_ptr = %x, dst_slm = %x\n", src, gB_ptr, dst_slm);
345+
#if 0
346+
if(cute::thread0()) {
347+
//printf("src = %x, gB_ptr = %x, dst_slm = %x\n", src, gB_ptr, dst_slm);
345348
// print("\n\n======================= SLM: \n");
346349
// print(" src : "); print(src); print("\n");
347350
// print(" gB_ptr : "); print(gB_ptr); print("\n");
348351
// print(" dst_slm : "); print(dst_slm); print("\n");
352+
// print(" fragment_scale: "); print(fragment_scale); print("\n");
349353
// print("\n\n=======================\n\n");
350-
//}
354+
}
355+
#endif
351356
#pragma unroll
352357
for (int n = 0; n < N; n++) {
353-
float scale_value = fragment_scale(n);
358+
//float scale_value = fragment_scale(n);
354359
#pragma unroll
355360
for (int l = 0; l < src_loop_num; l++) {
356361
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr)[n*src_loop_num + l];
357362
#pragma unroll
358363
for (int v = 0; v < src_vec_size; ++v) {
359364
src_compress_type src_value = reinterpret_cast<src_compress_type*>(src)[v];
360365
int dst_idx = v * src_compress_size;
366+
//float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + dst_idx / GROUP_SIZE);
361367
#pragma unroll
362368
for (int c = 0; c < src_compress_size; ++c) {
363369
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
370+
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_idx+c) / GROUP_SIZE);
364371
dst_slm[dst_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
372+
//if(cute::thread0()) printf("dst_idx+c = %d, n * (BLK_K / GROUP_SIZE) + (dst_idx+c)/GROUP_SIZE) = %d, scale_value = %f\n",dst_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_idx+c)/GROUP_SIZE, scale_value);
365373
}
366374
}
367375
}
@@ -453,7 +461,7 @@ inline float dDequantizeNF4(unsigned char val) {
453461
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
454462
dequant();
455463
#else
456-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale);
464+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
457465
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
458466
dequant(k_tile);
459467
#endif

run_case.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
#gdb -args python -m pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
3131
#pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
32-
#pytest -vs tests/test_xpu.py::TestXPU::test_gemv_4bit
33-
python tests/test_xpu_db.py
32+
pytest -vs tests/test_xpu.py::TestXPU::test_gemv_4bit
33+
#python tests/test_xpu_db.py
3434
#gdb -args python tests/test_xpu_db.py
3535
#pytest tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=256-uint8-bf16-fc1-nf4-DQ_True-xpu]

0 commit comments

Comments
 (0)