Skip to content

Commit 05efb69

Browse files
committed
save code
1 parent 865a62f commit 05efb69

1 file changed

Lines changed: 40 additions & 40 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -252,56 +252,56 @@ class gemm_4bit_cutlass_kernel {
252252
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
253253
int prefetch_k = k_start_idx;
254254

255-
auto copy_and_dequant = [&] (int start_lut_id, int k_tile, int k_s){
256-
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
257-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
258-
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
255+
auto copy_and_dequant = [&] (int start_lut_id, int k_tile, int k_s){
256+
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
257+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
258+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
259259

260-
constexpr int N = decltype(cute::size<1>(mma_B))::value;
261-
constexpr int K = decltype(cute::size(mma_B))::value / N;
260+
constexpr int N = decltype(cute::size<1>(mma_B))::value;
261+
constexpr int K = decltype(cute::size(mma_B))::value / N;
262262

263-
using src_compress_type = uint32_t;
264-
using dst_compress_type = uint32_t;
265-
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
266-
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
267-
constexpr int src_vec_size = 4;
268-
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
269-
270-
constexpr int dst_vec_size = 4; //src_vec_size;
271-
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size; //src_loop_num;
272-
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
273-
274-
int lut_id = start_lut_id;
263+
using src_compress_type = uint32_t;
264+
using dst_compress_type = uint32_t;
265+
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
266+
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
267+
constexpr int src_vec_size = 4;
268+
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
269+
270+
constexpr int dst_vec_size = 4; //src_vec_size;
271+
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size; //src_loop_num;
272+
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
273+
274+
int lut_id = start_lut_id;
275+
#pragma unroll
276+
for (int n = 0; n < N; n++) {
275277
#pragma unroll
276-
for (int n = 0; n < N; n++) {
277-
#pragma unroll
278-
for (int l = 0; l < src_loop_num; l++) {
278+
for (int l = 0; l < src_loop_num; l++) {
279279

280+
#pragma unroll
281+
for (int v = 0; v < src_vec_size; v++) {
282+
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l][v];
283+
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
280284
#pragma unroll
281-
for (int v = 0; v < src_vec_size; v++) {
282-
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l][v];
283-
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
284-
#pragma unroll
285-
for (int c = 0; c < src_compress_size; c++) {
286-
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
287-
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
288-
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
289-
lut_id = (lut_id + 1) % LUT_NUM;
290-
}
285+
for (int c = 0; c < src_compress_size; c++) {
286+
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
287+
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
288+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
289+
lut_id = (lut_id + 1) % LUT_NUM;
291290
}
292291
}
293-
294-
#pragma unroll
295-
for (int l = 0; l < dst_loop_num; l++) {
296-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
297-
}
298292
}
299293

300-
if (prefetch_k < k_tile_count) {
301-
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
302-
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
294+
#pragma unroll
295+
for (int l = 0; l < dst_loop_num; l++) {
296+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
303297
}
304-
};
298+
}
299+
300+
if (prefetch_k < k_tile_count) {
301+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
302+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
303+
}
304+
};
305305

306306
CUTLASS_PRAGMA_UNROLL
307307
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {

0 commit comments

Comments
 (0)