Skip to content

Commit f3b38b2

Browse files
S1ro1claude
andcommitted
refactor: drop client-side DP-rank pinning for external-LB
With external-LB data parallelism each DP rank is its own API server on its own port (the URL is the rank selector), so the client no longer needs the hybrid-LB `X-data-parallel-rank` header to pin a rollout to an internal DP shard. Remove the `dp_rank_count` client field + its auto-setup and the per-rank client expansion: one client per base URL, no rank header. The router (vllm-router or llm-d EPP) balances across the per-rank endpoints. This also fixes llm-d routing: the EPP forwards the header to the dp=1 backend, which rejected it ("data_parallel_rank N is out of range"). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent e751794 commit f3b38b2

6 files changed

Lines changed: 33 additions & 38 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.
44

5+
- **`client.dp_rank_count` removed**: With external-LB data parallelism each DP rank is its own endpoint (one `client.base_url` per rank), so the client no longer pins a rollout to an internal DP shard via the `X-data-parallel-rank` header. The `[orchestrator.student.client] dp_rank_count` field (and its per-rank client expansion) is gone — the router load-balances across the per-rank endpoints. Existing configs setting `dp_rank_count` must drop it (`extra="forbid"` rejects it); there is nothing to migrate. (2026-06-03)
56
- **`inference.kv_cache_offload.cpu_bytes` removed → discriminated `type` config**: The flat `[inference.kv_cache_offload]` block with a single `cpu_bytes` field is replaced by a backend-discriminated union with composable `cpu`/`disk` tiers. Migrate native CPU offload from `[inference.kv_cache_offload]\ncpu_bytes = N` to `[inference.kv_cache_offload]\ntype = "native"` plus `[inference.kv_cache_offload.cpu]\nnum_bytes = N`. A `type = "mooncake"` backend (per-node distributed store; multi-node/SLURM only) and an optional `[inference.kv_cache_offload.disk]\npath = "..."` tier (layered behind cpu) are also available. `extra="forbid"` rejects the old `cpu_bytes` key, so existing configs must migrate. (2026-06-02)
67
- **Orchestrator async-pipeline rewrite** (collection of removals/renames). The orchestrator was rewritten to overlap train/eval rollouts on a shared concurrency limiter; several config fields were removed or renamed.
78
- **`orchestrator.seed` removed**: was only consumed by the deleted buffer; no replacement.

packages/prime-rl-configs/src/prime_rl/configs/rl.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -600,16 +600,13 @@ def auto_setup_disaggregated_inference(self):
600600
def auto_setup_inference_client(self):
601601
"""Auto-configure orchestrator student client from the inference server config.
602602
603-
For all modes, sets dp_rank_count from inference DP size. For SFT mode,
604-
also sets base_url - rl/opd rely on the ClientConfig default
603+
For SFT mode, sets base_url - rl/opd rely on the ClientConfig default
605604
(``["http://localhost:8000/v1"]``) which already matches the auto-launched
606605
student vLLM at inference.server.port = 8000.
607606
"""
608607
if self.inference is None:
609608
return self
610609
client = self.orchestrator.student.client
611-
if "dp_rank_count" not in client.model_fields_set:
612-
client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp
613610
if self.orchestrator.training_mode == "sft" and "base_url" not in client.model_fields_set:
614611
host = self.inference.server.host or "localhost"
615612
port = self.inference.server.port

packages/prime-rl-configs/src/prime_rl/configs/shared.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,6 @@ class ClientConfig(BaseConfig):
127127
skip_model_check: bool = False
128128
"""Skip checking that the model is available in the inference pool. Useful for external APIs or keys that do not expose ``/models``."""
129129

130-
dp_rank_count: int = Field(1, ge=1)
131-
"""Number of data-parallel ranks behind each base URL. When > 1, each URL is expanded into ``dp_rank_count`` logical clients pinned via the ``X-data-parallel-rank`` header, so every request within a rollout hits the same DP engine and reuses KV cache. Auto-set from the inference config when using the RL entrypoint."""
132-
133130
admin_base_url: list[str] | None = None
134131
"""Separate base URLs for admin operations (weight updates, health checks). When set, admin clients bypass routers and hit each server directly — used in disaggregated P/D deployments where the router must not handle admin traffic."""
135132

src/prime_rl/utils/client.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
from prime_rl.configs.shared import ClientConfig
1818
from prime_rl.utils.logger import get_logger
1919

20-
# Identity tuple used by ``select_train_client`` to key load counts. ``api_base_url``
21-
# distinguishes servers; ``X-data-parallel-rank`` distinguishes DP shards within a
22-
# server, since the router uses that header to route to specific GPU ranks.
23-
ClientIdentity = tuple[str, str | None]
20+
# Identity used by ``select_train_client`` to key load counts. With external-LB
21+
# data parallelism each DP rank is its own endpoint, so ``api_base_url`` alone
22+
# uniquely identifies an inference target.
23+
ClientIdentity = str
2424

2525

2626
def client_identity(client: vf.ClientConfig) -> ClientIdentity:
2727
"""Stable identity for load balancing across inference clients."""
28-
return (client.api_base_url, client.extra_headers.get("X-data-parallel-rank"))
28+
return client.api_base_url
2929

3030

3131
@runtime_checkable
@@ -200,27 +200,24 @@ def setup_clients(
200200
k: v for k, v in ((k, os.getenv(v)) for k, v in client_config.headers_from_env.items()) if v is not None
201201
}
202202
for base_url in client_config.base_url:
203-
for dp_rank in range(client_config.dp_rank_count):
204-
headers = {**client_config.headers, **env_headers}
205-
if client_config.dp_rank_count > 1:
206-
headers["X-data-parallel-rank"] = str(dp_rank)
207-
clients.append(
208-
vf.ClientConfig(
209-
client_idx=client_idx,
210-
client_type=client_type,
211-
api_base_url=base_url,
212-
api_key_var=client_config.api_key_var,
213-
timeout=client_config.timeout,
214-
connect_timeout=client_config.connect_timeout,
215-
max_connections=8192,
216-
max_keepalive_connections=8192,
217-
max_retries=10,
218-
extra_headers=headers,
219-
extra_headers_from_state=client_config.extra_headers_from_state,
220-
**renderer_extra,
221-
)
203+
headers = {**client_config.headers, **env_headers}
204+
clients.append(
205+
vf.ClientConfig(
206+
client_idx=client_idx,
207+
client_type=client_type,
208+
api_base_url=base_url,
209+
api_key_var=client_config.api_key_var,
210+
timeout=client_config.timeout,
211+
connect_timeout=client_config.connect_timeout,
212+
max_connections=8192,
213+
max_keepalive_connections=8192,
214+
max_retries=10,
215+
extra_headers=headers,
216+
extra_headers_from_state=client_config.extra_headers_from_state,
217+
**renderer_extra,
222218
)
223-
client_idx += 1
219+
)
220+
client_idx += 1
224221
return clients
225222

226223

src/prime_rl/utils/elastic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def _rebuild_clients(self) -> None:
197197
api_key_var=self.client_config.api_key_var,
198198
headers=self.client_config.headers,
199199
headers_from_env=self.client_config.headers_from_env,
200-
dp_rank_count=self.client_config.dp_rank_count,
201200
extra_headers_from_state=self.client_config.extra_headers_from_state,
202201
)
203202
self._train_clients = (

tests/unit/utils/test_client.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,15 @@ def test_load_lora_adapter_succeeds_on_first_attempt():
4949
)
5050

5151

52-
def test_setup_clients_assigns_renderer_and_dp_rank_headers():
52+
def test_setup_clients_creates_one_renderer_client_per_url():
5353
from renderers import Qwen3VLRendererConfig
5454

55+
# External-LB: base_url is the list of per-rank endpoints (the URL is the rank
56+
# selector), so each URL maps to exactly one client with no rank header.
5557
client_config = ClientConfig(
56-
base_url=["http://worker-a:8000/v1"],
58+
base_url=["http://worker-a:8000/v1", "http://worker-a:8001/v1"],
5759
api_key_var="PRIME_API_KEY",
5860
headers={"X-Test": "test"},
59-
dp_rank_count=2,
6061
extra_headers_from_state={"X-Session-ID": "session_id"},
6162
)
6263

@@ -70,8 +71,11 @@ def test_setup_clients_assigns_renderer_and_dp_rank_headers():
7071
assert [client.client_type for client in clients] == ["renderer", "renderer"]
7172
assert [client.renderer_config for client in clients] == [renderer_settings, renderer_settings]
7273
assert [client.renderer_model_name for client in clients] == [None, None]
73-
assert [client.api_base_url for client in clients] == ["http://worker-a:8000/v1"] * 2
74-
assert [client.extra_headers["X-data-parallel-rank"] for client in clients] == ["0", "1"]
74+
assert [client.api_base_url for client in clients] == [
75+
"http://worker-a:8000/v1",
76+
"http://worker-a:8001/v1",
77+
]
78+
assert all("X-data-parallel-rank" not in client.extra_headers for client in clients)
7579
assert clients[0].extra_headers["X-Test"] == "test"
7680
assert clients[0].extra_headers_from_state == {"X-Session-ID": "session_id"}
7781

0 commit comments

Comments
 (0)