@@ -145,6 +145,7 @@ using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
145145using val_layout_load_B = decltype (make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{})));
146146using Copy_B = decltype (make_tiled_copy(atom_load_B{}, Layout<CopyThreadShape>{}, val_layout_load_B{}));
147147
148+ // using GmemTiledCopyScale = XE_2D_U16x1x32_LD_N; //XE_2D_U16x1x16_LD_N;
148149using GmemTiledCopyScale = XE_2D_U16x1x16_LD_N;
149150static constexpr auto SG_QNT_WIDTH = Int<SG_N >{};
150151using StrideScale = cute::Stride<_1, int64_t , int64_t >; // dynamic stride
@@ -171,7 +172,7 @@ class kgemm_4bit_inference_cutlass_dequant {
171172 T* A;
172173 uint8_t * B;
173174 float * out;
174- T *absmax;
175+ // T *absmax;
175176 float *datatype; // LUT
176177 int group_size;
177178
@@ -228,7 +229,7 @@ class kgemm_4bit_inference_cutlass_dequant {
228229 using SrcType = typename EngineIn::value_type;
229230 using DstType = typename EngineOut::value_type;
230231 using ScaleType = typename EngineScales::value_type;
231- #if 0
232+ #if 1
232233 int numbers = decltype (size (in))::value;
233234 for (int i=0 ; i<numbers; i++){
234235 // auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
@@ -240,6 +241,11 @@ class kgemm_4bit_inference_cutlass_dequant {
240241 // if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
241242 // printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
242243 }
244+ int scale_number = decltype (size (tCrS_input))::value;
245+ for (int i=0 ; i<scale_number; i++){
246+ auto s_value = tCrS_input[i];
247+ if (cute::thread0 ()) printf (" scale_number = %d, tCrS_input[%d] = %f\n " ,scale_number, i, static_cast <float >(s_value));
248+ }
243249 }
244250#else
245251 static constexpr auto N = decltype(size<1>(in))::value;
@@ -275,8 +281,8 @@ class kgemm_4bit_inference_cutlass_dequant {
275281
276282 for (int i = 0; i < vec_size; i++) {
277283 uint8_t value = (format_data >> (src_bits * i)) & 0xf;
278- dst[i] = ( static_cast <DstType>(quant_map[value])); // * ts ;
279- // if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, dst = %f\n", n, s, i, static_cast<int>(value), static_cast<float>(dst[i]));
284+ dst[i] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts)) ;
285+ if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts ), static_cast<float>(dst[i]));
280286 }
281287 }
282288 }
@@ -334,11 +340,9 @@ class kgemm_4bit_inference_cutlass_dequant {
334340 l_coord = BlockIdxZ ();
335341 }
336342 auto blk_coord_mnkl = make_coord (m_coord, n_coord, _, l_coord);
337- if (0 ){ // cute::thread0()) {
343+ if (cute::thread0 ()) {
338344 printf (" M = %d, N=%d, K=%d, L=%d\n " , M, N, K, L);
339- // }
340345 printf (" thread_idx = %d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n " ,thread_idx, m_coord, n_coord, l_coord, BlockIdxX (), BlockIdxY (), BlockIdxZ ());
341-
342346 }
343347 constexpr auto workgroup_shape = WorkgroupTileShape{}; // 256, 256, 32
344348 constexpr auto subgroup_tile_shape = SubgroupTileShape{}; // 32, 64, 32 (number of atom level workgroup: 256/8=32, 256/4=64, 32/2=32)
@@ -383,7 +387,8 @@ class kgemm_4bit_inference_cutlass_dequant {
383387 static constexpr auto scale_traits_num = SG_QNT_WIDTH / size<1 >(typename GmemTiledCopyScale::BlockShape{});
384388 using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
385389 Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
386-
390+ if (cute::thread0 ()) printf (" scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d\n " , scale_traits_size, scale_traits_num, SG_QNT_WIDTH );
391+
387392 static_assert (std::is_same_v<typename decltype (dequant_frag)::value_type, ElementQuant>);
388393 static_assert (std::is_same_v<typename decltype (mma_A)::value_type, ElementMMA>);
389394 static_assert (std::is_same_v<typename decltype (mma_B)::value_type, ElementMMA>);
@@ -412,6 +417,8 @@ class kgemm_4bit_inference_cutlass_dequant {
412417 const int n_coord_s = n_idx * BLK_N + (get_sub_group_id () % ATOM_N ) * SG_N ;
413418 const int l_coord_s = l_idx;
414419
420+ if (cute::thread0 ()) printf (" m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n " ,m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
421+
415422 auto copy_iter_s = [&](){
416423 return make_tensor (make_inttuple_iter (make_coord (n_coord_s, 0 , l_coord_s)),
417424 make_layout (make_shape (Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
@@ -436,6 +443,10 @@ class kgemm_4bit_inference_cutlass_dequant {
436443 print (" frag_copy_B : " ); print (frag_copy_B); print (" \n " );
437444 print (" dequant_frag : " ); print (dequant_frag); print (" \n " );
438445
446+ print (" ===================== D :\n " );
447+ print (" frag_copy_ScaleB : " ); print (frag_copy_Scale); print (" \n " );
448+ print (" copy_iter_s: " ); print (copy_iter_s); print (" \n " );
449+
439450 print (" ===================== D :\n " );
440451 print (" accumulators : " ); print (accumulators); print (" \n " );
441452
@@ -468,6 +479,8 @@ class kgemm_4bit_inference_cutlass_dequant {
468479
469480 const int k_reload_factor = params.group_size / BLK_K ;
470481
482+ if (cute::thread0 ()) printf (" params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n " ,params.group_size , BLK_K , k_reload_factor);
483+
471484 copy (tiled_copy_scale, copy_iter_s (_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
472485
473486 if (prefetch_k < k_tile_count) {
@@ -574,7 +587,6 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
574587 auto mA_mkl = make_tensor (make_gmem_ptr (A), make_layout (make_shape (m, k, l), stride_A));
575588 Copy_A tiled_copy_a{Copy_A{}.with (mA_mkl )};
576589
577- // StrideB stride_B = make_stride(int64_t{n}, cute::Int<1>{}, int64_t{0});
578590 StrideB stride_B = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k, l));
579591 auto mB_nkl = make_tensor (cute::subbyte_iterator<ElementB>(B), make_layout (make_shape (n, k, l), stride_B));
580592 Copy_B tiled_copy_b{Copy_B{}.with (mB_nkl )};
@@ -595,7 +607,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
595607 std::cout<<" n = " <<n<<" k = " <<k<<" blocksize = " <<blocksize<<" scale_k = " <<scale_k<<std::endl;
596608
597609 auto mScale = make_tensor (
598- make_gmem_ptr (absmax_), // static_cast <ElementScale *>(absmax )),
610+ make_gmem_ptr (reinterpret_cast <ElementScale *>(absmax_ )),
599611 make_layout (make_shape (n, scale_k, l), stride_S));
600612 Copy_Scale tiled_copy_scale = {Copy_Scale{}.with (mScale )};
601613
0 commit comments