99#include " infini_train/include/common/cuda/common_cuda.h"
1010#include " infini_train/include/common/cuda/gemm.cuh"
1111#include " infini_train/include/common/cuda/kernel_helper.cuh"
12+ #include " infini_train/include/core/runtime/device_guard.h"
1213#include " infini_train/include/dispatcher.h"
1314#include " infini_train/include/tensor.h"
1415#include " infini_train/src/core/runtime/cuda/cuda_dispatch.h"
16+ #include " infini_train/src/core/runtime/cuda/cuda_runtime_common.h"
1517
1618namespace infini_train ::kernels::cuda {
1719
@@ -58,7 +60,9 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
5860 auto output = std::make_shared<Tensor>(output_dims, dtype, input->GetDevice ());
5961
6062 auto device = input->GetDevice ();
61- const auto cuda_stream = GetCudaStream (device);
63+ const auto cuda_stream = dynamic_cast <infini_train::core::cuda::CudaStream *>(
64+ infini_train::core::GetDeviceGuardImpl (device.type ())->GetStream (device))
65+ ->cuda_stream ();
6266
6367 if (bias) {
6468 CHECK_EQ (bias->Dims ().size (), 1 );
@@ -80,18 +84,17 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
8084 // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
8185 // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
8286 if (bs == 1 && dtype == DataType::kFLOAT32 ) {
83- SgemvCuda (SgemvParams{
84- .trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
85- .m = static_cast <int >(transpose ? in_features : out_features),
86- .n = static_cast <int >(transpose ? out_features : in_features),
87- .A = static_cast <const float *>(weight->DataPtr ()),
88- .lda = static_cast <int >(transpose ? in_features : out_features),
89- .x = static_cast <const float *>(input->DataPtr ()),
90- .y = static_cast <float *>(output->DataPtr ()),
91- .alpha = 1 .0f ,
92- .beta = 1 .0f , // output already initialized with bias or zero above
93- .blas_handle = GetCublasHandle (device),
94- });
87+ SgemvCuda (device, SgemvParams{
88+ .trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
89+ .m = static_cast <int >(transpose ? in_features : out_features),
90+ .n = static_cast <int >(transpose ? out_features : in_features),
91+ .A = static_cast <const float *>(weight->DataPtr ()),
92+ .lda = static_cast <int >(transpose ? in_features : out_features),
93+ .x = static_cast <const float *>(input->DataPtr ()),
94+ .y = static_cast <float *>(output->DataPtr ()),
95+ .alpha = 1 .0f ,
96+ .beta = 1 .0f , // output already initialized with bias or zero above
97+ });
9598 } else {
9699 // cuBLAS is colmun-major
97100 // - if a is transposed:
@@ -106,25 +109,24 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
106109 // C = output.T[out_features, bs]
107110 // A = weight.T[out_features, in_features]
108111 // B = input.T[in_features, bs]
109- GemmCuda (GemmParams{
110- .trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
111- .trans_b = CUBLAS_OP_N,
112- .m = static_cast <int >(out_features),
113- .n = static_cast <int >(bs),
114- .k = static_cast <int >(in_features),
115- .A = weight->DataPtr (),
116- .lda = static_cast <int >(transpose ? in_features : out_features),
117- .B = input->DataPtr (),
118- .ldb = static_cast <int >(in_features),
119- .C = output->DataPtr (),
120- .ldc = static_cast <int >(out_features),
121- .alpha = 1 .0f ,
122- .beta = 1 .0f , // bias already written into output; beta=1 accumulates
123- .batch_count = 1 ,
124- .input_dtype = dtype,
125- .output_dtype = dtype,
126- .blas_handle = GetCublasHandle (device),
127- });
112+ GemmCuda (device, GemmParams{
113+ .trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
114+ .trans_b = CUBLAS_OP_N,
115+ .m = static_cast <int >(out_features),
116+ .n = static_cast <int >(bs),
117+ .k = static_cast <int >(in_features),
118+ .A = weight->DataPtr (),
119+ .lda = static_cast <int >(transpose ? in_features : out_features),
120+ .B = input->DataPtr (),
121+ .ldb = static_cast <int >(in_features),
122+ .C = output->DataPtr (),
123+ .ldc = static_cast <int >(out_features),
124+ .alpha = 1 .0f ,
125+ .beta = 1 .0f , // bias already written into output; beta=1 accumulates
126+ .batch_count = 1 ,
127+ .input_dtype = dtype,
128+ .output_dtype = dtype,
129+ });
128130 }
129131
130132 return output;
@@ -171,18 +173,17 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
171173 // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
172174 // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
173175 if (bs == 1 && compute_dtype == DataType::kFLOAT32 ) {
174- SgemvCuda (SgemvParams{
175- .trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
176- .m = static_cast <int >(transpose ? in_features : out_features),
177- .n = static_cast <int >(transpose ? out_features : in_features),
178- .A = static_cast <const float *>(weight->DataPtr ()),
179- .lda = static_cast <int >(transpose ? in_features : out_features),
180- .x = static_cast <const float *>(grad_output_promoted->DataPtr ()),
181- .y = static_cast <float *>(grad_input->DataPtr ()),
182- .alpha = 1 .0f ,
183- .beta = 0 .0f ,
184- .blas_handle = GetCublasHandle (grad_output->GetDevice ()),
185- });
176+ SgemvCuda (grad_output->GetDevice (), SgemvParams{
177+ .trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
178+ .m = static_cast <int >(transpose ? in_features : out_features),
179+ .n = static_cast <int >(transpose ? out_features : in_features),
180+ .A = static_cast <const float *>(weight->DataPtr ()),
181+ .lda = static_cast <int >(transpose ? in_features : out_features),
182+ .x = static_cast <const float *>(grad_output_promoted->DataPtr ()),
183+ .y = static_cast <float *>(grad_input->DataPtr ()),
184+ .alpha = 1 .0f ,
185+ .beta = 0 .0f ,
186+ });
186187 } else {
187188 // - if transpose:
188189 // weight is [out_features, in_features] here
@@ -197,25 +198,24 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
197198 // C = d_input.T[in_features, bs]
198199 // A = weight.T[out_features, in_features]
199200 // B = d_output.T[out_features, bs]
200- GemmCuda (GemmParams{
201- .trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
202- .trans_b = CUBLAS_OP_N,
203- .m = static_cast <int >(in_features),
204- .n = static_cast <int >(bs),
205- .k = static_cast <int >(out_features),
206- .A = weight->DataPtr (),
207- .lda = static_cast <int >(transpose ? in_features : out_features),
208- .B = grad_output_promoted->DataPtr (),
209- .ldb = static_cast <int >(out_features),
210- .C = grad_input->DataPtr (),
211- .ldc = static_cast <int >(in_features),
212- .alpha = 1 .0f ,
213- .beta = 0 .0f ,
214- .batch_count = 1 ,
215- .input_dtype = compute_dtype,
216- .output_dtype = output_dtype,
217- .blas_handle = GetCublasHandle (grad_output->GetDevice ()),
218- });
201+ GemmCuda (grad_output->GetDevice (), GemmParams{
202+ .trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
203+ .trans_b = CUBLAS_OP_N,
204+ .m = static_cast <int >(in_features),
205+ .n = static_cast <int >(bs),
206+ .k = static_cast <int >(out_features),
207+ .A = weight->DataPtr (),
208+ .lda = static_cast <int >(transpose ? in_features : out_features),
209+ .B = grad_output_promoted->DataPtr (),
210+ .ldb = static_cast <int >(out_features),
211+ .C = grad_input->DataPtr (),
212+ .ldc = static_cast <int >(in_features),
213+ .alpha = 1 .0f ,
214+ .beta = 0 .0f ,
215+ .batch_count = 1 ,
216+ .input_dtype = compute_dtype,
217+ .output_dtype = output_dtype,
218+ });
219219 }
220220
221221 return grad_input;
@@ -257,25 +257,24 @@ std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &inpu
257257 const int lda = static_cast <int >(transpose ? in_features : out_features);
258258 const int ldb = static_cast <int >(transpose ? out_features : in_features);
259259
260- GemmCuda (GemmParams{
261- .trans_a = CUBLAS_OP_N,
262- .trans_b = CUBLAS_OP_T,
263- .m = static_cast <int >(transpose ? in_features : out_features),
264- .n = static_cast <int >(transpose ? out_features : in_features),
265- .k = static_cast <int >(bs),
266- .A = a,
267- .lda = lda,
268- .B = b,
269- .ldb = ldb,
270- .C = grad_weight->DataPtr (),
271- .ldc = static_cast <int >(transpose ? in_features : out_features),
272- .alpha = 1 .0f ,
273- .beta = 0 .0f ,
274- .batch_count = 1 ,
275- .input_dtype = compute_dtype,
276- .output_dtype = output_dtype,
277- .blas_handle = GetCublasHandle (grad_output->GetDevice ()),
278- });
260+ GemmCuda (grad_output->GetDevice (), GemmParams{
261+ .trans_a = CUBLAS_OP_N,
262+ .trans_b = CUBLAS_OP_T,
263+ .m = static_cast <int >(transpose ? in_features : out_features),
264+ .n = static_cast <int >(transpose ? out_features : in_features),
265+ .k = static_cast <int >(bs),
266+ .A = a,
267+ .lda = lda,
268+ .B = b,
269+ .ldb = ldb,
270+ .C = grad_weight->DataPtr (),
271+ .ldc = static_cast <int >(transpose ? in_features : out_features),
272+ .alpha = 1 .0f ,
273+ .beta = 0 .0f ,
274+ .batch_count = 1 ,
275+ .input_dtype = compute_dtype,
276+ .output_dtype = output_dtype,
277+ });
279278
280279 return grad_weight;
281280}
@@ -292,7 +291,9 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o
292291 = std::make_shared<Tensor>(std::vector<int64_t >{out_features}, output_dtype, grad_output->GetDevice ());
293292
294293 auto device = grad_output->GetDevice ();
295- const auto cuda_stream = GetCudaStream (device);
294+ const auto cuda_stream = dynamic_cast <infini_train::core::cuda::CudaStream *>(
295+ infini_train::core::GetDeviceGuardImpl (device.type ())->GetStream (device))
296+ ->cuda_stream ();
296297
297298 // d_bias = \sum_i(i=0, bs-1) d_output[i]
298299 // TODO(dcj): use thrust::fill or reduce kernel do this
0 commit comments