Skip to content

Commit 62190ab

Browse files
committed
effective change
1 parent 0dec96e commit 62190ab

4 files changed

Lines changed: 59 additions & 37 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,8 @@ using ElementOutput = float;
5959
using ProblemShape = Shape<int, int, int, int>;
6060

6161
using TileShape = Shape<_16, _64, _64>;
62-
using TileShape_half = Shape<_16, _64, _32>;
6362
using TiledMma =
64-
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
63+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6564
Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
6665

6766
using WorkgroupTileShape = TileShape;
@@ -237,9 +236,10 @@ class kgemm_4bit_inference_cutlass_dequant {
237236
uint8_t value = in[i].get();
238237
out[i] = static_cast<DstType>(quant_map[value]);
239238
int thread_idx = int(ThreadIdxX());
240-
//if(thread_idx == 0)
241-
if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
242-
printf("thread_idx = %d, i = %d, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",thread_idx, i, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
239+
if(cute::thread0()){
240+
//if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
241+
//printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
242+
}
243243
}
244244
#else
245245
static constexpr auto N = decltype(size<1>(in))::value;
@@ -419,7 +419,7 @@ if(cute::thread0())
419419
// make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{})));
420420
//
421421
// }();
422-
#if 0
422+
#if 1
423423
#define PRINT(x) print(#x ": "); print(x); print("\n");
424424
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
425425
print("======================= A: \n");
@@ -437,6 +437,9 @@ if(cute::thread0())
437437
print(" frag_copy_B : "); print(frag_copy_B); print("\n");
438438
print(" dequant_frag : "); print(dequant_frag); print("\n");
439439

440+
print("===================== D :\n");
441+
print(" accumulators : "); print(accumulators); print("\n");
442+
440443
print("===================== Config: \n");
441444
print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n");
442445
print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n");
@@ -456,7 +459,7 @@ if(cute::thread0())
456459
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
457460
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
458461
}
459-
462+
//k_tile_count=1;
460463
for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
461464
barrier_arrive(2);
462465

@@ -477,39 +480,55 @@ if(cute::thread0())
477480

478481
dequant(dequant_frag, mma_B, /*fragment_scale,*/ quant_map);
479482

483+
//barrier_wait(1);
484+
480485
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
486+
barrier_wait(2);
487+
#if 0
488+
// 在调用gemm前后添加打印逻辑
489+
auto debug_print = [&](const char* name, auto& tensor) {
490+
int numbers = decltype(size(tensor))::value;
491+
printf("\n----- %s ----- numbers = %d\n", name, numbers);
492+
for (int i = 0; i < numbers; ++i) {
493+
printf("%s[%d] = %6.2f\n", name, i , static_cast<float>(tensor[i]));
494+
}
495+
printf("\n\n");
496+
barrier_wait(1);
497+
};
481498

482-
//// 在调用gemm前后添加打印逻辑
483-
//auto debug_print = [&](const char* name, auto& tensor) {
484-
// if (thread_idx == 0) {
485-
// printf("----- %s -----\n", name);
486-
// for (int i = 0; i < size<0>(tensor); ++i) {
487-
// for (int j = 0; j < size<1>(tensor); ++j) {
488-
// printf("%6.2f ", static_cast<float>(tensor(i, j)));
489-
// }
490-
// printf("\n");
491-
// }
492-
// }
493-
// barrier_wait(2);
494-
//};
495-
//
496-
//// 打印输入
497-
//debug_print("Input A (mma_A)", mma_A);
498-
//debug_print("Input B (mma_B)", mma_B);
499-
//debug_print("Accumulators (Before GEMM)", accumulators);
500-
//
501-
//// 执行GEMM
502-
//cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
503-
//
504-
//// 打印输出
505-
//debug_print("Accumulators (After GEMM)", accumulators);
499+
if (cute::thread0()) {
500+
// 打印输入
501+
debug_print("Input A (mma_A)", mma_A);
502+
barrier_wait(1);
503+
debug_print("Input B (mma_B)", mma_B);
504+
barrier_wait(1);
505+
debug_print("Accumulators (Before GEMM)", accumulators);
506+
barrier_wait(1);
507+
}
508+
// 执行GEMM
509+
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
506510

507-
barrier_wait(2);
511+
if (cute::thread0()) {
512+
// 打印输出
513+
debug_print("Accumulators (After GEMM)", accumulators);
514+
515+
barrier_wait(2);
516+
}
517+
#endif
518+
#if 0
519+
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
520+
barrier_wait(2);
521+
522+
for (int i = 0; i < accumulators.size(); ++i) {
523+
printf("Thread (%d, %d): accumulators[%d] =%f\n", syclcompat::global_id::x() , syclcompat::global_id::y(), i, static_cast<float>(accumulators[i]));
524+
}
525+
printf("\n");
526+
#endif
508527
}
509528

510529
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>((char*)nullptr);
511530
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
512-
auto problem_shape_MNKL = problem_size; //append<4>(problem_size, 1);
531+
auto problem_shape_MNKL = append<4>(problem_size, 1);
513532
epilogue(
514533
problem_shape_MNKL,
515534
subgroup_tile_shape,
@@ -573,7 +592,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
573592
//int k_half = k/2;
574593
//StrideB stride_B = make_stride(int64_t{1}, int64_t{n}, int64_t{n * k});
575594
StrideB stride_B = make_stride(int64_t{n}, cute::Int<1>{}, int64_t{0});
576-
auto mB_nkl = make_tensor(cute::subbyte_iterator<uint4_t>(B), make_layout(make_shape(n, k, l), stride_B));
595+
auto mB_nkl = make_tensor(cute::subbyte_iterator<ElementB>(B), make_layout(make_shape(n, k, l), stride_B));
577596
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};
578597

579598
#define PRINT(x) print(#x ": "); print(x); print("\n");

include/cute/container/array_subbyte.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ struct subbyte_reference
8484

8585
friend struct subbyte_iterator<T>;
8686

87+
public:
8788
// Pointer to storage element
8889
storage_type* ptr_ = nullptr;
8990

@@ -262,6 +263,8 @@ struct subbyte_iterator
262263
k = sizeof_bits_v<value_type> * k + idx_;
263264
ptr_ += k / sizeof_bits_v<storage_type>;
264265
idx_ = k % sizeof_bits_v<storage_type>;
266+
if(idx_ == 4)
267+
//printf("k = %d, ptr_ = %x, idx_ = %d\n", k, ptr_, idx_);
265268
return *this;
266269
}
267270

run_case.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
#gdb -args python -m pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
31-
#pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
32-
python tests/test_xpu_db.py
31+
pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
32+
#python tests/test_xpu_db.py
3333
#gdb -args python tests/test_xpu_db.py
3434
#pytest tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=256-uint8-bf16-fc1-nf4-DQ_True-xpu]

tests/test_xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
8989
##pdb.set_trace()
9090
C3 = torch.matmul(A, B.t())
9191
#pdb.set_trace()
92-
C2 = F.gemv_4bit(A, qB.t(), state=state)
92+
C2 = F.gemv_4bit(A, qB, state=state)
9393
#pdb.set_trace()
9494
print("C3.sum() = ", C3.sum())
9595
print("C2.sum() = ", C2.sum())

0 commit comments

Comments
 (0)