Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/design-docs/nemo-gym-integration.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NeMo Gym Integration

This document describes how NeMo RL integrates with [NeMo Gym](https://docs.nvidia.com/nemo/gym/latest/index.html) for multi-step and multi-turn reinforcement learning training.
This document describes how NeMo RL integrates with [NeMo Gym](https://docs.nvidia.com/nemo/gym/v0.2.1/index.html) for multi-step and multi-turn reinforcement learning training.

## Overview

Expand Down Expand Up @@ -181,7 +181,7 @@ sequenceDiagram
GRPO->>Policy: Compute loss and train
```

> **NeMo Gym server types** (see [Core Components](https://docs.nvidia.com/nemo/gym/latest/about/concepts/core-components.html)):
> **NeMo Gym server types** (see [Core Components](https://docs.nvidia.com/nemo/gym/v0.2.1/about/concepts/core-components/)):
> - **Agent Server**: Orchestrates the rollout loop
> - **Model Server**: HTTP proxy to vLLM; translates Responses API ↔ Chat Completions
> - **Resource Server**: Provides tools and rewards
Expand Down Expand Up @@ -254,4 +254,4 @@ Token IDs are extracted at the NeMo RL vLLM layer via the `/tokenize` endpoint.
- Tokenization matches the exact model and tokenizer used for generation
- No re-tokenization drift between generation and training

For details on on-policy token ID handling, see {doc}`../guides/environments` and the [NeMo Gym on-policy corrections documentation](https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html).
For details on on-policy token ID handling, see {doc}`../guides/environments` and the [NeMo Gym on-policy corrections documentation](https://docs.nvidia.com/nemo/gym/v0.2.1/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html).
15 changes: 15 additions & 0 deletions nemo_rl/algorithms/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ def setup(

# Extract individual configs for easier access
policy_config = master_config.policy

# TODO(https://github.com/NVIDIA-NeMo/RL/issues/2482): remove once CP is supported for RM training.
dtensor_cfg = policy_config.get("dtensor_cfg", {})
if (
dtensor_cfg.get("enabled", False)
and dtensor_cfg.get("context_parallel_size", 1) > 1
):
raise ValueError(
"Context parallelism (context_parallel_size > 1) is not supported for reward model "
"training on the DTensor backend. The log_sigmoid operator used in the RM loss does "
"not have a DTensor sharding strategy registered for CP meshes. "
"Please set policy.dtensor_cfg.context_parallel_size=1. "
"See https://github.com/NVIDIA-NeMo/RL/issues/2482 for tracking."
)

data_config = master_config.data
rm_config = master_config.rm
logger_config = master_config.logger
Expand Down
65 changes: 64 additions & 1 deletion tests/unit/algorithms/test_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchdata.stateful_dataloader import StatefulDataLoader

from nemo_rl.algorithms.loss import PreferenceLossFn
from nemo_rl.algorithms.rm import MasterConfig, _default_rm_save_state, rm_train
from nemo_rl.algorithms.rm import MasterConfig, _default_rm_save_state, rm_train, setup


@pytest.fixture
Expand Down Expand Up @@ -125,6 +125,69 @@ def val_iter(self):
}


def test_context_parallel_rejected_for_dtensor_rm():
"""Test that context_parallel_size > 1 raises ValueError for DTensor RM training.

TODO(https://github.com/NVIDIA-NeMo/RL/issues/2482): remove when CP is supported for RM.
"""
config = MasterConfig.model_construct(
**{
"policy": {
"dtensor_cfg": {
"enabled": True,
"context_parallel_size": 2,
"tensor_parallel_size": 1,
"sequence_parallel": False,
"activation_checkpointing": False,
"cpu_offload": False,
},
},
"rm": {"seed": 42},
"data": {},
"logger": {},
"cluster": {},
"checkpointing": {},
}
)
with pytest.raises(
ValueError,
match="Context parallelism.*is not supported for reward model training",
):
setup(config, MagicMock(), MagicMock(), {})


def test_context_parallel_allowed_when_one():
"""Test that context_parallel_size=1 does not raise for DTensor RM training.

We verify the CP check passes by confirming the error comes from a later
setup stage, not from our validation.

TODO(https://github.com/NVIDIA-NeMo/RL/issues/2482): remove when CP is supported for RM.
"""
config = MasterConfig.model_construct(
**{
"policy": {
"dtensor_cfg": {
"enabled": True,
"context_parallel_size": 1,
"tensor_parallel_size": 1,
"sequence_parallel": False,
"activation_checkpointing": False,
"cpu_offload": False,
},
},
"rm": {"seed": 42},
"data": {},
"logger": {},
"cluster": {},
"checkpointing": {},
}
)
with pytest.raises(Exception) as excinfo:
setup(config, MagicMock(), MagicMock(), {})
assert "Context parallelism" not in str(excinfo.value)


def test_exit_on_max_steps(mock_components):
"""Test that training loop exits when max_num_steps is reached"""
# Set max steps to 12, which is less than len(train_dataloader) * max_num_epochs
Expand Down
Loading