|
7 | 7 | #include <cublas_v2.h> |
8 | 8 |
|
9 | 9 | #include "infini_train/include/common/cuda/common_cuda.h" |
10 | | -#include "infini_train/include/common/cuda/gemm.cuh" |
11 | 10 | #include "infini_train/include/common/cuda/kernel_helper.cuh" |
12 | 11 | #include "infini_train/include/core/runtime/device_guard.h" |
13 | 12 | #include "infini_train/include/dispatcher.h" |
14 | 13 | #include "infini_train/include/tensor.h" |
15 | 14 | #include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" |
16 | 15 | #include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" |
| 16 | +#include "infini_train/src/kernels/cuda/common/gemm.cuh" |
17 | 17 |
|
18 | 18 | namespace infini_train::kernels::cuda { |
19 | 19 |
|
@@ -165,7 +165,7 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh |
165 | 165 | auto grad_output_promoted |
166 | 166 | = grad_output_dtype == compute_dtype ? grad_output : std::make_shared<Tensor>(grad_output->To(compute_dtype)); |
167 | 167 |
|
168 | | - // For bf16 compute, accumulate in fp32 to preserve precision. |
| 168 | + // FIXME(cx): output dtype promotion is a temporary hack; revisit when autograd/autocast is fixed. |
169 | 169 | auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; |
170 | 170 | // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. |
171 | 171 | auto grad_input = std::make_shared<Tensor>(input_dims, output_dtype, grad_output->GetDevice()); |
@@ -234,7 +234,7 @@ std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &inpu |
234 | 234 | auto grad_output_promoted |
235 | 235 | = grad_output_dtype == compute_dtype ? grad_output : std::make_shared<Tensor>(grad_output->To(compute_dtype)); |
236 | 236 |
|
237 | | - // For bf16 compute, accumulate in fp32 to preserve precision. |
| 237 | + // FIXME(cx): output dtype promotion is a temporary hack; revisit when autograd/autocast is fixed. |
238 | 238 | auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; |
239 | 239 | const std::vector<int64_t> weight_dims |
240 | 240 | = transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features}; |
@@ -285,7 +285,7 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o |
285 | 285 | const int64_t bs = std::accumulate(dims.rbegin() + 1, dims.rend(), 1, std::multiplies<int64_t>{}); |
286 | 286 |
|
287 | 287 | auto compute_dtype = grad_output->Dtype(); |
288 | | - // For bf16 compute, accumulate in fp32 to preserve precision. |
| 288 | + // FIXME(cx): output dtype promotion is a temporary hack; revisit when autograd/autocast is fixed. |
289 | 289 | auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; |
290 | 290 | auto grad_bias |
291 | 291 | = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, output_dtype, grad_output->GetDevice()); |
|
0 commit comments