Skip to content

Commit c227104

Browse files
committed
save code
1 parent 143b91e commit c227104

1 file changed

Lines changed: 15 additions & 19 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -325,25 +325,21 @@ inline float dDequantizeNF4(unsigned char val) {
325325
#if 1 //SLM
326326
#if 1
327327
auto dequant = [&] (int k_tile) {
328-
constexpr int N = decltype(cute::size<1>(mma_B))::value;
329-
constexpr int K = decltype(cute::size(mma_B))::value / N;
330-
331-
332-
using src_compress_type = uint64_t;
333-
using dst_compress_type = uint64_t;
334-
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
335-
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //16
336-
constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
337-
constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
338-
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
339-
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
340-
341-
alignas(16) ElementB* src = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (K * 4); //for K=64, 4 is hardcode for 128B alignment.
342-
const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2;
343-
//reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr)[0];
344-
345-
346-
ElementMMA* dst_slm = reinterpret_cast<ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
328+
constexpr int N = decltype(cute::size<1>(mma_B))::value;
329+
constexpr int K = decltype(cute::size(mma_B))::value / N;
330+
331+
using src_compress_type = uint64_t;
332+
using dst_compress_type = uint64_t;
333+
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
334+
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //16
335+
constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
336+
constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
337+
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
338+
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
339+
340+
alignas(16) ElementB* src = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (K * 4); //for K=64, 4 is hardcode for 128B alignment.
341+
const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2;
342+
ElementMMA* dst_slm = reinterpret_cast<ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
347343

348344
#pragma unroll
349345
for (int n = 0; n < N; n++) {

0 commit comments

Comments
 (0)