|
5 | 5 |
|
6 | 6 | def sample_autoregressive(inputs, |
7 | 7 | model, |
8 | | - stop_at_token=50256, |
9 | 8 | max_steps=None, |
10 | 9 | temperature=0.9, |
11 | | - padding_id = 0, |
12 | | - min_start_pos = None, |
13 | 10 | variable_dtype=mtf.VariableDType(tf.float32), |
14 | 11 | has_partial_sequences=True, |
15 | | - remove_partial_sequences=False, |
16 | 12 | sampling_keep_top_k=-1, |
17 | 13 | ): |
18 | 14 | """Sample randomly one token at a time. |
@@ -87,25 +83,10 @@ def sample_autoregressive(inputs, |
87 | 83 | if not has_partial_sequences: |
88 | 84 | partial_sequences_eos_count = 0 |
89 | 85 |
|
90 | | - if stop_at_token is not None: |
91 | | - partial_sequences_eos_count = mtf.reduce_sum( |
92 | | - mtf.to_int32(mtf.equal(inputs, stop_at_token)), |
93 | | - reduced_dim=length_dim) |
94 | | - |
95 | 86 | def cond_fn(position, ids, *unused_states): |
96 | 87 | """Should we run another loop iteration?""" |
97 | 88 | past_end = mtf.greater_equal(position, length_dim.size) |
98 | | - if max_steps: |
99 | | - past_end = mtf.logical_or( |
100 | | - past_end, mtf.greater_equal(position - initial_position, max_steps)) |
101 | | - |
102 | 89 | is_done = past_end |
103 | | - if stop_at_token is not None: |
104 | | - eos_count = mtf.reduce_sum( |
105 | | - mtf.to_int32(mtf.equal(ids, stop_at_token)), |
106 | | - reduced_dim=length_dim) |
107 | | - has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) |
108 | | - is_done = mtf.logical_or(is_done, has_additional_eos) |
109 | 90 | all_done = mtf.reduce_all(is_done) |
110 | 91 | return mtf.logical_not(all_done) |
111 | 92 |
|
@@ -169,11 +150,4 @@ def body_fn(position, ids, *states): |
169 | 150 | final_position, outputs = mtf.while_loop( |
170 | 151 | cond_fn, body_fn, while_loop_inputs)[:2] |
171 | 152 | del final_position |
172 | | - if has_partial_sequences and remove_partial_sequences: |
173 | | - # Remove partial sequences from outputs |
174 | | - partial_length = mtf.reduce_sum( |
175 | | - mtf.to_int32(mtf.not_equal(inputs, padding_id)), |
176 | | - reduced_dim=length_dim) |
177 | | - outputs = mtf.dynamic_shift( |
178 | | - outputs, -partial_length, length_dim, wrap=False) |
179 | 153 | return outputs |
0 commit comments