55
66#include " glog/logging.h"
77
8- #include " infini_train/include/autograd/linear.h"
98#include " infini_train/include/dispatcher.h"
109#include " infini_train/include/tensor.h"
1110
@@ -146,62 +145,55 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
146145 return output;
147146}
148147
149- // TODO(dcj): support linear without bias later
150- std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
151- LinearBackward (const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
152- int64_t in_features, int64_t out_features, const std::vector<int64_t > &input_dims,
153- const std::shared_ptr<Tensor> &grad_output, bool bias,
154- infini_train::autograd::LinearGradFlags grad_flags) {
148+ std::shared_ptr<Tensor> LinearBackwardInput (const std::shared_ptr<Tensor> &weight,
149+ const std::shared_ptr<Tensor> &grad_output, bool transpose,
150+ int64_t in_features, int64_t out_features,
151+ const std::vector<int64_t > &input_dims) {
155152 /*
156153 transpose: grad_input = grad_output * weight
157154 grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
158- grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
159- grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
160155
161156 !transpose: grad_input = grad_output * weight^T
162157 grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T
163- grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
164- grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
165158 */
166- const auto compute_grad_input = grad_flags.input ;
167- const auto compute_grad_weight = grad_flags.weight ;
168- const auto compute_grad_bias = grad_flags.bias ;
169-
170159 CHECK_GE (input_dims.size (), 2 );
171-
172- std::vector<int64_t > weight_dims
173- = transpose ? std::vector<int64_t >{out_features, in_features} : std::vector<int64_t >{in_features, out_features};
174-
175- std::shared_ptr<Tensor> grad_input = nullptr ;
176- std::shared_ptr<Tensor> grad_weight = nullptr ;
177- std::shared_ptr<Tensor> grad_bias = nullptr ;
178-
179- if (compute_grad_input) {
180- CHECK (weight != nullptr ) << " compute_grad_input=true but weight is nullptr (selective save mismatch)" ;
181- grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32 );
182- if (transpose) {
183- grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ();
184- } else {
185- grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ().transpose ();
186- }
160+ auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32 );
161+ if (transpose) {
162+ grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ();
163+ } else {
164+ grad_input->EigenMatrix () = grad_output->EigenMatrix () * weight->EigenMatrix ().transpose ();
187165 }
166+ return grad_input;
167+ }
188168
189- if (compute_grad_weight) {
190- CHECK (input != nullptr ) << " compute_grad_weight=true but input is nullptr (selective save mismatch)" ;
191- grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32 );
192- if (transpose) {
193- grad_weight->EigenMatrix () = grad_output->EigenMatrix ().transpose () * input->EigenMatrix ();
194- } else {
195- grad_weight->EigenMatrix () = input->EigenMatrix ().transpose () * grad_output->EigenMatrix ();
196- }
197- }
169+ std::shared_ptr<Tensor> LinearBackwardWeight (const std::shared_ptr<Tensor> &input,
170+ const std::shared_ptr<Tensor> &grad_output, bool transpose,
171+ int64_t in_features, int64_t out_features) {
172+ /*
173+ transpose:
174+ grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
198175
199- if (compute_grad_bias && bias) {
200- grad_bias = std::make_shared<Tensor>(std::vector<int64_t >{out_features}, DataType::kFLOAT32 );
201- grad_bias->EigenVector () = grad_output->EigenMatrix ().colwise ().sum ();
176+ !transpose:
177+ grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
178+ */
179+ std::vector<int64_t > weight_dims
180+ = transpose ? std::vector<int64_t >{out_features, in_features} : std::vector<int64_t >{in_features, out_features};
181+ auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32 );
182+ if (transpose) {
183+ grad_weight->EigenMatrix () = grad_output->EigenMatrix ().transpose () * input->EigenMatrix ();
184+ } else {
185+ grad_weight->EigenMatrix () = input->EigenMatrix ().transpose () * grad_output->EigenMatrix ();
202186 }
187+ return grad_weight;
188+ }
203189
204- return {grad_input, grad_weight, grad_bias};
190+ std::shared_ptr<Tensor> LinearBackwardBias (const std::shared_ptr<Tensor> &grad_output, int64_t out_features) {
191+ /*
192+ grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
193+ */
194+ auto grad_bias = std::make_shared<Tensor>(std::vector<int64_t >{out_features}, DataType::kFLOAT32 );
195+ grad_bias->EigenVector () = grad_output->EigenMatrix ().colwise ().sum ();
196+ return grad_bias;
205197}
206198} // namespace infini_train::kernels::cpu
207199
@@ -211,6 +203,8 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
211203REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
212204REGISTER_CPU_LINEAR_KERNEL(MatmulBackward)
213205REGISTER_CPU_LINEAR_KERNEL(LinearForward)
214- REGISTER_CPU_LINEAR_KERNEL(LinearBackward)
206+ REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput)
207+ REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight)
208+ REGISTER_CPU_LINEAR_KERNEL(LinearBackwardBias)
215209
216210#undef REGISTER_CPU_LINEAR_KERNEL
0 commit comments