@@ -339,27 +339,39 @@ inline float dDequantizeNF4(unsigned char val) {
339339 constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
340340
341341 alignas (16 ) ElementB* src = reinterpret_cast <ElementB*>(smem_buf) + thread_idx * (K * 4 ); // for K=64, 4 is hardcode for 128B alignment.
342- const uint8_t * gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2 ;
343- reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr )[0 ];
342+ const uint8_t * gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2 ;
343+ // reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr)[0];
344344
345345
346- ElementMMA* private_slm = reinterpret_cast <ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
347-
348- float scale_value = fragment_scale (0 );
346+ ElementMMA* dst_slm = reinterpret_cast <ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
349347
348+ #pragma unroll
349+ for (int n = 0 ; n < N; n++) {
350+ float scale_value = fragment_scale (n);
350351 #pragma unroll
351- for (int i = 0 ; i < src_vec_size; ++i) {
352- src_compress_type src_value = reinterpret_cast <src_compress_type*>(src)[i];
352+ for (int l = 0 ; l < src_loop_num; l++) {
353+ reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr )[n*src_loop_num + l];
354+ #pragma unroll
355+ for (int v = 0 ; v < src_vec_size; ++v) {
356+ src_compress_type src_value = reinterpret_cast <src_compress_type*>(src)[v];
357+ int dst_idx = v * src_compress_size;
353358 #pragma unroll
354- for (int j = 0 ; j < src_compress_size; ++j ) {
355- uint8_t bit_value = (src_value >> (4 * (((j+ 1 ) & 1 ) + (j >> 1 ) * 2 ))) & 0xF ;
356- private_slm[i * src_compress_size + j ] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
359+ for (int c = 0 ; c < src_compress_size; ++c ) {
360+ uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
361+ dst_slm[dst_idx + c ] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
357362 }
363+ }
358364 }
359-
360- for (int i=0 ; i<K/4 /16 ; i++){
361- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[i] = reinterpret_cast <const sycl::vec<dst_compress_type, dst_vec_size>*>(private_slm)[i];
365+
366+ #pragma unroll
367+ for (int l = 0 ; l < dst_loop_num; l++) {
368+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <const sycl::vec<dst_compress_type, dst_vec_size>*>(dst_slm)[0 ];
362369 }
370+ }
371+
372+ // for(int i=0; i<K/4/16; i++){
373+ // reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<const sycl::vec<dst_compress_type, dst_vec_size>*>(private_slm)[i];
374+ // }
363375 };
364376 #endif
365377#else //register
0 commit comments