Skip to content

Commit a60ffa3

Browse files
committed
save code
1 parent 94e44e3 commit a60ffa3

1 file changed

Lines changed: 66 additions & 9 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class gemm_4bit_cutlass_kernel {
224224
auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
225225

226226
Tensor tCgA = thr_mma.partition_A(gA);
227-
Tensor tCgB = thr_mma.partition_B(gB);
227+
Tensor tCgB = thr_mma.partition_B(gB); //values for each_thread (FrgV,(RestN,RestK),*)
228228

229229
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
230230
Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
@@ -366,7 +366,58 @@ auto dequant = [&] {
366366
};
367367
#endif
368368
#endif
369-
369+
#if 0
370+
if (cute::thread0()){ //thread_idx==0 && n_coord == 0 && l_coord==0) {
371+
print("\n\n======================= A: \n");
372+
print(" gA : "); print(gA); print("\n");
373+
print(" tCgA : "); print(tCgA); print("\n");
374+
print(" tAgA : "); print(tAgA); print("\n");
375+
print(" mma_A : "); print(mma_A); print("\n");
376+
print(" frag_copy_A : "); print(frag_copy_A); print("\n");
377+
378+
print("===================== B :\n");
379+
print(" gB : "); print(gB); print("\n");
380+
print(" tCgB : "); print(tCgB); print("\n");
381+
print(" tBgB : "); print(tBgB); print("\n");
382+
print(" mma_B : "); print(mma_B); print("\n");
383+
//print(" frag_copy_B : "); print(frag_copy_B); print("\n");
384+
//print(" dequant_frag : "); print(dequant_frag); print("\n");
385+
386+
print("===================== Scale :\n");
387+
print(" tiled_copy_scale : "); print(params.tiled_copy_scale); print("\n");
388+
print(" fragment_scale : "); print(fragment_scale); print("\n");
389+
print(" frag_copy_Scale : "); print(frag_copy_Scale); print("\n");
390+
print(" tSgS : "); print(tSgS); print("\n");
391+
392+
print("===================== D :\n");
393+
print(" accumulators : "); print(accumulators); print("\n");
394+
395+
print("===================== Config: \n");
396+
print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n");
397+
print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n");
398+
399+
print("===================== Config: \n");
400+
print(" tiled_mma : "); print(tiled_mma); print("\n");
401+
402+
print("===================== Config: \n");
403+
print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n");
404+
405+
print("===================== Config: \n");
406+
print(" thr_mma : "); print(thr_mma); print("\n");
407+
408+
print("===================== Config: \n");
409+
print(" tiled_prefetch_a : "); print(tiled_prefetch_a); print("\n");
410+
411+
print("===================== Config: \n");
412+
print(" tiled_prefetch_b : "); print(tiled_prefetch_b); print("\n");
413+
414+
print("===================== Config: \n");
415+
print(" pAgA : "); print(pAgA); print("\n");
416+
417+
print("===================== Config: \n");
418+
print(" pBgB : "); print(pBgB); print("\n\n\n");
419+
}
420+
#endif
370421

371422
CUTLASS_PRAGMA_UNROLL
372423
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
@@ -378,10 +429,14 @@ auto dequant = [&] {
378429
//copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
379430
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale);
380431
//barrier_wait(3);
381-
const uint8_t* gB_ptr = params.B + cute::get<0>(gB.layout()(make_coord(n_coord, k_tile, 0)));
382432
auto dequant = [&] {
383433
constexpr int N = decltype(cute::size<1>(mma_B))::value;
384434
constexpr int K = decltype(cute::size(mma_B))::value / N;
435+
const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k/2 + k_tile * BLK_K/2;
436+
//if(thread_idx==8 && int(BlockIdxX())==0 && int(BlockIdxY())==0 && int(BlockIdxZ())==0){
437+
// printf("BLK_N = %d, BLK_K = %d, thread_idx = %d, N = %d, params.k = %d, params.B = %x, n_coord = %d, k_tile = %d, gB_ptr = %x\n",static_cast<int>(BLK_N), static_cast<int>(BLK_K), thread_idx, N, params.k, params.B, n_coord, k_tile, gB_ptr);
438+
// print(" gB_ptr: "); print(gB_ptr); print("\n");
439+
//}
385440

386441
using compress_type = uint32_t;
387442
constexpr int compress_size = 32 / cute::sizeof_bits_v<ElementB>;
@@ -392,13 +447,14 @@ auto dequant = [&] {
392447
constexpr int ELEMS_PER_BANK = (ELEMS_PER_THREAD + BANK_NUM - 1) / BANK_NUM; // 2
393448

394449
//const ElementB* gB_ptr = params.B + gB(n_coord, idx2crd(k_tile, make_shape(params.k)), 0).offset();
395-
ElementB* slm_B = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * 512 ;
450+
alignas(16) ElementB* slm_B = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * 64 * 4; //512 ;
396451
*reinterpret_cast<sycl::vec<uint64_t, 4>*>(slm_B) = *reinterpret_cast<const sycl::vec<uint64_t, 4>*>(gB_ptr);
397452

398453
compress_type src[vec_size];
399454
*reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src) = *reinterpret_cast<const sycl::vec<compress_type, vec_size>*>(slm_B);
400455

401-
ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(slm_B) + thread_idx * ELEMS_PER_THREAD; // 每个线程一段 **连续** 128 B,天然 128 B 对齐
456+
//ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(slm_B) + thread_idx * ELEMS_PER_THREAD; // 每个线程一段 **连续** 128 B,天然 128 B 对齐
457+
ElementMMA dst[K];
402458

403459
float scale_value = fragment_scale(0);
404460

@@ -408,12 +464,13 @@ auto dequant = [&] {
408464
for (int j = 0; j < compress_size; ++j) {
409465
uint8_t bit_value = (src[i] >> (4 * (((j+1) & 1) + (j >> 1) * 2))) & 0xF;
410466
//uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
411-
private_slm[i * compress_size + j] =
412-
static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
467+
//private_slm[i * compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
468+
dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
413469
}
414470
}
415471

416-
*reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data())) = *reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm);
472+
//*reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data())) = *reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm);
473+
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
417474
};
418475
dequant();
419476
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
@@ -455,7 +512,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
455512
using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS>;
456513

457514
//static constexpr int smem_size= BLK_N * BLK_K * 16/8; //(16+1)*32/8;
458-
static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementB) + BLK_N * BLK_K * sizeof(ElementMMA);
515+
static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementB) * 4;// + BLK_N * BLK_K * sizeof(ElementMMA);
459516
size_t max_slm_size = q.get_device().get_info<sycl::info::device::local_mem_size>();
460517
assert(smem_size <= max_slm_size);
461518

0 commit comments

Comments
 (0)