|
37 | 37 | ) |
38 | 38 |
|
39 | 39 |
|
| 40 | +def fix_gpt_oss_export_transpose(key: str, weight: torch.Tensor) -> torch.Tensor: |
| 41 | + """Apply GPT-OSS down_proj transpose fix to the weight. |
| 42 | +
|
| 43 | + This is a workaround for the issue that the down_proj layout is not the same across different frameworks. |
| 44 | + - HF needs [in, out] layout. |
| 45 | + - Megatron needs [in, out] layout. |
| 46 | + - vLLM needs [out, in] layout. |
| 47 | + See https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/3271 for more details. |
| 48 | + """ |
| 49 | + if key.endswith("mlp.experts.down_proj"): |
| 50 | + weight = weight.transpose(-2, -1).contiguous() |
| 51 | + return weight |
| 52 | + |
| 53 | + |
40 | 54 | class VllmInternalWorkerExtension: |
41 | 55 | def init_collective( |
42 | 56 | self, |
@@ -199,20 +213,30 @@ def update_weights_via_ipc_zmq(self) -> bool: |
199 | 213 | shape, dtype = self.state_dict_info[key] # pyrefly |
200 | 214 | if isinstance(shape, list): |
201 | 215 | shape = torch.Size(shape) |
| 216 | + |
| 217 | + # Get the weight from the buffer |
202 | 218 | size_in_bytes = dtype.itemsize * shape.numel() |
203 | | - weights.append( |
204 | | - ( |
205 | | - key, |
206 | | - buffer[offset : offset + size_in_bytes] |
207 | | - .view(dtype=dtype) |
208 | | - .view(shape), |
209 | | - ) |
| 219 | + weight = ( |
| 220 | + buffer[offset : offset + size_in_bytes] |
| 221 | + .view(dtype=dtype) |
| 222 | + .view(shape) |
210 | 223 | ) |
| 224 | + # apply gpt-oss transpose fix |
| 225 | + if ( |
| 226 | + "GptOssForCausalLM" |
| 227 | + in self.model_runner.vllm_config.model_config.architectures |
| 228 | + ): |
| 229 | + weight = fix_gpt_oss_export_transpose(key, weight) |
| 230 | + weights.append((key, weight)) |
| 231 | + |
| 232 | + # Move offset to the next weight |
211 | 233 | aligned_size = calculate_aligned_size(size_in_bytes) |
212 | 234 | offset += aligned_size |
| 235 | + |
213 | 236 | assert offset == used_bytes, ( |
214 | 237 | "Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info" |
215 | 238 | ) |
| 239 | + |
216 | 240 | # Load weights into the model |
217 | 241 | from nemo_rl.models.generation.vllm.quantization import fp8 |
218 | 242 |
|
@@ -276,6 +300,15 @@ def _load_model_weights(weights, model_runner): |
276 | 300 | """ |
277 | 301 | from nemo_rl.models.generation.vllm.quantization import fp8 |
278 | 302 |
|
| 303 | + # apply gpt-oss transpose fix |
| 304 | + if ( |
| 305 | + "GptOssForCausalLM" |
| 306 | + in self.model_runner.vllm_config.model_config.architectures |
| 307 | + ): |
| 308 | + for idx, (key, weight) in enumerate(weights): |
| 309 | + weight = fix_gpt_oss_export_transpose(key, weight) |
| 310 | + weights[idx] = (key, weight) |
| 311 | + |
279 | 312 | policy_weights, draft_weights = self._split_policy_and_draft_weights( |
280 | 313 | weights |
281 | 314 | ) |
|
0 commit comments