Skip to content

Commit a10d187

Browse files
halleriteclaude
andcommitted
fix(inference): scope NemotronH reload fix to mixer.D
Direct measurement of the post-reload (pre-restore) state shows only mixer.D is actually corrupted by vLLM 0.22's layerwise online reload: its weight load is dropped and the param is left as uninitialized empty_strided memory -> non-deterministic garbage (NaN, inf, or huge finite values like 1e17) -> NaN logits. Same dtype (bf16) and strides as its neighbours dt_bias/A, which load fine, so it's a dropped load, not a dtype/stride issue. The MoE gate.e_score_correction_bias reloads correctly (post-reload value equals the received value exactly). It only appeared corrupted in an earlier norm-delta because the trainer broadcasts it shifted by -bias.min() (converting_nemotron_h.py) for bf16 representability -- a routing-invariant constant shift, not corruption. Restoring it was a no-op, so dropping it from the fix is behaviour-preserving. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent ff53c5c commit a10d187

1 file changed

Lines changed: 26 additions & 10 deletions

File tree

  • src/prime_rl/inference/vllm/worker

src/prime_rl/inference/vllm/worker/nccl.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,35 @@
2424

2525
logger = init_logger("vllm.inference.vllm.worker_nccl")
2626

27-
# NemotronH params that vLLM 0.22's layerwise reload mis-loads through the online-reload path.
28-
_RELOAD_CORRUPTED_SUFFIXES = (".mixer.D", ".e_score_correction_bias")
27+
# NemotronH mixer.D is dropped by vLLM 0.22's layerwise online-reload path (left uninitialized).
28+
_RELOAD_CORRUPTED_SUFFIXES = (".mixer.D",)
2929

3030

3131
def _restore_reload_corrupted_params(model: Module, received: dict[str, torch.Tensor]) -> None:
3232
"""Work around a vLLM 0.22 layerwise-reload bug for NemotronH.
3333
34-
The online reload mis-loads exactly two per-layer parameter families -- ``mixer.D`` (Mamba SSD
35-
skip) and the MoE router's ``gate.e_score_correction_bias`` -- while loading all other weights
36-
correctly. ``mixer.D`` ends up as non-deterministic garbage/inf (NaN logits) and the gate bias
37-
gets a wrong value (broken expert routing), so generations go to NaN after a weight update.
38-
39-
The received broadcast value is correct, so restore those params from it via each param's own
34+
The online reload drops the weight load for every Mamba layer's ``mixer.D`` (the SSD skip
35+
connection): the param is materialized as uninitialized ``empty_strided`` memory and never
36+
written, so it reads back as non-deterministic garbage (NaN, inf, or huge finite values like
37+
1e17), which makes the logits NaN after a weight update. Measured directly -- D has the same
38+
dtype (bf16) and strides as its neighbours ``dt_bias``/``A`` (which load fine), so this is a
39+
dropped load, not a dtype/stride issue.
40+
41+
Precise trigger (instrumented): the mixer streams its params in the order dt_bias, A_log, D, ...
42+
and its ``load_numel_total`` is 24 (A+D+dt_bias, 8 each). ``A``'s loader is
43+
``composed_weight_loader(sharded_weight_loader, -exp)``, whose extra copy makes vLLM's
44+
``CopyCounter`` attribute 16 elements to the 8-element ``A``. So after dt_bias (8) + A (16),
45+
``load_numel`` already equals ``load_numel_total`` and ``_layerwise_process`` finalizes the mixer
46+
-- materializing it via ``empty_strided`` and replaying only dt_bias+A -- before ``D`` (third in
47+
the stream) arrives; ``D``'s late load then hits the "Excessive loading" early-return and is
48+
dropped. (D is broadcast correctly, exactly once per layer, so this is a vLLM bug, not a
49+
conversion/broadcast bug.)
50+
51+
(The MoE ``gate.e_score_correction_bias`` reloads correctly -- it only looked corrupted in a
52+
norm-delta because the trainer broadcasts it shifted by ``-bias.min()`` for bf16 representability,
53+
a routing-invariant constant shift.)
54+
55+
The received broadcast value is correct, so restore D from it via the param's own
4056
``weight_loader`` (which applies the right sharding). Remove once the upstream reload bug is fixed.
4157
"""
4258

@@ -182,8 +198,8 @@ def update_weights_from_path(self, weight_dir: str) -> None:
182198
update_mla_absorbed_weights(model)
183199
return
184200

185-
# vLLM 0.22's layerwise reload mis-loads NemotronH mixer.D and MoE gate.e_score_correction_bias
186-
# (see _restore_reload_corrupted_params). Capture the correct received values to restore after.
201+
# vLLM 0.22's layerwise reload drops NemotronH mixer.D's weight load (see
202+
# _restore_reload_corrupted_params). Capture the correct received value to restore after.
187203
received_reload_fix: dict[str, torch.Tensor] = {}
188204

189205
def _capture_reload_fix(weights):

0 commit comments

Comments
 (0)