Skip to content

Commit 975ceee

Browse files
committed
save code, only B
1 parent ae4b5e2 commit 975ceee

1 file changed

Lines changed: 14 additions & 20 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -208,32 +208,28 @@ class gemm_4bit_cutlass_kernel {
208208
Tensor tCgA = thr_mma.partition_A(gA);
209209
Tensor tCgB = thr_mma.partition_B(gB); //values for each_thread (FrgV,(RestN,RestK),*)
210210

211-
Tensor mma_A_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
211+
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
212212
Tensor mma_B_a = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
213213
Tensor dequant_frag_a = make_tensor<ElementB>(mma_B_a.layout());
214214

215-
Tensor mma_A_b = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
216215
Tensor mma_B_b = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
217216
Tensor dequant_frag_b = make_tensor<ElementB>(mma_B_b.layout());
218217

219218
static constexpr auto scale_shape_t = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
220219
static constexpr auto scale_shape_n = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
221220
static constexpr auto scale_shape_k = BLK_K / GROUP_SIZE < 1 ? 1 : BLK_K / GROUP_SIZE;
222221
using FragScaleLayout = Layout<Shape<Int<scale_shape_t>, Int<scale_shape_n>, Int<scale_shape_k>>>; //[1, dequant_N, block_num]
223-
Tensor fragment_scale_a = make_tensor<ElementScale>(FragScaleLayout{});
224-
Tensor fragment_scale_b = make_tensor<ElementScale>(FragScaleLayout{});
222+
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
225223

226224
// static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
227225
// static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
228226
// static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
229227

230-
Tensor frag_copy_A_a = thr_copy_A.retile_D(mma_A_a);
228+
Tensor frag_copy_A = thr_copy_A.retile_D(mma_A);
231229
Tensor frag_copy_B_a = thr_copy_B.retile_D(dequant_frag_a);
232-
Tensor frag_copy_Scale_a = thr_copy_scale.retile_D(fragment_scale_a);
230+
Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
233231

234-
Tensor frag_copy_A_b = thr_copy_A.retile_D(mma_A_b);
235232
Tensor frag_copy_B_b = thr_copy_B.retile_D(dequant_frag_b);
236-
Tensor frag_copy_Scale_b = thr_copy_scale.retile_D(fragment_scale_b);
237233

238234
Tensor tAgA = thr_copy_A.retile_S(tCgA);
239235
Tensor tBgB = thr_copy_B.retile_S(tCgB);
@@ -283,12 +279,12 @@ class gemm_4bit_cutlass_kernel {
283279

284280
#pragma unroll
285281
for (int v = 0; v < src_vec_size; v++) {
286-
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(k_tile % 2 != 0 ? cute::raw_pointer_cast(dequant_frag_a.data()): cute::raw_pointer_cast(dequant_frag_b.data()))[n*src_loop_num + l][v];
282+
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(k_tile % 2 != 0 ? dequant_frag_a.data() : dequant_frag_b.data()))[n*src_loop_num + l][v];
287283
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
288284
#pragma unroll
289285
for (int c = 0; c < src_compress_size; c++) {
290286
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
291-
float scale_value = k_tile % 2 != 0 ? fragment_scale_a((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE))) : fragment_scale_b((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
287+
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
292288
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
293289
lut_id = (lut_id + 1) % LUT_NUM;
294290
}
@@ -311,32 +307,30 @@ class gemm_4bit_cutlass_kernel {
311307
int start_lut_id = sg_idx % LUT_NUM;
312308

313309
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B_a);
314-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + 0) * BLK_K/params.group_size), frag_copy_Scale_a);
315-
copy(params.tiled_copy_a, tAgA(_,_,_,0), frag_copy_A_a);
310+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + 0) * BLK_K/params.group_size), frag_copy_Scale);
311+
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), frag_copy_A);
316312

317313
for (int k_tile = k_start_idx + 1, k_s = 0 + 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
318314
if(k_tile % 2 != 0){
319315
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B_b);
320-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_b);
321-
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_b);
322316
} else {
323-
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), frag_copy_B_a);
324-
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale_a);
325-
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A_a);
317+
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B_a);
326318
}
327319

328-
329320
dequant(start_lut_id, k_tile);
330321

331322
if (prefetch_k < k_tile_count) {
332323
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
333324
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
334325
}
335326

336-
k_tile % 2 != 0 ? cute::gemm(tiled_mma, mma_A_a, mma_B_a, accumulators) : cute::gemm(tiled_mma, mma_A_b, mma_B_b, accumulators);
327+
cute::gemm(tiled_mma, mma_A, k_tile % 2 != 0 ? mma_B_a : mma_B_b, accumulators);
337328
barrier_wait(3);
329+
330+
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
331+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
338332
}
339-
cute::gemm(tiled_mma, mma_A_b, mma_B_b, accumulators);
333+
cute::gemm(tiled_mma, mma_A, mma_B_b, accumulators);
340334
barrier_wait(3);
341335

342336
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // atom numbers per thread; A frags per sub_group

0 commit comments

Comments
 (0)