Skip to content

Commit df97df2

Browse files
committed
update for hipblasVersionMajor >=3
1 parent 35266ea commit df97df2

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

csrc/ops.hip

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,23 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in
255255
const void * beta = &fbeta;
256256
hipblasStatus_t status;
257257

258+
#if hipblasVersionMajor >= 3
259+
status = hipblasGemmEx(context->m_handle,
260+
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
261+
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
262+
m, n, k,
263+
alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta,
264+
C, HIP_R_32I, ldc,
265+
HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
266+
#else
258267
status = hipblasGemmEx(context->m_handle,
259268
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
260269
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
261270
m, n, k,
262271
alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta,
263272
C, HIPBLAS_R_32I, ldc,
264273
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
274+
#endif
265275

266276
if (status != HIPBLAS_STATUS_SUCCESS)
267277
{
@@ -285,13 +295,23 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
285295
//printf("%i %i %i\n", strideA, strideB, strideC);
286296
//printf("%i\n", batchCount);
287297

298+
#if hipblasVersionMajor >= 3
299+
status = hipblasGemmStridedBatchedEx(context->m_handle,
300+
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
301+
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
302+
m, n, k,
303+
alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta,
304+
C, HIP_R_32I, ldc, (long long int)strideC, batchCount,
305+
HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
306+
#else
288307
status = hipblasGemmStridedBatchedEx(context->m_handle,
289308
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
290309
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
291310
m, n, k,
292311
alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta,
293312
C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount,
294313
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
314+
#endif
295315

296316
if (status != HIPBLAS_STATUS_SUCCESS)
297317
{

0 commit comments

Comments
 (0)