@@ -417,24 +417,24 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
417417 constexpr int src_loop_num = 1 ; // K / src_vec_size / src_compress_size;
418418 constexpr int dst_loop_num = 1 ; // K / dst_vec_size / dst_compress_size;
419419
420- src_compress_type src[src_loop_num * src_vec_size];
420+ // src_compress_type src[src_loop_num * src_vec_size];
421421 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
422422
423- reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[0 ];
423+ // reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
424424 float scale_value = fragment_scale (0 );// (dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
425425
426426 #pragma unroll
427427 for (int v = 0 ; v < src_vec_size; v++) {
428428 int dst_base_idx = v * src_compress_size;
429429 int c = 0 ;
430- uint8_t bit_value = (src [v] >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
430+ uint8_t bit_value = (reinterpret_cast <src_compress_type*>( cute::raw_pointer_cast (dequant_frag. data ())) [v] >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
431431 float converted_value_1 = quant_map[bit_value];
432432 float converted_value_2 = 0 .f ;
433433 #pragma unroll
434434 for (; c < src_compress_size-1 ;) {
435435 converted_value_2 = converted_value_1;
436436 c++;
437- bit_value = (src [v] >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
437+ bit_value = (reinterpret_cast <src_compress_type*>( cute::raw_pointer_cast (dequant_frag. data ())) [v] >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
438438 converted_value_1 = quant_map[bit_value];
439439 dst[dst_base_idx + c-1 ] = static_cast <ElementMMA>(converted_value_2 * scale_value);
440440 }
0 commit comments