From 1ea47c11ed32c7674b476779b1f670a03041d043 Mon Sep 17 00:00:00 2001 From: sami jaghouar Date: Thu, 4 Jun 2026 06:00:09 +0530 Subject: [PATCH 1/2] fix multi-node inference broadcast sizing --- .../src/prime_rl/configs/rl.py | 33 +++++++++++++------ skills/configs/SKILL.md | 6 ++++ tests/unit/test_configs.py | 21 ++++++++++++ 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index 1af3436d1b..d308524c78 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -548,11 +548,9 @@ def auto_setup_deployment(self): self.inference.api_server_count = dp_per_node if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl": - # Compute inference_world_size from actual worker count per server: - # each api_server runs tp workers that participate in collective_rpc. - api_server_count = self.inference.api_server_count if self.inference else 1 - tp = self.inference.parallel.tp if self.inference else 1 - total_infer_workers = self.deployment.total_infer_nodes * api_server_count * tp + # The multi-node RL launcher starts one TP-sharded backend for each + # DP slice on every allocated inference node. + total_infer_workers = self.deployment.total_infer_nodes * self.deployment.gpus_per_node assert self.trainer.weight_broadcast.type == "nccl" self.trainer.weight_broadcast.host = "0.0.0.0" self.trainer.weight_broadcast.inference_world_size = total_infer_workers @@ -600,16 +598,31 @@ def auto_setup_disaggregated_inference(self): def auto_setup_inference_client(self): """Auto-configure orchestrator student client from the inference server config. - For all modes, sets dp_rank_count from inference DP size. For SFT mode, - also sets base_url - rl/opd rely on the ClientConfig default - (``["http://localhost:8000/v1"]``) which already matches the auto-launched - student vLLM at inference.server.port = 8000. + For most modes, sets dp_rank_count from inference DP size. Dense + multi-node external-LB launches route through vllm-router, so requests + must not carry vLLM DP-rank headers. For SFT mode, also sets base_url - + rl/opd rely on the ClientConfig default (``["http://localhost:8000/v1"]``) + which already matches the auto-launched student vLLM at inference.server.port = 8000. """ if self.inference is None: return self client = self.orchestrator.student.client if "dp_rank_count" not in client.model_fields_set: - client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp + if ( + self.deployment.type == "multi_node" + and self.inference.deployment.type != "disaggregated" + and self.inference.enable_expert_parallel + ): + dp_per_node = self.deployment.gpus_per_node // self.inference.parallel.tp + client.dp_rank_count = self.deployment.num_infer_nodes * dp_per_node + elif ( + self.deployment.type == "multi_node" + and self.inference.deployment.type != "disaggregated" + and not self.inference.enable_expert_parallel + ): + client.dp_rank_count = 1 + else: + client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp if self.orchestrator.training_mode == "sft" and "base_url" not in client.model_fields_set: host = self.inference.server.host or "localhost" port = self.inference.server.port diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 83f7dd8d47..2b14b01ae2 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -28,6 +28,12 @@ uv run rl --help # all fields and defaults uv run rl @ rl.toml --dry-run --output-dir /tmp/x # write resolved TOML to /tmp/x/configs ``` +For multi-node RL with NCCL weight broadcast, inspect the resolved trainer and +orchestrator TOML after `--dry-run`. `inference_world_size` should match allocated +inference GPUs. For dense external-LB router launches, the orchestrator student's +`dp_rank_count` should stay `1`; admin URLs cover every backend for weight updates, +while the router handles request load balancing/stickiness. + ## Validators Incompatible combinations (e.g. CP requires flash attention) must raise in a `model_validator` at resolve time, not at runtime. When renaming a field, emit a deprecation warning with a migration hint — never silently drop. diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index fcdee7a843..b5dd311b67 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -537,3 +537,24 @@ def test_explicit_inference_parser_wins_over_auto(): ) assert config.inference is not None assert config.inference.model.tool_call_parser == "hermes" + + +def test_multi_node_dense_nccl_world_size_matches_inference_gpu_count_and_router_client(): + config = RLConfig.model_validate( + { + "model": {"name": "Qwen/Qwen3-0.6B"}, + "slurm": {}, + "deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2}, + "weight_broadcast": {"type": "nccl"}, + "trainer": {}, + "orchestrator": {"renderer": None}, + "inference": {"parallel": {"tp": 4, "dp": 4}}, + } + ) + + assert config.inference is not None + assert config.inference.api_server_count == 4 + assert config.inference.data_parallel_size_local == 2 + assert config.trainer.weight_broadcast.inference_world_size == 16 + assert config.orchestrator.weight_broadcast.inference_world_size == 16 + assert config.orchestrator.student.client.dp_rank_count == 1 From cf06820774cdb37a78622d82e3a4df4ffab5ef41 Mon Sep 17 00:00:00 2001 From: sami jaghouar Date: Thu, 4 Jun 2026 07:09:12 +0530 Subject: [PATCH 2/2] validate multi-node inference config invariants --- .../src/prime_rl/configs/rl.py | 43 +++++++++++++++++++ skills/configs/SKILL.md | 6 --- tests/unit/test_configs.py | 35 +++++++++++++++ 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index d308524c78..013fe90cb0 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -629,6 +629,49 @@ def auto_setup_inference_client(self): client.base_url = [f"http://{host}:{port}/v1"] return self + @model_validator(mode="after") + def validate_multi_node_inference_resolution(self): + if self.deployment.type != "multi_node" or self.inference is None: + return self + + if self.trainer.weight_broadcast.type == "nccl" or self.orchestrator.weight_broadcast.type == "nccl": + expected_world_size = self.deployment.total_infer_nodes * self.deployment.gpus_per_node + + if ( + self.trainer.weight_broadcast.type == "nccl" + and self.trainer.weight_broadcast.inference_world_size != expected_world_size + ): + raise ValueError( + "trainer.weight_broadcast.inference_world_size must match allocated inference GPUs " + f"({expected_world_size}) for multi-node NCCL weight broadcast." + ) + + if ( + self.orchestrator.weight_broadcast.type == "nccl" + and self.orchestrator.weight_broadcast.inference_world_size != expected_world_size + ): + raise ValueError( + "orchestrator.weight_broadcast.inference_world_size must match allocated inference GPUs " + f"({expected_world_size}) for multi-node NCCL weight broadcast." + ) + + if self.inference.deployment.type != "disaggregated": + if self.inference.enable_expert_parallel: + dp_per_node = self.deployment.gpus_per_node // self.inference.parallel.tp + expected_dp_rank_count = self.deployment.num_infer_nodes * dp_per_node + else: + expected_dp_rank_count = 1 + + client = self.orchestrator.student.client + if client.dp_rank_count != expected_dp_rank_count: + raise ValueError( + "orchestrator.student.client.dp_rank_count must resolve to " + f"{expected_dp_rank_count} for multi-node non-disaggregated inference; " + f"got {client.dp_rank_count}." + ) + + return self + @model_validator(mode="after") def auto_setup_slurm_template(self): """Auto-setup the default single-node/multi-node SLURM template if no custom template is provided.""" diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 2b14b01ae2..83f7dd8d47 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -28,12 +28,6 @@ uv run rl --help # all fields and defaults uv run rl @ rl.toml --dry-run --output-dir /tmp/x # write resolved TOML to /tmp/x/configs ``` -For multi-node RL with NCCL weight broadcast, inspect the resolved trainer and -orchestrator TOML after `--dry-run`. `inference_world_size` should match allocated -inference GPUs. For dense external-LB router launches, the orchestrator student's -`dp_rank_count` should stay `1`; admin URLs cover every backend for weight updates, -while the router handles request load balancing/stickiness. - ## Validators Incompatible combinations (e.g. CP requires flash attention) must raise in a `model_validator` at resolve time, not at runtime. When renaming a field, emit a deprecation warning with a migration hint — never silently drop. diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index b5dd311b67..b2afeb0a6f 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -558,3 +558,38 @@ def test_multi_node_dense_nccl_world_size_matches_inference_gpu_count_and_router assert config.trainer.weight_broadcast.inference_world_size == 16 assert config.orchestrator.weight_broadcast.inference_world_size == 16 assert config.orchestrator.student.client.dp_rank_count == 1 + + +def test_multi_node_dense_rejects_invalid_router_dp_rank_count(): + with pytest.raises(ValidationError, match=r"dp_rank_count must resolve to 1"): + RLConfig.model_validate( + { + "model": {"name": "Qwen/Qwen3-0.6B"}, + "slurm": {}, + "deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2}, + "weight_broadcast": {"type": "nccl"}, + "trainer": {}, + "orchestrator": { + "renderer": None, + "student": {"client": {"dp_rank_count": 4}}, + }, + "inference": {"parallel": {"tp": 4, "dp": 4}}, + } + ) + + +def test_multi_node_nccl_rejects_invalid_inference_world_size_override(): + with pytest.raises(ValidationError, match=r"inference_world_size must match allocated inference GPUs"): + RLConfig.model_validate( + { + "model": {"name": "Qwen/Qwen3-0.6B"}, + "slurm": {}, + "deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2}, + "trainer": {"weight_broadcast": {"type": "nccl", "inference_world_size": 32}}, + "orchestrator": { + "renderer": None, + "weight_broadcast": {"type": "nccl", "inference_world_size": 32}, + }, + "inference": {"parallel": {"tp": 4, "dp": 4}}, + } + )