Skip to content

Commit a894893

Browse files
authored
Address MaxUnpool shortcomings msrc116345 (#28550)
This pull request enhances the robustness of the `MaxUnpool` operator in ONNX Runtime by adding additional input validation and expanding test coverage for invalid input scenarios. The changes improve error handling for mismatched shapes, invalid dimensions, and other edge cases, ensuring the operator fails gracefully and predictably when given incorrect inputs. **Operator input validation improvements:** * Added runtime checks in `MaxUnpool::Compute` to ensure the `kernel_shape` rank matches the expected pooling dimensions, and that the indices tensor is present and correctly shaped. * Added validation to ensure that computed output dimensions are positive, with descriptive error messages if not. * Enforced that the `output_shape` tensor, if provided, must have the same number of elements as the rank of the input tensor. **Test coverage enhancements:** * Introduced multiple new tests in `unpool_op_test.cc` to cover invalid input cases, including mismatched indices shapes, rank-0 and rank-2 input tensors, negative indices, and incorrect `output_shape` element counts. These tests confirm that the operator fails with appropriate error messages in these scenarios. **References** onnx/onnx#7997 #28524
1 parent 67156b8 commit a894893

3 files changed

Lines changed: 382 additions & 12 deletions

File tree

onnxruntime/core/providers/cpu/nn/Unpool.cc

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,32 +46,43 @@ Status MaxUnpool::Compute(OpKernelContext* context) const {
4646
const TensorShape& X_shape = X->Shape();
4747
const auto* X_data = X->Data<float>();
4848

49+
// Spec: "Dimensions ... are in the form of (N x C x D1 x D2 ... Dn)" — minimum rank is 3.
4950
ORT_RETURN_IF_NOT(X_shape.NumDimensions() >= 3, "Input dimension cannot be less than 3.");
5051

51-
// Supported sizes check
52+
// Implementation limitation: only 1D/2D/3D spatial pooling supported.
5253
size_t pooling_dims = X_shape.NumDimensions() - 2;
5354
if (pooling_dims > 3) {
5455
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported pooling size.");
5556
}
5657

58+
// Spec: "The size of the kernel along each axis" — must match number of spatial dims.
59+
ORT_RETURN_IF_NOT(kernel_shape_.size() == pooling_dims,
60+
"kernel_shape rank mismatch: expected ", pooling_dims, " got ", kernel_shape_.size());
61+
5762
// Get pooled index tensor
5863
const auto* I = context->Input<Tensor>(1);
5964
const TensorShape& I_shape = I->Shape();
6065
const auto* I_data = I->Data<int64_t>();
6166

67+
// Spec: Input I "Dimensions must be the same as input tensor X."
6268
ORT_RETURN_IF_NOT(I_shape == X_shape, "Index tensor shape should be same as that of the input data tensor to unpool.");
6369

6470
// Calculate output tensor shape from attributes
65-
std::vector<int64_t> inferred_output_dims(X_shape.NumDimensions());
71+
TensorShapeVector inferred_output_dims(X_shape.NumDimensions());
6672

6773
// Copy batch and channel dims
6874
inferred_output_dims[0] = X_shape[0];
6975
inferred_output_dims[1] = X_shape[1];
7076

7177
// For feature dims calculate reversing the formula used for MaxPool
7278
for (size_t dim = 0; dim < kernel_shape_.size(); ++dim) {
73-
inferred_output_dims[dim + 2] =
74-
(X_shape[dim + 2] - 1) * strides_[dim] - (pads_[dim] + pads_[kernel_shape_.size() + dim]) + kernel_shape_[dim];
79+
int64_t dim_value = (X_shape[dim + 2] - 1) * strides_[dim] -
80+
(pads_[dim] + pads_[kernel_shape_.size() + dim]) + kernel_shape_[dim];
81+
// Each inferred spatial dim must be positive for a valid unpooling configuration.
82+
ORT_RETURN_IF_NOT(dim_value > 0,
83+
"Computed output dimension is not positive for axis ", dim + 2,
84+
". Check kernel_shape, strides, and pads attributes.");
85+
inferred_output_dims[dim + 2] = dim_value;
7586
}
7687

7788
TensorShape shape(inferred_output_dims);
@@ -80,14 +91,29 @@ Status MaxUnpool::Compute(OpKernelContext* context) const {
8091
auto tensor_shape = context->Input<Tensor>(2);
8192
if (tensor_shape == nullptr)
8293
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
94+
// Spec: output_shape is a 1-D tensor of int64.
8395
ORT_RETURN_IF_NOT(tensor_shape->Shape().GetDims().size() == 1,
8496
"Shape must be 1 dimensional as it's tensor data of a shape");
8597

98+
// Spec: output_shape specifies the full output shape (N x C x D1 x ... x Dn) — same rank as X.
99+
ORT_RETURN_IF_NOT(
100+
static_cast<size_t>(tensor_shape->Shape().Size()) == X_shape.NumDimensions(),
101+
"output_shape must have the same number of elements as the rank of input tensor X."
102+
" Got ",
103+
tensor_shape->Shape().Size(), ", expected ", X_shape.NumDimensions());
104+
86105
// Turn the shape tensor data into an actual shape
87-
const auto* p_shape = tensor_shape->Data<int64_t>();
88-
std::vector<int64_t> given_output_dims(p_shape, p_shape + tensor_shape->Shape().Size());
89-
TensorShape given_shape(given_output_dims);
106+
auto output_shape_span = tensor_shape->DataAsSpan<int64_t>();
107+
TensorShape given_shape(output_shape_span);
108+
109+
// Spec: output shape is (N x C x D1 x ... x Dn) — batch and channel must match input.
110+
ORT_RETURN_IF_NOT(given_shape[0] == X_shape[0] && given_shape[1] == X_shape[1],
111+
"output_shape batch and channel dimensions must match input. "
112+
"Expected [",
113+
X_shape[0], ", ", X_shape[1], "], got [",
114+
given_shape[0], ", ", given_shape[1], "].");
90115

116+
// Spec: output_shape disambiguates size — must be at least as large as the inferred minimum.
91117
ORT_RETURN_IF_NOT(given_shape.Size() >= shape.Size(),
92118
"output_shape is smaller than minimum required. output_shape:", given_shape,
93119
" inferred output shape:", shape);
@@ -97,18 +123,17 @@ Status MaxUnpool::Compute(OpKernelContext* context) const {
97123

98124
// unpool
99125
size_t total_elements = narrow<size_t>(X_shape.Size());
100-
size_t output_size = narrow<size_t>(shape.Size());
101126

102127
Tensor* Y = context->Output(0, shape);
103-
auto* Y_data = Y->MutableData<float>();
104-
auto out = gsl::make_span(Y_data, output_size);
128+
auto out = Y->MutableDataAsSpan<float>();
105129
std::fill_n(out.data(), out.size(), 0.f);
106130

107131
for (size_t cur_elem = 0; cur_elem < total_elements; ++cur_elem) {
108132
const int64_t idx = I_data[cur_elem];
109-
if (idx < 0 || idx >= static_cast<int64_t>(output_size)) {
133+
// Spec: "the values in indices are in the range [0, N x C x D1 x ... x Dn)."
134+
if (idx < 0 || idx >= static_cast<int64_t>(out.size())) {
110135
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
111-
"Index value out of bounds. Got: ", idx, ". Valid range is [0, ", output_size, ").");
136+
"Index value out of bounds. Got: ", idx, ". Valid range is [0, ", out.size(), ").");
112137
}
113138

114139
out[static_cast<size_t>(idx)] = X_data[cur_elem];

onnxruntime/core/providers/cpu/nn/unpool.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class MaxUnpool : public OpKernel {
3030
strides_.resize(kernel_shape_.size(), 1);
3131
}
3232

33+
ORT_ENFORCE(pads_.size() == kernel_shape_.size() * 2,
34+
"Pads attribute size must be twice the kernel_shape size. Got: ", pads_.size(),
35+
", expected: ", kernel_shape_.size() * 2);
36+
3337
for (size_t dim = 0; dim < kernel_shape_.size(); ++dim) {
3438
ORT_ENFORCE(kernel_shape_[dim] > 0);
3539
ORT_ENFORCE(pads_[dim] < kernel_shape_[dim] && pads_[dim + kernel_shape_.size()] < kernel_shape_[dim],

0 commit comments

Comments
 (0)