-
Notifications
You must be signed in to change notification settings - Fork 75
Expand file tree
/
Copy pathdecoder_dynamic.py
More file actions
90 lines (70 loc) · 2.99 KB
/
decoder_dynamic.py
File metadata and controls
90 lines (70 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
'''
Filename: decoder_predict.py
Project: tflib
File Created: Thursday, 29th November 2018 7:35:20 pm
Author: xiaofeng (sxf1052566766@163.com)
--------------------------
Last Modified: Thursday, 29th November 2018 7:39:08 pm
Modified By: xiaofeng (sxf1052566766@163.com)
---------------------------
: 2018.06 - 2018 .
'''
import tensorflow as tf
from tensorflow.python.util import nest
def transpose_batch_time(t):
if t.shape.ndims == 2:
return tf.transpose(t, [1, 0])
elif t.shape.ndims == 3:
return tf.transpose(t, [1, 0, 2])
elif t.shape.ndims == 4:
return tf.transpose(t, [1, 0, 2, 3])
else:
raise NotImplementedError
def dynamic_decode(decoder_cell, maximum_iterations):
"""Similar to dynamic_rnn but to decode
Args:
decoder_cell: (instance of DecoderCell) with step method
maximum_iterations: (int)
"""
try:
maximum_iterations = tf.convert_to_tensor(maximum_iterations,
dtype=tf.int32)
except ValueError:
pass
# create TA for outputs by mimicing the structure of decodercell output
def create_ta(d):
return tf.TensorArray(dtype=d, size=0, dynamic_size=True)
initial_time = tf.constant(0, dtype=tf.int32)
initial_outputs_ta = nest.map_structure(create_ta,
decoder_cell.output_dtype)
initial_state, initial_inputs, initial_finished = decoder_cell.initialize()
def condition(time, unused_outputs_ta, unused_state, unused_inputs,
finished):
return tf.logical_not(tf.reduce_all(finished))
def body(time, outputs_ta, state, inputs, finished):
# new_out is contained with logits=new_logits, ids=new_ids, parents=new_parents
new_output, new_state, new_inputs, new_finished = decoder_cell.step(
time, state, inputs, finished)
outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
outputs_ta, new_output)
new_finished = tf.logical_or(
tf.greater_equal(time, maximum_iterations),
new_finished)
return (time + 1, outputs_ta, new_state, new_inputs, new_finished)
with tf.variable_scope("rnn"):
res = tf.while_loop(
condition,
body,
loop_vars=[initial_time, initial_outputs_ta, initial_state,
initial_inputs, initial_finished],
back_prop=False)
# get final outputs and states
# final_outputs_ta is contaied with out
final_outputs_ta, final_state = res[1], res[2]
# unfold and stack the structure from the nested tas
final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
# finalize the computation from the decoder cell
final_outputs = decoder_cell.finalize(final_outputs, final_state)
# transpose the final output
final_outputs = nest.map_structure(transpose_batch_time, final_outputs)
return final_outputs, final_state