Skip to content

Commit 4c66170

Browse files
author
wzhouad
committed
Add Inference Module
1 parent f944fd5 commit 4c66170

1 file changed

Lines changed: 215 additions & 0 deletions

File tree

inference.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import tensorflow as tf
2+
import spacy
3+
import os
4+
import numpy as np
5+
import ujson as json
6+
7+
8+
from func import cudnn_gru, native_gru, dot_attention, summ, ptr_net
9+
from prepro import word_tokenize, convert_idx
10+
11+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
12+
13+
# Must be consistant with training
14+
char_limit = 16
15+
hidden = 75
16+
char_dim = 8
17+
char_hidden = 100
18+
use_cudnn = True
19+
20+
# File path
21+
target_dir = "data"
22+
save_dir = "log/model"
23+
word_emb_file = os.path.join(target_dir, "word_emb.json")
24+
char_emb_file = os.path.join(target_dir, "char_emb.json")
25+
word2idx_file = os.path.join(target_dir, "word2idx.json")
26+
char2idx_file = os.path.join(target_dir, "char2idx.json")
27+
28+
29+
class InfModel(object):
30+
31+
def __init__(self, word_mat, char_mat):
32+
self.c = tf.placeholder(tf.int32, [1, None])
33+
self.q = tf.placeholder(tf.int32, [1, None])
34+
self.ch = tf.placeholder(tf.int32, [1, None, char_limit])
35+
self.qh = tf.placeholder(tf.int32, [1, None, char_limit])
36+
37+
self.word_mat = tf.get_variable("word_mat", initializer=tf.constant(
38+
word_mat, dtype=tf.float32), trainable=False)
39+
self.char_mat = tf.get_variable(
40+
"char_mat", initializer=tf.constant(char_mat, dtype=tf.float32))
41+
42+
self.c_mask = tf.cast(self.c, tf.bool)
43+
self.q_mask = tf.cast(self.q, tf.bool)
44+
self.c_len = tf.reduce_sum(tf.cast(self.c_mask, tf.int32), axis=1)
45+
self.q_len = tf.reduce_sum(tf.cast(self.q_mask, tf.int32), axis=1)
46+
47+
self.c_maxlen = tf.reduce_max(self.c_len)
48+
self.q_maxlen = tf.reduce_max(self.q_len)
49+
50+
self.ch_len = tf.reshape(tf.reduce_sum(
51+
tf.cast(tf.cast(self.ch, tf.bool), tf.int32), axis=2), [-1])
52+
self.qh_len = tf.reshape(tf.reduce_sum(
53+
tf.cast(tf.cast(self.qh, tf.bool), tf.int32), axis=2), [-1])
54+
55+
self.ready()
56+
57+
def ready(self):
58+
N, PL, QL, CL, d, dc, dg = 1, self.c_maxlen, self.q_maxlen, char_limit, hidden, char_dim, char_hidden
59+
gru = cudnn_gru if use_cudnn else native_gru
60+
61+
with tf.variable_scope("emb"):
62+
with tf.variable_scope("char"):
63+
ch_emb = tf.reshape(tf.nn.embedding_lookup(
64+
self.char_mat, self.ch), [N * PL, CL, dc])
65+
qh_emb = tf.reshape(tf.nn.embedding_lookup(
66+
self.char_mat, self.qh), [N * QL, CL, dc])
67+
cell_fw = tf.contrib.rnn.GRUCell(dg)
68+
cell_bw = tf.contrib.rnn.GRUCell(dg)
69+
_, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
70+
cell_fw, cell_bw, ch_emb, self.ch_len, dtype=tf.float32)
71+
ch_emb = tf.concat([state_fw, state_bw], axis=1)
72+
_, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
73+
cell_fw, cell_bw, qh_emb, self.qh_len, dtype=tf.float32)
74+
qh_emb = tf.concat([state_fw, state_bw], axis=1)
75+
qh_emb = tf.reshape(qh_emb, [N, QL, 2 * dg])
76+
ch_emb = tf.reshape(ch_emb, [N, PL, 2 * dg])
77+
78+
with tf.name_scope("word"):
79+
c_emb = tf.nn.embedding_lookup(self.word_mat, self.c)
80+
q_emb = tf.nn.embedding_lookup(self.word_mat, self.q)
81+
82+
c_emb = tf.concat([c_emb, ch_emb], axis=2)
83+
q_emb = tf.concat([q_emb, qh_emb], axis=2)
84+
85+
with tf.variable_scope("encoding"):
86+
rnn = gru(num_layers=3, num_units=d, batch_size=N,
87+
input_size=c_emb.get_shape().as_list()[-1])
88+
c = rnn(c_emb, seq_len=self.c_len)
89+
q = rnn(q_emb, seq_len=self.q_len)
90+
91+
with tf.variable_scope("attention"):
92+
qc_att = dot_attention(c, q, mask=self.q_mask, hidden=d)
93+
rnn = gru(num_layers=1, num_units=d, batch_size=N,
94+
input_size=qc_att.get_shape().as_list()[-1])
95+
att = rnn(qc_att, seq_len=self.c_len)
96+
97+
with tf.variable_scope("match"):
98+
self_att = dot_attention(att, att, mask=self.c_mask, hidden=d)
99+
rnn = gru(num_layers=1, num_units=d, batch_size=N,
100+
input_size=self_att.get_shape().as_list()[-1])
101+
match = rnn(self_att, seq_len=self.c_len)
102+
103+
with tf.variable_scope("pointer"):
104+
init = summ(q[:, :, -2 * d:], d, mask=self.q_mask)
105+
pointer = ptr_net(batch=N, hidden=init.get_shape().as_list()[-1])
106+
logits1, logits2 = pointer(init, match, d, self.c_mask)
107+
108+
with tf.variable_scope("predict"):
109+
outer = tf.matmul(tf.expand_dims(tf.nn.softmax(logits1), axis=2),
110+
tf.expand_dims(tf.nn.softmax(logits2), axis=1))
111+
outer = tf.matrix_band_part(outer, 0, 15)
112+
self.yp1 = tf.argmax(tf.reduce_max(outer, axis=2), axis=1)
113+
self.yp2 = tf.argmax(tf.reduce_max(outer, axis=1), axis=1)
114+
115+
116+
class Inference(object):
117+
118+
def __init__(self):
119+
with open(word_emb_file, "r") as fh:
120+
self.word_mat = np.array(json.load(fh), dtype=np.float32)
121+
with open(char_emb_file, "r") as fh:
122+
self.char_mat = np.array(json.load(fh), dtype=np.float32)
123+
with open(word2idx_file, "r") as fh:
124+
self.word2idx_dict = json.load(fh)
125+
with open(char2idx_file, "r") as fh:
126+
self.char2idx_dict = json.load(fh)
127+
self.model = InfModel(self.word_mat, self.char_mat)
128+
sess_config = tf.ConfigProto(allow_soft_placement=True)
129+
sess_config.gpu_options.allow_growth = True
130+
self.sess = tf.Session(config=sess_config)
131+
saver = tf.train.Saver()
132+
saver.restore(self.sess, tf.train.latest_checkpoint(save_dir))
133+
134+
def response(self, context, question):
135+
sess = self.sess
136+
model = self.model
137+
span, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs = self.prepro(
138+
context, question)
139+
yp1, yp2 = sess.run([model.yp1, model.yp2], feed_dict={
140+
model.c: context_idxs, model.q: ques_idxs, model.ch: context_char_idxs, model.qh: ques_char_idxs})
141+
start_idx = span[yp1[0]][0]
142+
end_idx = span[yp2[0]][1]
143+
return context[start_idx: end_idx]
144+
145+
def prepro(self, context, question):
146+
context = context.replace("''", '" ').replace("``", '" ')
147+
context_tokens = word_tokenize(context)
148+
context_chars = [list(token) for token in context_tokens]
149+
spans = convert_idx(context, context_tokens)
150+
ques = question.replace("''", '" ').replace("``", '" ')
151+
ques_tokens = word_tokenize(ques)
152+
ques_chars = [list(token) for token in ques_tokens]
153+
154+
context_idxs = np.zeros([1, len(context_tokens)], dtype=np.int32)
155+
context_char_idxs = np.zeros(
156+
[1, len(context_tokens), char_limit], dtype=np.int32)
157+
ques_idxs = np.zeros([1, len(ques_tokens)], dtype=np.int32)
158+
ques_char_idxs = np.zeros(
159+
[1, len(ques_tokens), char_limit], dtype=np.int32)
160+
161+
def _get_word(word):
162+
for each in (word, word.lower(), word.capitalize(), word.upper()):
163+
if each in self.word2idx_dict:
164+
return self.word2idx_dict[each]
165+
return 1
166+
167+
def _get_char(char):
168+
if char in self.char2idx_dict:
169+
return self.char2idx_dict[char]
170+
return 1
171+
172+
for i, token in enumerate(context_tokens):
173+
context_idxs[0, i] = _get_word(token)
174+
175+
for i, token in enumerate(ques_tokens):
176+
ques_idxs[0, i] = _get_word(token)
177+
178+
for i, token in enumerate(context_chars):
179+
for j, char in enumerate(token):
180+
if j == char_limit:
181+
break
182+
context_char_idxs[0, i, j] = _get_char(char)
183+
184+
for i, token in enumerate(ques_chars):
185+
for j, char in enumerate(token):
186+
if j == char_limit:
187+
break
188+
ques_char_idxs[0, i, j] = _get_char(char)
189+
return spans, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs
190+
191+
192+
# Demo, example from paper "SQuAD: 100,000+ Questions for Machine Comprehension of Text"
193+
if __name__ == "__main__":
194+
infer = Inference()
195+
context = "In meteorology, precipitation is any product of the condensation " \
196+
"of atmospheric water vapor that falls under gravity. The main forms " \
197+
"of precipitation include drizzle, rain, sleet, snow, graupel and hail." \
198+
"Precipitation forms as smaller droplets coalesce via collision with other " \
199+
"rain drops or ice crystals within a cloud. Short, intense periods of rain " \
200+
"in scattered locations are called “showers”."
201+
ques1 = "What causes precipitation to fall?"
202+
ques2 = "What is another main form of precipitation besides drizzle, rain, snow, sleet and hail?"
203+
ques3 = "Where do water droplets collide with ice crystals to form precipitation?"
204+
205+
# Correct: gravity, Output: drizzle, rain, sleet, snow, graupel and hail
206+
ans1 = infer.response(context, ques1)
207+
print("Answer 1: {}".format(ans1))
208+
209+
# Correct: graupel, Output: graupel
210+
ans2 = infer.response(context, ques2)
211+
print("Answer 2: {}".format(ans2))
212+
213+
# Correct: within a cloud, Output: within a cloud
214+
ans3 = infer.response(context, ques3)
215+
print("Answer 3: {}".format(ans3))

0 commit comments

Comments
 (0)