Skip to content

Commit 9dc75fc

Browse files
committed
save code
1 parent dab7994 commit 9dc75fc

1 file changed

Lines changed: 17 additions & 20 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -323,45 +323,42 @@ inline float dDequantizeNF4(unsigned char val) {
323323
int prefetch_k = k_start_idx;
324324

325325
#if 1 //SLM
326-
//alignas(16) ElementB* slm_B = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (64 * 4) * k_tile_count;
327-
//const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * 1) * params.k/2;
328-
////using total_vec = 4*k_tile_count;
329-
//reinterpret_cast<sycl::vec<uint64_t, 16>*>(slm_B)[0] = reinterpret_cast<const sycl::vec<uint64_t, 16>*>(gB_ptr)[0];
330326
#if 1
331327
auto dequant = [&] (int k_tile) {
332328
constexpr int N = decltype(cute::size<1>(mma_B))::value;
333329
constexpr int K = decltype(cute::size(mma_B))::value / N;
334-
335-
using compress_type = uint32_t;
336-
constexpr int compress_size = 32 / cute::sizeof_bits_v<ElementB>;
337-
constexpr int vec_size = K / compress_size;
338-
330+
331+
332+
using src_compress_type = uint64_t;
333+
using dst_compress_type = uint64_t;
334+
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
335+
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //16
336+
constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
337+
constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
338+
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
339+
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
340+
339341
alignas(16) ElementB* src = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (K * 4); //for K=64, 4 is hardcode for 128B alignment.
340342
const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k/2 + k_tile * BLK_K/2;
341-
reinterpret_cast<sycl::vec<uint64_t, 4>*>(src)[0] = reinterpret_cast<const sycl::vec<uint64_t, 4>*>(gB_ptr)[0];
343+
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr)[0];
342344

343-
//compress_type src[vec_size];
344-
//reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<compress_type, vec_size>*>(slm_B)[0];
345345

346346
ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
347-
//ElementMMA dst[K];
348347

349348
float scale_value = fragment_scale(0);
350349

351350
#pragma unroll
352-
for (int i = 0; i < vec_size; ++i) {
353-
uint32_t src_value = reinterpret_cast<uint32_t*>(src)[i];
351+
for (int i = 0; i < src_vec_size; ++i) {
352+
src_compress_type src_value = reinterpret_cast<src_compress_type*>(src)[i];
354353
#pragma unroll
355-
for (int j = 0; j < compress_size; ++j) {
354+
for (int j = 0; j < src_compress_size; ++j) {
356355
uint8_t bit_value = (src_value >> (4 * (((j+1) & 1) + (j >> 1) * 2))) & 0xF;
357-
private_slm[i * compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
358-
//dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
356+
private_slm[i * src_compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
359357
}
360358
}
361359

362360
for(int i=0; i<K/4/16; i++){
363-
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm)[i];
364-
//reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[i];
361+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<const sycl::vec<dst_compress_type, dst_vec_size>*>(private_slm)[i];
365362
}
366363
};
367364
#endif

0 commit comments

Comments
 (0)