2121import subprocess
2222from concurrent .futures import ThreadPoolExecutor
2323
24- # Force the pure Python protobuf implementation to avoid UPB compatibility issues with TFDS
25- os .environ ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION" ] = "python"
2624from typing import Sequence
2725import argparse
2826import time
@@ -70,7 +68,6 @@ def get_start_step(config, local_args):
7068 if not existing_files :
7169 return 0
7270
73- # Updated regex to handle the host ID in the filename
7471 max_part_num = max (
7572 (int (m .group (1 )) for f in existing_files if (m := re .search (r"part_(\d+)_host" , f .name ))),
7673 default = - 1 ,
@@ -92,26 +89,34 @@ def create_tf_example(example_dict):
9289 features [key ] = tf .train .Feature (int64_list = tf .train .Int64List (value = [val ]))
9390 continue
9491
95- flat_val = np .asarray (val ).flatten ()
92+ flat_val = np .asarray (val ).ravel ()
9693
9794 if flat_val .dtype in [np .float32 , np .float64 , np .float16 , jnp .bfloat16 ]:
98- features [key ] = tf .train .Feature (float_list = tf .train .FloatList (value = flat_val .astype (np .float32 )))
95+ if flat_val .dtype != np .float32 :
96+ flat_val = flat_val .astype (np .float32 )
97+ features [key ] = tf .train .Feature (float_list = tf .train .FloatList (value = flat_val .tolist ()))
9998 elif flat_val .dtype in [np .int32 , np .int64 ]:
100- features [key ] = tf .train .Feature (int64_list = tf .train .Int64List (value = flat_val .astype (np .int64 )))
99+ if flat_val .dtype != np .int64 :
100+ flat_val = flat_val .astype (np .int64 )
101+ features [key ] = tf .train .Feature (int64_list = tf .train .Int64List (value = flat_val .tolist ()))
101102 else :
102103 raise ValueError (f"Unsupported dtype { flat_val .dtype } for key { key } " )
103104
104105 return tf .train .Example (features = tf .train .Features (feature = features )).SerializeToString ()
105106
106107
107- def background_process_and_write (writer , tokens , vals , idx , opt_data ):
108+ def background_process_and_write (writer , tokens , vals , idx , opt_data , serialization_executor ):
108109 """Executes entirely on a background CPU thread so the TPU never waits."""
109- tokens_np = np .array (tokens )
110- vals_np = np .array (vals )
111- idx_np = np .array (idx )
112- opt_data_np = {k : np .array (v ) for k , v in opt_data .items ()}
110+ # Convert exactly once
111+ tokens_np = np .asarray (tokens )
112+ vals_np = np .asarray (vals )
113+ idx_np = np .asarray (idx )
114+ opt_data_np = {k : np .asarray (v ) for k , v in opt_data .items ()}
113115
114116 batch_size = tokens_np .shape [0 ]
117+ example_dicts = []
118+
119+ # Prepare dictionaries sequentially
115120 for i in range (batch_size ):
116121 seq_bytes = tokens_np [i ].tobytes ()
117122 example_dict = {
@@ -122,8 +127,14 @@ def background_process_and_write(writer, tokens, vals, idx, opt_data):
122127 }
123128 for key , val_np in opt_data_np .items ():
124129 example_dict [key ] = val_np [i ]
130+ example_dicts .append (example_dict )
131+
132+ # Serialize to Protobufs in parallel across multiple CPU cores
133+ serialized_records = list (serialization_executor .map (create_tf_example , example_dicts ))
125134
126- writer .write (create_tf_example (example_dict ))
135+ # Write the serialized bytes to disk sequentially
136+ for record in serialized_records :
137+ writer .write (record )
127138
128139
129140def background_upload (local_path , gcs_path , process_index ):
@@ -163,12 +174,14 @@ def generate_and_save_data(config, local_args):
163174 writer = None
164175 local_output_path = None
165176
166- # all hosts initialize their own directories and thread pools
167177 if not os .path .exists (local_tmp_dir ):
168178 os .makedirs (local_tmp_dir , exist_ok = True )
169179
170- upload_executor = ThreadPoolExecutor (max_workers = 4 )
171- write_executor = ThreadPoolExecutor (max_workers = 2 )
180+ upload_executor = ThreadPoolExecutor (max_workers = local_args .upload_workers )
181+ # Restrict to 1 worker to ensure sequential writing to the ArrayRecord file
182+ write_executor = ThreadPoolExecutor (max_workers = 1 )
183+ # New executor purely for CPU-bound protobuf serialization
184+ serialization_executor = ThreadPoolExecutor (max_workers = local_args .serialization_workers )
172185
173186 devices = jax .devices ()
174187 devices_array = maxtext_utils .create_device_mesh (config , devices )
@@ -194,7 +207,6 @@ def generate_and_save_data(config, local_args):
194207 for step , batch in enumerate (islice (train_iter , start_step , config .steps ), start = start_step ):
195208 step_start = time .time ()
196209
197- # ALL HOSTS execute the file opening/closing logic
198210 if step % steps_per_file == 0 :
199211 if writer :
200212 write_executor .shutdown (wait = True )
@@ -204,15 +216,20 @@ def generate_and_save_data(config, local_args):
204216 if jax .process_index () == 0 :
205217 max_logging .log (f"Queueing distributed background uploads for Step { step } ..." )
206218 upload_executor .submit (background_upload , local_output_path , gcs_file_path , jax .process_index ())
207- write_executor = ThreadPoolExecutor (max_workers = 2 )
219+
220+ # Re-initialize the writer thread pool. We restrict it to exactly 1 worker
221+ # to ensure sequential, in-order writing to the ArrayRecord file. A new pool
222+ # is needed because the previous one was shut down to flush all pending writes.
223+ write_executor = ThreadPoolExecutor (max_workers = 1 )
208224
209225 file_index = step // steps_per_file
210- # filename includes host ID to prevent GCS collisions
211226 filename = f"teacher_top_k_part_{ file_index :05d} _host_{ jax .process_index ():03d} .array_record"
212227 local_output_path = os .path .join (local_tmp_dir , filename )
213228 writer = array_record_module .ArrayRecordWriter (local_output_path , "group_size:1" )
214229
215230 tokens = batch ["inputs" ]
231+
232+ # --- Model Forward Pass & Network Gather ---
216233 top_k_vals , top_k_idx = teacher_step (teacher_model , batch , k_val )
217234
218235 global_tokens = jax .experimental .multihost_utils .process_allgather (tokens , tiled = True )
@@ -225,13 +242,11 @@ def generate_and_save_data(config, local_args):
225242 optional_data [key ] = jax .experimental .multihost_utils .process_allgather (batch [key ], tiled = True )
226243
227244 if writer :
228- # Convert to numpy safely on the CPU
229245 global_tokens_np = np .array (global_tokens )
230246 global_vals_np = np .array (global_vals )
231247 global_idx_np = np .array (global_idx )
232248 optional_data_np = {k : np .array (v ) for k , v in optional_data .items ()}
233249
234- # Slice out this host's local fraction of the batch
235250 global_batch_size = global_tokens_np .shape [0 ]
236251 local_batch_size = global_batch_size // jax .process_count ()
237252 start_idx = jax .process_index () * local_batch_size
@@ -242,24 +257,33 @@ def generate_and_save_data(config, local_args):
242257 local_idx_np = global_idx_np [start_idx :end_idx ]
243258 local_opt_data_np = {k : v [start_idx :end_idx ] for k , v in optional_data_np .items ()}
244259
245- # Write synchronously
246- background_process_and_write (writer , local_tokens_np , local_vals_np , local_idx_np , local_opt_data_np )
260+ # --- Local Disk Writing ---
261+ # Submit to the background thread with the serialization_executor
262+ write_executor .submit (
263+ background_process_and_write ,
264+ writer ,
265+ local_tokens_np ,
266+ local_vals_np ,
267+ local_idx_np ,
268+ local_opt_data_np ,
269+ serialization_executor ,
270+ )
247271
248272 if step % 50 == 0 and jax .process_index () == 0 :
249273 max_logging .log (f"Successfully processed step { step } in { time .time () - step_start :.4f} s" )
250274
251- # Sync hosts briefly to ensure TPU compute stays aligned across the mesh
252275 multihost_utils .sync_global_devices (f"step_{ step } _complete" )
253276
254277 if jax .process_index () == 0 :
255278 max_logging .log (f"Generation loop finished in { time .time () - loop_start :.2f} s" )
256279
257280 multihost_utils .sync_global_devices ("loop_finished" )
258281
259- # Finalize writing and handle GCS upload on all hosts
260282 if writer :
261283 if write_executor :
262284 write_executor .shutdown (wait = True )
285+ if serialization_executor :
286+ serialization_executor .shutdown (wait = True )
263287 writer .close ()
264288
265289 if gcs_upload_path :
@@ -279,8 +303,9 @@ def generate_and_save_data(config, local_args):
279303def main (argv : Sequence [str ], local_args ):
280304 global_config = pyconfig .initialize (argv )
281305 teacher_overrides = global_config .teacher_overrides
282- teacher_argv = [argv [0 ], argv [1 ]]
283- teacher_config = pyconfig .initialize (teacher_argv , ** teacher_overrides )
306+
307+ teacher_config = pyconfig .initialize (argv , ** teacher_overrides )
308+
284309 generate_and_save_data (teacher_config , local_args )
285310
286311
@@ -295,7 +320,11 @@ def main(argv: Sequence[str], local_args):
295320 )
296321 parser .add_argument ("--gcs_upload_path" , type = str , default = None )
297322 parser .add_argument ("--local_tmp_dir" , type = str , default = "/tmp" )
298- parser .add_argument ("--steps_per_file" , type = int , default = 1000 )
323+ parser .add_argument ("--steps_per_file" , type = int , default = 50 )
324+ parser .add_argument ("--upload_workers" , type = int , default = 4 , help = "Number of workers for GCS uploads." )
325+ parser .add_argument (
326+ "--serialization_workers" , type = int , default = 16 , help = "Number of workers for protobuf serialization."
327+ )
299328 local_arg , remaining_args = parser .parse_known_args ()
300329
301330 main_wrapper = functools .partial (main , local_args = local_arg )
0 commit comments