@@ -118,30 +118,15 @@ infiniStatus_t Descriptor::calculate(void *workspace,
118118
119119 bool transpose_mat_1 = _info.transpose_mat_1 ;
120120 bool transpose_mat_2 = _info.transpose_mat_2 ;
121- int64_t M;
122- int64_t N;
123- int64_t lda;
124- int64_t ldb;
125- cublasOperation_t transa;
126- cublasOperation_t transb;
127-
128- if (transpose_mat_2) {
129- M = static_cast <int64_t >(_info.N );
130- N = static_cast <int64_t >(_info.M );
131- lda = (bit == 4 ? static_cast <int64_t >(_info.ldb ) * 2 : static_cast <int64_t >(_info.ldb ));
132- ldb = static_cast <int64_t >(_info.lda );
133- std::swap (a, b);
134- std::swap (kernel_Atype_, kernel_Btype_);
135- transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
136- transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
137- } else {
138- M = static_cast <int64_t >(_info.M );
139- N = static_cast <int64_t >(_info.N );
140- lda = static_cast <int64_t >(_info.lda );
141- ldb = static_cast <int64_t >(_info.ldb );
142- transa = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
143- transb = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
144- }
121+
122+ int64_t M = static_cast <int64_t >(_info.M );
123+ int64_t N = static_cast <int64_t >(_info.N );
124+ int64_t lda = static_cast <int64_t >(_info.lda );
125+ int64_t ldb = ((bit == 4 && transpose_mat_2) ? 2 * static_cast <int64_t >(_info.ldb ) : static_cast <int64_t >(_info.ldb ));
126+
127+ cublasOperation_t transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
128+ cublasOperation_t transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
129+
145130 int64_t scales_size_0 = static_cast <int64_t >(_info.scales_size_0 );
146131 int64_t scales_size_1 = static_cast <int64_t >(_info.scales_size_1 );
147132
@@ -150,7 +135,7 @@ infiniStatus_t Descriptor::calculate(void *workspace,
150135 dlblasExtQuantParametersV2_t extParameters;
151136
152137 if (quant_type == 0 ) {
153- extParameters.a_group_size_m = M / scales_size_1;
138+ extParameters.a_group_size_m = N / scales_size_1;
154139 extParameters.a_group_size_k = K / scales_size_0;
155140 extParameters.a_zeropoints_type = kernel_Ztype_;
156141 extParameters.a_zeropoints = b_zeros;
@@ -166,13 +151,13 @@ infiniStatus_t Descriptor::calculate(void *workspace,
166151 } else if (quant_type == 2 || quant_type == 3 ) {
167152 // calculate block_shape according weight/scales shape
168153 int block_shape = 128 ;
169- while ((M + block_shape - 1 ) / block_shape < scales_size_0) {
154+ while ((N + block_shape - 1 ) / block_shape < scales_size_0) {
170155 block_shape /= 2 ;
171156 if (block_shape < 32 ) {
172157 fprintf (stderr,
173158 " INTERNAL ASSERT FAILED: block_shape >= 32\n "
174159 " Invalid fp blockwise linear arguments. Weight: [%d, %d]. Scales: [%d, %d].\n " ,
175- (int )M , (int )K, (int )scales_size_0, (int )scales_size_1);
160+ (int )N , (int )K, (int )scales_size_0, (int )scales_size_1);
176161 abort ();
177162 }
178163 }
@@ -187,7 +172,12 @@ infiniStatus_t Descriptor::calculate(void *workspace,
187172 extParameters.a_zeropoints = nullptr ;
188173 extParameters.a_scales = b_scales;
189174 }
190-
175+ printf (" a=%s, b=%s, c=%s\n " ,
176+ _info.transpose_mat_1 ? " true" : " false" ,
177+ _info.transpose_mat_2 ? " true" : " false" ,
178+ _info.transpose_result ? " true" : " false" );
179+ printf (" M-K-N:[%ld, %ld, %ld], lda-ldb-ldc:[%ld, %ld, %ld]\n " , M, K, N, lda, ldb, result_ld);
180+ printf (" quant type:%ld, bit:%ld\n " , quant_type, bit);
191181 if (_info.dtype == INFINI_DTYPE_F16 || _info.dtype == INFINI_DTYPE_BF16) {
192182 CHECK_STATUS (_opaque->internal ->useCublas (
193183 (cudaStream_t)stream,
@@ -196,16 +186,16 @@ infiniStatus_t Descriptor::calculate(void *workspace,
196186 dlblasGemmExV2 (handle,
197187 transa,
198188 transb,
199- M,
200189 N,
190+ M,
201191 K,
202192 &alpha,
203- a,
204- kernel_Atype_,
205- lda,
206193 b,
207194 kernel_Btype_,
208195 ldb,
196+ a,
197+ kernel_Atype_,
198+ lda,
209199 &beta,
210200 out,
211201 kernel_Ctype_,
0 commit comments