@@ -330,6 +330,7 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
330330
331331 // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior).
332332 auto output_dtype = (compute_dtype == DataType::kBFLOAT16 ) ? DataType::kFLOAT32 : compute_dtype;
333+ // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output.
333334 auto grad_input = std::make_shared<Tensor>(input_dims, output_dtype, grad_output->GetDevice ());
334335
335336 auto device = grad_output->GetDevice ();
@@ -339,6 +340,7 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
339340 infini_train::core::GetDeviceGuardImpl (device.type ())->GetBlasHandle (device))
340341 ->cublas_handle ();
341342
343+ // TODO(zbl): use cublasSgemv if possible
342344 // - if transpose:
343345 // weight is [out_features, in_features] here
344346 // d_input = d_output * weight --> d_input.T = weight.T * d_output.T
@@ -393,6 +395,7 @@ std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &inpu
393395 auto output_dtype = (compute_dtype == DataType::kBFLOAT16 ) ? DataType::kFLOAT32 : compute_dtype;
394396 const std::vector<int64_t > weight_dims
395397 = transpose ? std::vector<int64_t >{out_features, in_features} : std::vector<int64_t >{in_features, out_features};
398+ // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output.
396399 auto grad_weight = std::make_shared<Tensor>(weight_dims, output_dtype, grad_output->GetDevice ());
397400
398401 auto device = grad_output->GetDevice ();
@@ -460,6 +463,7 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o
460463 ->cuda_stream ();
461464
462465 // d_bias = \sum_i(i=0, bs-1) d_output[i]
466+ // TODO(dcj): use thrust::fill or reduce kernel do this
463467 constexpr int BLOCK_SIZE = 256 ;
464468 switch (compute_dtype) {
465469 DISPATCH_CASE (WRAP ({
0 commit comments