@@ -244,7 +244,8 @@ class RecurrentModel():
244244 """Defines a recurrent neural network model of the proprioceptive system."""
245245
246246 def __init__ (
247- self , experiment_id , nclasses , rec_blocktype , n_recunits , npplayers , nppunits , keep_prob ):
247+ self , experiment_id , nclasses , rec_blocktype , n_recunits , npplayers , nppfilters ,
248+ s_kernelsize , s_stride , seed = None , train = True , CPU = False ):
248249 """Set up the hyperparameters of the recurrent model.
249250
250251 Arguments
@@ -254,29 +255,37 @@ def __init__(
254255 rec_blocktype: {'lstm', 'gru'} str, type of recurrent block.
255256 n_recunits : int, number of units in the recurrent block.
256257 npplayers : int, number of layers in the fully-connected module.
257- nppunits : list of ints, number of units in the affine layers for spatial processing.
258- keep_prob : float, amount of dropout at each spatial processing layer.
258+ nppfilters : list of ints, number of filters (spatial convolutions) for spatial processing.
259+ s_kernelsize : int, size of conv kernel
260+ s_stride : int, stride for conv kernel
261+ seed : int, for saving random initializations
262+ train : bool, whether to train the model or not (just save random initialization)
259263
260264 """
261265
262- assert len (nppunits ) == npplayers
263- assert rec_blocktype in ('lstm' , 'gru' )
266+ assert len (nppfilters ) == npplayers
267+ assert rec_blocktype in ('lstm' , 'gru' )
264268
265269 self .experiment_id = experiment_id
266270 self .nclasses = nclasses
267271 self .rec_blocktype = rec_blocktype
268272 self .n_recunits = n_recunits
269273 self .npplayers = npplayers
270- self .nppunits = nppunits
271- self .keep_prob = keep_prob
274+ self .nppfilters = nppfilters
275+ self .s_kernelsize = s_kernelsize
276+ self .s_stride = s_stride
277+ self .seed = seed
278+ self .CPU = CPU
272279
273280 # Make model name
274- dropout_name = {'0.6' : 'high' , '0.7' : 'med' , '0.8' : 'low' }
275- units = ('-' .join (str (i ) for i in nppunits ))
276- parts_name = [rec_blocktype , str (npplayers ), units , str (n_recunits ), dropout_name [str (keep_prob )]]
281+ units = ('-' .join (str (i ) for i in nppfilters ))
282+ parts_name = [rec_blocktype , str (npplayers ), units , str (n_recunits )]
277283
278284 # Create model directory
279285 self .name = '_' .join (parts_name )
286+ if seed is not None : self .name += '_' + str (self .seed )
287+ if not train : self .name += 'r'
288+
280289 exp_dir = os .path .join (MODELS_DIR , f'experiment_{ self .experiment_id } ' )
281290 self .model_path = os .path .join (exp_dir , self .name )
282291
@@ -292,25 +301,33 @@ def predict(self, X, is_training=True):
292301 net = OrderedDict ()
293302
294303 with tf .variable_scope ('Network' , reuse = tf .AUTO_REUSE ):
295- score = tf . transpose ( X , [ 0 , 2 , 1 ])
304+ score = X
296305 batch_size = X .get_shape ()[0 ]
297306
298- fully_connected = lambda score , layer_id : slim .fully_connected (
299- score , self .nppunits [layer_id ], normalizer_fn = slim .layer_norm , scope = f'FC{ layer_id } ' )
307+ spatial_conv = lambda score , layer_id : slim .conv2d (
308+ score , self .nppfilters [layer_id ], [self .s_kernelsize , 1 ], [self .s_stride , 1 ],
309+ data_format = 'NHWC' , normalizer_fn = slim .layer_norm , scope = f'Spatial{ layer_id } ' )
300310
301311 for layer in range (self .npplayers ):
302- score = fully_connected (score , layer )
303- score = slim .dropout (score , keep_prob = self .keep_prob , is_training = is_training )
312+ score = spatial_conv (score , layer )
304313 net [f'spatial{ layer } ' ] = score
305314
306315 # `cudnn_rnn` requires the inputs to be of shape [timesteps, batch_size, num_inputs]
316+ score = tf .transpose (score , [0 , 2 , 1 , 3 ])
317+ score = tf .reshape (score , [batch_size , 320 , - 1 ])
307318 score = tf .transpose (score , [1 , 0 , 2 ])
308- if self .rec_blocktype == 'lstm' :
319+
320+ if self .rec_blocktype == 'lstm' and self .CPU == False :
309321 recurrent_cell = cudnn_rnn .CudnnLSTM (1 , self .n_recunits , name = 'RecurrentBlock' )
310322 score , _ = recurrent_cell .apply (score )
311- elif self .rec_blocktype == 'gru' :
323+ elif self .rec_blocktype == 'gru' and self . CPU == False :
312324 recurrent_cell = cudnn_rnn .CudnnGRU (1 , self .n_recunits , name = 'RecurrentBlock' )
313325 score , _ = recurrent_cell .apply (score )
326+ elif self .rec_blocktype == 'lstm' and self .CPU :
327+ with tf .variable_scope ('RecurrentBlock' ):
328+ rec_layer = lambda : cudnn_rnn .CudnnCompatibleLSTMCell (self .n_recunits )
329+ recurrent_cell = tf .nn .rnn_cell .MultiRNNCell ([rec_layer () for _ in range (1 )])
330+ score , _ = tf .nn .dynamic_rnn (recurrent_cell , score , dtype = tf .float32 )
314331
315332 score = tf .transpose (score , [1 , 0 , 2 ])
316333 net ['recurrent_out' ] = score
0 commit comments