-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathlinear.cc
More file actions
216 lines (180 loc) · 9.11 KB
/
linear.cc
File metadata and controls
216 lines (180 loc) · 9.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#include <cstdint>
#include <memory>
#include <numeric>
#include <tuple>
#include "glog/logging.h"
#include "infini_train/include/autograd/linear.h"
#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"
namespace infini_train::kernels::cpu {
std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other) {
/*
output[*, m, n] = input[*, m, k] * other[*, k, n]
*/
// TODO(dcj): support broadcast later
const auto &input_dims = input->Dims();
const auto &other_dims = other->Dims();
CHECK_GE(input_dims.size(), 2);
CHECK_GE(other_dims.size(), 2);
CHECK_EQ(input_dims.size(), other_dims.size());
const int64_t m = input_dims[input_dims.size() - 2];
const int64_t k = input_dims[input_dims.size() - 1];
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
const int64_t n = other_dims[other_dims.size() - 1];
const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
}
std::vector<int64_t> output_dims = input_dims;
output_dims[output_dims.size() - 1] = n;
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32);
for (int64_t b = 0; b < bs; ++b) {
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
float acc = 0.0f;
for (int64_t p = 0; p < k; ++p) {
acc += static_cast<const float *>(input->DataPtr())[b * m * k + i * k + p]
* static_cast<const float *>(other->DataPtr())[b * k * n + p * n + j];
}
static_cast<float *>(output->DataPtr())[b * m * n + i * n + j] = acc;
}
}
}
return {output};
}
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
const std::shared_ptr<Tensor> &grad_output) {
/*
grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
*/
const auto &input_dims = input->Dims();
const auto &other_dims = other->Dims();
const auto &grad_output_dims = grad_output->Dims();
CHECK_GE(input_dims.size(), 2);
CHECK_EQ(input_dims.size(), other_dims.size());
CHECK_EQ(input_dims.size(), grad_output_dims.size());
const int64_t m = input_dims[input_dims.size() - 2];
const int64_t k = input_dims[input_dims.size() - 1];
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
const int64_t n = other_dims[other_dims.size() - 1];
CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]);
CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]);
const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match";
}
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
auto grad_other = std::make_shared<Tensor>(other_dims, DataType::kFLOAT32);
grad_input->Fill<float>(0.0f);
grad_other->Fill<float>(0.0f);
for (int64_t b = 0; b < bs; ++b) {
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
for (int64_t p = 0; p < k; ++p) {
const auto input_idx = b * m * k + i * k + p;
const auto other_idx = b * k * n + p * n + j;
static_cast<float *>(grad_input->DataPtr())[input_idx]
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
static_cast<float *>(grad_other->DataPtr())[other_idx]
+= grad * static_cast<const float *>(input->DataPtr())[input_idx];
}
}
}
}
return {grad_input, grad_other};
}
std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
bool transpose, const std::shared_ptr<Tensor> &bias) {
/*
transpose: output = input * weight^T + bias
output[*, out_features] = input[*, in_features] * weight[out_features, in_features]^T + bias[out_features]
!transpose: output = input * weight + bias
output[*, out_features] = input[*, in_features] * weight[in_features, out_features] + bias[out_features]
*/
const auto &input_dims = input->Dims();
CHECK_GE(input_dims.size(), 2);
const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies<int64_t>{});
const int64_t in_features = *input_dims.rbegin();
const auto &weight_dims = weight->Dims();
CHECK_EQ(weight_dims.size(), 2);
CHECK_EQ(in_features, weight_dims[transpose ? 1 : 0]);
const int out_features = weight_dims[transpose ? 0 : 1];
if (bias) {
const auto &bias_dims = bias->Dims();
CHECK_EQ(bias_dims.size(), 1);
CHECK_EQ(bias_dims[0], out_features);
}
auto output_dims = input_dims;
*output_dims.rbegin() = out_features;
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32);
if (transpose) {
output->EigenMatrix() = input->EigenMatrix() * weight->EigenMatrix().transpose();
} else {
output->EigenMatrix() = input->EigenMatrix() * weight->EigenMatrix();
}
if (bias) {
output->EigenMatrix().rowwise() += bias->EigenVector();
}
return output;
}
// TODO(dcj): support linear without bias later
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
int64_t in_features, int64_t out_features, const std::vector<int64_t> &input_dims,
const std::shared_ptr<Tensor> &grad_output, bool bias,
infini_train::autograd::LinearGradFlags grad_flags) {
/*
transpose: grad_input = grad_output * weight
grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
!transpose: grad_input = grad_output * weight^T
grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
*/
const auto compute_grad_input = grad_flags.input;
const auto compute_grad_weight = grad_flags.weight;
const auto compute_grad_bias = grad_flags.bias;
CHECK_GE(input_dims.size(), 2);
std::vector<int64_t> weight_dims
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;
if (compute_grad_input) {
CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)";
grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
if (transpose) {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
} else {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
}
}
if (compute_grad_weight) {
CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)";
grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
if (transpose) {
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
} else {
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
}
}
if (compute_grad_bias && bias) {
grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
}
return {grad_input, grad_weight, grad_bias};
}
} // namespace infini_train::kernels::cpu
#define REGISTER_CPU_LINEAR_KERNEL(kernel_name) \
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)
REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
REGISTER_CPU_LINEAR_KERNEL(MatmulBackward)
REGISTER_CPU_LINEAR_KERNEL(LinearForward)
REGISTER_CPU_LINEAR_KERNEL(LinearBackward)
#undef REGISTER_CPU_LINEAR_KERNEL