Skip to content

Commit 8e1f9d0

Browse files
authored
CUDA: handle OW > 65535 in im2col (2D and 3D) (ggml-org#22944)
`im2col_cuda` and `im2col_3d_cuda` both dispatch with `block_nums.y = OW`. CUDA caps grid Y at 65535. Conv1d encoders on raw 16 kHz audio with T > 65535 (~ 4 s) trip the limit -- e.g. SEANet at 11 s lands at OW = 176000 -- and the launch returns `invalid configuration argument`. Clamp `block_nums.y` to `MIN(OW, MAX_GRIDDIM_Y)` and loop inside the kernel with stride `MAX_GRIDDIM_Y`. Same in-kernel stride pattern already used for the z axis (`MAX_GRIDDIM_Z`). Both 2D `im2col_kernel` and 3D `im2col_3d_kernel` need the same fix. Bit-identical for OW <= 65535 (single iteration of the new outer loop). Tested on T4 / Jetson Orin with a SEANet encoder running on 11 s / 16 kHz audio (im2col reaching OW ~ 176000); pre-fix launch returns `invalid configuration argument`, post-fix runs to completion. Existing test-backend-ops im2col cases unchanged.
1 parent e936660 commit 8e1f9d0

1 file changed

Lines changed: 32 additions & 29 deletions

File tree

ggml/src/ggml-cuda/im2col.cu

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "im2col.cuh"
22

3+
#define MAX_GRIDDIM_Y 65535
34
#define MAX_GRIDDIM_Z 65535
45

56
template <typename T>
@@ -18,22 +19,23 @@ static __global__ void im2col_kernel(
1819
const int64_t ikh = rem / KW;
1920
const int64_t ikw = rem - ikh * KW;
2021

21-
const int64_t iow = blockIdx.y;
22-
for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
23-
const int64_t in = iz / OH;
24-
const int64_t ioh = iz - in * OH;
22+
for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) {
23+
for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) {
24+
const int64_t in = iz / OH;
25+
const int64_t ioh = iz - in * OH;
2526

26-
const int64_t iiw = iow * s0 + ikw * d0 - p0;
27-
const int64_t iih = ioh * s1 + ikh * d1 - p1;
27+
const int64_t iiw = iow * s0 + ikw * d0 - p0;
28+
const int64_t iih = ioh * s1 + ikh * d1 - p1;
2829

29-
const int64_t offset_dst =
30-
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
30+
const int64_t offset_dst =
31+
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
3132

32-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
33-
dst[offset_dst] = 0.0f;
34-
} else {
35-
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
36-
dst[offset_dst] = x[offset_src + iih * IW + iiw];
33+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
34+
dst[offset_dst] = 0.0f;
35+
} else {
36+
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
37+
dst[offset_dst] = x[offset_src + iih * IW + iiw];
38+
}
3739
}
3840
}
3941

@@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst,
5153
const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
5254
const int64_t N_OH = N * OH;
5355
const int64_t KH_KW = KW*KH;
54-
dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
56+
dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z));
5557
im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
5658
IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
5759
s0, s1, p0, p1, d0, d1);
@@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel(
136138
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
137139
const int64_t ikw = i % KW;
138140

139-
const int64_t iow = blockIdx.y;
140-
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
141-
const int64_t in = iz / OD_OH;
142-
const int64_t iod = (iz - in*OD_OH) / OH;
143-
const int64_t ioh = iz % OH;
141+
for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) {
142+
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) {
143+
const int64_t in = iz / OD_OH;
144+
const int64_t iod = (iz - in*OD_OH) / OH;
145+
const int64_t ioh = iz % OH;
144146

145-
const int64_t iiw = iow * s0 + ikw * d0 - p0;
146-
const int64_t iih = ioh * s1 + ikh * d1 - p1;
147-
const int64_t iid = iod * s2 + ikd * d2 - p2;
147+
const int64_t iiw = iow * s0 + ikw * d0 - p0;
148+
const int64_t iih = ioh * s1 + ikh * d1 - p1;
149+
const int64_t iid = iod * s2 + ikd * d2 - p2;
148150

149-
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
151+
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
150152

151-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
152-
dst[offset_dst] = 0.0f;
153-
} else {
154-
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
155-
dst[offset_dst] = src[offset_src];
153+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
154+
dst[offset_dst] = 0.0f;
155+
} else {
156+
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
157+
dst[offset_dst] = src[offset_src];
158+
}
156159
}
157160
}
158161
}
@@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst,
178181
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
179182
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
180183
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
181-
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
184+
dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z));
182185
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
183186
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
184187
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,

0 commit comments

Comments
 (0)