Skip to content

Commit 95db5de

Browse files
committed
Updated top-k sparse kl div logic
1 parent 8b7058d commit 95db5de

2 files changed

Lines changed: 82 additions & 103 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,19 +496,22 @@ def compute_loss(
496496

497497
# --- Soft loss: KL on temperature-softened distributions ---
498498

499-
# 1. Pre-compute Student log-probs over the full vocabulary
499+
# Pre-compute Student log-probs over the full vocabulary
500500
log_s_T_full = jax.nn.log_softmax(s_logits / temperature, axis=-1)
501501
log_s_1_full = jax.nn.log_softmax(s_logits, axis=-1)
502502

503503
if getattr(teacher_output, "top_k_indices", None) is not None:
504504
# --- SPARSE KL DIVERGENCE (Offline Mode) ---
505505

506-
# 2. Normalize teacher probabilities ONLY over the saved Top-K subset
506+
# 1. Normalize teacher probabilities only over the saved Top-K subset
507507
t_p_T_sparse = jax.nn.softmax(t_logits / temperature, axis=-1)
508508
log_t_p_T_sparse = jax.nn.log_softmax(t_logits / temperature, axis=-1)
509509

510-
# 3. Gather Student log-probs at the Teacher's exact Top-K indices
511-
log_s_T_sparse = jnp.take_along_axis(log_s_T_full, teacher_output.top_k_indices, axis=-1)
510+
# 2. Gather Student unnormalized logits at the Teacher's exact Top-K indices
511+
s_logits_sparse = jnp.take_along_axis(s_logits, teacher_output.top_k_indices, axis=-1)
512+
513+
# 3. Normalize Student probabilities only over the exact same Top-K subset
514+
log_s_T_sparse = jax.nn.log_softmax(s_logits_sparse / temperature, axis=-1)
512515

513516
# 4. KL(T || S) = Sum_over_TopK( P_T * (log_P_T - log_P_S) )
514517
kl_softened_per_pos = jnp.sum(t_p_T_sparse * (log_t_p_T_sparse - log_s_T_sparse), axis=-1)
@@ -517,6 +520,7 @@ def compute_loss(
517520
ce_teacher_per_pos = jnp.zeros(s_logits.shape[:-1])
518521
kl_t1_sum = jnp.array(0.0)
519522

523+
520524
else:
521525
# --- DENSE KL DIVERGENCE (Online Mode) ---
522526
t_p_T = jax.nn.softmax(t_logits / temperature, axis=-1)

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 74 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
"""
216
This module provides functionality to save top-k teacher logits
317
for distillation purposes in MaxText.
@@ -81,12 +95,11 @@ def create_tf_example(example_dict):
8195
if flat_val.dtype in [np.float32, np.float64, np.float16, jnp.bfloat16]:
8296
if flat_val.dtype != np.float32:
8397
flat_val = flat_val.astype(np.float32)
84-
# Use .tolist() for extremely fast Protobuf C++ ingestion
98+
# Use .tolist() for fast Protobuf C++ ingestion
8599
features[key] = tf.train.Feature(float_list=tf.train.FloatList(value=flat_val.tolist()))
86100
elif flat_val.dtype in [np.int32, np.int64]:
87101
if flat_val.dtype != np.int64:
88102
flat_val = flat_val.astype(np.int64)
89-
# Use .tolist() for extremely fast Protobuf C++ ingestion
90103
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=flat_val.tolist()))
91104
else:
92105
raise ValueError(f"Unsupported dtype {flat_val.dtype} for key {key}")
@@ -96,51 +109,46 @@ def create_tf_example(example_dict):
96109

97110
def background_process_and_write(writer, tokens, vals, idx, opt_data, serialization_executor):
98111
"""Executes entirely on a background CPU thread so the TPU never waits."""
99-
with tf.profiler.experimental.Trace("background_local_disk_write"):
100-
# Convert exactly once
101-
tokens_np = np.asarray(tokens)
102-
vals_np = np.asarray(vals)
103-
idx_np = np.asarray(idx)
104-
opt_data_np = {k: np.asarray(v) for k, v in opt_data.items()}
105-
106-
batch_size = tokens_np.shape[0]
107-
example_dicts = []
108-
109-
# Prepare dictionaries sequentially
110-
for i in range(batch_size):
111-
seq_bytes = tokens_np[i].tobytes()
112-
example_dict = {
113-
"inputs": tokens_np[i],
114-
"top_k_logits": vals_np[i],
115-
"top_k_indices": idx_np[i],
116-
"sequence_hash": hash(seq_bytes),
117-
}
118-
for key, val_np in opt_data_np.items():
119-
example_dict[key] = val_np[i]
120-
example_dicts.append(example_dict)
121-
122-
# Serialize to Protobufs in parallel across multiple CPU cores
123-
with tf.profiler.experimental.Trace("parallel_serialize"):
124-
serialized_records = list(serialization_executor.map(create_tf_example, example_dicts))
125-
126-
# Write the serialized bytes to disk sequentially
127-
with tf.profiler.experimental.Trace("sequential_write"):
128-
for record in serialized_records:
129-
writer.write(record)
112+
# Convert exactly once
113+
tokens_np = np.asarray(tokens)
114+
vals_np = np.asarray(vals)
115+
idx_np = np.asarray(idx)
116+
opt_data_np = {k: np.asarray(v) for k, v in opt_data.items()}
117+
118+
batch_size = tokens_np.shape[0]
119+
example_dicts = []
120+
121+
# Prepare dictionaries sequentially
122+
for i in range(batch_size):
123+
seq_bytes = tokens_np[i].tobytes()
124+
example_dict = {
125+
"inputs": tokens_np[i],
126+
"top_k_logits": vals_np[i],
127+
"top_k_indices": idx_np[i],
128+
"sequence_hash": hash(seq_bytes),
129+
}
130+
for key, val_np in opt_data_np.items():
131+
example_dict[key] = val_np[i]
132+
example_dicts.append(example_dict)
133+
134+
# Serialize to Protobufs in parallel across multiple CPU cores
135+
serialized_records = list(serialization_executor.map(create_tf_example, example_dicts))
136+
137+
# Write the serialized bytes to disk sequentially
138+
for record in serialized_records:
139+
writer.write(record)
130140

131141

132142
def background_upload(local_path, gcs_path, process_index):
133143
"""Executes a highly optimized concurrent upload via gcloud."""
134-
# Swapped to TF Trace context
135-
with tf.profiler.experimental.Trace("gcs_upload_and_cleanup"):
136-
try:
137-
subprocess.run(["gcloud", "storage", "cp", local_path, gcs_path], check=True, capture_output=True)
138-
os.remove(local_path)
139-
if process_index == 0:
140-
max_logging.log(f"Background upload complete: {gcs_path}")
141-
except subprocess.CalledProcessError as e:
142-
if process_index == 0:
143-
max_logging.log(f"Upload failed for {local_path}: {e.stderr.decode()}")
144+
try:
145+
subprocess.run(["gcloud", "storage", "cp", local_path, gcs_path], check=True, capture_output=True)
146+
os.remove(local_path)
147+
if process_index == 0:
148+
max_logging.log(f"Background upload complete: {gcs_path}")
149+
except subprocess.CalledProcessError as e:
150+
if process_index == 0:
151+
max_logging.log(f"Upload failed for {local_path}: {e.stderr.decode()}")
144152

145153

146154
@nnx.jit(static_argnames=("k",))
@@ -201,22 +209,6 @@ def generate_and_save_data(config, local_args):
201209
for step, batch in enumerate(islice(train_iter, start_step, config.steps), start=start_step):
202210
step_start = time.time()
203211

204-
# --- 1. PROFILER SETUP ---
205-
is_profiling_step = (
206-
config.profiler == "xplane"
207-
and step == config.skip_first_n_steps_for_profiler
208-
)
209-
210-
is_profiling_stop_step = (
211-
config.profiler == "xplane"
212-
and step == config.skip_first_n_steps_for_profiler + config.profiler_steps
213-
)
214-
215-
if is_profiling_step and jax.process_index() == 0:
216-
max_logging.log(f"Recording Host-Only XProf trace for step {step} using TF API...")
217-
options = tf.profiler.experimental.ProfilerOptions(host_tracer_level=2, device_tracer_level=0)
218-
tf.profiler.experimental.start(config.tensorboard_dir, options=options)
219-
220212
if step % steps_per_file == 0:
221213
if writer:
222214
write_executor.shutdown(wait=True)
@@ -225,9 +217,7 @@ def generate_and_save_data(config, local_args):
225217
gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path))
226218
if jax.process_index() == 0:
227219
max_logging.log(f"Queueing distributed background uploads for Step {step}...")
228-
# Swapped to TF Trace context
229-
with tf.profiler.experimental.Trace("submit_to_gcs_upload"):
230-
upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
220+
upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
231221

232222
# Re-initialize the writer with 1 worker
233223
write_executor = ThreadPoolExecutor(max_workers=1)
@@ -239,19 +229,17 @@ def generate_and_save_data(config, local_args):
239229

240230
tokens = batch["inputs"]
241231

242-
# --- TRACE 1: Model Forward Pass & Network Gather ---
243-
# Swapped to TF Trace context
244-
with tf.profiler.experimental.Trace("teacher_forward_and_gather"):
245-
top_k_vals, top_k_idx = teacher_step(teacher_model, batch, k_val)
232+
# --- Model Forward Pass & Network Gather ---
233+
top_k_vals, top_k_idx = teacher_step(teacher_model, batch, k_val)
246234

247-
global_tokens = jax.experimental.multihost_utils.process_allgather(tokens, tiled=True)
248-
global_vals = jax.experimental.multihost_utils.process_allgather(top_k_vals, tiled=True)
249-
global_idx = jax.experimental.multihost_utils.process_allgather(top_k_idx, tiled=True)
235+
global_tokens = jax.experimental.multihost_utils.process_allgather(tokens, tiled=True)
236+
global_vals = jax.experimental.multihost_utils.process_allgather(top_k_vals, tiled=True)
237+
global_idx = jax.experimental.multihost_utils.process_allgather(top_k_idx, tiled=True)
250238

251-
optional_data = {}
252-
for key in optional_keys:
253-
if key in batch:
254-
optional_data[key] = jax.experimental.multihost_utils.process_allgather(batch[key], tiled=True)
239+
optional_data = {}
240+
for key in optional_keys:
241+
if key in batch:
242+
optional_data[key] = jax.experimental.multihost_utils.process_allgather(batch[key], tiled=True)
255243

256244
if writer:
257245
global_tokens_np = np.array(global_tokens)
@@ -269,30 +257,23 @@ def generate_and_save_data(config, local_args):
269257
local_idx_np = global_idx_np[start_idx:end_idx]
270258
local_opt_data_np = {k: v[start_idx:end_idx] for k, v in optional_data_np.items()}
271259

272-
# --- TRACE 2: Local Disk Writing ---
260+
# --- Local Disk Writing ---
273261
# Submit to the background thread with the serialization_executor
274-
with tf.profiler.experimental.Trace("local_disk_write_submit"):
275-
write_executor.submit(
276-
background_process_and_write,
277-
writer,
278-
local_tokens_np,
279-
local_vals_np,
280-
local_idx_np,
281-
local_opt_data_np,
282-
serialization_executor
283-
)
262+
write_executor.submit(
263+
background_process_and_write,
264+
writer,
265+
local_tokens_np,
266+
local_vals_np,
267+
local_idx_np,
268+
local_opt_data_np,
269+
serialization_executor
270+
)
284271

285272
if step % 50 == 0 and jax.process_index() == 0:
286273
max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s")
287274

288275
multihost_utils.sync_global_devices(f"step_{step}_complete")
289276

290-
# --- 2. STOP PROFILER ---
291-
if is_profiling_stop_step:
292-
if jax.process_index() == 0:
293-
max_logging.log(f"Stopping XProf profiler and uploading clean host trace...")
294-
tf.profiler.experimental.stop()
295-
296277
if jax.process_index() == 0:
297278
max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s")
298279

@@ -307,9 +288,7 @@ def generate_and_save_data(config, local_args):
307288

308289
if gcs_upload_path:
309290
gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path))
310-
# Swapped to TF Trace context
311-
with tf.profiler.experimental.Trace("submit_to_gcs_upload"):
312-
upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
291+
upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
313292

314293
if upload_executor:
315294
if jax.process_index() == 0:
@@ -319,10 +298,6 @@ def generate_and_save_data(config, local_args):
319298
max_logging.log("All GCS uploads complete.")
320299

321300
multihost_utils.sync_global_devices("upload_complete")
322-
323-
if jax.process_index() == 0:
324-
max_logging.log("Waiting 15 seconds for XProf to save the trace...")
325-
time.sleep(15)
326301

327302

328303
def main(argv: Sequence[str], local_args):
@@ -345,8 +320,8 @@ def main(argv: Sequence[str], local_args):
345320
)
346321
parser.add_argument("--gcs_upload_path", type=str, default=None)
347322
parser.add_argument("--local_tmp_dir", type=str, default="/tmp")
348-
parser.add_argument("--steps_per_file", type=int, default=2)
323+
parser.add_argument("--steps_per_file", type=int, default=50)
349324
local_arg, remaining_args = parser.parse_known_args()
350325

351326
main_wrapper = functools.partial(main, local_args=local_arg)
352-
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)
327+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)