@@ -229,7 +229,9 @@ class gemm_4bit_cutlass_kernel {
229229 Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_a , tCgA (_,_,_,0 ).shape ()));
230230 Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout (params.tiled_copy_b , tCgB (_,_,_,0 ).shape ()));
231231
232- Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout ());
232+ // Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
233+ using DequantLayout = Layout<Shape<_16, _1, _4>>;
234+ Tensor dequant_frag = make_tensor<ElementB>(DequantLayout{});
233235
234236 static constexpr auto scale_traits_size = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
235237 static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
@@ -268,81 +270,7 @@ class gemm_4bit_cutlass_kernel {
268270 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
269271 int prefetch_k = k_start_idx;
270272
271- #if 0
272- auto convert = [](uint8_t quant_idx, float scale) {
273- const float range = 2.0f; // 假设量化范围[-1,1]
274- return ((quant_idx / 7.5f) - 1.0f) * scale; // 7.5=15/2 (4-bit)
275- };
276- #endif
277- #if 0
278- auto dequant = [&] {
279- constexpr int N = decltype(cute::size<1>(mma_B))::value;
280- constexpr int K = decltype(cute::size(mma_B))::value / N;
281-
282- using compress_type = uint32_t;
283- constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
284- constexpr int vec_size = K / compress_size;
285-
286- //if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
287- compress_type src[vec_size];
288- reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
289-
290- float scale_value = fragment_scale(0);
291-
292- auto* dst = reinterpret_cast<sycl::vec<int64_t, 16>*>(&smem_buf[thread_idx * decltype(cute::size(mma_B))::value * 2]);
293-
294- #pragma unroll
295- for (int i = 0; i < vec_size; i++) {
296- //compress_type src = src_[i];//(*src_).get(i);
297-
298- #pragma unroll
299- for (int j = 0; j < compress_size/2; j++) {
300- uint8_t high = (src[i]>> (4 * (j * 2 + 1))) & 0xf;
301- uint8_t low = (src[i] >> (4 * (j * 2))) & 0xf;
302- dst[0][i*compress_size+j*2] = static_cast<ElementMMA>(quant_map[high] * scale_value);
303- dst[0][i*compress_size+j*2+1] = static_cast<ElementMMA>(quant_map[low] * scale_value);
304- }
305- }
306- reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
307- #else
308- #if 0
309- auto dequant = [&] {
310- constexpr int N = decltype(cute::size<1>(mma_B))::value;
311- constexpr int K = decltype(cute::size(mma_B))::value / N;
312- using compress_type = uint32_t;
313- constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
314- constexpr int vec_size = K / compress_size;
315-
316- compress_type src[vec_size];
317- reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
318-
319- const int tid = thread_idx;
320- constexpr int BANK_NUM = 32;
321- constexpr int ELEMS_PER_THREAD = vec_size * compress_size;
322- constexpr int ELEMS_PER_BANK = (ELEMS_PER_THREAD + BANK_NUM - 1) / BANK_NUM;
323-
324- ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(smem_buf) + tid * BANK_NUM * ELEMS_PER_BANK;
325- //auto* private_slm = reinterpret_cast<sycl::vec<int64_t, 16>*>(&smem_buf[thread_idx * BANK_NUM * ELEMS_PER_BANK * 2]);
326- //if(cute::thread0()) printf("ELEMS_PER_THREAD = %d, ELEMS_PER_BANK = %d\n", ELEMS_PER_THREAD, ELEMS_PER_BANK);
327- float scale_value = fragment_scale(0);
328- #pragma unroll
329- for (int i = 0; i < vec_size; i++) {
330- #pragma unroll
331- for (int j = 0; j < compress_size; j++) {
332- uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
333-
334- const int linear_idx = i * compress_size + j;
335- const int bank = linear_idx % BANK_NUM;
336- const int offset = linear_idx / BANK_NUM;
337- //if(cute::thread0()) printf("i = %d, j = %d, linear_idx = %d, bank = %d, offset = %d, bank * ELEMS_PER_BANK + offset = %d\n",i,j,linear_idx,bank,offset, bank * ELEMS_PER_BANK + offset);
338-
339- private_slm[bank * ELEMS_PER_BANK + offset] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
340- }
341- }
342-
343- reinterpret_cast<sycl::vec<uint64_t, 16>*>(&mma_B)[0] = *reinterpret_cast<sycl::vec<uint64_t, 16>*>(private_slm);
344- };
345- #endif
273+ #if 1
346274auto dequant = [&] {
347275 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
348276 constexpr int K = decltype (cute::size (mma_B))::value / N;
@@ -378,7 +306,55 @@ auto dequant = [&] {
378306
379307 *reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ())) = *reinterpret_cast <const sycl::vec<int64_t , 16 >*>(private_slm);
380308};
309+ #else
310+ #if 1
311+ auto dequant = [&] {
312+ constexpr int N = decltype(cute::size<1>(mma_B))::value;
313+ constexpr int K = decltype(cute::size(mma_B))::value / N;
314+
315+ using compress_type = uint32_t;
316+ constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
317+ constexpr int vec_size = K / compress_size;
318+
319+ //if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
320+ compress_type src[vec_size];
321+ ElementMMA dst[K];
322+
323+ float scale_value = fragment_scale(0);
324+
325+ reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
326+
327+ #pragma unroll
328+ for (int i = 0; i < vec_size; i++) {
329+ #pragma unroll
330+ for (int j = 0; j < compress_size; j++) {
331+ uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
332+ dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
333+ //dst[i*compress_size+j] = static_cast<ElementMMA>(convert(bit_value, scale_value));
334+ }
335+ }
336+ reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
337+ };
338+ #else
339+ auto dequant = [&] {
340+ constexpr int N = decltype(cute::size<1>(mma_B))::value;
341+ constexpr int K = decltype(cute::size(mma_B))::value / N;
342+ float scale_value = fragment_scale(0);
343+
344+ //#pragma unroll
345+ //for(int i=0; i<K; i++) {
346+ // mma_B[i] = static_cast<ElementMMA>(quant_map[(reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i/2] >> (4 * ((i+1)%2))) & 0xf] * scale_value);
347+ //}
348+
349+ #pragma unroll
350+ for(int i=0; i<K/2; i++) {
351+ mma_B[i*2] = static_cast<ElementMMA>(quant_map[(reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i] >> 4) & 0xf] * scale_value);
352+ mma_B[i*2+1] = static_cast<ElementMMA>(quant_map[reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i] & 0xf] * scale_value);
353+ }
354+ };
381355#endif
356+ #endif
357+
382358
383359 CUTLASS_PRAGMA_UNROLL
384360 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
0 commit comments