diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index 1af3436d1b..013fe90cb0 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -548,11 +548,9 @@ def auto_setup_deployment(self): self.inference.api_server_count = dp_per_node if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl": - # Compute inference_world_size from actual worker count per server: - # each api_server runs tp workers that participate in collective_rpc. - api_server_count = self.inference.api_server_count if self.inference else 1 - tp = self.inference.parallel.tp if self.inference else 1 - total_infer_workers = self.deployment.total_infer_nodes * api_server_count * tp + # The multi-node RL launcher starts one TP-sharded backend for each + # DP slice on every allocated inference node. + total_infer_workers = self.deployment.total_infer_nodes * self.deployment.gpus_per_node assert self.trainer.weight_broadcast.type == "nccl" self.trainer.weight_broadcast.host = "0.0.0.0" self.trainer.weight_broadcast.inference_world_size = total_infer_workers @@ -600,22 +598,80 @@ def auto_setup_disaggregated_inference(self): def auto_setup_inference_client(self): """Auto-configure orchestrator student client from the inference server config. - For all modes, sets dp_rank_count from inference DP size. For SFT mode, - also sets base_url - rl/opd rely on the ClientConfig default - (``["http://localhost:8000/v1"]``) which already matches the auto-launched - student vLLM at inference.server.port = 8000. + For most modes, sets dp_rank_count from inference DP size. Dense + multi-node external-LB launches route through vllm-router, so requests + must not carry vLLM DP-rank headers. For SFT mode, also sets base_url - + rl/opd rely on the ClientConfig default (``["http://localhost:8000/v1"]``) + which already matches the auto-launched student vLLM at inference.server.port = 8000. """ if self.inference is None: return self client = self.orchestrator.student.client if "dp_rank_count" not in client.model_fields_set: - client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp + if ( + self.deployment.type == "multi_node" + and self.inference.deployment.type != "disaggregated" + and self.inference.enable_expert_parallel + ): + dp_per_node = self.deployment.gpus_per_node // self.inference.parallel.tp + client.dp_rank_count = self.deployment.num_infer_nodes * dp_per_node + elif ( + self.deployment.type == "multi_node" + and self.inference.deployment.type != "disaggregated" + and not self.inference.enable_expert_parallel + ): + client.dp_rank_count = 1 + else: + client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp if self.orchestrator.training_mode == "sft" and "base_url" not in client.model_fields_set: host = self.inference.server.host or "localhost" port = self.inference.server.port client.base_url = [f"http://{host}:{port}/v1"] return self + @model_validator(mode="after") + def validate_multi_node_inference_resolution(self): + if self.deployment.type != "multi_node" or self.inference is None: + return self + + if self.trainer.weight_broadcast.type == "nccl" or self.orchestrator.weight_broadcast.type == "nccl": + expected_world_size = self.deployment.total_infer_nodes * self.deployment.gpus_per_node + + if ( + self.trainer.weight_broadcast.type == "nccl" + and self.trainer.weight_broadcast.inference_world_size != expected_world_size + ): + raise ValueError( + "trainer.weight_broadcast.inference_world_size must match allocated inference GPUs " + f"({expected_world_size}) for multi-node NCCL weight broadcast." + ) + + if ( + self.orchestrator.weight_broadcast.type == "nccl" + and self.orchestrator.weight_broadcast.inference_world_size != expected_world_size + ): + raise ValueError( + "orchestrator.weight_broadcast.inference_world_size must match allocated inference GPUs " + f"({expected_world_size}) for multi-node NCCL weight broadcast." + ) + + if self.inference.deployment.type != "disaggregated": + if self.inference.enable_expert_parallel: + dp_per_node = self.deployment.gpus_per_node // self.inference.parallel.tp + expected_dp_rank_count = self.deployment.num_infer_nodes * dp_per_node + else: + expected_dp_rank_count = 1 + + client = self.orchestrator.student.client + if client.dp_rank_count != expected_dp_rank_count: + raise ValueError( + "orchestrator.student.client.dp_rank_count must resolve to " + f"{expected_dp_rank_count} for multi-node non-disaggregated inference; " + f"got {client.dp_rank_count}." + ) + + return self + @model_validator(mode="after") def auto_setup_slurm_template(self): """Auto-setup the default single-node/multi-node SLURM template if no custom template is provided.""" diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index fcdee7a843..b2afeb0a6f 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -537,3 +537,59 @@ def test_explicit_inference_parser_wins_over_auto(): ) assert config.inference is not None assert config.inference.model.tool_call_parser == "hermes" + + +def test_multi_node_dense_nccl_world_size_matches_inference_gpu_count_and_router_client(): + config = RLConfig.model_validate( + { + "model": {"name": "Qwen/Qwen3-0.6B"}, + "slurm": {}, + "deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2}, + "weight_broadcast": {"type": "nccl"}, + "trainer": {}, + "orchestrator": {"renderer": None}, + "inference": {"parallel": {"tp": 4, "dp": 4}}, + } + ) + + assert config.inference is not None + assert config.inference.api_server_count == 4 + assert config.inference.data_parallel_size_local == 2 + assert config.trainer.weight_broadcast.inference_world_size == 16 + assert config.orchestrator.weight_broadcast.inference_world_size == 16 + assert config.orchestrator.student.client.dp_rank_count == 1 + + +def test_multi_node_dense_rejects_invalid_router_dp_rank_count(): + with pytest.raises(ValidationError, match=r"dp_rank_count must resolve to 1"): + RLConfig.model_validate( + { + "model": {"name": "Qwen/Qwen3-0.6B"}, + "slurm": {}, + "deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2}, + "weight_broadcast": {"type": "nccl"}, + "trainer": {}, + "orchestrator": { + "renderer": None, + "student": {"client": {"dp_rank_count": 4}}, + }, + "inference": {"parallel": {"tp": 4, "dp": 4}}, + } + ) + + +def test_multi_node_nccl_rejects_invalid_inference_world_size_override(): + with pytest.raises(ValidationError, match=r"inference_world_size must match allocated inference GPUs"): + RLConfig.model_validate( + { + "model": {"name": "Qwen/Qwen3-0.6B"}, + "slurm": {}, + "deployment": {"type": "multi_node", "num_train_nodes": 2, "num_infer_nodes": 2}, + "trainer": {"weight_broadcast": {"type": "nccl", "inference_world_size": 32}}, + "orchestrator": { + "renderer": None, + "weight_broadcast": {"type": "nccl", "inference_world_size": 32}, + }, + "inference": {"parallel": {"tp": 4, "dp": 4}}, + } + )