Skip to content

Commit b1a952f

Browse files
authored
Move idx arithmetics in pooling outside the parallel loop (#280)
1 parent a04122b commit b1a952f

1 file changed

Lines changed: 32 additions & 31 deletions

File tree

include/layers/PoolingLayer.hpp

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,18 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
207207
std::vector<ValueType> res(this->outputShape_.count(), ValueType(0));
208208

209209
size_t spatial_dims = poolingShape_.dims();
210+
auto make_strides = [](const Shape& shape) {
211+
std::vector<size_t> strides(shape.dims(), 1);
212+
for (size_t i = shape.dims(); i > 1; --i) {
213+
strides[i - 2] = strides[i - 1] * shape[i - 1];
214+
}
215+
return strides;
216+
};
217+
218+
const auto input_strides = make_strides(this->inputShape_);
210219
int batch_dim = this->inputShape_.dims() > spatial_dims ? 0 : -1;
211220
int channel_dim = this->inputShape_.dims() > spatial_dims + 1 ? 1 : -1;
221+
size_t input_spatial_dim = this->inputShape_.dims() - spatial_dims;
212222

213223
size_t out_h = this->outputShape_[this->outputShape_.dims() - spatial_dims];
214224
size_t out_w =
@@ -219,6 +229,19 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
219229
size_t out_c = channel_dim >= 0 ? this->outputShape_[channel_dim] : 1;
220230

221231
size_t total_work = out_n * out_c * out_h * out_w;
232+
size_t kernel_w = spatial_dims > 1 ? poolingShape_[1] : 1;
233+
int input_h_limit = static_cast<int>(this->inputShape_[input_spatial_dim]);
234+
int input_w_limit =
235+
spatial_dims > 1
236+
? static_cast<int>(this->inputShape_[input_spatial_dim + 1])
237+
: 0;
238+
size_t input_batch_stride =
239+
batch_dim >= 0 ? input_strides[batch_dim] : static_cast<size_t>(0);
240+
size_t input_channel_stride =
241+
channel_dim >= 0 ? input_strides[channel_dim] : static_cast<size_t>(0);
242+
size_t input_h_stride = input_strides[input_spatial_dim];
243+
size_t input_w_stride =
244+
spatial_dims > 1 ? input_strides[input_spatial_dim + 1] : 0;
222245

223246
parallel::Options options;
224247
options.backend = parallel_backend_;
@@ -247,38 +270,26 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
247270
int start_w = spatial_dims > 1 ? static_cast<int>(w * strides_[1]) -
248271
static_cast<int>(pads_[2])
249272
: 0;
273+
size_t input_base = n * input_batch_stride + c * input_channel_stride;
250274

251275
auto sum = ValueType(0);
252276
ValueType max_val = std::numeric_limits<ValueType>::lowest();
253277
size_t count = 0;
254278

255279
for (size_t kh = 0; kh < poolingShape_[0]; kh++) {
256-
for (size_t kw = 0; kw < (spatial_dims > 1 ? poolingShape_[1] : 1);
257-
kw++) {
280+
for (size_t kw = 0; kw < kernel_w; kw++) {
258281
int pos_h = start_h + static_cast<int>(kh * dilations_[0]);
259282
int pos_w = spatial_dims > 1
260283
? start_w + static_cast<int>(kw * dilations_[1])
261284
: 0;
262285

263-
if (pos_h >= 0 &&
264-
pos_h < static_cast<int>(
265-
this->inputShape_[this->inputShape_.dims() -
266-
spatial_dims]) &&
267-
(spatial_dims <= 1 ||
268-
(pos_w >= 0 &&
269-
pos_w < static_cast<int>(
270-
this->inputShape_[this->inputShape_.dims() -
271-
spatial_dims + 1])))) {
272-
std::vector<size_t> input_coords(this->inputShape_.dims(), 0);
273-
if (batch_dim >= 0) input_coords[batch_dim] = n;
274-
if (channel_dim >= 0) input_coords[channel_dim] = c;
275-
input_coords[this->inputShape_.dims() - spatial_dims] = pos_h;
286+
if (pos_h >= 0 && pos_h < input_h_limit &&
287+
(spatial_dims <= 1 || (pos_w >= 0 && pos_w < input_w_limit))) {
288+
size_t input_index =
289+
input_base + static_cast<size_t>(pos_h) * input_h_stride;
276290
if (spatial_dims > 1) {
277-
input_coords[this->inputShape_.dims() - spatial_dims + 1] =
278-
pos_w;
291+
input_index += static_cast<size_t>(pos_w) * input_w_stride;
279292
}
280-
281-
size_t input_index = this->inputShape_.get_index(input_coords);
282293
ValueType val = input[input_index];
283294

284295
if (this->poolingType_ == kMax) {
@@ -295,22 +306,12 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
295306

296307
if (count == 0) return;
297308

298-
std::vector<size_t> output_coords(this->outputShape_.dims(), 0);
299-
if (batch_dim >= 0) output_coords[batch_dim] = n;
300-
if (channel_dim >= 0) output_coords[channel_dim] = c;
301-
output_coords[this->outputShape_.dims() - spatial_dims] = h;
302-
if (spatial_dims > 1) {
303-
output_coords[this->outputShape_.dims() - spatial_dims + 1] = w;
304-
}
305-
306-
size_t output_index = this->outputShape_.get_index(output_coords);
307-
308309
switch (this->poolingType_) {
309310
case kAverage:
310-
res[output_index] = sum / static_cast<ValueType>(count);
311+
res[idx] = sum / static_cast<ValueType>(count);
311312
break;
312313
case kMax:
313-
res[output_index] = max_val;
314+
res[idx] = max_val;
314315
break;
315316
default:
316317
throw std::runtime_error("Unknown pooling type");

0 commit comments

Comments
 (0)