Skip to content

Commit 4b515df

Browse files
committed
Added cmd args fo top_k and refactored code for readability
1 parent e4dcf69 commit 4b515df

1 file changed

Lines changed: 24 additions & 20 deletions

File tree

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
import os
77
import pickle
88
from typing import Sequence
9+
import argparse
10+
import sys
911
import tensorflow as tf
1012

11-
from absl import app
1213
import jax
1314
import numpy as np
15+
import functools
16+
from itertools import islice
1417

18+
from absl import app
1519
from MaxText import pyconfig
1620
from maxtext.utils import model_creation_utils
1721
from maxtext.input_pipeline import input_pipeline_interface
@@ -28,7 +32,7 @@ def get_top_k_logits(logits: jax.Array, k: int):
2832
return top_k_values, top_k_indices
2933

3034

31-
def generate_and_save_data(config):
35+
def generate_and_save_data(config, k_val):
3236
"""Generates top-k logits from the teacher model and saves them to an ArrayRecord file"""
3337
devices = jax.devices()
3438
devices_array = maxtext_utils.create_device_mesh(config, devices)
@@ -45,7 +49,6 @@ def generate_and_save_data(config):
4549
if config.run_name:
4650
output_dir = os.path.join(output_dir, config.run_name)
4751

48-
# create final GCS path and local temp file path
4952
final_gcs_file = os.path.join(output_dir, "teacher_top_k.array_record")
5053
local_temp_file = "/tmp/teacher_top_k.array_record"
5154

@@ -55,35 +58,26 @@ def generate_and_save_data(config):
5558
writer = array_record_module.ArrayRecordWriter(local_temp_file, "group_size:1000")
5659

5760
max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...")
58-
for step, batch in enumerate(train_iter):
59-
if step >= config.steps:
60-
break
61-
61+
for step, batch in enumerate(islice(train_iter, config.steps)):
6262
tokens = batch["inputs"]
63-
6463
logits = teacher_model(
6564
decoder_input_tokens=tokens,
6665
decoder_positions=batch["inputs_position"],
6766
enable_dropout=False,
6867
)
6968

70-
# determine top-k size and extract top-k logits and indices
71-
k_val = config.decode_sampling_top_k if config.decode_sampling_top_k > 0 else 128
7269
top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val)
7370

7471
gathered_vals = multihost_utils.process_allgather(top_k_vals, tiled=True)
7572
gathered_idx = multihost_utils.process_allgather(top_k_idx, tiled=True)
7673
gathered_tokens = multihost_utils.process_allgather(tokens, tiled=True)
7774

7875
optional_keys = ["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"]
79-
80-
gathered_optionals = {}
81-
for key in optional_keys:
82-
if key in batch:
83-
gathered_optionals[key] = multihost_utils.process_allgather(batch[key], tiled=True)
76+
gathered_optionals = {
77+
key: multihost_utils.process_allgather(batch[key], tiled=True) for key in optional_keys if key in batch
78+
}
8479

8580
if process_index == 0:
86-
# Writing the gathered tokens, top-k logits, and top-k indices to the ArrayRecord file
8781
record_dict = {
8882
"tokens": np.array(gathered_tokens),
8983
"top_k_logits": np.array(gathered_vals),
@@ -107,21 +101,31 @@ def generate_and_save_data(config):
107101
if not tf.io.gfile.exists(output_dir):
108102
tf.io.gfile.makedirs(output_dir)
109103

110-
# Upload the local file to GCS and remove the local temp file
111104
tf.io.gfile.copy(local_temp_file, final_gcs_file, overwrite=True)
112105
os.remove(local_temp_file)
113106
max_logging.log("Upload complete")
114107

115108

116-
def main(argv: Sequence[str]):
109+
def main(argv: Sequence[str], local_args):
117110
# Initialize the global configuration
118111
global_config = pyconfig.initialize(argv)
119112
teacher_overrides = global_config.teacher_overrides
120113
teacher_argv = [argv[0], argv[1]]
121114
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
122115

123-
generate_and_save_data(teacher_config)
116+
generate_and_save_data(teacher_config, local_args.top_k)
124117

125118

126119
if __name__ == "__main__":
127-
app.run(main)
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument(
122+
"--top_k",
123+
type=int,
124+
required=False,
125+
default=128,
126+
help="Top K value for logits.",
127+
)
128+
local_arg, remaining_args = parser.parse_known_args()
129+
130+
main_wrapper = functools.partial(main, local_args=local_arg)
131+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)