@@ -325,25 +325,21 @@ inline float dDequantizeNF4(unsigned char val) {
325325#if 1 // SLM
326326 #if 1
327327 auto dequant = [&] (int k_tile) {
328- constexpr int N = decltype (cute::size<1 >(mma_B))::value;
329- constexpr int K = decltype (cute::size (mma_B))::value / N;
330-
331-
332- using src_compress_type = uint64_t ;
333- using dst_compress_type = uint64_t ;
334- constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
335- constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 16
336- constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; // 4, 16 -> max vec_size of sycl::vec
337- constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; // 16, 16 -> max vec_size of sycl::vec
338- constexpr int src_loop_num = K / src_vec_size / src_compress_size;
339- constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
340-
341- alignas (16 ) ElementB* src = reinterpret_cast <ElementB*>(smem_buf) + thread_idx * (K * 4 ); // for K=64, 4 is hardcode for 128B alignment.
342- const uint8_t * gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2 ;
343- // reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr)[0];
344-
345-
346- ElementMMA* dst_slm = reinterpret_cast <ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
328+ constexpr int N = decltype (cute::size<1 >(mma_B))::value;
329+ constexpr int K = decltype (cute::size (mma_B))::value / N;
330+
331+ using src_compress_type = uint64_t ;
332+ using dst_compress_type = uint64_t ;
333+ constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
334+ constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 16
335+ constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; // 4, 16 -> max vec_size of sycl::vec
336+ constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; // 16, 16 -> max vec_size of sycl::vec
337+ constexpr int src_loop_num = K / src_vec_size / src_compress_size;
338+ constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
339+
340+ alignas (16 ) ElementB* src = reinterpret_cast <ElementB*>(smem_buf) + thread_idx * (K * 4 ); // for K=64, 4 is hardcode for 128B alignment.
341+ const uint8_t * gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2 ;
342+ ElementMMA* dst_slm = reinterpret_cast <ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
347343
348344 #pragma unroll
349345 for (int n = 0 ; n < N; n++) {
0 commit comments