Skip to content

Commit 1ff0a4c

Browse files
committed
save code
1 parent cad6505 commit 1ff0a4c

1 file changed

Lines changed: 70 additions & 24 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -208,28 +208,32 @@ class gemm_4bit_cutlass_kernel {
208208
Tensor tCgA = thr_mma.partition_A(gA);
209209
Tensor tCgB = thr_mma.partition_B(gB); //values for each_thread (FrgV,(RestN,RestK),*)
210210

211-
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
211+
Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
212212
Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
213213
Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout());
214214

215+
Tensor mma_A_b = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
215216
Tensor mma_B_b = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
216217
Tensor dequant_frag_b = make_tensor<ElementB>(mma_B_b.layout());
217218

218219
static constexpr auto scale_shape_t = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
219220
static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
220221
static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE;
221222
using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>; //[1, dequant_N, block_num]
222-
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
223+
Tensor fragment_scale_a = make_tensor<ElementScale>(FragScaleLayout{});
224+
Tensor fragment_scale_b = make_tensor<ElementScale>(FragScaleLayout{});
223225

224226
// static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
225227
// static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
226228
// static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
227229

228-
Tensor frag_copy_A = thr_copy_A.retile_D(mma_A);
230+
Tensor frag_copy_A_a = thr_copy_A.retile_D(mma_A_a);
229231
Tensor frag_copy_B_a = thr_copy_B.retile_D(dequant_frag_a);
230-
Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
232+
Tensor frag_copy_Scale_a = thr_copy_scale.retile_D(fragment_scale_a);
231233

234+
Tensor frag_copy_A_b = thr_copy_A.retile_D(mma_A_b);
232235
Tensor frag_copy_B_b = thr_copy_B.retile_D(dequant_frag_b);
236+
Tensor frag_copy_Scale_b = thr_copy_scale.retile_D(fragment_scale_b);
233237

234238
Tensor tAgA = thr_copy_A.retile_S(tCgA);
235239
Tensor tBgB = thr_copy_B.retile_S(tCgB);
@@ -284,7 +288,7 @@ class gemm_4bit_cutlass_kernel {
284288
#pragma unroll
285289
for (int c = 0; c < src_compress_size; c++) {
286290
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
287-
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
291+
float scale_value = fragment_scale_a((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
288292
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
289293
lut_id = (lut_id + 1) % LUT_NUM;
290294
}
@@ -326,7 +330,7 @@ class gemm_4bit_cutlass_kernel {
326330
#pragma unroll
327331
for (int c = 0; c < src_compress_size; c++) {
328332
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
329-
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
333+
float scale_value = fragment_scale_b((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
330334
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
331335
lut_id = (lut_id + 1) % LUT_NUM;
332336
}
@@ -349,33 +353,75 @@ class gemm_4bit_cutlass_kernel {
349353
int start_lut_id = sg_idx % LUT_NUM;
350354

351355
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B_a);
352-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + 0) * BLK_K/params.group_size), frag_copy_Scale);
353-
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), frag_copy_A);
356+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + 0) * BLK_K/params.group_size), frag_copy_Scale_a);
357+
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), frag_copy_A_a);
354358

355-
for (int k_tile = k_start_idx + 1, k_s = 0 + 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
356-
bool is_odd_tile = k_tile % 2 != 0;
359+
if (prefetch_k < k_tile_count) {
360+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
361+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
362+
}
363+
364+
prefetch_k++;
357365

358-
if(is_odd_tile){
366+
for (int k_tile = k_start_idx + 1, k_s = 0 + 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
367+
if(k_tile % 2 != 0){
368+
dequant_a(start_lut_id);
359369
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B_b);
370+
371+
//dequant_a(start_lut_id);
372+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_b);
373+
374+
//dequant_a(start_lut_id);
375+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_b);
376+
377+
//dequant_a(start_lut_id);
378+
if (prefetch_k < k_tile_count) {
379+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
380+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
381+
}
382+
383+
//dequant_a(start_lut_id);
384+
cute::gemm(tiled_mma, mma_A_a, mma_B_a, accumulators);
385+
barrier_wait(3);
386+
387+
//copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_a);
388+
//copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_a);
389+
390+
//if (prefetch_k < k_tile_count) {
391+
// prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
392+
// //prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
393+
//}
360394
} else {
395+
dequant_b(start_lut_id);
361396
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B_a);
362-
}
363397

364-
is_odd_tile ? dequant_a(start_lut_id) : dequant_b(start_lut_id);
398+
//dequant_b(start_lut_id);
399+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_a);
365400

366-
if (prefetch_k < k_tile_count) {
367-
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
368-
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
369-
}
370-
371-
cute::gemm(tiled_mma, mma_A, is_odd_tile ? mma_B_a : mma_B_b, accumulators);
372-
barrier_wait(3);
401+
//dequant_b(start_lut_id);
402+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_a);
373403

374-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
375-
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
404+
//dequant_b(start_lut_id);
405+
if (prefetch_k < k_tile_count) {
406+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
407+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
408+
}
409+
410+
//dequant_b(start_lut_id);
411+
cute::gemm(tiled_mma, mma_A_b, mma_B_b, accumulators);
412+
barrier_wait(3);
413+
414+
//copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_a);
415+
//copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_a);
416+
417+
//if (prefetch_k < k_tile_count) {
418+
// prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
419+
// //prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
420+
//}
421+
}
376422
}
377-
cute::gemm(tiled_mma, mma_A, mma_B_b, accumulators);
378-
barrier_wait(3);
423+
cute::gemm(tiled_mma, mma_A_a, mma_B_b, accumulators);
424+
//barrier_wait(3);
379425

380426
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // atom numbers per thread; A frags per sub_group
381427
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // atom numbers per thread; B frags per sub_group

0 commit comments

Comments
 (0)