-
Notifications
You must be signed in to change notification settings - Fork 507
Workflow to save top-k teacher logits in GCS to use in distillaiton #3193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+292
−0
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
188 changes: 188 additions & 0 deletions
188
src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Copyright 2023–2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| This module provides functionality to save top-k teacher logits | ||
| for distillation purposes in MaxText. | ||
|
|
||
| Example command: | ||
| python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py \ | ||
| src/maxtext/configs/post_train/distillation.yml \ | ||
| --top_k=128 \ | ||
| --gcs_upload_path=gs://my-bucket/teacher_logits/ | ||
| """ | ||
|
|
||
| import os | ||
| import pickle | ||
| from typing import Sequence | ||
| import argparse | ||
| import time | ||
| import sys | ||
| import tensorflow as tf | ||
|
|
||
| import jax | ||
| import functools | ||
| from itertools import islice | ||
|
|
||
| from absl import app | ||
| from MaxText import pyconfig | ||
| from maxtext.utils import model_creation_utils | ||
| from maxtext.input_pipeline import input_pipeline_interface | ||
| from maxtext.utils import maxtext_utils | ||
| from maxtext.utils import max_logging | ||
|
|
||
| from jax.experimental import multihost_utils | ||
| from array_record.python import array_record_module | ||
|
|
||
|
|
||
| def get_top_k_logits(logits: jax.Array, k: int): | ||
| """Extracts the top-k values and their vocabulary indices""" | ||
| top_k_values, top_k_indices = jax.lax.top_k(logits, k) | ||
| return top_k_values, top_k_indices | ||
|
|
||
|
|
||
| def generate_and_save_data(config, local_args): | ||
| """Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS.""" | ||
| k_val = local_args.top_k | ||
| optional_keys = local_args.optional_keys | ||
| gcs_upload_path = local_args.gcs_upload_path | ||
| local_tmp_dir = local_args.local_tmp_dir | ||
|
|
||
| devices = jax.devices() | ||
| devices_array = maxtext_utils.create_device_mesh(config, devices) | ||
| mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) | ||
|
|
||
| # Loading teacher model and dataset iterator | ||
| max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...") | ||
| teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) | ||
| train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh) | ||
|
ajkv-google marked this conversation as resolved.
|
||
|
|
||
| # Setup local tmp directory for Host 0 | ||
| filename = "teacher_top_k_global.array_record" | ||
| local_output_path = os.path.join(local_tmp_dir, filename) | ||
|
|
||
| writer = None | ||
| if jax.process_index() == 0: | ||
| if not os.path.exists(local_tmp_dir): | ||
| os.makedirs(local_tmp_dir) | ||
| max_logging.log(f"Process 0 writing globally gathered data to local path: {local_output_path}") | ||
| writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1000") | ||
|
|
||
| # Sync all hosts before starting the loop | ||
| multihost_utils.sync_global_devices("start_generation_loop") | ||
|
|
||
| max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...") | ||
| loop_start = time.time() | ||
|
|
||
| for step, batch in enumerate(islice(train_iter, config.steps)): | ||
| step_start = time.time() | ||
| tokens = batch["inputs"] | ||
| logits = teacher_model( | ||
| decoder_input_tokens=tokens, | ||
| decoder_positions=batch["inputs_position"], | ||
| enable_dropout=False, | ||
| ) | ||
| top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val) | ||
|
|
||
| # Fetch the global distributed jax arrays | ||
| global_vals = jax.device_get(top_k_vals) | ||
| global_idx = jax.device_get(top_k_idx) | ||
| global_tokens = jax.device_get(tokens) | ||
|
|
||
| if jax.process_index() == 0: | ||
| record_dict = { | ||
| "tokens": global_tokens, | ||
| "top_k_logits": global_vals, | ||
| "top_k_indices": global_idx, | ||
| } | ||
|
|
||
| for key in optional_keys: | ||
| if key in batch: | ||
| record_dict[key] = jax.device_get(batch[key]) | ||
|
|
||
| writer.write(pickle.dumps(record_dict)) | ||
|
|
||
| if step % 50 == 0: | ||
| max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s") | ||
|
|
||
| max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s") | ||
|
|
||
| # Sync to ensure all hosts finish the forward passes before host 0 starts uploading | ||
| multihost_utils.sync_global_devices("loop_finished") | ||
|
|
||
| # Finalize writing and handle GCS upload on Host 0 | ||
| if jax.process_index() == 0: | ||
| writer.close() | ||
| max_logging.log(f"Finished writing to local disk: {local_output_path}") | ||
|
|
||
| if gcs_upload_path: | ||
| gcs_file_path = os.path.join(gcs_upload_path, filename) | ||
| max_logging.log(f"Flag --gcs_upload_path detected. Uploading to: {gcs_file_path}") | ||
|
|
||
| if not tf.io.gfile.exists(gcs_upload_path): | ||
| tf.io.gfile.makedirs(gcs_upload_path) | ||
|
|
||
| # Perform the bulk copy to GCS | ||
| tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True) | ||
| max_logging.log("GCS Upload complete.") | ||
|
|
||
| # Sync all hosts one last time so worker hosts don't terminate the job | ||
| multihost_utils.sync_global_devices("upload_complete") | ||
|
|
||
|
|
||
| def main(argv: Sequence[str], local_args): | ||
| # Initialize the global configuration | ||
| global_config = pyconfig.initialize(argv) | ||
| teacher_overrides = global_config.teacher_overrides | ||
| teacher_argv = [argv[0], argv[1]] | ||
| teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) | ||
|
|
||
| # Pass the entire local_args object to clean up the function signature | ||
| generate_and_save_data(teacher_config, local_args) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--top_k", | ||
| type=int, | ||
| required=False, | ||
| default=128, | ||
| help="Top K value for logits.", | ||
| ) | ||
| parser.add_argument( | ||
| "--optional_keys", | ||
| type=str, | ||
| nargs="*", | ||
| default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"], | ||
| help="Optional keys to save from teacher logits (space-separated).", | ||
| ) | ||
| parser.add_argument( | ||
| "--gcs_upload_path", | ||
| type=str, | ||
| required=False, | ||
| default=None, | ||
| help="Optional GCS directory (e.g., gs://my-bucket/logits/) to upload the locally saved ArrayRecord file.", | ||
| ) | ||
| parser.add_argument( | ||
| "--local_tmp_dir", | ||
| type=str, | ||
| required=False, | ||
| default="/tmp", | ||
| help="Local temporary directory to write the ArrayRecord file before optional GCS upload.", | ||
| ) | ||
| local_arg, remaining_args = parser.parse_known_args() | ||
|
|
||
| main_wrapper = functools.partial(main, local_args=local_arg) | ||
| app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args) | ||
104 changes: 104 additions & 0 deletions
104
src/maxtext/trainers/post_train/distillation/verify_saved_logits.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # Copyright 2023–2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Verification script to check the correctness of saved top-k teacher logits. | ||
|
|
||
| Example usage: | ||
| python3 src/maxtext/trainers/post_train/distillation/verify_saved_logits.py \ | ||
| --output_dir=/path/to/your/output \ | ||
| --expected_steps=1000 \ | ||
| --top_k=128 | ||
| """ | ||
|
|
||
| import functools | ||
| import sys | ||
|
|
||
| import argparse | ||
| import pickle | ||
| from absl import app | ||
| import tensorflow as tf | ||
| from array_record.python import array_record_module | ||
| from maxtext.utils import max_logging | ||
|
|
||
|
|
||
| def verify_array_records(output_dir, expected_steps, expected_k, expected_keys): | ||
| """Verifies the contents of ArrayRecord files containing top-k teacher logits.""" | ||
|
|
||
| file_pattern = f"{output_dir}/*.array_record" | ||
| files = tf.io.gfile.glob(file_pattern) | ||
|
|
||
| if not files: | ||
| max_logging.log(f"Error: No ArrayRecord files found matching {file_pattern}") | ||
| return | ||
|
|
||
| max_logging.log(f"Found {len(files)} ArrayRecord files. Starting verification...") | ||
|
|
||
| for file_path in files: | ||
| max_logging.log(f"Verifying: {file_path}") | ||
| reader = array_record_module.ArrayRecordReader(file_path) | ||
| num_records = reader.num_records() | ||
|
|
||
| step_count = 0 | ||
| for _ in range(num_records): | ||
| record = reader.read() | ||
| data = pickle.loads(record) | ||
|
|
||
| # Verify all required keys are present | ||
| for key in ["tokens", "top_k_logits", "top_k_indices"]: | ||
| assert key in data, f"Missing required key '{key}' at step {step_count} in {file_path}" | ||
|
|
||
| # Verify all optional keys are present | ||
| for key in expected_keys: | ||
| assert key in data, f"Missing optional key '{key}' at step {step_count} in {file_path}" | ||
|
|
||
| # Verify shapes for Top-K outputs | ||
| actual_k_logits = data["top_k_logits"].shape[-1] | ||
| actual_k_indices = data["top_k_indices"].shape[-1] | ||
| assert actual_k_logits == expected_k, f"Expected top_k={expected_k}, got {actual_k_logits} for logits" | ||
| assert actual_k_indices == expected_k, f"Expected top_k={expected_k}, got {actual_k_indices} for indices" | ||
|
|
||
| step_count += 1 | ||
|
|
||
| # Verify the total number of steps processed | ||
| assert step_count == expected_steps, f"Expected {expected_steps} steps, but found {step_count} in {file_path}." | ||
|
|
||
| max_logging.log(f"Successfully verified {file_path}") | ||
| max_logging.log(f"- Total steps: {step_count} (Matches expected)") | ||
| max_logging.log(f"- Top-K dimension: {expected_k}") | ||
| max_logging.log(f"- Keys verified: {list(data.keys())}") | ||
|
|
||
|
|
||
| def main(argv, local_args): | ||
| verify_array_records(local_args.output_dir, local_args.expected_steps, local_args.top_k, local_args.optional_keys) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--output_dir", type=str, required=True, help="Directory containing the array_record files.") | ||
| parser.add_argument( | ||
| "--expected_steps", type=int, required=True, help="Number of expected steps (matches config.steps)." | ||
| ) | ||
| parser.add_argument("--top_k", type=int, default=128, help="Expected top K value.") | ||
| parser.add_argument( | ||
| "--optional_keys", | ||
| type=str, | ||
| nargs="*", | ||
| default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"], | ||
| help="Optional keys expected to be in the record.", | ||
| ) | ||
|
|
||
| local_arg, remaining_args = parser.parse_known_args() | ||
| main_wrapper = functools.partial(main, local_args=local_arg) | ||
| app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.