Skip to content

Commit 9d7e6d5

Browse files
[CPU/CUDA ep] Improve DeformConv op performance (#27824)
### Description Improve DeformConv op performance ### Motivation and Context This PR consolidates a series of optimizations targeting the `DeformConv` (Deformable Convolution) operator across both CPU and CUDA execution providers. * **For CPU:** The previous implementation suffered from bottlenecks due to redundant computations, lack of vectorization in bilinear sampling, and sub-optimal thread pool utilization. This overhaul redesigns the memory layout and execution pipeline to maximize SIMD opportunities and harden memory safety. * **For GPU:** The batched GEMM operation previously relied on an intermediate buffer and a custom scatter kernel to format the output, which consumed extra memory and kernel launch overhead. This update introduces a zero-copy approach. --- #### 1. CPU Optimizations & Refactoring The CPU execution path has been heavily refactored to minimize branching in hot paths, maximize vectorization, and safely handle edge cases. | Feature / Optimization | Description | Key Benefit | | :--- | :--- | :--- | | **AoSoA Bilinear Sampling Plan** | Replaced on-the-fly interpolation with a precomputed sampling plan using an 8-lane Array-of-Structures-of-Arrays (AoSoA) layout (`kPlanAoSoALanes`). | Perfectly aligns with 256-bit AVX2 vectors, enabling highly efficient SIMD unrolling during the `im2col` gathering phase. | | **Kernel Metadata Caching** | Introduced `DeformConvKernelMetaCacheData` to cache static convolution geometry (e.g., `kH`, `kW`, `padding`, `dilation`). | Eliminates the O(kernel_size) overhead of reallocating and recomputing base offsets on every single `Compute()` step. | | **Fast Math & Branchless Logic** | Implemented a custom `DeformConvFastFloor` and utilized an inverted bounds check with bitwise operations to evaluate all four corners simultaneously. | Removes expensive `std::floor` calls and unpredictable branches from the operator's hottest path. | | **Enhanced Parallelization** | Flattened the bilinear sampling plan build tasks across spatial pixels. | Allows `concurrency::ThreadPool::TryParallelFor` to split fine-grained work effectively, drastically improving thread pool scaling. | | **Hardened Bounds Checking** | Introduced compute-time bounds checks using `CheckedMulSizeT` and `CheckedBatchSpan`. | Ensures batch indexing and stride calculations stay within the addressable `size_t` range, preventing integer overflow vulnerabilities. | | **Bias Addition Refactoring** | Refactored bias addition to avoid expensive `div`/`mod` operations, applying `ORT_CPU_RESTRICT` and force-inlining. | Maximizes memory throughput and instruction pipelining during the final bias addition phase. | --- #### 2. GPU (CUDA) Optimizations The CUDA implementation was optimized to reduce memory footprint and eliminate unnecessary kernel launches. * **Zero-Copy GEMM Output:** Removed the temporary `gemm_output_buffer` allocation entirely. By carefully configuring the `stride_c` parameter (`stride_c_y = M * output_image_size`), the `cublasGemmStridedBatchedHelper` now writes the computed output directly into the correct NCHW memory layout of the final `Y` tensor. * **Kernel Elimination:** Completely removed the `DeformConvCopyGemmOutputRowMajorToNCHW` custom kernel and its associated dispatch logic. This reduces kernel launch overhead, lowers GPU memory bandwidth pressure, and simplifies the overall CUDA execution pipeline. * **Reduced Memory Footprint:** Updated the `bytes_per_image` calculation for workspace memory to reflect the removal of the GEMM output buffer. This allows the operator to potentially process more images in parallel under the same memory constraints. --- #### 3. Changed - **Batch chunking:** Chunk size `k` is chosen so that the number of outer rounds is minimized under the temp-memory cap; **`k` does not have to divide `N`**. The host loop uses `cur_parallel = min(k, N - b)`, so the last chunk may be smaller. This is the intended default behavior for this EP (not yet in a formal release). - **Kernel-size templates:** Im2col is specialized for **1×1, 3×3, and 7×7**; other sizes (including **5×5**) use the **dynamic** `kH`/`kW` path. Rationale: 5×5 is less common in current stacks (often replaced by stacked 3×3); specializing 7×7 targets common large-kernel cases. Older DCN/detection models that still use **5×5** deformable conv will take the dynamic path—correctness is unchanged; only compile-time unrolling differs. - **Add aliasing flags:** Updated DeformConv aliasing comments to make the stronger guarantee explicit: if output `Y` overlaps any input buffer, results can be incorrect regardless of `restrict`, because output writes may clobber source elements before they are fully consumed. `restrict` further tightens this by introducing undefined behavior when aliasing assumptions are violated. --- ### Summary In the current implementation, CPU performance is 33x (main branch is 15x) that of TorchVision. If we were to implement AVX2/AVX512 optimizations from scratch, we could achieve a 36x performance boost. However, I haven’t found any similar reference code in the ONNX Runtime repository. This PR also significantly improves parallelism: <img width="540" height="332" alt="image" src="https://github.com/user-attachments/assets/d4f670bd-dde3-43f1-b597-4471bfde005b" /> _Both ort and tv are configured with 16 threads_ ### Open Question for Reviewers **Regarding CUDA Temporary Memory Allocation:** Currently, the effective maximum temporary memory for CUDA is calculated using a heuristic (`total_global_mem * 0.1` or similar logic in `GetDeformConvEffectiveMaxTempBytes`). While the removal of `gemm_output_buffer` has reduced the memory footprint per image, I am not entirely certain if this 10% threshold is still the most appropriate value for balancing parallel image processing (`n_parallel_imgs`) against overall VRAM consumption in large models. I would appreciate any feedback or suggestions on whether we should tune this threshold, or if there's a more robust way to dynamically determine the optimal temporary workspace size for `DeformConv` in ORT.
1 parent 9e3614b commit 9d7e6d5

6 files changed

Lines changed: 1379 additions & 490 deletions

File tree

onnxruntime/core/providers/cpu/nn/deform_conv.cc

Lines changed: 690 additions & 171 deletions
Large diffs are not rendered by default.

onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#pragma once
55

6-
#include <climits>
6+
#include <limits>
77

88
#include "core/common/common.h"
99
#include "core/framework/op_kernel.h"
@@ -73,6 +73,42 @@ struct DeformConvParams {
7373
bool use_mask{false}; // Whether optional mask input is provided
7474
};
7575

76+
// Common derived dimensions used by both CPU and CUDA kernels.
77+
struct DeformConvCommonDims {
78+
int64_t kernel_size{0}; // kH * kW
79+
int64_t output_image_size{0}; // out_h * out_w
80+
int64_t input_image_size{0}; // H * W_in
81+
int64_t kernel_dim{0}; // (C / group) * kernel_size
82+
};
83+
84+
// Validates shared runtime bounds and computes common derived dimensions.
85+
// This helper is backend-agnostic and intended to be reused by both CPU/CUDA
86+
// after DeformConvValidateAndParse() succeeds.
87+
inline Status DeformConvValidateAndComputeCommonDims(const DeformConvParams& params,
88+
DeformConvCommonDims& dims) {
89+
const int64_t int64_max = std::numeric_limits<int64_t>::max();
90+
ORT_RETURN_IF_NOT(params.N > 0 && params.C > 0 && params.M > 0 &&
91+
params.group > 0 && params.offset_group > 0 &&
92+
params.kH > 0 && params.kW > 0 &&
93+
params.H > 0 && params.W_in > 0 &&
94+
params.out_h > 0 && params.out_w > 0,
95+
"Invalid deform conv dimensions.");
96+
97+
ORT_RETURN_IF_NOT(params.kH <= int64_max / params.kW, "kernel_size overflows int64.");
98+
dims.kernel_size = params.kH * params.kW;
99+
100+
ORT_RETURN_IF_NOT(params.out_h <= int64_max / params.out_w, "output_image_size overflows int64.");
101+
dims.output_image_size = params.out_h * params.out_w;
102+
103+
ORT_RETURN_IF_NOT(params.H <= int64_max / params.W_in, "input_image_size overflows int64.");
104+
dims.input_image_size = params.H * params.W_in;
105+
106+
ORT_RETURN_IF_NOT((params.C / params.group) <= int64_max / dims.kernel_size, "kernel_dim overflows int64.");
107+
dims.kernel_dim = (params.C / params.group) * dims.kernel_size;
108+
109+
return Status::OK();
110+
}
111+
76112
// Validates inputs and parses attributes into params.
77113
// Returns Status::OK() on success; on failure, params may be partially filled.
78114
inline Status DeformConvValidateAndParse(
@@ -159,10 +195,10 @@ inline Status DeformConvValidateAndParse(
159195
params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1;
160196
ORT_RETURN_IF_NOT(params.out_h >= 0 && params.out_w >= 0, "Computed output spatial size must be non-negative.");
161197

162-
// CPU BilinearInterpolate uses int for indices (for performance optimization); W <= INT_MAX / (H+1) covers all index math.
198+
// CPU BilinearInterpolate uses int for indices (for performance optimization); W <= int_max / (H+1) covers all index math.
163199
ORT_RETURN_IF_NOT(params.H >= 0 && params.W_in >= 0, "Input spatial dimensions H and W must be non-negative.");
164-
ORT_RETURN_IF_NOT(params.W_in <= static_cast<int64_t>(INT_MAX) / (params.H + 1),
165-
"Input (H+1)*W must not exceed INT_MAX (for performance optimization).");
200+
ORT_RETURN_IF_NOT(params.W_in <= static_cast<int64_t>(std::numeric_limits<int>::max()) / (params.H + 1),
201+
"Input (H+1)*W must not exceed int max (for performance optimization).");
166202

167203
// Validate tensor shapes (use division to avoid int64 overflow in offset_group * 2 * kH * kW).
168204
ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size.");

onnxruntime/core/providers/cuda/nn/deform_conv.cc

Lines changed: 81 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22
// Licensed under the MIT License.
33
//
44
// CUDA implementation of DeformConv (deformable convolution 2D).
5+
// High-level pipeline matches CPU `nn/deform_conv.cc`: im2col then grouped GEMM then optional bias;
6+
// this file hosts the EP and batch chunking; device kernels live in `deform_conv_impl.cu`.
7+
//
8+
// High-level pipeline (batch may be chunked for col_buffer memory; see GetNParallelImgs):
9+
// (1) Deformable im2col per chunk: DeformConvIm2ColImpl launches GPU kernels that fill col_buffer
10+
// (bilinear sampling + optional mask fused in threads; no separate sampling plan like CPU).
11+
// (2) Grouped strided batched GEMM: Y = W * Col via cuBLAS (row-major vs column-major mapping in ComputeInternal).
12+
// (3) Optional bias: add B[m] to each output channel map (DeformConvAddBiasImpl).
13+
//
14+
// Main difference vs CPU path: CPU builds an AoSoA bilinear plan once per image then reuses it across channels;
15+
// CUDA recomputes bilinear samples in the im2col kernel while walking offset/mask tensors.
516

617
#include "core/providers/shared_library/provider_api.h"
718
#include "deform_conv.h"
@@ -21,31 +32,30 @@ namespace {
2132

2233
constexpr int kMaxParallelImgs = 32;
2334

24-
// Returns the greatest divisor of n that is <= bound. Used to choose uniform batch chunk sizes.
25-
// Fast path: if n % bound == 0 (common for batch 32/64/128), return immediately.
26-
// When n >= bound^2, linear scan from bound down is O(bound). Otherwise divisor enumeration
27-
// from 1 to sqrt(n) is O(sqrt(n)). Uses integer comparison (no sqrt) for branch decision.
28-
int GetGreatestDivisorBelowBound(int n, int bound) {
29-
if (bound <= 0 || n <= 0) return 1;
30-
if (n % bound == 0) return bound; // Fast path: batch is multiple of target
31-
32-
// n >= bound^2 <=> bound <= sqrt(n) => linear scan is cheaper
33-
if (static_cast<int64_t>(n) >= static_cast<int64_t>(bound) * bound) {
34-
for (int k = bound - 1; k > 1; --k) {
35-
if (n % k == 0) return k;
36-
}
37-
} else {
38-
// n < bound^2 <=> bound > sqrt(n) => divisor enumeration is cheaper
39-
int best = 1;
40-
for (int i = 1; static_cast<int64_t>(i) * i <= static_cast<int64_t>(n); ++i) {
41-
if (n % i != 0) continue;
42-
const int q = n / i;
43-
if (q <= bound && q > best) best = q;
44-
if (i <= bound && i > best) best = i;
45-
}
46-
return best;
47-
}
48-
return 1;
35+
// ceil(numer / denom) for numer >= 0, denom > 0 (integer, no floating point).
36+
// Avoid (numer + denom - 1) / denom: numer near INT_MAX overflows signed int (UB in C++).
37+
inline int CeilDiv(int numer, int denom) {
38+
return numer / denom + (numer % denom != 0 ? 1 : 0);
39+
}
40+
41+
// Chooses DeformConv batch chunk size k (images per outer-loop iteration) given batch N and
42+
// a hard cap T from temp-memory budget (target_parallel_imgs).
43+
//
44+
// Goals (in order):
45+
// 1) Minimize the number of outer rounds I = ceil(N / k). Under k <= T, the minimum achievable
46+
// I is I* = ceil(N / min(N, T)) — take the largest allowed step min(N, T), same as always
47+
// using k = T when N > T, or one round when N <= T.
48+
// 2) Among all k with ceil(N/k) == I*, pick k = ceil(N / I*) so chunk sizes are as balanced as
49+
// possible (last chunk is only slightly smaller than full chunks). k need not divide N; choosing
50+
// k = ceil(N / I*) instead of always k = T often shrinks col_buffer stride when a full-T last
51+
// chunk would leave a much smaller tail.
52+
//
53+
// Closed form: k_cap = min(N, T), I = ceil(N / k_cap), return ceil(N / I).
54+
inline int GetDeformConvParallelChunkSize(int N, int T) {
55+
if (N <= 0 || T <= 0) return 1;
56+
const int k_cap = std::min(N, T);
57+
const int num_rounds = CeilDiv(N, k_cap);
58+
return CeilDiv(N, num_rounds);
4959
}
5060

5161
// Returns the maximum temp memory (bytes) allowed for DeformConv's im2col + GEMM buffers.
@@ -76,28 +86,25 @@ size_t GetDeformConvEffectiveMaxTempBytes(size_t total_global_mem) {
7686
}
7787

7888
// Returns how many images to process in parallel per batch chunk for DeformConv.
79-
// Chooses the largest divisor of batch size N that fits in the temp budget and does not
80-
// exceed kMaxParallelImgs, so that batch dimension is split evenly (no remainder).
81-
// Note: if N is prime and N > target_parallel_imgs, the greatest divisor <= target_parallel_imgs is 1,
82-
// so batching is effectively disabled (single-image chunks).
89+
//
90+
// Temp budget → cap T (see below). Chunk size k = GetDeformConvParallelChunkSize(N, T): minimize
91+
// outer-loop rounds first, then balance chunk sizes via ceil(N / ceil(N / min(N,T))).
92+
// The host loop still uses cur_parallel = min(k, N - b), so k need not divide N.
8393
//
8494
// Formulas:
85-
// kernel_size = kH * kW
86-
// output_image_size = out_h * out_w
87-
// bytes_per_image = output_image_size * (C * kernel_size + M / group) * sizeof(T)
88-
// (temp bytes per image: im2col col buffer + GEMM output buffer per output position)
95+
// kernel_size / output_image_size come from validated common dims
96+
// bytes_per_image = output_image_size * C * kernel_size * sizeof(T)
97+
// (temp bytes per image: im2col col buffer only; GEMM writes directly to Y)
8998
// max_parallel_imgs_mem = max(1, floor(effective_max_temp / bytes_per_image))
90-
// target_parallel_imgs = min(kMaxParallelImgs, max_parallel_imgs_mem)
91-
// return GetGreatestDivisorBelowBound(N, target_parallel_imgs)
99+
// target_parallel_imgs T = min(kMaxParallelImgs, max_parallel_imgs_mem)
100+
// return GetDeformConvParallelChunkSize(N, T)
92101
template <typename T>
93-
int GetNParallelImgs(const DeformConvParams& params, size_t total_global_mem) {
102+
int GetNParallelImgs(const DeformConvParams& params, int64_t kernel_size, int64_t output_image_size, size_t total_global_mem) {
94103
const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(total_global_mem);
95-
const int64_t kernel_size = params.kH * params.kW;
96-
const int64_t output_image_size = params.out_h * params.out_w;
97-
const size_t bytes_per_image = SafeInt<size_t>(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T);
104+
const size_t bytes_per_image = SafeInt<size_t>(output_image_size) * params.C * kernel_size * sizeof(T);
98105
const int max_parallel_imgs_mem = std::max(1, static_cast<int>(effective_max_temp / std::max(size_t(1), bytes_per_image)));
99106
const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem);
100-
return GetGreatestDivisorBelowBound(static_cast<int>(params.N), target_parallel_imgs);
107+
return GetDeformConvParallelChunkSize(narrow<int>(params.N), target_parallel_imgs);
101108
}
102109

103110
} // namespace
@@ -146,21 +153,20 @@ Status DeformConv<T>::ComputeInternal(OpKernelContext* context) const {
146153
return Status::OK();
147154
}
148155

149-
const int n_parallel_imgs = GetNParallelImgs<T>(params, GetDeviceProp().totalGlobalMem);
150-
151-
const int64_t kernel_size = kH * kW;
152-
const int64_t output_image_size = out_h * out_w;
153-
const int64_t input_image_size = H * W_in;
154-
const int64_t kernel_dim = (C / group) * kernel_size;
156+
DeformConvCommonDims common_dims;
157+
ORT_RETURN_IF_ERROR(DeformConvValidateAndComputeCommonDims(params, common_dims));
158+
const int64_t kernel_size = common_dims.kernel_size;
159+
const int64_t output_image_size = common_dims.output_image_size;
160+
const int64_t input_image_size = common_dims.input_image_size;
161+
const int64_t kernel_dim = common_dims.kernel_dim;
162+
const int n_parallel_imgs = GetNParallelImgs<T>(params, kernel_size, output_image_size, GetDeviceProp().totalGlobalMem);
155163

156164
const int64_t col_stride = static_cast<int64_t>(n_parallel_imgs) * output_image_size;
157165
const int64_t col_buffer_size = (C * kernel_size) * col_stride;
158166

159167
AllocatorPtr alloc;
160168
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
161169
auto col_buffer = IAllocator::MakeUniquePtr<T>(alloc, SafeInt<size_t>(col_buffer_size));
162-
// Removed col_transposed allocation as we avoid physical transpose.
163-
auto gemm_output_buffer = IAllocator::MakeUniquePtr<T>(alloc, SafeInt<size_t>((M / group) * col_stride));
164170

165171
const T* Xdata = X->Data<T>();
166172
const T* Wdata = W->Data<T>();
@@ -180,6 +186,7 @@ Status DeformConv<T>::ComputeInternal(OpKernelContext* context) const {
180186
const int64_t cur_out_size = static_cast<int64_t>(cur_parallel) * output_image_size;
181187

182188
const T* X_block = Xdata + b * (C * input_image_size);
189+
// Stride per full image along N: offset [N, offset_group*2*kH*kW, OH, OW] -> offset_group * 2*kH*kW * OH*OW floats.
183190
const T* offset_block = offset_data + b * (offset_group * 2 * kernel_size * output_image_size);
184191
const T* mask_block = use_mask ? (mask_data + b * (offset_group * kernel_size * output_image_size)) : nullptr;
185192

@@ -215,16 +222,18 @@ Status DeformConv<T>::ComputeInternal(OpKernelContext* context) const {
215222
// - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T
216223
// - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major
217224
//
218-
// m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size.
225+
// Per batch image: m=output_image_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim,
226+
// ldc=output_image_size (row-major Y slice [M/group, OH*OW]).
219227
//
220-
// cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write
221-
// directly into Y_g. Use strided batched for all groups in one call.
222-
// cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW.
223-
224-
const bool gemm_writes_directly = (cur_parallel == 1);
225-
if (gemm_writes_directly) {
226-
// Strided batched: one call for all groups. Strides between batches:
227-
const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1
228+
// cur_parallel==1: one strided-batched GEMM over all groups (single launch).
229+
// cur_parallel>1: per group, strided-batched GEMM with batch_count=cur_parallel; each batch writes one image
230+
// directly into NCHW Y (strideC = M * output_image_size), avoiding a temp buffer + scatter kernel.
231+
232+
if (cur_parallel == 1) {
233+
// col_buffer is packed per iteration with the current chunk width (cur_out_size).
234+
// Using outer-scope col_stride (based on n_parallel_imgs) breaks tail chunks where
235+
// cur_out_size != col_stride (including one-image tails) when group > 1.
236+
const int64_t stride_col = kernel_dim * cur_out_size;
228237
const int64_t stride_weight = (M / group) * kernel_dim;
229238
const int64_t stride_y = (M / group) * output_image_size;
230239
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
@@ -249,44 +258,42 @@ Status DeformConv<T>::ComputeInternal(OpKernelContext* context) const {
249258
device_prop,
250259
UseTF32()));
251260
} else {
252-
// cur_parallel>1: GEMM output layout differs from NCHW; write to buffer then copy per group.
261+
const int64_t stride_a_col = output_image_size;
262+
const int64_t stride_b = 0;
263+
const int64_t stride_c_y = M * output_image_size;
253264
for (int64_t g = 0; g < group; ++g) {
254265
const T* W_g = Wdata + g * (M / group) * kernel_dim;
255-
const T* col_g = col_buffer.get() + g * kernel_dim * col_stride;
266+
const T* col_g = col_buffer.get() + g * kernel_dim * cur_out_size;
256267
T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size;
257268

258-
CUBLAS_RETURN_IF_ERROR((cublasGemmHelper(
269+
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
259270
cublas,
260271
CUBLAS_OP_N,
261272
CUBLAS_OP_N,
262-
narrow<int>(cur_out_size),
273+
narrow<int>(output_image_size),
263274
narrow<int>(M / group),
264275
narrow<int>(kernel_dim),
265276
&alpha,
266277
reinterpret_cast<const CudaT*>(col_g),
267278
narrow<int>(cur_out_size),
279+
stride_a_col,
268280
reinterpret_cast<const CudaT*>(W_g),
269281
narrow<int>(kernel_dim),
282+
stride_b,
270283
&beta,
271-
reinterpret_cast<CudaT*>(gemm_output_buffer.get()),
272-
narrow<int>(cur_out_size),
284+
reinterpret_cast<CudaT*>(Y_g),
285+
narrow<int>(output_image_size),
286+
stride_c_y,
287+
narrow<int>(cur_parallel),
273288
device_prop,
274-
UseTF32())));
275-
276-
ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW<T>(
277-
stream,
278-
gemm_output_buffer.get(),
279-
Y_g,
280-
M,
281-
M / group,
282-
output_image_size,
283-
cur_parallel));
289+
UseTF32()));
284290
}
285291
}
286292
}
287293

288294
if (Bdata != nullptr) {
289-
ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl<T>(stream, Ydata, Bdata, N, M, out_h, out_w));
295+
ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl<T>(stream, Ydata, Bdata, N, M, out_h, out_w,
296+
static_cast<int64_t>(device_prop.maxGridSize[1])));
290297
}
291298

292299
return Status::OK();

0 commit comments

Comments
 (0)