Skip to content

Commit 283d083

Browse files
committed
refactor(gemm): extract shared GemmCuda primitive; split matmul kernels; rename MatmulBackwardInput1/2
- Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream() shared across all GEMM-using kernels - Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels - Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther for semantic clarity matching MatmulForward(input, other) parameter names - Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths); keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
1 parent be6eed3 commit 283d083

File tree

8 files changed

+694
-510
lines changed

8 files changed

+694
-510
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#pragma once
2+
3+
#include <cublas_v2.h>
4+
#include <cuda_runtime_api.h>
5+
6+
#include "infini_train/include/datatype.h"
7+
#include "infini_train/include/device.h"
8+
9+
namespace infini_train::kernels::cuda {
10+
11+
/**
12+
* Return the cuBLAS handle associated with the given device.
13+
* Shared by linear.cu, matmul.cu, and any future GEMM-using kernels.
14+
*/
15+
cublasHandle_t GetCublasHandle(const Device &device);
16+
17+
/**
18+
* Return the CUDA stream associated with the given device.
19+
* Shared by kernels that need to launch device-side code directly.
20+
*/
21+
cudaStream_t GetCudaStream(const Device &device);
22+
23+
/**
24+
* Parameter bundle for a single GEMM call:
25+
* C = alpha * op(A) * op(B) + beta * C
26+
*
27+
* batch_count == 1 → non-batched path (cublasGemmEx)
28+
* batch_count > 1 → strided-batched (cublasGemmStridedBatchedEx)
29+
*
30+
* When batch_count == 1, stride_a/b/c are unused and must be left at 0.
31+
*/
32+
struct GemmParams {
33+
cublasOperation_t trans_a = CUBLAS_OP_N;
34+
cublasOperation_t trans_b = CUBLAS_OP_N;
35+
36+
int m = 0; // rows of op(A) and C
37+
int n = 0; // cols of op(B) and C
38+
int k = 0; // cols of op(A) == rows of op(B)
39+
40+
const void *A = nullptr;
41+
int lda = 0;
42+
const void *B = nullptr;
43+
int ldb = 0;
44+
void *C = nullptr;
45+
int ldc = 0;
46+
47+
float alpha = 1.0f;
48+
float beta = 0.0f;
49+
50+
// batch_count=1: non-batched (Linear path); stride_a/b/c must be 0
51+
// batch_count>1: strided-batched (Matmul path)
52+
int batch_count = 1;
53+
long long stride_a = 0;
54+
long long stride_b = 0;
55+
long long stride_c = 0;
56+
57+
DataType input_dtype; // dtype of A and B
58+
DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out)
59+
60+
cublasHandle_t blas_handle = nullptr;
61+
};
62+
63+
/**
64+
* Execute the GEMM described by `p` via cuBLAS.
65+
* Dispatches to cublasGemmEx (batch_count==1) or
66+
* cublasGemmStridedBatchedEx (batch_count>1).
67+
* Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision.
68+
* Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)).
69+
*/
70+
void GemmCuda(const GemmParams &p);
71+
72+
} // namespace infini_train::kernels::cuda

infini_train/src/autograd/matmul.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,18 @@ std::vector<std::shared_ptr<Tensor>> Matmul::Backward(const std::vector<std::sha
5858

5959
auto device = input1->GetDevice().type();
6060

61-
std::shared_ptr<Tensor> grad_input1 = nullptr;
62-
std::shared_ptr<Tensor> grad_input2 = nullptr;
61+
std::shared_ptr<Tensor> grad_input = nullptr;
62+
std::shared_ptr<Tensor> grad_other = nullptr;
6363

6464
if (need_grad_input1) {
65-
grad_input1 = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput1"}, input2,
66-
grad_output, input1->Dims());
65+
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput"}, input2,
66+
grad_output, input1->Dims());
6767
}
6868
if (need_grad_input2) {
69-
grad_input2 = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput2"}, input1,
70-
grad_output, input2->Dims());
69+
grad_other = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardOther"}, input1,
70+
grad_output, input2->Dims());
7171
}
7272

73-
return {grad_input1, grad_input2};
73+
return {grad_input, grad_other};
7474
}
7575
} // namespace infini_train::autograd

infini_train/src/kernels/cpu/linear.cc

Lines changed: 1 addition & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -9,128 +9,6 @@
99
#include "infini_train/include/tensor.h"
1010

1111
namespace infini_train::kernels::cpu {
12-
std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other) {
13-
/*
14-
output[*, m, n] = input[*, m, k] * other[*, k, n]
15-
*/
16-
// TODO(dcj): support broadcast later
17-
const auto &input_dims = input->Dims();
18-
const auto &other_dims = other->Dims();
19-
20-
CHECK_GE(input_dims.size(), 2);
21-
CHECK_GE(other_dims.size(), 2);
22-
CHECK_EQ(input_dims.size(), other_dims.size());
23-
24-
const int64_t m = input_dims[input_dims.size() - 2];
25-
const int64_t k = input_dims[input_dims.size() - 1];
26-
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
27-
const int64_t n = other_dims[other_dims.size() - 1];
28-
29-
const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
30-
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
31-
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
32-
}
33-
34-
std::vector<int64_t> output_dims = input_dims;
35-
output_dims[output_dims.size() - 1] = n;
36-
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32);
37-
38-
for (int64_t b = 0; b < bs; ++b) {
39-
for (int64_t i = 0; i < m; ++i) {
40-
for (int64_t j = 0; j < n; ++j) {
41-
float acc = 0.0f;
42-
for (int64_t p = 0; p < k; ++p) {
43-
acc += static_cast<const float *>(input->DataPtr())[b * m * k + i * k + p]
44-
* static_cast<const float *>(other->DataPtr())[b * k * n + p * n + j];
45-
}
46-
static_cast<float *>(output->DataPtr())[b * m * n + i * n + j] = acc;
47-
}
48-
}
49-
}
50-
return {output};
51-
}
52-
53-
std::shared_ptr<Tensor> MatmulBackwardInput1(const std::shared_ptr<Tensor> &other,
54-
const std::shared_ptr<Tensor> &grad_output,
55-
const std::vector<int64_t> &input_dims) {
56-
/*
57-
grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T
58-
*/
59-
const auto &other_dims = other->Dims();
60-
const auto &grad_output_dims = grad_output->Dims();
61-
62-
CHECK_GE(other_dims.size(), 2);
63-
CHECK_EQ(other_dims.size(), grad_output_dims.size());
64-
65-
const int64_t m = grad_output_dims[grad_output_dims.size() - 2];
66-
const int64_t k = other_dims[other_dims.size() - 2];
67-
const int64_t n = grad_output_dims[grad_output_dims.size() - 1];
68-
69-
const int64_t bs
70-
= std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies<int64_t>{});
71-
for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) {
72-
CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match";
73-
}
74-
75-
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
76-
grad_input->Fill<float>(0.0f);
77-
78-
for (int64_t b = 0; b < bs; ++b) {
79-
for (int64_t i = 0; i < m; ++i) {
80-
for (int64_t j = 0; j < n; ++j) {
81-
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
82-
for (int64_t p = 0; p < k; ++p) {
83-
const auto other_idx = b * k * n + p * n + j;
84-
static_cast<float *>(grad_input->DataPtr())[b * m * k + i * k + p]
85-
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
86-
}
87-
}
88-
}
89-
}
90-
return grad_input;
91-
}
92-
93-
std::shared_ptr<Tensor> MatmulBackwardInput2(const std::shared_ptr<Tensor> &input1,
94-
const std::shared_ptr<Tensor> &grad_output,
95-
const std::vector<int64_t> &other_dims) {
96-
/*
97-
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
98-
*/
99-
const auto &input_dims = input1->Dims();
100-
const auto &grad_output_dims = grad_output->Dims();
101-
102-
CHECK_GE(input_dims.size(), 2);
103-
CHECK_EQ(input_dims.size(), grad_output_dims.size());
104-
105-
const int64_t m = input_dims[input_dims.size() - 2];
106-
const int64_t k = input_dims[input_dims.size() - 1];
107-
const int64_t n = grad_output_dims[grad_output_dims.size() - 1];
108-
CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]);
109-
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
110-
111-
const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
112-
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
113-
CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match";
114-
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
115-
}
116-
117-
auto grad_other = std::make_shared<Tensor>(other_dims, DataType::kFLOAT32);
118-
grad_other->Fill<float>(0.0f);
119-
120-
for (int64_t b = 0; b < bs; ++b) {
121-
for (int64_t i = 0; i < m; ++i) {
122-
for (int64_t j = 0; j < n; ++j) {
123-
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
124-
for (int64_t p = 0; p < k; ++p) {
125-
const auto input_idx = b * m * k + i * k + p;
126-
static_cast<float *>(grad_other->DataPtr())[b * k * n + p * n + j]
127-
+= grad * static_cast<const float *>(input1->DataPtr())[input_idx];
128-
}
129-
}
130-
}
131-
}
132-
return grad_other;
133-
}
13412

13513
std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
13614
bool transpose, const std::shared_ptr<Tensor> &bias) {
@@ -225,14 +103,12 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o
225103
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
226104
return grad_bias;
227105
}
106+
228107
} // namespace infini_train::kernels::cpu
229108

230109
#define REGISTER_CPU_LINEAR_KERNEL(kernel_name) \
231110
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)
232111

233-
REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
234-
REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput1)
235-
REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput2)
236112
REGISTER_CPU_LINEAR_KERNEL(LinearForward)
237113
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput)
238114
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#include <cstdint>
2+
#include <memory>
3+
#include <numeric>
4+
#include <vector>
5+
6+
#include "glog/logging.h"
7+
8+
#include "infini_train/include/dispatcher.h"
9+
#include "infini_train/include/tensor.h"
10+
11+
namespace infini_train::kernels::cpu {
12+
13+
std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other) {
14+
/*
15+
output[*, m, n] = input[*, m, k] * other[*, k, n]
16+
*/
17+
// TODO(dcj): support broadcast later
18+
const auto &input_dims = input->Dims();
19+
const auto &other_dims = other->Dims();
20+
21+
CHECK_GE(input_dims.size(), 2);
22+
CHECK_GE(other_dims.size(), 2);
23+
CHECK_EQ(input_dims.size(), other_dims.size());
24+
25+
const int64_t m = input_dims[input_dims.size() - 2];
26+
const int64_t k = input_dims[input_dims.size() - 1];
27+
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
28+
const int64_t n = other_dims[other_dims.size() - 1];
29+
30+
const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
31+
for (int64_t i = 0; i < static_cast<int64_t>(input_dims.size()) - 2; ++i) {
32+
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
33+
}
34+
35+
std::vector<int64_t> output_dims = input_dims;
36+
output_dims[output_dims.size() - 1] = n;
37+
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32);
38+
39+
for (int64_t b = 0; b < bs; ++b) {
40+
for (int64_t i = 0; i < m; ++i) {
41+
for (int64_t j = 0; j < n; ++j) {
42+
float acc = 0.0f;
43+
for (int64_t p = 0; p < k; ++p) {
44+
acc += static_cast<const float *>(input->DataPtr())[b * m * k + i * k + p]
45+
* static_cast<const float *>(other->DataPtr())[b * k * n + p * n + j];
46+
}
47+
static_cast<float *>(output->DataPtr())[b * m * n + i * n + j] = acc;
48+
}
49+
}
50+
}
51+
return {output};
52+
}
53+
54+
std::shared_ptr<Tensor> MatmulBackwardInput(const std::shared_ptr<Tensor> &other,
55+
const std::shared_ptr<Tensor> &grad_output,
56+
const std::vector<int64_t> &input_dims) {
57+
/*
58+
grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T
59+
*/
60+
const auto &other_dims = other->Dims();
61+
const auto &grad_output_dims = grad_output->Dims();
62+
63+
CHECK_GE(other_dims.size(), 2);
64+
CHECK_EQ(other_dims.size(), grad_output_dims.size());
65+
66+
const int64_t m = grad_output_dims[grad_output_dims.size() - 2];
67+
const int64_t k = other_dims[other_dims.size() - 2];
68+
const int64_t n = grad_output_dims[grad_output_dims.size() - 1];
69+
70+
const int64_t bs
71+
= std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies<int64_t>{});
72+
for (int64_t i = 0; i < static_cast<int64_t>(grad_output_dims.size()) - 2; ++i) {
73+
CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match";
74+
}
75+
76+
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
77+
grad_input->Fill<float>(0.0f);
78+
79+
for (int64_t b = 0; b < bs; ++b) {
80+
for (int64_t i = 0; i < m; ++i) {
81+
for (int64_t j = 0; j < n; ++j) {
82+
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
83+
for (int64_t p = 0; p < k; ++p) {
84+
const auto other_idx = b * k * n + p * n + j;
85+
static_cast<float *>(grad_input->DataPtr())[b * m * k + i * k + p]
86+
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
87+
}
88+
}
89+
}
90+
}
91+
return grad_input;
92+
}
93+
94+
std::shared_ptr<Tensor> MatmulBackwardOther(const std::shared_ptr<Tensor> &input1,
95+
const std::shared_ptr<Tensor> &grad_output,
96+
const std::vector<int64_t> &other_dims) {
97+
/*
98+
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
99+
*/
100+
const auto &input_dims = input1->Dims();
101+
const auto &grad_output_dims = grad_output->Dims();
102+
103+
CHECK_GE(input_dims.size(), 2);
104+
CHECK_EQ(input_dims.size(), grad_output_dims.size());
105+
106+
const int64_t m = input_dims[input_dims.size() - 2];
107+
const int64_t k = input_dims[input_dims.size() - 1];
108+
const int64_t n = grad_output_dims[grad_output_dims.size() - 1];
109+
CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]);
110+
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
111+
112+
const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
113+
for (int64_t i = 0; i < static_cast<int64_t>(input_dims.size()) - 2; ++i) {
114+
CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match";
115+
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
116+
}
117+
118+
auto grad_other = std::make_shared<Tensor>(other_dims, DataType::kFLOAT32);
119+
grad_other->Fill<float>(0.0f);
120+
121+
for (int64_t b = 0; b < bs; ++b) {
122+
for (int64_t i = 0; i < m; ++i) {
123+
for (int64_t j = 0; j < n; ++j) {
124+
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
125+
for (int64_t p = 0; p < k; ++p) {
126+
const auto input_idx = b * m * k + i * k + p;
127+
static_cast<float *>(grad_other->DataPtr())[b * k * n + p * n + j]
128+
+= grad * static_cast<const float *>(input1->DataPtr())[input_idx];
129+
}
130+
}
131+
}
132+
}
133+
return grad_other;
134+
}
135+
136+
} // namespace infini_train::kernels::cpu
137+
138+
#define REGISTER_CPU_MATMUL_KERNEL(kernel_name) \
139+
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)
140+
141+
REGISTER_CPU_MATMUL_KERNEL(MatmulForward)
142+
REGISTER_CPU_MATMUL_KERNEL(MatmulBackwardInput)
143+
REGISTER_CPU_MATMUL_KERNEL(MatmulBackwardOther)
144+
145+
#undef REGISTER_CPU_MATMUL_KERNEL

0 commit comments

Comments
 (0)