Skip to content

Commit 0efc6ca

Browse files
Merge pull request #3193 from AI-Hypercomputer:ajkv-teacher-top-k-distillation
PiperOrigin-RevId: 886400252
2 parents 161f69a + d4f37dc commit 0efc6ca

2 files changed

Lines changed: 292 additions & 0 deletions

File tree

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This module provides functionality to save top-k teacher logits
17+
for distillation purposes in MaxText.
18+
19+
Example command:
20+
python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py \
21+
src/maxtext/configs/post_train/distillation.yml \
22+
--top_k=128 \
23+
--gcs_upload_path=gs://my-bucket/teacher_logits/
24+
"""
25+
26+
import os
27+
import pickle
28+
from typing import Sequence
29+
import argparse
30+
import time
31+
import sys
32+
import tensorflow as tf
33+
34+
import jax
35+
import functools
36+
from itertools import islice
37+
38+
from absl import app
39+
from MaxText import pyconfig
40+
from maxtext.utils import model_creation_utils
41+
from maxtext.input_pipeline import input_pipeline_interface
42+
from maxtext.utils import maxtext_utils
43+
from maxtext.utils import max_logging
44+
45+
from jax.experimental import multihost_utils
46+
from array_record.python import array_record_module
47+
48+
49+
def get_top_k_logits(logits: jax.Array, k: int):
50+
"""Extracts the top-k values and their vocabulary indices"""
51+
top_k_values, top_k_indices = jax.lax.top_k(logits, k)
52+
return top_k_values, top_k_indices
53+
54+
55+
def generate_and_save_data(config, local_args):
56+
"""Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS."""
57+
k_val = local_args.top_k
58+
optional_keys = local_args.optional_keys
59+
gcs_upload_path = local_args.gcs_upload_path
60+
local_tmp_dir = local_args.local_tmp_dir
61+
62+
devices = jax.devices()
63+
devices_array = maxtext_utils.create_device_mesh(config, devices)
64+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
65+
66+
# Loading teacher model and dataset iterator
67+
max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...")
68+
teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
69+
train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)
70+
71+
# Setup local tmp directory for Host 0
72+
filename = "teacher_top_k_global.array_record"
73+
local_output_path = os.path.join(local_tmp_dir, filename)
74+
75+
writer = None
76+
if jax.process_index() == 0:
77+
if not os.path.exists(local_tmp_dir):
78+
os.makedirs(local_tmp_dir)
79+
max_logging.log(f"Process 0 writing globally gathered data to local path: {local_output_path}")
80+
writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1000")
81+
82+
# Sync all hosts before starting the loop
83+
multihost_utils.sync_global_devices("start_generation_loop")
84+
85+
max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...")
86+
loop_start = time.time()
87+
88+
for step, batch in enumerate(islice(train_iter, config.steps)):
89+
step_start = time.time()
90+
tokens = batch["inputs"]
91+
logits = teacher_model(
92+
decoder_input_tokens=tokens,
93+
decoder_positions=batch["inputs_position"],
94+
enable_dropout=False,
95+
)
96+
top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val)
97+
98+
# Fetch the global distributed jax arrays
99+
global_vals = jax.device_get(top_k_vals)
100+
global_idx = jax.device_get(top_k_idx)
101+
global_tokens = jax.device_get(tokens)
102+
103+
if jax.process_index() == 0:
104+
record_dict = {
105+
"tokens": global_tokens,
106+
"top_k_logits": global_vals,
107+
"top_k_indices": global_idx,
108+
}
109+
110+
for key in optional_keys:
111+
if key in batch:
112+
record_dict[key] = jax.device_get(batch[key])
113+
114+
writer.write(pickle.dumps(record_dict))
115+
116+
if step % 50 == 0:
117+
max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s")
118+
119+
max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s")
120+
121+
# Sync to ensure all hosts finish the forward passes before host 0 starts uploading
122+
multihost_utils.sync_global_devices("loop_finished")
123+
124+
# Finalize writing and handle GCS upload on Host 0
125+
if jax.process_index() == 0:
126+
writer.close()
127+
max_logging.log(f"Finished writing to local disk: {local_output_path}")
128+
129+
if gcs_upload_path:
130+
gcs_file_path = os.path.join(gcs_upload_path, filename)
131+
max_logging.log(f"Flag --gcs_upload_path detected. Uploading to: {gcs_file_path}")
132+
133+
if not tf.io.gfile.exists(gcs_upload_path):
134+
tf.io.gfile.makedirs(gcs_upload_path)
135+
136+
# Perform the bulk copy to GCS
137+
tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True)
138+
max_logging.log("GCS Upload complete.")
139+
140+
# Sync all hosts one last time so worker hosts don't terminate the job
141+
multihost_utils.sync_global_devices("upload_complete")
142+
143+
144+
def main(argv: Sequence[str], local_args):
145+
# Initialize the global configuration
146+
global_config = pyconfig.initialize(argv)
147+
teacher_overrides = global_config.teacher_overrides
148+
teacher_argv = [argv[0], argv[1]]
149+
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
150+
151+
# Pass the entire local_args object to clean up the function signature
152+
generate_and_save_data(teacher_config, local_args)
153+
154+
155+
if __name__ == "__main__":
156+
parser = argparse.ArgumentParser()
157+
parser.add_argument(
158+
"--top_k",
159+
type=int,
160+
required=False,
161+
default=128,
162+
help="Top K value for logits.",
163+
)
164+
parser.add_argument(
165+
"--optional_keys",
166+
type=str,
167+
nargs="*",
168+
default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"],
169+
help="Optional keys to save from teacher logits (space-separated).",
170+
)
171+
parser.add_argument(
172+
"--gcs_upload_path",
173+
type=str,
174+
required=False,
175+
default=None,
176+
help="Optional GCS directory (e.g., gs://my-bucket/logits/) to upload the locally saved ArrayRecord file.",
177+
)
178+
parser.add_argument(
179+
"--local_tmp_dir",
180+
type=str,
181+
required=False,
182+
default="/tmp",
183+
help="Local temporary directory to write the ArrayRecord file before optional GCS upload.",
184+
)
185+
local_arg, remaining_args = parser.parse_known_args()
186+
187+
main_wrapper = functools.partial(main, local_args=local_arg)
188+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Verification script to check the correctness of saved top-k teacher logits.
17+
18+
Example usage:
19+
python3 src/maxtext/trainers/post_train/distillation/verify_saved_logits.py \
20+
--output_dir=/path/to/your/output \
21+
--expected_steps=1000 \
22+
--top_k=128
23+
"""
24+
25+
import functools
26+
import sys
27+
28+
import argparse
29+
import pickle
30+
from absl import app
31+
import tensorflow as tf
32+
from array_record.python import array_record_module
33+
from maxtext.utils import max_logging
34+
35+
36+
def verify_array_records(output_dir, expected_steps, expected_k, expected_keys):
37+
"""Verifies the contents of ArrayRecord files containing top-k teacher logits."""
38+
39+
file_pattern = f"{output_dir}/*.array_record"
40+
files = tf.io.gfile.glob(file_pattern)
41+
42+
if not files:
43+
max_logging.log(f"Error: No ArrayRecord files found matching {file_pattern}")
44+
return
45+
46+
max_logging.log(f"Found {len(files)} ArrayRecord files. Starting verification...")
47+
48+
for file_path in files:
49+
max_logging.log(f"Verifying: {file_path}")
50+
reader = array_record_module.ArrayRecordReader(file_path)
51+
num_records = reader.num_records()
52+
53+
step_count = 0
54+
for _ in range(num_records):
55+
record = reader.read()
56+
data = pickle.loads(record)
57+
58+
# Verify all required keys are present
59+
for key in ["tokens", "top_k_logits", "top_k_indices"]:
60+
assert key in data, f"Missing required key '{key}' at step {step_count} in {file_path}"
61+
62+
# Verify all optional keys are present
63+
for key in expected_keys:
64+
assert key in data, f"Missing optional key '{key}' at step {step_count} in {file_path}"
65+
66+
# Verify shapes for Top-K outputs
67+
actual_k_logits = data["top_k_logits"].shape[-1]
68+
actual_k_indices = data["top_k_indices"].shape[-1]
69+
assert actual_k_logits == expected_k, f"Expected top_k={expected_k}, got {actual_k_logits} for logits"
70+
assert actual_k_indices == expected_k, f"Expected top_k={expected_k}, got {actual_k_indices} for indices"
71+
72+
step_count += 1
73+
74+
# Verify the total number of steps processed
75+
assert step_count == expected_steps, f"Expected {expected_steps} steps, but found {step_count} in {file_path}."
76+
77+
max_logging.log(f"Successfully verified {file_path}")
78+
max_logging.log(f"- Total steps: {step_count} (Matches expected)")
79+
max_logging.log(f"- Top-K dimension: {expected_k}")
80+
max_logging.log(f"- Keys verified: {list(data.keys())}")
81+
82+
83+
def main(argv, local_args):
84+
verify_array_records(local_args.output_dir, local_args.expected_steps, local_args.top_k, local_args.optional_keys)
85+
86+
87+
if __name__ == "__main__":
88+
parser = argparse.ArgumentParser()
89+
parser.add_argument("--output_dir", type=str, required=True, help="Directory containing the array_record files.")
90+
parser.add_argument(
91+
"--expected_steps", type=int, required=True, help="Number of expected steps (matches config.steps)."
92+
)
93+
parser.add_argument("--top_k", type=int, default=128, help="Expected top K value.")
94+
parser.add_argument(
95+
"--optional_keys",
96+
type=str,
97+
nargs="*",
98+
default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"],
99+
help="Optional keys expected to be in the record.",
100+
)
101+
102+
local_arg, remaining_args = parser.parse_known_args()
103+
main_wrapper = functools.partial(main, local_args=local_arg)
104+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)