Skip to content

Commit 786b454

Browse files
yuki-97NeMo Bot
authored andcommitted
fix: fix gpt oss export + bump mbridge (#2249)
Signed-off-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
1 parent c8d167a commit 786b454

8 files changed

Lines changed: 52 additions & 40 deletions

File tree

Submodule Megatron-Bridge updated 161 files

3rdparty/Megatron-Bridge-workspace/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"flash-linear-attention",
5757
"timm",
5858
"open-clip-torch>=3.2.0",
59-
"mlflow>=3.5.0",
59+
"mlflow>=3.9.0",
6060
"comet-ml>=3.50.0",
6161
"torch>=2.6.0",
6262
]
Submodule Megatron-LM updated 225 files

3rdparty/Megatron-LM-workspace/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
# TODO(https://github.com/NVIDIA-NeMo/RL/issues/2111): upgrade to core_cu13 when we move to CUDA 13 base container
5252
"transformer-engine[pytorch,core_cu12]",
5353
# VCS dependency - must match pyproject.toml [tool.uv.sources]
54-
"nvidia-resiliency-ext @ git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@63154570cea17f8805a7fd15cc3b8cc2919ba575",
54+
"nvidia-resiliency-ext @ git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@15a851565a4ce846c04431ecb0cf09903ab4837e",
5555
"tqdm",
5656
"einops~=0.8",
5757
"tensorstore~=0.1,!=0.1.46,!=0.1.72",

nemo_rl/models/generation/vllm/vllm_backend.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,20 @@
3737
)
3838

3939

40+
def fix_gpt_oss_export_transpose(key: str, weight: torch.Tensor) -> torch.Tensor:
41+
"""Apply GPT-OSS down_proj transpose fix to the weight.
42+
43+
This is a workaround for the issue that the down_proj layout is not the same across different frameworks.
44+
- HF needs [in, out] layout.
45+
- Megatron needs [in, out] layout.
46+
- vLLM needs [out, in] layout.
47+
See https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/3271 for more details.
48+
"""
49+
if key.endswith("mlp.experts.down_proj"):
50+
weight = weight.transpose(-2, -1).contiguous()
51+
return weight
52+
53+
4054
class VllmInternalWorkerExtension:
4155
def init_collective(
4256
self,
@@ -199,20 +213,30 @@ def update_weights_via_ipc_zmq(self) -> bool:
199213
shape, dtype = self.state_dict_info[key] # pyrefly
200214
if isinstance(shape, list):
201215
shape = torch.Size(shape)
216+
217+
# Get the weight from the buffer
202218
size_in_bytes = dtype.itemsize * shape.numel()
203-
weights.append(
204-
(
205-
key,
206-
buffer[offset : offset + size_in_bytes]
207-
.view(dtype=dtype)
208-
.view(shape),
209-
)
219+
weight = (
220+
buffer[offset : offset + size_in_bytes]
221+
.view(dtype=dtype)
222+
.view(shape)
210223
)
224+
# apply gpt-oss transpose fix
225+
if (
226+
"GptOssForCausalLM"
227+
in self.model_runner.vllm_config.model_config.architectures
228+
):
229+
weight = fix_gpt_oss_export_transpose(key, weight)
230+
weights.append((key, weight))
231+
232+
# Move offset to the next weight
211233
aligned_size = calculate_aligned_size(size_in_bytes)
212234
offset += aligned_size
235+
213236
assert offset == used_bytes, (
214237
"Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info"
215238
)
239+
216240
# Load weights into the model
217241
from nemo_rl.models.generation.vllm.quantization import fp8
218242

@@ -276,6 +300,15 @@ def _load_model_weights(weights, model_runner):
276300
"""
277301
from nemo_rl.models.generation.vllm.quantization import fp8
278302

303+
# apply gpt-oss transpose fix
304+
if (
305+
"GptOssForCausalLM"
306+
in self.model_runner.vllm_config.model_config.architectures
307+
):
308+
for idx, (key, weight) in enumerate(weights):
309+
weight = fix_gpt_oss_export_transpose(key, weight)
310+
weights[idx] = (key, weight)
311+
279312
policy_weights, draft_weights = self._split_policy_and_draft_weights(
280313
weights
281314
)

nemo_rl/models/megatron/setup.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -142,29 +142,6 @@ def destroy_parallel_state():
142142
except ImportError:
143143
pass
144144

145-
# Reset the third global async_calls instance in base strategy module
146-
try:
147-
import megatron.core.dist_checkpointing.strategies.base as base_strategy
148-
from megatron.core.dist_checkpointing.strategies.async_utils import (
149-
AsyncCallsQueue,
150-
)
151-
152-
# Clean up and reset the global async_calls in base strategy
153-
old_call_idx = getattr(base_strategy.async_calls, "call_idx", None)
154-
num_unfinalized = base_strategy.async_calls.get_num_unfinalized_calls()
155-
if num_unfinalized > 0:
156-
print(
157-
f"[WARNING] Resetting base strategy async_calls with {num_unfinalized} unfinalized calls"
158-
)
159-
try:
160-
base_strategy.async_calls.close()
161-
except:
162-
pass
163-
base_strategy.async_calls = AsyncCallsQueue()
164-
print(f"[DEBUG] Reset base strategy async_calls (old call_idx: {old_call_idx})")
165-
except ImportError:
166-
pass
167-
168145

169146
def setup_distributed() -> None:
170147
"""Handle NCCL settings, dtype mapping, and basic config setup."""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ requires-dist = [
497497
"flash-linear-attention",
498498
"timm",
499499
"open-clip-torch>=3.2.0",
500-
"mlflow>=3.5.0",
500+
"mlflow>=3.9.0",
501501
"comet-ml>=3.50.0",
502502
"torch>=2.6.0",
503503
]

uv.lock

Lines changed: 7 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)