@@ -323,45 +323,42 @@ inline float dDequantizeNF4(unsigned char val) {
323323 int prefetch_k = k_start_idx;
324324
325325#if 1 // SLM
326- // alignas(16) ElementB* slm_B = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (64 * 4) * k_tile_count;
327- // const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * 1) * params.k/2;
328- // //using total_vec = 4*k_tile_count;
329- // reinterpret_cast<sycl::vec<uint64_t, 16>*>(slm_B)[0] = reinterpret_cast<const sycl::vec<uint64_t, 16>*>(gB_ptr)[0];
330326 #if 1
331327 auto dequant = [&] (int k_tile) {
332328 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
333329 constexpr int K = decltype (cute::size (mma_B))::value / N;
334-
335- using compress_type = uint32_t ;
336- constexpr int compress_size = 32 / cute::sizeof_bits_v<ElementB>;
337- constexpr int vec_size = K / compress_size;
338-
330+
331+
332+ using src_compress_type = uint64_t ;
333+ using dst_compress_type = uint64_t ;
334+ constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
335+ constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 16
336+ constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; // 4, 16 -> max vec_size of sycl::vec
337+ constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; // 16, 16 -> max vec_size of sycl::vec
338+ constexpr int src_loop_num = K / src_vec_size / src_compress_size;
339+ constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
340+
339341 alignas (16 ) ElementB* src = reinterpret_cast <ElementB*>(smem_buf) + thread_idx * (K * 4 ); // for K=64, 4 is hardcode for 128B alignment.
340342 const uint8_t * gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k /2 + k_tile * BLK_K /2 ;
341- reinterpret_cast <sycl::vec<uint64_t , 4 >*>(src)[0 ] = reinterpret_cast <const sycl::vec<uint64_t , 4 >*>(gB_ptr )[0 ];
343+ reinterpret_cast <sycl::vec<src_compress_type, src_vec_size >*>(src)[0 ] = reinterpret_cast <const sycl::vec<src_compress_type, src_vec_size >*>(gB_ptr )[0 ];
342344
343- // compress_type src[vec_size];
344- // reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<compress_type, vec_size>*>(slm_B)[0];
345345
346346 ElementMMA* private_slm = reinterpret_cast <ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
347- // ElementMMA dst[K];
348347
349348 float scale_value = fragment_scale (0 );
350349
351350 #pragma unroll
352- for (int i = 0 ; i < vec_size ; ++i) {
353- uint32_t src_value = reinterpret_cast <uint32_t *>(src)[i];
351+ for (int i = 0 ; i < src_vec_size ; ++i) {
352+ src_compress_type src_value = reinterpret_cast <src_compress_type *>(src)[i];
354353 #pragma unroll
355- for (int j = 0 ; j < compress_size ; ++j) {
354+ for (int j = 0 ; j < src_compress_size ; ++j) {
356355 uint8_t bit_value = (src_value >> (4 * (((j+1 ) & 1 ) + (j >> 1 ) * 2 ))) & 0xF ;
357- private_slm[i * compress_size + j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
358- // dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
356+ private_slm[i * src_compress_size + j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
359357 }
360358 }
361359
362360 for (int i=0 ; i<K/4 /16 ; i++){
363- reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ()))[i] = reinterpret_cast <const sycl::vec<int64_t , 16 >*>(private_slm)[i];
364- // reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[i];
361+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[i] = reinterpret_cast <const sycl::vec<dst_compress_type, dst_vec_size>*>(private_slm)[i];
365362 }
366363 };
367364 #endif
0 commit comments