|
| 1 | +import tensorflow as tf |
| 2 | +import spacy |
| 3 | +import os |
| 4 | +import numpy as np |
| 5 | +import ujson as json |
| 6 | + |
| 7 | + |
| 8 | +from func import cudnn_gru, native_gru, dot_attention, summ, ptr_net |
| 9 | +from prepro import word_tokenize, convert_idx |
| 10 | + |
| 11 | +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
| 12 | + |
| 13 | +# Must be consistant with training |
| 14 | +char_limit = 16 |
| 15 | +hidden = 75 |
| 16 | +char_dim = 8 |
| 17 | +char_hidden = 100 |
| 18 | +use_cudnn = True |
| 19 | + |
| 20 | +# File path |
| 21 | +target_dir = "data" |
| 22 | +save_dir = "log/model" |
| 23 | +word_emb_file = os.path.join(target_dir, "word_emb.json") |
| 24 | +char_emb_file = os.path.join(target_dir, "char_emb.json") |
| 25 | +word2idx_file = os.path.join(target_dir, "word2idx.json") |
| 26 | +char2idx_file = os.path.join(target_dir, "char2idx.json") |
| 27 | + |
| 28 | + |
| 29 | +class InfModel(object): |
| 30 | + |
| 31 | + def __init__(self, word_mat, char_mat): |
| 32 | + self.c = tf.placeholder(tf.int32, [1, None]) |
| 33 | + self.q = tf.placeholder(tf.int32, [1, None]) |
| 34 | + self.ch = tf.placeholder(tf.int32, [1, None, char_limit]) |
| 35 | + self.qh = tf.placeholder(tf.int32, [1, None, char_limit]) |
| 36 | + |
| 37 | + self.word_mat = tf.get_variable("word_mat", initializer=tf.constant( |
| 38 | + word_mat, dtype=tf.float32), trainable=False) |
| 39 | + self.char_mat = tf.get_variable( |
| 40 | + "char_mat", initializer=tf.constant(char_mat, dtype=tf.float32)) |
| 41 | + |
| 42 | + self.c_mask = tf.cast(self.c, tf.bool) |
| 43 | + self.q_mask = tf.cast(self.q, tf.bool) |
| 44 | + self.c_len = tf.reduce_sum(tf.cast(self.c_mask, tf.int32), axis=1) |
| 45 | + self.q_len = tf.reduce_sum(tf.cast(self.q_mask, tf.int32), axis=1) |
| 46 | + |
| 47 | + self.c_maxlen = tf.reduce_max(self.c_len) |
| 48 | + self.q_maxlen = tf.reduce_max(self.q_len) |
| 49 | + |
| 50 | + self.ch_len = tf.reshape(tf.reduce_sum( |
| 51 | + tf.cast(tf.cast(self.ch, tf.bool), tf.int32), axis=2), [-1]) |
| 52 | + self.qh_len = tf.reshape(tf.reduce_sum( |
| 53 | + tf.cast(tf.cast(self.qh, tf.bool), tf.int32), axis=2), [-1]) |
| 54 | + |
| 55 | + self.ready() |
| 56 | + |
| 57 | + def ready(self): |
| 58 | + N, PL, QL, CL, d, dc, dg = 1, self.c_maxlen, self.q_maxlen, char_limit, hidden, char_dim, char_hidden |
| 59 | + gru = cudnn_gru if use_cudnn else native_gru |
| 60 | + |
| 61 | + with tf.variable_scope("emb"): |
| 62 | + with tf.variable_scope("char"): |
| 63 | + ch_emb = tf.reshape(tf.nn.embedding_lookup( |
| 64 | + self.char_mat, self.ch), [N * PL, CL, dc]) |
| 65 | + qh_emb = tf.reshape(tf.nn.embedding_lookup( |
| 66 | + self.char_mat, self.qh), [N * QL, CL, dc]) |
| 67 | + cell_fw = tf.contrib.rnn.GRUCell(dg) |
| 68 | + cell_bw = tf.contrib.rnn.GRUCell(dg) |
| 69 | + _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( |
| 70 | + cell_fw, cell_bw, ch_emb, self.ch_len, dtype=tf.float32) |
| 71 | + ch_emb = tf.concat([state_fw, state_bw], axis=1) |
| 72 | + _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( |
| 73 | + cell_fw, cell_bw, qh_emb, self.qh_len, dtype=tf.float32) |
| 74 | + qh_emb = tf.concat([state_fw, state_bw], axis=1) |
| 75 | + qh_emb = tf.reshape(qh_emb, [N, QL, 2 * dg]) |
| 76 | + ch_emb = tf.reshape(ch_emb, [N, PL, 2 * dg]) |
| 77 | + |
| 78 | + with tf.name_scope("word"): |
| 79 | + c_emb = tf.nn.embedding_lookup(self.word_mat, self.c) |
| 80 | + q_emb = tf.nn.embedding_lookup(self.word_mat, self.q) |
| 81 | + |
| 82 | + c_emb = tf.concat([c_emb, ch_emb], axis=2) |
| 83 | + q_emb = tf.concat([q_emb, qh_emb], axis=2) |
| 84 | + |
| 85 | + with tf.variable_scope("encoding"): |
| 86 | + rnn = gru(num_layers=3, num_units=d, batch_size=N, |
| 87 | + input_size=c_emb.get_shape().as_list()[-1]) |
| 88 | + c = rnn(c_emb, seq_len=self.c_len) |
| 89 | + q = rnn(q_emb, seq_len=self.q_len) |
| 90 | + |
| 91 | + with tf.variable_scope("attention"): |
| 92 | + qc_att = dot_attention(c, q, mask=self.q_mask, hidden=d) |
| 93 | + rnn = gru(num_layers=1, num_units=d, batch_size=N, |
| 94 | + input_size=qc_att.get_shape().as_list()[-1]) |
| 95 | + att = rnn(qc_att, seq_len=self.c_len) |
| 96 | + |
| 97 | + with tf.variable_scope("match"): |
| 98 | + self_att = dot_attention(att, att, mask=self.c_mask, hidden=d) |
| 99 | + rnn = gru(num_layers=1, num_units=d, batch_size=N, |
| 100 | + input_size=self_att.get_shape().as_list()[-1]) |
| 101 | + match = rnn(self_att, seq_len=self.c_len) |
| 102 | + |
| 103 | + with tf.variable_scope("pointer"): |
| 104 | + init = summ(q[:, :, -2 * d:], d, mask=self.q_mask) |
| 105 | + pointer = ptr_net(batch=N, hidden=init.get_shape().as_list()[-1]) |
| 106 | + logits1, logits2 = pointer(init, match, d, self.c_mask) |
| 107 | + |
| 108 | + with tf.variable_scope("predict"): |
| 109 | + outer = tf.matmul(tf.expand_dims(tf.nn.softmax(logits1), axis=2), |
| 110 | + tf.expand_dims(tf.nn.softmax(logits2), axis=1)) |
| 111 | + outer = tf.matrix_band_part(outer, 0, 15) |
| 112 | + self.yp1 = tf.argmax(tf.reduce_max(outer, axis=2), axis=1) |
| 113 | + self.yp2 = tf.argmax(tf.reduce_max(outer, axis=1), axis=1) |
| 114 | + |
| 115 | + |
| 116 | +class Inference(object): |
| 117 | + |
| 118 | + def __init__(self): |
| 119 | + with open(word_emb_file, "r") as fh: |
| 120 | + self.word_mat = np.array(json.load(fh), dtype=np.float32) |
| 121 | + with open(char_emb_file, "r") as fh: |
| 122 | + self.char_mat = np.array(json.load(fh), dtype=np.float32) |
| 123 | + with open(word2idx_file, "r") as fh: |
| 124 | + self.word2idx_dict = json.load(fh) |
| 125 | + with open(char2idx_file, "r") as fh: |
| 126 | + self.char2idx_dict = json.load(fh) |
| 127 | + self.model = InfModel(self.word_mat, self.char_mat) |
| 128 | + sess_config = tf.ConfigProto(allow_soft_placement=True) |
| 129 | + sess_config.gpu_options.allow_growth = True |
| 130 | + self.sess = tf.Session(config=sess_config) |
| 131 | + saver = tf.train.Saver() |
| 132 | + saver.restore(self.sess, tf.train.latest_checkpoint(save_dir)) |
| 133 | + |
| 134 | + def response(self, context, question): |
| 135 | + sess = self.sess |
| 136 | + model = self.model |
| 137 | + span, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs = self.prepro( |
| 138 | + context, question) |
| 139 | + yp1, yp2 = sess.run([model.yp1, model.yp2], feed_dict={ |
| 140 | + model.c: context_idxs, model.q: ques_idxs, model.ch: context_char_idxs, model.qh: ques_char_idxs}) |
| 141 | + start_idx = span[yp1[0]][0] |
| 142 | + end_idx = span[yp2[0]][1] |
| 143 | + return context[start_idx: end_idx] |
| 144 | + |
| 145 | + def prepro(self, context, question): |
| 146 | + context = context.replace("''", '" ').replace("``", '" ') |
| 147 | + context_tokens = word_tokenize(context) |
| 148 | + context_chars = [list(token) for token in context_tokens] |
| 149 | + spans = convert_idx(context, context_tokens) |
| 150 | + ques = question.replace("''", '" ').replace("``", '" ') |
| 151 | + ques_tokens = word_tokenize(ques) |
| 152 | + ques_chars = [list(token) for token in ques_tokens] |
| 153 | + |
| 154 | + context_idxs = np.zeros([1, len(context_tokens)], dtype=np.int32) |
| 155 | + context_char_idxs = np.zeros( |
| 156 | + [1, len(context_tokens), char_limit], dtype=np.int32) |
| 157 | + ques_idxs = np.zeros([1, len(ques_tokens)], dtype=np.int32) |
| 158 | + ques_char_idxs = np.zeros( |
| 159 | + [1, len(ques_tokens), char_limit], dtype=np.int32) |
| 160 | + |
| 161 | + def _get_word(word): |
| 162 | + for each in (word, word.lower(), word.capitalize(), word.upper()): |
| 163 | + if each in self.word2idx_dict: |
| 164 | + return self.word2idx_dict[each] |
| 165 | + return 1 |
| 166 | + |
| 167 | + def _get_char(char): |
| 168 | + if char in self.char2idx_dict: |
| 169 | + return self.char2idx_dict[char] |
| 170 | + return 1 |
| 171 | + |
| 172 | + for i, token in enumerate(context_tokens): |
| 173 | + context_idxs[0, i] = _get_word(token) |
| 174 | + |
| 175 | + for i, token in enumerate(ques_tokens): |
| 176 | + ques_idxs[0, i] = _get_word(token) |
| 177 | + |
| 178 | + for i, token in enumerate(context_chars): |
| 179 | + for j, char in enumerate(token): |
| 180 | + if j == char_limit: |
| 181 | + break |
| 182 | + context_char_idxs[0, i, j] = _get_char(char) |
| 183 | + |
| 184 | + for i, token in enumerate(ques_chars): |
| 185 | + for j, char in enumerate(token): |
| 186 | + if j == char_limit: |
| 187 | + break |
| 188 | + ques_char_idxs[0, i, j] = _get_char(char) |
| 189 | + return spans, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs |
| 190 | + |
| 191 | + |
| 192 | +# Demo, example from paper "SQuAD: 100,000+ Questions for Machine Comprehension of Text" |
| 193 | +if __name__ == "__main__": |
| 194 | + infer = Inference() |
| 195 | + context = "In meteorology, precipitation is any product of the condensation " \ |
| 196 | + "of atmospheric water vapor that falls under gravity. The main forms " \ |
| 197 | + "of precipitation include drizzle, rain, sleet, snow, graupel and hail." \ |
| 198 | + "Precipitation forms as smaller droplets coalesce via collision with other " \ |
| 199 | + "rain drops or ice crystals within a cloud. Short, intense periods of rain " \ |
| 200 | + "in scattered locations are called “showers”." |
| 201 | + ques1 = "What causes precipitation to fall?" |
| 202 | + ques2 = "What is another main form of precipitation besides drizzle, rain, snow, sleet and hail?" |
| 203 | + ques3 = "Where do water droplets collide with ice crystals to form precipitation?" |
| 204 | + |
| 205 | + # Correct: gravity, Output: drizzle, rain, sleet, snow, graupel and hail |
| 206 | + ans1 = infer.response(context, ques1) |
| 207 | + print("Answer 1: {}".format(ans1)) |
| 208 | + |
| 209 | + # Correct: graupel, Output: graupel |
| 210 | + ans2 = infer.response(context, ques2) |
| 211 | + print("Answer 2: {}".format(ans2)) |
| 212 | + |
| 213 | + # Correct: within a cloud, Output: within a cloud |
| 214 | + ans3 = infer.response(context, ques3) |
| 215 | + print("Answer 3: {}".format(ans3)) |
0 commit comments