From 99d11e1761e2eb6a19432a8e5b0e4439cd75c5a1 Mon Sep 17 00:00:00 2001 From: sid Date: Wed, 13 Jan 2021 23:18:26 +0100 Subject: [PATCH 01/33] add basic sampling code --- src/dalle_mtf/__init__.py | 3 +- src/dalle_mtf/models.py | 22 +---- src/dalle_mtf/sample.py | 176 ++++++++++++++++++++++++++++++++++++++ src/model_fns.py | 56 ++++++++++-- src/model_fns_tf.py | 3 +- 5 files changed, 232 insertions(+), 28 deletions(-) create mode 100644 src/dalle_mtf/sample.py diff --git a/src/dalle_mtf/__init__.py b/src/dalle_mtf/__init__.py index a53a710..d9d604b 100644 --- a/src/dalle_mtf/__init__.py +++ b/src/dalle_mtf/__init__.py @@ -1 +1,2 @@ -from .models import DALLE, DiscreteVAE \ No newline at end of file +from .models import DALLE, DiscreteVAE +from .sample import sample_autoregressive \ No newline at end of file diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 7bc7474..59b83be 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -179,6 +179,7 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.activation_fn = activation_fn if self.is_incremental_inference: assert self.context is not None, "must have context in incremental inference" + assert self.context['mode'] == 'incremental' if params is None: # extra params params = {} self.params = defaultdict(lambda: None, params) @@ -254,25 +255,7 @@ def attention(self, x, n_state, mask, attention_type="global", name="attn"): self.context.record_new_states([k, v]) with tf.variable_scope("attention"): - if attention_type == "local": - # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. - radius = self.params.get("local_attention_radius", 256) - if self.is_incremental_inference: - q *= one_hot - a = mtf_transformer.attention.local_attention_1d( - q, k, v, - length_dim=k.shape[1], - key_dim=self.dimensions["kv_dim"], - value_dim=self.dimensions["kv_dim"], - radius=radius, - length_dim_num_splits=1, - fully_autoregressive=True, - attention_kwargs={}, - ) - if self.is_incremental_inference: - a = mtf.gather(a, self.context.position - 1, seq_dim) - - elif attention_type == "global": + if attention_type == "global": if exists(mask): if not self.is_incremental_inference: broadcasted_mask = mtf.broadcast(mask, @@ -402,6 +385,7 @@ def forward(self, features, return_loss=True, return_logits=False): out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) if not return_loss: + logits = mtf.cast(logits, self.variable_dtype.master_dtype) return logits labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py new file mode 100644 index 0000000..e184ac8 --- /dev/null +++ b/src/dalle_mtf/sample.py @@ -0,0 +1,176 @@ +import mesh_tensorflow as mtf +import tensorflow.compat.v1 as tf +import mesh_tensorflow.transformer as mtf_transformer + + +def sample_autoregressive(inputs, + model, + params, + stop_at_token=50256, + max_steps=None, + temperature=0.9, + variable_dtype=mtf.VariableDType(tf.float32), + has_partial_sequences=True, + remove_partial_sequences=False, + sampling_keep_top_k=-1, + ): + """Sample randomly one token at a time. + + The partial_sequences represent partial sequences to be continued. The + first tokens of each sequence are nonzero representing the given partial + sequences and the last tokens of each sequence are zeros, representing what + needs to be filled in. + + If there are no partial sequences (you want to sample from the beginning), + then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and + has_partial_sequences=False (so we can skip computation). + + Args: + inputs: an int32 Tensor with shape [, length_dim], + model: DALL-E model + params: model paramers. + stop_at_token: an optional integer eos id. Stop when we produce it. + max_steps: an optional integer, the max number of steps to decode. + temperature: an optional floating point value between 0.0 and 1.0 0.0 + means argmax, 1.0 means sample according to predicted distribution. + variable_dtype: a mtf.VariableDType + has_partial_sequences: a boolean + decoding, one per each input layer + the embedding layer + remove_partial_sequences: a boolean - whether to remove the partial + sequences from the output + sampling_keep_top_k: an integer - if not -1, only sample from the top k + logits. + + Returns: + a Tensor with shape [, length_dim] + """ + + # with dalle, inputs will be a text sequence of len 256, then the rest image tokens. + # the parts we want to fill in will be <|pad_token|>, which we should assign in the input + + batch_dims = inputs.shape.dims[:-1] + length_dim = inputs.shape.dims[-1] + padding_id = params.get("padding_id", 0) + + initial_position = mtf.reduce_sum( + mtf.to_int32(mtf.not_equal(inputs, padding_id)), + reduced_dim=length_dim) # Gets position where zero padding starts + + length_range = mtf.range(inputs.mesh, length_dim, tf.int32) + + # Builds context to pass around internally + # The 'first part' context records initial states of k / v / x + + context_first_part = mtf_transformer.transformer.Context( + model=None, + mesh=inputs.mesh, + batch_dims=batch_dims, + length_dim=length_dim, + variable_dtype=variable_dtype, + mode="first_part", + position=length_range, + position_is_default=True, + new_states=[], + initial_position=initial_position, + sequence_id=None, + constant_states=[], + inputs=inputs) + model.context = context_first_part + + with tf.variable_scope('dall-e'): + logits = model.forward({'tokens': inputs}, return_loss=False, return_logits=True) + del logits + + if not has_partial_sequences: + initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] + else: + initial_states = context_first_part.new_states + + if not has_partial_sequences: + partial_sequences_eos_count = 0 + + if stop_at_token is not None: + partial_sequences_eos_count = mtf.reduce_sum( + mtf.to_int32(mtf.equal(inputs, stop_at_token)), + reduced_dim=length_dim) + + def cond_fn(position, ids, *unused_states): + """Should we run another loop iteration?""" + past_end = mtf.greater_equal(position, length_dim.size) + if max_steps: + past_end = mtf.logical_or( + past_end, mtf.greater_equal(position - initial_position, max_steps)) + + is_done = past_end + if stop_at_token is not None: + eos_count = mtf.reduce_sum( + mtf.to_int32(mtf.equal(ids, stop_at_token)), + reduced_dim=length_dim) + has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) + is_done = mtf.logical_or(is_done, has_additional_eos) + all_done = mtf.reduce_all(is_done) + return mtf.logical_not(all_done) + + def body_fn(position, ids, *states): + """One step in the decode loop.""" + nonlocal sampling_keep_top_k + + context = mtf_transformer.transformer.Context( + model=None, + mesh=inputs.mesh, + batch_dims=batch_dims, + length_dim=length_dim, + variable_dtype=variable_dtype, + mode="incremental", + position=position, + position_is_default=True, + states=states, + new_states=[], + initial_position=position, + sequence_id=None, + inputs=ids) + + model.is_incremental_inference = True + model.context = context + with tf.variable_scope("dall-e", reuse=tf.AUTO_REUSE): + logits = model.forward({'tokens': inputs}, return_loss=False, return_logits=True) + + # By default, do top_k sampling of 0.9 + if sampling_keep_top_k == -2: + sampling_keep_top_k = int(logits.shape[-1].size * 0.1) + + if sampling_keep_top_k != -1: + if sampling_keep_top_k <= 0: + raise ValueError("sampling_keep_top_k must either be -1 or positive.") + k_largest = mtf.nth_largest_element( + logits, n=sampling_keep_top_k, + reduced_dim=model.dimensions['final_vocab_dim']) + logits = mtf.where(mtf.less_equal(logits, k_largest), + mtf.ones_like(logits) * -1e6, logits) + + # temperature sampling + ids_this_step = mtf.sample_with_temperature( + logits, model.dimensions['final_vocab_dim'], temperature) + + # reshape & assign results + ids_this_step = mtf.reshape(ids_this_step, batch_dims) + one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) + one_new_id = ids_this_step * one_hot + new_ids = (1 - one_hot) * ids + one_new_id + new_position = position + 1 + ret = [new_position, new_ids] + ret += context.new_states + return ret + + while_loop_inputs = [initial_position, inputs] + initial_states + final_position, outputs = mtf.while_loop( + cond_fn, body_fn, while_loop_inputs)[:2] + del final_position + if has_partial_sequences and remove_partial_sequences: + # Remove partial sequences from outputs + partial_length = mtf.reduce_sum( + mtf.to_int32(mtf.not_equal(inputs, padding_id)), + reduced_dim=length_dim) + outputs = mtf.dynamic_shift( + outputs, -partial_length, length_dim, wrap=False) + return outputs diff --git a/src/model_fns.py b/src/model_fns.py index f89ebb8..2c38147 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -4,8 +4,9 @@ import mesh_tensorflow.transformer as mtf_transformer from .optimizers import get_optimizer from .utils import mode_to_str, get_graph_info, create_host_call, simd_mesh_setup, scalar_summary -from .dalle_mtf import DALLE +from .dalle_mtf import DALLE, sample_autoregressive from .vae_tf import DiscreteVAE +from tensorflow.python.ops import resources def initialize_vae_weights(checkpoint_path, scope="vae"): @@ -16,7 +17,7 @@ def initialize_vae_weights(checkpoint_path, scope="vae"): vars_to_restore = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=scope) ckpt_vars = [ - name for name, _ in tf.train.list_variables(checkpoint_path)] + name for name, _ in tf.train.list_variables(checkpoint_path)] tf.logging.info(f"RESTORING {len(vars_to_restore)} VAE VARS FROM CHECKPOINT: ") tf.logging.info(f"CHECKPOINT PATH: {checkpoint_path}") tf.logging.info(f"CHECKPOINT VARS:") @@ -132,7 +133,48 @@ def dalle_model_fn(features, labels, mode, params): scalar_summary("input_image", mtf_features["image_inputs"]) if mode == tf.estimator.ModeKeys.PREDICT: - raise NotImplementedError + # Set up the model for prediction + inputs = mtf_features["tokens"] + + mtf_samples = sample_autoregressive(inputs, + model, + params, + stop_at_token=model.eos_token_id, + max_steps=None, + temperature=0.9, + variable_dtype=model.variable_dtype, + has_partial_sequences=True, + remove_partial_sequences=True, + sampling_keep_top_k=-1, + ) + + mtf_samples = mtf.anonymize(mtf_samples) + inputs = mtf.anonymize(inputs) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) + inputs = lowering.export_to_tf_tensor(inputs) + outputs = lowering.export_to_tf_tensor(mtf_samples) + # predictions_decoded = vae.decode(outputs) + predictions = { + "inputs": inputs, + "outputs": outputs} + + def scaffold_fn(): + return tf.train.Scaffold( + local_init_op=tf.group( + tf.train.Scaffold.default_local_init_op(), + lowering.copy_masters_to_slices(), + name="mtf_local_init_op"), + ready_op=tf.concat( + [tf.report_uninitialized_variables(), + resources.report_uninitialized_resources()], + axis=0, + name="mtf_ready_op")) + + return tpu_estimator.TPUEstimatorSpec( + mode=tf.estimator.ModeKeys.PREDICT, + predictions=predictions, + scaffold_fn=scaffold_fn, + prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) @@ -155,8 +197,9 @@ def dalle_model_fn(features, labels, mode, params): if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): - loss, loss_batch = model.forward(mtf_features, return_loss=True) - return {"loss": loss, "loss_batch": loss_batch} + with tf.variable_scope('dall-e'): + loss, loss_batch = model.forward(mtf_features, return_loss=True) + return {"loss": loss, "loss_batch": loss_batch} # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, model.dimensions["batch_dim"], @@ -164,7 +207,8 @@ def serialized_fn(mtf_features): loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] else: - loss, loss_batch = model.forward(mtf_features, return_loss=True) + with tf.variable_scope('dall-e'): + loss, loss_batch = model.forward(mtf_features, return_loss=True) del loss_batch # TODO: may need this for some metrics - otherwise, remove from output diff --git a/src/model_fns_tf.py b/src/model_fns_tf.py index 784a702..937dbdb 100644 --- a/src/model_fns_tf.py +++ b/src/model_fns_tf.py @@ -1,9 +1,8 @@ import tensorflow.compat.v1 as tf import tensorflow.compat.v2 as tf2 from tensorflow.python.tpu import tpu_estimator -from .optimizers import get_optimizer from .vae_tf import DiscreteVAE -from .utils import scalar_summary, mode_to_str, create_host_call +from .utils import mode_to_str def vae_model_fn(features, labels, mode, params): From b9fd039cc6639feb4348a14566fe3652942ebc23 Mon Sep 17 00:00:00 2001 From: sid Date: Thu, 14 Jan 2021 00:15:49 +0100 Subject: [PATCH 02/33] add prediction input / output fns --- src/input_fns.py | 29 +++++++++++++++++- src/model_fns.py | 75 ++++++++++++++++++++++++++-------------------- src/utils/utils.py | 6 +++- train_dalle.py | 13 +++++++- 4 files changed, 88 insertions(+), 35 deletions(-) diff --git a/src/input_fns.py b/src/input_fns.py index ee35bfc..d503647 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -38,6 +38,31 @@ def truncate_or_pad_label(label, params): return label +def pred_input(params, tokenizer, prompt='a cat in a hat'): + tokens = tokenizer.encode(prompt).ids + if len(tokens) > params["total_seq_len"]: + tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating " + "input.") + tokens = tokens[len(tokens) - params["total_seq_len"]:] # TODO: left or right truncate here? + if len(tokens) < params["total_seq_len"]: + tokens = tf.pad(tokens, [[0, params["total_seq_len"] - len(tokens)]], constant_values=params["padding_id"]) + t = tf.broadcast_to(tokens, [params["batch_size"], params["total_seq_len"]]) + dataset = tf.data.Dataset.from_tensors(t) + + def _dummy_labels(x): + return x, x + + dataset = dataset.map(_dummy_labels) + return dataset + + +def pred_output(predictions, out_name='test'): + with tf.gfile.Open(f"{out_name}.txt", "w") as f: + for i, p in enumerate(predictions): + p = p["outputs"] + f.write(str(p["outputs"])) + + def read_labeled_tfrecord(params): def read_fn(example): features = { @@ -103,6 +128,7 @@ def _process_path(file_path): dataset = configure_for_performance(dataset, params, eval) return dataset.repeat() + def dalle_input_fn(params, eval=False): path = params["dataset"]["train_path"] if not eval else params["dataset"]["eval_path"] files = tf.io.gfile.glob(path) @@ -113,7 +139,8 @@ def dalle_input_fn(params, eval=False): if not eval: dataset = dataset.shuffle(file_count, reshuffle_each_iteration=False) - dataset = dataset.apply(tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False)) + dataset = dataset.apply( + tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False)) parse_fn = read_labeled_tfrecord(params) dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = configure_for_performance(dataset, params, eval) diff --git a/src/model_fns.py b/src/model_fns.py index 53a8975..776b09a 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -64,18 +64,17 @@ def dalle_model_fn(features, labels, mode, params): vae, vae_checkpoint_path = load_vae_model(params, mode_str) initialize_vae_weights(vae_checkpoint_path) - H = W = params["dataset"]["image_size"] - image_seq_len = (vae.H // (2 ** len(vae.convblocks))) ** 2 // (vae.stack_factor ** 2) # TODO: check this is correct batch_size = params[f"{mode_str}_batch_size"] n_channels = params.get("input_channels", 3) + if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: - with tf.variable_scope("vae"): - vae_logits = vae.forward(features, return_logits=True) + with tf.variable_scope("vae"): + vae_logits = vae.forward(features, return_logits=True) - # TODO: using argmax sampling for now, but is that optimal? - tokens = tf.math.argmax(vae_logits, -1) - img_tokens_reshaped = tf.cast(tf.reshape(tokens, (batch_size, image_seq_len)), tf.int32) + # TODO: using argmax sampling for now, but is that optimal? + tokens = tf.math.argmax(vae_logits, -1) + img_tokens_reshaped = tf.cast(tf.reshape(tokens, (batch_size, params['image_seq_len'])), tf.int32) # Construct mtf graph + mesh from params graph = mtf.Graph() @@ -99,7 +98,7 @@ def dalle_model_fn(features, labels, mode, params): text_vocab_size=params["text_vocab_size"], image_vocab_size=params["image_vocab_size"], text_seq_len=params["text_seq_len"], - image_seq_len=image_seq_len, + image_seq_len=params['image_seq_len'], n_layers=params["n_layers"], n_heads=params["n_heads"], batch_size=batch_size, @@ -110,29 +109,41 @@ def dalle_model_fn(features, labels, mode, params): # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step - features_dict = {"image_inputs": features, - "text_inputs": labels} - mtf_features = {} - for key, x in features_dict.items(): - if x is not None: - if key == "text_inputs": - text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) - x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) - mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) - - mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - - if key == "image_inputs": - mtf_shape = mtf.Shape([ - model.dimensions["batch_dim"], - mtf.Dimension("img_height_dim", vae.H), - mtf.Dimension("img_width_dim", vae.W), - mtf.Dimension("img_channel_dim", vae.num_ch), - ]) - x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC - mtf_features["image_inputs"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - - scalar_summary("input_image", mtf_features["image_inputs"]) + if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: + features_dict = {"image_inputs": features, + "text_inputs": labels} + mtf_features = {} + for key, x in features_dict.items(): + if x is not None: + if key == "text_inputs": + text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) + x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) + mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) + + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + + if key == "image_inputs": + mtf_shape = mtf.Shape([ + model.dimensions["batch_dim"], + mtf.Dimension("img_height_dim", vae.H), + mtf.Dimension("img_width_dim", vae.W), + mtf.Dimension("img_channel_dim", vae.num_ch), + ]) + x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC + mtf_features["image_inputs"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + scalar_summary("input_image", mtf_features["image_inputs"]) + else: + features_dict = {"text_inputs": labels} + mtf_features = {} + for key, x in features_dict.items(): + if x is not None: + if key == "text_inputs": + text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) + x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) + mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) + + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["tokens"] @@ -151,7 +162,7 @@ def dalle_model_fn(features, labels, mode, params): mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) - lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) # predictions_decoded = vae.decode(outputs) diff --git a/src/utils/utils.py b/src/utils/utils.py index 45178d6..95e8751 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -224,4 +224,8 @@ def scalar_summary(name, x): Returns: a Tensor which is identical in value to x """ - return ScalarSummaryOperation(name, x) \ No newline at end of file + return ScalarSummaryOperation(name, x) + +def get_image_seq_len(dalle_params): + return (dalle_params["vae_params"]['dataset']['image_size'] // (2 ** len(dalle_params["vae_params"]['convblocks']))) ** 2 // ( + dalle_params.get("vae_params").get("stack_factor", 1) ** 2) \ No newline at end of file diff --git a/train_dalle.py b/train_dalle.py index e5c4428..9432604 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -6,7 +6,7 @@ import argparse from src.utils import * from src.model_fns import dalle_model_fn -from src.input_fns import dalle_input_fn +from src.input_fns import dalle_input_fn, pred_input, pred_output from src.data import get_tokenizer def parse_args(): @@ -18,6 +18,8 @@ def parse_args(): parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.") parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " "starts a new training run") + parser.add_argument('--predict', action='store_true', help='run model in predict mode') + parser.add_argument('--prompt', type=str, default='a cat in a hat') args = parser.parse_args() assert args.model is not None, "Model must be set" return args @@ -46,6 +48,8 @@ def main(): params["gpu_ids"] = args.gpu_ids tokenizer = get_tokenizer(params["tokenizer"]) assert len(tokenizer) == params["text_vocab_size"], f"tokenizer vocab size {len(tokenizer)} must equal model vocab size {params['text_vocab_size']}" + params['image_seq_len'] = get_image_seq_len(params) + params['total_seq_len'] = params['image_seq_len'] + params['text_seq_len'] params["padding_id"] = tokenizer.encode(tokenizer.pad_token)[0] # Set up TPUs and Estimator if args.tpu == "colab": @@ -76,6 +80,13 @@ def main(): eval_batch_size=params["eval_batch_size"], predict_batch_size=params["predict_batch_size"], params=params) + if args.predict: + # Predict + pred_input_fn = partial(pred_input, params, tokenizer, args.prompt) + predictions = estimator.predict(input_fn=pred_input_fn) + logging.info("Predictions generated") + pred_output(predictions, 'test') + return has_predict_or_eval_steps = params["predict_steps"] > 0 or params["eval_steps"] > 0 if has_predict_or_eval_steps: From cf76c6c3d5734e9b7b98c2ef338680845571a9d3 Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 14 Jan 2021 00:16:16 +0000 Subject: [PATCH 03/33] get sample_autoregressive working --- configs/dalle_coco.json | 8 ++++---- src/dalle_mtf/models.py | 14 ++++++++++++-- src/input_fns.py | 3 +-- src/model_fns.py | 6 ++---- train_dalle.py | 2 +- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/configs/dalle_coco.json b/configs/dalle_coco.json index 7d2c768..d71e6a6 100644 --- a/configs/dalle_coco.json +++ b/configs/dalle_coco.json @@ -7,19 +7,19 @@ }, "train_batch_size": 128, "eval_batch_size": 128, - "predict_batch_size": 128, + "predict_batch_size": 16, "steps_per_checkpoint": 5000, "iterations": 1000, "train_steps": 100000, "predict_steps": 0, "eval_steps": 0, "n_channels": 3, - "bf_16": false, + "bf_16": true, "recompute_grad": true, "lr": 0.0001, - "model_path": "gs://neo-models/dalle_coco/", + "model_path": "gs://neo-models/dalle_coco_sample/", "mesh_shape": "data:16,model:2", - "layout": "batch_dim:data", + "layout": "batch_dim:data,embed_dim:model", "n_embd": 1024, "text_vocab_size": 50258, "image_vocab_size": 512, diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 59b83be..f55ea54 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -228,8 +228,13 @@ def get_attn_mask(self, mesh, nd, ns): return self.attn_mask def attention(self, x, n_state, mask, attention_type="global", name="attn"): - # x :: [batch, seq, n_embd] - batch_dim, seq_dim, embd_dim = x_shape = x.shape + if not self.is_incremental_inference: + # x :: [batch, seq, n_embd] + batch_dim, seq_dim, embd_dim = x_shape = x.shape + else: + batch_dim, embd_dim = x_shape = x.shape + seq_dim = self.dimensions['total_seq_dim'] + assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads" with tf.variable_scope(name): # Compute attention inputs @@ -379,6 +384,11 @@ def to_logits(self, x): def forward(self, features, return_loss=True, return_logits=False): inputs = features["tokens"] + if self.is_incremental_inference: + # reshape inputs if in inference mode + inputs = mtf.gather(inputs, self.context.position - 1, self.dimensions['total_seq_dim']) + inputs = mtf.reshape(inputs, [self.dimensions['batch_dim']]) + tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding") mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"]) diff --git a/src/input_fns.py b/src/input_fns.py index d503647..74e6ab8 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -39,7 +39,7 @@ def truncate_or_pad_label(label, params): def pred_input(params, tokenizer, prompt='a cat in a hat'): - tokens = tokenizer.encode(prompt).ids + tokens = tokenizer.encode(prompt) if len(tokens) > params["total_seq_len"]: tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating " "input.") @@ -59,7 +59,6 @@ def _dummy_labels(x): def pred_output(predictions, out_name='test'): with tf.gfile.Open(f"{out_name}.txt", "w") as f: for i, p in enumerate(predictions): - p = p["outputs"] f.write(str(p["outputs"])) diff --git a/src/model_fns.py b/src/model_fns.py index 776b09a..4aabd5f 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -138,11 +138,9 @@ def dalle_model_fn(features, labels, mode, params): for key, x in features_dict.items(): if x is not None: if key == "text_inputs": - text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) - x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) + text_tokens = tf.reshape(x, [batch_size, params["total_seq_len"]]) mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) - - mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, text_tokens, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction diff --git a/train_dalle.py b/train_dalle.py index 9432604..6ad8fe6 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -82,7 +82,7 @@ def main(): params=params) if args.predict: # Predict - pred_input_fn = partial(pred_input, params, tokenizer, args.prompt) + pred_input_fn = partial(pred_input, tokenizer=tokenizer, prompt=args.prompt) predictions = estimator.predict(input_fn=pred_input_fn) logging.info("Predictions generated") pred_output(predictions, 'test') From c7ff6c4549bf25a3c5a9721e345f09ea18ecdf09 Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 14 Jan 2021 00:21:22 +0000 Subject: [PATCH 04/33] truncate text tokens properly --- src/input_fns.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/input_fns.py b/src/input_fns.py index 74e6ab8..0862147 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -40,10 +40,10 @@ def truncate_or_pad_label(label, params): def pred_input(params, tokenizer, prompt='a cat in a hat'): tokens = tokenizer.encode(prompt) - if len(tokens) > params["total_seq_len"]: + if len(tokens) > params["text_seq_len"]: tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating " "input.") - tokens = tokens[len(tokens) - params["total_seq_len"]:] # TODO: left or right truncate here? + tokens = tokens[len(tokens) - params["text_seq_len"]:] # TODO: left or right truncate here? if len(tokens) < params["total_seq_len"]: tokens = tf.pad(tokens, [[0, params["total_seq_len"] - len(tokens)]], constant_values=params["padding_id"]) t = tf.broadcast_to(tokens, [params["batch_size"], params["total_seq_len"]]) @@ -59,7 +59,7 @@ def _dummy_labels(x): def pred_output(predictions, out_name='test'): with tf.gfile.Open(f"{out_name}.txt", "w") as f: for i, p in enumerate(predictions): - f.write(str(p["outputs"])) + f.write(str(p["outputs"].tolist())) def read_labeled_tfrecord(params): From 4d51fd94bf74ad1424445da12bcfb9ffba0e5d85 Mon Sep 17 00:00:00 2001 From: connor Date: Thu, 14 Jan 2021 00:59:10 +0000 Subject: [PATCH 05/33] log model params to tensorboard --- src/utils/utils.py | 33 ++++++++++++++++++++++++++++++++- train_dalle.py | 1 + train_vae.py | 1 + 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/utils/utils.py b/src/utils/utils.py index 95e8751..e1ea317 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -9,6 +9,8 @@ import logging import sys from mesh_tensorflow.ops import Operation, Tensor +import re + def fetch_model_params(model): model_path = model if model.endswith(".json") else f"./configs/{model}.json" @@ -226,6 +228,35 @@ def scalar_summary(name, x): """ return ScalarSummaryOperation(name, x) + def get_image_seq_len(dalle_params): return (dalle_params["vae_params"]['dataset']['image_size'] // (2 ** len(dalle_params["vae_params"]['convblocks']))) ** 2 // ( - dalle_params.get("vae_params").get("stack_factor", 1) ** 2) \ No newline at end of file + dalle_params.get("vae_params").get("stack_factor", 1) ** 2) + +def save_config(params_dict, logdir): + tf.logging.info(f"Saving config to {logdir}") + text = "{\n\n" + total_params = len(params_dict) + for count, key in enumerate(params_dict): + config_value = str(params_dict[key]) + if re.search('[a-zA-Z]', config_value): + if config_value.lower() != 'true': + if config_value.lower() != 'false': + if config_value[0] != '[': + # TODO: Making a manual exception for parsing epsilon right now since it's the only number in + # scientific notation. Should fix this. + if key != "epsilon": + config_value = f'"{config_value}"' + if count == total_params - 1: + text += f'"{str(key)}"' + ' : ' + config_value + '\n\n' + else: + text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n' + text += '\n\n}' + sess = tf.InteractiveSession() + summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text)) + summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph) + text = sess.run(summary_op) + summary_writer.add_summary(text, 0) + summary_writer.flush() + summary_writer.close() + tf.reset_default_graph() \ No newline at end of file diff --git a/train_dalle.py b/train_dalle.py index 6ad8fe6..8ad9a8e 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -31,6 +31,7 @@ def main(): logging = setup_logging(args) params = fetch_model_params(args.model) params["vae_params"] = fetch_model_params(params["vae_model"]) + save_config(params, params['model_dir']) assert params["model_type"].lower() == "dalle", f'model_type {params["model_type"]} not recognized' # Confirm deletion of checkpoint files if --new flag is set diff --git a/train_vae.py b/train_vae.py index 835cbbc..053cc79 100644 --- a/train_vae.py +++ b/train_vae.py @@ -28,6 +28,7 @@ def main(): args = parse_args() logging = setup_logging(args) params = fetch_model_params(args.model) + save_config(params, params['model_dir']) assert params["model_type"].lower() == "vae", f'model_type {params["model_type"]} not recognized' # get current step From 346871f5b2f0dad5019785881f2cbcdf62555a8d Mon Sep 17 00:00:00 2001 From: Ben Wang Date: Thu, 14 Jan 2021 15:49:50 +1100 Subject: [PATCH 06/33] add vae decoding and write to jpeg --- src/input_fns.py | 8 +++++--- src/model_fns.py | 9 +++++++-- src/vae_tf/models.py | 22 +++++++++++++++++++--- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/input_fns.py b/src/input_fns.py index 0862147..3559227 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -1,3 +1,5 @@ +import imageio +import numpy as np import tensorflow.compat.v1 as tf @@ -57,9 +59,9 @@ def _dummy_labels(x): def pred_output(predictions, out_name='test'): - with tf.gfile.Open(f"{out_name}.txt", "w") as f: - for i, p in enumerate(predictions): - f.write(str(p["outputs"].tolist())) + for i, p in enumerate(predictions): + denormalize = lambda x: (((x + 1) / 2) * 255.0).astype(np.uint8) + imageio.imwrite(f"{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"])) def read_labeled_tfrecord(params): diff --git a/src/model_fns.py b/src/model_fns.py index 4aabd5f..ef022e6 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -163,10 +163,15 @@ def dalle_model_fn(features, labels, mode, params): lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) - # predictions_decoded = vae.decode(outputs) + + img_outputs = outputs[:, -model.image_seq_len:] + predictions_decoded = vae.decode(img_outputs) + predictions = { "inputs": inputs, - "outputs": outputs} + "outputs": outputs, + "predictions_decoded": predictions_decoded + } def scaffold_fn(): return tf.train.Scaffold( diff --git a/src/vae_tf/models.py b/src/vae_tf/models.py index d7dd073..6c8e8f6 100644 --- a/src/vae_tf/models.py +++ b/src/vae_tf/models.py @@ -75,6 +75,8 @@ def __init__(self, self.recompute_grad = recompute_grad self.bf16 = use_bf16 + self.n_hid = convblocks[-1][1] + assert math.log2(stack_factor).is_integer() # maybe you don't actually need this? self.stack_factor = stack_factor @@ -109,7 +111,6 @@ def encoder_block(x, channels=channels): x = x + res_out with tf.variable_scope(f"codebook"): - self.n_hid = x.shape[-1] embedding = tf.get_variable("codebook", shape=[self.n_hid, self.num_tokens], dtype=tf.float32) if self.bf16: @@ -119,9 +120,8 @@ def encoder_block(x, channels=channels): return output - def decoder(self, x): - with tf.variable_scope(f"codebook", reuse=True): + with tf.variable_scope(f"codebook", reuse=tf.AUTO_REUSE): embedding = tf.get_variable("codebook", shape=[self.n_hid, self.num_tokens], dtype=tf.float32) x = tf.matmul(x, embedding, transpose_b=True) @@ -162,6 +162,22 @@ def decoder_block(x, channels=channels): return x + def decode(self, input_indices): + batch, seqlen = input_indices.shape + + print(f"seqlen {seqlen}") + print(f"side expected {self.W // (2 ** len(self.convblocks))}") + + assert seqlen == (self.W // (2 ** len(self.convblocks))) * (self.H // (2 ** len(self.convblocks))) + + input_onehot = tf.one_hot(input_indices, self.num_tokens) + input_reshaped = tf.reshape(input_onehot, [batch, + self.H // (2 ** len(self.convblocks)), + self.W // (2 ** len(self.convblocks)), + self.num_tokens]) # NHWC + + return self.decoder(input_reshaped) + def forward(self, features, return_recon_loss=False, return_logits=False, hard_gumbel=True, temperature=1.): if isinstance(features, dict): img = features["inputs"] From d13c3309ac2124be2a9d3fe68e4e927e420a5883 Mon Sep 17 00:00:00 2001 From: Ben Wang Date: Thu, 14 Jan 2021 16:38:17 +1100 Subject: [PATCH 07/33] unshift image outputs at decode time --- src/model_fns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model_fns.py b/src/model_fns.py index ef022e6..a51e4dc 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -164,7 +164,7 @@ def dalle_model_fn(features, labels, mode, params): inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) - img_outputs = outputs[:, -model.image_seq_len:] + img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size predictions_decoded = vae.decode(img_outputs) predictions = { From f8a744929bb4189cf3efd0378daf027996ce7c74 Mon Sep 17 00:00:00 2001 From: Ben Wang Date: Thu, 14 Jan 2021 19:42:46 +1100 Subject: [PATCH 08/33] dirty hack to use vae decoder params when training dalle --- src/model_fns.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/model_fns.py b/src/model_fns.py index a51e4dc..4214274 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -76,6 +76,10 @@ def dalle_model_fn(features, labels, mode, params): tokens = tf.math.argmax(vae_logits, -1) img_tokens_reshaped = tf.cast(tf.reshape(tokens, (batch_size, params['image_seq_len'])), tf.int32) + # TODO: get rid of this ugly hack, its just to pull the decoder parameters in during training + with tf.variable_scope('vae'): + vae.decoder(tf.zeros_like(vae_logits)) + # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) @@ -165,7 +169,8 @@ def dalle_model_fn(features, labels, mode, params): outputs = lowering.export_to_tf_tensor(mtf_samples) img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size - predictions_decoded = vae.decode(img_outputs) + with tf.variable_scope('vae'): + predictions_decoded = vae.decode(img_outputs) predictions = { "inputs": inputs, From ff56d1265482f65afd57df25b9bb1dbdbccab373 Mon Sep 17 00:00:00 2001 From: Leo Gao <54557097+leogao2@users.noreply.github.com> Date: Sat, 16 Jan 2021 19:13:35 -0700 Subject: [PATCH 09/33] Move initialize_vae_weights to after lowering --- src/model_fns.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model_fns.py b/src/model_fns.py index 4214274..b9731bd 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -63,7 +63,6 @@ def dalle_model_fn(features, labels, mode, params): # load vae in tensorflow graph before mtf vae, vae_checkpoint_path = load_vae_model(params, mode_str) - initialize_vae_weights(vae_checkpoint_path) H = W = params["dataset"]["image_size"] batch_size = params[f"{mode_str}_batch_size"] n_channels = params.get("input_channels", 3) @@ -168,6 +167,8 @@ def dalle_model_fn(features, labels, mode, params): inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) + initialize_vae_weights(vae_checkpoint_path) + img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size with tf.variable_scope('vae'): predictions_decoded = vae.decode(img_outputs) From 4c4e0e06fa211036cf4b41d4942cba467b07a529 Mon Sep 17 00:00:00 2001 From: connor Date: Sun, 17 Jan 2021 14:13:55 +0000 Subject: [PATCH 10/33] fix vae checkpoint load in training --- src/input_fns.py | 3 +++ src/model_fns.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/input_fns.py b/src/input_fns.py index 3559227..b07098a 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -61,6 +61,9 @@ def _dummy_labels(x): def pred_output(predictions, out_name='test'): for i, p in enumerate(predictions): denormalize = lambda x: (((x + 1) / 2) * 255.0).astype(np.uint8) + # to debug: + # with open(f"{out_name}_{i}.txt", 'w') as f: + # f.write(str(p["outputs"].tolist())) imageio.imwrite(f"{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"])) diff --git a/src/model_fns.py b/src/model_fns.py index b9731bd..52cf870 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -134,7 +134,8 @@ def dalle_model_fn(features, labels, mode, params): ]) x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC mtf_features["image_inputs"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - scalar_summary("input_image", mtf_features["image_inputs"]) + denormalize = lambda x: (x + 1) / 2 + scalar_summary("input_image", denormalize(mtf_features["image_inputs"])) else: features_dict = {"text_inputs": labels} mtf_features = {} @@ -163,19 +164,21 @@ def dalle_model_fn(features, labels, mode, params): mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) - lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=params.get('autostack', True)) + inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) initialize_vae_weights(vae_checkpoint_path) - + img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size + with tf.variable_scope('vae'): predictions_decoded = vae.decode(img_outputs) predictions = { "inputs": inputs, - "outputs": outputs, + "outputs": img_outputs, "predictions_decoded": predictions_decoded } @@ -250,11 +253,12 @@ def serialized_fn(mtf_features): get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors - lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=params.get('autostack', True)) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) + if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) @@ -264,8 +268,12 @@ def serialized_fn(mtf_features): tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step train_op = tf.group(tf_update_ops) + with mtf.utils.outside_all_rewrites(): + # only *now* can we initialize vae weights (stupid tensorflow) + initialize_vae_weights(vae_checkpoint_path) + # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: From 2c14bde296835ae7c16b4dfca6da712de9fb4c84 Mon Sep 17 00:00:00 2001 From: connor Date: Mon, 18 Jan 2021 19:45:25 +0000 Subject: [PATCH 11/33] fix parameter count logging --- src/utils/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils/utils.py b/src/utils/utils.py index e1ea317..70869d5 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -68,7 +68,7 @@ def get_n_trainable_vars(graph): for dim in shape: variable_parameters *= dim.size total_parameters += variable_parameters - print(f"\n\nN PARAMS:\n{total_parameters:,}\n\n") + tf.logging.info(f"\n\nN PARAMS:\n{total_parameters:,}\n\n") def print_dim_names(graph): @@ -85,10 +85,10 @@ def print_dim_names(graph): # Print all dim names in graph & write to file all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims unique_dims = list(set(all_dim_names)) - print("ALL DIM NAMES:") + tf.logging.info("ALL DIM NAMES:") for dim_name in unique_dims: - print(dim_name) - print('\n') + tf.logging.info(dim_name) + tf.logging.info('\n') def get_graph_info(graph): From 130c26e3126659f1724d8c1bafad85b85b7c09dc Mon Sep 17 00:00:00 2001 From: connor Date: Mon, 18 Jan 2021 19:46:22 +0000 Subject: [PATCH 12/33] fix image vocab size --- configs/dalle_coco.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/dalle_coco.json b/configs/dalle_coco.json index d71e6a6..0454b98 100644 --- a/configs/dalle_coco.json +++ b/configs/dalle_coco.json @@ -22,7 +22,7 @@ "layout": "batch_dim:data,embed_dim:model", "n_embd": 1024, "text_vocab_size": 50258, - "image_vocab_size": 512, + "image_vocab_size": 2048, "text_seq_len": 256, "n_layers": 12, "n_heads": 8, From 42e7677401fe8fd96f65d92bb9f3627a5fc96b48 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 2 Apr 2021 13:19:50 -0700 Subject: [PATCH 13/33] add missing dep --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 6ef6eb9..bc423c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ tensorflow==2.4.0 +tensorflow-datasets mesh_tensorflow==0.1.18 tpunicorn lm_dataformat From fcab065f831c0949dd23101c363c68123d633f54 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 2 Apr 2021 20:56:25 -0700 Subject: [PATCH 14/33] ready tests for rewiring --- .github/workflows/tests.yml | 33 ++++++++++++++++ src/dalle_mtf/sample.py | 4 +- src/model_fns.py | 6 +-- test.py | 76 +++++++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/tests.yml create mode 100644 test.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..b9719eb --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,33 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8, 3.9] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Test with pytest + run: | + pytest -s test.py diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index e184ac8..a438cd8 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -5,10 +5,10 @@ def sample_autoregressive(inputs, model, - params, stop_at_token=50256, max_steps=None, temperature=0.9, + padding_id = 0, variable_dtype=mtf.VariableDType(tf.float32), has_partial_sequences=True, remove_partial_sequences=False, @@ -28,7 +28,6 @@ def sample_autoregressive(inputs, Args: inputs: an int32 Tensor with shape [, length_dim], model: DALL-E model - params: model paramers. stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 @@ -50,7 +49,6 @@ def sample_autoregressive(inputs, batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] - padding_id = params.get("padding_id", 0) initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, padding_id)), diff --git a/src/model_fns.py b/src/model_fns.py index 52cf870..0486c6e 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -152,10 +152,10 @@ def dalle_model_fn(features, labels, mode, params): mtf_samples = sample_autoregressive(inputs, model, - params, - stop_at_token=model.eos_token_id, - max_steps=None, + max_steps=model.total_seq_dim, # will always run until the full image is produced + stop_at_token=None, temperature=0.9, + padding_id = 0, variable_dtype=model.variable_dtype, has_partial_sequences=True, remove_partial_sequences=True, diff --git a/test.py b/test.py new file mode 100644 index 0000000..fad6631 --- /dev/null +++ b/test.py @@ -0,0 +1,76 @@ +import pytest +import traceback +import logging +from collections import defaultdict +from contextlib import contextmanager + +import tensorflow as tf +tf.compat.v1.enable_eager_execution() +import mesh_tensorflow as mtf +from mesh_tensorflow import placement_mesh_impl + +from src.dalle_mtf.models import DALLE +from src.dalle_mtf.sample import sample_autoregressive + +# helper functions + +@contextmanager +def not_raises(exception): + try: + yield + except exception: + logging.error(traceback.format_exc()) + raise pytest.fail("DID RAISE {0}".format(exception)) + +# tests + +def test_model(): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + + model = DALLE( + batch_size = 1, + n_embd = 16, + n_heads = 2, + bf_16 = False + ) + + batch_dim = model.dimensions["batch_dim"] + sequence_dim = model.dimensions["total_seq_dim"] + + features = { + 'tokens': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32), + 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + } + + with not_raises(Exception): + loss, loss_batch, logits = model.forward(features, return_loss = True, return_logits = True) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + logits = lowering.export_to_tf_tensor(logits) + +def test_sampling(): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + + model = DALLE( + batch_size = 1, + n_embd = 16, + n_heads = 2, + bf_16 = False + ) + + batch_dim = model.dimensions["batch_dim"] + sequence_dim = model.dimensions["total_seq_dim"] + + inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + + with not_raises(Exception): + samples = sample_autoregressive( + inputs, model, variable_dtype=mtf.VariableDType(), max_steps = sequence_dim.size, + remove_partial_sequences=False, stop_at_token=None) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + samples = lowering.export_to_tf_tensor(samples) From 385810826e9f599a77391fd9b3a3a9db48d91167 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 2 Apr 2021 20:57:23 -0700 Subject: [PATCH 15/33] no 3.9 for tests --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b9719eb..d83b8aa 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.7, 3.8] steps: - uses: actions/checkout@v2 From 002343a2032b7e05c969c152628395643a14a126 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 2 Apr 2021 22:32:35 -0700 Subject: [PATCH 16/33] make DALLE-mtf work with text and image logits created with separate projections --- src/dalle_mtf/models.py | 81 ++++++++++++++++++++++++++++++----------- src/dalle_mtf/sample.py | 9 ++++- test.py | 4 +- 3 files changed, 69 insertions(+), 25 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index f55ea54..0278edc 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -142,7 +142,7 @@ class DALLE: def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024, n_layers=6, n_heads=8, batch_size=32, bf_16=True, attn_mask=None, mode="train", - is_incremental_inference=False, context=None, loss_fn=None, params=None, eos_token_id=None, + is_incremental_inference=False, context=None, loss_fn=None, text_loss_weight = 0.15, params=None, padding_id=None, activation_fn=None): self.n_embd = n_embd @@ -154,8 +154,8 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.n_layers = n_layers self.n_heads = n_heads self.attn_mask = attn_mask - self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS - self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id + self.total_tokens = text_vocab_size + image_vocab_size + self.padding_id = 0 if padding_id is None else padding_id self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd), "text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size), "image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size), @@ -174,6 +174,7 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq if loss_fn is None: loss_fn = mtf.layers.softmax_cross_entropy_with_logits self.loss_fn = loss_fn + self.text_loss_weight = text_loss_weight if activation_fn is None: activation_fn = mtf.relu self.activation_fn = activation_fn @@ -331,12 +332,17 @@ def transformer(self, x, mask): x = mtf.recompute_grad(block_fn, [x]) else: x = block_fn(x) - return x + return self.layer_norm(x) - def _loss(self, logits, labels): + def _loss(self, text_logits, image_logits, text_labels, image_labels): with tf.variable_scope("loss_final"): - loss_batch = self.loss_fn(logits=logits, targets=labels, - vocab_dim=logits.shape[-1], z_loss=0.0) + text_loss_batch = self.loss_fn(logits=text_logits, targets=text_labels, + vocab_dim=text_logits.shape[-1], z_loss=0.0) + + image_loss_batch = self.loss_fn(logits=text_logits, targets=text_labels, + vocab_dim=text_logits.shape[-1], z_loss=0.0) + + loss_batch = text_loss_batch * self.text_loss_weight + image_loss_batch with tf.variable_scope("reduce_mean_final"): loss = mtf.reduce_mean(loss_batch) @@ -376,35 +382,66 @@ def layer_norm(self, x, name="layer_norm", axis=None, epsilon=1e-5): x = x * g + b return x - def to_logits(self, x): + def to_image_logits(self, x): with tf.variable_scope("to_logits"): - logits = self.linear(self.layer_norm(x), self.dimensions["final_vocab_dim"], name="linear_out") + if not self.is_incremental_inference: + x = mtf.slice(x, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = x.shape[1].name) + + image_logits = self.linear(x, self.dimensions["image_vocab_dim"], name="linear_image_out") + # Go to full precision for the logits - return mtf.cast(logits, tf.float32) + image_logits = mtf.cast(image_logits, tf.float32) + return image_logits + + def to_text_logits(self, x): + with tf.variable_scope("to_logits"): + text_tokens = mtf.slice(x, begin = 0, size = self.text_seq_len, slice_dim_name = x.shape[1].name) + text_logits = self.linear(text_tokens, self.dimensions["text_vocab_dim"], name="linear_text_out") + + # Go to full precision for the logits + text_logits = mtf.cast(text_logits, tf.float32) + return text_logits def forward(self, features, return_loss=True, return_logits=False): inputs = features["tokens"] + orig_inputs = inputs + + mesh = inputs.mesh + if self.is_incremental_inference: # reshape inputs if in inference mode inputs = mtf.gather(inputs, self.context.position - 1, self.dimensions['total_seq_dim']) inputs = mtf.reshape(inputs, [self.dimensions['batch_dim']]) + else: + # add a to the inputs, and then remove the last token + inputs = pad(inputs, [1, 0], dim_name = inputs.shape[1].name, pad_value = 0.) + inputs = mtf.slice(inputs, begin = 1, size = (inputs.shape[1].size - 1), slice_dim_name = inputs.shape[1].name) + + # embed text and image tokens jointly and add positional embeds tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding") - mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"]) + mask = self.get_attn_mask(mesh, orig_inputs.shape[1], self.dimensions["memory_len_dim"]) + out = self.transformer(tokens, mask=mask) - logits = self.to_logits(out) + + image_logits = self.to_image_logits(out) + if not return_loss: - logits = mtf.cast(logits, self.variable_dtype.master_dtype) - return logits - - labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id) - indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1 - labels = mtf.gather(labels, indices, dim=labels.shape[1]) - labels = mtf.rename_dimension(labels, "range", "total_seq_dim") - loss, loss_batch = self._loss(logits, labels) + image_logits = mtf.cast(image_logits, self.variable_dtype.master_dtype) + return image_logits # we only care about image logits, text logits will be used for loss and never used otherwise + + text_logits = self.to_text_logits(out) + + labels = orig_inputs # a is prepended, so the labels it the same as the original input now + + text_labels = mtf.slice(labels, begin = 0, size = self.text_seq_len, slice_dim_name = labels.shape[1].name) + image_labels = mtf.slice(labels, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = labels.shape[1].name) + + loss, loss_batch = self._loss(text_logits, image_logits, text_labels, image_labels) + if return_logits and return_loss: # Cast back to checkpoint dtype - logits = mtf.cast(logits, self.variable_dtype.master_dtype) - return loss, loss_batch, logits + image_logits = mtf.cast(image_logits, self.variable_dtype.master_dtype) + return loss, loss_batch, image_logits # we only care about image logits, text logits will be used for loss and never used otherwise return loss, loss_batch diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index a438cd8..12770ea 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -142,13 +142,18 @@ def body_fn(position, ids, *states): raise ValueError("sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, - reduced_dim=model.dimensions['final_vocab_dim']) + reduced_dim=model.dimensions['image_vocab_dim']) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) # temperature sampling ids_this_step = mtf.sample_with_temperature( - logits, model.dimensions['final_vocab_dim'], temperature) + logits, model.dimensions['image_vocab_dim'], temperature) + + # because the image ids are in the range of [0, image_seq_len) + # it must be bumped up by the `text_seq_len` to be in the order of [[text embeds] [image _embeds]] in the input embeddings + + ids_this_step += model.text_seq_len # reshape & assign results ids_this_step = mtf.reshape(ids_this_step, batch_dims) diff --git a/test.py b/test.py index fad6631..e1ba748 100644 --- a/test.py +++ b/test.py @@ -56,6 +56,8 @@ def test_sampling(): model = DALLE( batch_size = 1, + text_seq_len = 1, + image_seq_len = 4, n_embd = 16, n_heads = 2, bf_16 = False @@ -64,7 +66,7 @@ def test_sampling(): batch_dim = model.dimensions["batch_dim"] sequence_dim = model.dimensions["total_seq_dim"] - inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + inputs = mtf.zeros(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) with not_raises(Exception): samples = sample_autoregressive( From 037aefb63398ab821e75c6af1a5d12ebb0f341dd Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 2 Apr 2021 23:02:19 -0700 Subject: [PATCH 17/33] add axial positional embedding --- src/dalle_mtf/models.py | 62 ++++++++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 0278edc..877f60b 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -202,7 +202,40 @@ def embedding(self, x, name): x = mtf.dropout(x, rate=embed_dropout, name="wte_dropout") return x - def positional_embedding(self, x, name): + def axial_positional_embedding(self, x, name): + mesh = x.mesh + + with tf.variable_scope(name): + axial_dim_side = int(sqrt(self.image_seq_len)) + + embd_dim = self.dimensions["embed_dim"] + axial_dim = mtf.Dimension("axial_dim", self.image_seq_len) + + dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_side, axial_dim_side))] + + axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=self.variable_dtype.master_dtype, + slice_dtype=self.variable_dtype.slice_dtype, + activation_dtype=self.variable_dtype.activation_dtype) + + axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=self.variable_dtype.master_dtype, + slice_dtype=self.variable_dtype.slice_dtype, + activation_dtype=self.variable_dtype.activation_dtype) + + axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), + (axial_wpe_1, axial_wpe_2)) + wpe = (axial_wpe_1 + axial_wpe_2) / 2 + + wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) + wpe = pad(wpe, [self.text_seq_len, 0], axial_dim.name) + wpe = mtf.replace_dimensions(wpe, wpe.shape[0], self.dimensions["embed_seq_dim"]) + return wpe + + + def absolute_positional_embedding(self, x, name): with tf.variable_scope(name): # Positional embedding wpe = mtf.get_variable(x.mesh, "wpe", @@ -211,14 +244,17 @@ def positional_embedding(self, x, name): master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) - position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \ - self.is_incremental_inference else (self.context.position - 1) - pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) - embed_dropout = self.params.get("embed_dropout", 0) - if embed_dropout > 0 and self.mode == "train": - pos_emb = mtf.dropout(pos_emb, rate=embed_dropout, name="wte_dropout") - x += pos_emb - return x + return wpe + + def apply_positional_embedding(self, x, wpe): + position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \ + self.is_incremental_inference else (self.context.position - 1) + pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) + embed_dropout = self.params.get("embed_dropout", 0) + if embed_dropout > 0 and self.mode == "train": + pos_emb = mtf.dropout(pos_emb, rate=embed_dropout, name="wte_dropout") + x += pos_emb + return x def get_attn_mask(self, mesh, nd, ns): if not exists(self.attn_mask): @@ -419,7 +455,13 @@ def forward(self, features, return_loss=True, return_logits=False): # embed text and image tokens jointly and add positional embeds - tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding") + inputs = self.embedding(inputs, "embedding") + + abs_pos_emb = self.absolute_positional_embedding(inputs, "positional_embedding") + axial_pos_emb = self.axial_positional_embedding(inputs, "axial_positional_embedding") + + inputs = self.apply_positional_embedding(inputs, abs_pos_emb) + tokens = self.apply_positional_embedding(inputs, axial_pos_emb) mask = self.get_attn_mask(mesh, orig_inputs.shape[1], self.dimensions["memory_len_dim"]) From e203de4beec786fc14f0a20cf6e33b5d1d31f554 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 2 Apr 2021 23:07:07 -0700 Subject: [PATCH 18/33] cleanup --- src/dalle_mtf/models.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 877f60b..d6688e1 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -202,9 +202,7 @@ def embedding(self, x, name): x = mtf.dropout(x, rate=embed_dropout, name="wte_dropout") return x - def axial_positional_embedding(self, x, name): - mesh = x.mesh - + def axial_positional_embedding(self, mesh, name): with tf.variable_scope(name): axial_dim_side = int(sqrt(self.image_seq_len)) @@ -235,10 +233,10 @@ def axial_positional_embedding(self, x, name): return wpe - def absolute_positional_embedding(self, x, name): + def absolute_positional_embedding(self, mesh, name): with tf.variable_scope(name): # Positional embedding - wpe = mtf.get_variable(x.mesh, "wpe", + wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([self.dimensions["embed_seq_dim"], self.dimensions["embed_dim"]]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=self.variable_dtype.master_dtype, @@ -457,8 +455,8 @@ def forward(self, features, return_loss=True, return_logits=False): inputs = self.embedding(inputs, "embedding") - abs_pos_emb = self.absolute_positional_embedding(inputs, "positional_embedding") - axial_pos_emb = self.axial_positional_embedding(inputs, "axial_positional_embedding") + abs_pos_emb = self.absolute_positional_embedding(mesh, "positional_embedding") + axial_pos_emb = self.axial_positional_embedding(mesh, "axial_positional_embedding") inputs = self.apply_positional_embedding(inputs, abs_pos_emb) tokens = self.apply_positional_embedding(inputs, axial_pos_emb) From 5ecb496d6e9867aa226f6f48b313849f8a584eb4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 2 Apr 2021 23:45:11 -0700 Subject: [PATCH 19/33] shift by text vocab size (not text seq len) --- src/dalle_mtf/sample.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 12770ea..12ddbc3 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -150,10 +150,10 @@ def body_fn(position, ids, *states): ids_this_step = mtf.sample_with_temperature( logits, model.dimensions['image_vocab_dim'], temperature) - # because the image ids are in the range of [0, image_seq_len) - # it must be bumped up by the `text_seq_len` to be in the order of [[text embeds] [image _embeds]] in the input embeddings + # because the image ids are in the range of [0, image_vocab_size) + # it must be bumped up by the `text_vocab_size` to be in the order of [[text embeds] [image _embeds]] in the input embeddings - ids_this_step += model.text_seq_len + ids_this_step += model.text_vocab_size # reshape & assign results ids_this_step = mtf.reshape(ids_this_step, batch_dims) From ec5c29878b778c43814d23c7cf60bbbb6ac84994 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 3 Apr 2021 08:43:46 -0700 Subject: [PATCH 20/33] make sure sampling can be forced to never start below a certain minimum starting position, useful when the text tokens contain padding tokens of 0 --- src/dalle_mtf/sample.py | 6 ++++++ test.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 12ddbc3..fd9b430 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -9,6 +9,7 @@ def sample_autoregressive(inputs, max_steps=None, temperature=0.9, padding_id = 0, + min_start_pos = None, variable_dtype=mtf.VariableDType(tf.float32), has_partial_sequences=True, remove_partial_sequences=False, @@ -54,6 +55,11 @@ def sample_autoregressive(inputs, mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts + if min_start_pos is not None: + # force the sampling to never start below a minimum starting position, say the text length. + # this will also be useful for image completion, where you can start sampling from half the image tokens + initial_position = mtf.maximum(initial_position, min_start_pos) + length_range = mtf.range(inputs.mesh, length_dim, tf.int32) # Builds context to pass around internally diff --git a/test.py b/test.py index e1ba748..6196995 100644 --- a/test.py +++ b/test.py @@ -70,8 +70,14 @@ def test_sampling(): with not_raises(Exception): samples = sample_autoregressive( - inputs, model, variable_dtype=mtf.VariableDType(), max_steps = sequence_dim.size, - remove_partial_sequences=False, stop_at_token=None) + inputs, + model, + variable_dtype=mtf.VariableDType(), + max_steps = sequence_dim.size, + remove_partial_sequences=False, + stop_at_token=None, + min_start_pos=model.text_seq_len + ) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) From 3492d80249ca0b871b57ea807dbdfdb9f6cec2ca Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 3 Apr 2021 10:48:35 -0700 Subject: [PATCH 21/33] add unique pad token ids, which obviates the need to remove attention and cross entropy loss from text padding tokens --- src/dalle_mtf/models.py | 27 +++++++++++++++++++++++---- src/model_fns.py | 4 ++-- test.py | 26 ++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index d6688e1..ec3d5fe 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -154,10 +154,11 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.n_layers = n_layers self.n_heads = n_heads self.attn_mask = attn_mask - self.total_tokens = text_vocab_size + image_vocab_size + self.total_tokens = text_vocab_size + text_seq_len + image_vocab_size # (this is the order of the embeddings as well [pad] [text tokens] [text padding tokens] [image tokens]) + self.padding_id = 0 if padding_id is None else padding_id self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd), - "text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size), + "text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size + text_seq_len), "image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size), "final_vocab_dim": mtf.Dimension("vocab_dim", self.total_tokens), "total_seq_dim": mtf.Dimension("total_seq_dim", self.total_seq_dim), @@ -185,6 +186,12 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq params = {} self.params = defaultdict(lambda: None, params) + def shift_image_tokens(self, image_tokens): + return image_tokens + self.text_seq_len + self.text_vocab_dim + + def unshift_image_tokens(self, image_tokens): + return image_tokens - (self.text_seq_len + self.text_vocab_dim) + def embedding(self, x, name): embd_dim = self.dimensions["embed_dim"] vocab_dim = self.dimensions["final_vocab_dim"] @@ -438,9 +445,21 @@ def to_text_logits(self, x): def forward(self, features, return_loss=True, return_logits=False): inputs = features["tokens"] + mesh = inputs.mesh + + # make sure all padding gets turned into unique padding tokens + + input_range = mtf.range(mesh, self.dimensions['total_seq_dim'], tf.int32) + + pad_mask = mtf.logical_and(mtf.less(input_range, self.text_seq_len), mtf.equal(inputs, 0)) # only mask in the positions less than text sequence length, and where the input is 0 + pad_token_ids = input_range + self.text_seq_len # shift to the range of pad token ids, which come after text token ids, and before image token ids + + inputs = mtf.where(pad_mask, pad_token_ids, inputs) + + # save original inputs to be used as labels + orig_inputs = inputs - mesh = inputs.mesh if self.is_incremental_inference: # reshape inputs if in inference mode @@ -448,7 +467,7 @@ def forward(self, features, return_loss=True, return_logits=False): inputs = mtf.reshape(inputs, [self.dimensions['batch_dim']]) else: # add a to the inputs, and then remove the last token - inputs = pad(inputs, [1, 0], dim_name = inputs.shape[1].name, pad_value = 0.) + inputs = pad(inputs, [1, 0], dim_name = inputs.shape[1].name, pad_value = self.padding_id) inputs = mtf.slice(inputs, begin = 1, size = (inputs.shape[1].size - 1), slice_dim_name = inputs.shape[1].name) # embed text and image tokens jointly and add positional embeds diff --git a/src/model_fns.py b/src/model_fns.py index 0486c6e..9ba11d8 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -120,7 +120,7 @@ def dalle_model_fn(features, labels, mode, params): if x is not None: if key == "text_inputs": text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) - x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) + x = tf.concat((text_tokens, model.shift_image_tokens(img_tokens_reshaped)), axis=1) mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) @@ -171,7 +171,7 @@ def dalle_model_fn(features, labels, mode, params): initialize_vae_weights(vae_checkpoint_path) - img_outputs = outputs[:, -model.image_seq_len:] - model.text_vocab_size + img_outputs = model.unshift_image_tokens(outputs[:, -model.image_seq_len:]) with tf.variable_scope('vae'): predictions_decoded = vae.decode(img_outputs) diff --git a/test.py b/test.py index 6196995..b3a165c 100644 --- a/test.py +++ b/test.py @@ -50,6 +50,32 @@ def test_model(): lowering = mtf.Lowering(graph, {mesh: mesh_impl}) logits = lowering.export_to_tf_tensor(logits) +def test_model_fn(): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + + model = DALLE( + batch_size = 1, + n_embd = 16, + n_heads = 2, + bf_16 = False + ) + + batch_dim = model.dimensions["batch_dim"] + sequence_dim = model.dimensions["total_seq_dim"] + + features = { + 'tokens': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32), + 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + } + + with not_raises(Exception): + loss, loss_batch, logits = model.forward(features, return_loss = True, return_logits = True) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + logits = lowering.export_to_tf_tensor(logits) + def test_sampling(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") From 3b7684938ba1eb3a9c8a9abc294839021b62069f Mon Sep 17 00:00:00 2001 From: sdtblck <46172032+sdtblck@users.noreply.github.com> Date: Sun, 4 Apr 2021 21:55:41 +0200 Subject: [PATCH 22/33] fix text_vocab_dim error --- src/dalle_mtf/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index ec3d5fe..020e796 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -187,10 +187,10 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.params = defaultdict(lambda: None, params) def shift_image_tokens(self, image_tokens): - return image_tokens + self.text_seq_len + self.text_vocab_dim + return image_tokens + self.text_seq_len + self.dimensions['text_vocab_dim'] def unshift_image_tokens(self, image_tokens): - return image_tokens - (self.text_seq_len + self.text_vocab_dim) + return image_tokens - (self.text_seq_len + self.dimensions['text_vocab_dim']) def embedding(self, x, name): embd_dim = self.dimensions["embed_dim"] From ef965371e66c05c19c817a0acb5a3ac1fad5c7d9 Mon Sep 17 00:00:00 2001 From: sdtblck <46172032+sdtblck@users.noreply.github.com> Date: Sun, 4 Apr 2021 21:58:29 +0200 Subject: [PATCH 23/33] Update models.py --- src/dalle_mtf/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 020e796..7dede09 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -187,10 +187,10 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.params = defaultdict(lambda: None, params) def shift_image_tokens(self, image_tokens): - return image_tokens + self.text_seq_len + self.dimensions['text_vocab_dim'] + return image_tokens + self.text_seq_len + self.dimensions['text_vocab_dim'].size def unshift_image_tokens(self, image_tokens): - return image_tokens - (self.text_seq_len + self.dimensions['text_vocab_dim']) + return image_tokens - (self.text_seq_len + self.dimensions['text_vocab_dim'].size) def embedding(self, x, name): embd_dim = self.dimensions["embed_dim"] From a1634d47d0f054394e8bb9a68299ebf14f469ce2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 13:23:49 -0700 Subject: [PATCH 24/33] fix bug --- src/dalle_mtf/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 7dede09..2a8a1e4 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -380,8 +380,8 @@ def _loss(self, text_logits, image_logits, text_labels, image_labels): text_loss_batch = self.loss_fn(logits=text_logits, targets=text_labels, vocab_dim=text_logits.shape[-1], z_loss=0.0) - image_loss_batch = self.loss_fn(logits=text_logits, targets=text_labels, - vocab_dim=text_logits.shape[-1], z_loss=0.0) + image_loss_batch = self.loss_fn(logits=image_logits, targets=image_labels, + vocab_dim=image_logits.shape[-1], z_loss=0.0) loss_batch = text_loss_batch * self.text_loss_weight + image_loss_batch @@ -468,7 +468,7 @@ def forward(self, features, return_loss=True, return_logits=False): else: # add a to the inputs, and then remove the last token inputs = pad(inputs, [1, 0], dim_name = inputs.shape[1].name, pad_value = self.padding_id) - inputs = mtf.slice(inputs, begin = 1, size = (inputs.shape[1].size - 1), slice_dim_name = inputs.shape[1].name) + inputs = mtf.slice(inputs, begin = 1, size = self.total_seq_dim, slice_dim_name = inputs.shape[1].name) # embed text and image tokens jointly and add positional embeds From a3d21150c67914e6041262af9d05e2d7e8093727 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 14:18:54 -0700 Subject: [PATCH 25/33] fix bug with shift --- src/dalle_mtf/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 2a8a1e4..1f9f708 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -468,7 +468,7 @@ def forward(self, features, return_loss=True, return_logits=False): else: # add a to the inputs, and then remove the last token inputs = pad(inputs, [1, 0], dim_name = inputs.shape[1].name, pad_value = self.padding_id) - inputs = mtf.slice(inputs, begin = 1, size = self.total_seq_dim, slice_dim_name = inputs.shape[1].name) + inputs = mtf.slice(inputs, begin = 0, size = self.total_seq_dim, slice_dim_name = inputs.shape[1].name) # embed text and image tokens jointly and add positional embeds From 5f279ad45b2d5126d1124be90f5962e50f07e48c Mon Sep 17 00:00:00 2001 From: sdtblck <46172032+sdtblck@users.noreply.github.com> Date: Mon, 5 Apr 2021 00:34:45 +0200 Subject: [PATCH 26/33] Add adam weight decay optimizer --- src/optimizers.py | 166 +++++++++++++++++++++++----------------------- 1 file changed, 83 insertions(+), 83 deletions(-) diff --git a/src/optimizers.py b/src/optimizers.py index 7f77c04..42a627c 100644 --- a/src/optimizers.py +++ b/src/optimizers.py @@ -79,7 +79,7 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): scalar_summary("lr", learning_rate) if optimizer_name.lower() == "adam": - optimizer = mtf.optimize.AdamWeightDecayOptimizer( + optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=params.get("weight_decay", 0.0), beta_1=params.get("beta_1", 0.9), @@ -104,85 +104,85 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): return learning_rate, update_ops, var_grads_fp -# class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): -# """A basic Adam optimizer that includes "correct" L2 weight decay.""" - -# def __init__(self, -# learning_rate, -# weight_decay_rate=0.0, -# beta_1=0.9, -# beta_2=0.999, -# epsilon=1e-6, -# exclude_from_weight_decay=None, -# variable_dtype=None): -# """Constructs a AdamWeightDecayOptimizer.""" - -# self.learning_rate = learning_rate -# self.weight_decay_rate = weight_decay_rate -# self.beta_1 = beta_1 -# self.beta_2 = beta_2 -# self.epsilon = epsilon -# self.exclude_from_weight_decay = exclude_from_weight_decay -# self.variable_dtype = variable_dtype - -# def apply_grad(self, grad, var): -# """See base class.""" -# if grad is None: -# tf.logging.warning("Gradient is None for variable %s" % var.name) -# return [] - -# grad = mtf.to_float(grad) - -# assignments = [] - -# m = mtf.get_variable( -# var.mesh, var.name + "/adam_m", var.shape, -# initializer=tf.zeros_initializer(), -# # master_dtype=self.variable_dtype.master_dtype, -# # slice_dtype=self.variable_dtype.slice_dtype, -# # activation_dtype=self.variable_dtype.activation_dtype, -# trainable=False) - -# v = mtf.get_variable( -# var.mesh, var.name + "/adam_v", var.shape, -# initializer=tf.zeros_initializer(), -# # master_dtype=self.variable_dtype.master_dtype, -# # slice_dtype=self.variable_dtype.slice_dtype, -# # activation_dtype=self.variable_dtype.activation_dtype, -# trainable=False) - -# # Standard Adam update. -# next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad -# next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) - -# update = next_m / (mtf.sqrt(next_v) + self.epsilon) - -# # Just adding the square of the weights to the loss function is *not* -# # the correct way of using L2 regularization/weight decay with Adam, -# # since that will interact with the m and v parameters in strange ways. -# # -# # Instead we want to decay the weights in a manner that doesn't interact -# # with the m/v parameters. This is equivalent to adding the square -# # of the weights to the loss with plain (non-momentum) SGD. -# if self._do_use_weight_decay(var.name): -# update += mtf.to_float(var.value) * self.weight_decay_rate - -# update_with_lr = self.learning_rate * update - -# var_update = mtf.assign_sub(var, update_with_lr) - -# assignments.extend( -# [var_update, -# mtf.assign(m, next_m), -# mtf.assign(v, next_v)]) -# return assignments - -# def _do_use_weight_decay(self, param_name): -# """Whether to use L2 weight decay for `param_name`.""" -# if not self.weight_decay_rate: -# return False -# if self.exclude_from_weight_decay: -# for r in self.exclude_from_weight_decay: -# if re.search(r, param_name) is not None: -# return False -# return True +class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): + """A basic Adam optimizer that includes "correct" L2 weight decay.""" + + def __init__(self, + learning_rate, + weight_decay_rate=0.0, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=None, + variable_dtype=None): + """Constructs a AdamWeightDecayOptimizer.""" + + self.learning_rate = learning_rate + self.weight_decay_rate = weight_decay_rate + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.exclude_from_weight_decay = exclude_from_weight_decay + self.variable_dtype = variable_dtype + + def apply_grad(self, grad, var): + """See base class.""" + if grad is None: + tf.logging.warning("Gradient is None for variable %s" % var.name) + return [] + + grad = mtf.to_float(grad) + + assignments = [] + + m = mtf.get_variable( + var.mesh, var.name + "/adam_m", var.shape, + initializer=tf.zeros_initializer(), + # master_dtype=self.variable_dtype.master_dtype, + # slice_dtype=self.variable_dtype.slice_dtype, + # activation_dtype=self.variable_dtype.activation_dtype, + trainable=False) + + v = mtf.get_variable( + var.mesh, var.name + "/adam_v", var.shape, + initializer=tf.zeros_initializer(), + # master_dtype=self.variable_dtype.master_dtype, + # slice_dtype=self.variable_dtype.slice_dtype, + # activation_dtype=self.variable_dtype.activation_dtype, + trainable=False) + + # Standard Adam update. + next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad + next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) + + update = next_m / (mtf.sqrt(next_v) + self.epsilon) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if self._do_use_weight_decay(var.name): + update += mtf.to_float(var.value) * self.weight_decay_rate + + update_with_lr = self.learning_rate * update + + var_update = mtf.assign_sub(var, update_with_lr) + + assignments.extend( + [var_update, + mtf.assign(m, next_m), + mtf.assign(v, next_v)]) + return assignments + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if not self.weight_decay_rate: + return False + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True From 8db0d55097b0a7ad4f7eb86be882d71b7b0c244b Mon Sep 17 00:00:00 2001 From: sdtblck <46172032+sdtblck@users.noreply.github.com> Date: Mon, 5 Apr 2021 00:37:21 +0200 Subject: [PATCH 27/33] fix args.steps_per_checkpoint --- train_dalle.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/train_dalle.py b/train_dalle.py index 8ad9a8e..1a3ae10 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -19,7 +19,7 @@ def parse_args(): parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " "starts a new training run") parser.add_argument('--predict', action='store_true', help='run model in predict mode') - parser.add_argument('--prompt', type=str, default='a cat in a hat') + parser.add_argument('--prompt', type=str, default='face') args = parser.parse_args() assert args.model is not None, "Model must be set" return args @@ -93,14 +93,15 @@ def main(): if has_predict_or_eval_steps: # Eval and train - stop and predict and/or eval every checkpoint while current_step < params["train_steps"]: - next_checkpoint = min(current_step + args.steps_per_checkpoint, params["train_steps"]) + next_checkpoint = min(current_step + params["steps_per_checkpoint"], params["train_steps"]) estimator.train(input_fn=partial(dalle_input_fn, eval=False), max_steps=next_checkpoint) current_step = next_checkpoint if params["predict_steps"] > 0: raise NotImplementedError if params["eval_steps"] > 0: - raise NotImplementedError + estimator.evaluate(input_fn=partial(dalle_input_fn, eval=True), + steps=params["eval_steps"]) return else: # Else, just train @@ -110,6 +111,7 @@ def main(): max_steps=params["train_steps"]) + if __name__ == "__main__": tf.disable_v2_behavior() main() From d3112ca80140fe353228bb3e61bacac7d89ef331 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 15:47:52 -0700 Subject: [PATCH 28/33] add variable scope --- src/dalle_mtf/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 1f9f708..9d0ed6b 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -424,7 +424,7 @@ def layer_norm(self, x, name="layer_norm", axis=None, epsilon=1e-5): return x def to_image_logits(self, x): - with tf.variable_scope("to_logits"): + with tf.variable_scope("to_image_logits"): if not self.is_incremental_inference: x = mtf.slice(x, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = x.shape[1].name) @@ -435,7 +435,7 @@ def to_image_logits(self, x): return image_logits def to_text_logits(self, x): - with tf.variable_scope("to_logits"): + with tf.variable_scope("to_text_logits"): text_tokens = mtf.slice(x, begin = 0, size = self.text_seq_len, slice_dim_name = x.shape[1].name) text_logits = self.linear(text_tokens, self.dimensions["text_vocab_dim"], name="linear_text_out") From 33de20318170bf1a0bcbf41bda20847f652152ab Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 15:51:12 -0700 Subject: [PATCH 29/33] do an early copy of inputs for labels --- src/dalle_mtf/models.py | 3 +-- test.py | 26 -------------------------- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 9d0ed6b..fd5040e 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -458,8 +458,7 @@ def forward(self, features, return_loss=True, return_logits=False): # save original inputs to be used as labels - orig_inputs = inputs - + orig_inputs = mtf.slice(inputs, begin = 0, size = self.total_seq_dim, slice_dim_name = inputs.shape[1].name) if self.is_incremental_inference: # reshape inputs if in inference mode diff --git a/test.py b/test.py index b3a165c..6196995 100644 --- a/test.py +++ b/test.py @@ -50,32 +50,6 @@ def test_model(): lowering = mtf.Lowering(graph, {mesh: mesh_impl}) logits = lowering.export_to_tf_tensor(logits) -def test_model_fn(): - graph = mtf.Graph() - mesh = mtf.Mesh(graph, "my_mesh") - - model = DALLE( - batch_size = 1, - n_embd = 16, - n_heads = 2, - bf_16 = False - ) - - batch_dim = model.dimensions["batch_dim"] - sequence_dim = model.dimensions["total_seq_dim"] - - features = { - 'tokens': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32), - 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) - } - - with not_raises(Exception): - loss, loss_batch, logits = model.forward(features, return_loss = True, return_logits = True) - - mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) - lowering = mtf.Lowering(graph, {mesh: mesh_impl}) - logits = lowering.export_to_tf_tensor(logits) - def test_sampling(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") From 4dbf727bb2d1d5bd124ca95378c928d87eaad82d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Apr 2021 15:53:43 -0700 Subject: [PATCH 30/33] tweak --- src/dalle_mtf/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index fd5040e..c1f2280 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -344,6 +344,7 @@ def attention(self, x, n_state, mask, attention_type="global", name="attn"): a = mtf.dropout(a, rate=residual_dropout, name="res_dropout") return a + def mlp(self, x, n_state, name="mlp"): residual_dropout = self.params.get("residual_dropout", 0) with tf.variable_scope(name): @@ -479,8 +480,7 @@ def forward(self, features, return_loss=True, return_logits=False): inputs = self.apply_positional_embedding(inputs, abs_pos_emb) tokens = self.apply_positional_embedding(inputs, axial_pos_emb) - mask = self.get_attn_mask(mesh, orig_inputs.shape[1], self.dimensions["memory_len_dim"]) - + mask = self.get_attn_mask(mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) image_logits = self.to_image_logits(out) From 10dc0246637753ec572fcf471ca7ddaa09527b5e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 5 Apr 2021 14:02:38 -0700 Subject: [PATCH 31/33] fix initial positions at text_seq_len --- src/dalle_mtf/sample.py | 10 ++-------- test.py | 3 ++- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index fd9b430..486fedd 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -51,14 +51,8 @@ def sample_autoregressive(inputs, batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] - initial_position = mtf.reduce_sum( - mtf.to_int32(mtf.not_equal(inputs, padding_id)), - reduced_dim=length_dim) # Gets position where zero padding starts - - if min_start_pos is not None: - # force the sampling to never start below a minimum starting position, say the text length. - # this will also be useful for image completion, where you can start sampling from half the image tokens - initial_position = mtf.maximum(initial_position, min_start_pos) + # Gets position (in image inputs) where zero padding starts + initial_position = mtf.zeros(inputs.mesh, batch_dims, dtype=tf.int32) + model.text_seq_len length_range = mtf.range(inputs.mesh, length_dim, tf.int32) diff --git a/test.py b/test.py index 6196995..e5d929d 100644 --- a/test.py +++ b/test.py @@ -56,7 +56,7 @@ def test_sampling(): model = DALLE( batch_size = 1, - text_seq_len = 1, + text_seq_len = 3, image_seq_len = 4, n_embd = 16, n_heads = 2, @@ -82,3 +82,4 @@ def test_sampling(): mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) samples = lowering.export_to_tf_tensor(samples) + print(samples) \ No newline at end of file From 9dacd357767cf481568ad26e32817db69a3ebdd5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 5 Apr 2021 16:54:07 -0700 Subject: [PATCH 32/33] more cleanup --- src/dalle_mtf/sample.py | 27 --------------------------- src/model_fns.py | 6 +----- test.py | 6 +----- 3 files changed, 2 insertions(+), 37 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 486fedd..3551fcd 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -5,14 +5,9 @@ def sample_autoregressive(inputs, model, - stop_at_token=50256, - max_steps=None, temperature=0.9, - padding_id = 0, - min_start_pos = None, variable_dtype=mtf.VariableDType(tf.float32), has_partial_sequences=True, - remove_partial_sequences=False, sampling_keep_top_k=-1, ): """Sample randomly one token at a time. @@ -87,25 +82,10 @@ def sample_autoregressive(inputs, if not has_partial_sequences: partial_sequences_eos_count = 0 - if stop_at_token is not None: - partial_sequences_eos_count = mtf.reduce_sum( - mtf.to_int32(mtf.equal(inputs, stop_at_token)), - reduced_dim=length_dim) - def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) - if max_steps: - past_end = mtf.logical_or( - past_end, mtf.greater_equal(position - initial_position, max_steps)) - is_done = past_end - if stop_at_token is not None: - eos_count = mtf.reduce_sum( - mtf.to_int32(mtf.equal(ids, stop_at_token)), - reduced_dim=length_dim) - has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) - is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) @@ -169,11 +149,4 @@ def body_fn(position, ids, *states): final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position - if has_partial_sequences and remove_partial_sequences: - # Remove partial sequences from outputs - partial_length = mtf.reduce_sum( - mtf.to_int32(mtf.not_equal(inputs, padding_id)), - reduced_dim=length_dim) - outputs = mtf.dynamic_shift( - outputs, -partial_length, length_dim, wrap=False) return outputs diff --git a/src/model_fns.py b/src/model_fns.py index 9ba11d8..3f1e243 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -152,14 +152,10 @@ def dalle_model_fn(features, labels, mode, params): mtf_samples = sample_autoregressive(inputs, model, - max_steps=model.total_seq_dim, # will always run until the full image is produced - stop_at_token=None, temperature=0.9, - padding_id = 0, variable_dtype=model.variable_dtype, has_partial_sequences=True, - remove_partial_sequences=True, - sampling_keep_top_k=-1, + sampling_keep_top_k=-2, ) mtf_samples = mtf.anonymize(mtf_samples) diff --git a/test.py b/test.py index e5d929d..0ff0728 100644 --- a/test.py +++ b/test.py @@ -72,11 +72,7 @@ def test_sampling(): samples = sample_autoregressive( inputs, model, - variable_dtype=mtf.VariableDType(), - max_steps = sequence_dim.size, - remove_partial_sequences=False, - stop_at_token=None, - min_start_pos=model.text_seq_len + variable_dtype=mtf.VariableDType() ) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) From c073fbc5f0a47b8bd9dbce1378119b5fc5f0ddb2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 5 Apr 2021 22:47:35 -0700 Subject: [PATCH 33/33] make sure axial positional embedding is correctly shifted by one due to --- src/dalle_mtf/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index c1f2280..110c543 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -235,7 +235,8 @@ def axial_positional_embedding(self, mesh, name): wpe = (axial_wpe_1 + axial_wpe_2) / 2 wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) - wpe = pad(wpe, [self.text_seq_len, 0], axial_dim.name) + wpe = pad(wpe, [self.text_seq_len + 1, 0], axial_dim.name) + wpe = mtf.slice(wpe, 0, self.total_seq_dim, axial_dim.name) wpe = mtf.replace_dimensions(wpe, wpe.shape[0], self.dimensions["embed_seq_dim"]) return wpe