Skip to content

Commit 7f6ca61

Browse files
authored
fix(inference): restore NemotronH mixer.D after vLLM 0.22 layerwise reload (#2714)
1 parent 434c9c0 commit 7f6ca61

2 files changed

Lines changed: 43 additions & 34 deletions

File tree

src/prime_rl/inference/patches.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def transformers_v5_compat():
1818
_patch_qwen35_lora()
1919
_patch_lora_key_prefix()
2020
monkey_patch_deep_gemm_silu_mul_quant_int64()
21-
monkey_patch_vllm_layerwise_reload_alias_buffers()
2221
monkey_patch_vllm_padded_input_scrub()
2322
monkey_patch_return_routed_experts_with_nixl_connector()
2423

@@ -67,38 +66,6 @@ def _post_init(config: VllmConfig):
6766
logger.warning("Enabled vLLM routed-experts capture with NIXL connector patch.")
6867

6968

70-
def monkey_patch_vllm_layerwise_reload_alias_buffers():
71-
# vLLM's layerwise reload materializes each buffer as an independent tensor
72-
# and then copies it back into the original kernel storage. When a buffer
73-
# aliases a parameter (e.g. NemotronH Mamba's mixer.conv_weights, a view of
74-
# mixer.conv1d.weight), the buffer copy stamps garbage into the parameter's
75-
# storage *after* the parameter has been correctly reloaded. Skip the copy
76-
# for any buffer that shares storage with a parameter; _place_kernel_tensors
77-
# re-registers the original view, which trivially reflects the parameter.
78-
# Remove this patch once https://github.com/vllm-project/vllm/pull/42481 is
79-
# included in the vLLM release we pin/use.
80-
from vllm.logger import init_logger
81-
from vllm.model_executor.model_loader.reload import layerwise as reload_layerwise
82-
83-
logger = init_logger(__name__)
84-
85-
def _copy_and_restore_kernel_tensors(layer: torch.nn.Module, info: reload_layerwise.LayerReloadingInfo):
86-
assert info.kernel_tensors is not None
87-
parameters, buffers = info.kernel_tensors
88-
param_storage_ptrs = {p.untyped_storage().data_ptr() for p in layer.parameters(recurse=True)}
89-
for name, param in parameters.items():
90-
param.data.copy_(getattr(layer, name))
91-
for name, buffer in buffers.items():
92-
if buffer.untyped_storage().data_ptr() in param_storage_ptrs:
93-
continue
94-
buffer.data.copy_(getattr(layer, name))
95-
96-
reload_layerwise._place_kernel_tensors(layer, info)
97-
98-
reload_layerwise._copy_and_restore_kernel_tensors = _copy_and_restore_kernel_tensors
99-
logger.warning("Enabled vLLM layerwise reload alias-buffer patch.")
100-
101-
10269
@triton.jit
10370
def _silu_mul_per_token_group_quant_fp8_colmajor_int64_kernel(
10471
y_ptr,

src/prime_rl/inference/vllm/worker/nccl.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,37 @@
2424

2525
logger = init_logger("vllm.inference.vllm.worker_nccl")
2626

27+
# NemotronH mixer.D is dropped by vLLM 0.22's layerwise online-reload path (left uninitialized).
28+
_RELOAD_CORRUPTED_SUFFIXES = (".mixer.D",)
29+
30+
31+
def _restore_reload_corrupted_params(model: Module, received: dict[str, torch.Tensor]) -> None:
32+
"""Work around a vLLM 0.22 layerwise-reload bug for NemotronH.
33+
34+
The online reload drops the weight load for every Mamba layer's ``mixer.D`` (the SSD skip
35+
connection), leaving it as uninitialized ``empty_strided`` memory -- it reads back as garbage
36+
(NaN/inf) and the logits go NaN after a weight update. The received broadcast value is correct,
37+
so restore D from it via the param's own ``weight_loader``. Remove once the upstream bug is fixed.
38+
"""
39+
40+
def _layer_key(name: str) -> str:
41+
index = name.find("layers.")
42+
return name[index:] if index >= 0 else name
43+
44+
received_by_key = {_layer_key(name): tensor for name, tensor in received.items()}
45+
for name, param in model.named_parameters():
46+
if not name.endswith(_RELOAD_CORRUPTED_SUFFIXES):
47+
continue
48+
tensor = received_by_key.get(_layer_key(name))
49+
if tensor is None:
50+
continue
51+
tensor = tensor.to(device=param.device)
52+
weight_loader = getattr(param, "weight_loader", None)
53+
if weight_loader is not None:
54+
weight_loader(param, tensor)
55+
elif tensor.shape == param.shape:
56+
param.data.copy_(tensor.to(param.dtype))
57+
2758

2859
def receive_integer(communicator: PyNcclCommunicator) -> int:
2960
"""Receive an integer from the trainer master rank using NCCL communicator."""
@@ -148,9 +179,20 @@ def update_weights_from_path(self, weight_dir: str) -> None:
148179
update_mla_absorbed_weights(model)
149180
return
150181

182+
# vLLM 0.22's layerwise reload drops NemotronH mixer.D's weight load (see
183+
# _restore_reload_corrupted_params). Capture the correct received value to restore after.
184+
received_reload_fix: dict[str, torch.Tensor] = {}
185+
186+
def _capture_reload_fix(weights):
187+
for name, tensor in weights:
188+
if name.endswith(_RELOAD_CORRUPTED_SUFFIXES):
189+
received_reload_fix[name] = tensor.detach().to("cpu", copy=True)
190+
yield name, tensor
191+
151192
load_weights_checkpoint_layerwise(
152193
model,
153-
state_iter,
194+
_capture_reload_fix(state_iter),
154195
self.model_runner.model_config,
155196
self.vllm_config,
156197
)
198+
_restore_reload_corrupted_params(model, received_reload_fix)

0 commit comments

Comments
 (0)