fix(inference): restore NemotronH mixer.D after vLLM 0.22 layerwise reload#2714
Merged
Conversation
…fter vLLM reload vLLM 0.22's layerwise reload mis-loads exactly two NemotronH per-layer param families through the online-reload path -- mixer.D (Mamba SSD skip) and the MoE router's gate.e_score_correction_bias -- while loading all other weights correctly. mixer.D becomes non-deterministic garbage/inf (NaN logits) and the gate bias gets a wrong value (broken routing), so generations go to NaN after a weight update. Restore both from the received broadcast (correct by definition) via each param's own weight_loader. Also drop monkey_patch_vllm_layerwise_reload_alias_buffers: it crashes on vLLM 0.22 (AttributeError on the delattr'd conv_weights) and conv_weights is handled correctly by vLLM's native reload finalize. Supersedes #2701. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
570ce05 to
a10d187
Compare
S1ro1
reviewed
Jun 4, 2026
| _RELOAD_CORRUPTED_SUFFIXES = (".mixer.D",) | ||
|
|
||
|
|
||
| def _restore_reload_corrupted_params(model: Module, received: dict[str, torch.Tensor]) -> None: |
Collaborator
There was a problem hiding this comment.
Let's remove this excessive comment, only the middle part is sufficient I think
Direct measurement of the post-reload (pre-restore) state shows only mixer.D is actually corrupted by vLLM 0.22's layerwise online reload: its weight load is dropped and the param is left as uninitialized empty_strided memory -> non-deterministic garbage (NaN, inf, or huge finite values like 1e17) -> NaN logits. Same dtype (bf16) and strides as its neighbours dt_bias/A, which load fine, so it's a dropped load, not a dtype/stride issue. The MoE gate.e_score_correction_bias reloads correctly (post-reload value equals the received value exactly). It only appeared corrupted in an earlier norm-delta because the trainer broadcasts it shifted by -bias.min() (converting_nemotron_h.py) for bf16 representability -- a routing-invariant constant shift, not corruption. Restoring it was a no-op, so dropping it from the fix is behaviour-preserving. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
a10d187 to
d0d32e3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
After a weight update, NemotronH inference produced NaN logits / garbage generations. Root cause: vLLM 0.22's layerwise online reload (
load_weights_checkpoint_layerwise) drops the weight load for every Mamba layer'smixer.D(the SSD skip connection). The reload materializes each layer's tensors as uninitializedempty_stridedmemory and then replays buffered loads;mixer.D's load is dropped, so it is never written and reads back as non-deterministic garbage —NaN,inf, or huge finite values (~1e17) — which makes the logits NaN.Confirmed by direct measurement that this is not a dtype/stride issue: post-reload
Disbf16 [8]contiguous, byte-identical in dtype/shape/stride to its neighboursdt_bias/A, which load fine.The symptom is delayed/non-deterministic because of async lag: the orchestrator runs a couple of steps on the initial weights, so the garbage only appears at the first step that uses reloaded weights (Mismatch KL jumps from ~1e-3 to 2–5).
This supersedes #2701: the existing
monkey_patch_vllm_layerwise_reload_alias_buffers(which #2701 tweaks) targetsconv_weights, but that alias is a red herring — vLLM's reload finalize re-derives it correctly. The monkey-patch's copy-back loop insteadgetattrsconv_weightsafter it's beendelattr'd, producingAttributeError: 'MambaMixer2' object has no attribute 'conv_weights'(500). #2701 crashes identically and does not address the realmixer.Ddrop.Why only
mixer.D(and note_score_correction_bias)An earlier per-tensor norm-delta also flagged the MoE router's
gate.e_score_correction_bias. Direct post-reload measurement shows that one is a false positive: its post-reload value equals the received broadcast value exactly (norms match to 4 decimals across layers). It only looked changed because the trainer broadcasts it shifted by-bias.min()(converting_nemotron_h.py, HF→prime; the prime→HF path renames but never re-adds the min) so the ~57-magnitude bias fits bf16 without its ~0.04 inter-expert spread collapsing. That is a routing-invariant constant shift (top-k is invariant to a constant added to every expert; routing weights come from raw sigmoid), so reloading the shifted value is correct — and reversing the shift would reintroduce the bf16 collapse. Soe_score_correction_biasis not corrupted, and the fix is scoped tomixer.Donly.Changes
monkey_patch_vllm_layerwise_reload_alias_buffers(call + definition). It crashes on vLLM 0.22, andconv_weightsis handled by vLLM's native reload finalize (#42481).mixer.Dafter reload (_restore_reload_corrupted_paramsin the NCCL weight-update worker): capture the received broadcast value for.mixer.Dwhile streaming intoload_weights_checkpoint_layerwise, then restore it via the param's ownweight_loader(correct sharding). The received value is by definition the intended one.Validation
2-node SLURM RL run (Nemotron-3-Nano-30B, reverse-text):
AttributeErroron the first reload).mixer.Dcomes backNaN/inf/1e17(non-deterministic, uninitialized), same dtype/strides asdt_bias/A;e_score_correction_biasequals the received value exactly.Notes
The underlying defect is in vLLM's layerwise reload — it conflates "elements copied" with "elements loaded." The reload finalizes a layer when
load_numel >= load_numel_total, where:load_numelis tracked byCopyCounter, aTorchDispatchModethat addsnumel()on everyaten.copy_.defaultop;load_numel_total = get_layer_size(layer)counts the layer's param elements (implicitly assuming one copy per element).These disagree for any loader that writes a param more than once. The Mamba mixer's
load_numel_totalis 24 (A+D+dt_bias, 8 elems each) and it streams params in the orderdt_bias, A_log, D, ….A's loader iscomposed_weight_loader(sharded_weight_loader(0), -exp), whosecomposed_loaderissues twocopy_calls into the 8-elementA: (1)default_weight_loader→param.data.copy_(shard), then (2)param.data.copy_(-exp(param))to post-process. SoCopyCounterattributes 16 toA, whileD/dt_bias(plain sharded loader, onecopy_) count 8 each.Result: after
dt_bias(8) +A(16),load_numel == 24 == load_numel_totaland_layerwise_processfinalizes the mixer — materializing it viaempty_stridedand replaying onlydt_bias+A— beforeD(third in the stream) arrives.D's late load then hits theonline_process_loader"Excessive loading" early-return and is dropped, leavingDuninitialized. Measured directly:[LW-PROC] buffered=['dt_bias','A'] D_in=False numel=24/24in 368/368 observations, and.mixer.Dis broadcast exactly once per layer (23×) — so this is a vLLM bug, not a conversion/broadcast bug. (vLLM's ownonline_process_loadercomment acknowledges a sibling case — qconfigs that "load the same weight multiple times" overshootingload_numel_total— but doesn't handle thecomposed_weight_loadertransform case.) This worker-side restore is a workaround that can be removed once fixed upstream.The
e_score_correction_bias-bias.min()shift inconverting_nemotron_h.pyis intentional and correct (bf16 representability, routing-invariant) — it should not be "reversed".Requires the separate NemotronH offline-init fix (
use_mamba_kernels=False, merged in fix(trainer): disable NemotronH HF-Hub mamba kernels for offline init #2713) for the trainer to start underHF_HUB_OFFLINE=1and exercise this path end-to-end.🤖 Generated with Claude Code
Note
Medium Risk
Changes live inference weight-update behavior for NemotronH; wrong restore logic could corrupt Mamba skip weights, but the fix is narrow (
.mixer.Donly) and uses received broadcast values plus existing loaders.Overview
Fixes NemotronH inference after NCCL weight updates on vLLM 0.22 by removing a broken layerwise-reload monkey-patch and re-applying
mixer.Dfrom the broadcast stream afterload_weights_checkpoint_layerwise.patches.py: Stops registeringmonkey_patch_vllm_layerwise_reload_alias_buffers(call and implementation). That patch tried to skip buffer copies that alias parameters; on 0.22 it canAttributeErroron reload finalize instead of fixing the real issue.nccl.py: Wraps the incoming weight iterator to snapshot.mixer.Dtensors on CPU, runs layerwise reload as before, then_restore_reload_corrupted_paramswrites them back via each param’sweight_loader(orcopy_fallback), keyed bylayers.*suffixes. This works around vLLM finalizing Mamba mixer layers beforeDis loaded, which left skip-connection weights uninitialized and produced NaN logits after the first reload step.Scope is NCCL online updates only (not the quantize/kernel path or filesystem loader).
Reviewed by Cursor Bugbot for commit d0d32e3. Bugbot is set up for automated code reviews on this repo. Configure here.