Skip to content

Commit bf8960d

Browse files
authored
Merge pull request #233 from ShieLian/patch-3
Update SubPixelConv2d using tf.depth_to_sapce
2 parents 0b10326 + 90459bf commit bf8960d

1 file changed

Lines changed: 6 additions & 22 deletions

File tree

tensorlayer/layers.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,11 +2385,14 @@ def SubpixelConv2d(net, scale=2, n_out_channel=None, act=tf.identity, name='subp
23852385
def _PS(X, r, n_out_channel):
23862386
if n_out_channel >= 1:
23872387
assert int(X.get_shape()[-1]) == (r ** 2) * n_out_channel, _err_log
2388+
'''
23882389
bsize, a, b, c = X.get_shape().as_list()
23892390
bsize = tf.shape(X)[0] # Handling Dimension(None) type for undefined batch dim
23902391
Xs=tf.split(X,r,3) #b*h*w*r*r
23912392
Xr=tf.concat(Xs,2) #b*h*(r*w)*r
23922393
X=tf.reshape(Xr,(bsize,r*a,r*b,n_out_channel)) # b*(r*h)*(r*w)*c
2394+
'''
2395+
X=tf.depth_to_space(X,r)
23932396
else:
23942397
print(_err_log)
23952398
return X
@@ -2467,31 +2470,12 @@ def SubpixelConv2d_old(net, scale=2, n_out_channel=None, act=tf.identity, name='
24672470
if scope_name:
24682471
name = scope_name + '/' + name
24692472

2470-
def _phase_shift(I, r):
2471-
if tf.__version__ < '1.0':
2472-
raise Exception("Only support TF1.0+")
2473-
bsize, a, b, c = I.get_shape().as_list()
2474-
bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
2475-
X = tf.reshape(I, (bsize, a, b, r, r))
2476-
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1 # tf 0.12
2477-
# X = tf.split(1, a, X) # a, [bsize, b, r, r] # tf 0.12
2478-
X = tf.split(X, a, 1)
2479-
# X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, b, a*r, r # tf 0.12
2480-
X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2)
2481-
# X = tf.split(1, b, X) # b, [bsize, a*r, r] # tf 0.12
2482-
X = tf.split(X, b, 1)
2483-
# X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r, b*r # tf 0.12
2484-
X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2)
2485-
return tf.reshape(X, (bsize, a*r, b*r, 1))
2486-
24872473
def _PS(X, r, n_out_channel):
24882474
if n_out_channel > 1:
24892475
assert int(X.get_shape()[-1]) == (r ** 2) * n_out_channel, _err_log
2490-
Xc = tf.split(X, n_out_channel, 3)
2491-
X = tf.concat([_phase_shift(x, r) for x in Xc], 3)
2492-
elif n_out_channel == 1:
2493-
assert int(X.get_shape()[-1]) == (r ** 2), _err_log
2494-
X = _phase_shift(X, r)
2476+
X=tf.transpose(X,[0,2,1,3])
2477+
X=tf.depth_to_space(X,r)
2478+
X=tf.transpose(X,[0,2,1,3])
24952479
else:
24962480
print(_err_log)
24972481
return X

0 commit comments

Comments
 (0)