@@ -234,23 +234,29 @@ class kgemm_4bit_inference_cutlass_dequant {
234234 using DstType = typename EngineOut::value_type;
235235 using ScaleType = typename EngineScales::value_type;
236236#if 0
237- int numbers = decltype(size(in))::value;
238- for(int i=0; i<numbers; i++){
239- //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
240- //out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
241- uint8_t value = in[i].get();
242- out[i] = static_cast<DstType>(quant_map[value]);
243- int thread_idx = int(ThreadIdxX());
244- if(cute::thread0()){
245- //if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
246- //printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
247- }
248- }
249- int scale_number = decltype(size(tCrS_input))::value;
250- for(int i=0; i<scale_number; i++){
237+ static constexpr auto N = decltype(size<1>(in))::value;
238+ static constexpr auto loop_cnt = decltype(size(out))::value / N;
239+ for (int n = 0; n < N; n++) {
251240 auto s_value = tCrS_input(i);
252- if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
253- }
241+ for (int l = 0; s < loop_cnt; l++) {
242+
243+ // int numbers = decltype(size(in))::value;
244+ // for(int i=0; i<numbers / N; i++){
245+ // //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
246+ // //out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
247+ // uint8_t value = in[i].get();
248+ // out[i] = static_cast<DstType>(quant_map[value]);
249+ // int thread_idx = int(ThreadIdxX());
250+ // if(cute::thread0()){
251+ // //if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
252+ // //printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
253+ // }
254+ // }
255+ // int scale_number = decltype(size(tCrS_input))::value;
256+ // for(int i=0; i<scale_number; i++){
257+ // auto s_value = tCrS_input(i);
258+ // if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
259+ // }
254260#else
255261 static constexpr auto N = decltype (size<1 >(in))::value;
256262
@@ -269,7 +275,11 @@ class kgemm_4bit_inference_cutlass_dequant {
269275 auto s_tensor = make_tensor ((format_type*)(raw_pointer_cast (in.data ())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
270276 auto d_tensor = make_tensor (out.data (), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
271277
272- // if(cute::thread0())
278+ int scale_number = decltype (size (tCrS_input))::value;
279+ for (int i=0 ; i<scale_number; i++){
280+ auto s_value = tCrS_input (i);
281+ if (cute::thread0 ()) printf (" scale_number = %d, tCrS_input[%d] = %f\n " ,scale_number, i, static_cast <float >(s_value));
282+ }
273283// printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
274284
275285 for (int n = 0 ; n < N; n++) {
@@ -285,8 +295,13 @@ class kgemm_4bit_inference_cutlass_dequant {
285295
286296 for (int i = 0 ; i < vec_size; i++) {
287297 uint8_t value = (format_data >> (src_bits * i)) & 0xf ;
288- dst[i] = static_cast <DstType>(quant_map[value] * static_cast <float >(ts));
289- // if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
298+ if (i % 2 != 0 ) { // 1,3, high_4bit
299+ dst[i-1 ] = static_cast <DstType>(quant_map[value] * static_cast <float >(ts));
300+ } else {
301+ dst[i+1 ] = static_cast <DstType>(quant_map[value] * static_cast <float >(ts));
302+ }
303+ if (cute::thread0 ())
304+ printf (" tid = %d, n = %d, s = %d, i = %d, format_data = %d, value = %d, quant_map[value] = %f, ts = %f, dst = %f\n " ,ThreadIdxX (), n, s, i, static_cast <int >(format_data), static_cast <int >(value), quant_map[value], static_cast <float >(ts), static_cast <float >(dst[i]));
290305 }
291306 }
292307 }
@@ -500,29 +515,38 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
500515 }
501516 #undef PRINT
502517#endif
503- const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K));
518+ const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K));
504519 int prefetch_k = k_start_idx;
505520
521+ #if 1
522+ const int k_reload_factor = ceil_div (params.group_size , BLK_K );
523+ if (cute::thread0 ()) printf (" params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n " ,params.group_size , BLK_K , k_reload_factor);
524+ #endif
506525 CUTLASS_PRAGMA_UNROLL
507526 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
508527 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
509528 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
510529 }
511530
512- for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
531+ for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++, k_s ++) {
513532 barrier_arrive (2 );
514533
515534 // Copy gmem to rmem for the first k_tile
516535 copy (tiled_copy_a, tAgA (_,_,_,k_tile), frag_copy_A);
517536 copy (tiled_copy_b, tBgB (_,_,_,k_tile), frag_copy_B);
518-
537+ #if 1
538+ const int s_step = k_start_idx + (k_s / k_reload_factor); // 1 + k_tile / k_reload_factor;
539+ if (cute::thread0 ()) printf (" k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n " ,k_start_idx, k_s, k_reload_factor, s_step);
540+ copy (tiled_copy_scale, copy_iter_s (_, _, _, s_step), frag_copy_Scale);
541+ #else
519542 const int k_reload_factor = ceil_div(params.group_size, BLK_K);
520543 //const int k_reload_factor = params.group_size / BLK_K;
521544
522- if (cute::thread0 ()) printf (" params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n " ,params.group_size , BLK_K , k_reload_factor);
545+ //if(cute::thread0())
546+ printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
523547
524548 copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
525-
549+ # endif
526550 if (prefetch_k < k_tile_count) {
527551 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
528552 }
@@ -563,12 +587,10 @@ if (cute::thread0()) {
563587// 打印输出
564588debug_print("Accumulators (After GEMM)", accumulators);
565589
566- barrier_wait(2);
567590}
568591#endif
569592#if 0
570593cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
571- barrier_wait(2);
572594
573595for (int i = 0; i < accumulators.size(); ++i) {
574596 printf("Thread (%d, %d): accumulators[%d] =%f\n", syclcompat::global_id::x() , syclcompat::global_id::y(), i, static_cast<float>(accumulators[i]));
0 commit comments