Skip to content

Commit da8b70a

Browse files
Merge pull request #4056 from AI-Hypercomputer:ajkv/efficient-top-k-saving
PiperOrigin-RevId: 928686854
2 parents c2d7758 + df0a92f commit da8b70a

2 files changed

Lines changed: 65 additions & 31 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,19 +496,22 @@ def compute_loss(
496496

497497
# --- Soft loss: KL on temperature-softened distributions ---
498498

499-
# 1. Pre-compute Student log-probs over the full vocabulary
499+
# Pre-compute Student log-probs over the full vocabulary
500500
log_s_T_full = jax.nn.log_softmax(s_logits / temperature, axis=-1)
501501
log_s_1_full = jax.nn.log_softmax(s_logits, axis=-1)
502502

503503
if getattr(teacher_output, "top_k_indices", None) is not None:
504504
# --- SPARSE KL DIVERGENCE (Offline Mode) ---
505505

506-
# 2. Normalize teacher probabilities ONLY over the saved Top-K subset
506+
# 1. Normalize teacher probabilities only over the saved Top-K subset
507507
t_p_T_sparse = jax.nn.softmax(t_logits / temperature, axis=-1)
508508
log_t_p_T_sparse = jax.nn.log_softmax(t_logits / temperature, axis=-1)
509509

510-
# 3. Gather Student log-probs at the Teacher's exact Top-K indices
511-
log_s_T_sparse = jnp.take_along_axis(log_s_T_full, teacher_output.top_k_indices, axis=-1)
510+
# 2. Gather Student unnormalized logits at the Teacher's exact Top-K indices
511+
s_logits_sparse = jnp.take_along_axis(s_logits, teacher_output.top_k_indices, axis=-1)
512+
513+
# 3. Normalize Student probabilities only over the exact same Top-K subset
514+
log_s_T_sparse = jax.nn.log_softmax(s_logits_sparse / temperature, axis=-1)
512515

513516
# 4. KL(T || S) = Sum_over_TopK( P_T * (log_P_T - log_P_S) )
514517
kl_softened_per_pos = jnp.sum(t_p_T_sparse * (log_t_p_T_sparse - log_s_T_sparse), axis=-1)
@@ -661,6 +664,7 @@ def __init__(
661664

662665
# Re-initialize internal Orbax manager with MaxText's Grain handler
663666
# pylint: disable=access-member-before-definition
667+
# pytype: disable=attribute-error
664668
if self._checkpoint_manager is not None:
665669
root_directory = self._checkpoint_manager.directory
666670

@@ -681,6 +685,7 @@ def __init__(
681685
item_handlers=item_handlers,
682686
options=options,
683687
)
688+
# pytype: enable=attribute-error
684689
# pylint: enable=access-member-before-definition
685690

686691
def save(

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import subprocess
2222
from concurrent.futures import ThreadPoolExecutor
2323

24-
# Force the pure Python protobuf implementation to avoid UPB compatibility issues with TFDS
25-
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
2624
from typing import Sequence
2725
import argparse
2826
import time
@@ -70,7 +68,6 @@ def get_start_step(config, local_args):
7068
if not existing_files:
7169
return 0
7270

73-
# Updated regex to handle the host ID in the filename
7471
max_part_num = max(
7572
(int(m.group(1)) for f in existing_files if (m := re.search(r"part_(\d+)_host", f.name))),
7673
default=-1,
@@ -92,26 +89,34 @@ def create_tf_example(example_dict):
9289
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=[val]))
9390
continue
9491

95-
flat_val = np.asarray(val).flatten()
92+
flat_val = np.asarray(val).ravel()
9693

9794
if flat_val.dtype in [np.float32, np.float64, np.float16, jnp.bfloat16]:
98-
features[key] = tf.train.Feature(float_list=tf.train.FloatList(value=flat_val.astype(np.float32)))
95+
if flat_val.dtype != np.float32:
96+
flat_val = flat_val.astype(np.float32)
97+
features[key] = tf.train.Feature(float_list=tf.train.FloatList(value=flat_val.tolist()))
9998
elif flat_val.dtype in [np.int32, np.int64]:
100-
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=flat_val.astype(np.int64)))
99+
if flat_val.dtype != np.int64:
100+
flat_val = flat_val.astype(np.int64)
101+
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=flat_val.tolist()))
101102
else:
102103
raise ValueError(f"Unsupported dtype {flat_val.dtype} for key {key}")
103104

104105
return tf.train.Example(features=tf.train.Features(feature=features)).SerializeToString()
105106

106107

107-
def background_process_and_write(writer, tokens, vals, idx, opt_data):
108+
def background_process_and_write(writer, tokens, vals, idx, opt_data, serialization_executor):
108109
"""Executes entirely on a background CPU thread so the TPU never waits."""
109-
tokens_np = np.array(tokens)
110-
vals_np = np.array(vals)
111-
idx_np = np.array(idx)
112-
opt_data_np = {k: np.array(v) for k, v in opt_data.items()}
110+
# Convert exactly once
111+
tokens_np = np.asarray(tokens)
112+
vals_np = np.asarray(vals)
113+
idx_np = np.asarray(idx)
114+
opt_data_np = {k: np.asarray(v) for k, v in opt_data.items()}
113115

114116
batch_size = tokens_np.shape[0]
117+
example_dicts = []
118+
119+
# Prepare dictionaries sequentially
115120
for i in range(batch_size):
116121
seq_bytes = tokens_np[i].tobytes()
117122
example_dict = {
@@ -122,8 +127,14 @@ def background_process_and_write(writer, tokens, vals, idx, opt_data):
122127
}
123128
for key, val_np in opt_data_np.items():
124129
example_dict[key] = val_np[i]
130+
example_dicts.append(example_dict)
131+
132+
# Serialize to Protobufs in parallel across multiple CPU cores
133+
serialized_records = list(serialization_executor.map(create_tf_example, example_dicts))
125134

126-
writer.write(create_tf_example(example_dict))
135+
# Write the serialized bytes to disk sequentially
136+
for record in serialized_records:
137+
writer.write(record)
127138

128139

129140
def background_upload(local_path, gcs_path, process_index):
@@ -163,12 +174,14 @@ def generate_and_save_data(config, local_args):
163174
writer = None
164175
local_output_path = None
165176

166-
# all hosts initialize their own directories and thread pools
167177
if not os.path.exists(local_tmp_dir):
168178
os.makedirs(local_tmp_dir, exist_ok=True)
169179

170-
upload_executor = ThreadPoolExecutor(max_workers=4)
171-
write_executor = ThreadPoolExecutor(max_workers=2)
180+
upload_executor = ThreadPoolExecutor(max_workers=local_args.upload_workers)
181+
# Restrict to 1 worker to ensure sequential writing to the ArrayRecord file
182+
write_executor = ThreadPoolExecutor(max_workers=1)
183+
# New executor purely for CPU-bound protobuf serialization
184+
serialization_executor = ThreadPoolExecutor(max_workers=local_args.serialization_workers)
172185

173186
devices = jax.devices()
174187
devices_array = maxtext_utils.create_device_mesh(config, devices)
@@ -194,7 +207,6 @@ def generate_and_save_data(config, local_args):
194207
for step, batch in enumerate(islice(train_iter, start_step, config.steps), start=start_step):
195208
step_start = time.time()
196209

197-
# ALL HOSTS execute the file opening/closing logic
198210
if step % steps_per_file == 0:
199211
if writer:
200212
write_executor.shutdown(wait=True)
@@ -204,15 +216,20 @@ def generate_and_save_data(config, local_args):
204216
if jax.process_index() == 0:
205217
max_logging.log(f"Queueing distributed background uploads for Step {step}...")
206218
upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
207-
write_executor = ThreadPoolExecutor(max_workers=2)
219+
220+
# Re-initialize the writer thread pool. We restrict it to exactly 1 worker
221+
# to ensure sequential, in-order writing to the ArrayRecord file. A new pool
222+
# is needed because the previous one was shut down to flush all pending writes.
223+
write_executor = ThreadPoolExecutor(max_workers=1)
208224

209225
file_index = step // steps_per_file
210-
# filename includes host ID to prevent GCS collisions
211226
filename = f"teacher_top_k_part_{file_index:05d}_host_{jax.process_index():03d}.array_record"
212227
local_output_path = os.path.join(local_tmp_dir, filename)
213228
writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1")
214229

215230
tokens = batch["inputs"]
231+
232+
# --- Model Forward Pass & Network Gather ---
216233
top_k_vals, top_k_idx = teacher_step(teacher_model, batch, k_val)
217234

218235
global_tokens = jax.experimental.multihost_utils.process_allgather(tokens, tiled=True)
@@ -225,13 +242,11 @@ def generate_and_save_data(config, local_args):
225242
optional_data[key] = jax.experimental.multihost_utils.process_allgather(batch[key], tiled=True)
226243

227244
if writer:
228-
# Convert to numpy safely on the CPU
229245
global_tokens_np = np.array(global_tokens)
230246
global_vals_np = np.array(global_vals)
231247
global_idx_np = np.array(global_idx)
232248
optional_data_np = {k: np.array(v) for k, v in optional_data.items()}
233249

234-
# Slice out this host's local fraction of the batch
235250
global_batch_size = global_tokens_np.shape[0]
236251
local_batch_size = global_batch_size // jax.process_count()
237252
start_idx = jax.process_index() * local_batch_size
@@ -242,24 +257,33 @@ def generate_and_save_data(config, local_args):
242257
local_idx_np = global_idx_np[start_idx:end_idx]
243258
local_opt_data_np = {k: v[start_idx:end_idx] for k, v in optional_data_np.items()}
244259

245-
# Write synchronously
246-
background_process_and_write(writer, local_tokens_np, local_vals_np, local_idx_np, local_opt_data_np)
260+
# --- Local Disk Writing ---
261+
# Submit to the background thread with the serialization_executor
262+
write_executor.submit(
263+
background_process_and_write,
264+
writer,
265+
local_tokens_np,
266+
local_vals_np,
267+
local_idx_np,
268+
local_opt_data_np,
269+
serialization_executor,
270+
)
247271

248272
if step % 50 == 0 and jax.process_index() == 0:
249273
max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s")
250274

251-
# Sync hosts briefly to ensure TPU compute stays aligned across the mesh
252275
multihost_utils.sync_global_devices(f"step_{step}_complete")
253276

254277
if jax.process_index() == 0:
255278
max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s")
256279

257280
multihost_utils.sync_global_devices("loop_finished")
258281

259-
# Finalize writing and handle GCS upload on all hosts
260282
if writer:
261283
if write_executor:
262284
write_executor.shutdown(wait=True)
285+
if serialization_executor:
286+
serialization_executor.shutdown(wait=True)
263287
writer.close()
264288

265289
if gcs_upload_path:
@@ -279,8 +303,9 @@ def generate_and_save_data(config, local_args):
279303
def main(argv: Sequence[str], local_args):
280304
global_config = pyconfig.initialize(argv)
281305
teacher_overrides = global_config.teacher_overrides
282-
teacher_argv = [argv[0], argv[1]]
283-
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
306+
307+
teacher_config = pyconfig.initialize(argv, **teacher_overrides)
308+
284309
generate_and_save_data(teacher_config, local_args)
285310

286311

@@ -295,7 +320,11 @@ def main(argv: Sequence[str], local_args):
295320
)
296321
parser.add_argument("--gcs_upload_path", type=str, default=None)
297322
parser.add_argument("--local_tmp_dir", type=str, default="/tmp")
298-
parser.add_argument("--steps_per_file", type=int, default=1000)
323+
parser.add_argument("--steps_per_file", type=int, default=50)
324+
parser.add_argument("--upload_workers", type=int, default=4, help="Number of workers for GCS uploads.")
325+
parser.add_argument(
326+
"--serialization_workers", type=int, default=16, help="Number of workers for protobuf serialization."
327+
)
299328
local_arg, remaining_args = parser.parse_known_args()
300329

301330
main_wrapper = functools.partial(main, local_args=local_arg)

0 commit comments

Comments
 (0)