Skip to content

Commit d318709

Browse files
committed
fix(rm): raise clear error when context parallelism is used with DTensor RM training
Context parallelism (context_parallel_size > 1) is not supported for reward model training on the DTensor backend because the log_sigmoid operator lacks a DTensor sharding strategy for CP meshes. Instead of letting users hit cryptic runtime errors, raise a clear ValueError during setup with a link to the tracking issue. Signed-off-by: Terry Kong <terryk@nvidia.com>
1 parent a760f1c commit d318709

2 files changed

Lines changed: 79 additions & 1 deletion

File tree

nemo_rl/algorithms/rm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,21 @@ def setup(
116116

117117
# Extract individual configs for easier access
118118
policy_config = master_config.policy
119+
120+
# TODO(https://github.com/NVIDIA-NeMo/RL/issues/2482): remove once CP is supported for RM training.
121+
dtensor_cfg = policy_config.get("dtensor_cfg", {})
122+
if (
123+
dtensor_cfg.get("enabled", False)
124+
and dtensor_cfg.get("context_parallel_size", 1) > 1
125+
):
126+
raise ValueError(
127+
"Context parallelism (context_parallel_size > 1) is not supported for reward model "
128+
"training on the DTensor backend. The log_sigmoid operator used in the RM loss does "
129+
"not have a DTensor sharding strategy registered for CP meshes. "
130+
"Please set policy.dtensor_cfg.context_parallel_size=1. "
131+
"See https://github.com/NVIDIA-NeMo/RL/issues/2482 for tracking."
132+
)
133+
119134
data_config = master_config.data
120135
rm_config = master_config.rm
121136
logger_config = master_config.logger

tests/unit/algorithms/test_rm.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchdata.stateful_dataloader import StatefulDataLoader
2020

2121
from nemo_rl.algorithms.loss import PreferenceLossFn
22-
from nemo_rl.algorithms.rm import MasterConfig, _default_rm_save_state, rm_train
22+
from nemo_rl.algorithms.rm import MasterConfig, _default_rm_save_state, rm_train, setup
2323

2424

2525
@pytest.fixture
@@ -125,6 +125,69 @@ def val_iter(self):
125125
}
126126

127127

128+
def test_context_parallel_rejected_for_dtensor_rm():
129+
"""Test that context_parallel_size > 1 raises ValueError for DTensor RM training.
130+
131+
TODO(https://github.com/NVIDIA-NeMo/RL/issues/2482): remove when CP is supported for RM.
132+
"""
133+
config = MasterConfig.model_construct(
134+
**{
135+
"policy": {
136+
"dtensor_cfg": {
137+
"enabled": True,
138+
"context_parallel_size": 2,
139+
"tensor_parallel_size": 1,
140+
"sequence_parallel": False,
141+
"activation_checkpointing": False,
142+
"cpu_offload": False,
143+
},
144+
},
145+
"rm": {"seed": 42},
146+
"data": {},
147+
"logger": {},
148+
"cluster": {},
149+
"checkpointing": {},
150+
}
151+
)
152+
with pytest.raises(
153+
ValueError,
154+
match="Context parallelism.*is not supported for reward model training",
155+
):
156+
setup(config, MagicMock(), MagicMock(), {})
157+
158+
159+
def test_context_parallel_allowed_when_one():
160+
"""Test that context_parallel_size=1 does not raise for DTensor RM training.
161+
162+
We verify the CP check passes by confirming the error comes from a later
163+
setup stage, not from our validation.
164+
165+
TODO(https://github.com/NVIDIA-NeMo/RL/issues/2482): remove when CP is supported for RM.
166+
"""
167+
config = MasterConfig.model_construct(
168+
**{
169+
"policy": {
170+
"dtensor_cfg": {
171+
"enabled": True,
172+
"context_parallel_size": 1,
173+
"tensor_parallel_size": 1,
174+
"sequence_parallel": False,
175+
"activation_checkpointing": False,
176+
"cpu_offload": False,
177+
},
178+
},
179+
"rm": {"seed": 42},
180+
"data": {},
181+
"logger": {},
182+
"cluster": {},
183+
"checkpointing": {},
184+
}
185+
)
186+
with pytest.raises(Exception) as excinfo:
187+
setup(config, MagicMock(), MagicMock(), {})
188+
assert "Context parallelism" not in str(excinfo.value)
189+
190+
128191
def test_exit_on_max_steps(mock_components):
129192
"""Test that training loop exits when max_num_steps is reached"""
130193
# Set max steps to 12, which is less than len(train_dataloader) * max_num_epochs

0 commit comments

Comments
 (0)