Skip to content

Commit bc45998

Browse files
committed
save code
1 parent 3fa1119 commit bc45998

1 file changed

Lines changed: 49 additions & 41 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class gemm_4bit_cutlass_kernel {
208208
Tensor tCgA = thr_mma.partition_A(gA);
209209
Tensor tCgB = thr_mma.partition_B(gB); //values for each_thread (FrgV,(RestN,RestK),*)
210210

211-
#if 1
211+
#if 0
212212
Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
213213
Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
214214
Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout());
@@ -239,26 +239,39 @@ class gemm_4bit_cutlass_kernel {
239239
auto layout_A = make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape());
240240
Tensor mma_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_A.shape(), Int<2>{}), cute::append(layout_A.stride(), Int<0>{})));
241241
#else
242-
auto layout_A = make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape());
243-
auto layout_B = make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape());
244-
245-
Tensor mma_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_A.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_A.stride(), 0)));
246-
Tensor mma_B = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_B.stride(), 0)));
247-
Tensor dequant_frag = make_tensor<ElementB>(cute::make_layout(cute::append(layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_B.stride(), 0)));
248-
242+
Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
243+
Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
244+
Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout());
245+
246+
Tensor mma_A_b = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
247+
Tensor mma_B_b = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
248+
Tensor dequant_frag_b = make_tensor<ElementB>(mma_B_b.layout());
249+
249250
static constexpr auto scale_shape_t = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
250251
static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
251252
static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE;
252-
using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>;
253-
Tensor fragment_scale = make_tensor<ElementScale>(cute::make_layout(cute::append(FragScaleLayout{}.shape(), cute::make_shape(Int<2>{})), cute::make_stride(FragScaleLayout{}.stride(), 0)));
254-
255-
auto single_layout_A = thr_copy_A.retile_D(cute::make_tensor(mma_A.data(), layout_A)).layout();
256-
auto single_layout_B = thr_copy_B.retile_D(cute::make_tensor(dequant_frag.data(), layout_B)).layout();
257-
auto single_layout_Scale = thr_copy_scale.retile_D(cute::make_tensor(fragment_scale.data(), FragScaleLayout{})).layout();
258-
259-
Tensor frag_copy_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(single_layout_A.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_A.stride(), 0)));
260-
Tensor frag_copy_B = make_tensor<ElementB>(cute::make_layout(cute::append(single_layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_B.stride(), 0)));
261-
Tensor frag_copy_Scale = make_tensor<float>(cute::make_layout(cute::append(single_layout_Scale.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_Scale.stride(), 0)));
253+
using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>; //[1, dequant_N, block_num]
254+
Tensor fragment_scale_a = make_tensor<ElementScale>(FragScaleLayout{});
255+
Tensor fragment_scale_b = make_tensor<ElementScale>(FragScaleLayout{});
256+
257+
Tensor frag_copy_A_a = thr_copy_A.retile_D(mma_A_a);
258+
Tensor frag_copy_B_a = thr_copy_B.retile_D(dequant_frag_a);
259+
Tensor frag_copy_Scale_a = thr_copy_scale.retile_D(fragment_scale_a);
260+
261+
Tensor frag_copy_A_b = thr_copy_A.retile_D(mma_A_b);
262+
Tensor frag_copy_B_b = thr_copy_B.retile_D(dequant_frag_b);
263+
Tensor frag_copy_Scale_b = thr_copy_scale.retile_D(fragment_scale_b);
264+
265+
cute::tuple<decltype(mma_A_a), decltype(mma_A_b)> mma_A(mma_A_a, mma_A_b);
266+
cute::tuple<decltype(mma_B_a), decltype(mma_B_b)> mma_B(mma_B_a, mma_B_b);
267+
cute::tuple<decltype(dequant_frag_a), decltype(dequant_frag_b)> dequant_frag(dequant_frag_a, dequant_frag_b);
268+
cute::tuple<decltype(fragment_scale_a), decltype(fragment_scale_b)> fragment_scale(fragment_scale_a, fragment_scale_b);
269+
cute::tuple<decltype(frag_copy_A_a), decltype(frag_copy_A_b)> frag_copy_A(frag_copy_A_a, frag_copy_A_b);
270+
cute::tuple<decltype(frag_copy_B_a), decltype(frag_copy_B_b)> frag_copy_B(frag_copy_B_a, frag_copy_B_b);
271+
cute::tuple<decltype(frag_copy_Scale_a), decltype(frag_copy_Scale_b)> frag_copy_Scale(frag_copy_Scale_a, frag_copy_Scale_b);
272+
//auto& mma_A_0 = cute::get<0>(mma_A_tuple); // 引用 mma_A_a
273+
//auto& mma_A_1 = cute::get<1>(mma_A_tuple); // 引用 mma_A_b
274+
262275
#endif
263276
Tensor tAgA = thr_copy_A.retile_S(tCgA);
264277
Tensor tBgB = thr_copy_B.retile_S(tCgB);
@@ -293,8 +306,8 @@ class gemm_4bit_cutlass_kernel {
293306

294307
int start_lut_id = sg_idx % LUT_NUM;
295308

296-
#if 0
297-
auto dequant = [&](int start_lut_id, int buffer_idx) {
309+
#if 1
310+
auto dequant = [&](int start_lut_id, const int buffer_idx) {
298311
constexpr int N = decltype(cute::size<1>(mma_B))::value;
299312
constexpr int K = decltype(cute::size(mma_B))::value / N;
300313

@@ -309,14 +322,6 @@ class gemm_4bit_cutlass_kernel {
309322
constexpr int dst_vec_size = 4;
310323
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
311324

312-
size_t dequant_offset = buffer_idx * dequant_frag.size() / 2;
313-
size_t scale_offset = buffer_idx * fragment_scale.size() / 2;
314-
size_t mma_offset = buffer_idx * mma_B.size() / 2;
315-
316-
auto* dequant_ptr = cute::raw_pointer_cast(dequant_frag.data()) + dequant_offset;
317-
auto* scale_ptr = cute::raw_pointer_cast(fragment_scale.data()) + scale_offset;
318-
auto* mma_ptr = cute::raw_pointer_cast(mma_B.data()) + mma_offset;
319-
320325
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
321326

322327
int lut_id = start_lut_id;
@@ -328,13 +333,13 @@ class gemm_4bit_cutlass_kernel {
328333

329334
#pragma unroll
330335
for (int v = 0; v < src_vec_size; v++) {
331-
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(dequant_ptr)[n*src_loop_num + l][v];
336+
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(cute::get<buffer_idx>(dequant_frag).data()))[n*src_loop_num + l][v];
332337
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
333338

334339
#pragma unroll
335340
for (int c = 0; c < src_compress_size; c++) {
336341
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
337-
float scale_value = *reinterpret_cast<float*>(scale_ptr + ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE))));
342+
float scale_value = cute::get<buffer_idx>(fragment_scale_a)((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
338343

339344
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
340345
lut_id = (lut_id + 1) % LUT_NUM;
@@ -344,14 +349,14 @@ class gemm_4bit_cutlass_kernel {
344349

345350
#pragma unroll
346351
for (int l = 0; l < dst_loop_num; l++) {
347-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(mma_ptr)[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
352+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(cute::get<buffer_idx>(mma_B_a).data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
348353
}
349354
}
350355
};
351356

352-
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B(_,_,_,0));
353-
copy(params.tiled_copy_scale, tSgS(_,_,_,k_start_idx * BLK_K/params.group_size), frag_copy_Scale(_,_,_,0));
354-
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), frag_copy_A(_,_,_,0));
357+
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), cute::get<0>(frag_copy_B));
358+
copy(params.tiled_copy_scale, tSgS(_,_,_,k_start_idx * BLK_K/params.group_size), cute::get<0>(frag_copy_Scale));
359+
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), cute::get<0>(frag_copy_A));
355360

356361
if (prefetch_k < k_tile_count) {
357362
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
@@ -360,24 +365,27 @@ class gemm_4bit_cutlass_kernel {
360365
prefetch_k++;
361366

362367
for (int k_tile = k_start_idx + 1, k_s = 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
363-
const int buf_idx = k_tile % 2;
368+
constexpr int buf_idx = k_tile % 2;
364369

365370
dequant(start_lut_id, buf_idx);
366371

367-
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B(_,_,_,buf_idx));
368-
copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), frag_copy_Scale(_,_,_,buf_idx));
369-
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A(_,_,_,buf_idx));
372+
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), cute::get<buf_idx>(frag_copy_B));
373+
copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), cute::get<buf_idx>(frag_copy_Scale));
374+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), cute::get<buf_idx>(frag_copy_A));
370375

371376
if (prefetch_k < k_tile_count) {
372377
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
373378
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
374379
}
375-
376-
cute::gemm(tiled_mma, frag_copy_A(_,_,_,1-buf_idx), frag_copy_B(_,_,_,1-buf_idx), accumulators);
380+
381+
constexpr int idx = 1 - buf_idx;
382+
cute::gemm(tiled_mma, cute::get<idx>(frag_copy_A), cute::get<idx>(frag_copy_B), accumulators);
377383
barrier_wait(3);
378384
}
379-
cute::gemm(tiled_mma, frag_copy_A(_,_,_,1), frag_copy_B(_,_,_,1), accumulators);
385+
cute::gemm(tiled_mma, cute::get<1>(frag_copy_A), cute::get<1>(frag_copy_B), accumulators);
386+
380387
#else
388+
381389
auto dequant_a = [&] (int start_lut_id){
382390
constexpr int N = decltype(cute::size<1>(mma_B_a))::value;
383391
constexpr int K = decltype(cute::size(mma_B_a))::value / N;

0 commit comments

Comments
 (0)