Skip to content

Commit 15be0d2

Browse files
committed
refactor(gemm): remove blas_handle from GemmParams/SgemvParams; add device param
GemmParams and SgemvParams are pure problem descriptions and should not carry runtime state. Move handle acquisition into GemmCuda/SgemvCuda via a device parameter, inline the dynamic_cast directly. Remove the public GetCublasHandle/GetCudaStream helpers from gemm.cuh.
1 parent a7e1b99 commit 15be0d2

5 files changed

Lines changed: 234 additions & 266 deletions

File tree

infini_train/include/common/cuda/gemm.cuh

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,12 @@
11
#pragma once
22

33
#include <cublas_v2.h>
4-
#include <cuda_runtime_api.h>
54

65
#include "infini_train/include/datatype.h"
76
#include "infini_train/include/device.h"
87

98
namespace infini_train::kernels::cuda {
109

11-
/**
12-
* Return the cuBLAS handle associated with the given device.
13-
* Shared by linear.cu, matmul.cu, and any future GEMM-using kernels.
14-
*/
15-
cublasHandle_t GetCublasHandle(const Device &device);
16-
17-
/**
18-
* Return the CUDA stream associated with the given device.
19-
* Shared by kernels that need to launch device-side code directly.
20-
*/
21-
cudaStream_t GetCudaStream(const Device &device);
22-
2310
/**
2411
* Parameter bundle for a single GEMM call:
2512
* C = alpha * op(A) * op(B) + beta * C
@@ -56,8 +43,6 @@ struct GemmParams {
5643

5744
DataType input_dtype; // dtype of A and B
5845
DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out)
59-
60-
cublasHandle_t blas_handle = nullptr;
6146
};
6247

6348
/**
@@ -67,7 +52,7 @@ struct GemmParams {
6752
* Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision.
6853
* Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)).
6954
*/
70-
void GemmCuda(const GemmParams &p);
55+
void GemmCuda(const Device &device, const GemmParams &p);
7156

7257
/**
7358
* Parameter bundle for a single SGEMV call (fp32 only):
@@ -88,9 +73,8 @@ struct SgemvParams {
8873
int incy = 1;
8974
float alpha = 1.0f;
9075
float beta = 0.0f;
91-
cublasHandle_t blas_handle = nullptr;
9276
};
9377

94-
void SgemvCuda(const SgemvParams &p);
78+
void SgemvCuda(const Device &device, const SgemvParams &p);
9579

9680
} // namespace infini_train::kernels::cuda

infini_train/src/kernels/cuda/gemm.cu

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,6 @@
1212

1313
namespace infini_train::kernels::cuda {
1414

15-
cublasHandle_t GetCublasHandle(const Device &device) {
16-
return dynamic_cast<infini_train::core::cuda::CudaBlasHandle *>(
17-
infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device))
18-
->cublas_handle();
19-
}
20-
21-
cudaStream_t GetCudaStream(const Device &device) {
22-
return dynamic_cast<infini_train::core::cuda::CudaStream *>(
23-
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
24-
->cuda_stream();
25-
}
26-
2715
namespace {
2816

2917
cudaDataType_t ToCudaDataType(DataType dt) {
@@ -42,8 +30,10 @@ cudaDataType_t ToCudaDataType(DataType dt) {
4230

4331
} // namespace
4432

45-
void GemmCuda(const GemmParams &p) {
46-
DCHECK(p.blas_handle != nullptr);
33+
void GemmCuda(const Device &device, const GemmParams &p) {
34+
const cublasHandle_t blas_handle = dynamic_cast<infini_train::core::cuda::CudaBlasHandle *>(
35+
infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device))
36+
->cublas_handle();
4737

4838
if (p.batch_count == 1) {
4939
// strides are unused in the non-batched path; assert they are left at 0
@@ -61,19 +51,20 @@ void GemmCuda(const GemmParams &p) {
6151
const cublasComputeType_t ctype = CUBLAS_COMPUTE_32F;
6252

6353
if (p.batch_count == 1) {
64-
CUBLAS_CHECK(cublasGemmEx(p.blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.B,
54+
CUBLAS_CHECK(cublasGemmEx(blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.B,
6555
type_b, p.ldb, &p.beta, p.C, type_c, p.ldc, ctype, CUBLAS_GEMM_DEFAULT));
6656
} else {
67-
CUBLAS_CHECK(cublasGemmStridedBatchedEx(p.blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A,
68-
type_a, p.lda, p.stride_a, p.B, type_b, p.ldb, p.stride_b, &p.beta, p.C,
69-
type_c, p.ldc, p.stride_c, p.batch_count, ctype, CUBLAS_GEMM_DEFAULT));
57+
CUBLAS_CHECK(cublasGemmStridedBatchedEx(blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a,
58+
p.lda, p.stride_a, p.B, type_b, p.ldb, p.stride_b, &p.beta, p.C, type_c,
59+
p.ldc, p.stride_c, p.batch_count, ctype, CUBLAS_GEMM_DEFAULT));
7060
}
7161
}
7262

73-
void SgemvCuda(const SgemvParams &p) {
74-
DCHECK(p.blas_handle != nullptr);
75-
CUBLAS_CHECK(
76-
cublasSgemv(p.blas_handle, p.trans, p.m, p.n, &p.alpha, p.A, p.lda, p.x, p.incx, &p.beta, p.y, p.incy));
63+
void SgemvCuda(const Device &device, const SgemvParams &p) {
64+
const cublasHandle_t blas_handle = dynamic_cast<infini_train::core::cuda::CudaBlasHandle *>(
65+
infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device))
66+
->cublas_handle();
67+
CUBLAS_CHECK(cublasSgemv(blas_handle, p.trans, p.m, p.n, &p.alpha, p.A, p.lda, p.x, p.incx, &p.beta, p.y, p.incy));
7768
}
7869

7970
} // namespace infini_train::kernels::cuda

infini_train/src/kernels/cuda/linear.cu

Lines changed: 84 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#include "infini_train/include/common/cuda/common_cuda.h"
1010
#include "infini_train/include/common/cuda/gemm.cuh"
1111
#include "infini_train/include/common/cuda/kernel_helper.cuh"
12+
#include "infini_train/include/core/runtime/device_guard.h"
1213
#include "infini_train/include/dispatcher.h"
1314
#include "infini_train/include/tensor.h"
1415
#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h"
16+
#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h"
1517

1618
namespace infini_train::kernels::cuda {
1719

@@ -58,7 +60,9 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
5860
auto output = std::make_shared<Tensor>(output_dims, dtype, input->GetDevice());
5961

6062
auto device = input->GetDevice();
61-
const auto cuda_stream = GetCudaStream(device);
63+
const auto cuda_stream = dynamic_cast<infini_train::core::cuda::CudaStream *>(
64+
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
65+
->cuda_stream();
6266

6367
if (bias) {
6468
CHECK_EQ(bias->Dims().size(), 1);
@@ -80,18 +84,17 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
8084
// When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
8185
// cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
8286
if (bs == 1 && dtype == DataType::kFLOAT32) {
83-
SgemvCuda(SgemvParams{
84-
.trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
85-
.m = static_cast<int>(transpose ? in_features : out_features),
86-
.n = static_cast<int>(transpose ? out_features : in_features),
87-
.A = static_cast<const float *>(weight->DataPtr()),
88-
.lda = static_cast<int>(transpose ? in_features : out_features),
89-
.x = static_cast<const float *>(input->DataPtr()),
90-
.y = static_cast<float *>(output->DataPtr()),
91-
.alpha = 1.0f,
92-
.beta = 1.0f, // output already initialized with bias or zero above
93-
.blas_handle = GetCublasHandle(device),
94-
});
87+
SgemvCuda(device, SgemvParams{
88+
.trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
89+
.m = static_cast<int>(transpose ? in_features : out_features),
90+
.n = static_cast<int>(transpose ? out_features : in_features),
91+
.A = static_cast<const float *>(weight->DataPtr()),
92+
.lda = static_cast<int>(transpose ? in_features : out_features),
93+
.x = static_cast<const float *>(input->DataPtr()),
94+
.y = static_cast<float *>(output->DataPtr()),
95+
.alpha = 1.0f,
96+
.beta = 1.0f, // output already initialized with bias or zero above
97+
});
9598
} else {
9699
// cuBLAS is colmun-major
97100
// - if a is transposed:
@@ -106,25 +109,24 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
106109
// C = output.T[out_features, bs]
107110
// A = weight.T[out_features, in_features]
108111
// B = input.T[in_features, bs]
109-
GemmCuda(GemmParams{
110-
.trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
111-
.trans_b = CUBLAS_OP_N,
112-
.m = static_cast<int>(out_features),
113-
.n = static_cast<int>(bs),
114-
.k = static_cast<int>(in_features),
115-
.A = weight->DataPtr(),
116-
.lda = static_cast<int>(transpose ? in_features : out_features),
117-
.B = input->DataPtr(),
118-
.ldb = static_cast<int>(in_features),
119-
.C = output->DataPtr(),
120-
.ldc = static_cast<int>(out_features),
121-
.alpha = 1.0f,
122-
.beta = 1.0f, // bias already written into output; beta=1 accumulates
123-
.batch_count = 1,
124-
.input_dtype = dtype,
125-
.output_dtype = dtype,
126-
.blas_handle = GetCublasHandle(device),
127-
});
112+
GemmCuda(device, GemmParams{
113+
.trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
114+
.trans_b = CUBLAS_OP_N,
115+
.m = static_cast<int>(out_features),
116+
.n = static_cast<int>(bs),
117+
.k = static_cast<int>(in_features),
118+
.A = weight->DataPtr(),
119+
.lda = static_cast<int>(transpose ? in_features : out_features),
120+
.B = input->DataPtr(),
121+
.ldb = static_cast<int>(in_features),
122+
.C = output->DataPtr(),
123+
.ldc = static_cast<int>(out_features),
124+
.alpha = 1.0f,
125+
.beta = 1.0f, // bias already written into output; beta=1 accumulates
126+
.batch_count = 1,
127+
.input_dtype = dtype,
128+
.output_dtype = dtype,
129+
});
128130
}
129131

130132
return output;
@@ -171,18 +173,17 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
171173
// When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
172174
// cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
173175
if (bs == 1 && compute_dtype == DataType::kFLOAT32) {
174-
SgemvCuda(SgemvParams{
175-
.trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
176-
.m = static_cast<int>(transpose ? in_features : out_features),
177-
.n = static_cast<int>(transpose ? out_features : in_features),
178-
.A = static_cast<const float *>(weight->DataPtr()),
179-
.lda = static_cast<int>(transpose ? in_features : out_features),
180-
.x = static_cast<const float *>(grad_output_promoted->DataPtr()),
181-
.y = static_cast<float *>(grad_input->DataPtr()),
182-
.alpha = 1.0f,
183-
.beta = 0.0f,
184-
.blas_handle = GetCublasHandle(grad_output->GetDevice()),
185-
});
176+
SgemvCuda(grad_output->GetDevice(), SgemvParams{
177+
.trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
178+
.m = static_cast<int>(transpose ? in_features : out_features),
179+
.n = static_cast<int>(transpose ? out_features : in_features),
180+
.A = static_cast<const float *>(weight->DataPtr()),
181+
.lda = static_cast<int>(transpose ? in_features : out_features),
182+
.x = static_cast<const float *>(grad_output_promoted->DataPtr()),
183+
.y = static_cast<float *>(grad_input->DataPtr()),
184+
.alpha = 1.0f,
185+
.beta = 0.0f,
186+
});
186187
} else {
187188
// - if transpose:
188189
// weight is [out_features, in_features] here
@@ -197,25 +198,24 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
197198
// C = d_input.T[in_features, bs]
198199
// A = weight.T[out_features, in_features]
199200
// B = d_output.T[out_features, bs]
200-
GemmCuda(GemmParams{
201-
.trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
202-
.trans_b = CUBLAS_OP_N,
203-
.m = static_cast<int>(in_features),
204-
.n = static_cast<int>(bs),
205-
.k = static_cast<int>(out_features),
206-
.A = weight->DataPtr(),
207-
.lda = static_cast<int>(transpose ? in_features : out_features),
208-
.B = grad_output_promoted->DataPtr(),
209-
.ldb = static_cast<int>(out_features),
210-
.C = grad_input->DataPtr(),
211-
.ldc = static_cast<int>(in_features),
212-
.alpha = 1.0f,
213-
.beta = 0.0f,
214-
.batch_count = 1,
215-
.input_dtype = compute_dtype,
216-
.output_dtype = output_dtype,
217-
.blas_handle = GetCublasHandle(grad_output->GetDevice()),
218-
});
201+
GemmCuda(grad_output->GetDevice(), GemmParams{
202+
.trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
203+
.trans_b = CUBLAS_OP_N,
204+
.m = static_cast<int>(in_features),
205+
.n = static_cast<int>(bs),
206+
.k = static_cast<int>(out_features),
207+
.A = weight->DataPtr(),
208+
.lda = static_cast<int>(transpose ? in_features : out_features),
209+
.B = grad_output_promoted->DataPtr(),
210+
.ldb = static_cast<int>(out_features),
211+
.C = grad_input->DataPtr(),
212+
.ldc = static_cast<int>(in_features),
213+
.alpha = 1.0f,
214+
.beta = 0.0f,
215+
.batch_count = 1,
216+
.input_dtype = compute_dtype,
217+
.output_dtype = output_dtype,
218+
});
219219
}
220220

221221
return grad_input;
@@ -257,25 +257,24 @@ std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &inpu
257257
const int lda = static_cast<int>(transpose ? in_features : out_features);
258258
const int ldb = static_cast<int>(transpose ? out_features : in_features);
259259

260-
GemmCuda(GemmParams{
261-
.trans_a = CUBLAS_OP_N,
262-
.trans_b = CUBLAS_OP_T,
263-
.m = static_cast<int>(transpose ? in_features : out_features),
264-
.n = static_cast<int>(transpose ? out_features : in_features),
265-
.k = static_cast<int>(bs),
266-
.A = a,
267-
.lda = lda,
268-
.B = b,
269-
.ldb = ldb,
270-
.C = grad_weight->DataPtr(),
271-
.ldc = static_cast<int>(transpose ? in_features : out_features),
272-
.alpha = 1.0f,
273-
.beta = 0.0f,
274-
.batch_count = 1,
275-
.input_dtype = compute_dtype,
276-
.output_dtype = output_dtype,
277-
.blas_handle = GetCublasHandle(grad_output->GetDevice()),
278-
});
260+
GemmCuda(grad_output->GetDevice(), GemmParams{
261+
.trans_a = CUBLAS_OP_N,
262+
.trans_b = CUBLAS_OP_T,
263+
.m = static_cast<int>(transpose ? in_features : out_features),
264+
.n = static_cast<int>(transpose ? out_features : in_features),
265+
.k = static_cast<int>(bs),
266+
.A = a,
267+
.lda = lda,
268+
.B = b,
269+
.ldb = ldb,
270+
.C = grad_weight->DataPtr(),
271+
.ldc = static_cast<int>(transpose ? in_features : out_features),
272+
.alpha = 1.0f,
273+
.beta = 0.0f,
274+
.batch_count = 1,
275+
.input_dtype = compute_dtype,
276+
.output_dtype = output_dtype,
277+
});
279278

280279
return grad_weight;
281280
}
@@ -292,7 +291,9 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o
292291
= std::make_shared<Tensor>(std::vector<int64_t>{out_features}, output_dtype, grad_output->GetDevice());
293292

294293
auto device = grad_output->GetDevice();
295-
const auto cuda_stream = GetCudaStream(device);
294+
const auto cuda_stream = dynamic_cast<infini_train::core::cuda::CudaStream *>(
295+
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
296+
->cuda_stream();
296297

297298
// d_bias = \sum_i(i=0, bs-1) d_output[i]
298299
// TODO(dcj): use thrust::fill or reduce kernel do this

0 commit comments

Comments
 (0)