@@ -207,8 +207,18 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
207207 std::vector<ValueType> res (this ->outputShape_ .count (), ValueType (0 ));
208208
209209 size_t spatial_dims = poolingShape_.dims ();
210+ auto make_strides = [](const Shape& shape) {
211+ std::vector<size_t > strides (shape.dims (), 1 );
212+ for (size_t i = shape.dims (); i > 1 ; --i) {
213+ strides[i - 2 ] = strides[i - 1 ] * shape[i - 1 ];
214+ }
215+ return strides;
216+ };
217+
218+ const auto input_strides = make_strides (this ->inputShape_ );
210219 int batch_dim = this ->inputShape_ .dims () > spatial_dims ? 0 : -1 ;
211220 int channel_dim = this ->inputShape_ .dims () > spatial_dims + 1 ? 1 : -1 ;
221+ size_t input_spatial_dim = this ->inputShape_ .dims () - spatial_dims;
212222
213223 size_t out_h = this ->outputShape_ [this ->outputShape_ .dims () - spatial_dims];
214224 size_t out_w =
@@ -219,6 +229,19 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
219229 size_t out_c = channel_dim >= 0 ? this ->outputShape_ [channel_dim] : 1 ;
220230
221231 size_t total_work = out_n * out_c * out_h * out_w;
232+ size_t kernel_w = spatial_dims > 1 ? poolingShape_[1 ] : 1 ;
233+ int input_h_limit = static_cast <int >(this ->inputShape_ [input_spatial_dim]);
234+ int input_w_limit =
235+ spatial_dims > 1
236+ ? static_cast <int >(this ->inputShape_ [input_spatial_dim + 1 ])
237+ : 0 ;
238+ size_t input_batch_stride =
239+ batch_dim >= 0 ? input_strides[batch_dim] : static_cast <size_t >(0 );
240+ size_t input_channel_stride =
241+ channel_dim >= 0 ? input_strides[channel_dim] : static_cast <size_t >(0 );
242+ size_t input_h_stride = input_strides[input_spatial_dim];
243+ size_t input_w_stride =
244+ spatial_dims > 1 ? input_strides[input_spatial_dim + 1 ] : 0 ;
222245
223246 parallel::Options options;
224247 options.backend = parallel_backend_;
@@ -247,38 +270,26 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
247270 int start_w = spatial_dims > 1 ? static_cast <int >(w * strides_[1 ]) -
248271 static_cast <int >(pads_[2 ])
249272 : 0 ;
273+ size_t input_base = n * input_batch_stride + c * input_channel_stride;
250274
251275 auto sum = ValueType (0 );
252276 ValueType max_val = std::numeric_limits<ValueType>::lowest ();
253277 size_t count = 0 ;
254278
255279 for (size_t kh = 0 ; kh < poolingShape_[0 ]; kh++) {
256- for (size_t kw = 0 ; kw < (spatial_dims > 1 ? poolingShape_[1 ] : 1 );
257- kw++) {
280+ for (size_t kw = 0 ; kw < kernel_w; kw++) {
258281 int pos_h = start_h + static_cast <int >(kh * dilations_[0 ]);
259282 int pos_w = spatial_dims > 1
260283 ? start_w + static_cast <int >(kw * dilations_[1 ])
261284 : 0 ;
262285
263- if (pos_h >= 0 &&
264- pos_h < static_cast <int >(
265- this ->inputShape_ [this ->inputShape_ .dims () -
266- spatial_dims]) &&
267- (spatial_dims <= 1 ||
268- (pos_w >= 0 &&
269- pos_w < static_cast <int >(
270- this ->inputShape_ [this ->inputShape_ .dims () -
271- spatial_dims + 1 ])))) {
272- std::vector<size_t > input_coords (this ->inputShape_ .dims (), 0 );
273- if (batch_dim >= 0 ) input_coords[batch_dim] = n;
274- if (channel_dim >= 0 ) input_coords[channel_dim] = c;
275- input_coords[this ->inputShape_ .dims () - spatial_dims] = pos_h;
286+ if (pos_h >= 0 && pos_h < input_h_limit &&
287+ (spatial_dims <= 1 || (pos_w >= 0 && pos_w < input_w_limit))) {
288+ size_t input_index =
289+ input_base + static_cast <size_t >(pos_h) * input_h_stride;
276290 if (spatial_dims > 1 ) {
277- input_coords[this ->inputShape_ .dims () - spatial_dims + 1 ] =
278- pos_w;
291+ input_index += static_cast <size_t >(pos_w) * input_w_stride;
279292 }
280-
281- size_t input_index = this ->inputShape_ .get_index (input_coords);
282293 ValueType val = input[input_index];
283294
284295 if (this ->poolingType_ == kMax ) {
@@ -295,22 +306,12 @@ std::vector<ValueType> PoolingLayerImpl<ValueType>::run(
295306
296307 if (count == 0 ) return ;
297308
298- std::vector<size_t > output_coords (this ->outputShape_ .dims (), 0 );
299- if (batch_dim >= 0 ) output_coords[batch_dim] = n;
300- if (channel_dim >= 0 ) output_coords[channel_dim] = c;
301- output_coords[this ->outputShape_ .dims () - spatial_dims] = h;
302- if (spatial_dims > 1 ) {
303- output_coords[this ->outputShape_ .dims () - spatial_dims + 1 ] = w;
304- }
305-
306- size_t output_index = this ->outputShape_ .get_index (output_coords);
307-
308309 switch (this ->poolingType_ ) {
309310 case kAverage :
310- res[output_index ] = sum / static_cast <ValueType>(count);
311+ res[idx ] = sum / static_cast <ValueType>(count);
311312 break ;
312313 case kMax :
313- res[output_index ] = max_val;
314+ res[idx ] = max_val;
314315 break ;
315316 default :
316317 throw std::runtime_error (" Unknown pooling type" );
0 commit comments