Skip to content

Commit ec68753

Browse files
committed
more cleanup
1 parent 10dc024 commit ec68753

3 files changed

Lines changed: 2 additions & 35 deletions

File tree

src/dalle_mtf/sample.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,10 @@
55

66
def sample_autoregressive(inputs,
77
model,
8-
stop_at_token=50256,
98
max_steps=None,
109
temperature=0.9,
11-
padding_id = 0,
12-
min_start_pos = None,
1310
variable_dtype=mtf.VariableDType(tf.float32),
1411
has_partial_sequences=True,
15-
remove_partial_sequences=False,
1612
sampling_keep_top_k=-1,
1713
):
1814
"""Sample randomly one token at a time.
@@ -87,25 +83,10 @@ def sample_autoregressive(inputs,
8783
if not has_partial_sequences:
8884
partial_sequences_eos_count = 0
8985

90-
if stop_at_token is not None:
91-
partial_sequences_eos_count = mtf.reduce_sum(
92-
mtf.to_int32(mtf.equal(inputs, stop_at_token)),
93-
reduced_dim=length_dim)
94-
9586
def cond_fn(position, ids, *unused_states):
9687
"""Should we run another loop iteration?"""
9788
past_end = mtf.greater_equal(position, length_dim.size)
98-
if max_steps:
99-
past_end = mtf.logical_or(
100-
past_end, mtf.greater_equal(position - initial_position, max_steps))
101-
10289
is_done = past_end
103-
if stop_at_token is not None:
104-
eos_count = mtf.reduce_sum(
105-
mtf.to_int32(mtf.equal(ids, stop_at_token)),
106-
reduced_dim=length_dim)
107-
has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
108-
is_done = mtf.logical_or(is_done, has_additional_eos)
10990
all_done = mtf.reduce_all(is_done)
11091
return mtf.logical_not(all_done)
11192

@@ -169,11 +150,4 @@ def body_fn(position, ids, *states):
169150
final_position, outputs = mtf.while_loop(
170151
cond_fn, body_fn, while_loop_inputs)[:2]
171152
del final_position
172-
if has_partial_sequences and remove_partial_sequences:
173-
# Remove partial sequences from outputs
174-
partial_length = mtf.reduce_sum(
175-
mtf.to_int32(mtf.not_equal(inputs, padding_id)),
176-
reduced_dim=length_dim)
177-
outputs = mtf.dynamic_shift(
178-
outputs, -partial_length, length_dim, wrap=False)
179153
return outputs

src/model_fns.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,10 @@ def dalle_model_fn(features, labels, mode, params):
152152

153153
mtf_samples = sample_autoregressive(inputs,
154154
model,
155-
max_steps=model.total_seq_dim, # will always run until the full image is produced
156-
stop_at_token=None,
157155
temperature=0.9,
158-
padding_id = 0,
159156
variable_dtype=model.variable_dtype,
160157
has_partial_sequences=True,
161-
remove_partial_sequences=True,
162-
sampling_keep_top_k=-1,
158+
sampling_keep_top_k=-2,
163159
)
164160

165161
mtf_samples = mtf.anonymize(mtf_samples)

test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,7 @@ def test_sampling():
7373
inputs,
7474
model,
7575
variable_dtype=mtf.VariableDType(),
76-
max_steps = sequence_dim.size,
77-
remove_partial_sequences=False,
78-
stop_at_token=None,
79-
min_start_pos=model.text_seq_len
76+
max_steps = sequence_dim.size
8077
)
8178

8279
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])

0 commit comments

Comments
 (0)