1818
1919Example command:
2020python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py \
21- src/maxtext/configs/post_train/distillation.yml \
22- --top_k=128 \
23- --gcs_upload_path=gs://my-bucket/teacher_logits/
21+ src/maxtext/configs/post_train/distillation.yml \
22+ --local_tmp_dir=/tmp/save_logits_dir \
23+ --steps_per_file=10
2424"""
2525
2626import os
3030import time
3131import sys
3232import tensorflow as tf
33+ import re
3334
3435import jax
36+ import jax .numpy as jnp
3537import functools
3638from itertools import islice
3739
@@ -52,12 +54,41 @@ def get_top_k_logits(logits: jax.Array, k: int):
5254 return top_k_values , top_k_indices
5355
5456
57+ def get_start_step (config , local_args ):
58+ """Determines the starting step for the generation process."""
59+ if jax .process_index () != 0 :
60+ return 0
61+
62+ output_dir = local_args .gcs_upload_path if local_args .gcs_upload_path else local_args .local_tmp_dir
63+ if not tf .io .gfile .exists (output_dir ):
64+ tf .io .gfile .makedirs (output_dir )
65+ return 0
66+
67+ existing_files = tf .io .gfile .glob (os .path .join (output_dir , "teacher_top_k_part_*.array_record" ))
68+ if not existing_files :
69+ return 0
70+
71+ # Find the highest part number from the filenames
72+ max_part_num = max (
73+ (int (m .group (1 )) for f in existing_files if (m := re .search (r"part_(\d+).array_record" , os .path .basename (f )))),
74+ default = - 1 ,
75+ )
76+
77+ if max_part_num == - 1 :
78+ return 0
79+
80+ start_step = max_part_num * local_args .steps_per_file
81+ max_logging .log (f"Found existing data, resuming from step { start_step } " )
82+ return start_step
83+
84+
5585def generate_and_save_data (config , local_args ):
56- """Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS. """
86+ """Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS"""
5787 k_val = local_args .top_k
5888 optional_keys = local_args .optional_keys
5989 gcs_upload_path = local_args .gcs_upload_path
6090 local_tmp_dir = local_args .local_tmp_dir
91+ steps_per_file = local_args .steps_per_file
6192
6293 devices = jax .devices ()
6394 devices_array = maxtext_utils .create_device_mesh (config , devices )
@@ -68,25 +99,43 @@ def generate_and_save_data(config, local_args):
6899 teacher_model , _ = model_creation_utils .create_nnx_model (config , mesh = mesh )
69100 train_iter , _ = input_pipeline_interface .create_data_iterator (config , mesh )
70101
71- # Setup local tmp directory for Host 0
72- filename = "teacher_top_k_global.array_record"
73- local_output_path = os . path . join ( local_tmp_dir , filename )
102+ # Determine start_step for resuming
103+ start_step = get_start_step ( config , local_args )
104+ start_step = int ( multihost_utils . broadcast_one_to_all ( jnp . array ( start_step )) )
74105
75106 writer = None
107+ local_output_path = None
76108 if jax .process_index () == 0 :
77109 if not os .path .exists (local_tmp_dir ):
78110 os .makedirs (local_tmp_dir )
79- max_logging .log (f"Process 0 writing globally gathered data to local path: { local_output_path } " )
80- writer = array_record_module .ArrayRecordWriter (local_output_path , "group_size:1000" )
81111
82112 # Sync all hosts before starting the loop
83113 multihost_utils .sync_global_devices ("start_generation_loop" )
84114
85- max_logging .log (f"Starting Top-K generation loop for { config .steps } steps..." )
115+ max_logging .log (f"Starting Top-K generation loop for { config .steps - start_step } steps..." )
86116 loop_start = time .time ()
87117
88- for step , batch in enumerate (islice (train_iter , config .steps )):
118+ for step , batch in enumerate (islice (train_iter , start_step , config .steps ), start = start_step ):
89119 step_start = time .time ()
120+
121+ # Open a new writer for each file chunk on process 0
122+ if jax .process_index () == 0 and step % steps_per_file == 0 :
123+ if writer :
124+ writer .close ()
125+ if gcs_upload_path :
126+ # Upload the previous file
127+ gcs_file_path = os .path .join (gcs_upload_path , os .path .basename (local_output_path ))
128+ max_logging .log (f"Uploading { local_output_path } to { gcs_file_path } " )
129+ tf .io .gfile .copy (local_output_path , gcs_file_path , overwrite = True )
130+ os .remove (local_output_path )
131+ max_logging .log ("Upload complete." )
132+
133+ file_index = step // steps_per_file
134+ filename = f"teacher_top_k_part_{ file_index :05d} .array_record"
135+ local_output_path = os .path .join (local_tmp_dir , filename )
136+ max_logging .log (f"Process 0 writing to new chunk: { local_output_path } " )
137+ writer = array_record_module .ArrayRecordWriter (local_output_path , "group_size:1000" )
138+
90139 tokens = batch ["inputs" ]
91140 logits = teacher_model (
92141 decoder_input_tokens = tokens ,
@@ -122,19 +171,15 @@ def generate_and_save_data(config, local_args):
122171 multihost_utils .sync_global_devices ("loop_finished" )
123172
124173 # Finalize writing and handle GCS upload on Host 0
125- if jax .process_index () == 0 :
174+ if jax .process_index () == 0 and writer :
126175 writer .close ()
127176 max_logging .log (f"Finished writing to local disk: { local_output_path } " )
128177
129178 if gcs_upload_path :
130- gcs_file_path = os .path .join (gcs_upload_path , filename )
131- max_logging .log (f"Flag --gcs_upload_path detected. Uploading to: { gcs_file_path } " )
132-
133- if not tf .io .gfile .exists (gcs_upload_path ):
134- tf .io .gfile .makedirs (gcs_upload_path )
135-
136- # Perform the bulk copy to GCS
179+ gcs_file_path = os .path .join (gcs_upload_path , os .path .basename (local_output_path ))
180+ max_logging .log (f"Uploading final chunk to: { gcs_file_path } " )
137181 tf .io .gfile .copy (local_output_path , gcs_file_path , overwrite = True )
182+ os .remove (local_output_path )
138183 max_logging .log ("GCS Upload complete." )
139184
140185 # Sync all hosts one last time so worker hosts don't terminate the job
@@ -182,6 +227,12 @@ def main(argv: Sequence[str], local_args):
182227 default = "/tmp" ,
183228 help = "Local temporary directory to write the ArrayRecord file before optional GCS upload." ,
184229 )
230+ parser .add_argument (
231+ "--steps_per_file" ,
232+ type = int ,
233+ default = 1000 ,
234+ help = "Number of steps to save in each chunk." ,
235+ )
185236 local_arg , remaining_args = parser .parse_known_args ()
186237
187238 main_wrapper = functools .partial (main , local_args = local_arg )
0 commit comments