Skip to content

Commit b7c33c7

Browse files
committed
Updated code to be less redundant and made sure to remove written files to gcs from local
1 parent 2f2312f commit b7c33c7

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import re
3434

3535
import jax
36+
import jax.numpy as jnp
3637
import functools
3738
from itertools import islice
3839

@@ -68,17 +69,15 @@ def get_start_step(config, local_args):
6869
return 0
6970

7071
# Find the highest part number from the filenames
71-
max_part_num = -1
72-
for f in existing_files:
73-
max_part_num = max(
74-
(int(m.group(1)) for f in existing_files if (m := re.search(r"part_(\d+).array_record", os.path.basename(f)))),
75-
default=-1,
76-
)
72+
max_part_num = max(
73+
(int(m.group(1)) for f in existing_files if (m := re.search(r"part_(\d+).array_record", os.path.basename(f)))),
74+
default=-1,
75+
)
7776

7877
if max_part_num == -1:
7978
return 0
8079

81-
start_step = (max_part_num + 1) * local_args.steps_per_file
80+
start_step = max_part_num * local_args.steps_per_file
8281
max_logging.log(f"Found existing data, resuming from step {start_step}")
8382
return start_step
8483

@@ -102,7 +101,7 @@ def generate_and_save_data(config, local_args):
102101

103102
# Determine start_step for resuming
104103
start_step = get_start_step(config, local_args)
105-
start_step = int(multihost_utils.broadcast_one_to_all(jax.numpy.array(start_step)))
104+
start_step = int(multihost_utils.broadcast_one_to_all(jnp.array(start_step)))
106105

107106
writer = None
108107
local_output_path = None
@@ -128,6 +127,7 @@ def generate_and_save_data(config, local_args):
128127
gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path))
129128
max_logging.log(f"Uploading {local_output_path} to {gcs_file_path}")
130129
tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True)
130+
os.remove(local_output_path)
131131
max_logging.log("Upload complete.")
132132

133133
file_index = step // steps_per_file
@@ -179,6 +179,7 @@ def generate_and_save_data(config, local_args):
179179
gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path))
180180
max_logging.log(f"Uploading final chunk to: {gcs_file_path}")
181181
tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True)
182+
os.remove(local_output_path)
182183
max_logging.log("GCS Upload complete.")
183184

184185
# Sync all hosts one last time so worker hosts don't terminate the job

0 commit comments

Comments
 (0)