@@ -255,13 +255,23 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in
255255 const void * beta = &fbeta;
256256 hipblasStatus_t status;
257257
258+ #if hipblasVersionMajor >= 3
259+ status = hipblasGemmEx (context->m_handle ,
260+ transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
261+ transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
262+ m, n, k,
263+ alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta,
264+ C, HIP_R_32I, ldc,
265+ HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
266+ #else
258267 status = hipblasGemmEx (context->m_handle ,
259268 transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
260269 transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
261270 m, n, k,
262271 alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta,
263272 C, HIPBLAS_R_32I, ldc,
264273 HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
274+ #endif
265275
266276 if (status != HIPBLAS_STATUS_SUCCESS)
267277 {
@@ -285,13 +295,23 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
285295 // printf("%i %i %i\n", strideA, strideB, strideC);
286296 // printf("%i\n", batchCount);
287297
298+ #if hipblasVersionMajor >= 3
299+ status = hipblasGemmStridedBatchedEx (context->m_handle ,
300+ transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
301+ transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
302+ m, n, k,
303+ alpha, A, HIP_R_8I, lda, (long long int )strideA, B, HIP_R_8I, ldb, (long long int )strideB, beta,
304+ C, HIP_R_32I, ldc, (long long int )strideC, batchCount,
305+ HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
306+ #else
288307 status = hipblasGemmStridedBatchedEx (context->m_handle ,
289308 transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
290309 transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
291310 m, n, k,
292311 alpha, A, HIPBLAS_R_8I, lda, (long long int )strideA, B, HIPBLAS_R_8I, ldb, (long long int )strideB, beta,
293312 C, HIPBLAS_R_32I, ldc, (long long int )strideC, batchCount,
294313 HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
314+ #endif
295315
296316 if (status != HIPBLAS_STATUS_SUCCESS)
297317 {
0 commit comments