@@ -2385,11 +2385,14 @@ def SubpixelConv2d(net, scale=2, n_out_channel=None, act=tf.identity, name='subp
23852385 def _PS (X , r , n_out_channel ):
23862386 if n_out_channel >= 1 :
23872387 assert int (X .get_shape ()[- 1 ]) == (r ** 2 ) * n_out_channel , _err_log
2388+ '''
23882389 bsize, a, b, c = X.get_shape().as_list()
23892390 bsize = tf.shape(X)[0] # Handling Dimension(None) type for undefined batch dim
23902391 Xs=tf.split(X,r,3) #b*h*w*r*r
23912392 Xr=tf.concat(Xs,2) #b*h*(r*w)*r
23922393 X=tf.reshape(Xr,(bsize,r*a,r*b,n_out_channel)) # b*(r*h)*(r*w)*c
2394+ '''
2395+ X = tf .depth_to_space (X ,r )
23932396 else :
23942397 print (_err_log )
23952398 return X
@@ -2467,31 +2470,12 @@ def SubpixelConv2d_old(net, scale=2, n_out_channel=None, act=tf.identity, name='
24672470 if scope_name :
24682471 name = scope_name + '/' + name
24692472
2470- def _phase_shift (I , r ):
2471- if tf .__version__ < '1.0' :
2472- raise Exception ("Only support TF1.0+" )
2473- bsize , a , b , c = I .get_shape ().as_list ()
2474- bsize = tf .shape (I )[0 ] # Handling Dimension(None) type for undefined batch dim
2475- X = tf .reshape (I , (bsize , a , b , r , r ))
2476- X = tf .transpose (X , (0 , 1 , 2 , 4 , 3 )) # bsize, a, b, 1, 1 # tf 0.12
2477- # X = tf.split(1, a, X) # a, [bsize, b, r, r] # tf 0.12
2478- X = tf .split (X , a , 1 )
2479- # X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, b, a*r, r # tf 0.12
2480- X = tf .concat ([tf .squeeze (x , axis = 1 ) for x in X ], 2 )
2481- # X = tf.split(1, b, X) # b, [bsize, a*r, r] # tf 0.12
2482- X = tf .split (X , b , 1 )
2483- # X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r, b*r # tf 0.12
2484- X = tf .concat ([tf .squeeze (x , axis = 1 ) for x in X ], 2 )
2485- return tf .reshape (X , (bsize , a * r , b * r , 1 ))
2486-
24872473 def _PS (X , r , n_out_channel ):
24882474 if n_out_channel > 1 :
24892475 assert int (X .get_shape ()[- 1 ]) == (r ** 2 ) * n_out_channel , _err_log
2490- Xc = tf .split (X , n_out_channel , 3 )
2491- X = tf .concat ([_phase_shift (x , r ) for x in Xc ], 3 )
2492- elif n_out_channel == 1 :
2493- assert int (X .get_shape ()[- 1 ]) == (r ** 2 ), _err_log
2494- X = _phase_shift (X , r )
2476+ X = tf .transpose (X ,[0 ,2 ,1 ,3 ])
2477+ X = tf .depth_to_space (X ,r )
2478+ X = tf .transpose (X ,[0 ,2 ,1 ,3 ])
24952479 else :
24962480 print (_err_log )
24972481 return X
0 commit comments