|
| 1 | +""" |
| 2 | +This module provides functionality to save top-k teacher logits |
| 3 | +for distillation purposes in MaxText. |
| 4 | +""" |
| 5 | + |
| 6 | +import os |
| 7 | +import pickle |
| 8 | +from typing import Sequence |
| 9 | +import argparse |
| 10 | +import sys |
| 11 | +import tensorflow as tf |
| 12 | + |
| 13 | +import jax |
| 14 | +import numpy as np |
| 15 | +import functools |
| 16 | +from itertools import islice |
| 17 | + |
| 18 | +from absl import app |
| 19 | +from MaxText import pyconfig |
| 20 | +from maxtext.utils import model_creation_utils |
| 21 | +from maxtext.input_pipeline import input_pipeline_interface |
| 22 | +from maxtext.utils import maxtext_utils |
| 23 | +from maxtext.utils import max_logging |
| 24 | + |
| 25 | +from jax.experimental import multihost_utils |
| 26 | +from array_record.python import array_record_module |
| 27 | + |
| 28 | + |
| 29 | +def get_top_k_logits(logits: jax.Array, k: int): |
| 30 | + """Extracts the top-k values and their vocabulary indices""" |
| 31 | + top_k_values, top_k_indices = jax.lax.top_k(logits, k) |
| 32 | + return top_k_values, top_k_indices |
| 33 | + |
| 34 | + |
| 35 | +def generate_and_save_data(config, k_val): |
| 36 | + """Generates top-k logits from the teacher model and saves them to an ArrayRecord file""" |
| 37 | + devices = jax.devices() |
| 38 | + devices_array = maxtext_utils.create_device_mesh(config, devices) |
| 39 | + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) |
| 40 | + |
| 41 | + # Loading teacher model and dataset iterator |
| 42 | + max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...") |
| 43 | + teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) |
| 44 | + train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh) |
| 45 | + |
| 46 | + process_index = jax.process_index() |
| 47 | + |
| 48 | + output_dir = config.base_output_directory |
| 49 | + if config.run_name: |
| 50 | + output_dir = os.path.join(output_dir, config.run_name) |
| 51 | + |
| 52 | + final_gcs_file = os.path.join(output_dir, "teacher_top_k.array_record") |
| 53 | + local_temp_file = "/tmp/teacher_top_k.array_record" |
| 54 | + |
| 55 | + writer = None |
| 56 | + if process_index == 0: |
| 57 | + max_logging.log(f"Opening local ArrayRecordWriter at {local_temp_file}") |
| 58 | + writer = array_record_module.ArrayRecordWriter(local_temp_file, "group_size:1000") |
| 59 | + |
| 60 | + max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...") |
| 61 | + for step, batch in enumerate(islice(train_iter, config.steps)): |
| 62 | + tokens = batch["inputs"] |
| 63 | + logits = teacher_model( |
| 64 | + decoder_input_tokens=tokens, |
| 65 | + decoder_positions=batch["inputs_position"], |
| 66 | + enable_dropout=False, |
| 67 | + ) |
| 68 | + |
| 69 | + top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val) |
| 70 | + |
| 71 | + gathered_vals = multihost_utils.process_allgather(top_k_vals, tiled=True) |
| 72 | + gathered_idx = multihost_utils.process_allgather(top_k_idx, tiled=True) |
| 73 | + gathered_tokens = multihost_utils.process_allgather(tokens, tiled=True) |
| 74 | + |
| 75 | + optional_keys = config.teacher_logits_optional_keys |
| 76 | + gathered_optionals = { |
| 77 | + key: multihost_utils.process_allgather(batch[key], tiled=True) for key in optional_keys if key in batch |
| 78 | + } |
| 79 | + |
| 80 | + if process_index == 0: |
| 81 | + record_dict = { |
| 82 | + "tokens": np.array(gathered_tokens), |
| 83 | + "top_k_logits": np.array(gathered_vals), |
| 84 | + "top_k_indices": np.array(gathered_idx), |
| 85 | + } |
| 86 | + |
| 87 | + for key, gathered_val in gathered_optionals.items(): |
| 88 | + record_dict[key] = np.array(gathered_val) |
| 89 | + |
| 90 | + writer.write(pickle.dumps(record_dict)) |
| 91 | + |
| 92 | + if step % 50 == 0: |
| 93 | + max_logging.log(f"Successfully processed step {step}") |
| 94 | + |
| 95 | + if writer is None: |
| 96 | + return |
| 97 | + |
| 98 | + writer.close() |
| 99 | + max_logging.log(f"Finished writing locally, uploading to GCS: {final_gcs_file}...") |
| 100 | + |
| 101 | + if not tf.io.gfile.exists(output_dir): |
| 102 | + tf.io.gfile.makedirs(output_dir) |
| 103 | + |
| 104 | + tf.io.gfile.copy(local_temp_file, final_gcs_file, overwrite=True) |
| 105 | + os.remove(local_temp_file) |
| 106 | + max_logging.log("Upload complete") |
| 107 | + |
| 108 | + |
| 109 | +def main(argv: Sequence[str], local_args): |
| 110 | + # Initialize the global configuration |
| 111 | + global_config = pyconfig.initialize(argv) |
| 112 | + teacher_overrides = global_config.teacher_overrides |
| 113 | + teacher_argv = [argv[0], argv[1]] |
| 114 | + teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) |
| 115 | + |
| 116 | + generate_and_save_data(teacher_config, local_args.top_k) |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == "__main__": |
| 120 | + parser = argparse.ArgumentParser() |
| 121 | + parser.add_argument( |
| 122 | + "--top_k", |
| 123 | + type=int, |
| 124 | + required=False, |
| 125 | + default=128, |
| 126 | + help="Top K value for logits.", |
| 127 | + ) |
| 128 | + local_arg, remaining_args = parser.parse_known_args() |
| 129 | + |
| 130 | + main_wrapper = functools.partial(main, local_args=local_arg) |
| 131 | + app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args) |
0 commit comments