Skip to content

Commit ae4b5e2

Browse files
committed
a,b buffer
1 parent 05efb69 commit ae4b5e2

1 file changed

Lines changed: 44 additions & 32 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -208,24 +208,32 @@ class gemm_4bit_cutlass_kernel {
208208
Tensor tCgA = thr_mma.partition_A(gA);
209209
Tensor tCgB = thr_mma.partition_B(gB); //values for each_thread (FrgV,(RestN,RestK),*)
210210

211-
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
212-
Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
211+
Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
212+
Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
213+
Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout());
213214

214-
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
215+
Tensor mma_A_b = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
216+
Tensor mma_B_b = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
217+
Tensor dequant_frag_b = make_tensor<ElementB>(mma_B_b.layout());
215218

216219
static constexpr auto scale_shape_t = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
217220
static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
218221
static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE;
219222
using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>; //[1, dequant_N, block_num]
220-
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
223+
Tensor fragment_scale_a = make_tensor<ElementScale>(FragScaleLayout{});
224+
Tensor fragment_scale_b = make_tensor<ElementScale>(FragScaleLayout{});
221225

222226
// static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
223227
// static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
224228
// static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
225229

226-
Tensor frag_copy_A = thr_copy_A.retile_D(mma_A);
227-
Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
228-
Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
230+
Tensor frag_copy_A_a = thr_copy_A.retile_D(mma_A_a);
231+
Tensor frag_copy_B_a = thr_copy_B.retile_D(dequant_frag_a);
232+
Tensor frag_copy_Scale_a = thr_copy_scale.retile_D(fragment_scale_a);
233+
234+
Tensor frag_copy_A_b = thr_copy_A.retile_D(mma_A_b);
235+
Tensor frag_copy_B_b = thr_copy_B.retile_D(dequant_frag_b);
236+
Tensor frag_copy_Scale_b = thr_copy_scale.retile_D(fragment_scale_b);
229237

230238
Tensor tAgA = thr_copy_A.retile_S(tCgA);
231239
Tensor tBgB = thr_copy_B.retile_S(tCgB);
@@ -252,13 +260,9 @@ class gemm_4bit_cutlass_kernel {
252260
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
253261
int prefetch_k = k_start_idx;
254262

255-
auto copy_and_dequant = [&] (int start_lut_id, int k_tile, int k_s){
256-
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
257-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
258-
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
259-
260-
constexpr int N = decltype(cute::size<1>(mma_B))::value;
261-
constexpr int K = decltype(cute::size(mma_B))::value / N;
263+
auto dequant = [&] (int start_lut_id, int k_tile){
264+
constexpr int N = decltype(cute::size<1>(mma_B_a))::value;
265+
constexpr int K = decltype(cute::size(mma_B_a))::value / N;
262266

263267
using src_compress_type = uint32_t;
264268
using dst_compress_type = uint32_t;
@@ -279,12 +283,12 @@ class gemm_4bit_cutlass_kernel {
279283

280284
#pragma unroll
281285
for (int v = 0; v < src_vec_size; v++) {
282-
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l][v];
286+
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(k_tile % 2 != 0 ? cute::raw_pointer_cast(dequant_frag_a.data()): cute::raw_pointer_cast(dequant_frag_b.data()))[n*src_loop_num + l][v];
283287
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
284288
#pragma unroll
285289
for (int c = 0; c < src_compress_size; c++) {
286290
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
287-
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
291+
float scale_value = k_tile % 2 != 0 ? fragment_scale_a((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE))) : fragment_scale_b((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
288292
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
289293
lut_id = (lut_id + 1) % LUT_NUM;
290294
}
@@ -293,14 +297,9 @@ class gemm_4bit_cutlass_kernel {
293297

294298
#pragma unroll
295299
for (int l = 0; l < dst_loop_num; l++) {
296-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
300+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(k_tile % 2 != 0 ? mma_B_a.data() : mma_B_b.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
297301
}
298302
}
299-
300-
if (prefetch_k < k_tile_count) {
301-
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
302-
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
303-
}
304303
};
305304

306305
CUTLASS_PRAGMA_UNROLL
@@ -311,21 +310,34 @@ class gemm_4bit_cutlass_kernel {
311310

312311
int start_lut_id = sg_idx % LUT_NUM;
313312

314-
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
315-
//copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
316-
//copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
317-
//copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
313+
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B_a);
314+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + 0) * BLK_K/params.group_size), frag_copy_Scale_a);
315+
copy(params.tiled_copy_a, tAgA(_,_,_,0), frag_copy_A_a);
316+
317+
for (int k_tile = k_start_idx + 1, k_s = 0 + 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
318+
if(k_tile % 2 != 0){
319+
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B_b);
320+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_b);
321+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_b);
322+
} else {
323+
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B_a);
324+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_a);
325+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_a);
326+
}
318327

319-
copy_and_dequant(start_lut_id, k_tile, k_s);
320328

321-
//if (prefetch_k < k_tile_count) {
322-
// prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
323-
// prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
324-
//}
329+
dequant(start_lut_id, k_tile);
325330

326-
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
331+
if (prefetch_k < k_tile_count) {
332+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
333+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
334+
}
335+
336+
k_tile % 2 != 0 ? cute::gemm(tiled_mma, mma_A_a, mma_B_a, accumulators) : cute::gemm(tiled_mma, mma_A_b, mma_B_b, accumulators);
327337
barrier_wait(3);
328338
}
339+
cute::gemm(tiled_mma, mma_A_b, mma_B_b, accumulators);
340+
barrier_wait(3);
329341

330342
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // atom numbers per thread; A frags per sub_group
331343
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // atom numbers per thread; B frags per sub_group

0 commit comments

Comments
 (0)