Skip to content

Commit 05ff59c

Browse files
authored
CUDA: batch out_prod inner loop with cublasSgemmStridedBatched (ggml-org#22651)
* CUDA: batch out_prod inner loop with cublasSgemmStridedBatched * CUDA: batch out_prod inner loop with cublasSgemmStridedBatched * CUDA: add cublasSgemmStridedBatched mapping for HIP and MUSA backends
1 parent aaf4a4d commit 05ff59c

4 files changed

Lines changed: 31 additions & 7 deletions

File tree

ggml/src/ggml-cuda/out-prod.cu

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
5454
const int64_t dps2 = ne2 / ne02;
5555
const int64_t dps3 = ne3 / ne03;
5656

57-
// TODO batched matrix multiplication
58-
for (int64_t i3 = 0; i3 < ne3; ++i3) {
59-
for (int64_t i2 = 0; i2 < ne2; ++i2) {
57+
if (dps2 == 1 && ne2 > 1) {
58+
// src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM
59+
GGML_ASSERT(ne2 <= std::numeric_limits<int>::max());
60+
const int batch_count = (int) ne2;
61+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
6062
CUBLAS_CHECK(
61-
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
63+
cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op,
6264
ne0, ne1, ne01,
63-
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
64-
src1_d + i3 *s13 + i2 *s12, ldb,
65-
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
65+
&alpha, src0_d + (i3/dps3)*s03, lda, s02,
66+
src1_d + i3 *s13, ldb, s12,
67+
&beta, dst_d + i3 *s3, ldc, s2,
68+
batch_count));
69+
}
70+
} else {
71+
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
72+
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
73+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
74+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
75+
CUBLAS_CHECK(
76+
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
77+
ne0, ne1, ne01,
78+
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
79+
src1_d + i3 *s13 + i2 *s12, ldb,
80+
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
81+
}
6682
}
6783
}
6884
}

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
4949
#define cublasSetStream hipblasSetStream
5050
#define cublasSgemm hipblasSgemm
51+
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
5152
#define cublasStatus_t hipblasStatus_t
5253
#define cublasOperation_t hipblasOperation_t
5354
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#define cublasSetMathMode mublasSetMathMode
3333
#define cublasSetStream mublasSetStream
3434
#define cublasSgemm mublasSgemm
35+
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
3536
#define cublasStatus_t mublasStatus_t
3637
#define cublasOperation_t mublasOperation_t
3738
#define cublasGetStatusString mublasGetStatusString

tests/test-backend-ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8385,6 +8385,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
83858385
}
83868386
}
83878387

8388+
// ne2 sweep to cover the cublasSgemmStridedBatched path (dps2 == 1, ne2 > 1)
8389+
for (int64_t ne2 : {1, 8, 16, 32}) {
8390+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32,
8391+
256, 16, 16, {ne2, 1}, {1, 1}));
8392+
}
8393+
83888394
// add_id
83898395
for (ggml_type type_a : {GGML_TYPE_F32}) {
83908396
for (ggml_type type_b : {GGML_TYPE_F32}) {

0 commit comments

Comments
 (0)