@@ -59,6 +59,7 @@ using ElementOutput = float;
5959using ProblemShape = Shape<int , int , int , int >;
6060
6161using TileShape = Shape<_16, _64, _64>;
62+ using TileShape_half = Shape<_16, _64, _32>;
6263using TiledMma =
6364 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
6465 Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
@@ -138,6 +139,7 @@ using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout<CopyThreadShape>{}
138139
139140using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
140141using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::ColumnMajor>;
142+ // using StrideB = Stride<int64_t, int64_t, int64_t>;
141143// using Copy_B = typename Copy_Traits<GmemTiledCopyB, StrideB>::template DefaultTiledCopy<ElementB>;
142144using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>;
143145using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
@@ -234,7 +236,10 @@ class kgemm_4bit_inference_cutlass_dequant {
234236 // out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
235237 uint8_t value = in[i].get ();
236238 out[i] = static_cast <DstType>(quant_map[value]);
237- if (cute::thread0 ()) printf (" thread_idx = %d, i = %d, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n " ,int (ThreadIdxX ()), i, value, static_cast <int >(value), quant_map[value], static_cast <float >(out[i]));
239+ int thread_idx = int (ThreadIdxX ());
240+ // if(thread_idx == 0)
241+ if (syclcompat::global_id::x () == 2 && syclcompat::global_id::y () ==0 && syclcompat::global_id::z () ==0 )
242+ printf (" thread_idx = %d, i = %d, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n " ,thread_idx, i, value, static_cast <int >(value), quant_map[value], static_cast <float >(out[i]));
238243 }
239244#else
240245 static constexpr auto N = decltype(size<1>(in))::value;
@@ -330,7 +335,7 @@ if(cute::thread0())
330335 l_coord = BlockIdxZ ();
331336 }
332337 auto blk_coord_mnkl = make_coord (m_coord, n_coord, _, l_coord);
333- if (cute::thread0 ()) {
338+ if (0 ){ // cute::thread0()) {
334339 printf (" M = %d, N=%d, K=%d, L=%d\n " , M, N, K, L);
335340 // }
336341 printf (" thread_idx = %d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n " ,thread_idx, m_coord, n_coord, l_coord, BlockIdxX (), BlockIdxY (), BlockIdxZ ());
@@ -414,7 +419,7 @@ if(cute::thread0())
414419// make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{})));
415420//
416421// }();
417-
422+ # if 0
418423 #define PRINT(x) print(#x ": "); print(x); print("\n");
419424 if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
420425 print("======================= A: \n");
@@ -442,7 +447,7 @@ if(cute::thread0())
442447 print(" pBgB : "); print(pBgB); print("\n");
443448 }
444449 #undef PRINT
445-
450+ # endif
446451 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K));
447452 int prefetch_k = k_start_idx;
448453
@@ -466,7 +471,7 @@ if(cute::thread0())
466471 if (prefetch_k < k_tile_count) {
467472 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
468473 }
469- if (prefetch_k < k_tile_count / 2 ) {
474+ if (prefetch_k < k_tile_count) {
470475 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
471476 }
472477
@@ -517,14 +522,11 @@ if(cute::thread0())
517522};
518523
519524template <typename T, int BITS >
520- void gemm_4bit_inference_cutlass_dequant (int m, int n, int k_ , T *A, unsigned char *B,
525+ void gemm_4bit_inference_cutlass_dequant (int m, int n, int k , T *A, unsigned char *B,
521526 T *absmax_, float *datatype, float *out, int lda,
522527 int ldb, int ldc, int blocksize, sycl::queue *stream) {
523528 std::cout<<" this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n " ;
524529
525- int k = k_;
526-
527-
528530 sycl::queue q = *stream;
529531 using GemmKernel = kgemm_4bit_inference_cutlass_dequant<T, BITS >;
530532
@@ -555,7 +557,7 @@ int k = k_;
555557 auto mA_mkl = make_tensor (make_gmem_ptr (A), make_layout (make_shape (m, k, l), stride_A));
556558 Copy_A tiled_copy_a{Copy_A{}.with (mA_mkl )};
557559
558- StrideB stride_B = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k, l));
560+ // StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n/2 , k, l));
559561// auto stride_B_custom = cute::make_stride(
560562// cute::Int<1>{}, // 连续维度步幅(字节)
561563// (n * 4 + 7) / 8, // pitch = ceil(n * 4bit / 8bit)
@@ -568,9 +570,20 @@ int k = k_;
568570// (n * 4 ) / 8,
569571// (n * k * 4 ) / 8
570572// );
573+ // int k_half = k/2;
574+ // StrideB stride_B = make_stride(int64_t{1}, int64_t{n}, int64_t{n * k});
575+ StrideB stride_B = make_stride (int64_t {n}, cute::Int<1 >{}, int64_t {0 });
571576 auto mB_nkl = make_tensor (cute::subbyte_iterator<uint4_t >(B), make_layout (make_shape (n, k, l), stride_B));
572577 Copy_B tiled_copy_b{Copy_B{}.with (mB_nkl )};
573578
579+ #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
580+ if (cutlass::thread (LOG_THREAD , LOG_GROUP )) {
581+ print (" ===================== B :\n " );
582+ print (" stride_B : " ); print (stride_B); print (" \n " );
583+ print (" ===================== B :\n " );
584+ }
585+ #undef PRINT
586+
574587 params.tiled_copy_a = tiled_copy_a;
575588 params.tiled_copy_b = tiled_copy_b;
576589
0 commit comments