Skip to content

Commit 4534154

Browse files
committed
Refactor(linear): split LinearBackward kernel into 3 independent kernels
Move grad_flags logic from kernel to autograd layer. The monolithic LinearBackward kernel is replaced by LinearBackwardInput, LinearBackwardWeight, and LinearBackwardBias — each a pure compute operation with no autograd-related parameters.
1 parent 3d61e10 commit 4534154

File tree

4 files changed

+203
-214
lines changed

4 files changed

+203
-214
lines changed

infini_train/include/autograd/linear.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@ class Tensor;
1212

1313
namespace infini_train::autograd {
1414

15-
struct LinearGradFlags {
16-
bool input = false;
17-
bool weight = false;
18-
bool bias = false;
19-
};
20-
2115
class Linear : public Function {
2216
public:
2317
static constexpr char kType[] = "LinearFunction";

infini_train/src/autograd/linear.cc

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,29 @@ std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::sha
5656
const auto &grad_output = grad_outputs[0];
5757

5858
CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward";
59-
LinearGradFlags grad_flags = {.input = needs_input_grad_[0],
60-
.weight = needs_input_grad_.size() > 1 && needs_input_grad_[1],
61-
.bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]};
59+
bool need_grad_input = needs_input_grad_[0];
60+
bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
61+
bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2];
6262

6363
auto device = grad_output->GetDevice().type();
64-
// TODO: skip autograd graph construction entirely when no input requires grad
65-
auto [grad_input, grad_weight, grad_bias]
66-
= Dispatcher::Instance()
67-
.Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
68-
{device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_,
69-
grad_output, bias_, grad_flags);
64+
65+
std::shared_ptr<Tensor> grad_input = nullptr;
66+
std::shared_ptr<Tensor> grad_weight = nullptr;
67+
std::shared_ptr<Tensor> grad_bias = nullptr;
68+
69+
if (need_grad_input) {
70+
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
71+
{device, "LinearBackwardInput"}, weight, grad_output, transpose_, in_features_, out_features_, input_dims_);
72+
}
73+
if (need_grad_weight) {
74+
grad_weight = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
75+
{device, "LinearBackwardWeight"}, input, grad_output, transpose_, in_features_, out_features_);
76+
}
77+
if (need_grad_bias) {
78+
grad_bias = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "LinearBackwardBias"}, grad_output,
79+
out_features_);
80+
}
81+
7082
if (bias_) {
7183
return {grad_input, grad_weight, grad_bias};
7284
} else {

infini_train/src/kernels/cpu/linear.cc

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
#include "glog/logging.h"
77

8-
#include "infini_train/include/autograd/linear.h"
98
#include "infini_train/include/dispatcher.h"
109
#include "infini_train/include/tensor.h"
1110

@@ -146,62 +145,55 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
146145
return output;
147146
}
148147

149-
// TODO(dcj): support linear without bias later
150-
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
151-
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
152-
int64_t in_features, int64_t out_features, const std::vector<int64_t> &input_dims,
153-
const std::shared_ptr<Tensor> &grad_output, bool bias,
154-
infini_train::autograd::LinearGradFlags grad_flags) {
148+
std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weight,
149+
const std::shared_ptr<Tensor> &grad_output, bool transpose,
150+
int64_t in_features, int64_t out_features,
151+
const std::vector<int64_t> &input_dims) {
155152
/*
156153
transpose: grad_input = grad_output * weight
157154
grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
158-
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
159-
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
160155
161156
!transpose: grad_input = grad_output * weight^T
162157
grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T
163-
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
164-
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
165158
*/
166-
const auto compute_grad_input = grad_flags.input;
167-
const auto compute_grad_weight = grad_flags.weight;
168-
const auto compute_grad_bias = grad_flags.bias;
169-
170159
CHECK_GE(input_dims.size(), 2);
171-
172-
std::vector<int64_t> weight_dims
173-
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
174-
175-
std::shared_ptr<Tensor> grad_input = nullptr;
176-
std::shared_ptr<Tensor> grad_weight = nullptr;
177-
std::shared_ptr<Tensor> grad_bias = nullptr;
178-
179-
if (compute_grad_input) {
180-
CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)";
181-
grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
182-
if (transpose) {
183-
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
184-
} else {
185-
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
186-
}
160+
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
161+
if (transpose) {
162+
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
163+
} else {
164+
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
187165
}
166+
return grad_input;
167+
}
188168

189-
if (compute_grad_weight) {
190-
CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)";
191-
grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
192-
if (transpose) {
193-
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
194-
} else {
195-
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
196-
}
197-
}
169+
std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &input,
170+
const std::shared_ptr<Tensor> &grad_output, bool transpose,
171+
int64_t in_features, int64_t out_features) {
172+
/*
173+
transpose:
174+
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
198175
199-
if (compute_grad_bias && bias) {
200-
grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
201-
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
176+
!transpose:
177+
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
178+
*/
179+
std::vector<int64_t> weight_dims
180+
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
181+
auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
182+
if (transpose) {
183+
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
184+
} else {
185+
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
202186
}
187+
return grad_weight;
188+
}
203189

204-
return {grad_input, grad_weight, grad_bias};
190+
std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_output, int64_t out_features) {
191+
/*
192+
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
193+
*/
194+
auto grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
195+
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
196+
return grad_bias;
205197
}
206198
} // namespace infini_train::kernels::cpu
207199

@@ -211,6 +203,8 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
211203
REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
212204
REGISTER_CPU_LINEAR_KERNEL(MatmulBackward)
213205
REGISTER_CPU_LINEAR_KERNEL(LinearForward)
214-
REGISTER_CPU_LINEAR_KERNEL(LinearBackward)
206+
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput)
207+
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight)
208+
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardBias)
215209

216210
#undef REGISTER_CPU_LINEAR_KERNEL

0 commit comments

Comments
 (0)