|
| 1 | +""" |
| 2 | +This module provides functionality to save top-k teacher logits |
| 3 | +for distillation purposes in MaxText. |
| 4 | +
|
| 5 | +Example command: python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py src/maxtext/configs/post_train/distillation.yml --top_k=128 |
| 6 | +""" |
| 7 | + |
| 8 | +import os |
| 9 | +import pickle |
| 10 | +from typing import Sequence |
| 11 | +import argparse |
| 12 | +import time |
| 13 | +import sys |
| 14 | +import tensorflow as tf |
| 15 | + |
| 16 | +import jax |
| 17 | +import numpy as np |
| 18 | +import functools |
| 19 | +from itertools import islice |
| 20 | + |
| 21 | +from absl import app |
| 22 | +from MaxText import pyconfig |
| 23 | +from maxtext.utils import model_creation_utils |
| 24 | +from maxtext.input_pipeline import input_pipeline_interface |
| 25 | +from maxtext.utils import maxtext_utils |
| 26 | +from maxtext.utils import max_logging |
| 27 | + |
| 28 | +from jax.experimental import multihost_utils |
| 29 | +from array_record.python import array_record_module |
| 30 | + |
| 31 | + |
| 32 | +def get_top_k_logits(logits: jax.Array, k: int): |
| 33 | + """Extracts the top-k values and their vocabulary indices""" |
| 34 | + top_k_values, top_k_indices = jax.lax.top_k(logits, k) |
| 35 | + return top_k_values, top_k_indices |
| 36 | + |
| 37 | + |
| 38 | +def get_local_cpu_array(arr): |
| 39 | + """Extracts the local data from a sharded JAX array to a host numpy array.""" |
| 40 | + return np.concatenate([np.array(s.data) for s in arr.addressable_shards], axis=0) |
| 41 | + |
| 42 | + |
| 43 | +def generate_and_save_data(config, k_val): |
| 44 | + """Generates top-k logits from the teacher model and saves them to an ArrayRecord file""" |
| 45 | + devices = jax.devices() |
| 46 | + devices_array = maxtext_utils.create_device_mesh(config, devices) |
| 47 | + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) |
| 48 | + |
| 49 | + # Loading teacher model and dataset iterator |
| 50 | + max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...") |
| 51 | + teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) |
| 52 | + train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh) |
| 53 | + |
| 54 | + output_dir = config.base_output_directory |
| 55 | + if config.run_name: |
| 56 | + output_dir = os.path.join(output_dir, config.run_name) |
| 57 | + |
| 58 | + if jax.process_index() == 0: |
| 59 | + if not tf.io.gfile.exists(output_dir): |
| 60 | + tf.io.gfile.makedirs(output_dir) |
| 61 | + |
| 62 | + # Sync all hosts to ensure directory exists before writers open files |
| 63 | + multihost_utils.sync_global_devices("create_output_dir") |
| 64 | + |
| 65 | + # Each host writes to a unique file based on its process index to avoid write conflicts |
| 66 | + filename = f"teacher_top_k_process_{jax.process_index()}.array_record" |
| 67 | + output_path = os.path.join(output_dir, filename) |
| 68 | + |
| 69 | + max_logging.log(f"Process {jax.process_index()} writing directly to: {output_path}") |
| 70 | + writer = array_record_module.ArrayRecordWriter(output_path, "group_size:1000") |
| 71 | + |
| 72 | + max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...") |
| 73 | + loop_start = time.time() |
| 74 | + for step, batch in enumerate(islice(train_iter, config.steps)): |
| 75 | + step_start = time.time() |
| 76 | + tokens = batch["inputs"] |
| 77 | + logits = teacher_model( |
| 78 | + decoder_input_tokens=tokens, |
| 79 | + decoder_positions=batch["inputs_position"], |
| 80 | + enable_dropout=False, |
| 81 | + ) |
| 82 | + top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val) |
| 83 | + |
| 84 | + # Extract only the local data for this host (Distributed Writing) |
| 85 | + local_vals = get_local_cpu_array(top_k_vals) |
| 86 | + local_idx = get_local_cpu_array(top_k_idx) |
| 87 | + local_tokens = get_local_cpu_array(tokens) |
| 88 | + |
| 89 | + optional_keys = config.teacher_logits_optional_keys |
| 90 | + local_optionals = { |
| 91 | + key: get_local_cpu_array(batch[key]) for key in optional_keys if key in batch |
| 92 | + } |
| 93 | + |
| 94 | + record_dict = { |
| 95 | + "tokens": local_tokens, |
| 96 | + "top_k_logits": local_vals, |
| 97 | + "top_k_indices": local_idx, |
| 98 | + } |
| 99 | + for key, local_val in local_optionals.items(): |
| 100 | + record_dict[key] = local_val |
| 101 | + |
| 102 | + writer.write(pickle.dumps(record_dict)) |
| 103 | + |
| 104 | + if step % 50 == 0: |
| 105 | + max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s") |
| 106 | + |
| 107 | + max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s") |
| 108 | + |
| 109 | + writer.close() |
| 110 | + max_logging.log(f"Finished writing to {output_path}.") |
| 111 | + |
| 112 | + |
| 113 | +def main(argv: Sequence[str], local_args): |
| 114 | + # Initialize the global configuration |
| 115 | + global_config = pyconfig.initialize(argv) |
| 116 | + teacher_overrides = global_config.teacher_overrides |
| 117 | + teacher_argv = [argv[0], argv[1]] |
| 118 | + teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) |
| 119 | + |
| 120 | + generate_and_save_data(teacher_config, local_args.top_k) |
| 121 | + |
| 122 | + |
| 123 | +if __name__ == "__main__": |
| 124 | + parser = argparse.ArgumentParser() |
| 125 | + parser.add_argument( |
| 126 | + "--top_k", |
| 127 | + type=int, |
| 128 | + required=False, |
| 129 | + default=128, |
| 130 | + help="Top K value for logits.", |
| 131 | + ) |
| 132 | + local_arg, remaining_args = parser.parse_known_args() |
| 133 | + |
| 134 | + main_wrapper = functools.partial(main, local_args=local_arg) |
| 135 | + app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args) |
0 commit comments