Skip to content

Commit 66e45dc

Browse files
committed
refactor(matmul): split MatmulBackward kernel into 2 independent kernels
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
1 parent 4534154 commit 66e45dc

File tree

3 files changed

+180
-98
lines changed

3 files changed

+180
-98
lines changed

infini_train/src/autograd/matmul.cc

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,17 @@ void Matmul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
3131
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
3232
// determined by autocast, not derived from output->Dtype().
3333
auto compute_dtype = output->Dtype();
34-
saved_tensors_ = {
35-
input1->Dtype() == compute_dtype ? input1 : std::make_shared<Tensor>(input1->To(compute_dtype)),
36-
input2->Dtype() == compute_dtype ? input2 : std::make_shared<Tensor>(input2->To(compute_dtype)),
34+
35+
// grad_input1 = grad_output @ input2^T, so input2 is needed
36+
// grad_input2 = grad_output^T @ input1, so input1 is needed
37+
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
38+
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];
39+
40+
auto cast = [&](const std::shared_ptr<Tensor> &t) {
41+
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
3742
};
43+
44+
saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
3845
out_features_ = output->Dims()[0];
3946
}
4047

@@ -45,10 +52,24 @@ std::vector<std::shared_ptr<Tensor>> Matmul::Backward(const std::vector<std::sha
4552
CHECK_EQ(grad_outputs.size(), 1);
4653
const auto &grad_output = grad_outputs[0];
4754

55+
CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Matmul::Backward";
56+
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
57+
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];
58+
4859
auto device = input1->GetDevice().type();
49-
auto [grad_input1, grad_input2]
50-
= Dispatcher::Instance().Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
51-
{device, "MatmulBackward"}, input1, input2, grad_output);
60+
61+
std::shared_ptr<Tensor> grad_input1 = nullptr;
62+
std::shared_ptr<Tensor> grad_input2 = nullptr;
63+
64+
if (need_grad_input1) {
65+
grad_input1 = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput1"}, input2,
66+
grad_output, input1->Dims());
67+
}
68+
if (need_grad_input2) {
69+
grad_input2 = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput2"}, input1,
70+
grad_output, input2->Dims());
71+
}
72+
5273
return {grad_input1, grad_input2};
5374
}
5475
} // namespace infini_train::autograd

infini_train/src/kernels/cpu/linear.cc

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,38 +50,71 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
5050
return {output};
5151
}
5252

53-
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
54-
MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
55-
const std::shared_ptr<Tensor> &grad_output) {
53+
std::shared_ptr<Tensor> MatmulBackwardInput1(const std::shared_ptr<Tensor> &other,
54+
const std::shared_ptr<Tensor> &grad_output,
55+
const std::vector<int64_t> &input_dims) {
5656
/*
5757
grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T
58-
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
5958
*/
60-
const auto &input_dims = input->Dims();
6159
const auto &other_dims = other->Dims();
6260
const auto &grad_output_dims = grad_output->Dims();
6361

62+
CHECK_GE(other_dims.size(), 2);
63+
CHECK_EQ(other_dims.size(), grad_output_dims.size());
64+
65+
const int64_t m = grad_output_dims[grad_output_dims.size() - 2];
66+
const int64_t k = other_dims[other_dims.size() - 2];
67+
const int64_t n = grad_output_dims[grad_output_dims.size() - 1];
68+
69+
const int64_t bs
70+
= std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies<int64_t>{});
71+
for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) {
72+
CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match";
73+
}
74+
75+
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
76+
grad_input->Fill<float>(0.0f);
77+
78+
for (int64_t b = 0; b < bs; ++b) {
79+
for (int64_t i = 0; i < m; ++i) {
80+
for (int64_t j = 0; j < n; ++j) {
81+
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
82+
for (int64_t p = 0; p < k; ++p) {
83+
const auto other_idx = b * k * n + p * n + j;
84+
static_cast<float *>(grad_input->DataPtr())[b * m * k + i * k + p]
85+
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
86+
}
87+
}
88+
}
89+
}
90+
return grad_input;
91+
}
92+
93+
std::shared_ptr<Tensor> MatmulBackwardInput2(const std::shared_ptr<Tensor> &input1,
94+
const std::shared_ptr<Tensor> &grad_output,
95+
const std::vector<int64_t> &other_dims) {
96+
/*
97+
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
98+
*/
99+
const auto &input_dims = input1->Dims();
100+
const auto &grad_output_dims = grad_output->Dims();
101+
64102
CHECK_GE(input_dims.size(), 2);
65-
CHECK_EQ(input_dims.size(), other_dims.size());
66103
CHECK_EQ(input_dims.size(), grad_output_dims.size());
67104

68105
const int64_t m = input_dims[input_dims.size() - 2];
69106
const int64_t k = input_dims[input_dims.size() - 1];
70-
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
71-
const int64_t n = other_dims[other_dims.size() - 1];
72-
107+
const int64_t n = grad_output_dims[grad_output_dims.size() - 1];
73108
CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]);
74-
CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]);
109+
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
75110

76111
const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
77112
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
78-
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
79113
CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match";
114+
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
80115
}
81116

82-
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
83117
auto grad_other = std::make_shared<Tensor>(other_dims, DataType::kFLOAT32);
84-
grad_input->Fill<float>(0.0f);
85118
grad_other->Fill<float>(0.0f);
86119

87120
for (int64_t b = 0; b < bs; ++b) {
@@ -90,16 +123,13 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
90123
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
91124
for (int64_t p = 0; p < k; ++p) {
92125
const auto input_idx = b * m * k + i * k + p;
93-
const auto other_idx = b * k * n + p * n + j;
94-
static_cast<float *>(grad_input->DataPtr())[input_idx]
95-
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
96-
static_cast<float *>(grad_other->DataPtr())[other_idx]
97-
+= grad * static_cast<const float *>(input->DataPtr())[input_idx];
126+
static_cast<float *>(grad_other->DataPtr())[b * k * n + p * n + j]
127+
+= grad * static_cast<const float *>(input1->DataPtr())[input_idx];
98128
}
99129
}
100130
}
101131
}
102-
return {grad_input, grad_other};
132+
return grad_other;
103133
}
104134

105135
std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
@@ -201,7 +231,8 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o
201231
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)
202232

203233
REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
204-
REGISTER_CPU_LINEAR_KERNEL(MatmulBackward)
234+
REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput1)
235+
REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput2)
205236
REGISTER_CPU_LINEAR_KERNEL(LinearForward)
206237
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput)
207238
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight)

0 commit comments

Comments
 (0)