@@ -61,7 +61,7 @@ static constexpr float quant_map_static[16] = {
6161};
6262#endif
6363
64- using TileShape = Shape<_32, _256 , _64>;
64+ using TileShape = Shape<_32, _128 , _64>;
6565using TiledMma =
6666 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6767 Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
@@ -231,7 +231,7 @@ inline float dDequantizeNF4(unsigned char val) {
231231 ? BlockIdxX () : BlockIdxY ();
232232 const int l_coord = BlockIdxZ ();
233233
234- #if 1
234+ #if 0
235235 //float* quant_map;
236236 //static constexpr std::array<float, 16> quant_map{};
237237 // {
@@ -276,7 +276,7 @@ inline float dDequantizeNF4(unsigned char val) {
276276 Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
277277 Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
278278
279- #if 1 // SLM: 0, register: 1
279+ #if 0 //SLM: 0, register: 1
280280 #if 1 //fragement register
281281 Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
282282 #else //common register
@@ -322,7 +322,7 @@ inline float dDequantizeNF4(unsigned char val) {
322322 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
323323 int prefetch_k = k_start_idx;
324324
325- #if 0 //SLM
325+ #if 1 // SLM
326326 // alignas(16) ElementB* slm_B = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (64 * 4) * k_tile_count;
327327 // const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * 1) * params.k/2;
328328 // //using total_vec = 4*k_tile_count;
@@ -350,10 +350,10 @@ inline float dDequantizeNF4(unsigned char val) {
350350
351351 #pragma unroll
352352 for (int i = 0 ; i < vec_size; ++i) {
353- // uint32_t src_value = reinterpret_cast<uint32_t*>(src)[i];
353+ uint32_t src_value = reinterpret_cast <uint32_t *>(src)[i];
354354 #pragma unroll
355355 for (int j = 0 ; j < compress_size; ++j) {
356- uint8_t bit_value = (reinterpret_cast<uint32_t*>(src)[i] >> (4 * (((j+1) & 1) + (j >> 1) * 2))) & 0xF;
356+ uint8_t bit_value = (src_value >> (4 * (((j+1 ) & 1 ) + (j >> 1 ) * 2 ))) & 0xF ;
357357 private_slm[i * compress_size + j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
358358 // dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
359359 }
@@ -439,7 +439,7 @@ inline float dDequantizeNF4(unsigned char val) {
439439 }
440440
441441 for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
442- #if 1 // SLM: 0, register: 1
442+ #if 0 //SLM: 0, register: 1
443443 copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
444444 copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale);
445445 copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
@@ -486,8 +486,8 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
486486
487487 using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS >;
488488
489- static constexpr int smem_size= (16 ) * sizeof (float );
490- // static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
489+ // static constexpr int smem_size= (16) * sizeof(float);
490+ static constexpr int smem_size = BLK_N * BLK_K * sizeof (ElementMMA) * 2 * 2 ; // aligned with 128B and will be reused for dequant src and dst.
491491 size_t max_slm_size = q.get_device ().get_info <sycl::info::device::local_mem_size>();
492492 assert (smem_size <= max_slm_size);
493493
0 commit comments