Skip to content

Commit 1448d5a

Browse files
Merge pull request #3555 from AI-Hypercomputer:ajkv/fault-tolerant-save-top-k
PiperOrigin-RevId: 902881912
2 parents 59cf305 + b7c33c7 commit 1448d5a

2 files changed

Lines changed: 102 additions & 38 deletions

File tree

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
1919
Example command:
2020
python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py \
21-
src/maxtext/configs/post_train/distillation.yml \
22-
--top_k=128 \
23-
--gcs_upload_path=gs://my-bucket/teacher_logits/
21+
src/maxtext/configs/post_train/distillation.yml \
22+
--local_tmp_dir=/tmp/save_logits_dir \
23+
--steps_per_file=10
2424
"""
2525

2626
import os
@@ -30,8 +30,10 @@
3030
import time
3131
import sys
3232
import tensorflow as tf
33+
import re
3334

3435
import jax
36+
import jax.numpy as jnp
3537
import functools
3638
from itertools import islice
3739

@@ -52,12 +54,41 @@ def get_top_k_logits(logits: jax.Array, k: int):
5254
return top_k_values, top_k_indices
5355

5456

57+
def get_start_step(config, local_args):
58+
"""Determines the starting step for the generation process."""
59+
if jax.process_index() != 0:
60+
return 0
61+
62+
output_dir = local_args.gcs_upload_path if local_args.gcs_upload_path else local_args.local_tmp_dir
63+
if not tf.io.gfile.exists(output_dir):
64+
tf.io.gfile.makedirs(output_dir)
65+
return 0
66+
67+
existing_files = tf.io.gfile.glob(os.path.join(output_dir, "teacher_top_k_part_*.array_record"))
68+
if not existing_files:
69+
return 0
70+
71+
# Find the highest part number from the filenames
72+
max_part_num = max(
73+
(int(m.group(1)) for f in existing_files if (m := re.search(r"part_(\d+).array_record", os.path.basename(f)))),
74+
default=-1,
75+
)
76+
77+
if max_part_num == -1:
78+
return 0
79+
80+
start_step = max_part_num * local_args.steps_per_file
81+
max_logging.log(f"Found existing data, resuming from step {start_step}")
82+
return start_step
83+
84+
5585
def generate_and_save_data(config, local_args):
56-
"""Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS."""
86+
"""Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS"""
5787
k_val = local_args.top_k
5888
optional_keys = local_args.optional_keys
5989
gcs_upload_path = local_args.gcs_upload_path
6090
local_tmp_dir = local_args.local_tmp_dir
91+
steps_per_file = local_args.steps_per_file
6192

6293
devices = jax.devices()
6394
devices_array = maxtext_utils.create_device_mesh(config, devices)
@@ -68,25 +99,43 @@ def generate_and_save_data(config, local_args):
6899
teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
69100
train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)
70101

71-
# Setup local tmp directory for Host 0
72-
filename = "teacher_top_k_global.array_record"
73-
local_output_path = os.path.join(local_tmp_dir, filename)
102+
# Determine start_step for resuming
103+
start_step = get_start_step(config, local_args)
104+
start_step = int(multihost_utils.broadcast_one_to_all(jnp.array(start_step)))
74105

75106
writer = None
107+
local_output_path = None
76108
if jax.process_index() == 0:
77109
if not os.path.exists(local_tmp_dir):
78110
os.makedirs(local_tmp_dir)
79-
max_logging.log(f"Process 0 writing globally gathered data to local path: {local_output_path}")
80-
writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1000")
81111

82112
# Sync all hosts before starting the loop
83113
multihost_utils.sync_global_devices("start_generation_loop")
84114

85-
max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...")
115+
max_logging.log(f"Starting Top-K generation loop for {config.steps - start_step} steps...")
86116
loop_start = time.time()
87117

88-
for step, batch in enumerate(islice(train_iter, config.steps)):
118+
for step, batch in enumerate(islice(train_iter, start_step, config.steps), start=start_step):
89119
step_start = time.time()
120+
121+
# Open a new writer for each file chunk on process 0
122+
if jax.process_index() == 0 and step % steps_per_file == 0:
123+
if writer:
124+
writer.close()
125+
if gcs_upload_path:
126+
# Upload the previous file
127+
gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path))
128+
max_logging.log(f"Uploading {local_output_path} to {gcs_file_path}")
129+
tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True)
130+
os.remove(local_output_path)
131+
max_logging.log("Upload complete.")
132+
133+
file_index = step // steps_per_file
134+
filename = f"teacher_top_k_part_{file_index:05d}.array_record"
135+
local_output_path = os.path.join(local_tmp_dir, filename)
136+
max_logging.log(f"Process 0 writing to new chunk: {local_output_path}")
137+
writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1000")
138+
90139
tokens = batch["inputs"]
91140
logits = teacher_model(
92141
decoder_input_tokens=tokens,
@@ -122,19 +171,15 @@ def generate_and_save_data(config, local_args):
122171
multihost_utils.sync_global_devices("loop_finished")
123172

124173
# Finalize writing and handle GCS upload on Host 0
125-
if jax.process_index() == 0:
174+
if jax.process_index() == 0 and writer:
126175
writer.close()
127176
max_logging.log(f"Finished writing to local disk: {local_output_path}")
128177

129178
if gcs_upload_path:
130-
gcs_file_path = os.path.join(gcs_upload_path, filename)
131-
max_logging.log(f"Flag --gcs_upload_path detected. Uploading to: {gcs_file_path}")
132-
133-
if not tf.io.gfile.exists(gcs_upload_path):
134-
tf.io.gfile.makedirs(gcs_upload_path)
135-
136-
# Perform the bulk copy to GCS
179+
gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path))
180+
max_logging.log(f"Uploading final chunk to: {gcs_file_path}")
137181
tf.io.gfile.copy(local_output_path, gcs_file_path, overwrite=True)
182+
os.remove(local_output_path)
138183
max_logging.log("GCS Upload complete.")
139184

140185
# Sync all hosts one last time so worker hosts don't terminate the job
@@ -182,6 +227,12 @@ def main(argv: Sequence[str], local_args):
182227
default="/tmp",
183228
help="Local temporary directory to write the ArrayRecord file before optional GCS upload.",
184229
)
230+
parser.add_argument(
231+
"--steps_per_file",
232+
type=int,
233+
default=1000,
234+
help="Number of steps to save in each chunk.",
235+
)
185236
local_arg, remaining_args = parser.parse_known_args()
186237

187238
main_wrapper = functools.partial(main, local_args=local_arg)

src/maxtext/trainers/post_train/distillation/verify_saved_logits.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
Verification script to check the correctness of saved top-k teacher logits.
1717
1818
Example usage:
19-
python3 src/maxtext/trainers/post_train/distillation/verify_saved_logits.py \
20-
--output_dir=/path/to/your/output \
21-
--expected_steps=1000 \
22-
--top_k=128
19+
python3 python3 src/maxtext/trainers/post_train/distillation/verify_saved_logits.py \
20+
--output_dir=/tmp/save_logits_dir \
21+
--expected_steps=140
2322
"""
2423

2524
import functools
@@ -40,44 +39,58 @@ def verify_array_records(output_dir, expected_steps, expected_k, expected_keys):
4039
files = tf.io.gfile.glob(file_pattern)
4140

4241
if not files:
43-
max_logging.log(f"Error: No ArrayRecord files found matching {file_pattern}")
44-
return
42+
assert False, f"Error: No ArrayRecord files found matching {file_pattern}"
4543

4644
max_logging.log(f"Found {len(files)} ArrayRecord files. Starting verification...")
4745

46+
total_records_processed = 0
47+
all_keys_verified = set()
48+
4849
for file_path in files:
4950
max_logging.log(f"Verifying: {file_path}")
5051
reader = array_record_module.ArrayRecordReader(file_path)
51-
num_records = reader.num_records()
52+
num_records_in_file = reader.num_records()
53+
54+
if num_records_in_file == 0:
55+
max_logging.log(f"Warning: {file_path} is empty.")
56+
continue
5257

53-
step_count = 0
54-
for _ in range(num_records):
58+
for record_idx in range(num_records_in_file):
5559
record = reader.read()
5660
data = pickle.loads(record)
5761

5862
# Verify all required keys are present
59-
for key in ["tokens", "top_k_logits", "top_k_indices"]:
60-
assert key in data, f"Missing required key '{key}' at step {step_count} in {file_path}"
63+
required_keys = ["tokens", "top_k_logits", "top_k_indices"]
64+
for key in required_keys:
65+
assert key in data, f"Missing required key '{key}' in record {record_idx} in {file_path}"
6166

6267
# Verify all optional keys are present
6368
for key in expected_keys:
64-
assert key in data, f"Missing optional key '{key}' at step {step_count} in {file_path}"
69+
assert key in data, f"Missing optional key '{key}' in record {record_idx} in {file_path}"
6570

6671
# Verify shapes for Top-K outputs
6772
actual_k_logits = data["top_k_logits"].shape[-1]
6873
actual_k_indices = data["top_k_indices"].shape[-1]
6974
assert actual_k_logits == expected_k, f"Expected top_k={expected_k}, got {actual_k_logits} for logits"
7075
assert actual_k_indices == expected_k, f"Expected top_k={expected_k}, got {actual_k_indices} for indices"
7176

72-
step_count += 1
77+
if not all_keys_verified:
78+
all_keys_verified.update(data.keys())
79+
80+
total_records_processed += num_records_in_file
81+
max_logging.log(f"Verified {num_records_in_file} records in {file_path}")
7382

74-
# Verify the total number of steps processed
75-
assert step_count == expected_steps, f"Expected {expected_steps} steps, but found {step_count} in {file_path}."
83+
# Verify the total number of steps processed across all files
84+
assert (
85+
total_records_processed == expected_steps
86+
), f"Expected a total of {expected_steps} steps across all files, but found {total_records_processed}."
7687

77-
max_logging.log(f"Successfully verified {file_path}")
78-
max_logging.log(f"- Total steps: {step_count} (Matches expected)")
79-
max_logging.log(f"- Top-K dimension: {expected_k}")
80-
max_logging.log(f"- Keys verified: {list(data.keys())}")
88+
max_logging.log("-----------------------------------------")
89+
max_logging.log("Verification Successful!")
90+
max_logging.log(f"- Total files verified: {len(files)}")
91+
max_logging.log(f"- Total steps verified: {total_records_processed} (Matches expected)")
92+
max_logging.log(f"- Top-K dimension: {expected_k}")
93+
max_logging.log(f"- Keys verified in records: {sorted(list(all_keys_verified))}")
8194

8295

8396
def main(argv, local_args):

0 commit comments

Comments
 (0)