@@ -80,18 +80,18 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
8080 // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
8181 // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
8282 if (bs == 1 && dtype == DataType::kFLOAT32 ) {
83- SgemvParams p;
84- p .trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N ;
85- p .m = static_cast <int >(transpose ? in_features : out_features);
86- p .n = static_cast <int >(transpose ? out_features : in_features);
87- p .A = static_cast <const float *>(weight->DataPtr ());
88- p .lda = static_cast <int >(transpose ? in_features : out_features);
89- p .x = static_cast <const float *>(input->DataPtr ());
90- p .y = static_cast <float *>(output->DataPtr ());
91- p .alpha = 1 .0f ;
92- p .beta = 1 .0f ; // output already initialized with bias or zero above
93- p .blas_handle = GetCublasHandle (device);
94- SgemvCuda (p );
83+ SgemvCuda ( SgemvParams{
84+ .trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N ,
85+ .m = static_cast <int >(transpose ? in_features : out_features),
86+ .n = static_cast <int >(transpose ? out_features : in_features),
87+ .A = static_cast <const float *>(weight->DataPtr ()),
88+ .lda = static_cast <int >(transpose ? in_features : out_features),
89+ .x = static_cast <const float *>(input->DataPtr ()),
90+ .y = static_cast <float *>(output->DataPtr ()),
91+ .alpha = 1 .0f ,
92+ .beta = 1 .0f , // output already initialized with bias or zero above
93+ .blas_handle = GetCublasHandle (device),
94+ } );
9595 } else {
9696 // cuBLAS is colmun-major
9797 // - if a is transposed:
@@ -106,26 +106,25 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
106106 // C = output.T[out_features, bs]
107107 // A = weight.T[out_features, in_features]
108108 // B = input.T[in_features, bs]
109- GemmParams p;
110- p.trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N ;
111- p.trans_b = CUBLAS_OP_N ;
112- p.m = static_cast <int >(out_features);
113- p.n = static_cast <int >(bs);
114- p.k = static_cast <int >(in_features);
115- p.A = weight->DataPtr ();
116- p.lda = static_cast <int >(transpose ? in_features : out_features);
117- p.B = input->DataPtr ();
118- p.ldb = static_cast <int >(in_features);
119- p.C = output->DataPtr ();
120- p.ldc = static_cast <int >(out_features);
121- p.alpha = 1 .0f ;
122- p.beta = 1 .0f ; // bias already written into output; beta=1 accumulates
123- p.batch_count = 1 ;
124- p.input_dtype = dtype;
125- p.output_dtype = dtype;
126- p.blas_handle = GetCublasHandle (device);
127-
128- GemmCuda (p);
109+ GemmCuda (GemmParams{
110+ .trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N ,
111+ .trans_b = CUBLAS_OP_N ,
112+ .m = static_cast <int >(out_features),
113+ .n = static_cast <int >(bs),
114+ .k = static_cast <int >(in_features),
115+ .A = weight->DataPtr (),
116+ .lda = static_cast <int >(transpose ? in_features : out_features),
117+ .B = input->DataPtr (),
118+ .ldb = static_cast <int >(in_features),
119+ .C = output->DataPtr (),
120+ .ldc = static_cast <int >(out_features),
121+ .alpha = 1 .0f ,
122+ .beta = 1 .0f , // bias already written into output; beta=1 accumulates
123+ .batch_count = 1 ,
124+ .input_dtype = dtype,
125+ .output_dtype = dtype,
126+ .blas_handle = GetCublasHandle (device),
127+ });
129128 }
130129
131130 return output;
@@ -172,18 +171,18 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
172171 // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
173172 // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
174173 if (bs == 1 && compute_dtype == DataType::kFLOAT32 ) {
175- SgemvParams p;
176- p .trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T ;
177- p .m = static_cast <int >(transpose ? in_features : out_features);
178- p .n = static_cast <int >(transpose ? out_features : in_features);
179- p .A = static_cast <const float *>(weight->DataPtr ());
180- p .lda = static_cast <int >(transpose ? in_features : out_features);
181- p .x = static_cast <const float *>(grad_output_promoted->DataPtr ());
182- p .y = static_cast <float *>(grad_input->DataPtr ());
183- p .alpha = 1 .0f ;
184- p .beta = 0 .0f ;
185- p .blas_handle = GetCublasHandle (grad_output->GetDevice ());
186- SgemvCuda (p );
174+ SgemvCuda ( SgemvParams{
175+ .trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T ,
176+ .m = static_cast <int >(transpose ? in_features : out_features),
177+ .n = static_cast <int >(transpose ? out_features : in_features),
178+ .A = static_cast <const float *>(weight->DataPtr ()),
179+ .lda = static_cast <int >(transpose ? in_features : out_features),
180+ .x = static_cast <const float *>(grad_output_promoted->DataPtr ()),
181+ .y = static_cast <float *>(grad_input->DataPtr ()),
182+ .alpha = 1 .0f ,
183+ .beta = 0 .0f ,
184+ .blas_handle = GetCublasHandle (grad_output->GetDevice ()),
185+ } );
187186 } else {
188187 // - if transpose:
189188 // weight is [out_features, in_features] here
@@ -198,26 +197,25 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
198197 // C = d_input.T[in_features, bs]
199198 // A = weight.T[out_features, in_features]
200199 // B = d_output.T[out_features, bs]
201- GemmParams p;
202- p.trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T ;
203- p.trans_b = CUBLAS_OP_N ;
204- p.m = static_cast <int >(in_features);
205- p.n = static_cast <int >(bs);
206- p.k = static_cast <int >(out_features);
207- p.A = weight->DataPtr ();
208- p.lda = static_cast <int >(transpose ? in_features : out_features);
209- p.B = grad_output_promoted->DataPtr ();
210- p.ldb = static_cast <int >(out_features);
211- p.C = grad_input->DataPtr ();
212- p.ldc = static_cast <int >(in_features);
213- p.alpha = 1 .0f ;
214- p.beta = 0 .0f ;
215- p.batch_count = 1 ;
216- p.input_dtype = compute_dtype;
217- p.output_dtype = output_dtype;
218- p.blas_handle = GetCublasHandle (grad_output->GetDevice ());
219-
220- GemmCuda (p);
200+ GemmCuda (GemmParams{
201+ .trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T ,
202+ .trans_b = CUBLAS_OP_N ,
203+ .m = static_cast <int >(in_features),
204+ .n = static_cast <int >(bs),
205+ .k = static_cast <int >(out_features),
206+ .A = weight->DataPtr (),
207+ .lda = static_cast <int >(transpose ? in_features : out_features),
208+ .B = grad_output_promoted->DataPtr (),
209+ .ldb = static_cast <int >(out_features),
210+ .C = grad_input->DataPtr (),
211+ .ldc = static_cast <int >(in_features),
212+ .alpha = 1 .0f ,
213+ .beta = 0 .0f ,
214+ .batch_count = 1 ,
215+ .input_dtype = compute_dtype,
216+ .output_dtype = output_dtype,
217+ .blas_handle = GetCublasHandle (grad_output->GetDevice ()),
218+ });
221219 }
222220
223221 return grad_input;
@@ -259,26 +257,25 @@ std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &inpu
259257 const int lda = static_cast <int >(transpose ? in_features : out_features);
260258 const int ldb = static_cast <int >(transpose ? out_features : in_features);
261259
262- GemmParams p;
263- p.trans_a = CUBLAS_OP_N ;
264- p.trans_b = CUBLAS_OP_T ;
265- p.m = static_cast <int >(transpose ? in_features : out_features);
266- p.n = static_cast <int >(transpose ? out_features : in_features);
267- p.k = static_cast <int >(bs);
268- p.A = a;
269- p.lda = lda;
270- p.B = b;
271- p.ldb = ldb;
272- p.C = grad_weight->DataPtr ();
273- p.ldc = static_cast <int >(transpose ? in_features : out_features);
274- p.alpha = 1 .0f ;
275- p.beta = 0 .0f ;
276- p.batch_count = 1 ;
277- p.input_dtype = compute_dtype;
278- p.output_dtype = output_dtype;
279- p.blas_handle = GetCublasHandle (grad_output->GetDevice ());
280-
281- GemmCuda (p);
260+ GemmCuda (GemmParams{
261+ .trans_a = CUBLAS_OP_N ,
262+ .trans_b = CUBLAS_OP_T ,
263+ .m = static_cast <int >(transpose ? in_features : out_features),
264+ .n = static_cast <int >(transpose ? out_features : in_features),
265+ .k = static_cast <int >(bs),
266+ .A = a,
267+ .lda = lda,
268+ .B = b,
269+ .ldb = ldb,
270+ .C = grad_weight->DataPtr (),
271+ .ldc = static_cast <int >(transpose ? in_features : out_features),
272+ .alpha = 1 .0f ,
273+ .beta = 0 .0f ,
274+ .batch_count = 1 ,
275+ .input_dtype = compute_dtype,
276+ .output_dtype = output_dtype,
277+ .blas_handle = GetCublasHandle (grad_output->GetDevice ()),
278+ });
282279
283280 return grad_weight;
284281}
0 commit comments