@@ -371,30 +371,39 @@ inline float dDequantizeNF4(unsigned char val) {
371371 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
372372 constexpr int K = decltype (cute::size (mma_B))::value / N;
373373
374- using compress_type = uint64_t ;
375- constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
376- constexpr int vec_size = K / compress_size;
377-
374+ // using compress_type = uint64_t;
375+ // constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
376+ // constexpr int vec_size = K / compress_size;
377+
378+ using src_compress_type = uint64_t ;
379+ using dst_compress_type = uint64_t ;
380+ constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
381+ constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 16
382+ constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; // 4, 16 -> max vec_size of sycl::vec
383+ constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; // 16, 16 -> max vec_size of sycl::vec
384+ constexpr int src_loop_num = K / src_vec_size / src_compress_size;
385+ constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
386+
378387 // if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
379- compress_type src[vec_size ];
388+ src_compress_type src[src_vec_size ];
380389 ElementMMA dst[K];
381390
382391 float scale_value = fragment_scale (0 );
383392
384- reinterpret_cast <sycl::vec<compress_type, vec_size >*>(src)[0 ] = reinterpret_cast <sycl::vec<compress_type, vec_size >*>(cute::raw_pointer_cast (dequant_frag.data ()))[0 ];
393+ reinterpret_cast <sycl::vec<src_compress_type, src_vec_size >*>(src)[0 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size >*>(cute::raw_pointer_cast (dequant_frag.data ()))[0 ];
385394
386395 #pragma unroll
387- for (int i = 0 ; i < vec_size ; i++) {
388- compress_type src_value = src[i];
396+ for (int i = 0 ; i < src_vec_size ; i++) {
397+ src_compress_type src_value = src[i];
389398 #pragma unroll
390- for (int j = 0 ; j < compress_size ; j++) {
399+ for (int j = 0 ; j < src_compress_size ; j++) {
391400 unsigned char bit_value = (src_value >> (4 * ((j+1 )%2 + (j/2 )*2 ))) & 0xf ;
392401 // __builtin_assume(bit_value >= 0 && bit_value < 16);
393- dst[i*compress_size +j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
402+ dst[i*src_compress_size +j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
394403 // dst[i*compress_size+j] = static_cast<ElementMMA>(dDequantizeNF4(bit_value) * scale_value);
395404 }
396405 }
397- reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ()))[0 ] = reinterpret_cast <sycl::vec<int64_t , 16 >*>(dst)[0 ];
406+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size >*>(cute::raw_pointer_cast (mma_B.data ()))[0 ] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size >*>(dst)[0 ];
398407 };
399408 #else //elemented load/store
400409 auto dequant = [&] {
0 commit comments