Skip to content

Commit dab7994

Browse files
committed
save code
1 parent de7bc02 commit dab7994

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static constexpr float quant_map_static[16] = {
6161
};
6262
#endif
6363

64-
using TileShape = Shape<_32, _256, _64>;
64+
using TileShape = Shape<_32, _128, _64>;
6565
using TiledMma =
6666
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6767
Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
@@ -231,7 +231,7 @@ inline float dDequantizeNF4(unsigned char val) {
231231
? BlockIdxX() : BlockIdxY();
232232
const int l_coord = BlockIdxZ();
233233

234-
#if 1
234+
#if 0
235235
//float* quant_map;
236236
//static constexpr std::array<float, 16> quant_map{};
237237
// {
@@ -276,7 +276,7 @@ inline float dDequantizeNF4(unsigned char val) {
276276
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
277277
Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
278278

279-
#if 1 //SLM: 0, register: 1
279+
#if 0 //SLM: 0, register: 1
280280
#if 1 //fragement register
281281
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
282282
#else //common register
@@ -322,7 +322,7 @@ inline float dDequantizeNF4(unsigned char val) {
322322
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
323323
int prefetch_k = k_start_idx;
324324

325-
#if 0 //SLM
325+
#if 1 //SLM
326326
//alignas(16) ElementB* slm_B = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (64 * 4) * k_tile_count;
327327
//const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * 1) * params.k/2;
328328
////using total_vec = 4*k_tile_count;
@@ -350,10 +350,10 @@ inline float dDequantizeNF4(unsigned char val) {
350350

351351
#pragma unroll
352352
for (int i = 0; i < vec_size; ++i) {
353-
//uint32_t src_value = reinterpret_cast<uint32_t*>(src)[i];
353+
uint32_t src_value = reinterpret_cast<uint32_t*>(src)[i];
354354
#pragma unroll
355355
for (int j = 0; j < compress_size; ++j) {
356-
uint8_t bit_value = (reinterpret_cast<uint32_t*>(src)[i] >> (4 * (((j+1) & 1) + (j >> 1) * 2))) & 0xF;
356+
uint8_t bit_value = (src_value >> (4 * (((j+1) & 1) + (j >> 1) * 2))) & 0xF;
357357
private_slm[i * compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
358358
//dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
359359
}
@@ -439,7 +439,7 @@ inline float dDequantizeNF4(unsigned char val) {
439439
}
440440

441441
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
442-
#if 1 //SLM: 0, register: 1
442+
#if 0 //SLM: 0, register: 1
443443
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
444444
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale);
445445
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
@@ -486,8 +486,8 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
486486

487487
using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS>;
488488

489-
static constexpr int smem_size= (16) * sizeof(float);
490-
//static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
489+
//static constexpr int smem_size= (16) * sizeof(float);
490+
static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
491491
size_t max_slm_size = q.get_device().get_info<sycl::info::device::local_mem_size>();
492492
assert(smem_size <= max_slm_size);
493493

0 commit comments

Comments
 (0)