Skip to content

Commit feb23d2

Browse files
committed
save code
1 parent 446ad8c commit feb23d2

1 file changed

Lines changed: 14 additions & 28 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ using TiledMma =
6767
Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
6868
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969
using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
70-
constexpr int PipelineStages = 4;
70+
constexpr int PipelineStages = 2;
7171

7272
using MmaAtomShape = typename TiledMma::AtomShape_MNK;
7373
using WorkgroupTileShape = TileShape;
@@ -239,8 +239,8 @@ class gemm_4bit_cutlass_kernel {
239239
Tensor tAgA = thr_copy_A.retile_S(tCgA);
240240
Tensor tBgB = thr_copy_B.retile_S(tCgB);
241241

242-
auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(params.tiled_copy_a);;
243-
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(params.tiled_copy_b);;
242+
auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(params.tiled_copy_a);
243+
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(params.tiled_copy_b);
244244
auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx);
245245
auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx);
246246

@@ -273,33 +273,20 @@ class gemm_4bit_cutlass_kernel {
273273
using VecDstElemType = cute::array<ElementMMA, compress_size>;
274274
using VecDstType = cute::array<VecDstElemType, vec_size>;
275275

276-
auto s_tensor = cute::make_tensor((VecSrcType*)(cute::raw_pointer_cast(dequant_frag.data())), cute::make_shape(cute::Int<K / (compress_size * vec_size)>{}, cute::Int<N>{}));
277-
auto d_tensor = cute::make_tensor((VecDstType*)(cute::raw_pointer_cast(mma_B.data())), cute::make_shape(cute::Int<K / (compress_size * vec_size)>{}, cute::Int<N>{}));
278-
279-
//auto src_ = *(cute::array<VecSrcType, K / (compress_size * vec_size, N)>*)(s_tensor.data());
280-
//auto dst_ = *(cute::array<VecDstType, K / (compress_size * vec_size, N)>*)(d_tensor.data());
276+
float scale_value = fragment_scale(0);
277+
auto src = *(VecSrcType*)(cute::raw_pointer_cast(dequant_frag.data()));
278+
auto& dst = *(VecDstType*)(cute::raw_pointer_cast(mma_B.data()));
279+
VecDstType dst_val;
281280
#pragma unroll
282-
for (int n = 0; n < N; n++) {
283-
float scale_value = fragment_scale(n);
284-
auto src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor(_, n).data());
285-
auto& dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor(_, n).data());
286-
//auto& src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(src_[n]);
287-
//auto& dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(dst_[n]);
288-
#pragma unroll
289-
for (int k = 0; k < K / (compress_size * vec_size); k++) {
290-
VecDstType dst_val;
281+
for (int i = 0; i < vec_size; i++) {
282+
VecDstElemType dst_elem;
291283
#pragma unroll
292-
for (int i = 0; i < vec_size; i++) {
293-
VecDstElemType dst_elem;
294-
#pragma unroll
295-
for (int j = 0; j < compress_size; j++) {
296-
dst_elem[j] = static_cast<ElementMMA>(quant_map[(src[k][i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf] * scale_value);
297-
}
298-
dst_val[i] = dst_elem;
284+
for (int j = 0; j < compress_size; j++) {
285+
dst_elem[j] = static_cast<ElementMMA>(quant_map[(src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf] * scale_value);
299286
}
300-
dst[k] = dst_val;
301-
}
287+
dst_val[i] = dst_elem;
302288
}
289+
dst = dst_val;
303290
};
304291

305292
CUTLASS_PRAGMA_UNROLL
@@ -308,7 +295,7 @@ class gemm_4bit_cutlass_kernel {
308295
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
309296
}
310297

311-
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count; k_tile++, k_s++) {
298+
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
312299
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
313300
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale);
314301
//barrier_wait(3);
@@ -318,7 +305,6 @@ class gemm_4bit_cutlass_kernel {
318305
if (prefetch_k < k_tile_count) {
319306
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
320307
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
321-
prefetch_k++;
322308
}
323309

324310
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);

0 commit comments

Comments
 (0)