Skip to content

Commit bdadc9d

Browse files
committed
clean code
1 parent a2bd43b commit bdadc9d

1 file changed

Lines changed: 2 additions & 128 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 2 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -329,46 +329,6 @@ inline float dDequantizeNF4(unsigned char val) {
329329
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
330330
int prefetch_k = k_start_idx;
331331

332-
#if 0
333-
constexpr float VALUE0 = -1.0f;
334-
constexpr float VALUE1 = -0.6961928f;
335-
constexpr float VALUE2 = -0.52507305f;
336-
constexpr float VALUE3 = -0.39491749f;
337-
constexpr float VALUE4 = -0.28444138f;
338-
constexpr float VALUE5 = -0.18477343f;
339-
constexpr float VALUE6 = -0.09105004f;
340-
constexpr float VALUE7 = 0.0f;
341-
constexpr float VALUE8 = 0.0795803f;
342-
constexpr float VALUE9 = 0.1609302f;
343-
constexpr float VALUE10 = 0.2461123f;
344-
constexpr float VALUE11 = 0.33791524f;
345-
constexpr float VALUE12 = 0.44070983f;
346-
constexpr float VALUE13 = 0.562617f;
347-
constexpr float VALUE14 = 0.72295684f;
348-
constexpr float VALUE15 = 1.0f;
349-
350-
auto quant_map_alias = [&](uint8_t index) {
351-
switch(index) {
352-
case 0: return VALUE0;
353-
case 1: return VALUE1;
354-
case 2: return VALUE2;
355-
case 3: return VALUE3;
356-
case 4: return VALUE4;
357-
case 5: return VALUE5;
358-
case 6: return VALUE6;
359-
case 7: return VALUE7;
360-
case 8: return VALUE8;
361-
case 9: return VALUE9;
362-
case 10: return VALUE10;
363-
case 11: return VALUE11;
364-
case 12: return VALUE12;
365-
case 13: return VALUE13;
366-
case 14: return VALUE14;
367-
case 15: return VALUE15;
368-
}
369-
};
370-
#endif
371-
372332
#if 0 //SLM
373333
#if 1
374334
auto dequant = [&] (int k_tile) {
@@ -426,11 +386,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
426386
};
427387
#endif
428388
#else //register
429-
#if 1 //vectorized load/store
430389
auto dequant = [&] {
431390
constexpr int N = decltype(cute::size<1>(mma_B))::value;
432391
constexpr int K = decltype(cute::size(mma_B))::value / N;
433-
//if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
434392

435393
using src_compress_type = uint64_t;
436394
using dst_compress_type = uint64_t;
@@ -440,121 +398,37 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
440398
constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
441399
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
442400
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
443-
#if 0
444-
if(cute::thread0()) printf("N = %d, K = %d, src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n", N, K, src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num);
445-
#endif
401+
446402
src_compress_type src[src_vec_size];
447403
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
448404

449-
sycl::vec<float, 16> loaded = *(sycl::vec<float, 16>*)&quant_map[0];
450405

451406
#pragma unroll
452407
for (int n = 0; n < N; n++) {
453408
#pragma unroll
454409
for (int l = 0; l < src_loop_num; l++) {
455410
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l];
456411

457-
#if 0
458-
if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
459-
printf("n = %d, src_l = %d\n", n, l);
460-
print("======================= src vectorization: \n");
461-
print(" src_g_ptr : "); print(&(reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n * src_loop_num + l])); print("\n");
462-
print(" src_ptr : "); print(&(reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0])); print("\n");
463-
print("=======================\n");
464-
}
465-
#endif
466412
#pragma unroll
467413
for (int v = 0; v < src_vec_size; v++) {
468414
src_compress_type src_value = src[v];
469415
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
416+
470417
#pragma unroll
471418
for (int c = 0; c < src_compress_size; c++) {
472419
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
473-
//float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) / GROUP_SIZE);
474420
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
475421
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
476-
477-
//uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
478-
//uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
479-
//float ts_high = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE);
480-
//float ts_low = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE);
481-
//dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
482-
//dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
483-
#if 0
484-
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
485-
486-
constexpr uint8_t VEC_WIDTH = 4;
487-
uint8_t base_offset = (bit_value / VEC_WIDTH) * VEC_WIDTH;
488-
sycl::vec<float, VEC_WIDTH> loaded = *(sycl::vec<float, VEC_WIDTH>*)&quant_map[base_offset];
489-
//auto mask = (sycl::vec<int, VEC_WIDTH>(0,1,2,3) == (bit_value % VEC_WIDTH));
490-
//float convert_value = loaded[0] * static_cast<float>(mask[0]) +
491-
// loaded[1] * static_cast<float>(mask[1]) +
492-
// loaded[2] * static_cast<float>(mask[2]) +
493-
// loaded[3] * static_cast<float>(mask[3]);
494-
auto lane = bit_value % VEC_WIDTH;
495-
float convert_value = loaded[lane];
496-
dst[dst_base_idx + c] = static_cast<ElementMMA>(convert_value * scale_value);
497-
//#endif
498-
auto mask = (sycl::vec<uint8_t, 16>(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15) == sycl::vec<uint8_t, 16>(bit_value));
499-
float convert_value = 0.0f;
500-
#pragma unroll
501-
for (int i = 0; i < 16; ++i) {
502-
convert_value += loaded[i] * static_cast<float>(mask[i]);
503-
}
504-
505-
//auto sg = sycl::ext::oneapi::experimental::this_sub_group();
506-
//float convert_value = sycl::select_from_group(
507-
// sg,
508-
// loaded,
509-
// bit_value // 直接使用bit_value作为索引
510-
//);
511-
dst[dst_base_idx + c] = static_cast<ElementMMA>(convert_value * scale_value);
512-
#endif
513-
514-
#if 0
515-
if(thread_idx==60 && m_coord==0 && n_coord==0 && l_coord==0){
516-
printf("tid = %d, m_coord = %d, n_coord = %d, l_coord = %d, n = %d, src_l = %d, dst_dx = %d, scale_idx = %d, scale_value = %f\n", thread_idx, m_coord, n_coord, l_coord, n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
517-
//print(" scale_value : "); print(scale_value); print("\n");
518-
}
519-
#endif
520422
}
521423
}
522424
}
523425

524426
#pragma unroll
525427
for (int l = 0; l < dst_loop_num; l++) {
526428
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
527-
528-
#if 0
529-
if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
530-
printf("n = %d, dst_l = %d\n", n, l);
531-
print("======================= dst vectorization: \n");
532-
print(" dst_g_ptr : "); print(&(reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l])); print("\n");
533-
print(" dst_ptr : "); print(&(reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l])); print("\n");
534-
print("=======================\n");
535-
}
536-
#endif
537429
}
538430
}
539431
};
540-
#else //elemented load/store
541-
auto dequant = [&] {
542-
constexpr int N = decltype(cute::size<1>(mma_B))::value;
543-
constexpr int K = decltype(cute::size(mma_B))::value / N;
544-
float scale_value = fragment_scale(0);
545-
546-
//#pragma unroll
547-
//for(int i=0; i<K; i++) {
548-
// mma_B[i] = static_cast<ElementMMA>(quant_map[(reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i/2] >> (4 * ((i+1)%2))) & 0xf] * scale_value);
549-
//}
550-
551-
#pragma unroll
552-
for(int i=0; i<K/2; i++) {
553-
mma_B[i*2] = static_cast<ElementMMA>(quant_map[(reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i] >> 4) & 0xf] * scale_value);
554-
mma_B[i*2+1] = static_cast<ElementMMA>(quant_map[reinterpret_cast<uint8_t*>(cute::raw_pointer_cast(dequant_frag.data()))[i] & 0xf] * scale_value);
555-
}
556-
};
557-
#endif
558432
#endif
559433

560434
CUTLASS_PRAGMA_UNROLL

0 commit comments

Comments
 (0)