Skip to content

Commit 91afd70

Browse files
committed
save code
1 parent e3b8b2b commit 91afd70

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,11 @@ class gemm_4bit_cutlass_kernel {
252252
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
253253
int prefetch_k = k_start_idx;
254254

255-
auto dequant = [&] (int start_lut_id){
255+
auto dequant = [&] (int start_lut_id, int k_tile, int k_s){
256+
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
257+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
258+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
259+
256260
constexpr int N = decltype(cute::size<1>(mma_B))::value;
257261
constexpr int K = decltype(cute::size(mma_B))::value / N;
258262

@@ -303,11 +307,11 @@ class gemm_4bit_cutlass_kernel {
303307
int start_lut_id = sg_idx % LUT_NUM;
304308

305309
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
306-
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
307-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
308-
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
310+
//copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
311+
//copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
312+
//copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
309313

310-
dequant(start_lut_id);
314+
dequant(start_lut_id, k_tile, k_s);
311315

312316
if (prefetch_k < k_tile_count) {
313317
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));

0 commit comments

Comments
 (0)