Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions include/layers/PoolingLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,18 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
std::vector<ValueType> res(this->outputShape_.count(), ValueType(0));

size_t spatial_dims = poolingShape_.dims();
auto make_strides = [](const Shape& shape) {
std::vector<size_t> strides(shape.dims(), 1);
for (size_t i = shape.dims(); i > 1; --i) {
strides[i - 2] = strides[i - 1] * shape[i - 1];
}
return strides;
};

const auto input_strides = make_strides(this->inputShape_);
int batch_dim = this->inputShape_.dims() > spatial_dims ? 0 : -1;
int channel_dim = this->inputShape_.dims() > spatial_dims + 1 ? 1 : -1;
size_t input_spatial_dim = this->inputShape_.dims() - spatial_dims;

size_t out_h = this->outputShape_[this->outputShape_.dims() - spatial_dims];
size_t out_w =
Expand All @@ -219,6 +229,19 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
size_t out_c = channel_dim >= 0 ? this->outputShape_[channel_dim] : 1;

size_t total_work = out_n * out_c * out_h * out_w;
size_t kernel_w = spatial_dims > 1 ? poolingShape_[1] : 1;
int input_h_limit = static_cast<int>(this->inputShape_[input_spatial_dim]);
int input_w_limit =
spatial_dims > 1
? static_cast<int>(this->inputShape_[input_spatial_dim + 1])
: 0;
size_t input_batch_stride =
batch_dim >= 0 ? input_strides[batch_dim] : static_cast<size_t>(0);
size_t input_channel_stride =
channel_dim >= 0 ? input_strides[channel_dim] : static_cast<size_t>(0);
size_t input_h_stride = input_strides[input_spatial_dim];
size_t input_w_stride =
spatial_dims > 1 ? input_strides[input_spatial_dim + 1] : 0;

parallel::Options options;
options.backend = parallel_backend_;
Expand Down Expand Up @@ -247,38 +270,26 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
int start_w = spatial_dims > 1 ? static_cast<int>(w * strides_[1]) -
static_cast<int>(pads_[2])
: 0;
size_t input_base = n * input_batch_stride + c * input_channel_stride;

auto sum = ValueType(0);
ValueType max_val = std::numeric_limits<ValueType>::lowest();
size_t count = 0;

for (size_t kh = 0; kh < poolingShape_[0]; kh++) {
for (size_t kw = 0; kw < (spatial_dims > 1 ? poolingShape_[1] : 1);
kw++) {
for (size_t kw = 0; kw < kernel_w; kw++) {
int pos_h = start_h + static_cast<int>(kh * dilations_[0]);
int pos_w = spatial_dims > 1
? start_w + static_cast<int>(kw * dilations_[1])
: 0;

if (pos_h >= 0 &&
pos_h < static_cast<int>(
this->inputShape_[this->inputShape_.dims() -
spatial_dims]) &&
(spatial_dims <= 1 ||
(pos_w >= 0 &&
pos_w < static_cast<int>(
this->inputShape_[this->inputShape_.dims() -
spatial_dims + 1])))) {
std::vector<size_t> input_coords(this->inputShape_.dims(), 0);
if (batch_dim >= 0) input_coords[batch_dim] = n;
if (channel_dim >= 0) input_coords[channel_dim] = c;
input_coords[this->inputShape_.dims() - spatial_dims] = pos_h;
if (pos_h >= 0 && pos_h < input_h_limit &&
(spatial_dims <= 1 || (pos_w >= 0 && pos_w < input_w_limit))) {
size_t input_index =
input_base + static_cast<size_t>(pos_h) * input_h_stride;
if (spatial_dims > 1) {
input_coords[this->inputShape_.dims() - spatial_dims + 1] =
pos_w;
input_index += static_cast<size_t>(pos_w) * input_w_stride;
}

size_t input_index = this->inputShape_.get_index(input_coords);
ValueType val = input[input_index];

if (this->poolingType_ == kMax) {
Expand All @@ -295,22 +306,12 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(

if (count == 0) return;

std::vector<size_t> output_coords(this->outputShape_.dims(), 0);
if (batch_dim >= 0) output_coords[batch_dim] = n;
if (channel_dim >= 0) output_coords[channel_dim] = c;
output_coords[this->outputShape_.dims() - spatial_dims] = h;
if (spatial_dims > 1) {
output_coords[this->outputShape_.dims() - spatial_dims + 1] = w;
}

size_t output_index = this->outputShape_.get_index(output_coords);

switch (this->poolingType_) {
case kAverage:
res[output_index] = sum / static_cast<ValueType>(count);
res[idx] = sum / static_cast<ValueType>(count);
break;
case kMax:
res[output_index] = max_val;
res[idx] = max_val;
break;
default:
throw std::runtime_error("Unknown pooling type");
Expand Down
Loading