Skip to content

Commit 9264ffa

Browse files
committed
save code
1 parent 8909af8 commit 9264ffa

1 file changed

Lines changed: 20 additions & 11 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -371,30 +371,39 @@ inline float dDequantizeNF4(unsigned char val) {
371371
constexpr int N = decltype(cute::size<1>(mma_B))::value;
372372
constexpr int K = decltype(cute::size(mma_B))::value / N;
373373

374-
using compress_type = uint64_t;
375-
constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
376-
constexpr int vec_size = K / compress_size;
377-
374+
// using compress_type = uint64_t;
375+
// constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
376+
// constexpr int vec_size = K / compress_size;
377+
378+
using src_compress_type = uint64_t;
379+
using dst_compress_type = uint64_t;
380+
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
381+
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //16
382+
constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
383+
constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
384+
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
385+
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
386+
378387
//if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
379-
compress_type src[vec_size];
388+
src_compress_type src[src_vec_size];
380389
ElementMMA dst[K];
381390

382391
float scale_value = fragment_scale(0);
383392

384-
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
393+
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
385394

386395
#pragma unroll
387-
for (int i = 0; i < vec_size; i++) {
388-
compress_type src_value = src[i];
396+
for (int i = 0; i < src_vec_size; i++) {
397+
src_compress_type src_value = src[i];
389398
#pragma unroll
390-
for (int j = 0; j < compress_size; j++) {
399+
for (int j = 0; j < src_compress_size; j++) {
391400
unsigned char bit_value = (src_value >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
392401
// __builtin_assume(bit_value >= 0 && bit_value < 16);
393-
dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
402+
dst[i*src_compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
394403
//dst[i*compress_size+j] = static_cast<ElementMMA>(dDequantizeNF4(bit_value) * scale_value);
395404
}
396405
}
397-
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
406+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0];
398407
};
399408
#else //elemented load/store
400409
auto dequant = [&] {

0 commit comments

Comments
 (0)