Skip to content

Commit c5702d8

Browse files
committed
Updated code and cleaned up commits
Moved optional keys to be cmd arguments Added script to verify writing of data to gs bucket added code to work in multihost fixed spacing and code formatting saving top-k teacher logits on one host to local with option to store to a gs bucket updated code formatting Updated to provide local filepath as cmd arg
1 parent 5cd1acb commit c5702d8

2 files changed

Lines changed: 264 additions & 0 deletions

File tree

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
This module provides functionality to save top-k teacher logits
3+
for distillation purposes in MaxText.
4+
5+
Example command:
6+
python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py \
7+
src/maxtext/configs/post_train/distillation.yml \
8+
--top_k=128 \
9+
--gcs_upload_path=gs://my-bucket/teacher_logits/
10+
"""
11+
12+
import os
13+
import pickle
14+
from typing import Sequence
15+
import argparse
16+
import time
17+
import sys
18+
import tensorflow as tf
19+
20+
import jax
21+
import functools
22+
from itertools import islice
23+
24+
from absl import app
25+
from MaxText import pyconfig
26+
from maxtext.utils import model_creation_utils
27+
from maxtext.input_pipeline import input_pipeline_interface
28+
from maxtext.utils import maxtext_utils
29+
from maxtext.utils import max_logging
30+
31+
from jax.experimental import multihost_utils
32+
from array_record.python import array_record_module
33+
34+
35+
def get_top_k_logits(logits: jax.Array, k: int):
36+
"""Extracts the top-k values and their vocabulary indices"""
37+
top_k_values, top_k_indices = jax.lax.top_k(logits, k)
38+
return top_k_values, top_k_indices
39+
40+
41+
def generate_and_save_data(config, local_args):
42+
"""Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS."""
43+
k_val = local_args.top_k
44+
optional_keys = local_args.optional_keys
45+
gcs_upload_path = local_args.gcs_upload_path
46+
local_tmp_dir = local_args.local_tmp_dir
47+
48+
devices = jax.devices()
49+
devices_array = maxtext_utils.create_device_mesh(config, devices)
50+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
51+
52+
# Loading teacher model and dataset iterator
53+
max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...")
54+
teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
55+
train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)
56+
57+
# Setup local tmp directory for Host 0
58+
filename = "teacher_top_k_global.array_record"
59+
local_output_path = os.path.join(local_tmp_dir, filename)
60+
61+
writer = None
62+
if jax.process_index() == 0:
63+
if not os.path.exists(local_tmp_dir):
64+
os.makedirs(local_tmp_dir)
65+
max_logging.log(f"Process 0 writing globally gathered data to local path: {local_output_path}")
66+
writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1000")
67+
68+
# Sync all hosts before starting the loop
69+
multihost_utils.sync_global_devices("start_generation_loop")
70+
71+
max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...")
72+
loop_start = time.time()
73+
74+
for step, batch in enumerate(islice(train_iter, config.steps)):
75+
step_start = time.time()
76+
tokens = batch["inputs"]
77+
logits = teacher_model(
78+
decoder_input_tokens=tokens,
79+
decoder_positions=batch["inputs_position"],
80+
enable_dropout=False,
81+
)
82+
top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val)
83+
84+
# Fetch the global distributed jax arrays
85+
global_vals = jax.device_get(top_k_vals)
86+
global_idx = jax.device_get(top_k_idx)
87+
global_tokens = jax.device_get(tokens)
88+
89+
if jax.process_index() == 0:
90+
record_dict = {
91+
"tokens": global_tokens,
92+
"top_k_logits": global_vals,
93+
"top_k_indices": global_idx,
94+
}
95+
96+
for key in optional_keys:
97+
if key in batch:
98+
record_dict[key] = jax.device_get(batch[key])
99+
100+
writer.write(pickle.dumps(record_dict))
101+
102+
if step % 50 == 0:
103+
max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s")
104+
105+
max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s")
106+
107+
# Sync to ensure all hosts finish the forward passes before host 0 starts uploading
108+
multihost_utils.sync_global_devices("loop_finished")
109+
110+
# Finalize writing and handle GCS upload on Host 0
111+
if jax.process_index() == 0:
112+
writer.close()
113+
max_logging.log(f"Finished writing to local disk: {local_output_path}")
114+
115+
if gcs_upload_path:
116+
gcs_file_path = os.path.join(gcs_upload_path, filename)
117+
max_logging.log(f"Flag --gcs_upload_path detected. Uploading to: {gcs_file_path}")
118+
119+
if not tf.io.gfile.exists(gcs_upload_path):
120+
tf.io.gfile.makedirs(gcs_upload_path)
121+
122+
# Perform the bulk copy to GCS
123+
tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True)
124+
max_logging.log("GCS Upload complete.")
125+
126+
# Sync all hosts one last time so worker hosts don't terminate the job
127+
multihost_utils.sync_global_devices("upload_complete")
128+
129+
130+
def main(argv: Sequence[str], local_args):
131+
# Initialize the global configuration
132+
global_config = pyconfig.initialize(argv)
133+
teacher_overrides = global_config.teacher_overrides
134+
teacher_argv = [argv[0], argv[1]]
135+
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
136+
137+
# Pass the entire local_args object to clean up the function signature
138+
generate_and_save_data(teacher_config, local_args)
139+
140+
141+
if __name__ == "__main__":
142+
parser = argparse.ArgumentParser()
143+
parser.add_argument(
144+
"--top_k",
145+
type=int,
146+
required=False,
147+
default=128,
148+
help="Top K value for logits.",
149+
)
150+
parser.add_argument(
151+
"--optional_keys",
152+
type=str,
153+
nargs="*",
154+
default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"],
155+
help="Optional keys to save from teacher logits (space-separated).",
156+
)
157+
parser.add_argument(
158+
"--gcs_upload_path",
159+
type=str,
160+
required=False,
161+
default=None,
162+
help="Optional GCS directory (e.g., gs://my-bucket/logits/) to upload the locally saved ArrayRecord file.",
163+
)
164+
parser.add_argument(
165+
"--local_tmp_dir",
166+
type=str,
167+
required=False,
168+
default="/tmp",
169+
help="Local temporary directory to write the ArrayRecord file before optional GCS upload.",
170+
)
171+
local_arg, remaining_args = parser.parse_known_args()
172+
173+
main_wrapper = functools.partial(main, local_args=local_arg)
174+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
Verification script to check the correctness of saved top-k teacher logits.
3+
4+
Example usage:
5+
python3 src/maxtext/trainers/post_train/distillation/verify_saved_logits.py \
6+
--output_dir=/path/to/your/output \
7+
--expected_steps=1000 \
8+
--top_k=128
9+
"""
10+
11+
import functools
12+
import sys
13+
14+
import argparse
15+
import pickle
16+
from absl import app
17+
import tensorflow as tf
18+
from array_record.python import array_record_module
19+
from maxtext.utils import max_logging
20+
21+
22+
def verify_array_records(output_dir, expected_steps, expected_k, expected_keys):
23+
"""Verifies the contents of ArrayRecord files containing top-k teacher logits."""
24+
25+
file_pattern = f"{output_dir}/*.array_record"
26+
files = tf.io.gfile.glob(file_pattern)
27+
28+
if not files:
29+
max_logging.log(f"Error: No ArrayRecord files found matching {file_pattern}")
30+
return
31+
32+
max_logging.log(f"Found {len(files)} ArrayRecord files. Starting verification...")
33+
34+
for file_path in files:
35+
max_logging.log(f"Verifying: {file_path}")
36+
reader = array_record_module.ArrayRecordReader(file_path)
37+
num_records = reader.num_records()
38+
39+
step_count = 0
40+
for _ in range(num_records):
41+
record = reader.read()
42+
data = pickle.loads(record)
43+
44+
# Verify all required keys are present
45+
for key in ["tokens", "top_k_logits", "top_k_indices"]:
46+
assert key in data, f"Missing required key '{key}' at step {step_count} in {file_path}"
47+
48+
# Verify all optional keys are present
49+
for key in expected_keys:
50+
assert key in data, f"Missing optional key '{key}' at step {step_count} in {file_path}"
51+
52+
# Verify shapes for Top-K outputs
53+
actual_k_logits = data["top_k_logits"].shape[-1]
54+
actual_k_indices = data["top_k_indices"].shape[-1]
55+
assert actual_k_logits == expected_k, f"Expected top_k={expected_k}, got {actual_k_logits} for logits"
56+
assert actual_k_indices == expected_k, f"Expected top_k={expected_k}, got {actual_k_indices} for indices"
57+
58+
step_count += 1
59+
60+
# Verify the total number of steps processed
61+
assert step_count == expected_steps, f"Expected {expected_steps} steps, but found {step_count} in {file_path}."
62+
63+
max_logging.log(f"Successfully verified {file_path}")
64+
max_logging.log(f"- Total steps: {step_count} (Matches expected)")
65+
max_logging.log(f"- Top-K dimension: {expected_k}")
66+
max_logging.log(f"- Keys verified: {list(data.keys())}")
67+
68+
69+
def main(argv, local_args):
70+
verify_array_records(local_args.output_dir, local_args.expected_steps, local_args.top_k, local_args.optional_keys)
71+
72+
73+
if __name__ == "__main__":
74+
parser = argparse.ArgumentParser()
75+
parser.add_argument("--output_dir", type=str, required=True, help="Directory containing the array_record files.")
76+
parser.add_argument(
77+
"--expected_steps", type=int, required=True, help="Number of expected steps (matches config.steps)."
78+
)
79+
parser.add_argument("--top_k", type=int, default=128, help="Expected top K value.")
80+
parser.add_argument(
81+
"--optional_keys",
82+
type=str,
83+
nargs="*",
84+
default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"],
85+
help="Optional keys expected to be in the record.",
86+
)
87+
88+
local_arg, remaining_args = parser.parse_known_args()
89+
main_wrapper = functools.partial(main, local_args=local_arg)
90+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)