Skip to content

Commit 609d285

Browse files
committed
save code
1 parent 7f0e2c0 commit 609d285

3 files changed

Lines changed: 92 additions & 12 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class gemm_4bit_cutlass_kernel {
186186
? BlockIdxX() : BlockIdxY();
187187
const int l_coord = BlockIdxZ();
188188

189-
#if 1
189+
#if 0
190190
float* quant_map;
191191
{
192192
// Load Dequatize LUT and save to SLM, 16 for 4bits
@@ -274,6 +274,7 @@ class gemm_4bit_cutlass_kernel {
274274
return ((quant_idx / 7.5f) - 1.0f) * scale; // 7.5=15/2 (4-bit)
275275
};
276276
#endif
277+
#if 0
277278
auto dequant = [&] {
278279
constexpr int N = decltype(cute::size<1>(mma_B))::value;
279280
constexpr int K = decltype(cute::size(mma_B))::value / N;
@@ -284,23 +285,100 @@ class gemm_4bit_cutlass_kernel {
284285

285286
//if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
286287
compress_type src[vec_size];
287-
ElementMMA dst[K];
288+
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
288289

289290
float scale_value = fragment_scale(0);
290291

291-
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
292+
auto* dst = reinterpret_cast<sycl::vec<int64_t, 16>*>(&smem_buf[thread_idx * decltype(cute::size(mma_B))::value * 2]);
292293

293294
#pragma unroll
294295
for (int i = 0; i < vec_size; i++) {
296+
//compress_type src = src_[i];//(*src_).get(i);
297+
295298
#pragma unroll
296-
for (int j = 0; j < compress_size; j++) {
297-
uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
298-
dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
299-
//dst[i*compress_size+j] = static_cast<ElementMMA>(convert(bit_value, scale_value));
299+
for (int j = 0; j < compress_size/2; j++) {
300+
uint8_t high = (src[i]>> (4 * (j * 2 + 1))) & 0xf;
301+
uint8_t low = (src[i] >> (4 * (j * 2))) & 0xf;
302+
dst[0][i*compress_size+j*2] = static_cast<ElementMMA>(quant_map[high] * scale_value);
303+
dst[0][i*compress_size+j*2+1] = static_cast<ElementMMA>(quant_map[low] * scale_value);
300304
}
301305
}
302306
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
303-
};
307+
#else
308+
#if 0
309+
auto dequant = [&] {
310+
constexpr int N = decltype(cute::size<1>(mma_B))::value;
311+
constexpr int K = decltype(cute::size(mma_B))::value / N;
312+
using compress_type = uint32_t;
313+
constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
314+
constexpr int vec_size = K / compress_size;
315+
316+
compress_type src[vec_size];
317+
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
318+
319+
const int tid = thread_idx;
320+
constexpr int BANK_NUM = 32;
321+
constexpr int ELEMS_PER_THREAD = vec_size * compress_size;
322+
constexpr int ELEMS_PER_BANK = (ELEMS_PER_THREAD + BANK_NUM - 1) / BANK_NUM;
323+
324+
ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(smem_buf) + tid * BANK_NUM * ELEMS_PER_BANK;
325+
//auto* private_slm = reinterpret_cast<sycl::vec<int64_t, 16>*>(&smem_buf[thread_idx * BANK_NUM * ELEMS_PER_BANK * 2]);
326+
//if(cute::thread0()) printf("ELEMS_PER_THREAD = %d, ELEMS_PER_BANK = %d\n", ELEMS_PER_THREAD, ELEMS_PER_BANK);
327+
float scale_value = fragment_scale(0);
328+
#pragma unroll
329+
for (int i = 0; i < vec_size; i++) {
330+
#pragma unroll
331+
for (int j = 0; j < compress_size; j++) {
332+
uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
333+
334+
const int linear_idx = i * compress_size + j;
335+
const int bank = linear_idx % BANK_NUM;
336+
const int offset = linear_idx / BANK_NUM;
337+
//if(cute::thread0()) printf("i = %d, j = %d, linear_idx = %d, bank = %d, offset = %d, bank * ELEMS_PER_BANK + offset = %d\n",i,j,linear_idx,bank,offset, bank * ELEMS_PER_BANK + offset);
338+
339+
private_slm[bank * ELEMS_PER_BANK + offset] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
340+
}
341+
}
342+
343+
reinterpret_cast<sycl::vec<uint64_t, 16>*>(&mma_B)[0] = *reinterpret_cast<sycl::vec<uint64_t, 16>*>(private_slm);
344+
};
345+
#endif
346+
auto dequant = [&] {
347+
constexpr int N = decltype(cute::size<1>(mma_B))::value;
348+
constexpr int K = decltype(cute::size(mma_B))::value / N;
349+
350+
using compress_type = uint32_t;
351+
constexpr int compress_size = 32 / cute::sizeof_bits_v<ElementB>;
352+
constexpr int vec_size = K / compress_size;
353+
354+
constexpr int BANK_NUM = 32; // Intel SLM bank 数
355+
constexpr int ELEMS_PER_THREAD = vec_size * compress_size; // 64
356+
constexpr int ELEMS_PER_BANK = (ELEMS_PER_THREAD + BANK_NUM - 1) / BANK_NUM; // 2
357+
358+
compress_type src[vec_size];
359+
*reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src) =
360+
*reinterpret_cast<const sycl::vec<compress_type, vec_size>*>(
361+
cute::raw_pointer_cast(dequant_frag.data()));
362+
363+
const int tid = thread_idx;
364+
ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(smem_buf) + tid * ELEMS_PER_THREAD; // 每个线程一段 **连续** 128 B,天然 128 B 对齐
365+
366+
float scale_value = fragment_scale(0);
367+
368+
#pragma unroll
369+
for (int i = 0; i < vec_size; ++i) {
370+
#pragma unroll
371+
for (int j = 0; j < compress_size; ++j) {
372+
uint8_t bit_value = (src[i] >> (4 * (((j+1) & 1) + (j >> 1) * 2))) & 0xF;
373+
//uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
374+
private_slm[i * compress_size + j] =
375+
static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
376+
}
377+
}
378+
379+
*reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data())) = *reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm);
380+
};
381+
#endif
304382

305383
CUTLASS_PRAGMA_UNROLL
306384
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
@@ -351,7 +429,9 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
351429

352430
using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS>;
353431

354-
static constexpr int smem_size= (16+1)*32/8;
432+
static constexpr int smem_size= BLK_N * BLK_K * 16/8; //(16+1)*32/8;
433+
size_t max_slm_size = q.get_device().get_info<sycl::info::device::local_mem_size>();
434+
assert(smem_size <= max_slm_size);
355435

356436
auto problem_size = ProblemShape{m, n, k, l};
357437

run_case.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
#gdb -args python -m pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
3131
#pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
32-
pytest -vs tests/test_xpu.py::TestXPU::test_gemv_4bit
33-
#python tests/test_xpu_db.py
32+
#pytest -vs tests/test_xpu.py::TestXPU::test_gemv_4bit
33+
python tests/test_xpu_db.py
3434
#gdb -args python tests/test_xpu_db.py
3535
#pytest tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=256-uint8-bf16-fc1-nf4-DQ_True-xpu]

tests/test_xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
282282
for i in range(iters):
283283
#pdb.set_trace()
284284
if kind == "fc1":
285-
A = torch.randn(2, dim, dim, dtype=dtype, device=device)
285+
A = torch.randn(dim, dim, dtype=dtype, device=device)
286286
B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim)
287287
elif kind == "fc2":
288288
A = torch.randn(dim, 4 * dim, dtype=dtype, device=device)

0 commit comments

Comments
 (0)