3333import re
3434
3535import jax
36+ import jax .numpy as jnp
3637import functools
3738from itertools import islice
3839
@@ -68,17 +69,15 @@ def get_start_step(config, local_args):
6869 return 0
6970
7071 # Find the highest part number from the filenames
71- max_part_num = - 1
72- for f in existing_files :
73- max_part_num = max (
74- (int (m .group (1 )) for f in existing_files if (m := re .search (r"part_(\d+).array_record" , os .path .basename (f )))),
75- default = - 1 ,
76- )
72+ max_part_num = max (
73+ (int (m .group (1 )) for f in existing_files if (m := re .search (r"part_(\d+).array_record" , os .path .basename (f )))),
74+ default = - 1 ,
75+ )
7776
7877 if max_part_num == - 1 :
7978 return 0
8079
81- start_step = ( max_part_num + 1 ) * local_args .steps_per_file
80+ start_step = max_part_num * local_args .steps_per_file
8281 max_logging .log (f"Found existing data, resuming from step { start_step } " )
8382 return start_step
8483
@@ -102,7 +101,7 @@ def generate_and_save_data(config, local_args):
102101
103102 # Determine start_step for resuming
104103 start_step = get_start_step (config , local_args )
105- start_step = int (multihost_utils .broadcast_one_to_all (jax . numpy .array (start_step )))
104+ start_step = int (multihost_utils .broadcast_one_to_all (jnp .array (start_step )))
106105
107106 writer = None
108107 local_output_path = None
@@ -128,6 +127,7 @@ def generate_and_save_data(config, local_args):
128127 gcs_file_path = os .path .join (gcs_upload_path , os .path .basename (local_output_path ))
129128 max_logging .log (f"Uploading { local_output_path } to { gcs_file_path } " )
130129 tf .io .gfile .copy (local_output_path , gcs_file_path , overwrite = True )
130+ os .remove (local_output_path )
131131 max_logging .log ("Upload complete." )
132132
133133 file_index = step // steps_per_file
@@ -179,6 +179,7 @@ def generate_and_save_data(config, local_args):
179179 gcs_file_path = os .path .join (gcs_upload_path , os .path .basename (local_output_path ))
180180 max_logging .log (f"Uploading final chunk to: { gcs_file_path } " )
181181 tf .io .gfile .copy (local_output_path , gcs_file_path , overwrite = True )
182+ os .remove (local_output_path )
182183 max_logging .log ("GCS Upload complete." )
183184
184185 # Sync all hosts one last time so worker hosts don't terminate the job
0 commit comments