Skip to content

Commit 3f63567

Browse files
committed
save code
1 parent 6262c79 commit 3f63567

1 file changed

Lines changed: 24 additions & 24 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static constexpr float quant_map_static[16] = {
6161
};
6262
#endif
6363

64-
using TileShape = Shape<_64, _128, _128>;
64+
using TileShape = Shape<_64, _128, _64>;
6565
using TiledMma =
6666
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
6767
Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
@@ -414,8 +414,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
414414
constexpr int dst_compress_size = 4; //cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
415415
constexpr int src_vec_size = 8; //(K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
416416
constexpr int dst_vec_size = 16; //(K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
417-
constexpr int src_loop_num = 2; //K / src_vec_size / src_compress_size;
418-
constexpr int dst_loop_num = 2; //K / dst_vec_size / dst_compress_size;
417+
constexpr int src_loop_num = 1; //K / src_vec_size / src_compress_size;
418+
constexpr int dst_loop_num = 1; //K / dst_vec_size / dst_compress_size;
419419

420420
src_compress_type src[src_loop_num * src_vec_size];
421421
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
@@ -441,29 +441,29 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
441441
dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
442442
}
443443

444-
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[1] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[1];
445-
scale_value = fragment_scale(1);
444+
// reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[1] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[1];
445+
// scale_value = fragment_scale(1);
446446
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0];
447447

448-
#pragma unroll
449-
for (int v = src_vec_size; v < src_loop_num * src_vec_size; v++) {
450-
int dst_base_idx = v * src_compress_size;
451-
int c = 0;
452-
uint8_t bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
453-
float converted_value_1 = quant_map[bit_value];
454-
float converted_value_2 = 0.f;
455-
#pragma unroll
456-
for (; c < src_compress_size-1;) {
457-
converted_value_2 = converted_value_1;
458-
c++;
459-
bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
460-
converted_value_1 = quant_map[bit_value];
461-
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
462-
}
463-
dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
464-
}
465-
466-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[1] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[1];
448+
// #pragma unroll
449+
// for (int v = src_vec_size; v < src_loop_num * src_vec_size; v++) {
450+
// int dst_base_idx = v * src_compress_size;
451+
// int c = 0;
452+
// uint8_t bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
453+
// float converted_value_1 = quant_map[bit_value];
454+
// float converted_value_2 = 0.f;
455+
// #pragma unroll
456+
// for (; c < src_compress_size-1;) {
457+
// converted_value_2 = converted_value_1;
458+
// c++;
459+
// bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
460+
// converted_value_1 = quant_map[bit_value];
461+
// dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
462+
// }
463+
// dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
464+
// }
465+
//
466+
// reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[1] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[1];
467467

468468
};
469469
#endif

0 commit comments

Comments
 (0)