Skip to content

Commit f6ec6ca

Browse files
committed
optimize dequant
1 parent 224af41 commit f6ec6ca

2 files changed

Lines changed: 19 additions & 15 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,18 @@ using ElementOutput = float;
5252

5353
using ProblemShape = Shape<int, int, int, int>;
5454

55+
#if 1
5556
using TileShape = Shape<_256, _256, _32>;
57+
//using TileShape = Shape<_128, _256, _32>;
5658
using TiledMma =
5759
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
5860
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
61+
#else
62+
using TileShape = Shape<_16, _64, _64>;
63+
using TiledMma =
64+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
65+
Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
66+
#endif
5967

6068
using WorkgroupTileShape = TileShape;
6169
static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); //256 //16
@@ -204,8 +212,7 @@ class gemm_4bit_cutlass_kernel {
204212
Tensor<EngineIn, LayoutIn> const& in,
205213
Tensor<EngineOut, LayoutOut>& out,
206214
Tensor<EngineScales, LayoutScales>& tCrS_input,
207-
T* quant_map,
208-
int n_coord, int thread_idx, int k_start_idx, int k_s, int k_reload_factor, int s_idx
215+
T* quant_map
209216
) {
210217
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
211218
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
@@ -243,13 +250,9 @@ class gemm_4bit_cutlass_kernel {
243250
auto& dst = *(cute::array<DstType, vec_size>*)(d_tensor(_, s, n).data());
244251

245252
CUTLASS_PRAGMA_UNROLL
246-
for (int i = 0; i < vec_size; i++) {
247-
uint8_t value = (format_data >> (src_bits * i)) & 0xf;
248-
if(i % 2 != 0) { //1,3, high_4bit
249-
dst[i-1] = static_cast<DstType>(quant_map[value] * ts);
250-
} else {
251-
dst[i+1] = static_cast<DstType>(quant_map[value] * ts);
252-
}
253+
for (int i = 0; i < vec_size/2; i++) {
254+
dst[i * 2] = static_cast<DstType>(quant_map[(format_data >> (src_bits * (i * 2 + 1))) & 0xf] * ts);
255+
dst[i * 2 + 1] = static_cast<DstType>(quant_map[(format_data >> (src_bits * (i * 2))) & 0xf] * ts);
253256
}
254257
}
255258
}
@@ -280,7 +283,7 @@ class gemm_4bit_cutlass_kernel {
280283
if (thread_idx < 16) {
281284
quant_map[thread_idx] = T(datatype[thread_idx]);
282285
}
283-
barrier_wait(1);
286+
barrier_arrive(3);
284287

285288
auto blk_shape = TileShape{};
286289
int m_coord, n_coord, l_coord;
@@ -362,7 +365,7 @@ class gemm_4bit_cutlass_kernel {
362365
}();
363366

364367
#if 0
365-
if (thread_idx==16 && n_coord == 0 && l_coord==1) {
368+
if (thread_idx==0 && n_coord == 0 && l_coord==0) {
366369
print("\n\n======================= A: \n");
367370
print(" gA : "); print(gA); print("\n");
368371
print(" tCgA : "); print(tCgA); print("\n");
@@ -423,7 +426,7 @@ class gemm_4bit_cutlass_kernel {
423426
}
424427

425428
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++, k_s++) {
426-
barrier_arrive(2);
429+
//barrier_arrive(3);
427430

428431
copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
429432
copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
@@ -438,11 +441,11 @@ class gemm_4bit_cutlass_kernel {
438441
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
439442
}
440443

441-
dequant(dequant_frag, mma_B, fragment_scale, quant_map, n_coord, thread_idx, k_start_idx, k_s, k_reload_factor, s_idx);
444+
dequant(dequant_frag, mma_B, fragment_scale, quant_map);
442445

443446

444447
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
445-
barrier_wait(2);
448+
barrier_wait(3);
446449
}
447450

448451
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>((char*)nullptr);
@@ -467,7 +470,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
467470

468471
using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS>;
469472

470-
static constexpr int smem_size= 256;
473+
static constexpr int smem_size= 16*16/2;
471474

472475
auto problem_size = ProblemShape{m, n, k, l};
473476

tests/test_xpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
306306
#pdb.set_trace()
307307
C2 = F.gemv_4bit(A, qB.t(), state=state)
308308
#print("C2[0] = ", C2[0])
309+
#pdb.set_trace()
309310
#A.requires_grad = True
310311
C1 = bnb.matmul_4bit(A, qB.t(), state)
311312
#pdb.set_trace()

0 commit comments

Comments
 (0)