|
24 | 24 |
|
25 | 25 | logger = init_logger("vllm.inference.vllm.worker_nccl") |
26 | 26 |
|
| 27 | +# NemotronH mixer.D is dropped by vLLM 0.22's layerwise online-reload path (left uninitialized). |
| 28 | +_RELOAD_CORRUPTED_SUFFIXES = (".mixer.D",) |
| 29 | + |
| 30 | + |
| 31 | +def _restore_reload_corrupted_params(model: Module, received: dict[str, torch.Tensor]) -> None: |
| 32 | + """Work around a vLLM 0.22 layerwise-reload bug for NemotronH. |
| 33 | +
|
| 34 | + The online reload drops the weight load for every Mamba layer's ``mixer.D`` (the SSD skip |
| 35 | + connection), leaving it as uninitialized ``empty_strided`` memory -- it reads back as garbage |
| 36 | + (NaN/inf) and the logits go NaN after a weight update. The received broadcast value is correct, |
| 37 | + so restore D from it via the param's own ``weight_loader``. Remove once the upstream bug is fixed. |
| 38 | + """ |
| 39 | + |
| 40 | + def _layer_key(name: str) -> str: |
| 41 | + index = name.find("layers.") |
| 42 | + return name[index:] if index >= 0 else name |
| 43 | + |
| 44 | + received_by_key = {_layer_key(name): tensor for name, tensor in received.items()} |
| 45 | + for name, param in model.named_parameters(): |
| 46 | + if not name.endswith(_RELOAD_CORRUPTED_SUFFIXES): |
| 47 | + continue |
| 48 | + tensor = received_by_key.get(_layer_key(name)) |
| 49 | + if tensor is None: |
| 50 | + continue |
| 51 | + tensor = tensor.to(device=param.device) |
| 52 | + weight_loader = getattr(param, "weight_loader", None) |
| 53 | + if weight_loader is not None: |
| 54 | + weight_loader(param, tensor) |
| 55 | + elif tensor.shape == param.shape: |
| 56 | + param.data.copy_(tensor.to(param.dtype)) |
| 57 | + |
27 | 58 |
|
28 | 59 | def receive_integer(communicator: PyNcclCommunicator) -> int: |
29 | 60 | """Receive an integer from the trainer master rank using NCCL communicator.""" |
@@ -148,9 +179,20 @@ def update_weights_from_path(self, weight_dir: str) -> None: |
148 | 179 | update_mla_absorbed_weights(model) |
149 | 180 | return |
150 | 181 |
|
| 182 | + # vLLM 0.22's layerwise reload drops NemotronH mixer.D's weight load (see |
| 183 | + # _restore_reload_corrupted_params). Capture the correct received value to restore after. |
| 184 | + received_reload_fix: dict[str, torch.Tensor] = {} |
| 185 | + |
| 186 | + def _capture_reload_fix(weights): |
| 187 | + for name, tensor in weights: |
| 188 | + if name.endswith(_RELOAD_CORRUPTED_SUFFIXES): |
| 189 | + received_reload_fix[name] = tensor.detach().to("cpu", copy=True) |
| 190 | + yield name, tensor |
| 191 | + |
151 | 192 | load_weights_checkpoint_layerwise( |
152 | 193 | model, |
153 | | - state_iter, |
| 194 | + _capture_reload_fix(state_iter), |
154 | 195 | self.model_runner.model_config, |
155 | 196 | self.vllm_config, |
156 | 197 | ) |
| 198 | + _restore_reload_corrupted_params(model, received_reload_fix) |
0 commit comments