Skip to content

Commit da1df6e

Browse files
committed
enable scaling
1 parent d238a6a commit da1df6e

2 files changed

Lines changed: 29 additions & 10 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,16 @@ using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
145145
using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{})));
146146
using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout<CopyThreadShape>{}, val_layout_load_B{}));
147147

148-
using GmemTiledCopyScale = XE_2D_U16x1x32_LD_N; //XE_2D_U16x1x16_LD_N;
149-
static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
148+
//using GmemTiledCopyScale = XE_2D_U16x1x32_LD_N;
149+
using GmemTiledCopyScale = XE_2D_U16x1x16_LD_N;
150150
using StrideScale = cute::Stride<_1, int64_t, int64_t>; //dynamic stride
151151
using traits_load_scale = Copy_Traits<GmemTiledCopyScale, StrideScale>;
152+
//using AtomLayout = Layout<
153+
// Shape<_16, _2>, // 匹配 XE_2D_U16x1x32_LD_N 的 BlockShape
154+
// Stride<_1, _16> // 连续存储,步长 16
155+
//>;
156+
//using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale, AtomLayout>;
157+
//using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, AtomLayout{})); //group-wise scale
152158
using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale>;
153159
using val_layout_load_scale = decltype(make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{})));
154160
using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, val_layout_load_scale{})); //group-wise scale
@@ -228,7 +234,7 @@ class kgemm_4bit_inference_cutlass_dequant {
228234
using SrcType = typename EngineIn::value_type;
229235
using DstType = typename EngineOut::value_type;
230236
using ScaleType = typename EngineScales::value_type;
231-
#if 1
237+
#if 0
232238
int numbers = decltype(size(in))::value;
233239
for(int i=0; i<numbers; i++){
234240
//auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
@@ -281,7 +287,7 @@ class kgemm_4bit_inference_cutlass_dequant {
281287
for (int i = 0; i < vec_size; i++) {
282288
uint8_t value = (format_data >> (src_bits * i)) & 0xf;
283289
dst[i] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts));
284-
if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
290+
//if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
285291
}
286292
}
287293
}
@@ -401,6 +407,7 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
401407
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; //SubgroupSize;
402408
static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
403409
using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
410+
//using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>, Stride<_1,_1,_0>>;
404411
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
405412
if(cute::thread0()) printf("scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d, BlockShape = %d, BlockShape_1= %d\n", scale_traits_size, scale_traits_num, SG_QNT_WIDTH, decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value, decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value);
406413

@@ -412,7 +419,16 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
412419
Tensor frag_copy_A = thr_copy_A.retile_D(mma_A);
413420
Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
414421
Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
415-
422+
//auto frag_layout = make_layout(
423+
// make_shape(_2{}, _1{}, _1{}), // 形状 (_2, _1, _1)
424+
// make_stride(_1{}, _1{}, _0{}) // 步长 (_1, _1, _0)
425+
//);
426+
//Tensor frag_copy_Scale = thr_copy_scale.retile_D(make_tensor(fragment_scale.data(), frag_layout));
427+
428+
//using FragLayout = Layout<Shape<_2,_1,_1>, Stride<_1,_1,_0>>;
429+
//Tensor fragment_scale = make_tensor<ElementScale>(FragLayout{});
430+
//Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
431+
416432
//// Retile global counting tensors for copies:
417433
Tensor tAgA = thr_copy_A.retile_S(tCgA);
418434
Tensor tBgB = thr_copy_B.retile_S(tCgB);
@@ -441,8 +457,11 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
441457

442458
}();
443459

444-
//using ExpectedLayout = typename decltype(tiled_copy_scale)::TiledLayout::dst_layout; //decltype(tiled_copy_scale.dst_layout()); //decltype(tiled_copy_scale.atom_layout_dst());
445-
//static_assert(is_same<decltype(frag_copy_Scale.layout()), ExpectedLayout>::value, "布局不匹配");
460+
//auto copy_iter_s = [&](){
461+
// return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
462+
// make_layout(make_shape(Int<decltype(size<0>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, Int<decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, _1{}, k_tile_count),
463+
// make_stride(_16{}, _32{}, _0{}, _1{})));
464+
//}();
446465

447466
#if 1
448467
#define PRINT(x) print(#x ": "); print(x); print("\n");
@@ -464,6 +483,7 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
464483

465484
print("===================== D :\n");
466485
print(" tiled_copy_scale : "); print(tiled_copy_scale); print("\n");
486+
print(" fragment_scale : "); print(fragment_scale); print("\n");
467487
print(" frag_copy_Scale : "); print(frag_copy_Scale); print("\n");
468488
print(" copy_iter_s: "); print(copy_iter_s); print("\n");
469489

@@ -638,7 +658,6 @@ std::cout << std::endl;
638658
const int scale_k = cute::ceil_div(k, blocksize);
639659
StrideScale stride_S = cutlass::make_cute_packed_stride(StrideScale{}, cute::make_shape(n, scale_k, l));
640660
std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
641-
642661
auto mScale = make_tensor(
643662
make_gmem_ptr(absmax_),
644663
make_layout(make_shape(n, scale_k, l), stride_S));

tests/test_xpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
6666
#for i in range(iters):
6767
#pdb.set_trace()
6868
if kind == "fc1":
69-
A = torch.ones(32, dim, dtype=dtype, device=device)
70-
B = torch.ones(dim, dim, dtype=dtype, device=device) / math.sqrt(dim)
69+
A = torch.randn(32, dim, dtype=dtype, device=device)
70+
B = torch.randn(dim, dim, dtype=dtype, device=device) / math.sqrt(dim)
7171
elif kind == "fc2":
7272
A = torch.randn(1, 4 * dim, dtype=dtype, device=device)
7373
B = torch.randn(dim, 4 * dim, dtype=dtype, device=device) / math.sqrt(dim)

0 commit comments

Comments
 (0)