Skip to content

Commit e03c095

Browse files
authored
Align output shape formulas across conv implementations (#237)
1 parent 541f549 commit e03c095

1 file changed

Lines changed: 22 additions & 36 deletions

File tree

include/layers/ConvLayer.hpp

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88

99
namespace it_lab_ai {
1010

11+
inline size_t ComputeConvOutputDim(size_t input_size, size_t kernel_size,
12+
size_t stride, size_t padding,
13+
size_t dilation) {
14+
const size_t effective_kernel = dilation * (kernel_size - 1) + 1;
15+
if (stride == 0 || input_size + 2 * padding < effective_kernel) {
16+
return 0;
17+
}
18+
return (input_size + 2 * padding - effective_kernel) / stride + 1;
19+
}
20+
1121
class ConvolutionalLayer : public Layer {
1222
private:
1323
size_t stride_;
@@ -165,12 +175,10 @@ void Conv4D(const Tensor& input, const Tensor& kernel_, const Tensor& bias_,
165175
}
166176
}
167177

168-
size_t out_height =
169-
(in_height + 2 * pads_ - dilations_ * (kernel_height - 1) - 1) / stride_ +
170-
1;
178+
size_t out_height = ComputeConvOutputDim(in_height, kernel_height, stride_,
179+
pads_, dilations_);
171180
size_t out_width =
172-
(in_width + 2 * pads_ - dilations_ * (kernel_width - 1) - 1) / stride_ +
173-
1;
181+
ComputeConvOutputDim(in_width, kernel_width, stride_, pads_, dilations_);
174182

175183
std::vector<std::vector<std::vector<std::vector<ValueType>>>> padded_input(
176184
batch_size,
@@ -352,20 +360,10 @@ void Conv4DSTL(const Tensor& input, const Tensor& kernel_, const Tensor& bias_,
352360
for (auto& t : threads) t.join();
353361
threads.clear();
354362

355-
size_t crat = 0;
356-
if ((in_height + 2 * pads_ - dilations_ * (kernel_height - 1)) % stride_ != 0)
357-
crat = 1;
358-
359-
size_t out_height =
360-
(in_height + 2 * pads_ - dilations_ * (kernel_height - 1)) / stride_ +
361-
crat;
362-
363-
crat = 0;
364-
if ((in_width + 2 * pads_ - dilations_ * (kernel_width - 1)) % stride_ != 0)
365-
crat = 1;
366-
363+
size_t out_height = ComputeConvOutputDim(in_height, kernel_height, stride_,
364+
pads_, dilations_);
367365
size_t out_width =
368-
(in_width + 2 * pads_ - dilations_ * (kernel_width - 1)) / stride_ + crat;
366+
ComputeConvOutputDim(in_width, kernel_width, stride_, pads_, dilations_);
369367

370368
std::vector<std::vector<std::vector<std::vector<ValueType>>>> output_tensor(
371369
batch_size, std::vector<std::vector<std::vector<ValueType>>>(
@@ -474,12 +472,10 @@ void DepthwiseConv4D(const Tensor& input, const Tensor& kernel_,
474472
throw std::runtime_error("Invalid kernel shape for depthwise convolution");
475473
}
476474

477-
size_t out_height =
478-
(in_height + 2 * pads_ - dilations_ * (kernel_height - 1) - 1) / stride_ +
479-
1;
475+
size_t out_height = ComputeConvOutputDim(in_height, kernel_height, stride_,
476+
pads_, dilations_);
480477
size_t out_width =
481-
(in_width + 2 * pads_ - dilations_ * (kernel_width - 1) - 1) / stride_ +
482-
1;
478+
ComputeConvOutputDim(in_width, kernel_width, stride_, pads_, dilations_);
483479

484480
Tensor output_tensor(Shape({batch_size, channels, out_height, out_width}),
485481
input.get_type());
@@ -568,20 +564,10 @@ void Conv4D_Legacy(const Tensor& input, const Tensor& kernel_,
568564
}
569565
}
570566

571-
size_t crat = 0;
572-
if ((in_height + 2 * pads_ - dilations_ * (kernel_height - 1)) % stride_ != 0)
573-
crat = 1;
574-
575-
size_t out_height =
576-
(in_height + 2 * pads_ - dilations_ * (kernel_height - 1)) / stride_ +
577-
crat;
578-
579-
crat = 0;
580-
if ((in_width + 2 * pads_ - dilations_ * (kernel_width - 1)) % stride_ != 0)
581-
crat = 1;
582-
567+
size_t out_height = ComputeConvOutputDim(in_height, kernel_height, stride_,
568+
pads_, dilations_);
583569
size_t out_width =
584-
(in_width + 2 * pads_ - dilations_ * (kernel_width - 1)) / stride_ + crat;
570+
ComputeConvOutputDim(in_width, kernel_width, stride_, pads_, dilations_);
585571

586572
std::vector<std::vector<std::vector<std::vector<ValueType>>>> output_tensor(
587573
batch_size, std::vector<std::vector<std::vector<ValueType>>>(

0 commit comments

Comments
 (0)