11#include " im2col.cuh"
22
3+ #define MAX_GRIDDIM_Y 65535
34#define MAX_GRIDDIM_Z 65535
45
56template <typename T>
@@ -18,22 +19,23 @@ static __global__ void im2col_kernel(
1819 const int64_t ikh = rem / KW;
1920 const int64_t ikw = rem - ikh * KW;
2021
21- const int64_t iow = blockIdx .y ;
22- for (int64_t iz = blockIdx .z ; iz < N_OH; iz+= MAX_GRIDDIM_Z) {
23- const int64_t in = iz / OH;
24- const int64_t ioh = iz - in * OH;
22+ for ( int64_t iow = blockIdx .y ; iow < OW; iow += MAX_GRIDDIM_Y) {
23+ for (int64_t iz = blockIdx .z ; iz < N_OH; iz += MAX_GRIDDIM_Z) {
24+ const int64_t in = iz / OH;
25+ const int64_t ioh = iz - in * OH;
2526
26- const int64_t iiw = iow * s0 + ikw * d0 - p0;
27- const int64_t iih = ioh * s1 + ikh * d1 - p1;
27+ const int64_t iiw = iow * s0 + ikw * d0 - p0;
28+ const int64_t iih = ioh * s1 + ikh * d1 - p1;
2829
29- const int64_t offset_dst =
30- ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
30+ const int64_t offset_dst =
31+ ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
3132
32- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
33- dst[offset_dst] = 0 .0f ;
34- } else {
35- const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
36- dst[offset_dst] = x[offset_src + iih * IW + iiw];
33+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
34+ dst[offset_dst] = 0 .0f ;
35+ } else {
36+ const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
37+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
38+ }
3739 }
3840 }
3941
@@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst,
5153 const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1 ) / CUDA_IM2COL_BLOCK_SIZE;
5254 const int64_t N_OH = N * OH;
5355 const int64_t KH_KW = KW*KH;
54- dim3 block_nums (num_blocks, OW , MIN (N_OH, MAX_GRIDDIM_Z));
56+ dim3 block_nums (num_blocks, MIN (OW, MAX_GRIDDIM_Y) , MIN (N_OH, MAX_GRIDDIM_Z));
5557 im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0 , stream>>> (x, dst, IC, IW, IH, OH, OW, KW, KH,
5658 IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
5759 s0, s1, p0, p1, d0, d1);
@@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel(
136138 const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
137139 const int64_t ikw = i % KW;
138140
139- const int64_t iow = blockIdx .y ;
140- for (int64_t iz = blockIdx .z ; iz < N_OD_OH; iz+= MAX_GRIDDIM_Z) {
141- const int64_t in = iz / OD_OH;
142- const int64_t iod = (iz - in*OD_OH) / OH;
143- const int64_t ioh = iz % OH;
141+ for ( int64_t iow = blockIdx .y ; iow < OW; iow += MAX_GRIDDIM_Y) {
142+ for (int64_t iz = blockIdx .z ; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) {
143+ const int64_t in = iz / OD_OH;
144+ const int64_t iod = (iz - in*OD_OH) / OH;
145+ const int64_t ioh = iz % OH;
144146
145- const int64_t iiw = iow * s0 + ikw * d0 - p0;
146- const int64_t iih = ioh * s1 + ikh * d1 - p1;
147- const int64_t iid = iod * s2 + ikd * d2 - p2;
147+ const int64_t iiw = iow * s0 + ikw * d0 - p0;
148+ const int64_t iih = ioh * s1 + ikh * d1 - p1;
149+ const int64_t iid = iod * s2 + ikd * d2 - p2;
148150
149- const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
151+ const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
150152
151- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
152- dst[offset_dst] = 0 .0f ;
153- } else {
154- const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
155- dst[offset_dst] = src[offset_src];
153+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
154+ dst[offset_dst] = 0 .0f ;
155+ } else {
156+ const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
157+ dst[offset_dst] = src[offset_src];
158+ }
156159 }
157160 }
158161}
@@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst,
178181 const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
179182 const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
180183 const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1 ) / CUDA_IM2COL_BLOCK_SIZE;
181- dim3 block_nums (num_blocks, OW , MIN (N_OD_OH, MAX_GRIDDIM_Z));
184+ dim3 block_nums (num_blocks, MIN (OW, MAX_GRIDDIM_Y) , MIN (N_OD_OH, MAX_GRIDDIM_Z));
182185 im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0 , stream>>> (src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
183186 OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
184187 IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
0 commit comments