Skip to content

Commit 12916e9

Browse files
authored
Check weight shape dimensions in ConvTranspose shape inference msrc116345 (#28524)
This pull request introduces comprehensive validation and error handling improvements for the ConvTranspose operator across CPU, CUDA, WebGPU, and XNNPACK backends, as well as in shape inference and unit tests. The main focus is to ensure that invalid input shapes (especially rank-0 or rank-1 tensors) are properly detected and reported, preventing undefined behavior and improving robustness. Additionally, error messages are clarified, and several helper functions now return `Status` for better error propagation. **Validation and Error Handling Improvements:** * All ConvTranspose implementations (CPU, CUDA, WebGPU) now explicitly check that input `X` and filter `W` tensors have at least 3 dimensions, returning clear error messages if not. (`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R65-R79)`, `[[2]](diffhunk://#diff-d1bbcb0542b5acea587ac929cd6362cfd11172c522505c6db8b457a9d470c63dR273-R289)`, `[[3]](diffhunk://#diff-b615243d0702e9613bd815173108306495b0f690294001e606823b77322f6fafR22-L28)`) * The shape inference function for `ConvTransposeWithDynamicPads` now fails gracefully with descriptive errors if input or weight tensors have fewer than 2 dimensions. (`[onnxruntime/core/graph/contrib_ops/contrib_defs.ccL62-R67](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L62-R67)`) * Additional validation ensures that `output_padding` and dynamic pads have correct sizes, and that `output_padding` values are within ONNX-specified limits. (`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R138-R153)`, `[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R171-R187)`) **Refactoring for Robustness:** * Helper functions such as `ComputePadsAndOutputShape` and `ComputeTransposePadAndOutputShape` now return `Status`, allowing errors to propagate and be handled appropriately rather than causing crashes or silent failures. (`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L165-R234)`, `[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L194-R262)`, `[[3]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L220-R282)`, `[[4]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R291-R302)`) * All call sites (CPU, CUDA, WebGPU, XNNPACK) are updated to handle and propagate these errors using `ORT_RETURN_IF_ERROR` or `ORT_THROW_IF_ERROR`. (`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R171-R187)`, `[[2]](diffhunk://#diff-d1bbcb0542b5acea587ac929cd6362cfd11172c522505c6db8b457a9d470c63dL362-R379)`, `[[3]](diffhunk://#diff-b615243d0702e9613bd815173108306495b0f690294001e606823b77322f6fafL48-R60)`, `[[4]](diffhunk://#diff-6a2f8672090f25850b90b266aff3c7212552fc81b14bb7b539e9e5161c9fd526L494-R497)`) **Unit Test Enhancements:** * New negative tests are added to verify that rank-0 and rank-1 weight tensors are properly rejected and produce the expected error messages, increasing test coverage and reliability. (`[onnxruntime/test/contrib_ops/conv_transpose_with_dynamic_pads_test.ccR22-R56](diffhunk://#diff-cb5bfc51d0c8096922eb674d142f0e970d5becd140b47bdfd7729a06a818b598R22-R56)`) **Minor Code Quality Improvements:** * Improved memory management in the CPU implementation by wrapping the allocated buffer in `BufferUniquePtr` immediately to prevent leaks if exceptions are thrown. (`[onnxruntime/core/providers/cpu/nn/conv_transpose.ccR79-R89](diffhunk://#diff-0dcb5a9c8ba0c4e67940e9d77f77cb706bbf82d67bf6757967099b0a69c797b5R79-R89)`) * Minor includes and type safety improvements (e.g., use of `SafeInt` for overflow protection). (`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R22)`, `[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R291-R302)`) **Summary of Most Important Changes:** **1. Validation and Error Handling** - All ConvTranspose implementations now check that input and filter tensors have at least 3 dimensions, returning clear errors if not. (`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R65-R79)`, `[[2]](diffhunk://#diff-d1bbcb0542b5acea587ac929cd6362cfd11172c522505c6db8b457a9d470c63dR273-R289)`, `[[3]](diffhunk://#diff-b615243d0702e9613bd815173108306495b0f690294001e606823b77322f6fafR22-L28)`) - Shape inference for `ConvTransposeWithDynamicPads` fails with descriptive errors for invalid input or weight tensor ranks. (`[onnxruntime/core/graph/contrib_ops/contrib_defs.ccL62-R67](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L62-R67)`) - Additional checks for `output_padding` and dynamic pads sizes and values, with ONNX spec compliance. (`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R138-R153)`, `[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R171-R187)`) **2. Error Propagation and Refactoring** - Helper functions now return `Status` and propagate errors; all call sites updated to handle these errors. (`[[1]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L165-R234)`, `[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L194-R262)`, `[[3]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8L220-R282)`, `[[4]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R291-R302)`, `[[5]](diffhunk://#diff-d1bbcb0542b5acea587ac929cd6362cfd11172c522505c6db8b457a9d470c63dL362-R379)`, `[[6]](diffhunk://#diff-b615243d0702e9613bd815173108306495b0f690294001e606823b77322f6fafL48-R60)`, `[[7]](diffhunk://#diff-6a2f8672090f25850b90b266aff3c7212552fc81b14bb7b539e9e5161c9fd526L494-R497)`) **3. Unit Testing** - Added tests to ensure invalid weight tensor ranks are rejected with proper error messages. (`[onnxruntime/test/contrib_ops/conv_transpose_with_dynamic_pads_test.ccR22-R56](diffhunk://#diff-cb5bfc51d0c8096922eb674d142f0e970d5becd140b47bdfd7729a06a818b598R22-R56)`) **4. Code Quality** - Improved buffer management and type safety in CPU backend. (`[[1]](diffhunk://#diff-0dcb5a9c8ba0c4e67940e9d77f77cb706bbf82d67bf6757967099b0a69c797b5R79-R89)`, `[[2]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R22)`, `[[3]](diffhunk://#diff-72fa27d94d5d92dd1e78ff510ef9a84d1ad74426c19af9722cf6511f8d38a5a8R291-R302)`)
1 parent c85f6eb commit 12916e9

9 files changed

Lines changed: 462 additions & 42 deletions

File tree

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,22 @@ void convTransposeWithDynamicPadsShapeInference(InferenceContext& ctx) {
5656
}
5757

5858
int64_t group = getAttribute(ctx, "group", 1);
59+
if (group <= 0) {
60+
fail_shape_inference("group attribute must be positive. Got: ", group);
61+
}
5962

6063
auto input_shape = ctx.getInputType(0)->tensor_type().shape();
61-
if (input_shape.dim_size() < 2) {
62-
return; // Input tensor should have at least two dimensions.
64+
// ConvTranspose requires X=(N x C x D1...Dn) and W=(C x M/group x k1...kn), both rank >= 3.
65+
// The upstream ONNX ConvTranspose shape inference only checks rank >= 2, which allows rank-2
66+
// inputs to pass shape inference but crash at kernel execution time. We tighten the check here
67+
// to fail early at model load with a clear error. Fixing ONNX upstream is tracked separately.
68+
if (input_shape.dim_size() < 3) {
69+
fail_shape_inference("Input tensor must have at least 3 dimensions. Got: ", input_shape.dim_size());
70+
}
71+
72+
auto weight_shape = ctx.getInputType(1)->tensor_type().shape();
73+
if (weight_shape.dim_size() < 3) {
74+
fail_shape_inference("Weight tensor must have at least 3 dimensions. Got: ", weight_shape.dim_size());
6375
}
6476

6577
// first dim is the batch axis and the next is the number of channels.
@@ -147,7 +159,7 @@ void convTransposeWithDynamicPadsShapeInference(InferenceContext& ctx) {
147159

148160
*final_output_shape->add_dim() = input_shape.dim(0);
149161
*final_output_shape->add_dim() =
150-
ctx.getInputType(1)->tensor_type().shape().dim(1) *
162+
weight_shape.dim(1) *
151163
group; // channels should be the second dim of second input multiply
152164
// group.
153165

onnxruntime/core/providers/cpu/nn/conv_transpose.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,17 @@ Status ConvTranspose<float>::PrePack(const Tensor& tensor, int input_idx, Alloca
7676
size_t packed_filter_data_size = SafeInt<size_t>(packed_elements_per_group) * sizeof(float) * conv_transpose_attrs_.group;
7777
auto* packed_filter_data = alloc->Alloc(packed_filter_data_size);
7878

79+
// Wrap in BufferUniquePtr immediately to prevent leaks.
80+
transposed_filter_ = BufferUniquePtr(packed_filter_data, BufferDeleter(std::move(alloc)));
81+
7982
// Initialize memory to 0 as there could be some padding associated with pre-packed
8083
// buffer memory and we don not want it uninitialized and generate different hashes
8184
// if and when we try to cache this pre-packed buffer for sharing between sessions.
8285
memset(packed_filter_data, 0, packed_filter_data_size);
8386

84-
transposed_filter_ = BufferUniquePtr(packed_filter_data, BufferDeleter(std::move(alloc)));
85-
8687
for (int64_t group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) {
8788
MlasTranspose(tensor.Data<float>() + (group_id * N * K),
88-
((float*)packed_filter_data) + (group_id * packed_elements_per_group),
89+
static_cast<float*>(packed_filter_data) + (group_id * packed_elements_per_group),
8990
K, N, nullptr);
9091
}
9192

onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h

Lines changed: 123 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
#pragma once
2020

21+
#include <algorithm>
22+
2123
#include "core/providers/cpu/nn/conv_attributes.h"
24+
#include "core/common/safeint.h"
2225

2326
namespace onnxruntime {
2427

@@ -61,6 +64,21 @@ struct ConvTransposeAttributes : public ConvAttributes {
6164
const Tensor* B = has_bias ? (dynamic_padding ? context->Input<Tensor>(3) : context->Input<Tensor>(2)) : nullptr;
6265

6366
const int rank = static_cast<int>(X->Shape().NumDimensions());
67+
68+
// ConvTranspose requires X shape (N x C x D1...Dn) and W shape (C x M/group x k1...kn),
69+
// both must have at least 3 dimensions.
70+
if (rank < 3) {
71+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
72+
"Input X must have at least 3 dimensions (N x C x D1...Dn).",
73+
" X: ", X->Shape().ToString().c_str());
74+
}
75+
76+
if (static_cast<int>(F_Shape.NumDimensions()) < 3) {
77+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
78+
"Filter W must have at least 3 dimensions (C x M/group x k1...kn).",
79+
" W: ", F_Shape.ToString().c_str());
80+
}
81+
6482
TensorShape input_shape = X->Shape().Slice(is_nhwc ? 1 : 2, is_nhwc ? rank - 1 : rank);
6583
const int64_t num_input_channels = is_nhwc ? X->Shape()[rank - 1] : X->Shape()[1];
6684
const int64_t N = X->Shape()[0];
@@ -119,11 +137,32 @@ struct ConvTransposeAttributes : public ConvAttributes {
119137
if (local_output_padding.empty()) {
120138
local_output_padding.resize(kernel_shape.size(), 0);
121139
}
140+
if (local_output_padding.size() != kernel_shape.size()) {
141+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
142+
"output_padding size (", local_output_padding.size(),
143+
") does not match the number of spatial dimensions (", kernel_shape.size(), ").");
144+
}
122145
ConvPadVector local_pads;
123146
local_pads.reserve(2 * (input_shape.NumDimensions()));
124147
if (dynamic_padding) {
125-
for (int64_t i = 0; i < Pads->Shape().SizeFromDimension(0); ++i) {
126-
local_pads.push_back(Pads->Data<int64_t>()[i]);
148+
if (Pads->Shape().NumDimensions() != 1) {
149+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
150+
"Dynamic pads tensor must be 1-D. Got rank: ", Pads->Shape().NumDimensions());
151+
}
152+
const int64_t expected_pads_size = SafeInt<int64_t>(kernel_shape.size()) * 2;
153+
const int64_t actual_pads_size = Pads->Shape().SizeFromDimension(0);
154+
if (actual_pads_size != expected_pads_size) {
155+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
156+
"Dynamic pads tensor size (", actual_pads_size,
157+
") does not match expected size (2 * spatial_dims = ", expected_pads_size, ").");
158+
}
159+
const auto* pads_data = Pads->Data<int64_t>();
160+
for (int64_t i = 0; i < actual_pads_size; ++i) {
161+
if (pads_data[i] < 0) {
162+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
163+
"Dynamic pads must be non-negative. Got pads[", i, "] = ", pads_data[i]);
164+
}
165+
local_pads.push_back(pads_data[i]);
127166
}
128167
} else {
129168
local_pads.assign(pads.begin(), pads.end());
@@ -140,10 +179,34 @@ struct ConvTransposeAttributes : public ConvAttributes {
140179
local_strides.resize(kernel_shape.size(), 1);
141180
}
142181

182+
if (local_strides.size() != kernel_shape.size()) {
183+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
184+
"strides size (", local_strides.size(),
185+
") does not match the number of spatial dimensions (", kernel_shape.size(), ").");
186+
}
187+
if (local_dilations.size() != kernel_shape.size()) {
188+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
189+
"dilations size (", local_dilations.size(),
190+
") does not match the number of spatial dimensions (", kernel_shape.size(), ").");
191+
}
192+
193+
// ONNX spec: "output_padding[i] should be less than max(stride[i], dilation[i])".
194+
// This constraint ensures the output_padding is unambiguous — larger values would shift
195+
// the output by more than one stride/dilation step, making the inverse of Conv ill-defined.
196+
for (size_t i = 0; i < local_output_padding.size(); ++i) {
197+
int64_t limit = std::max(local_strides[i], local_dilations[i]);
198+
if (local_output_padding[i] >= limit) {
199+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
200+
"output_padding[", i, "] (", local_output_padding[i],
201+
") must be less than max(stride, dilation) (", limit,
202+
") for spatial dimension ", i, ".");
203+
}
204+
}
205+
143206
TensorShapeVector Y_dims;
144207

145-
ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape,
146-
local_strides, local_dilations, local_output_padding, N, &local_pads, &Y_dims, is_nhwc);
208+
ORT_RETURN_IF_ERROR(ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape,
209+
local_strides, local_dilations, local_output_padding, N, &local_pads, &Y_dims, is_nhwc));
147210
TensorShape Yshape(Y_dims);
148211
Tensor* Y = context->Output(0, Yshape);
149212

@@ -162,50 +225,68 @@ struct ConvTransposeAttributes : public ConvAttributes {
162225
return Status::OK();
163226
}
164227

165-
void ComputePadsAndOutputShape(TensorShape input_shape, int64_t output_channel,
166-
const TensorShapeVector& kernel_shape, const TensorShapeVector& p_strides,
167-
const TensorShapeVector& p_dilations, const TensorShapeVector& p_output_padding, const int64_t N,
168-
ConvPadVector* p_pads, TensorShapeVector* output_shape_p,
169-
bool is_nhwc = false) const {
228+
Status ComputePadsAndOutputShape(TensorShape input_shape, int64_t output_channel,
229+
const TensorShapeVector& kernel_shape, const TensorShapeVector& p_strides,
230+
const TensorShapeVector& p_dilations, const TensorShapeVector& p_output_padding, const int64_t N,
231+
ConvPadVector* p_pads, TensorShapeVector* output_shape_p,
232+
bool is_nhwc = false) const {
170233
size_t output_shape_size = output_shape.size();
234+
size_t rank = input_shape.NumDimensions();
235+
236+
if (p_pads->size() != 2 * rank) {
237+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
238+
"pads size (", p_pads->size(), ") does not match expected size (2 * ", rank, ").");
239+
}
240+
241+
// output_shape attribute, if specified, must have either 'rank' or 'rank + 2' elements
242+
if (output_shape_size != 0 && output_shape_size != rank && output_shape_size != rank + 2) {
243+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
244+
"output_shape attribute has ", output_shape_size,
245+
" elements, expected ", rank, " or ", rank + 2, ".");
246+
}
247+
171248
if (is_nhwc) {
172249
output_shape_p->insert(output_shape_p->begin(), {N});
173250
} else {
174251
output_shape_p->insert(output_shape_p->begin(), {N, output_channel});
175252
}
176253

177-
size_t rank = input_shape.NumDimensions();
178254
for (size_t dim = 0; dim < rank; ++dim) {
179255
int64_t dim_size = -1;
180256

181257
if (output_shape_size != 0) {
182258
dim_size = output_shape_size == rank ? output_shape[dim] : output_shape[dim + 2];
183259
}
184260

185-
ComputeTransposePadAndOutputShape(
261+
ORT_RETURN_IF_ERROR(ComputeTransposePadAndOutputShape(
186262
input_shape[dim],
187263
p_strides[dim],
188264
kernel_shape[dim],
189265
p_dilations[dim],
190266
p_output_padding[dim],
191267
auto_pad,
192-
&p_pads->at(dim),
193-
&p_pads->at(input_shape.NumDimensions() + dim),
194-
&dim_size);
268+
&(*p_pads)[dim],
269+
&(*p_pads)[input_shape.NumDimensions() + dim],
270+
&dim_size));
195271

196-
ORT_ENFORCE(dim_size > 0, "Invalid input shape: ", input_shape.ToString());
272+
if (dim_size <= 0) {
273+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
274+
"Computed output dimension is <= 0 for dim ", dim,
275+
". Input shape: ", input_shape.ToString());
276+
}
197277
output_shape_p->push_back(dim_size);
198278
}
199279
if (is_nhwc) {
200280
output_shape_p->push_back(output_channel);
201281
}
282+
return Status::OK();
202283
}
203284

204285
TensorShapeVector output_padding;
205286
TensorShapeVector output_shape;
206287

207288
private:
208-
void ComputeTransposePadAndOutputShape(
289+
Status ComputeTransposePadAndOutputShape(
209290
const int64_t in_size,
210291
const int64_t stride,
211292
const int64_t kernel,
@@ -217,27 +298,48 @@ struct ConvTransposeAttributes : public ConvAttributes {
217298
int64_t* out_size) const {
218299
// Output shape is explicitly provided - pad values will have to be computed
219300
if (*out_size != -1) {
220-
ORT_ENFORCE(*out_size >= 0);
301+
if (*out_size < 0) {
302+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
303+
"Explicit output size is negative: ", *out_size);
304+
}
221305
// total pad
222306
auto total_pad = ComputeTotalPad(in_size, stride, adj,
223307
kernel, dilation, *out_size);
224308
DistributePadding(pad_type, total_pad, *pad_head, *pad_tail);
225-
return;
309+
return Status::OK();
226310
}
227311

228312
// Output shape is not provided - it needs to be computed along with pad values (if applicable)
229313

314+
// Validate that stride, kernel, and dilation are positive
315+
if (stride <= 0) {
316+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Stride must be positive. Got: ", stride);
317+
}
318+
if (kernel <= 0) {
319+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Kernel size must be positive. Got: ", kernel);
320+
}
321+
if (dilation <= 0) {
322+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Dilation must be positive. Got: ", dilation);
323+
}
324+
if (adj < 0) {
325+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Output padding must be non-negative. Got: ", adj);
326+
}
327+
230328
// Compute padding if the auto_pad attribute is SAME_UPPER/SAME_LOWER
231329
if (pad_type == AutoPadType::SAME_UPPER || pad_type == AutoPadType::SAME_LOWER) {
232330
// The ONNX spec says if `auto_pad` attribute is set, pad until the `out_size`
233331
// is `in_size * stride`
332+
int64_t auto_out_size = SafeInt<int64_t>(in_size) * stride;
234333
auto total_pad = ComputeTotalPad(in_size, stride, adj,
235-
kernel, dilation, /*out_size = */ in_size * stride);
334+
kernel, dilation, auto_out_size);
236335
DistributePadding(pad_type, total_pad, *pad_head, *pad_tail);
237336
}
238337

239-
*out_size =
240-
(in_size - 1) * stride + adj + (kernel - 1) * dilation + 1 - *pad_head - *pad_tail;
338+
// *out_size = (in_size - 1) * stride + adj + (kernel - 1) * dilation + 1 - *pad_head - *pad_tail
339+
*out_size = SafeInt<int64_t>(in_size - 1) * stride + adj +
340+
SafeInt<int64_t>(kernel - 1) * dilation + 1 -
341+
*pad_head - *pad_tail;
342+
return Status::OK();
241343
}
242344
};
243345

0 commit comments

Comments
 (0)