@@ -145,10 +145,16 @@ using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
145145using val_layout_load_B = decltype (make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{})));
146146using Copy_B = decltype (make_tiled_copy(atom_load_B{}, Layout<CopyThreadShape>{}, val_layout_load_B{}));
147147
148- using GmemTiledCopyScale = XE_2D_U16x1x32_LD_N; // XE_2D_U16x1x16_LD_N;
149- static constexpr auto SG_QNT_WIDTH = Int< SG_N >{};
148+ // using GmemTiledCopyScale = XE_2D_U16x1x32_LD_N;
149+ using GmemTiledCopyScale = XE_2D_U16x1x16_LD_N;
150150using StrideScale = cute::Stride<_1, int64_t , int64_t >; // dynamic stride
151151using traits_load_scale = Copy_Traits<GmemTiledCopyScale, StrideScale>;
152+ // using AtomLayout = Layout<
153+ // Shape<_16, _2>, // 匹配 XE_2D_U16x1x32_LD_N 的 BlockShape
154+ // Stride<_1, _16> // 连续存储,步长 16
155+ // >;
156+ // using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale, AtomLayout>;
157+ // using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, AtomLayout{})); //group-wise scale
152158using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale>;
153159using val_layout_load_scale = decltype (make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{})));
154160using Copy_Scale = decltype (make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, val_layout_load_scale{})); // group-wise scale
@@ -228,7 +234,7 @@ class kgemm_4bit_inference_cutlass_dequant {
228234 using SrcType = typename EngineIn::value_type;
229235 using DstType = typename EngineOut::value_type;
230236 using ScaleType = typename EngineScales::value_type;
231- #if 1
237+ #if 0
232238 int numbers = decltype(size(in))::value;
233239 for(int i=0; i<numbers; i++){
234240 //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
@@ -281,7 +287,7 @@ class kgemm_4bit_inference_cutlass_dequant {
281287 for (int i = 0 ; i < vec_size; i++) {
282288 uint8_t value = (format_data >> (src_bits * i)) & 0xf ;
283289 dst[i] = static_cast <DstType>(quant_map[value] * static_cast <float >(ts));
284- if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
290+ // if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
285291 }
286292 }
287293 }
@@ -401,6 +407,7 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
401407 static constexpr auto scale_traits_size = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; // SubgroupSize;
402408 static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
403409 using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
410+ // using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>, Stride<_1,_1,_0>>;
404411 Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
405412 if (cute::thread0 ()) printf (" scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d, BlockShape = %d, BlockShape_1= %d\n " , scale_traits_size, scale_traits_num, SG_QNT_WIDTH , decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value, decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value);
406413
@@ -412,7 +419,16 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
412419 Tensor frag_copy_A = thr_copy_A.retile_D (mma_A);
413420 Tensor frag_copy_B = thr_copy_B.retile_D (dequant_frag);
414421 Tensor frag_copy_Scale = thr_copy_scale.retile_D (fragment_scale);
415-
422+ // auto frag_layout = make_layout(
423+ // make_shape(_2{}, _1{}, _1{}), // 形状 (_2, _1, _1)
424+ // make_stride(_1{}, _1{}, _0{}) // 步长 (_1, _1, _0)
425+ // );
426+ // Tensor frag_copy_Scale = thr_copy_scale.retile_D(make_tensor(fragment_scale.data(), frag_layout));
427+
428+ // using FragLayout = Layout<Shape<_2,_1,_1>, Stride<_1,_1,_0>>;
429+ // Tensor fragment_scale = make_tensor<ElementScale>(FragLayout{});
430+ // Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
431+
416432// // Retile global counting tensors for copies:
417433 Tensor tAgA = thr_copy_A.retile_S (tCgA);
418434 Tensor tBgB = thr_copy_B.retile_S (tCgB);
@@ -441,8 +457,11 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
441457
442458 }();
443459
444- // using ExpectedLayout = typename decltype(tiled_copy_scale)::TiledLayout::dst_layout; //decltype(tiled_copy_scale.dst_layout()); //decltype(tiled_copy_scale.atom_layout_dst());
445- // static_assert(is_same<decltype(frag_copy_Scale.layout()), ExpectedLayout>::value, "布局不匹配");
460+ // auto copy_iter_s = [&](){
461+ // return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
462+ // make_layout(make_shape(Int<decltype(size<0>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, Int<decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, _1{}, k_tile_count),
463+ // make_stride(_16{}, _32{}, _0{}, _1{})));
464+ // }();
446465
447466#if 1
448467 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
@@ -464,6 +483,7 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
464483
465484 print (" ===================== D :\n " );
466485 print (" tiled_copy_scale : " ); print (tiled_copy_scale); print (" \n " );
486+ print (" fragment_scale : " ); print (fragment_scale); print (" \n " );
467487 print (" frag_copy_Scale : " ); print (frag_copy_Scale); print (" \n " );
468488 print (" copy_iter_s: " ); print (copy_iter_s); print (" \n " );
469489
@@ -638,7 +658,6 @@ std::cout << std::endl;
638658 const int scale_k = cute::ceil_div (k, blocksize);
639659 StrideScale stride_S = cutlass::make_cute_packed_stride (StrideScale{}, cute::make_shape (n, scale_k, l));
640660 std::cout<<" n = " <<n<<" k = " <<k<<" blocksize = " <<blocksize<<" scale_k = " <<scale_k<<std::endl;
641-
642661 auto mScale = make_tensor (
643662 make_gmem_ptr (absmax_),
644663 make_layout (make_shape (n, scale_k, l), stride_S));
0 commit comments