Skip to content

Commit fe05ce3

Browse files
halleriteclaude
andcommitted
fix(inference): restore NemotronH mixer.D + e_score_correction_bias after vLLM reload
vLLM 0.22's layerwise reload mis-loads exactly two NemotronH per-layer param families through the online-reload path -- mixer.D (Mamba SSD skip) and the MoE router's gate.e_score_correction_bias -- while loading all other weights correctly. mixer.D becomes non-deterministic garbage/inf (NaN logits) and the gate bias gets a wrong value (broken routing), so generations go to NaN after a weight update. Restore both from the received broadcast (correct by definition) via each param's own weight_loader. Also drop monkey_patch_vllm_layerwise_reload_alias_buffers: it crashes on vLLM 0.22 (AttributeError on the delattr'd conv_weights) and conv_weights is handled correctly by vLLM's native reload finalize. Supersedes #2701. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent f619e36 commit fe05ce3

2 files changed

Lines changed: 51 additions & 34 deletions

File tree

src/prime_rl/inference/patches.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def transformers_v5_compat():
1818
_patch_qwen35_lora()
1919
_patch_lora_key_prefix()
2020
monkey_patch_deep_gemm_silu_mul_quant_int64()
21-
monkey_patch_vllm_layerwise_reload_alias_buffers()
2221
monkey_patch_vllm_padded_input_scrub()
2322
monkey_patch_return_routed_experts_with_nixl_connector()
2423

@@ -67,38 +66,6 @@ def _post_init(config: VllmConfig):
6766
logger.warning("Enabled vLLM routed-experts capture with NIXL connector patch.")
6867

6968

70-
def monkey_patch_vllm_layerwise_reload_alias_buffers():
71-
# vLLM's layerwise reload materializes each buffer as an independent tensor
72-
# and then copies it back into the original kernel storage. When a buffer
73-
# aliases a parameter (e.g. NemotronH Mamba's mixer.conv_weights, a view of
74-
# mixer.conv1d.weight), the buffer copy stamps garbage into the parameter's
75-
# storage *after* the parameter has been correctly reloaded. Skip the copy
76-
# for any buffer that shares storage with a parameter; _place_kernel_tensors
77-
# re-registers the original view, which trivially reflects the parameter.
78-
# Remove this patch once https://github.com/vllm-project/vllm/pull/42481 is
79-
# included in the vLLM release we pin/use.
80-
from vllm.logger import init_logger
81-
from vllm.model_executor.model_loader.reload import layerwise as reload_layerwise
82-
83-
logger = init_logger(__name__)
84-
85-
def _copy_and_restore_kernel_tensors(layer: torch.nn.Module, info: reload_layerwise.LayerReloadingInfo):
86-
assert info.kernel_tensors is not None
87-
parameters, buffers = info.kernel_tensors
88-
param_storage_ptrs = {p.untyped_storage().data_ptr() for p in layer.parameters(recurse=True)}
89-
for name, param in parameters.items():
90-
param.data.copy_(getattr(layer, name))
91-
for name, buffer in buffers.items():
92-
if buffer.untyped_storage().data_ptr() in param_storage_ptrs:
93-
continue
94-
buffer.data.copy_(getattr(layer, name))
95-
96-
reload_layerwise._place_kernel_tensors(layer, info)
97-
98-
reload_layerwise._copy_and_restore_kernel_tensors = _copy_and_restore_kernel_tensors
99-
logger.warning("Enabled vLLM layerwise reload alias-buffer patch.")
100-
101-
10269
@triton.jit
10370
def _silu_mul_per_token_group_quant_fp8_colmajor_int64_kernel(
10471
y_ptr,

src/prime_rl/inference/vllm/worker/nccl.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,45 @@
2424

2525
logger = init_logger("vllm.inference.vllm.worker_nccl")
2626

27+
# NemotronH params that vLLM 0.22's layerwise reload mis-loads through the online-reload path.
28+
_RELOAD_CORRUPTED_SUFFIXES = (".mixer.D", ".e_score_correction_bias")
29+
30+
31+
def _restore_reload_corrupted_params(model: Module, received: dict[str, torch.Tensor]) -> None:
32+
"""Work around a vLLM 0.22 layerwise-reload bug for NemotronH.
33+
34+
The online reload mis-loads exactly two per-layer parameter families -- ``mixer.D`` (Mamba SSD
35+
skip) and the MoE router's ``gate.e_score_correction_bias`` -- while loading all other weights
36+
correctly. ``mixer.D`` ends up as non-deterministic garbage/inf (NaN logits) and the gate bias
37+
gets a wrong value (broken expert routing), so generations go to NaN after a weight update.
38+
39+
The received broadcast value is correct, so restore those params from it via each param's own
40+
``weight_loader`` (which applies the right sharding). Remove once the upstream reload bug is fixed.
41+
"""
42+
43+
def _layer_key(name: str) -> str:
44+
index = name.find("layers.")
45+
return name[index:] if index >= 0 else name
46+
47+
received_by_key = {_layer_key(name): tensor for name, tensor in received.items()}
48+
restored = 0
49+
for name, param in model.named_parameters():
50+
if not name.endswith(_RELOAD_CORRUPTED_SUFFIXES):
51+
continue
52+
tensor = received_by_key.get(_layer_key(name))
53+
if tensor is None:
54+
continue
55+
tensor = tensor.to(device=param.device)
56+
weight_loader = getattr(param, "weight_loader", None)
57+
if weight_loader is not None:
58+
weight_loader(param, tensor)
59+
elif tensor.shape == param.shape:
60+
param.data.copy_(tensor.to(param.dtype))
61+
else:
62+
continue
63+
restored += 1
64+
logger.debug("Restored %d NemotronH params (mixer.D, e_score_correction_bias) after reload", restored)
65+
2766

2867
def receive_integer(communicator: PyNcclCommunicator) -> int:
2968
"""Receive an integer from the trainer master rank using NCCL communicator."""
@@ -148,9 +187,20 @@ def update_weights_from_path(self, weight_dir: str) -> None:
148187
update_mla_absorbed_weights(model)
149188
return
150189

190+
# vLLM 0.22's layerwise reload mis-loads NemotronH mixer.D and MoE gate.e_score_correction_bias
191+
# (see _restore_reload_corrupted_params). Capture the correct received values to restore after.
192+
received_reload_fix: dict[str, torch.Tensor] = {}
193+
194+
def _capture_reload_fix(weights):
195+
for name, tensor in weights:
196+
if name.endswith(_RELOAD_CORRUPTED_SUFFIXES):
197+
received_reload_fix[name] = tensor.detach().to("cpu", copy=True)
198+
yield name, tensor
199+
151200
load_weights_checkpoint_layerwise(
152201
model,
153-
state_iter,
202+
_capture_reload_fix(state_iter),
154203
self.model_runner.model_config,
155204
self.vllm_config,
156205
)
206+
_restore_reload_corrupted_params(model, received_reload_fix)

0 commit comments

Comments
 (0)