@@ -59,9 +59,8 @@ using ElementOutput = float;
5959using ProblemShape = Shape<int , int , int , int >;
6060
6161using TileShape = Shape<_16, _64, _64>;
62- using TileShape_half = Shape<_16, _64, _32>;
6362using TiledMma =
64- typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT >, Layout<TileShape>,
63+ typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT >, Layout<TileShape>,
6564 Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
6665
6766using WorkgroupTileShape = TileShape;
@@ -237,9 +236,10 @@ class kgemm_4bit_inference_cutlass_dequant {
237236 uint8_t value = in[i].get ();
238237 out[i] = static_cast <DstType>(quant_map[value]);
239238 int thread_idx = int (ThreadIdxX ());
240- // if(thread_idx == 0)
241- if (syclcompat::global_id::x () == 2 && syclcompat::global_id::y () ==0 && syclcompat::global_id::z () ==0 )
242- printf (" thread_idx = %d, i = %d, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n " ,thread_idx, i, value, static_cast <int >(value), quant_map[value], static_cast <float >(out[i]));
239+ if (cute::thread0 ()){
240+ // if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
241+ // printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
242+ }
243243 }
244244#else
245245 static constexpr auto N = decltype(size<1>(in))::value;
@@ -419,7 +419,7 @@ if(cute::thread0())
419419// make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{})));
420420//
421421// }();
422- #if 0
422+ #if 1
423423 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
424424 if (cutlass::thread (LOG_THREAD , LOG_GROUP )) {
425425 print (" ======================= A: \n " );
@@ -437,6 +437,9 @@ if(cute::thread0())
437437 print (" frag_copy_B : " ); print (frag_copy_B); print (" \n " );
438438 print (" dequant_frag : " ); print (dequant_frag); print (" \n " );
439439
440+ print (" ===================== D :\n " );
441+ print (" accumulators : " ); print (accumulators); print (" \n " );
442+
440443 print (" ===================== Config: \n " );
441444 print (" threads per workgroup : " ); print (MaxThreadsPerBlock); print (" \n " );
442445 print (" SubgroupTileShape : " ); print (SubgroupTileShape{}); print (" \n " );
@@ -456,7 +459,7 @@ if(cute::thread0())
456459 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
457460 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
458461 }
459-
462+ // k_tile_count=1;
460463 for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
461464 barrier_arrive (2 );
462465
@@ -477,39 +480,55 @@ if(cute::thread0())
477480
478481 dequant (dequant_frag, mma_B, /* fragment_scale,*/ quant_map);
479482
483+ // barrier_wait(1);
484+
480485 cute::gemm (tiled_mma, mma_A, mma_B, accumulators);
486+ barrier_wait (2 );
487+ #if 0
488+ // 在调用gemm前后添加打印逻辑
489+ auto debug_print = [&](const char* name, auto& tensor) {
490+ int numbers = decltype(size(tensor))::value;
491+ printf("\n----- %s ----- numbers = %d\n", name, numbers);
492+ for (int i = 0; i < numbers; ++i) {
493+ printf("%s[%d] = %6.2f\n", name, i , static_cast<float>(tensor[i]));
494+ }
495+ printf("\n\n");
496+ barrier_wait(1);
497+ };
481498
482- // // 在调用gemm前后添加打印逻辑
483- // auto debug_print = [&](const char* name, auto& tensor) {
484- // if (thread_idx == 0) {
485- // printf("----- %s -----\n", name);
486- // for (int i = 0; i < size<0>(tensor); ++i) {
487- // for (int j = 0; j < size<1>(tensor); ++j) {
488- // printf("%6.2f ", static_cast<float>(tensor(i, j)));
489- // }
490- // printf("\n");
491- // }
492- // }
493- // barrier_wait(2);
494- // };
495- //
496- // // 打印输入
497- // debug_print("Input A (mma_A)", mma_A);
498- // debug_print("Input B (mma_B)", mma_B);
499- // debug_print("Accumulators (Before GEMM)", accumulators);
500- //
501- // // 执行GEMM
502- // cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
503- //
504- // // 打印输出
505- // debug_print("Accumulators (After GEMM)", accumulators);
499+ if (cute::thread0()) {
500+ // 打印输入
501+ debug_print("Input A (mma_A)", mma_A);
502+ barrier_wait(1);
503+ debug_print("Input B (mma_B)", mma_B);
504+ barrier_wait(1);
505+ debug_print("Accumulators (Before GEMM)", accumulators);
506+ barrier_wait(1);
507+ }
508+ // 执行GEMM
509+ cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
506510
507- barrier_wait (2 );
511+ if (cute::thread0()) {
512+ // 打印输出
513+ debug_print("Accumulators (After GEMM)", accumulators);
514+
515+ barrier_wait(2);
516+ }
517+ #endif
518+ #if 0
519+ cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
520+ barrier_wait(2);
521+
522+ for (int i = 0; i < accumulators.size(); ++i) {
523+ printf("Thread (%d, %d): accumulators[%d] =%f\n", syclcompat::global_id::x() , syclcompat::global_id::y(), i, static_cast<float>(accumulators[i]));
524+ }
525+ printf("\n");
526+ #endif
508527 }
509528
510529 SharedStorage& shared_storage = *reinterpret_cast <SharedStorage*>((char *)nullptr );
511530 CollectiveEpilogue epilogue{params.epilogue , shared_storage.epilogue };
512- auto problem_shape_MNKL = problem_size; // append<4>(problem_size, 1);
531+ auto problem_shape_MNKL = append<4 >(problem_size, 1 );
513532 epilogue (
514533 problem_shape_MNKL,
515534 subgroup_tile_shape,
@@ -573,7 +592,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
573592 // int k_half = k/2;
574593 // StrideB stride_B = make_stride(int64_t{1}, int64_t{n}, int64_t{n * k});
575594 StrideB stride_B = make_stride (int64_t {n}, cute::Int<1 >{}, int64_t {0 });
576- auto mB_nkl = make_tensor (cute::subbyte_iterator<uint4_t >(B), make_layout (make_shape (n, k, l), stride_B));
595+ auto mB_nkl = make_tensor (cute::subbyte_iterator<ElementB >(B), make_layout (make_shape (n, k, l), stride_B));
577596 Copy_B tiled_copy_b{Copy_B{}.with (mB_nkl )};
578597
579598 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
0 commit comments