66import os
77import pickle
88from typing import Sequence
9+ import argparse
10+ import sys
911import tensorflow as tf
1012
11- from absl import app
1213import jax
1314import numpy as np
15+ import functools
16+ from itertools import islice
1417
18+ from absl import app
1519from MaxText import pyconfig
1620from maxtext .utils import model_creation_utils
1721from maxtext .input_pipeline import input_pipeline_interface
@@ -28,7 +32,7 @@ def get_top_k_logits(logits: jax.Array, k: int):
2832 return top_k_values , top_k_indices
2933
3034
31- def generate_and_save_data (config ):
35+ def generate_and_save_data (config , k_val ):
3236 """Generates top-k logits from the teacher model and saves them to an ArrayRecord file"""
3337 devices = jax .devices ()
3438 devices_array = maxtext_utils .create_device_mesh (config , devices )
@@ -45,7 +49,6 @@ def generate_and_save_data(config):
4549 if config .run_name :
4650 output_dir = os .path .join (output_dir , config .run_name )
4751
48- # create final GCS path and local temp file path
4952 final_gcs_file = os .path .join (output_dir , "teacher_top_k.array_record" )
5053 local_temp_file = "/tmp/teacher_top_k.array_record"
5154
@@ -55,35 +58,26 @@ def generate_and_save_data(config):
5558 writer = array_record_module .ArrayRecordWriter (local_temp_file , "group_size:1000" )
5659
5760 max_logging .log (f"Starting Top-K generation loop for { config .steps } steps..." )
58- for step , batch in enumerate (train_iter ):
59- if step >= config .steps :
60- break
61-
61+ for step , batch in enumerate (islice (train_iter , config .steps )):
6262 tokens = batch ["inputs" ]
63-
6463 logits = teacher_model (
6564 decoder_input_tokens = tokens ,
6665 decoder_positions = batch ["inputs_position" ],
6766 enable_dropout = False ,
6867 )
6968
70- # determine top-k size and extract top-k logits and indices
71- k_val = config .decode_sampling_top_k if config .decode_sampling_top_k > 0 else 128
7269 top_k_vals , top_k_idx = get_top_k_logits (logits , k = k_val )
7370
7471 gathered_vals = multihost_utils .process_allgather (top_k_vals , tiled = True )
7572 gathered_idx = multihost_utils .process_allgather (top_k_idx , tiled = True )
7673 gathered_tokens = multihost_utils .process_allgather (tokens , tiled = True )
7774
7875 optional_keys = ["inputs_position" , "inputs_segmentation" , "targets_segmentation" , "targets" ]
79-
80- gathered_optionals = {}
81- for key in optional_keys :
82- if key in batch :
83- gathered_optionals [key ] = multihost_utils .process_allgather (batch [key ], tiled = True )
76+ gathered_optionals = {
77+ key : multihost_utils .process_allgather (batch [key ], tiled = True ) for key in optional_keys if key in batch
78+ }
8479
8580 if process_index == 0 :
86- # Writing the gathered tokens, top-k logits, and top-k indices to the ArrayRecord file
8781 record_dict = {
8882 "tokens" : np .array (gathered_tokens ),
8983 "top_k_logits" : np .array (gathered_vals ),
@@ -107,21 +101,31 @@ def generate_and_save_data(config):
107101 if not tf .io .gfile .exists (output_dir ):
108102 tf .io .gfile .makedirs (output_dir )
109103
110- # Upload the local file to GCS and remove the local temp file
111104 tf .io .gfile .copy (local_temp_file , final_gcs_file , overwrite = True )
112105 os .remove (local_temp_file )
113106 max_logging .log ("Upload complete" )
114107
115108
116- def main (argv : Sequence [str ]):
109+ def main (argv : Sequence [str ], local_args ):
117110 # Initialize the global configuration
118111 global_config = pyconfig .initialize (argv )
119112 teacher_overrides = global_config .teacher_overrides
120113 teacher_argv = [argv [0 ], argv [1 ]]
121114 teacher_config = pyconfig .initialize (teacher_argv , ** teacher_overrides )
122115
123- generate_and_save_data (teacher_config )
116+ generate_and_save_data (teacher_config , local_args . top_k )
124117
125118
126119if __name__ == "__main__" :
127- app .run (main )
120+ parser = argparse .ArgumentParser ()
121+ parser .add_argument (
122+ "--top_k" ,
123+ type = int ,
124+ required = False ,
125+ default = 128 ,
126+ help = "Top K value for logits." ,
127+ )
128+ local_arg , remaining_args = parser .parse_known_args ()
129+
130+ main_wrapper = functools .partial (main , local_args = local_arg )
131+ app .run (main_wrapper , argv = [sys .argv [0 ]] + remaining_args )
0 commit comments