Skip to content

Commit 7f0e2c0

Browse files
committed
save code
1 parent feb23d2 commit 7f0e2c0

1 file changed

Lines changed: 31 additions & 18 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using namespace cutlass::gemm;
3636

3737
// Define Basic information
3838
//Weight-only-quant (B)
39-
using MmaType = cutlass::bfloat16_t;
39+
using MmaType = sycl::ext::oneapi::bfloat16; //cutlass::bfloat16_t;
4040
using QuantType = cutlass::uint4_t; //NF4,FP4
4141

4242
using ElementA = MmaType;
@@ -186,6 +186,7 @@ class gemm_4bit_cutlass_kernel {
186186
? BlockIdxX() : BlockIdxY();
187187
const int l_coord = BlockIdxZ();
188188

189+
#if 1
189190
float* quant_map;
190191
{
191192
// Load Dequatize LUT and save to SLM, 16 for 4bits
@@ -195,7 +196,14 @@ class gemm_4bit_cutlass_kernel {
195196
}
196197
barrier_arrive(3);
197198
}
198-
199+
#else
200+
constexpr float quant_map[16] = {
201+
-1.0f, -0.6961928f, -0.52507305f, -0.39491749f,
202+
-0.28444138f, -0.18477343f, -0.09105004f, 0.0f,
203+
0.0795803f, 0.1609302f, 0.2461123f, 0.33791524f,
204+
0.44070983f, 0.562617f, 0.72295684f, 1.0f
205+
};
206+
#endif
199207
Tensor mA_mkl = cute::get_pvc_tensor(make_shape(params.m, params.k, params.l));
200208
Tensor mB_nkl = cute::get_pvc_tensor(make_shape(params.n, params.k,1));
201209

@@ -260,33 +268,38 @@ class gemm_4bit_cutlass_kernel {
260268
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
261269
int prefetch_k = k_start_idx;
262270

271+
#if 0
272+
auto convert = [](uint8_t quant_idx, float scale) {
273+
const float range = 2.0f; // 假设量化范围[-1,1]
274+
return ((quant_idx / 7.5f) - 1.0f) * scale; // 7.5=15/2 (4-bit)
275+
};
276+
#endif
263277
auto dequant = [&] {
264278
constexpr int N = decltype(cute::size<1>(mma_B))::value;
265279
constexpr int K = decltype(cute::size(mma_B))::value / N;
266-
//if(cute::thread0()) printf("K = %d, N = %d\n", K, N);
267280

268281
using compress_type = uint32_t;
269282
constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
270-
constexpr auto vec_size = K / compress_size;
283+
constexpr int vec_size = K / compress_size;
271284

272-
using VecSrcType = cute::array<compress_type, vec_size>;
273-
using VecDstElemType = cute::array<ElementMMA, compress_size>;
274-
using VecDstType = cute::array<VecDstElemType, vec_size>;
285+
//if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
286+
compress_type src[vec_size];
287+
ElementMMA dst[K];
275288

276289
float scale_value = fragment_scale(0);
277-
auto src = *(VecSrcType*)(cute::raw_pointer_cast(dequant_frag.data()));
278-
auto& dst = *(VecDstType*)(cute::raw_pointer_cast(mma_B.data()));
279-
VecDstType dst_val;
280-
#pragma unroll
281-
for (int i = 0; i < vec_size; i++) {
282-
VecDstElemType dst_elem;
290+
291+
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
292+
293+
#pragma unroll
294+
for (int i = 0; i < vec_size; i++) {
283295
#pragma unroll
284296
for (int j = 0; j < compress_size; j++) {
285-
dst_elem[j] = static_cast<ElementMMA>(quant_map[(src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf] * scale_value);
297+
uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
298+
dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
299+
//dst[i*compress_size+j] = static_cast<ElementMMA>(convert(bit_value, scale_value));
286300
}
287-
dst_val[i] = dst_elem;
288-
}
289-
dst = dst_val;
301+
}
302+
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
290303
};
291304

292305
CUTLASS_PRAGMA_UNROLL
@@ -338,7 +351,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
338351

339352
using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS>;
340353

341-
static constexpr int smem_size= 16*32/8;
354+
static constexpr int smem_size= (16+1)*32/8;
342355

343356
auto problem_size = ProblemShape{m, n, k, l};
344357

0 commit comments

Comments
 (0)