Skip to content

Commit d14a1d2

Browse files
authored
Merge pull request amathislab#5 from amathislab/new-spindle
New spindle model and task (for revised version)
2 parents 03aec7b + 6b746c5 commit d14a1d2

38 files changed

Lines changed: 2806 additions & 388 deletions

code/kinematics_decoding.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
from nn_models import ConvModel, RecurrentModel, AffineModel
4+
from nn_rmodels import ConvRModel, RecurrentRModel
45

56
import numpy as np
67
import pandas as pd
@@ -39,6 +40,74 @@ def load_model(meta_data, experiment_id, model_type, is_trained):
3940
t_stride=int(meta_data['t_stride']),
4041
seed=myseed,
4142
train=is_trained)
43+
44+
if model_type == 'rec':
45+
46+
try:
47+
myseed = meta_data['seed']
48+
except:
49+
myseed = None
50+
51+
model = RecurrentModel(
52+
experiment_id=experiment_id,
53+
nclasses=20,
54+
rec_blocktype=meta_data['rec_blocktype'],
55+
n_recunits=meta_data['n_recunits'],
56+
npplayers=meta_data['npplayers'],
57+
s_kernelsize=meta_data['s_kernelsize'],
58+
s_stride=meta_data['s_stride'],
59+
nppfilters=meta_data['nppfilters'],
60+
seed=myseed,
61+
train=is_trained)
62+
63+
return model
64+
65+
66+
def load_rmodel(meta_data, experiment_id, model_type, is_trained):
67+
'''Load a trained `ConvRModel`, `RecurrentRModel` object.
68+
69+
Returns
70+
-------
71+
model : an instance of `ConvRModel` or `RecurrentRModel`
72+
73+
'''
74+
if model_type == 'conv':
75+
76+
try:
77+
myseed = meta_data['seed']
78+
except:
79+
myseed = None
80+
81+
model = ConvRModel(
82+
experiment_id=experiment_id,
83+
arch_type=meta_data['arch_type'],
84+
nlayers=meta_data['nlayers'],
85+
n_skernels=meta_data['n_skernels'],
86+
n_tkernels=meta_data['n_tkernels'],
87+
s_kernelsize=int(meta_data['s_kernelsize']),
88+
t_kernelsize=int(meta_data['t_kernelsize']),
89+
s_stride=int(meta_data['s_stride']),
90+
t_stride=int(meta_data['t_stride']),
91+
seed=myseed,
92+
train=is_trained)
93+
94+
if model_type == 'rec':
95+
96+
try:
97+
myseed = meta_data['seed']
98+
except:
99+
myseed = None
100+
101+
model = RecurrentRModel(
102+
experiment_id=experiment_id,
103+
rec_blocktype=meta_data['rec_blocktype'],
104+
n_recunits=meta_data['n_recunits'],
105+
npplayers=meta_data['npplayers'],
106+
s_kernelsize=meta_data['s_kernelsize'],
107+
s_stride=meta_data['s_stride'],
108+
nppfilters=meta_data['nppfilters'],
109+
seed=myseed,
110+
train=is_trained)
42111

43112
return model
44113

@@ -81,7 +150,7 @@ def set_kin_dimensions(kinematics, n_out_timesteps):
81150
assert n_timesteps == 320, "Kinematics shape mismatch. Please revise."
82151
skip_idx = n_timesteps // (n_out_timesteps - 1)
83152
kinematics = kinematics[:, :, ::skip_idx]
84-
assert kinematics.shape[-1] == n_out_timesteps, "Whaatt!"
153+
assert kinematics.shape[-1] == n_out_timesteps
85154
return kinematics.transpose(0, 2, 1)
86155

87156

@@ -141,7 +210,11 @@ def make_h5file_name(model_path, repfile_name, layername):
141210

142211
with tf.Graph().as_default():
143212
X = tf.placeholder(tf.float32, myshape, name="X")
144-
_, _, output = model.predict(X, is_training=False)
213+
name = type(model).__name__
214+
if name == 'ConvModel' or name == 'RecurrentModel':
215+
_, _, output = model.predict(X, is_training=False)
216+
elif name == 'ConvRModel' or name == 'RecurrentRModel':
217+
_, output = model.predict(X, is_training=False)
145218
restorer = tf.train.Saver()
146219
myconfig = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
147220

code/nn_models.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ class RecurrentModel():
244244
"""Defines a recurrent neural network model of the proprioceptive system."""
245245

246246
def __init__(
247-
self, experiment_id, nclasses, rec_blocktype, n_recunits, npplayers, nppunits, keep_prob):
247+
self, experiment_id, nclasses, rec_blocktype, n_recunits, npplayers, nppfilters,
248+
s_kernelsize, s_stride, seed=None, train=True, CPU=False):
248249
"""Set up the hyperparameters of the recurrent model.
249250
250251
Arguments
@@ -254,29 +255,37 @@ def __init__(
254255
rec_blocktype: {'lstm', 'gru'} str, type of recurrent block.
255256
n_recunits : int, number of units in the recurrent block.
256257
npplayers : int, number of layers in the fully-connected module.
257-
nppunits : list of ints, number of units in the affine layers for spatial processing.
258-
keep_prob : float, amount of dropout at each spatial processing layer.
258+
nppfilters : list of ints, number of filters (spatial convolutions) for spatial processing.
259+
s_kernelsize : int, size of conv kernel
260+
s_stride : int, stride for conv kernel
261+
seed : int, for saving random initializations
262+
train : bool, whether to train the model or not (just save random initialization)
259263
260264
"""
261265

262-
assert len(nppunits) == npplayers
263-
assert rec_blocktype in ('lstm', 'gru')
266+
assert len(nppfilters) == npplayers
267+
assert rec_blocktype in ('lstm', 'gru')
264268

265269
self.experiment_id = experiment_id
266270
self.nclasses = nclasses
267271
self.rec_blocktype = rec_blocktype
268272
self.n_recunits = n_recunits
269273
self.npplayers = npplayers
270-
self.nppunits = nppunits
271-
self.keep_prob = keep_prob
274+
self.nppfilters = nppfilters
275+
self.s_kernelsize = s_kernelsize
276+
self.s_stride = s_stride
277+
self.seed = seed
278+
self.CPU = CPU
272279

273280
# Make model name
274-
dropout_name = {'0.6': 'high', '0.7': 'med', '0.8': 'low'}
275-
units = ('-'.join(str(i) for i in nppunits))
276-
parts_name = [rec_blocktype, str(npplayers), units, str(n_recunits), dropout_name[str(keep_prob)]]
281+
units = ('-'.join(str(i) for i in nppfilters))
282+
parts_name = [rec_blocktype, str(npplayers), units, str(n_recunits)]
277283

278284
# Create model directory
279285
self.name = '_'.join(parts_name)
286+
if seed is not None: self.name += '_' + str(self.seed)
287+
if not train: self.name += 'r'
288+
280289
exp_dir = os.path.join(MODELS_DIR, f'experiment_{self.experiment_id}')
281290
self.model_path = os.path.join(exp_dir, self.name)
282291

@@ -292,25 +301,33 @@ def predict(self, X, is_training=True):
292301
net = OrderedDict()
293302

294303
with tf.variable_scope('Network', reuse=tf.AUTO_REUSE):
295-
score = tf.transpose(X, [0, 2, 1])
304+
score = X
296305
batch_size = X.get_shape()[0]
297306

298-
fully_connected = lambda score, layer_id: slim.fully_connected(
299-
score, self.nppunits[layer_id], normalizer_fn=slim.layer_norm, scope=f'FC{layer_id}')
307+
spatial_conv = lambda score, layer_id: slim.conv2d(
308+
score, self.nppfilters[layer_id], [self.s_kernelsize, 1], [self.s_stride, 1],
309+
data_format='NHWC', normalizer_fn=slim.layer_norm, scope=f'Spatial{layer_id}')
300310

301311
for layer in range(self.npplayers):
302-
score = fully_connected(score, layer)
303-
score = slim.dropout(score, keep_prob=self.keep_prob, is_training=is_training)
312+
score = spatial_conv(score, layer)
304313
net[f'spatial{layer}'] = score
305314

306315
# `cudnn_rnn` requires the inputs to be of shape [timesteps, batch_size, num_inputs]
316+
score = tf.transpose(score, [0, 2, 1, 3])
317+
score = tf.reshape(score, [batch_size, 320, -1])
307318
score = tf.transpose(score, [1, 0, 2])
308-
if self.rec_blocktype == 'lstm':
319+
320+
if self.rec_blocktype == 'lstm' and self.CPU == False:
309321
recurrent_cell = cudnn_rnn.CudnnLSTM(1, self.n_recunits, name='RecurrentBlock')
310322
score, _ = recurrent_cell.apply(score)
311-
elif self.rec_blocktype == 'gru':
323+
elif self.rec_blocktype == 'gru' and self.CPU == False:
312324
recurrent_cell = cudnn_rnn.CudnnGRU(1, self.n_recunits, name='RecurrentBlock')
313325
score, _ = recurrent_cell.apply(score)
326+
elif self.rec_blocktype == 'lstm' and self.CPU:
327+
with tf.variable_scope('RecurrentBlock'):
328+
rec_layer = lambda: cudnn_rnn.CudnnCompatibleLSTMCell(self.n_recunits)
329+
recurrent_cell = tf.nn.rnn_cell.MultiRNNCell([rec_layer() for _ in range(1)])
330+
score, _ = tf.nn.dynamic_rnn(recurrent_cell, score, dtype=tf.float32)
314331

315332
score = tf.transpose(score, [1, 0, 2])
316333
net['recurrent_out'] = score

0 commit comments

Comments
 (0)