@@ -13,7 +13,7 @@ layout (push_constant) uniform parameter
1313 uint IW; uint IH;
1414 uint OW; uint OH;
1515 uint KW; uint KH;
16- uint pelements ;
16+ uint OH_batch ;
1717 uint CHW;
1818 int s0; int s1;
1919 int p0; int p1;
@@ -34,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
3434layout (buffer_reference) buffer D_ptr {D_TYPE d;};
3535#endif
3636
37- void im2col(const uint y, const uint z) {
38- const uint gidx = gl_GlobalInvocationID.x;
37+ void im2col(const uint ow, const uint z_idx) {
38+ const uint oh = z_idx % p.OH;
39+ const uint batch_idx = z_idx / p.OH;
3940
40- const uint oh = y ;
41- const uint batch = z / p.IC ;
42- const uint ic = z % p.IC ;
41+ const uint gidx = gl_LocalInvocationID.x ;
42+ const uint src_batch = batch_idx * p.batch_offset ;
43+ const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW ;
4344
44- const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
45- const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
46- const int oh_s1 = int(oh) * p.s1;
47- const uint ksize = p.OW * p.KH;
45+ const uint KHKW = p.KH * p.KW;
4846
49- const uint base_linear_idx = gidx * NUM_ITER;
47+ uint wg_x = gl_WorkGroupID.x;
48+ do {
49+ const uint wg_offset = wg_x * 512;
5050
51- uint current_kx = base_linear_idx / ksize;
52- const uint rem = base_linear_idx - (current_kx * ksize);
53- uint current_ky = rem / p.OW;
54- uint current_ix = rem % p.OW;
51+ [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
52+ const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE;
5553
56- A_TYPE values[NUM_ITER];
57- BDA_OFFSET_T offset_dst[NUM_ITER];
58- [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
59- values[idx] = A_TYPE(0);
60- }
61-
62- [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
63-
64- const uint linear_idx = base_linear_idx + idx;
65-
66- if (linear_idx >= p.pelements) {
67- continue;
68- }
69-
70- const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
71- const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
72-
73- offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
74-
75- if ((iih < p.IH) && (iiw < p.IW)) {
76- values[idx] = data_a[src_base + iih * p.IW + iiw];
77- }
78-
79- if (++current_ix == p.OW) {
80- current_ix = 0;
81- if (++current_ky == p.KH) {
82- current_ky = 0;
83- current_kx++;
54+ if (chw_idx >= p.CHW) {
55+ return;
8456 }
85- }
86- }
8757
88- [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
58+ const uint ic = chw_idx / KHKW;
59+ const uint rem = chw_idx - ic * KHKW;
60+ const uint ky = rem / p.KW;
61+ const uint kx = rem - ky * p.KW;
8962
90- const uint linear_idx = base_linear_idx + idx;
63+ const uint iiw = ow * p.s0 + kx * p.d0 - p.p0;
64+ const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
9165
92- if (linear_idx >= p.pelements) {
93- continue;
94- }
66+ A_TYPE val = A_TYPE(0);
67+ if (iih < p.IH && iiw < p.IW) {
68+ val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw];
69+ }
9570
9671#if BDA
97- D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx] );
98- dst_addr .d = D_TYPE(values[idx] );
72+ D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx) );
73+ out_ptr .d = D_TYPE(val );
9974#else
100- data_d[offset_dst[idx]] = D_TYPE(values[idx] );
75+ data_d[dst_row + chw_idx] = D_TYPE(val );
10176#endif
102- }
77+ }
78+
79+ wg_x += gl_NumWorkGroups.x;
80+ } while (wg_x * 512 < p.CHW);
10381}
10482
10583void main() {
106- uint y = gl_GlobalInvocationID.y;
107- while (y < p.OH ) {
84+ uint ow = gl_GlobalInvocationID.y;
85+ while (ow < p.OW ) {
10886 uint z = gl_GlobalInvocationID.z;
109- while (z < p.batch_IC ) {
110- im2col(y , z);
87+ while (z < p.OH_batch ) {
88+ im2col(ow , z);
11189 z += gl_NumWorkGroups.z;
11290 }
113- y += gl_NumWorkGroups.y;
91+ ow += gl_NumWorkGroups.y;
11492 }
11593}
0 commit comments