Skip to content

Commit d6e3899

Browse files
committed
refactor(gemm): move gemm.cuh/gemm.cu to src/kernels/cuda/common/
include/ is for public-facing interfaces only; gemm primitives are internal, so relocate them under src/. Update all include paths. Also rename ctype -> compute_type and add FIXME on bf16 output dtype promotion hack in linear backward passes.
1 parent 15be0d2 commit d6e3899

5 files changed

Lines changed: 10 additions & 10 deletions

File tree

infini_train/src/kernels/cuda/gemm.cu renamed to infini_train/src/kernels/cuda/common/gemm.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "infini_train/include/common/cuda/gemm.cuh"
1+
#include "infini_train/src/kernels/cuda/common/gemm.cuh"
22

33
#include <cublas_v2.h>
44

@@ -48,15 +48,15 @@ void GemmCuda(const Device &device, const GemmParams &p) {
4848
const cudaDataType_t type_c = ToCudaDataType(p.output_dtype);
4949
// Always use CUBLAS_COMPUTE_32F: required for bf16/fp16 correctness,
5050
// and fine for fp32 (same compute path).
51-
const cublasComputeType_t ctype = CUBLAS_COMPUTE_32F;
51+
const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
5252

5353
if (p.batch_count == 1) {
5454
CUBLAS_CHECK(cublasGemmEx(blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.B,
55-
type_b, p.ldb, &p.beta, p.C, type_c, p.ldc, ctype, CUBLAS_GEMM_DEFAULT));
55+
type_b, p.ldb, &p.beta, p.C, type_c, p.ldc, compute_type, CUBLAS_GEMM_DEFAULT));
5656
} else {
5757
CUBLAS_CHECK(cublasGemmStridedBatchedEx(blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a,
5858
p.lda, p.stride_a, p.B, type_b, p.ldb, p.stride_b, &p.beta, p.C, type_c,
59-
p.ldc, p.stride_c, p.batch_count, ctype, CUBLAS_GEMM_DEFAULT));
59+
p.ldc, p.stride_c, p.batch_count, compute_type, CUBLAS_GEMM_DEFAULT));
6060
}
6161
}
6262

File renamed without changes.

infini_train/src/kernels/cuda/linear.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
#include <cublas_v2.h>
88

99
#include "infini_train/include/common/cuda/common_cuda.h"
10-
#include "infini_train/include/common/cuda/gemm.cuh"
1110
#include "infini_train/include/common/cuda/kernel_helper.cuh"
1211
#include "infini_train/include/core/runtime/device_guard.h"
1312
#include "infini_train/include/dispatcher.h"
1413
#include "infini_train/include/tensor.h"
1514
#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h"
1615
#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h"
16+
#include "infini_train/src/kernels/cuda/common/gemm.cuh"
1717

1818
namespace infini_train::kernels::cuda {
1919

@@ -165,7 +165,7 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
165165
auto grad_output_promoted
166166
= grad_output_dtype == compute_dtype ? grad_output : std::make_shared<Tensor>(grad_output->To(compute_dtype));
167167

168-
// For bf16 compute, accumulate in fp32 to preserve precision.
168+
// FIXME(cx): output dtype promotion is a temporary hack; revisit when autograd/autocast is fixed.
169169
auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype;
170170
// No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output.
171171
auto grad_input = std::make_shared<Tensor>(input_dims, output_dtype, grad_output->GetDevice());
@@ -234,7 +234,7 @@ std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &inpu
234234
auto grad_output_promoted
235235
= grad_output_dtype == compute_dtype ? grad_output : std::make_shared<Tensor>(grad_output->To(compute_dtype));
236236

237-
// For bf16 compute, accumulate in fp32 to preserve precision.
237+
// FIXME(cx): output dtype promotion is a temporary hack; revisit when autograd/autocast is fixed.
238238
auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype;
239239
const std::vector<int64_t> weight_dims
240240
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
@@ -285,7 +285,7 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o
285285
const int64_t bs = std::accumulate(dims.rbegin() + 1, dims.rend(), 1, std::multiplies<int64_t>{});
286286

287287
auto compute_dtype = grad_output->Dtype();
288-
// For bf16 compute, accumulate in fp32 to preserve precision.
288+
// FIXME(cx): output dtype promotion is a temporary hack; revisit when autograd/autocast is fixed.
289289
auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype;
290290
auto grad_bias
291291
= std::make_shared<Tensor>(std::vector<int64_t>{out_features}, output_dtype, grad_output->GetDevice());

infini_train/src/kernels/cuda/matmul.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include <cublas_v2.h>
77

88
#include "infini_train/include/common/cuda/common_cuda.h"
9-
#include "infini_train/include/common/cuda/gemm.cuh"
109
#include "infini_train/include/dispatcher.h"
1110
#include "infini_train/include/tensor.h"
11+
#include "infini_train/src/kernels/cuda/common/gemm.cuh"
1212

1313
namespace infini_train::kernels::cuda {
1414

infini_train/src/kernels/cuda/outer.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
#include "glog/logging.h"
88

99
#include "infini_train/include/common/cuda/common_cuda.h"
10-
#include "infini_train/include/common/cuda/gemm.cuh"
1110
#include "infini_train/include/dispatcher.h"
1211
#include "infini_train/include/tensor.h"
12+
#include "infini_train/src/kernels/cuda/common/gemm.cuh"
1313

1414
namespace infini_train::kernels::cuda {
1515

0 commit comments

Comments
 (0)