@@ -329,46 +329,6 @@ inline float dDequantizeNF4(unsigned char val) {
329329 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
330330 int prefetch_k = k_start_idx;
331331
332- #if 0
333- constexpr float VALUE0 = -1.0f;
334- constexpr float VALUE1 = -0.6961928f;
335- constexpr float VALUE2 = -0.52507305f;
336- constexpr float VALUE3 = -0.39491749f;
337- constexpr float VALUE4 = -0.28444138f;
338- constexpr float VALUE5 = -0.18477343f;
339- constexpr float VALUE6 = -0.09105004f;
340- constexpr float VALUE7 = 0.0f;
341- constexpr float VALUE8 = 0.0795803f;
342- constexpr float VALUE9 = 0.1609302f;
343- constexpr float VALUE10 = 0.2461123f;
344- constexpr float VALUE11 = 0.33791524f;
345- constexpr float VALUE12 = 0.44070983f;
346- constexpr float VALUE13 = 0.562617f;
347- constexpr float VALUE14 = 0.72295684f;
348- constexpr float VALUE15 = 1.0f;
349-
350- auto quant_map_alias = [&](uint8_t index) {
351- switch(index) {
352- case 0: return VALUE0;
353- case 1: return VALUE1;
354- case 2: return VALUE2;
355- case 3: return VALUE3;
356- case 4: return VALUE4;
357- case 5: return VALUE5;
358- case 6: return VALUE6;
359- case 7: return VALUE7;
360- case 8: return VALUE8;
361- case 9: return VALUE9;
362- case 10: return VALUE10;
363- case 11: return VALUE11;
364- case 12: return VALUE12;
365- case 13: return VALUE13;
366- case 14: return VALUE14;
367- case 15: return VALUE15;
368- }
369- };
370- #endif
371-
372332#if 0 //SLM
373333 #if 1
374334 auto dequant = [&] (int k_tile) {
@@ -426,11 +386,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
426386 };
427387 #endif
428388#else // register
429- #if 1 // vectorized load/store
430389 auto dequant = [&] {
431390 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
432391 constexpr int K = decltype (cute::size (mma_B))::value / N;
433- // if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
434392
435393 using src_compress_type = uint64_t ;
436394 using dst_compress_type = uint64_t ;
@@ -440,121 +398,37 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
440398 constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; // 16, 16 -> max vec_size of sycl::vec
441399 constexpr int src_loop_num = K / src_vec_size / src_compress_size;
442400 constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
443- #if 0
444- if(cute::thread0()) printf("N = %d, K = %d, src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n", N, K, src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num);
445- #endif
401+
446402 src_compress_type src[src_vec_size];
447403 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
448404
449- sycl::vec<float , 16 > loaded = *(sycl::vec<float , 16 >*)&quant_map[0 ];
450405
451406 #pragma unroll
452407 for (int n = 0 ; n < N; n++) {
453408 #pragma unroll
454409 for (int l = 0 ; l < src_loop_num; l++) {
455410 reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l];
456411
457- #if 0
458- if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
459- printf("n = %d, src_l = %d\n", n, l);
460- print("======================= src vectorization: \n");
461- print(" src_g_ptr : "); print(&(reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n * src_loop_num + l])); print("\n");
462- print(" src_ptr : "); print(&(reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0])); print("\n");
463- print("=======================\n");
464- }
465- #endif
466412 #pragma unroll
467413 for (int v = 0 ; v < src_vec_size; v++) {
468414 src_compress_type src_value = src[v];
469415 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
416+
470417 #pragma unroll
471418 for (int c = 0 ; c < src_compress_size; c++) {
472419 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
473- // float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) / GROUP_SIZE);
474420 float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
475421 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
476-
477- // uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
478- // uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
479- // float ts_high = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE);
480- // float ts_low = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE);
481- // dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
482- // dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
483- #if 0
484- //dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
485-
486- constexpr uint8_t VEC_WIDTH = 4;
487- uint8_t base_offset = (bit_value / VEC_WIDTH) * VEC_WIDTH;
488- sycl::vec<float, VEC_WIDTH> loaded = *(sycl::vec<float, VEC_WIDTH>*)&quant_map[base_offset];
489- //auto mask = (sycl::vec<int, VEC_WIDTH>(0,1,2,3) == (bit_value % VEC_WIDTH));
490- //float convert_value = loaded[0] * static_cast<float>(mask[0]) +
491- // loaded[1] * static_cast<float>(mask[1]) +
492- // loaded[2] * static_cast<float>(mask[2]) +
493- // loaded[3] * static_cast<float>(mask[3]);
494- auto lane = bit_value % VEC_WIDTH;
495- float convert_value = loaded[lane];
496- dst[dst_base_idx + c] = static_cast<ElementMMA>(convert_value * scale_value);
497- //#endif
498- auto mask = (sycl::vec<uint8_t, 16>(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15) == sycl::vec<uint8_t, 16>(bit_value));
499- float convert_value = 0.0f;
500- #pragma unroll
501- for (int i = 0; i < 16; ++i) {
502- convert_value += loaded[i] * static_cast<float>(mask[i]);
503- }
504-
505- //auto sg = sycl::ext::oneapi::experimental::this_sub_group();
506- //float convert_value = sycl::select_from_group(
507- // sg,
508- // loaded,
509- // bit_value // 直接使用bit_value作为索引
510- //);
511- dst[dst_base_idx + c] = static_cast<ElementMMA>(convert_value * scale_value);
512- #endif
513-
514- #if 0
515- if(thread_idx==60 && m_coord==0 && n_coord==0 && l_coord==0){
516- printf("tid = %d, m_coord = %d, n_coord = %d, l_coord = %d, n = %d, src_l = %d, dst_dx = %d, scale_idx = %d, scale_value = %f\n", thread_idx, m_coord, n_coord, l_coord, n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
517- //print(" scale_value : "); print(scale_value); print("\n");
518- }
519- #endif
520422 }
521423 }
522424 }
523425
524426 #pragma unroll
525427 for (int l = 0 ; l < dst_loop_num; l++) {
526428 reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
527-
528- #if 0
529- if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
530- printf("n = %d, dst_l = %d\n", n, l);
531- print("======================= dst vectorization: \n");
532- print(" dst_g_ptr : "); print(&(reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l])); print("\n");
533- print(" dst_ptr : "); print(&(reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l])); print("\n");
534- print("=======================\n");
535- }
536- #endif
537429 }
538430 }
539431 };
540- #else //elemented load/store
541- auto dequant = [&] {
542- constexpr int N = decltype(cute::size<1>(mma_B))::value;
543- constexpr int K = decltype(cute::size(mma_B))::value / N;
544- float scale_value = fragment_scale(0);
545-
546- //#pragma unroll
547- //for(int i=0; i<K; i++) {
548- // mma_B[i] = static_cast<ElementMMA>(quant_map[(reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i/2] >> (4 * ((i+1)%2))) & 0xf] * scale_value);
549- //}
550-
551- #pragma unroll
552- for(int i=0; i<K/2; i++) {
553- mma_B[i*2] = static_cast<ElementMMA>(quant_map[(reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i] >> 4) & 0xf] * scale_value);
554- mma_B[i*2+1] = static_cast<ElementMMA>(quant_map[reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i] & 0xf] * scale_value);
555- }
556- };
557- #endif
558432#endif
559433
560434 CUTLASS_PRAGMA_UNROLL
0 commit comments