Skip to content

Commit 65914f9

Browse files
S1ro1claude
andauthored
fix(configs): size dense multi-node NCCL world by inference GPU count (#2707)
The dense multi-node external-LB weight-broadcast world size was computed as total_infer_nodes * api_server_count * tp. api_server_count can resolve to the global DP size (e.g. when parallel.dp is set, or via validator ordering), which double-counts the node dimension, so the trainer's NCCL broadcast waits for more ranks than exist and init deadlocks. Every allocated inference GPU is one NCCL rank, and the external-LB launcher starts dp_per_node TP-sharded servers per node (gpus_per_node workers/node), so size the world directly as total_infer_nodes * gpus_per_node. This matches the disaggregated path. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent b67fd12 commit 65914f9

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

  • packages/prime-rl-configs/src/prime_rl/configs

packages/prime-rl-configs/src/prime_rl/configs/rl.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,11 +548,14 @@ def auto_setup_deployment(self):
548548
self.inference.api_server_count = dp_per_node
549549

550550
if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl":
551-
# Compute inference_world_size from actual worker count per server:
552-
# each api_server runs tp workers that participate in collective_rpc.
553-
api_server_count = self.inference.api_server_count if self.inference else 1
554-
tp = self.inference.parallel.tp if self.inference else 1
555-
total_infer_workers = self.deployment.total_infer_nodes * api_server_count * tp
551+
# Every allocated inference GPU is a NCCL rank in the weight broadcast.
552+
# The external-LB launcher starts dp_per_node (= gpus_per_node / tp)
553+
# TP-sharded servers per node, i.e. gpus_per_node workers per node, so use
554+
# the GPU count directly. Deriving it from api_server_count double-counts:
555+
# api_server_count can resolve to the *global* DP size, making the node
556+
# factor count twice and NCCL wait for ranks that never connect. Matches
557+
# the disaggregated path below.
558+
total_infer_workers = self.deployment.total_infer_nodes * self.deployment.gpus_per_node
556559
assert self.trainer.weight_broadcast.type == "nccl"
557560
self.trainer.weight_broadcast.host = "0.0.0.0"
558561
self.trainer.weight_broadcast.inference_world_size = total_infer_workers

0 commit comments

Comments
 (0)