Commit 12916e9
authored
Check weight shape dimensions in ConvTranspose shape inference msrc116345 (#28524)
This pull request introduces comprehensive validation and error handling
improvements for the ConvTranspose operator across CPU, CUDA, WebGPU,
and XNNPACK backends, as well as in shape inference and unit tests. The
main focus is to ensure that invalid input shapes (especially rank-0 or
rank-1 tensors) are properly detected and reported, preventing undefined
behavior and improving robustness. Additionally, error messages are
clarified, and several helper functions now return `Status` for better
error propagation.
**Validation and Error Handling Improvements:**
* All ConvTranspose implementations (CPU, CUDA, WebGPU) now explicitly
check that input `X` and filter `W` tensors have at least 3 dimensions,
returning clear error messages if not.
(`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R65-R79)`,
`[[2]](diffhunk://#diff-d1bbcb0542b5acea587ac929cd6362cfd11172c522505c6db8b457a9d470c63dR273-R289)`,
`[[3]](diffhunk://#diff-b615243d0702e9613bd815173108306495b0f690294001e606823b77322f6fafR22-L28)`)
* The shape inference function for `ConvTransposeWithDynamicPads` now
fails gracefully with descriptive errors if input or weight tensors have
fewer than 2 dimensions.
(`[onnxruntime/core/graph/contrib_ops/contrib_defs.ccL62-R67](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L62-R67)`)
* Additional validation ensures that `output_padding` and dynamic pads
have correct sizes, and that `output_padding` values are within
ONNX-specified limits.
(`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R138-R153)`,
`[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R171-R187)`)
**Refactoring for Robustness:**
* Helper functions such as `ComputePadsAndOutputShape` and
`ComputeTransposePadAndOutputShape` now return `Status`, allowing errors
to propagate and be handled appropriately rather than causing crashes or
silent failures.
(`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L165-R234)`,
`[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L194-R262)`,
`[[3]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L220-R282)`,
`[[4]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R291-R302)`)
* All call sites (CPU, CUDA, WebGPU, XNNPACK) are updated to handle and
propagate these errors using `ORT_RETURN_IF_ERROR` or
`ORT_THROW_IF_ERROR`.
(`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R171-R187)`,
`[[2]](diffhunk://#diff-d1bbcb0542b5acea587ac929cd6362cfd11172c522505c6db8b457a9d470c63dL362-R379)`,
`[[3]](diffhunk://#diff-b615243d0702e9613bd815173108306495b0f690294001e606823b77322f6fafL48-R60)`,
`[[4]](diffhunk://#diff-6a2f8672090f25850b90b266aff3c7212552fc81b14bb7b539e9e5161c9fd526L494-R497)`)
**Unit Test Enhancements:**
* New negative tests are added to verify that rank-0 and rank-1 weight
tensors are properly rejected and produce the expected error messages,
increasing test coverage and reliability.
(`[onnxruntime/test/contrib_ops/conv_transpose_with_dynamic_pads_test.ccR22-R56](diffhunk://#diff-cb5bfc51d0c8096922eb674d142f0e970d5becd140b47bdfd7729a06a818b598R22-R56)`)
**Minor Code Quality Improvements:**
* Improved memory management in the CPU implementation by wrapping the
allocated buffer in `BufferUniquePtr` immediately to prevent leaks if
exceptions are thrown.
(`[onnxruntime/core/providers/cpu/nn/conv_transpose.ccR79-R89](diffhunk://#diff-0dcb5a9c8ba0c4e67940e9d77f77cb706bbf82d67bf6757967099b0a69c797b5R79-R89)`)
* Minor includes and type safety improvements (e.g., use of `SafeInt`
for overflow protection).
(`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R22)`,
`[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R291-R302)`)
**Summary of Most Important Changes:**
**1. Validation and Error Handling**
- All ConvTranspose implementations now check that input and filter
tensors have at least 3 dimensions, returning clear errors if not.
(`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R65-R79)`,
`[[2]](diffhunk://#diff-d1bbcb0542b5acea587ac929cd6362cfd11172c522505c6db8b457a9d470c63dR273-R289)`,
`[[3]](diffhunk://#diff-b615243d0702e9613bd815173108306495b0f690294001e606823b77322f6fafR22-L28)`)
- Shape inference for `ConvTransposeWithDynamicPads` fails with
descriptive errors for invalid input or weight tensor ranks.
(`[onnxruntime/core/graph/contrib_ops/contrib_defs.ccL62-R67](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L62-R67)`)
- Additional checks for `output_padding` and dynamic pads sizes and
values, with ONNX spec compliance.
(`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R138-R153)`,
`[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R171-R187)`)
**2. Error Propagation and Refactoring**
- Helper functions now return `Status` and propagate errors; all call
sites updated to handle these errors.
(`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L165-R234)`,
`[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L194-R262)`,
`[[3]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L220-R282)`,
`[[4]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R291-R302)`,
`[[5]](diffhunk://#diff-d1bbcb0542b5acea587ac929cd6362cfd11172c522505c6db8b457a9d470c63dL362-R379)`,
`[[6]](diffhunk://#diff-b615243d0702e9613bd815173108306495b0f690294001e606823b77322f6fafL48-R60)`,
`[[7]](diffhunk://#diff-6a2f8672090f25850b90b266aff3c7212552fc81b14bb7b539e9e5161c9fd526L494-R497)`)
**3. Unit Testing**
- Added tests to ensure invalid weight tensor ranks are rejected with
proper error messages.
(`[onnxruntime/test/contrib_ops/conv_transpose_with_dynamic_pads_test.ccR22-R56](diffhunk://#diff-cb5bfc51d0c8096922eb674d142f0e970d5becd140b47bdfd7729a06a818b598R22-R56)`)
**4. Code Quality**
- Improved buffer management and type safety in CPU backend.
(`[[1]](diffhunk://#diff-0dcb5a9c8ba0c4e67940e9d77f77cb706bbf82d67bf6757967099b0a69c797b5R79-R89)`,
`[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R22)`,
`[[3]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R291-R302)`)1 parent c85f6eb commit 12916e9
9 files changed
Lines changed: 462 additions & 42 deletions
File tree
- onnxruntime
- core
- graph/contrib_ops
- providers
- cpu/nn
- cuda/nn
- webgpu/nn
- xnnpack/nn
- test
- contrib_ops
- providers
- cpu/nn
- cuda/nhwc
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
59 | 62 | | |
60 | 63 | | |
61 | | - | |
62 | | - | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
63 | 75 | | |
64 | 76 | | |
65 | 77 | | |
| |||
147 | 159 | | |
148 | 160 | | |
149 | 161 | | |
150 | | - | |
| 162 | + | |
151 | 163 | | |
152 | 164 | | |
153 | 165 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
79 | 82 | | |
80 | 83 | | |
81 | 84 | | |
82 | 85 | | |
83 | 86 | | |
84 | | - | |
85 | | - | |
86 | 87 | | |
87 | 88 | | |
88 | | - | |
| 89 | + | |
89 | 90 | | |
90 | 91 | | |
91 | 92 | | |
| |||
Lines changed: 123 additions & 21 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
| 22 | + | |
21 | 23 | | |
| 24 | + | |
22 | 25 | | |
23 | 26 | | |
24 | 27 | | |
| |||
61 | 64 | | |
62 | 65 | | |
63 | 66 | | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
64 | 82 | | |
65 | 83 | | |
66 | 84 | | |
| |||
119 | 137 | | |
120 | 138 | | |
121 | 139 | | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
122 | 145 | | |
123 | 146 | | |
124 | 147 | | |
125 | | - | |
126 | | - | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
127 | 166 | | |
128 | 167 | | |
129 | 168 | | |
| |||
140 | 179 | | |
141 | 180 | | |
142 | 181 | | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
143 | 206 | | |
144 | 207 | | |
145 | | - | |
146 | | - | |
| 208 | + | |
| 209 | + | |
147 | 210 | | |
148 | 211 | | |
149 | 212 | | |
| |||
162 | 225 | | |
163 | 226 | | |
164 | 227 | | |
165 | | - | |
166 | | - | |
167 | | - | |
168 | | - | |
169 | | - | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
170 | 233 | | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
171 | 248 | | |
172 | 249 | | |
173 | 250 | | |
174 | 251 | | |
175 | 252 | | |
176 | 253 | | |
177 | | - | |
178 | 254 | | |
179 | 255 | | |
180 | 256 | | |
181 | 257 | | |
182 | 258 | | |
183 | 259 | | |
184 | 260 | | |
185 | | - | |
| 261 | + | |
186 | 262 | | |
187 | 263 | | |
188 | 264 | | |
189 | 265 | | |
190 | 266 | | |
191 | 267 | | |
192 | | - | |
193 | | - | |
194 | | - | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
195 | 271 | | |
196 | | - | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
197 | 277 | | |
198 | 278 | | |
199 | 279 | | |
200 | 280 | | |
201 | 281 | | |
| 282 | + | |
202 | 283 | | |
203 | 284 | | |
204 | 285 | | |
205 | 286 | | |
206 | 287 | | |
207 | 288 | | |
208 | | - | |
| 289 | + | |
209 | 290 | | |
210 | 291 | | |
211 | 292 | | |
| |||
217 | 298 | | |
218 | 299 | | |
219 | 300 | | |
220 | | - | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
221 | 305 | | |
222 | 306 | | |
223 | 307 | | |
224 | 308 | | |
225 | | - | |
| 309 | + | |
226 | 310 | | |
227 | 311 | | |
228 | 312 | | |
229 | 313 | | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
230 | 328 | | |
231 | 329 | | |
232 | 330 | | |
233 | 331 | | |
| 332 | + | |
234 | 333 | | |
235 | | - | |
| 334 | + | |
236 | 335 | | |
237 | 336 | | |
238 | 337 | | |
239 | | - | |
240 | | - | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
241 | 343 | | |
242 | 344 | | |
243 | 345 | | |
| |||
0 commit comments