Skip to content

Commit 3fa1119

Browse files
committed
save code
1 parent 1ff0a4c commit 3fa1119

1 file changed

Lines changed: 121 additions & 10 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 121 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ class gemm_4bit_cutlass_kernel {
207207

208208
Tensor tCgA = thr_mma.partition_A(gA);
209209
Tensor tCgB = thr_mma.partition_B(gB); //values for each_thread (FrgV,(RestN,RestK),*)
210-
210+
211+
#if 1
211212
Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
212213
Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
213214
Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout());
@@ -235,6 +236,30 @@ class gemm_4bit_cutlass_kernel {
235236
Tensor frag_copy_B_b = thr_copy_B.retile_D(dequant_frag_b);
236237
Tensor frag_copy_Scale_b = thr_copy_scale.retile_D(fragment_scale_b);
237238

239+
auto layout_A = make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape());
240+
Tensor mma_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_A.shape(), Int<2>{}), cute::append(layout_A.stride(), Int<0>{})));
241+
#else
242+
auto layout_A = make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape());
243+
auto layout_B = make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape());
244+
245+
Tensor mma_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_A.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_A.stride(), 0)));
246+
Tensor mma_B = make_tensor<ElementMMA>(cute::make_layout(cute::append(layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_B.stride(), 0)));
247+
Tensor dequant_frag = make_tensor<ElementB>(cute::make_layout(cute::append(layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(layout_B.stride(), 0)));
248+
249+
static constexpr auto scale_shape_t = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
250+
static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
251+
static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE;
252+
using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>;
253+
Tensor fragment_scale = make_tensor<ElementScale>(cute::make_layout(cute::append(FragScaleLayout{}.shape(), cute::make_shape(Int<2>{})), cute::make_stride(FragScaleLayout{}.stride(), 0)));
254+
255+
auto single_layout_A = thr_copy_A.retile_D(cute::make_tensor(mma_A.data(), layout_A)).layout();
256+
auto single_layout_B = thr_copy_B.retile_D(cute::make_tensor(dequant_frag.data(), layout_B)).layout();
257+
auto single_layout_Scale = thr_copy_scale.retile_D(cute::make_tensor(fragment_scale.data(), FragScaleLayout{})).layout();
258+
259+
Tensor frag_copy_A = make_tensor<ElementMMA>(cute::make_layout(cute::append(single_layout_A.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_A.stride(), 0)));
260+
Tensor frag_copy_B = make_tensor<ElementB>(cute::make_layout(cute::append(single_layout_B.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_B.stride(), 0)));
261+
Tensor frag_copy_Scale = make_tensor<float>(cute::make_layout(cute::append(single_layout_Scale.shape(), cute::make_shape(Int<2>{})), cute::make_stride(single_layout_Scale.stride(), 0)));
262+
#endif
238263
Tensor tAgA = thr_copy_A.retile_S(tCgA);
239264
Tensor tBgB = thr_copy_B.retile_S(tCgB);
240265

@@ -260,10 +285,103 @@ class gemm_4bit_cutlass_kernel {
260285
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
261286
int prefetch_k = k_start_idx;
262287

288+
CUTLASS_PRAGMA_UNROLL
289+
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
290+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
291+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
292+
}
293+
294+
int start_lut_id = sg_idx % LUT_NUM;
295+
296+
#if 0
297+
auto dequant = [&](int start_lut_id, int buffer_idx) {
298+
constexpr int N = decltype(cute::size<1>(mma_B))::value;
299+
constexpr int K = decltype(cute::size(mma_B))::value / N;
300+
301+
using src_compress_type = uint32_t;
302+
using dst_compress_type = uint32_t;
303+
304+
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
305+
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
306+
constexpr int src_vec_size = 4;
307+
308+
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
309+
constexpr int dst_vec_size = 4;
310+
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
311+
312+
size_t dequant_offset = buffer_idx * dequant_frag.size() / 2;
313+
size_t scale_offset = buffer_idx * fragment_scale.size() / 2;
314+
size_t mma_offset = buffer_idx * mma_B.size() / 2;
315+
316+
auto* dequant_ptr = cute::raw_pointer_cast(dequant_frag.data()) + dequant_offset;
317+
auto* scale_ptr = cute::raw_pointer_cast(fragment_scale.data()) + scale_offset;
318+
auto* mma_ptr = cute::raw_pointer_cast(mma_B.data()) + mma_offset;
319+
320+
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
321+
322+
int lut_id = start_lut_id;
323+
#pragma unroll
324+
for (int n = 0; n < N; n++) {
325+
326+
#pragma unroll
327+
for (int l = 0; l < src_loop_num; l++) {
328+
329+
#pragma unroll
330+
for (int v = 0; v < src_vec_size; v++) {
331+
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(dequant_ptr)[n*src_loop_num + l][v];
332+
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
333+
334+
#pragma unroll
335+
for (int c = 0; c < src_compress_size; c++) {
336+
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
337+
float scale_value = *reinterpret_cast<float*>(scale_ptr + ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE))));
338+
339+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
340+
lut_id = (lut_id + 1) % LUT_NUM;
341+
}
342+
}
343+
}
344+
345+
#pragma unroll
346+
for (int l = 0; l < dst_loop_num; l++) {
347+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(mma_ptr)[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
348+
}
349+
}
350+
};
351+
352+
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B(_,_,_,0));
353+
copy(params.tiled_copy_scale, tSgS(_,_,_,k_start_idx * BLK_K/params.group_size), frag_copy_Scale(_,_,_,0));
354+
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), frag_copy_A(_,_,_,0));
355+
356+
if (prefetch_k < k_tile_count) {
357+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
358+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
359+
}
360+
prefetch_k++;
361+
362+
for (int k_tile = k_start_idx + 1, k_s = 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
363+
const int buf_idx = k_tile % 2;
364+
365+
dequant(start_lut_id, buf_idx);
366+
367+
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B(_,_,_,buf_idx));
368+
copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), frag_copy_Scale(_,_,_,buf_idx));
369+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A(_,_,_,buf_idx));
370+
371+
if (prefetch_k < k_tile_count) {
372+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
373+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
374+
}
375+
376+
cute::gemm(tiled_mma, frag_copy_A(_,_,_,1-buf_idx), frag_copy_B(_,_,_,1-buf_idx), accumulators);
377+
barrier_wait(3);
378+
}
379+
cute::gemm(tiled_mma, frag_copy_A(_,_,_,1), frag_copy_B(_,_,_,1), accumulators);
380+
#else
263381
auto dequant_a = [&] (int start_lut_id){
264382
constexpr int N = decltype(cute::size<1>(mma_B_a))::value;
265383
constexpr int K = decltype(cute::size(mma_B_a))::value / N;
266-
384+
267385
using src_compress_type = uint32_t;
268386
using dst_compress_type = uint32_t;
269387
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
@@ -344,14 +462,6 @@ class gemm_4bit_cutlass_kernel {
344462
}
345463
};
346464

347-
CUTLASS_PRAGMA_UNROLL
348-
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
349-
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
350-
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
351-
}
352-
353-
int start_lut_id = sg_idx % LUT_NUM;
354-
355465
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B_a);
356466
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + 0) * BLK_K/params.group_size), frag_copy_Scale_a);
357467
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), frag_copy_A_a);
@@ -422,6 +532,7 @@ class gemm_4bit_cutlass_kernel {
422532
}
423533
cute::gemm(tiled_mma, mma_A_a, mma_B_b, accumulators);
424534
//barrier_wait(3);
535+
#endif
425536

426537
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // atom numbers per thread; A frags per sub_group
427538
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // atom numbers per thread; B frags per sub_group

0 commit comments

Comments
 (0)