Skip to content

Commit 7cffe43

Browse files
committed
Add standalone script to save top-K teacher logits for offline distillation
1 parent 62ee818 commit 7cffe43

2 files changed

Lines changed: 138 additions & 1 deletion

File tree

src/maxtext/configs/types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,13 @@ class Distillation(BaseModel):
10521052
# --- Loss Params ---
10531053
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
10541054
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
1055+
1056+
# --- Teacher topk distillation ---
1057+
teacher_logits_optional_keys: list[str] = Field(
1058+
default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"],
1059+
description="Optional keys to save from teacher logits"
1060+
)
1061+
10551062

10561063

10571064
class TrainingLoop(BaseModel):
@@ -1809,7 +1816,6 @@ class MaxTextConfig(
18091816
# Reinforcement Learning
18101817
RLHardware,
18111818
VLLM,
1812-
RL,
18131819
RLDataset,
18141820
RLEvaluation,
18151821
Reward,
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
This module provides functionality to save top-k teacher logits
3+
for distillation purposes in MaxText.
4+
"""
5+
6+
import os
7+
import pickle
8+
from typing import Sequence
9+
import argparse
10+
import sys
11+
import tensorflow as tf
12+
13+
import jax
14+
import numpy as np
15+
import functools
16+
from itertools import islice
17+
18+
from absl import app
19+
from MaxText import pyconfig
20+
from maxtext.utils import model_creation_utils
21+
from maxtext.input_pipeline import input_pipeline_interface
22+
from maxtext.utils import maxtext_utils
23+
from maxtext.utils import max_logging
24+
25+
from jax.experimental import multihost_utils
26+
from array_record.python import array_record_module
27+
28+
29+
def get_top_k_logits(logits: jax.Array, k: int):
30+
"""Extracts the top-k values and their vocabulary indices"""
31+
top_k_values, top_k_indices = jax.lax.top_k(logits, k)
32+
return top_k_values, top_k_indices
33+
34+
35+
def generate_and_save_data(config, k_val):
36+
"""Generates top-k logits from the teacher model and saves them to an ArrayRecord file"""
37+
devices = jax.devices()
38+
devices_array = maxtext_utils.create_device_mesh(config, devices)
39+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
40+
41+
# Loading teacher model and dataset iterator
42+
max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...")
43+
teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
44+
train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)
45+
46+
process_index = jax.process_index()
47+
48+
output_dir = config.base_output_directory
49+
if config.run_name:
50+
output_dir = os.path.join(output_dir, config.run_name)
51+
52+
final_gcs_file = os.path.join(output_dir, "teacher_top_k.array_record")
53+
local_temp_file = "/tmp/teacher_top_k.array_record"
54+
55+
writer = None
56+
if process_index == 0:
57+
max_logging.log(f"Opening local ArrayRecordWriter at {local_temp_file}")
58+
writer = array_record_module.ArrayRecordWriter(local_temp_file, "group_size:1000")
59+
60+
max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...")
61+
for step, batch in enumerate(islice(train_iter, config.steps)):
62+
tokens = batch["inputs"]
63+
logits = teacher_model(
64+
decoder_input_tokens=tokens,
65+
decoder_positions=batch["inputs_position"],
66+
enable_dropout=False,
67+
)
68+
69+
top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val)
70+
71+
gathered_vals = multihost_utils.process_allgather(top_k_vals, tiled=True)
72+
gathered_idx = multihost_utils.process_allgather(top_k_idx, tiled=True)
73+
gathered_tokens = multihost_utils.process_allgather(tokens, tiled=True)
74+
75+
optional_keys = config.teacher_logits_optional_keys
76+
gathered_optionals = {
77+
key: multihost_utils.process_allgather(batch[key], tiled=True) for key in optional_keys if key in batch
78+
}
79+
80+
if process_index == 0:
81+
record_dict = {
82+
"tokens": np.array(gathered_tokens),
83+
"top_k_logits": np.array(gathered_vals),
84+
"top_k_indices": np.array(gathered_idx),
85+
}
86+
87+
for key, gathered_val in gathered_optionals.items():
88+
record_dict[key] = np.array(gathered_val)
89+
90+
writer.write(pickle.dumps(record_dict))
91+
92+
if step % 50 == 0:
93+
max_logging.log(f"Successfully processed step {step}")
94+
95+
if writer is None:
96+
return
97+
98+
writer.close()
99+
max_logging.log(f"Finished writing locally, uploading to GCS: {final_gcs_file}...")
100+
101+
if not tf.io.gfile.exists(output_dir):
102+
tf.io.gfile.makedirs(output_dir)
103+
104+
tf.io.gfile.copy(local_temp_file, final_gcs_file, overwrite=True)
105+
os.remove(local_temp_file)
106+
max_logging.log("Upload complete")
107+
108+
109+
def main(argv: Sequence[str], local_args):
110+
# Initialize the global configuration
111+
global_config = pyconfig.initialize(argv)
112+
teacher_overrides = global_config.teacher_overrides
113+
teacher_argv = [argv[0], argv[1]]
114+
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
115+
116+
generate_and_save_data(teacher_config, local_args.top_k)
117+
118+
119+
if __name__ == "__main__":
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument(
122+
"--top_k",
123+
type=int,
124+
required=False,
125+
default=128,
126+
help="Top K value for logits.",
127+
)
128+
local_arg, remaining_args = parser.parse_known_args()
129+
130+
main_wrapper = functools.partial(main, local_args=local_arg)
131+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)