Skip to content

Commit 7b7abec

Browse files
committed
add new method
1 parent 609d285 commit 7b7abec

1 file changed

Lines changed: 52 additions & 76 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 52 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ class gemm_4bit_cutlass_kernel {
229229
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_a, tCgA(_,_,_,0).shape()));
230230
Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout(params.tiled_copy_b, tCgB(_,_,_,0).shape()));
231231

232-
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
232+
//Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
233+
using DequantLayout = Layout<Shape<_16, _1, _4>>;
234+
Tensor dequant_frag = make_tensor<ElementB>(DequantLayout{});
233235

234236
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize;
235237
static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
@@ -268,81 +270,7 @@ class gemm_4bit_cutlass_kernel {
268270
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
269271
int prefetch_k = k_start_idx;
270272

271-
#if 0
272-
auto convert = [](uint8_t quant_idx, float scale) {
273-
const float range = 2.0f; // 假设量化范围[-1,1]
274-
return ((quant_idx / 7.5f) - 1.0f) * scale; // 7.5=15/2 (4-bit)
275-
};
276-
#endif
277-
#if 0
278-
auto dequant = [&] {
279-
constexpr int N = decltype(cute::size<1>(mma_B))::value;
280-
constexpr int K = decltype(cute::size(mma_B))::value / N;
281-
282-
using compress_type = uint32_t;
283-
constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
284-
constexpr int vec_size = K / compress_size;
285-
286-
//if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
287-
compress_type src[vec_size];
288-
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
289-
290-
float scale_value = fragment_scale(0);
291-
292-
auto* dst = reinterpret_cast<sycl::vec<int64_t, 16>*>(&smem_buf[thread_idx * decltype(cute::size(mma_B))::value * 2]);
293-
294-
#pragma unroll
295-
for (int i = 0; i < vec_size; i++) {
296-
//compress_type src = src_[i];//(*src_).get(i);
297-
298-
#pragma unroll
299-
for (int j = 0; j < compress_size/2; j++) {
300-
uint8_t high = (src[i]>> (4 * (j * 2 + 1))) & 0xf;
301-
uint8_t low = (src[i] >> (4 * (j * 2))) & 0xf;
302-
dst[0][i*compress_size+j*2] = static_cast<ElementMMA>(quant_map[high] * scale_value);
303-
dst[0][i*compress_size+j*2+1] = static_cast<ElementMMA>(quant_map[low] * scale_value);
304-
}
305-
}
306-
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
307-
#else
308-
#if 0
309-
auto dequant = [&] {
310-
constexpr int N = decltype(cute::size<1>(mma_B))::value;
311-
constexpr int K = decltype(cute::size(mma_B))::value / N;
312-
using compress_type = uint32_t;
313-
constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
314-
constexpr int vec_size = K / compress_size;
315-
316-
compress_type src[vec_size];
317-
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
318-
319-
const int tid = thread_idx;
320-
constexpr int BANK_NUM = 32;
321-
constexpr int ELEMS_PER_THREAD = vec_size * compress_size;
322-
constexpr int ELEMS_PER_BANK = (ELEMS_PER_THREAD + BANK_NUM - 1) / BANK_NUM;
323-
324-
ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(smem_buf) + tid * BANK_NUM * ELEMS_PER_BANK;
325-
//auto* private_slm = reinterpret_cast<sycl::vec<int64_t, 16>*>(&smem_buf[thread_idx * BANK_NUM * ELEMS_PER_BANK * 2]);
326-
//if(cute::thread0()) printf("ELEMS_PER_THREAD = %d, ELEMS_PER_BANK = %d\n", ELEMS_PER_THREAD, ELEMS_PER_BANK);
327-
float scale_value = fragment_scale(0);
328-
#pragma unroll
329-
for (int i = 0; i < vec_size; i++) {
330-
#pragma unroll
331-
for (int j = 0; j < compress_size; j++) {
332-
uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
333-
334-
const int linear_idx = i * compress_size + j;
335-
const int bank = linear_idx % BANK_NUM;
336-
const int offset = linear_idx / BANK_NUM;
337-
//if(cute::thread0()) printf("i = %d, j = %d, linear_idx = %d, bank = %d, offset = %d, bank * ELEMS_PER_BANK + offset = %d\n",i,j,linear_idx,bank,offset, bank * ELEMS_PER_BANK + offset);
338-
339-
private_slm[bank * ELEMS_PER_BANK + offset] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
340-
}
341-
}
342-
343-
reinterpret_cast<sycl::vec<uint64_t, 16>*>(&mma_B)[0] = *reinterpret_cast<sycl::vec<uint64_t, 16>*>(private_slm);
344-
};
345-
#endif
273+
#if 1
346274
auto dequant = [&] {
347275
constexpr int N = decltype(cute::size<1>(mma_B))::value;
348276
constexpr int K = decltype(cute::size(mma_B))::value / N;
@@ -378,7 +306,55 @@ auto dequant = [&] {
378306

379307
*reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data())) = *reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm);
380308
};
309+
#else
310+
#if 1
311+
auto dequant = [&] {
312+
constexpr int N = decltype(cute::size<1>(mma_B))::value;
313+
constexpr int K = decltype(cute::size(mma_B))::value / N;
314+
315+
using compress_type = uint32_t;
316+
constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
317+
constexpr int vec_size = K / compress_size;
318+
319+
//if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
320+
compress_type src[vec_size];
321+
ElementMMA dst[K];
322+
323+
float scale_value = fragment_scale(0);
324+
325+
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
326+
327+
#pragma unroll
328+
for (int i = 0; i < vec_size; i++) {
329+
#pragma unroll
330+
for (int j = 0; j < compress_size; j++) {
331+
uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
332+
dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
333+
//dst[i*compress_size+j] = static_cast<ElementMMA>(convert(bit_value, scale_value));
334+
}
335+
}
336+
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
337+
};
338+
#else
339+
auto dequant = [&] {
340+
constexpr int N = decltype(cute::size<1>(mma_B))::value;
341+
constexpr int K = decltype(cute::size(mma_B))::value / N;
342+
float scale_value = fragment_scale(0);
343+
344+
//#pragma unroll
345+
//for(int i=0; i<K; i++) {
346+
// mma_B[i] = static_cast<ElementMMA>(quant_map[(reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i/2] >> (4 * ((i+1)%2))) & 0xf] * scale_value);
347+
//}
348+
349+
#pragma unroll
350+
for(int i=0; i<K/2; i++) {
351+
mma_B[i*2] = static_cast<ElementMMA>(quant_map[(reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i] >> 4) & 0xf] * scale_value);
352+
mma_B[i*2+1] = static_cast<ElementMMA>(quant_map[reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i] & 0xf] * scale_value);
353+
}
354+
};
381355
#endif
356+
#endif
357+
382358

383359
CUTLASS_PRAGMA_UNROLL
384360
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {

0 commit comments

Comments
 (0)