Skip to content

Commit d385edb

Browse files
committed
fixup! Refactor(linear): split LinearBackward kernel into 3 independent kernels
1 parent ea0edc1 commit d385edb

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

infini_train/src/kernels/cuda/linear.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
330330

331331
// For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior).
332332
auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype;
333+
// No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output.
333334
auto grad_input = std::make_shared<Tensor>(input_dims, output_dtype, grad_output->GetDevice());
334335

335336
auto device = grad_output->GetDevice();
@@ -339,6 +340,7 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
339340
infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device))
340341
->cublas_handle();
341342

343+
// TODO(zbl): use cublasSgemv if possible
342344
// - if transpose:
343345
// weight is [out_features, in_features] here
344346
// d_input = d_output * weight --> d_input.T = weight.T * d_output.T
@@ -393,6 +395,7 @@ std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &inpu
393395
auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype;
394396
const std::vector<int64_t> weight_dims
395397
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
398+
// No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output.
396399
auto grad_weight = std::make_shared<Tensor>(weight_dims, output_dtype, grad_output->GetDevice());
397400

398401
auto device = grad_output->GetDevice();
@@ -460,6 +463,7 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o
460463
->cuda_stream();
461464

462465
// d_bias = \sum_i(i=0, bs-1) d_output[i]
466+
// TODO(dcj): use thrust::fill or reduce kernel do this
463467
constexpr int BLOCK_SIZE = 256;
464468
switch (compute_dtype) {
465469
DISPATCH_CASE(WRAP({

0 commit comments

Comments
 (0)