@@ -224,7 +224,7 @@ class gemm_4bit_cutlass_kernel {
224224 auto thr_mma = tiled_mma.get_slice (first_thread_in_sg_idx);
225225
226226 Tensor tCgA = thr_mma.partition_A (gA );
227- Tensor tCgB = thr_mma.partition_B (gB );
227+ Tensor tCgB = thr_mma.partition_B (gB ); // values for each_thread (FrgV,(RestN,RestK),*)
228228
229229 Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
230230 Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
@@ -366,7 +366,58 @@ auto dequant = [&] {
366366 };
367367#endif
368368 #endif
369-
369+ #if 0
370+ if (cute::thread0()){ //thread_idx==0 && n_coord == 0 && l_coord==0) {
371+ print("\n\n======================= A: \n");
372+ print(" gA : "); print(gA); print("\n");
373+ print(" tCgA : "); print(tCgA); print("\n");
374+ print(" tAgA : "); print(tAgA); print("\n");
375+ print(" mma_A : "); print(mma_A); print("\n");
376+ print(" frag_copy_A : "); print(frag_copy_A); print("\n");
377+
378+ print("===================== B :\n");
379+ print(" gB : "); print(gB); print("\n");
380+ print(" tCgB : "); print(tCgB); print("\n");
381+ print(" tBgB : "); print(tBgB); print("\n");
382+ print(" mma_B : "); print(mma_B); print("\n");
383+ //print(" frag_copy_B : "); print(frag_copy_B); print("\n");
384+ //print(" dequant_frag : "); print(dequant_frag); print("\n");
385+
386+ print("===================== Scale :\n");
387+ print(" tiled_copy_scale : "); print(params.tiled_copy_scale); print("\n");
388+ print(" fragment_scale : "); print(fragment_scale); print("\n");
389+ print(" frag_copy_Scale : "); print(frag_copy_Scale); print("\n");
390+ print(" tSgS : "); print(tSgS); print("\n");
391+
392+ print("===================== D :\n");
393+ print(" accumulators : "); print(accumulators); print("\n");
394+
395+ print("===================== Config: \n");
396+ print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n");
397+ print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n");
398+
399+ print("===================== Config: \n");
400+ print(" tiled_mma : "); print(tiled_mma); print("\n");
401+
402+ print("===================== Config: \n");
403+ print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n");
404+
405+ print("===================== Config: \n");
406+ print(" thr_mma : "); print(thr_mma); print("\n");
407+
408+ print("===================== Config: \n");
409+ print(" tiled_prefetch_a : "); print(tiled_prefetch_a); print("\n");
410+
411+ print("===================== Config: \n");
412+ print(" tiled_prefetch_b : "); print(tiled_prefetch_b); print("\n");
413+
414+ print("===================== Config: \n");
415+ print(" pAgA : "); print(pAgA); print("\n");
416+
417+ print("===================== Config: \n");
418+ print(" pBgB : "); print(pBgB); print("\n\n\n");
419+ }
420+ #endif
370421
371422 CUTLASS_PRAGMA_UNROLL
372423 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
@@ -378,10 +429,14 @@ auto dequant = [&] {
378429 // copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
379430 copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale);
380431 // barrier_wait(3);
381- const uint8_t * gB_ptr = params.B + cute::get<0 >(gB .layout ()(make_coord (n_coord, k_tile, 0 )));
382432auto dequant = [&] {
383433 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
384434 constexpr int K = decltype (cute::size (mma_B))::value / N;
435+ const uint8_t * gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k /2 + k_tile * BLK_K /2 ;
436+ // if(thread_idx==8 && int(BlockIdxX())==0 && int(BlockIdxY())==0 && int(BlockIdxZ())==0){
437+ // printf("BLK_N = %d, BLK_K = %d, thread_idx = %d, N = %d, params.k = %d, params.B = %x, n_coord = %d, k_tile = %d, gB_ptr = %x\n",static_cast<int>(BLK_N), static_cast<int>(BLK_K), thread_idx, N, params.k, params.B, n_coord, k_tile, gB_ptr);
438+ // print(" gB_ptr: "); print(gB_ptr); print("\n");
439+ // }
385440
386441 using compress_type = uint32_t ;
387442 constexpr int compress_size = 32 / cute::sizeof_bits_v<ElementB>;
@@ -392,13 +447,14 @@ auto dequant = [&] {
392447 constexpr int ELEMS_PER_BANK = (ELEMS_PER_THREAD + BANK_NUM - 1 ) / BANK_NUM ; // 2
393448
394449 // const ElementB* gB_ptr = params.B + gB(n_coord, idx2crd(k_tile, make_shape(params.k)), 0).offset();
395- ElementB* slm_B = reinterpret_cast <ElementB*>(smem_buf) + thread_idx * 512 ;
450+ alignas ( 16 ) ElementB* slm_B = reinterpret_cast <ElementB*>(smem_buf) + thread_idx * 64 * 4 ; // 512 ;
396451 *reinterpret_cast <sycl::vec<uint64_t , 4 >*>(slm_B) = *reinterpret_cast <const sycl::vec<uint64_t , 4 >*>(gB_ptr );
397452
398453 compress_type src[vec_size];
399454 *reinterpret_cast <sycl::vec<compress_type, vec_size>*>(src) = *reinterpret_cast <const sycl::vec<compress_type, vec_size>*>(slm_B);
400455
401- ElementMMA* private_slm = reinterpret_cast <ElementMMA*>(slm_B) + thread_idx * ELEMS_PER_THREAD ; // 每个线程一段 **连续** 128 B,天然 128 B 对齐
456+ // ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(slm_B) + thread_idx * ELEMS_PER_THREAD; // 每个线程一段 **连续** 128 B,天然 128 B 对齐
457+ ElementMMA dst[K];
402458
403459 float scale_value = fragment_scale (0 );
404460
@@ -408,12 +464,13 @@ auto dequant = [&] {
408464 for (int j = 0 ; j < compress_size; ++j) {
409465 uint8_t bit_value = (src[i] >> (4 * (((j+1 ) & 1 ) + (j >> 1 ) * 2 ))) & 0xF ;
410466 // uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
411- private_slm[i * compress_size + j] =
412- static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
467+ // private_slm[i * compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
468+ dst[i*compress_size+j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
413469 }
414470 }
415471
416- *reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ())) = *reinterpret_cast <const sycl::vec<int64_t , 16 >*>(private_slm);
472+ // *reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data())) = *reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm);
473+ reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ()))[0 ] = reinterpret_cast <sycl::vec<int64_t , 16 >*>(dst)[0 ];
417474};
418475 dequant ();
419476 copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
@@ -455,7 +512,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
455512 using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS >;
456513
457514 // static constexpr int smem_size= BLK_N * BLK_K * 16/8; //(16+1)*32/8;
458- static constexpr int smem_size = BLK_N * BLK_K * sizeof (ElementB) + BLK_N * BLK_K * sizeof (ElementMMA);
515+ static constexpr int smem_size = BLK_N * BLK_K * sizeof (ElementB) * 4 ; // + BLK_N * BLK_K * sizeof(ElementMMA);
459516 size_t max_slm_size = q.get_device ().get_info <sycl::info::device::local_mem_size>();
460517 assert (smem_size <= max_slm_size);
461518
0 commit comments