|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +This module provides functionality to save top-k teacher logits |
| 17 | +for distillation purposes in MaxText. |
| 18 | +
|
| 19 | +Example command: |
| 20 | +python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py \ |
| 21 | +src/maxtext/configs/post_train/distillation.yml \ |
| 22 | +--top_k=128 \ |
| 23 | +--gcs_upload_path=gs://my-bucket/teacher_logits/ |
| 24 | +""" |
| 25 | + |
| 26 | +import os |
| 27 | +import pickle |
| 28 | +from typing import Sequence |
| 29 | +import argparse |
| 30 | +import time |
| 31 | +import sys |
| 32 | +import tensorflow as tf |
| 33 | + |
| 34 | +import jax |
| 35 | +import functools |
| 36 | +from itertools import islice |
| 37 | + |
| 38 | +from absl import app |
| 39 | +from MaxText import pyconfig |
| 40 | +from maxtext.utils import model_creation_utils |
| 41 | +from maxtext.input_pipeline import input_pipeline_interface |
| 42 | +from maxtext.utils import maxtext_utils |
| 43 | +from maxtext.utils import max_logging |
| 44 | + |
| 45 | +from jax.experimental import multihost_utils |
| 46 | +from array_record.python import array_record_module |
| 47 | + |
| 48 | + |
| 49 | +def get_top_k_logits(logits: jax.Array, k: int): |
| 50 | + """Extracts the top-k values and their vocabulary indices""" |
| 51 | + top_k_values, top_k_indices = jax.lax.top_k(logits, k) |
| 52 | + return top_k_values, top_k_indices |
| 53 | + |
| 54 | + |
| 55 | +def generate_and_save_data(config, local_args): |
| 56 | + """Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS.""" |
| 57 | + k_val = local_args.top_k |
| 58 | + optional_keys = local_args.optional_keys |
| 59 | + gcs_upload_path = local_args.gcs_upload_path |
| 60 | + local_tmp_dir = local_args.local_tmp_dir |
| 61 | + |
| 62 | + devices = jax.devices() |
| 63 | + devices_array = maxtext_utils.create_device_mesh(config, devices) |
| 64 | + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) |
| 65 | + |
| 66 | + # Loading teacher model and dataset iterator |
| 67 | + max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...") |
| 68 | + teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) |
| 69 | + train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh) |
| 70 | + |
| 71 | + # Setup local tmp directory for Host 0 |
| 72 | + filename = "teacher_top_k_global.array_record" |
| 73 | + local_output_path = os.path.join(local_tmp_dir, filename) |
| 74 | + |
| 75 | + writer = None |
| 76 | + if jax.process_index() == 0: |
| 77 | + if not os.path.exists(local_tmp_dir): |
| 78 | + os.makedirs(local_tmp_dir) |
| 79 | + max_logging.log(f"Process 0 writing globally gathered data to local path: {local_output_path}") |
| 80 | + writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1000") |
| 81 | + |
| 82 | + # Sync all hosts before starting the loop |
| 83 | + multihost_utils.sync_global_devices("start_generation_loop") |
| 84 | + |
| 85 | + max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...") |
| 86 | + loop_start = time.time() |
| 87 | + |
| 88 | + for step, batch in enumerate(islice(train_iter, config.steps)): |
| 89 | + step_start = time.time() |
| 90 | + tokens = batch["inputs"] |
| 91 | + logits = teacher_model( |
| 92 | + decoder_input_tokens=tokens, |
| 93 | + decoder_positions=batch["inputs_position"], |
| 94 | + enable_dropout=False, |
| 95 | + ) |
| 96 | + top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val) |
| 97 | + |
| 98 | + # Fetch the global distributed jax arrays |
| 99 | + global_vals = jax.device_get(top_k_vals) |
| 100 | + global_idx = jax.device_get(top_k_idx) |
| 101 | + global_tokens = jax.device_get(tokens) |
| 102 | + |
| 103 | + if jax.process_index() == 0: |
| 104 | + record_dict = { |
| 105 | + "tokens": global_tokens, |
| 106 | + "top_k_logits": global_vals, |
| 107 | + "top_k_indices": global_idx, |
| 108 | + } |
| 109 | + |
| 110 | + for key in optional_keys: |
| 111 | + if key in batch: |
| 112 | + record_dict[key] = jax.device_get(batch[key]) |
| 113 | + |
| 114 | + writer.write(pickle.dumps(record_dict)) |
| 115 | + |
| 116 | + if step % 50 == 0: |
| 117 | + max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s") |
| 118 | + |
| 119 | + max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s") |
| 120 | + |
| 121 | + # Sync to ensure all hosts finish the forward passes before host 0 starts uploading |
| 122 | + multihost_utils.sync_global_devices("loop_finished") |
| 123 | + |
| 124 | + # Finalize writing and handle GCS upload on Host 0 |
| 125 | + if jax.process_index() == 0: |
| 126 | + writer.close() |
| 127 | + max_logging.log(f"Finished writing to local disk: {local_output_path}") |
| 128 | + |
| 129 | + if gcs_upload_path: |
| 130 | + gcs_file_path = os.path.join(gcs_upload_path, filename) |
| 131 | + max_logging.log(f"Flag --gcs_upload_path detected. Uploading to: {gcs_file_path}") |
| 132 | + |
| 133 | + if not tf.io.gfile.exists(gcs_upload_path): |
| 134 | + tf.io.gfile.makedirs(gcs_upload_path) |
| 135 | + |
| 136 | + # Perform the bulk copy to GCS |
| 137 | + tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True) |
| 138 | + max_logging.log("GCS Upload complete.") |
| 139 | + |
| 140 | + # Sync all hosts one last time so worker hosts don't terminate the job |
| 141 | + multihost_utils.sync_global_devices("upload_complete") |
| 142 | + |
| 143 | + |
| 144 | +def main(argv: Sequence[str], local_args): |
| 145 | + # Initialize the global configuration |
| 146 | + global_config = pyconfig.initialize(argv) |
| 147 | + teacher_overrides = global_config.teacher_overrides |
| 148 | + teacher_argv = [argv[0], argv[1]] |
| 149 | + teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) |
| 150 | + |
| 151 | + # Pass the entire local_args object to clean up the function signature |
| 152 | + generate_and_save_data(teacher_config, local_args) |
| 153 | + |
| 154 | + |
| 155 | +if __name__ == "__main__": |
| 156 | + parser = argparse.ArgumentParser() |
| 157 | + parser.add_argument( |
| 158 | + "--top_k", |
| 159 | + type=int, |
| 160 | + required=False, |
| 161 | + default=128, |
| 162 | + help="Top K value for logits.", |
| 163 | + ) |
| 164 | + parser.add_argument( |
| 165 | + "--optional_keys", |
| 166 | + type=str, |
| 167 | + nargs="*", |
| 168 | + default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"], |
| 169 | + help="Optional keys to save from teacher logits (space-separated).", |
| 170 | + ) |
| 171 | + parser.add_argument( |
| 172 | + "--gcs_upload_path", |
| 173 | + type=str, |
| 174 | + required=False, |
| 175 | + default=None, |
| 176 | + help="Optional GCS directory (e.g., gs://my-bucket/logits/) to upload the locally saved ArrayRecord file.", |
| 177 | + ) |
| 178 | + parser.add_argument( |
| 179 | + "--local_tmp_dir", |
| 180 | + type=str, |
| 181 | + required=False, |
| 182 | + default="/tmp", |
| 183 | + help="Local temporary directory to write the ArrayRecord file before optional GCS upload.", |
| 184 | + ) |
| 185 | + local_arg, remaining_args = parser.parse_known_args() |
| 186 | + |
| 187 | + main_wrapper = functools.partial(main, local_args=local_arg) |
| 188 | + app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args) |
0 commit comments