-
Notifications
You must be signed in to change notification settings - Fork 75
Expand file tree
/
Copy pathdecoder_error_checker.py
More file actions
executable file
·104 lines (89 loc) · 4.83 KB
/
decoder_error_checker.py
File metadata and controls
executable file
·104 lines (89 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
'''
File: decoder_error_checker.py
Project: models
File Created: Saturday, 29th December 2018 3:07:21 pm
Author: xiaofeng (sxf1052566766@163.com)
-----
Last Modified: Saturday, 29th December 2018 3:07:45 pm
Modified By: xiaofeng (sxf1052566766@163.com>)
-----
2018.06 - 2018 Latex Math, Latex Math
'''
from __future__ import division
import tensorflow as tf
from models.component.attention_cell_sequence import AttCell
from models.component.decoder_beamsearch import BeamSearchDecoderCell
from models.component.decoder_greedy import GreedyDecoderCell
from models.component.decoder_dynamic import dynamic_decode
from tensorflow.contrib.rnn import GRUCell
from models.component.LnRnn import LNGRUCell, LNLSTMCell
from models.component.word_embeding import Embedding, embedding_initializer
class DecoderAtt(object):
""" Decoder section of the errorchecker model """
def __init__(self, config, vocab):
self._config = config
self._vocab = vocab
self._name = self._config.model.get('errche_decoder_name')
# vocabulary size of the target sequence
self._targ_voc = self._vocab.errche_vocab_size_targ
# index of the END token in the vocabulary
self._id_end = self._config.dataset.get('id_end')
# embeding dim of the target sequence
self._embeding_dim_targ = self._config.model.get('errche_embeding_dims_target')
# dim of encoder
self._rnn_encoder_dim = self._config.model.get('errche_rnn_encoder_dim')
# dim of decoder
self._rnn_decoder_dim = self._config.model.get('errche_rnn_decoder_dim')
# dim of attention
self._att_dim = self._config.model.get('att_dim')
assert self._rnn_encoder_dim * 2 == self._rnn_decoder_dim, \
"Encoder BiRnn out dim is the double encoder dim and it must be equal with decoder dim"
self._tiles = 1 if self._config.model.decoding == 'greedy' else self._config.model.beam_size
self._embedding_table_traget = tf.get_variable(
"targ_vocab_embeding", dtype=tf.float32, shape=[self._targ_voc, self._embeding_dim_targ],
initializer=embedding_initializer())
self._start_token = tf.squeeze(
input=self._embedding_table_traget[0, :],
name='targ_start_flage')
def __call__(self, encoder_out, droupout, input_sequence=None):
self._batch_size = tf.shape(encoder_out)[0]
with tf.variable_scope(self._name, reuse=False,initializer=tf.orthogonal_initializer()):
sequence_embeding = Embedding('embeding', self._embedding_table_traget, input_sequence)
# attention cell come from Rnn
""" Uniform gru cell """
# RnnCell = GRUCell(name='DecoderGru', num_units=self._rnn_decoder_dim)
""" LN gru cell """
RnnCell = LNGRUCell(name='DecoderGru', num_units=self._rnn_decoder_dim)
att_cell = AttCell(
name='AttCell', attention_in=encoder_out, decoder_cell=RnnCell,
n_hid=self._rnn_decoder_dim, dim_att=self._att_dim, dim_o=self._rnn_decoder_dim,
dropuout=droupout, vacab_size=self._targ_voc)
# [batch,sequence_length]
# sequence_length is equal with the input label length
sequence_length = tf.tile(tf.expand_dims(
tf.shape(sequence_embeding)[1], 0), [self._batch_size])
pred_train, _ = tf.nn.dynamic_rnn(
att_cell, sequence_embeding, initial_state=att_cell.initial_state(),
sequence_length=sequence_length, dtype=tf.float32, swap_memory=True)
# evaluating , predict
with tf.variable_scope(self._name, reuse=True):
""" uniform gru cell """
# RnnCell = GRUCell(name='DecoderGru', num_units=self._rnn_decoder_dim)
""" LN gru cell """
RnnCell = LNGRUCell(name='DecoderGru', num_units=self._rnn_decoder_dim)
att_cell = AttCell(
name='AttCell', attention_in=encoder_out, decoder_cell=RnnCell,
n_hid=self._rnn_decoder_dim, dim_att=self._att_dim, dim_o=self._rnn_decoder_dim,
dropuout=droupout, vacab_size=self._targ_voc, tiles=self._tiles)
if self._config.model.decoding == 'beams_search':
decoder_cell = BeamSearchDecoderCell(
self._embedding_table_traget, att_cell, self._batch_size, self._start_token,
self._id_end, self._config.model.beam_size,
self._config.model.div_gamma, self._config.model.div_prob)
else:
decoder_cell = GreedyDecoderCell(
self._embedding_table_traget, att_cell, self._batch_size, self._start_token,
self._id_end)
pred_validate, _ = dynamic_decode(
decoder_cell, self._config.model.MaxPredictLength + 1)
return pred_train, pred_validate