66#include < optional>
77#include < vector>
88
9+ inline void prepare_matrix_for_cublas (
10+ infiniopTensorDescriptor_t tensor,
11+ bool &transpose_tensor) {
12+
13+ auto strides = tensor->strides ();
14+ auto sizes = tensor->shape ();
15+
16+ if ((strides[0 ] == 1 ) && (strides[1 ] >= std::max<int64_t >(1 , sizes[0 ]))) {
17+
18+ transpose_tensor = false ;
19+ return ;
20+ }
21+ if ((strides[1 ] == 1 ) && (strides[0 ] >= std::max<int64_t >(1 , sizes[1 ]))) {
22+
23+ transpose_tensor = true ;
24+ return ;
25+ }
26+ transpose_tensor = true ;
27+ }
28+
929namespace op ::gptq_qyblas_gemm {
1030
1131class GptqQyblasGemmInfo {
1232 GptqQyblasGemmInfo () = default ;
1333
1434public:
15- infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype;
35+ infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype ;
1636 size_t M, K, N, scales_size_0, scales_size_1;
1737 ptrdiff_t lda, ldb, result_ld;
18- bool transpose_mat_1, transpose_mat_2, transpose_result;
38+ bool transpose_result;
39+ char transa, transb;
1940
2041 static utils::Result<GptqQyblasGemmInfo> createGptqQyblasGemmInfo (
2142 infiniopTensorDescriptor_t out_desc,
@@ -27,17 +48,38 @@ class GptqQyblasGemmInfo {
2748 auto dtype = a_desc->dtype ();
2849
2950 CHECK_DTYPE (dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
30- CHECK_DTYPE (dtype, out_desc->dtype ());
51+ auto out_dtype = out_desc->dtype ();
52+ CHECK_DTYPE (dtype, out_dtype);
3153
3254 const infiniDtype_t weight_dtype = b_desc->dtype ();
33- CHECK_DTYPE (weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8);
55+ // CHECK_DTYPE(weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8);
3456
3557 const infiniDtype_t scales_dtype = b_scales_desc->dtype ();
3658 const infiniDtype_t zeros_dtype = b_zeros_desc->dtype ();
3759
38- size_t M = out_desc->shape ()[0 ];
39- size_t N = out_desc->shape ()[1 ];
40- size_t K = a_desc->shape ()[1 ];
60+ bool transpose_result = false ;
61+ bool transpose_mat_1 = false ;
62+ bool transpose_mat_2 = false ;
63+
64+ prepare_matrix_for_cublas (out_desc, transpose_result);
65+
66+ auto mata = (transpose_result ? b_desc : a_desc);
67+ prepare_matrix_for_cublas (transpose_result ? b_desc : a_desc, transpose_mat_1);
68+ auto matb = (transpose_result ? a_desc : b_desc);
69+ prepare_matrix_for_cublas (transpose_result ? a_desc : b_desc, transpose_mat_2);
70+
71+ auto mat1_sizes = a_desc->shape ();
72+ auto mat2_sizes = b_desc->shape ();
73+ if (transpose_result) {
74+ transpose_mat_1 = !transpose_mat_1;
75+ transpose_mat_2 = !transpose_mat_2;
76+ mat1_sizes = mata->shape ();
77+ mat2_sizes = matb->shape ();
78+ }
79+
80+ size_t M = mat1_sizes[transpose_result ? 1 : 0 ];
81+ size_t K = mat1_sizes[transpose_result ? 0 : 1 ];
82+ size_t N = mat2_sizes[transpose_result ? 0 : 1 ];
4183
4284 size_t scales_size_0 = b_scales_desc->shape ()[0 ];
4385 size_t scales_size_1 = b_scales_desc->shape ()[1 ];
@@ -50,40 +92,23 @@ class GptqQyblasGemmInfo {
5092 && b_zeros_desc->ndim () == ndim,
5193 INFINI_STATUS_BAD_TENSOR_SHAPE);
5294
53- bool transpose_result = false ;
54- if (out_desc->strides ()[0 ] == 1 && out_desc->strides ()[1 ] >= std::max<int64_t >(1 , out_desc->shape ()[0 ])) {
55- transpose_result = true ;
56- } else if (out_desc->strides ()[1 ] == 1 && out_desc->strides ()[0 ] >= std::max<int64_t >(1 , out_desc->shape ()[1 ])) {
57- transpose_result = false ;
58- } else {
59- transpose_result = false ;
60- }
61- bool transpose_mat_1 = false ;
62- if (a_desc->strides ()[0 ] == 1 && a_desc->strides ()[1 ] >= std::max<int64_t >(1 , a_desc->shape ()[0 ])) {
63- transpose_mat_1 = true ;
64- } else if (a_desc->strides ()[1 ] == 1 && a_desc->strides ()[0 ] >= std::max<int64_t >(1 , a_desc->shape ()[1 ])) {
65- transpose_mat_1 = false ;
66- } else {
67- transpose_mat_1 = false ;
68- }
69- bool transpose_mat_2 = false ;
70- if (b_desc->strides ()[0 ] == 1 && b_desc->strides ()[1 ] >= std::max<int64_t >(1 , b_desc->shape ()[0 ])) {
71- transpose_mat_2 = true ;
72- } else if (b_desc->strides ()[1 ] == 1 && b_desc->strides ()[0 ] >= std::max<int64_t >(1 , b_desc->shape ()[1 ])) {
73- transpose_mat_2 = false ;
74- } else {
75- transpose_mat_2 = false ;
76- }
95+ ptrdiff_t lda = mata->strides ()[(transpose_mat_1 == transpose_result)
96+ ? 1
97+ : 0 ];
98+ ptrdiff_t ldb = matb->strides ()[(transpose_mat_2 == transpose_result)
99+ ? 1
100+ : 0 ];
101+ ptrdiff_t result_ld = out_desc->strides ()[transpose_result ? 0 : 1 ];
77102
78- ptrdiff_t lda = a_desc->strides ()[transpose_mat_1 ? 1 : 0 ];
79- ptrdiff_t ldb = b_desc->strides ()[transpose_mat_2 ? 1 : 0 ];
80- ptrdiff_t result_ld = out_desc->strides ()[transpose_result ? 1 : 0 ];
103+ char transa = transpose_mat_1 ? ' t' : ' n' ;
104+ char transb = transpose_mat_2 ? ' t' : ' n' ;
81105
82106 return utils::Result<GptqQyblasGemmInfo>(GptqQyblasGemmInfo{
83- dtype, weight_dtype, scales_dtype, zeros_dtype,
107+ dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype,
84108 M, K, N, scales_size_0, scales_size_1,
85109 lda, ldb, result_ld,
86- transpose_mat_1, transpose_mat_2, transpose_result});
110+ transpose_result,
111+ transa, transb});
87112 }
88113};
89114
0 commit comments