Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions packages/prime-rl-configs/src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,9 @@ def auto_setup_deployment(self):
self.inference.api_server_count = dp_per_node

if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl":
# Compute inference_world_size from actual worker count per server:
# each api_server runs tp workers that participate in collective_rpc.
api_server_count = self.inference.api_server_count if self.inference else 1
tp = self.inference.parallel.tp if self.inference else 1
total_infer_workers = self.deployment.total_infer_nodes * api_server_count * tp
# The multi-node RL launcher starts one TP-sharded backend for each
# DP slice on every allocated inference node.
total_infer_workers = self.deployment.total_infer_nodes * self.deployment.gpus_per_node
assert self.trainer.weight_broadcast.type == "nccl"
self.trainer.weight_broadcast.host = "0.0.0.0"
self.trainer.weight_broadcast.inference_world_size = total_infer_workers
Expand Down Expand Up @@ -600,22 +598,80 @@ def auto_setup_disaggregated_inference(self):
def auto_setup_inference_client(self):
"""Auto-configure orchestrator student client from the inference server config.

For all modes, sets dp_rank_count from inference DP size. For SFT mode,
also sets base_url - rl/opd rely on the ClientConfig default
(``["http://localhost:8000/v1"]``) which already matches the auto-launched
student vLLM at inference.server.port = 8000.
For most modes, sets dp_rank_count from inference DP size. Dense
multi-node external-LB launches route through vllm-router, so requests
must not carry vLLM DP-rank headers. For SFT mode, also sets base_url -
rl/opd rely on the ClientConfig default (``["http://localhost:8000/v1"]``)
which already matches the auto-launched student vLLM at inference.server.port = 8000.
"""
if self.inference is None:
return self
client = self.orchestrator.student.client
if "dp_rank_count" not in client.model_fields_set:
client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp
if (
self.deployment.type == "multi_node"
and self.inference.deployment.type != "disaggregated"
and self.inference.enable_expert_parallel
):
dp_per_node = self.deployment.gpus_per_node // self.inference.parallel.tp
client.dp_rank_count = self.deployment.num_infer_nodes * dp_per_node
elif (
self.deployment.type == "multi_node"
and self.inference.deployment.type != "disaggregated"
and not self.inference.enable_expert_parallel
):
client.dp_rank_count = 1
else:
client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp
if self.orchestrator.training_mode == "sft" and "base_url" not in client.model_fields_set:
host = self.inference.server.host or "localhost"
port = self.inference.server.port
client.base_url = [f"http://{host}:{port}/v1"]
return self

@model_validator(mode="after")
def validate_multi_node_inference_resolution(self):
if self.deployment.type != "multi_node" or self.inference is None:
return self

if self.trainer.weight_broadcast.type == "nccl" or self.orchestrator.weight_broadcast.type == "nccl":
expected_world_size = self.deployment.total_infer_nodes * self.deployment.gpus_per_node

if (
self.trainer.weight_broadcast.type == "nccl"
and self.trainer.weight_broadcast.inference_world_size != expected_world_size
):
raise ValueError(
"trainer.weight_broadcast.inference_world_size must match allocated inference GPUs "
f"({expected_world_size}) for multi-node NCCL weight broadcast."
)

if (
self.orchestrator.weight_broadcast.type == "nccl"
and self.orchestrator.weight_broadcast.inference_world_size != expected_world_size
):
raise ValueError(
"orchestrator.weight_broadcast.inference_world_size must match allocated inference GPUs "
f"({expected_world_size}) for multi-node NCCL weight broadcast."
)

if self.inference.deployment.type != "disaggregated":
if self.inference.enable_expert_parallel:
dp_per_node = self.deployment.gpus_per_node // self.inference.parallel.tp
expected_dp_rank_count = self.deployment.num_infer_nodes * dp_per_node
else:
expected_dp_rank_count = 1

client = self.orchestrator.student.client
if client.dp_rank_count != expected_dp_rank_count:
raise ValueError(
"orchestrator.student.client.dp_rank_count must resolve to "
f"{expected_dp_rank_count} for multi-node non-disaggregated inference; "
f"got {client.dp_rank_count}."
)

return self

@model_validator(mode="after")
def auto_setup_slurm_template(self):
"""Auto-setup the default single-node/multi-node SLURM template if no custom template is provided."""
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,59 @@ def test_explicit_inference_parser_wins_over_auto():
)
assert config.inference is not None
assert config.inference.model.tool_call_parser == "hermes"


def test_multi_node_dense_nccl_world_size_matches_inference_gpu_count_and_router_client():
config = RLConfig.model_validate(
{
"model": {"name": "Qwen/Qwen3-0.6B"},
"slurm": {},
"deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2},
"weight_broadcast": {"type": "nccl"},
"trainer": {},
"orchestrator": {"renderer": None},
"inference": {"parallel": {"tp": 4, "dp": 4}},
}
)

assert config.inference is not None
assert config.inference.api_server_count == 4
assert config.inference.data_parallel_size_local == 2
assert config.trainer.weight_broadcast.inference_world_size == 16
assert config.orchestrator.weight_broadcast.inference_world_size == 16
assert config.orchestrator.student.client.dp_rank_count == 1


def test_multi_node_dense_rejects_invalid_router_dp_rank_count():
with pytest.raises(ValidationError, match=r"dp_rank_count must resolve to 1"):
RLConfig.model_validate(
{
"model": {"name": "Qwen/Qwen3-0.6B"},
"slurm": {},
"deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2},
"weight_broadcast": {"type": "nccl"},
"trainer": {},
"orchestrator": {
"renderer": None,
"student": {"client": {"dp_rank_count": 4}},
},
"inference": {"parallel": {"tp": 4, "dp": 4}},
}
)


def test_multi_node_nccl_rejects_invalid_inference_world_size_override():
with pytest.raises(ValidationError, match=r"inference_world_size must match allocated inference GPUs"):
RLConfig.model_validate(
{
"model": {"name": "Qwen/Qwen3-0.6B"},
"slurm": {},
"deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2},
"trainer": {"weight_broadcast": {"type": "nccl", "inference_world_size": 32}},
"orchestrator": {
"renderer": None,
"weight_broadcast": {"type": "nccl", "inference_world_size": 32},
},
"inference": {"parallel": {"tp": 4, "dp": 4}},
}
)
Loading