Skip to content

Commit 7cddd84

Browse files
committed
Updated code and cleaned up commits
1 parent f3d9f5c commit 7cddd84

1 file changed

Lines changed: 135 additions & 0 deletions

File tree

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
This module provides functionality to save top-k teacher logits
3+
for distillation purposes in MaxText.
4+
5+
Example command: python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py src/maxtext/configs/post_train/distillation.yml --top_k=128
6+
"""
7+
8+
import os
9+
import pickle
10+
from typing import Sequence
11+
import argparse
12+
import time
13+
import sys
14+
import tensorflow as tf
15+
16+
import jax
17+
import numpy as np
18+
import functools
19+
from itertools import islice
20+
21+
from absl import app
22+
from MaxText import pyconfig
23+
from maxtext.utils import model_creation_utils
24+
from maxtext.input_pipeline import input_pipeline_interface
25+
from maxtext.utils import maxtext_utils
26+
from maxtext.utils import max_logging
27+
28+
from jax.experimental import multihost_utils
29+
from array_record.python import array_record_module
30+
31+
32+
def get_top_k_logits(logits: jax.Array, k: int):
33+
"""Extracts the top-k values and their vocabulary indices"""
34+
top_k_values, top_k_indices = jax.lax.top_k(logits, k)
35+
return top_k_values, top_k_indices
36+
37+
38+
def get_local_cpu_array(arr):
39+
"""Extracts the local data from a sharded JAX array to a host numpy array."""
40+
return np.concatenate([np.array(s.data) for s in arr.addressable_shards], axis=0)
41+
42+
43+
def generate_and_save_data(config, k_val):
44+
"""Generates top-k logits from the teacher model and saves them to an ArrayRecord file"""
45+
devices = jax.devices()
46+
devices_array = maxtext_utils.create_device_mesh(config, devices)
47+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
48+
49+
# Loading teacher model and dataset iterator
50+
max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...")
51+
teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
52+
train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)
53+
54+
output_dir = config.base_output_directory
55+
if config.run_name:
56+
output_dir = os.path.join(output_dir, config.run_name)
57+
58+
if jax.process_index() == 0:
59+
if not tf.io.gfile.exists(output_dir):
60+
tf.io.gfile.makedirs(output_dir)
61+
62+
# Sync all hosts to ensure directory exists before writers open files
63+
multihost_utils.sync_global_devices("create_output_dir")
64+
65+
# Each host writes to a unique file based on its process index to avoid write conflicts
66+
filename = f"teacher_top_k_process_{jax.process_index()}.array_record"
67+
output_path = os.path.join(output_dir, filename)
68+
69+
max_logging.log(f"Process {jax.process_index()} writing directly to: {output_path}")
70+
writer = array_record_module.ArrayRecordWriter(output_path, "group_size:1000")
71+
72+
max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...")
73+
loop_start = time.time()
74+
for step, batch in enumerate(islice(train_iter, config.steps)):
75+
step_start = time.time()
76+
tokens = batch["inputs"]
77+
logits = teacher_model(
78+
decoder_input_tokens=tokens,
79+
decoder_positions=batch["inputs_position"],
80+
enable_dropout=False,
81+
)
82+
top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val)
83+
84+
# Extract only the local data for this host (Distributed Writing)
85+
local_vals = get_local_cpu_array(top_k_vals)
86+
local_idx = get_local_cpu_array(top_k_idx)
87+
local_tokens = get_local_cpu_array(tokens)
88+
89+
optional_keys = config.teacher_logits_optional_keys
90+
local_optionals = {
91+
key: get_local_cpu_array(batch[key]) for key in optional_keys if key in batch
92+
}
93+
94+
record_dict = {
95+
"tokens": local_tokens,
96+
"top_k_logits": local_vals,
97+
"top_k_indices": local_idx,
98+
}
99+
for key, local_val in local_optionals.items():
100+
record_dict[key] = local_val
101+
102+
writer.write(pickle.dumps(record_dict))
103+
104+
if step % 50 == 0:
105+
max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s")
106+
107+
max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s")
108+
109+
writer.close()
110+
max_logging.log(f"Finished writing to {output_path}.")
111+
112+
113+
def main(argv: Sequence[str], local_args):
114+
# Initialize the global configuration
115+
global_config = pyconfig.initialize(argv)
116+
teacher_overrides = global_config.teacher_overrides
117+
teacher_argv = [argv[0], argv[1]]
118+
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
119+
120+
generate_and_save_data(teacher_config, local_args.top_k)
121+
122+
123+
if __name__ == "__main__":
124+
parser = argparse.ArgumentParser()
125+
parser.add_argument(
126+
"--top_k",
127+
type=int,
128+
required=False,
129+
default=128,
130+
help="Top K value for logits.",
131+
)
132+
local_arg, remaining_args = parser.parse_known_args()
133+
134+
main_wrapper = functools.partial(main, local_args=local_arg)
135+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)