-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtf_example.py
More file actions
117 lines (92 loc) · 3.84 KB
/
tf_example.py
File metadata and controls
117 lines (92 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""TensorFlow RNN example."""
import warnings
import os
import string
with warnings.catch_warnings():
warnings.simplefilter('ignore', RuntimeWarning)
warnings.simplefilter('ignore', FutureWarning)
import tensorflow as tf
from dataset import Dataset
from iterator import Iterator
import text
def tf_example(args):
"""Build, train, and test the discriminator using TensorFlow frontend."""
dtype = tf.float32
# Get one-hot encoded English and German/French words.
words, labels = text.get_data(args.length, args.language, True)
data = Dataset(words, labels, args.validation_split)
iterator = Iterator(data, args.n_epochs, args.batch_size)
# Model input and output.
with tf.name_scope('io'):
x = tf.placeholder(dtype, [words.shape[0], None, words.shape[2]], 'x')
y = tf.placeholder(dtype, [None, 1], 'y')
# The neural network.
with tf.variable_scope('lstm0'):
lstm0 = tf.contrib.rnn.LSTMBlockFusedCell(args.n_state)
lstm0_output, lstm0_state = lstm0(x, dtype=dtype)
with tf.variable_scope('lstm1'):
lstm1 = tf.contrib.rnn.LSTMBlockFusedCell(args.n_state)
lstm1_output, lstm1_state = lstm1(lstm0_output, dtype=dtype)
with tf.variable_scope('dense'):
# Apply the sigmoid in the loss, not in the dense layer.
logits = tf.layers.dense(lstm1_output[-1, :, :], 1, name='logits')
# The training loss.
with tf.name_scope('loss'):
loss = tf.losses.sigmoid_cross_entropy(y, logits)
# Classification accuracy. y = 1 iff logits > 0.
with tf.name_scope('accuracy'):
correct = tf.cast(
tf.equal(tf.cast(y, tf.bool), tf.greater(logits, 0)),
dtype
)
accuracy = tf.reduce_mean(correct)
# The optimizer.
with tf.name_scope('optimizer'):
optimizer = tf.train.AdamOptimizer()
global_step = tf.Variable(0, False, name='global_step')
# The training operation.
with tf.name_scope('train'):
train_op = optimizer.minimize(loss, global_step)
# Label inputs.
with tf.name_scope('predict'):
label = tf.sigmoid(logits, 'label')
# One-hot encoded words to label.
test_words = text.get_test_data(args.language)
words_encoded = text.one_hot(test_words, args.length, True)
# Run the training and testing.
with tf.Session() as session:
tf.global_variables_initializer().run()
# Run through the minibatches.
for data_x, data_y in iterator:
# Run the training operation and get the loss.
train_op.run({x: data_x, y: data_y})
# Report diagnostics at every epoch.
if iterator.new_epoch:
# Report training loss and accuracy.
train_loss, train_accuracy = session.run(
[loss, accuracy],
{x: data_x, y: data_y}
)
# Report validation loss and accuracy.
val_loss, val_accuracy = session.run(
[loss, accuracy],
{x: data.x['val'], y: data.y['val']}
)
print('Epoch {}:'.format(iterator.epoch))
print(
' Train: loss = {:.6f}, accuracy = {:.6f}'
.format(train_loss, train_accuracy)
)
print(
' Validation: loss = {:.6f}, accuracy = {:.6f}'
.format(val_loss, val_accuracy)
)
# Label words.
test_labels = label.eval({x: words_encoded})
# Save all variables.
saver = tf.train.Saver()
saver.save(session, os.getcwd() + '/tf_example.ckpt')
# Print predictions.
print('\nWord: P({})'.format(args.language.capitalize()))
for word, label in zip(test_words, test_labels):
print('{}: {:.3f}'.format(word, float(label)))