Commit 05a15fd
committed
Fix bypass distillation: PP-aware save, step semantics, resume bookkeeping
Several correctness/UX fixes to the bypass distillation stage, surfaced
while running the Nemotron-3-Nano-30B-A3B-Base-BF16 tutorial end-to-end.
Save / resume correctness
- save_bypass_checkpoint now calls save_checkpoint_from_shards (gather-
aware) instead of save_checkpoint. Under PP each rank's
model.state_dict() only carries its own owned blocks; the unsharded
variant let every rank race-write a partial model.safetensors.index.json
to the same path (last writer wins), so resume's load_and_shard_model
left most params on the meta device.
- find_latest_run_dir: treat missing parent dir as "no previous runs"
rather than fatal — a freshly-wiped bypass dir or a non-master rank
reaching the function before master's set_experiment_dir mkdir
becomes visible (no barrier between them) no longer crashes.
- puzzletron_nas_plugin "bypass already done" check now keys off a
_DONE sentinel written only on successful training completion. The
prior `latest` symlink check tripped on every periodic save, so a
Ctrl-C'd partial run looked completed and the pipeline skipped
ahead to stats.
- On resume, bump iter_num/step_num by 1 — args.json records the
iter that just completed (increments at the bottom of the loop run
after save), so restoring as-is re-executed the saved iter
(visible as "iter 899" instead of "iter 900" in the first
post-resume chunk).
- Persist per-block bookkeeping (best_losses_by_name,
best_steps_by_name, initial_losses_by_name) into cfg.bypass on
each log chunk and restore on resume; the columns no longer reset
to "current is best" / re-anchor "Δ from initial" to the resume
point.
Step / iter semantics
- max_steps formula now divides by grad_accumulation_steps. Previously
max_steps = ceil(training_tokens / tokens_per_iter) and
max_iters = max_steps × grad_accum, so the actual budget was
grad_accum × training_tokens (8× overshoot for the Nemotron config).
- warmup_steps resolver gains a grad_accum arg so warmup is in
optimizer-step units matching _get_lr's step_num indexing.
YAMLs that reference the resolver pass the new arg.
- Save-dir naming changed from iter-NNNNNN-ckpt /
{start,final,best}-iter-* to step-NNNNNN-ckpt /
{start,final,best}-step-* — saves were already step-gated, the iter
in the name was misleading. Tests updated to match.
- Log line now leads with step (matching what users mean by "a step")
and reports avg_step_time = grad_accum × per-iter; iter is dropped
from the headline entirely. wandb log keys follow.
- log_interval is now in optimizer-step units (multiplied by grad_accum
internally before slicing the per-iter history).
- Stitched Module Losses table: best_steps_by_name and the summary
line now record step_num (matching the header units), not iter.
UX / readability
- Stitched Module Losses: replace "Change from avg" (a same-step
cross-block diff masquerading as a temporal delta) with
"Δ from initial" — anchored at lowest_iter's per-block losses,
showing both absolute and % reduction. Column reordered to
Block / Loss / Δ from initial / Best Value / Best Step so the
delta sits next to Loss.
- Architecture plot emojis: 🐙 for kv-heads attention, 🔀 for MoE,
🧱 for dense FFN (ffn_dim renamed from ffn_intermediate). Filled
in the gaps left by 🐍 mamba / ❌ no_op.
Tutorial config
- nemotron-3-nano-30b-a3b: add_attention_no_ops false (no_op variant
is never picked by MIP at target_num_kv_heads=9 so scoring it is
wasted compute), training_tokens 1e+7 → 5e+7 (50M),
log_interval and save_interval set to 100 steps.
- Tutorial markdown updated for the 50M budget and the
attention-only variant set.
Test fixture: patched_save's monkeypatch updated from "save_checkpoint"
to "save_checkpoint_from_shards" to match the production change.
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>1 parent 88792f4 commit 05a15fd
13 files changed
Lines changed: 229 additions & 141 deletions
File tree
- examples/puzzletron
- configs
- llama-3_1-8B_pruneffn_memory/bypass
- nemotron-3-nano-30b-a3b
- bypass
- modelopt/torch/puzzletron
- bypass_distillation
- tools
- utils
- tests
- gpu/torch/puzzletron
- unit/torch/puzzletron
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
10 | | - | |
| 10 | + | |
11 | 11 | | |
12 | | - | |
| 12 | + | |
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| |||
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
73 | | - | |
| 73 | + | |
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
| |||
84 | 84 | | |
85 | 85 | | |
86 | 86 | | |
87 | | - | |
| 87 | + | |
88 | 88 | | |
89 | 89 | | |
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
94 | | - | |
| 94 | + | |
95 | 95 | | |
96 | 96 | | |
97 | 97 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
40 | | - | |
| 40 | + | |
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
| |||
Lines changed: 2 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
18 | | - | |
| 18 | + | |
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | | - | |
| 22 | + | |
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| |||
Lines changed: 3 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
41 | | - | |
| 41 | + | |
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | | - | |
| 45 | + | |
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
| |||
69 | 69 | | |
70 | 70 | | |
71 | 71 | | |
72 | | - | |
| 72 | + | |
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
| |||
Lines changed: 4 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
18 | | - | |
19 | | - | |
| 18 | + | |
| 19 | + | |
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
28 | | - | |
| 27 | + | |
| 28 | + | |
29 | 29 | | |
30 | 30 | | |
Lines changed: 21 additions & 10 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
| 29 | + | |
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
| 37 | + | |
38 | 38 | | |
39 | 39 | | |
40 | | - | |
| 40 | + | |
41 | 41 | | |
42 | | - | |
| 42 | + | |
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
46 | 46 | | |
47 | | - | |
| 47 | + | |
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
52 | | - | |
53 | | - | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
54 | 60 | | |
55 | 61 | | |
56 | 62 | | |
57 | 63 | | |
58 | | - | |
| 64 | + | |
59 | 65 | | |
60 | 66 | | |
61 | 67 | | |
| |||
207 | 213 | | |
208 | 214 | | |
209 | 215 | | |
210 | | - | |
211 | | - | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
212 | 223 | | |
213 | 224 | | |
214 | 225 | | |
| |||
0 commit comments