@@ -446,6 +446,8 @@ if(cute::thread0()) printf("N = %d, K = %d, src_compress_size = %d, dst_compress
446446 src_compress_type src[src_vec_size];
447447 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
448448
449+ sycl::vec<float , 16 > loaded = *(sycl::vec<float , 16 >*)&quant_map[0 ];
450+
449451 #pragma unroll
450452 for (int n = 0 ; n < N; n++) {
451453 #pragma unroll
@@ -470,7 +472,37 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
470472 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
471473 float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + c) / GROUP_SIZE );
472474 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
475+ #if 0
473476 //dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
477+
478+ constexpr uint8_t VEC_WIDTH = 4;
479+ uint8_t base_offset = (bit_value / VEC_WIDTH) * VEC_WIDTH;
480+ sycl::vec<float, VEC_WIDTH> loaded = *(sycl::vec<float, VEC_WIDTH>*)&quant_map[base_offset];
481+ //auto mask = (sycl::vec<int, VEC_WIDTH>(0,1,2,3) == (bit_value % VEC_WIDTH));
482+ //float convert_value = loaded[0] * static_cast<float>(mask[0]) +
483+ // loaded[1] * static_cast<float>(mask[1]) +
484+ // loaded[2] * static_cast<float>(mask[2]) +
485+ // loaded[3] * static_cast<float>(mask[3]);
486+ auto lane = bit_value % VEC_WIDTH;
487+ float convert_value = loaded[lane];
488+ dst[dst_base_idx + c] = static_cast<ElementMMA>(convert_value * scale_value);
489+ //#endif
490+ auto mask = (sycl::vec<uint8_t, 16>(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15) == sycl::vec<uint8_t, 16>(bit_value));
491+ float convert_value = 0.0f;
492+ #pragma unroll
493+ for (int i = 0; i < 16; ++i) {
494+ convert_value += loaded[i] * static_cast<float>(mask[i]);
495+ }
496+
497+ //auto sg = sycl::ext::oneapi::experimental::this_sub_group();
498+ //float convert_value = sycl::select_from_group(
499+ // sg,
500+ // loaded,
501+ // bit_value // 直接使用bit_value作为索引
502+ //);
503+ dst[dst_base_idx + c] = static_cast<ElementMMA>(convert_value * scale_value);
504+ #endif
505+
474506#if 0
475507if(thread_idx==60 && m_coord==0 && n_coord==0 && l_coord==0){
476508 printf("tid = %d, m_coord = %d, n_coord = %d, l_coord = %d, n = %d, src_l = %d, dst_dx = %d, scale_idx = %d, scale_value = %f\n", thread_idx, m_coord, n_coord, l_coord, n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
0 commit comments