Skip to content

Commit 05a15fd

Browse files
committed
Fix bypass distillation: PP-aware save, step semantics, resume bookkeeping
Several correctness/UX fixes to the bypass distillation stage, surfaced while running the Nemotron-3-Nano-30B-A3B-Base-BF16 tutorial end-to-end. Save / resume correctness - save_bypass_checkpoint now calls save_checkpoint_from_shards (gather- aware) instead of save_checkpoint. Under PP each rank's model.state_dict() only carries its own owned blocks; the unsharded variant let every rank race-write a partial model.safetensors.index.json to the same path (last writer wins), so resume's load_and_shard_model left most params on the meta device. - find_latest_run_dir: treat missing parent dir as "no previous runs" rather than fatal — a freshly-wiped bypass dir or a non-master rank reaching the function before master's set_experiment_dir mkdir becomes visible (no barrier between them) no longer crashes. - puzzletron_nas_plugin "bypass already done" check now keys off a _DONE sentinel written only on successful training completion. The prior `latest` symlink check tripped on every periodic save, so a Ctrl-C'd partial run looked completed and the pipeline skipped ahead to stats. - On resume, bump iter_num/step_num by 1 — args.json records the iter that just completed (increments at the bottom of the loop run after save), so restoring as-is re-executed the saved iter (visible as "iter 899" instead of "iter 900" in the first post-resume chunk). - Persist per-block bookkeeping (best_losses_by_name, best_steps_by_name, initial_losses_by_name) into cfg.bypass on each log chunk and restore on resume; the columns no longer reset to "current is best" / re-anchor "Δ from initial" to the resume point. Step / iter semantics - max_steps formula now divides by grad_accumulation_steps. Previously max_steps = ceil(training_tokens / tokens_per_iter) and max_iters = max_steps × grad_accum, so the actual budget was grad_accum × training_tokens (8× overshoot for the Nemotron config). - warmup_steps resolver gains a grad_accum arg so warmup is in optimizer-step units matching _get_lr's step_num indexing. YAMLs that reference the resolver pass the new arg. - Save-dir naming changed from iter-NNNNNN-ckpt / {start,final,best}-iter-* to step-NNNNNN-ckpt / {start,final,best}-step-* — saves were already step-gated, the iter in the name was misleading. Tests updated to match. - Log line now leads with step (matching what users mean by "a step") and reports avg_step_time = grad_accum × per-iter; iter is dropped from the headline entirely. wandb log keys follow. - log_interval is now in optimizer-step units (multiplied by grad_accum internally before slicing the per-iter history). - Stitched Module Losses table: best_steps_by_name and the summary line now record step_num (matching the header units), not iter. UX / readability - Stitched Module Losses: replace "Change from avg" (a same-step cross-block diff masquerading as a temporal delta) with "Δ from initial" — anchored at lowest_iter's per-block losses, showing both absolute and % reduction. Column reordered to Block / Loss / Δ from initial / Best Value / Best Step so the delta sits next to Loss. - Architecture plot emojis: 🐙 for kv-heads attention, 🔀 for MoE, 🧱 for dense FFN (ffn_dim renamed from ffn_intermediate). Filled in the gaps left by 🐍 mamba / ❌ no_op. Tutorial config - nemotron-3-nano-30b-a3b: add_attention_no_ops false (no_op variant is never picked by MIP at target_num_kv_heads=9 so scoring it is wasted compute), training_tokens 1e+7 → 5e+7 (50M), log_interval and save_interval set to 100 steps. - Tutorial markdown updated for the 50M budget and the attention-only variant set. Test fixture: patched_save's monkeypatch updated from "save_checkpoint" to "save_checkpoint_from_shards" to match the production change. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
1 parent 88792f4 commit 05a15fd

13 files changed

Lines changed: 229 additions & 141 deletions

File tree

examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ A minimal end-to-end demonstration that **bypass distillation improves quality**
77
The teacher has 6 attention layers (each with `num_key_value_heads=2`) interleaved between Mamba and MoE-FFN blocks — **12 KV heads total** across the whole model. We compress to **9 KV heads (75% of teacher)** in two ways and compare:
88

99
1. **Without bypass** — replacement library uses Truncate-init weights (KV heads sliced from teacher; no further training).
10-
2. **With bypass** — the bypass step runs ~10M tokens of per-block knowledge distillation, training a 1-KV-head variant per attention layer against the teacher.
10+
2. **With bypass** — the bypass step runs ~50M tokens of per-block knowledge distillation, training a 1-KV-head variant per attention layer against the teacher.
1111

12-
Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: 9`), so MIP picks per attention layer from `{teacher 2-head, 1-head, no_op}` (the no_op variant lets the solver drop attention entirely on a layer if doing so is cheap enough). FFN/MoE/Mamba blocks are copied verbatim from the teacher in both runs — only attention weights change.
12+
Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: 9`), so MIP picks per attention layer from `{teacher 2-head, 1-head}`. FFN/MoE/Mamba blocks are copied verbatim from the teacher in both runs — only attention weights change.
1313

1414
**Metrics:** `lm_loss` and `token_accuracy_top_1` measured against the same held-out dataset by the realize-model step (printed automatically to `puzzle_dir/log.txt`).
1515

@@ -70,7 +70,7 @@ torchrun --nproc_per_node=8 examples/puzzletron/main.py \
7070
--config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml
7171
```
7272

73-
Skip-if-done caching reuses Step A's converted teacher checkpoint, activation scores, and pruned checkpoints. Only Step 5 (bypass distillation, ~60 min for 10M tokens) and the downstream library/scoring/MIP rerun. Wall-clock: roughly **+1.5 h** on top of Step A.
73+
Skip-if-done caching reuses Step A's converted teacher checkpoint, activation scores, and pruned checkpoints. Only Step 5 (bypass distillation, ~50M tokens) and the downstream library/scoring/MIP rerun.
7474

7575
Bypass writes its outputs under `${puzzle_dir}/bypass/bypass_runs/bypass_heads_1/` and creates a symlink `${puzzle_dir}/ckpts/bypass_heads_1` that the replacement library builder picks up automatically.
7676

@@ -84,14 +84,14 @@ Reducing total KV heads from 12 → 9 (75% of teacher) at fixed FFN/MoE/Mamba on
8484
|------------------------------|----------------------:|----------:|-----------------------:|
8585
| Teacher | 12 | 0.5950 | 0.8468 |
8686
| Pruned, **no bypass** (Truncate-init) | 9 | 0.6347 | 0.8373 |
87-
| Pruned, **with bypass** (10M-token BLD) | 9 | **0.6055**| **0.8441** |
87+
| Pruned, **with bypass** (50M-token BLD) | 9 | **0.6055**| **0.8441** |
8888

8989
**Bypass closes ~74% of the regression gap** at this compression budget:
9090

9191
- `lm_loss` gap to teacher: `0.0397` without bypass → `0.0105` with bypass — bypass recovers **74%**.
9292
- `token_accuracy_top_1` gap to teacher: `0.0095` without bypass → `0.0027` with bypass — bypass recovers **72%**.
9393

94-
For 10M tokens of per-block KD, that's a substantial lift on a real 30B-A3B teacher.
94+
For 50M tokens of per-block KD, that's a substantial lift on a real 30B-A3B teacher.
9595

9696
## Going further: full accuracy recovery
9797

examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ training:
3737
micro_batch_size: 2
3838
val_micro_batch_size: 1
3939
warmup_ratio: 0.05
40-
warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.warmup_ratio}} # Auto-calculated warmup steps
40+
warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.grad_accumulation_steps},${.warmup_ratio}} # Auto-calculated warmup steps
4141
min_lr_factor: 1e-5
4242
grad_accumulation_steps: 1
4343
skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues.

examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
1515
skip_realize_model: false
1616

1717
# KV-heads-only pruning: lock off FFN/MoE-side variants. The replacement library
18-
# exposes {teacher 2-head, 1-head, no_op} per attention layer; FFN and Mamba
18+
# exposes {teacher 2-head, 1-head} per attention layer; FFN and Mamba
1919
# blocks are copied verbatim from the teacher.
2020
build_replacement_library:
2121
add_ffn_no_ops: false
22-
add_attention_no_ops: true
22+
add_attention_no_ops: false
2323

2424
calc_subblock_stats:
2525
batch_sizes: [64, 96, 128]

examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ data:
3838
# Training Configuration
3939
training:
4040
learning_rate: 3e-4
41-
training_tokens: 1e+7 # 10M tokens (toy budget)
41+
training_tokens: 5e+7 # 50M tokens (toy budget)
4242
micro_batch_size: 2
4343
val_micro_batch_size: 2
4444
warmup_ratio: 0.05
45-
warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.warmup_ratio}}
45+
warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.grad_accumulation_steps},${.warmup_ratio}}
4646
min_lr_factor: 1e-5
4747
grad_accumulation_steps: 8
4848
skip_first_batches: 0
@@ -69,7 +69,7 @@ model:
6969
model_overrides:
7070
delete_old_checkpoints: true
7171
save_interval_seconds: 12900
72-
save_interval: 1e+9
72+
save_interval: 100
7373
save_checkpoint_when_done: true
7474

7575
# Architecture override: only attention is touched. FFN/MoE/Mamba sub-blocks

examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ puzzle_dir: /workspace/puzzle_dir
1515
# Toy KV-heads-only constraint.
1616
# Teacher has 6 attention layers × num_key_value_heads=2 = 12 KV heads total.
1717
# Target 9 leaves 75% of teacher KV heads — the MIP solver picks per-layer from
18-
# {teacher 2-head, 1-head, no_op} so some layers stay full, some collapse to 1
19-
# head, and some become no_op.
18+
# {teacher 2-head, 1-head} so some layers stay full and the rest collapse to 1
19+
# head.
2020
mip:
2121
human_constraints:
2222
target_num_kv_heads: 9
2323

2424
# KV-heads-only toy pruning task.
2525
# teacher num_attention_heads = 32, num_key_value_heads = 2 (n_heads_in_group = 16)
2626
# Bypass-trains a single 1-KV-head variant per attention layer
27-
# (n_heads_in_group = 32). Combined with `add_attention_no_ops: true` in the base
28-
# config, MIP picks per-layer from {teacher 2-head, 1-head, no_op}.
27+
# (n_heads_in_group = 32). Combined with `add_attention_no_ops: false` in the
28+
# base config, MIP picks per-layer from {teacher 2-head, 1-head}.
2929
pruning:
3030
n_heads_in_group_list: [32] # 32 / 32 = 1 KV head per attention layer

modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,42 @@
2626

2727
import modelopt.torch.utils.distributed as dist
2828
from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor
29-
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint
29+
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint_from_shards
3030
from modelopt.torch.puzzletron.tools.logger import aprint, mprint
3131
from modelopt.torch.utils.robust_json import json_dump
3232

3333
from .stitched_model_factory import StitchedModuleDescriptor
3434

3535

3636
def find_latest_run_dir(run_parent_dir: Union[str, Path]) -> str | None:
37-
"""Find the latest plain-iter checkpoint directory within a run parent directory.
37+
"""Find the latest plain-step checkpoint directory within a run parent directory.
3838
3939
Resume must pick a directory created by the step-interval / time-based / final save
40-
paths (named ``iter-NNNNNN-ckpt``) — not ``best-iter-*`` (which corresponds to a
40+
paths (named ``step-NNNNNN-ckpt``) — not ``best-step-*`` (which corresponds to a
4141
validation-best snapshot whose optimizer state may be stale relative to the latest
42-
iter), nor ``start-iter-*`` / ``final-iter-*`` (markers, not resume points).
42+
step), nor ``start-step-*`` / ``final-step-*`` (markers, not resume points).
4343
"""
4444
run_parent_dir = Path(run_parent_dir)
4545

4646
# Check for the "latest" symlink — set only by save_bypass_checkpoint, always
47-
# points at a plain ``iter-*`` directory. Fast path.
47+
# points at a plain ``step-*`` directory. Fast path.
4848
latest_dir = run_parent_dir / "latest"
4949
if latest_dir.exists() and (latest_dir / "saving_completed").exists():
5050
return str(latest_dir)
5151

52-
# Fallback: scan plain ``iter-NNNNNN-ckpt`` directories only.
53-
iter_re = re.compile(r"^iter-(\d+)-ckpt$")
52+
# Fallback: scan plain ``step-NNNNNN-ckpt`` directories only.
53+
# Treat a missing parent dir as "no previous runs" rather than fatal — this
54+
# handles two cases cleanly: a freshly-wiped bypass dir, and the race where
55+
# non-master ranks reach this function before master finishes the
56+
# ``set_experiment_dir`` mkdir on a shared filesystem.
57+
if not run_parent_dir.exists():
58+
return None
59+
step_re = re.compile(r"^step-(\d+)-ckpt$")
5460
candidate_dirs: list[tuple[int, Path]] = []
5561
for d in run_parent_dir.iterdir():
5662
if not d.is_dir():
5763
continue
58-
match = iter_re.match(d.name)
64+
match = step_re.match(d.name)
5965
if match:
6066
candidate_dirs.append((int(match.group(1)), d))
6167

@@ -207,8 +213,13 @@ def save_bypass_checkpoint(
207213
checkpoint_dir=checkpoint_dir,
208214
overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints,
209215
)
210-
# Save as HF checkpoint
211-
save_checkpoint(model=model, checkpoint_dir=checkpoint_dir, descriptor=descriptor)
216+
# Save as HF checkpoint. Must use the gather-aware variant: bypass training is
217+
# pipeline-parallel so each rank's `model.state_dict()` only carries its own
218+
# owned blocks. The unsharded `save_checkpoint` would have every rank write a
219+
# partial `model.safetensors.index.json` to the same path (last writer wins),
220+
# producing an index that omits most ranks' weights — resume then leaves params
221+
# on the meta device.
222+
save_checkpoint_from_shards(model=model, checkpoint_dir=checkpoint_dir, descriptor=descriptor)
212223

213224
if dist.is_master():
214225
# Create 'latest' symlink

0 commit comments

Comments
 (0)