@@ -218,7 +218,7 @@ def generate_and_save_data(config, local_args):
218218 if jax .process_index () == 0 :
219219 max_logging .log (f"Queueing distributed background uploads for Step { step } ..." )
220220 upload_executor .submit (background_upload , local_output_path , gcs_file_path , jax .process_index ())
221-
221+
222222 # Re-initialize the writer with 1 worker
223223 write_executor = ThreadPoolExecutor (max_workers = 1 )
224224
@@ -260,13 +260,13 @@ def generate_and_save_data(config, local_args):
260260 # --- Local Disk Writing ---
261261 # Submit to the background thread with the serialization_executor
262262 write_executor .submit (
263- background_process_and_write ,
264- writer ,
265- local_tokens_np ,
266- local_vals_np ,
267- local_idx_np ,
263+ background_process_and_write ,
264+ writer ,
265+ local_tokens_np ,
266+ local_vals_np ,
267+ local_idx_np ,
268268 local_opt_data_np ,
269- serialization_executor
269+ serialization_executor ,
270270 )
271271
272272 if step % 50 == 0 and jax .process_index () == 0 :
@@ -303,9 +303,9 @@ def generate_and_save_data(config, local_args):
303303def main (argv : Sequence [str ], local_args ):
304304 global_config = pyconfig .initialize (argv )
305305 teacher_overrides = global_config .teacher_overrides
306-
306+
307307 teacher_config = pyconfig .initialize (argv , ** teacher_overrides )
308-
308+
309309 generate_and_save_data (teacher_config , local_args )
310310
311311
@@ -324,4 +324,4 @@ def main(argv: Sequence[str], local_args):
324324 local_arg , remaining_args = parser .parse_known_args ()
325325
326326 main_wrapper = functools .partial (main , local_args = local_arg )
327- app .run (main_wrapper , argv = [sys .argv [0 ]] + remaining_args )
327+ app .run (main_wrapper , argv = [sys .argv [0 ]] + remaining_args )
0 commit comments