Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions tmva/sofie/inc/TMVA/ROperator_Pool.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,14 @@ public:
size_t input2 = (fDim > 1) ? input[0][3] : 1;
size_t input3 = (fDim > 2) ? input[0][4] : 1;

// use ceiling division when ceil_mode=1, floor otherwise
auto poolOutDim = [this](size_t in, size_t pad, size_t kern, size_t stride) -> size_t {
size_t n = in + pad - kern;
return (fAttrCeilMode ? (n + stride - 1) / stride : n / stride) + 1;
};

size_t pad1 = fAttrPads[0] + fAttrPads[i1];
size_t output1 = (input1 + pad1 - fAttrKernelShape[0]) / fAttrStrides[0] + 1;
size_t output1 = poolOutDim(input1, pad1, fAttrKernelShape[0], fAttrStrides[0]);

size_t batch_size = input[0][0]; // first element in input tensor
size_t output_channels = input[0][1]; // first element in output tensor
Expand All @@ -186,14 +192,14 @@ public:
return ret;

size_t pad2 = fAttrPads[1] + fAttrPads[i2];
size_t output2 = (input2 + pad2 - fAttrKernelShape[1]) / fAttrStrides[1] + 1;
size_t output2 = poolOutDim(input2, pad2, fAttrKernelShape[1], fAttrStrides[1]);
// output is N x C x OH x OW
ret[0].push_back(output2);
if (fDim == 2)
return ret;

size_t pad3 = fAttrPads[2] + fAttrPads[i3];
size_t output3 = (input3 + pad3 - fAttrKernelShape[2] ) / fAttrStrides[2] + 1;
size_t output3 = poolOutDim(input3, pad3, fAttrKernelShape[2], fAttrStrides[2]);

// output is N x C x OH x OW x OD
ret[0].push_back(output3);
Expand Down Expand Up @@ -283,20 +289,21 @@ public:
assert(fAttrKernelShape.size() == 3);
// find lower bounds of filtered area
int hmin = - fAttrPads[0]; // minimum lower bound value of filter area
int hmax = fShapeX[2] + fAttrPads[1] - fAttrKernelShape[0] +1; // maximum lower bound value + 1
// use stride instead of 1 when ceil_mode=1, so the loop covers the extra partial window
int hmax = fShapeX[2] + fAttrPads[1] - fAttrKernelShape[0] + (fAttrCeilMode ? (int)fAttrStrides[0] : 1);
int wmin,wmax,dmin,dmax;

if(fDim >= 2){
wmin = - fAttrPads[2]; // minimum lower bound value of filter area
wmax = fShapeX[3] + fAttrPads[3] - fAttrKernelShape[1] +1; // maximum lower bound value + 1
wmax = fShapeX[3] + fAttrPads[3] - fAttrKernelShape[1] + (fAttrCeilMode ? (int)fAttrStrides[1] : 1);
}
else{
wmin=1;
wmax=1;
}
if(fDim == 3){
dmin = - fAttrPads[4]; // minimum lower bound value of filter area
dmax = fShapeX[4] + fAttrPads[5] - fAttrKernelShape[2] +1; // maximum lower bound value + 1
dmax = fShapeX[4] + fAttrPads[5] - fAttrKernelShape[2] + (fAttrCeilMode ? (int)fAttrStrides[2] : 1);
}
else{
dmin=1;
Expand Down
19 changes: 19 additions & 0 deletions tmva/sofie/test/TestCustomModelsFromONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ constexpr auto modelDataSuffix = "_FromONNX.dat";
#include "input_models/references/ConvWithAsymmetricPadding.ref.hxx"
#include "input_models/references/MaxPool1d.ref.hxx"
#include "input_models/references/MaxPool2d.ref.hxx"
#include "input_models/references/MaxPool2d_CeilMode.ref.hxx"
#include "input_models/references/MaxPool3d.ref.hxx"
#include "input_models/references/Max.ref.hxx"
#include "input_models/references/MaxMultidirectionalBroadcast.ref.hxx"
Expand Down Expand Up @@ -787,6 +788,24 @@ TEST(ONNX, MaxPool2d){

}

TEST(ONNX, MaxPool2d_CeilMode)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// 1x1x5x5 input: values 0..24; MaxPool kernel=2x2 stride=2 ceil_mode=1 -> 1x1x3x3 output
std::vector<float> input(25);
for (int i = 0; i < 25; i++)
input[i] = static_cast<float>(i);

ASSERT_INCLUDE_AND_RUN(std::vector<float>, "MaxPool2d_CeilMode", input);
EXPECT_EQ(output.size(), sizeof(MaxPool2d_CeilMode_ExpectedOutput::output) / sizeof(float));

float *correct = MaxPool2d_CeilMode_ExpectedOutput::output;
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}
}

TEST(ONNX, MaxPool3d){
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace MaxPool2d_CeilMode_ExpectedOutput {
float output[] = {6, 8, 9, 16, 18, 19, 21, 23, 24};
} // namespace MaxPool2d_CeilMode_ExpectedOutput