diff --git a/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py b/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py new file mode 100644 index 0000000000..05c567e924 --- /dev/null +++ b/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py @@ -0,0 +1,188 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides functionality to save top-k teacher logits +for distillation purposes in MaxText. + +Example command: +python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py \ +src/maxtext/configs/post_train/distillation.yml \ +--top_k=128 \ +--gcs_upload_path=gs://my-bucket/teacher_logits/ +""" + +import os +import pickle +from typing import Sequence +import argparse +import time +import sys +import tensorflow as tf + +import jax +import functools +from itertools import islice + +from absl import app +from MaxText import pyconfig +from maxtext.utils import model_creation_utils +from maxtext.input_pipeline import input_pipeline_interface +from maxtext.utils import maxtext_utils +from maxtext.utils import max_logging + +from jax.experimental import multihost_utils +from array_record.python import array_record_module + + +def get_top_k_logits(logits: jax.Array, k: int): + """Extracts the top-k values and their vocabulary indices""" + top_k_values, top_k_indices = jax.lax.top_k(logits, k) + return top_k_values, top_k_indices + + +def generate_and_save_data(config, local_args): + """Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS.""" + k_val = local_args.top_k + optional_keys = local_args.optional_keys + gcs_upload_path = local_args.gcs_upload_path + local_tmp_dir = local_args.local_tmp_dir + + devices = jax.devices() + devices_array = maxtext_utils.create_device_mesh(config, devices) + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) + + # Loading teacher model and dataset iterator + max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...") + teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) + train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh) + + # Setup local tmp directory for Host 0 + filename = "teacher_top_k_global.array_record" + local_output_path = os.path.join(local_tmp_dir, filename) + + writer = None + if jax.process_index() == 0: + if not os.path.exists(local_tmp_dir): + os.makedirs(local_tmp_dir) + max_logging.log(f"Process 0 writing globally gathered data to local path: {local_output_path}") + writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1000") + + # Sync all hosts before starting the loop + multihost_utils.sync_global_devices("start_generation_loop") + + max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...") + loop_start = time.time() + + for step, batch in enumerate(islice(train_iter, config.steps)): + step_start = time.time() + tokens = batch["inputs"] + logits = teacher_model( + decoder_input_tokens=tokens, + decoder_positions=batch["inputs_position"], + enable_dropout=False, + ) + top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val) + + # Fetch the global distributed jax arrays + global_vals = jax.device_get(top_k_vals) + global_idx = jax.device_get(top_k_idx) + global_tokens = jax.device_get(tokens) + + if jax.process_index() == 0: + record_dict = { + "tokens": global_tokens, + "top_k_logits": global_vals, + "top_k_indices": global_idx, + } + + for key in optional_keys: + if key in batch: + record_dict[key] = jax.device_get(batch[key]) + + writer.write(pickle.dumps(record_dict)) + + if step % 50 == 0: + max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s") + + max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s") + + # Sync to ensure all hosts finish the forward passes before host 0 starts uploading + multihost_utils.sync_global_devices("loop_finished") + + # Finalize writing and handle GCS upload on Host 0 + if jax.process_index() == 0: + writer.close() + max_logging.log(f"Finished writing to local disk: {local_output_path}") + + if gcs_upload_path: + gcs_file_path = os.path.join(gcs_upload_path, filename) + max_logging.log(f"Flag --gcs_upload_path detected. Uploading to: {gcs_file_path}") + + if not tf.io.gfile.exists(gcs_upload_path): + tf.io.gfile.makedirs(gcs_upload_path) + + # Perform the bulk copy to GCS + tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True) + max_logging.log("GCS Upload complete.") + + # Sync all hosts one last time so worker hosts don't terminate the job + multihost_utils.sync_global_devices("upload_complete") + + +def main(argv: Sequence[str], local_args): + # Initialize the global configuration + global_config = pyconfig.initialize(argv) + teacher_overrides = global_config.teacher_overrides + teacher_argv = [argv[0], argv[1]] + teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) + + # Pass the entire local_args object to clean up the function signature + generate_and_save_data(teacher_config, local_args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--top_k", + type=int, + required=False, + default=128, + help="Top K value for logits.", + ) + parser.add_argument( + "--optional_keys", + type=str, + nargs="*", + default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"], + help="Optional keys to save from teacher logits (space-separated).", + ) + parser.add_argument( + "--gcs_upload_path", + type=str, + required=False, + default=None, + help="Optional GCS directory (e.g., gs://my-bucket/logits/) to upload the locally saved ArrayRecord file.", + ) + parser.add_argument( + "--local_tmp_dir", + type=str, + required=False, + default="/tmp", + help="Local temporary directory to write the ArrayRecord file before optional GCS upload.", + ) + local_arg, remaining_args = parser.parse_known_args() + + main_wrapper = functools.partial(main, local_args=local_arg) + app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args) diff --git a/src/maxtext/trainers/post_train/distillation/verify_saved_logits.py b/src/maxtext/trainers/post_train/distillation/verify_saved_logits.py new file mode 100644 index 0000000000..aa60581571 --- /dev/null +++ b/src/maxtext/trainers/post_train/distillation/verify_saved_logits.py @@ -0,0 +1,104 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Verification script to check the correctness of saved top-k teacher logits. + +Example usage: +python3 src/maxtext/trainers/post_train/distillation/verify_saved_logits.py \ + --output_dir=/path/to/your/output \ + --expected_steps=1000 \ + --top_k=128 +""" + +import functools +import sys + +import argparse +import pickle +from absl import app +import tensorflow as tf +from array_record.python import array_record_module +from maxtext.utils import max_logging + + +def verify_array_records(output_dir, expected_steps, expected_k, expected_keys): + """Verifies the contents of ArrayRecord files containing top-k teacher logits.""" + + file_pattern = f"{output_dir}/*.array_record" + files = tf.io.gfile.glob(file_pattern) + + if not files: + max_logging.log(f"Error: No ArrayRecord files found matching {file_pattern}") + return + + max_logging.log(f"Found {len(files)} ArrayRecord files. Starting verification...") + + for file_path in files: + max_logging.log(f"Verifying: {file_path}") + reader = array_record_module.ArrayRecordReader(file_path) + num_records = reader.num_records() + + step_count = 0 + for _ in range(num_records): + record = reader.read() + data = pickle.loads(record) + + # Verify all required keys are present + for key in ["tokens", "top_k_logits", "top_k_indices"]: + assert key in data, f"Missing required key '{key}' at step {step_count} in {file_path}" + + # Verify all optional keys are present + for key in expected_keys: + assert key in data, f"Missing optional key '{key}' at step {step_count} in {file_path}" + + # Verify shapes for Top-K outputs + actual_k_logits = data["top_k_logits"].shape[-1] + actual_k_indices = data["top_k_indices"].shape[-1] + assert actual_k_logits == expected_k, f"Expected top_k={expected_k}, got {actual_k_logits} for logits" + assert actual_k_indices == expected_k, f"Expected top_k={expected_k}, got {actual_k_indices} for indices" + + step_count += 1 + + # Verify the total number of steps processed + assert step_count == expected_steps, f"Expected {expected_steps} steps, but found {step_count} in {file_path}." + + max_logging.log(f"Successfully verified {file_path}") + max_logging.log(f"- Total steps: {step_count} (Matches expected)") + max_logging.log(f"- Top-K dimension: {expected_k}") + max_logging.log(f"- Keys verified: {list(data.keys())}") + + +def main(argv, local_args): + verify_array_records(local_args.output_dir, local_args.expected_steps, local_args.top_k, local_args.optional_keys) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", type=str, required=True, help="Directory containing the array_record files.") + parser.add_argument( + "--expected_steps", type=int, required=True, help="Number of expected steps (matches config.steps)." + ) + parser.add_argument("--top_k", type=int, default=128, help="Expected top K value.") + parser.add_argument( + "--optional_keys", + type=str, + nargs="*", + default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"], + help="Optional keys expected to be in the record.", + ) + + local_arg, remaining_args = parser.parse_known_args() + main_wrapper = functools.partial(main, local_args=local_arg) + app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)