Skip to content

Commit 570ce05

Browse files
halleriteclaude
andcommitted
fix(inference): scope NemotronH reload fix to mixer.D
Direct measurement of the post-reload (pre-restore) state shows only mixer.D is actually corrupted by vLLM 0.22's layerwise online reload: its weight load is dropped and the param is left as uninitialized empty_strided memory -> non-deterministic garbage (NaN, inf, or huge finite values like 1e17) -> NaN logits. Same dtype (bf16) and strides as its neighbours dt_bias/A, which load fine, so it's a dropped load, not a dtype/stride issue. The MoE gate.e_score_correction_bias reloads correctly (post-reload value equals the received value exactly). It only appeared corrupted in an earlier norm-delta because the trainer broadcasts it shifted by -bias.min() (converting_nemotron_h.py) for bf16 representability -- a routing-invariant constant shift, not corruption. Restoring it was a no-op, so dropping it from the fix is behaviour-preserving. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent ff53c5c commit 570ce05

1 file changed

Lines changed: 13 additions & 9 deletions

File tree

  • src/prime_rl/inference/vllm/worker

src/prime_rl/inference/vllm/worker/nccl.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,23 @@
2424

2525
logger = init_logger("vllm.inference.vllm.worker_nccl")
2626

27-
# NemotronH params that vLLM 0.22's layerwise reload mis-loads through the online-reload path.
28-
_RELOAD_CORRUPTED_SUFFIXES = (".mixer.D", ".e_score_correction_bias")
27+
# NemotronH mixer.D is dropped by vLLM 0.22's layerwise online-reload path (left uninitialized).
28+
_RELOAD_CORRUPTED_SUFFIXES = (".mixer.D",)
2929

3030

3131
def _restore_reload_corrupted_params(model: Module, received: dict[str, torch.Tensor]) -> None:
3232
"""Work around a vLLM 0.22 layerwise-reload bug for NemotronH.
3333
34-
The online reload mis-loads exactly two per-layer parameter families -- ``mixer.D`` (Mamba SSD
35-
skip) and the MoE router's ``gate.e_score_correction_bias`` -- while loading all other weights
36-
correctly. ``mixer.D`` ends up as non-deterministic garbage/inf (NaN logits) and the gate bias
37-
gets a wrong value (broken expert routing), so generations go to NaN after a weight update.
34+
The online reload drops the weight load for every Mamba layer's ``mixer.D`` (the SSD skip
35+
connection): the param is materialized as uninitialized ``empty_strided`` memory and never
36+
written, so it reads back as non-deterministic garbage (NaN, inf, or huge finite values like
37+
1e17), which makes the logits NaN after a weight update. Measured directly -- D has the same
38+
dtype (bf16) and strides as its neighbours ``dt_bias``/``A`` (which load fine), so this is a
39+
dropped load, not a dtype/stride issue. (The MoE ``gate.e_score_correction_bias`` reloads
40+
correctly -- it only looked corrupted in a norm-delta because the trainer broadcasts it shifted
41+
by ``-bias.min()`` for bf16 representability, a routing-invariant constant shift.)
3842
39-
The received broadcast value is correct, so restore those params from it via each param's own
43+
The received broadcast value is correct, so restore D from it via the param's own
4044
``weight_loader`` (which applies the right sharding). Remove once the upstream reload bug is fixed.
4145
"""
4246

@@ -182,8 +186,8 @@ def update_weights_from_path(self, weight_dir: str) -> None:
182186
update_mla_absorbed_weights(model)
183187
return
184188

185-
# vLLM 0.22's layerwise reload mis-loads NemotronH mixer.D and MoE gate.e_score_correction_bias
186-
# (see _restore_reload_corrupted_params). Capture the correct received values to restore after.
189+
# vLLM 0.22's layerwise reload drops NemotronH mixer.D's weight load (see
190+
# _restore_reload_corrupted_params). Capture the correct received value to restore after.
187191
received_reload_fix: dict[str, torch.Tensor] = {}
188192

189193
def _capture_reload_fix(weights):

0 commit comments

Comments
 (0)