@@ -383,8 +383,8 @@ inline float dDequantizeNF4(unsigned char val) {
383383
384384 // if(cute::thread0()) printf("params.group_size = %d, k_reload_factor = %d, k_tile_count = %d, N = %d, K = %d, src_compress_size = %d, src_vec_size = %d, dst_compress_size = %d, dst_vec_size = %d\n",params.group_size, k_reload_factor, k_tile_count, N, K, src_compress_size, src_vec_size, dst_compress_size, dst_vec_size);
385385
386- src_compress_type src[N*src_loop_num* src_vec_size];
387- ElementMMA dst[N*K ];
386+ src_compress_type src[src_vec_size];
387+ ElementMMA dst[dst_compress_size * dst_vec_size ];
388388
389389 #pragma unroll
390390 for (int n = 0 ; n < N; n++) {
@@ -393,11 +393,11 @@ inline float dDequantizeNF4(unsigned char val) {
393393 for (int l = 0 ; l < src_loop_num; l++) {
394394 // src_compress_type src[src_vec_size];
395395 // ElementMMA dst[K/dst_loop_num];
396- reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[n*src_loop_num + l ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l];
396+ reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l];
397397 #pragma unroll
398398 for (int v = 0 ; v < src_vec_size; v++) {
399- src_compress_type src_value = src[(n*src_loop_num + l)*src_vec_size + v];
400- int dst_idx = ((n*src_loop_num + l)* src_vec_size + v) * src_compress_size;
399+ src_compress_type src_value = src[v];
400+ int dst_idx = v * src_compress_size;
401401 #pragma unroll
402402 for (int c = 0 ; c < src_compress_size; c++) {
403403 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
@@ -408,7 +408,7 @@ inline float dDequantizeNF4(unsigned char val) {
408408
409409 #pragma unroll
410410 for (int l = 0 ; l < dst_loop_num; l++) {
411- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[n*dst_loop_num + l ];
411+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0 ];
412412 }
413413 }
414414 };
0 commit comments