-
Notifications
You must be signed in to change notification settings - Fork 388
Expand file tree
/
Copy pathfastcell_example.py
More file actions
100 lines (80 loc) · 3.49 KB
/
Copy pathfastcell_example.py
File metadata and controls
100 lines (80 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import helpermethods
import torch
import numpy as np
import sys
from edgeml_pytorch.graph.rnn import *
from edgeml_pytorch.trainer.fastTrainer import FastTrainer
def main():
# change cuda:0 to cuda:gpuid for specific allocation
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Fixing seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# Hyper Param pre-processing
args = helpermethods.getArgs()
dataDir = args.data_dir
cell = args.cell
inputDims = args.input_dim
batch_first = args.batch_first
hiddenDims = args.hidden_dim
totalEpochs = args.epochs
learningRate = args.learning_rate
outFile = args.output_file
batchSize = args.batch_size
decayStep = args.decay_step
decayRate = args.decay_rate
wRank = args.wRank
uRank = args.uRank
sW = args.sW
sU = args.sU
update_non_linearity = args.update_nl
gate_non_linearity = args.gate_nl
(dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest,
mean, std) = helpermethods.preProcessData(dataDir)
assert dataDimension % inputDims == 0, "Infeasible per step input, " + \
"Timesteps have to be integer"
timeSteps = int(dataDimension / inputDims)
Xtrain = Xtrain.reshape((-1, timeSteps, inputDims))
Xtest = Xtest.reshape((-1, timeSteps, inputDims))
if not batch_first:
Xtrain = np.swapaxes(Xtrain, 0, 1)
Xtest = np.swapaxes(Xtest, 0, 1)
currDir = helpermethods.createTimeStampDir(dataDir, cell)
helpermethods.dumpCommand(sys.argv, currDir)
helpermethods.saveMeanStd(mean, std, currDir)
if cell == "FastGRNN":
FastCell = FastGRNNCell(inputDims, hiddenDims,
gate_nonlinearity=gate_non_linearity,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "FastRNN":
FastCell = FastRNNCell(inputDims, hiddenDims,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "UGRNN":
FastCell = UGRNNLRCell(inputDims, hiddenDims,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "GRU":
FastCell = GRULRCell(inputDims, hiddenDims,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "LSTM":
FastCell = LSTMLRCell(inputDims, hiddenDims,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
else:
sys.exit('Exiting: No Such Cell as ' + cell)
FastCellTrainer = FastTrainer(FastCell, numClasses, sW=sW, sU=sU,
learningRate=learningRate, outFile=outFile,
device=device, batch_first=batch_first)
FastCellTrainer.train(batchSize, totalEpochs,
torch.from_numpy(Xtrain.astype(np.float32)),
torch.from_numpy(Xtest.astype(np.float32)),
torch.from_numpy(Ytrain.astype(np.float32)),
torch.from_numpy(Ytest.astype(np.float32)),
decayStep, decayRate, dataDir, currDir)
if __name__ == '__main__':
main()