@@ -89,13 +89,11 @@ def create_tf_example(example_dict):
8989 features [key ] = tf .train .Feature (int64_list = tf .train .Int64List (value = [val ]))
9090 continue
9191
92- # Use .ravel() to avoid the memory copy that .flatten() does
9392 flat_val = np .asarray (val ).ravel ()
9493
9594 if flat_val .dtype in [np .float32 , np .float64 , np .float16 , jnp .bfloat16 ]:
9695 if flat_val .dtype != np .float32 :
9796 flat_val = flat_val .astype (np .float32 )
98- # Use .tolist() for fast Protobuf C++ ingestion
9997 features [key ] = tf .train .Feature (float_list = tf .train .FloatList (value = flat_val .tolist ()))
10098 elif flat_val .dtype in [np .int32 , np .int64 ]:
10199 if flat_val .dtype != np .int64 :
@@ -179,11 +177,11 @@ def generate_and_save_data(config, local_args):
179177 if not os .path .exists (local_tmp_dir ):
180178 os .makedirs (local_tmp_dir , exist_ok = True )
181179
182- upload_executor = ThreadPoolExecutor (max_workers = 4 )
180+ upload_executor = ThreadPoolExecutor (max_workers = local_args . upload_workers )
183181 # Restrict to 1 worker to ensure sequential writing to the ArrayRecord file
184182 write_executor = ThreadPoolExecutor (max_workers = 1 )
185183 # New executor purely for CPU-bound protobuf serialization
186- serialization_executor = ThreadPoolExecutor (max_workers = 16 )
184+ serialization_executor = ThreadPoolExecutor (max_workers = local_args . serialization_workers )
187185
188186 devices = jax .devices ()
189187 devices_array = maxtext_utils .create_device_mesh (config , devices )
@@ -219,7 +217,9 @@ def generate_and_save_data(config, local_args):
219217 max_logging .log (f"Queueing distributed background uploads for Step { step } ..." )
220218 upload_executor .submit (background_upload , local_output_path , gcs_file_path , jax .process_index ())
221219
222- # Re-initialize the writer with 1 worker
220+ # Re-initialize the writer thread pool. We restrict it to exactly 1 worker
221+ # to ensure sequential, in-order writing to the ArrayRecord file. A new pool
222+ # is needed because the previous one was shut down to flush all pending writes.
223223 write_executor = ThreadPoolExecutor (max_workers = 1 )
224224
225225 file_index = step // steps_per_file
@@ -321,6 +321,10 @@ def main(argv: Sequence[str], local_args):
321321 parser .add_argument ("--gcs_upload_path" , type = str , default = None )
322322 parser .add_argument ("--local_tmp_dir" , type = str , default = "/tmp" )
323323 parser .add_argument ("--steps_per_file" , type = int , default = 50 )
324+ parser .add_argument ("--upload_workers" , type = int , default = 4 , help = "Number of workers for GCS uploads." )
325+ parser .add_argument (
326+ "--serialization_workers" , type = int , default = 16 , help = "Number of workers for protobuf serialization."
327+ )
324328 local_arg , remaining_args = parser .parse_known_args ()
325329
326330 main_wrapper = functools .partial (main , local_args = local_arg )
0 commit comments