Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.

- **`client.dp_rank_count` removed**: With external-LB data parallelism each DP rank is its own endpoint (one `client.base_url` per rank), so the client no longer pins a rollout to an internal DP shard via the `X-data-parallel-rank` header. The `[orchestrator.student.client] dp_rank_count` field (and its per-rank client expansion) is gone — the router load-balances across the per-rank endpoints. The header was a leftover from the hybrid-LB path dropped in #2696, and forwarding it to single-DP backends raised `data_parallel_rank out of range`. Existing configs setting `dp_rank_count` must drop it (`extra="forbid"` rejects it); there is nothing to migrate. (2026-06-04)
- **`inference.kv_cache_offload.cpu_bytes` removed → discriminated `type` config**: The flat `[inference.kv_cache_offload]` block with a single `cpu_bytes` field is replaced by a backend-discriminated union with composable `cpu`/`disk` tiers. Migrate native CPU offload from `[inference.kv_cache_offload]\ncpu_bytes = N` to `[inference.kv_cache_offload]\ntype = "native"` plus `[inference.kv_cache_offload.cpu]\nnum_bytes = N`. A `type = "mooncake"` backend (per-node distributed store; multi-node/SLURM only) and an optional `[inference.kv_cache_offload.disk]\npath = "..."` tier (layered behind cpu) are also available. `extra="forbid"` rejects the old `cpu_bytes` key, so existing configs must migrate. (2026-06-02)
- **Orchestrator async-pipeline rewrite** (collection of removals/renames). The orchestrator was rewritten to overlap train/eval rollouts on a shared concurrency limiter; several config fields were removed or renamed.
- **`orchestrator.seed` removed**: was only consumed by the deleted buffer; no replacement.
Expand Down
5 changes: 1 addition & 4 deletions packages/prime-rl-configs/src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,16 +600,13 @@ def auto_setup_disaggregated_inference(self):
def auto_setup_inference_client(self):
"""Auto-configure orchestrator student client from the inference server config.

For all modes, sets dp_rank_count from inference DP size. For SFT mode,
also sets base_url - rl/opd rely on the ClientConfig default
For SFT mode, sets base_url - rl/opd rely on the ClientConfig default
(``["http://localhost:8000/v1"]``) which already matches the auto-launched
student vLLM at inference.server.port = 8000.
"""
if self.inference is None:
return self
client = self.orchestrator.student.client
if "dp_rank_count" not in client.model_fields_set:
client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp
if self.orchestrator.training_mode == "sft" and "base_url" not in client.model_fields_set:
host = self.inference.server.host or "localhost"
port = self.inference.server.port
Expand Down
3 changes: 0 additions & 3 deletions packages/prime-rl-configs/src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ class ClientConfig(BaseConfig):
skip_model_check: bool = False
"""Skip checking that the model is available in the inference pool. Useful for external APIs or keys that do not expose ``/models``."""

dp_rank_count: int = Field(1, ge=1)
"""Number of data-parallel ranks behind each base URL. When > 1, each URL is expanded into ``dp_rank_count`` logical clients pinned via the ``X-data-parallel-rank`` header, so every request within a rollout hits the same DP engine and reuses KV cache. Auto-set from the inference config when using the RL entrypoint."""

admin_base_url: list[str] | None = None
"""Separate base URLs for admin operations (weight updates, health checks). When set, admin clients bypass routers and hit each server directly — used in disaggregated P/D deployments where the router must not handle admin traffic."""

Expand Down
47 changes: 22 additions & 25 deletions src/prime_rl/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from prime_rl.configs.shared import ClientConfig
from prime_rl.utils.logger import get_logger

# Identity tuple used by ``select_train_client`` to key load counts. ``api_base_url``
# distinguishes servers; ``X-data-parallel-rank`` distinguishes DP shards within a
# server, since the router uses that header to route to specific GPU ranks.
ClientIdentity = tuple[str, str | None]
# Identity used by ``select_train_client`` to key load counts. With external-LB
# data parallelism each DP rank is its own endpoint, so ``api_base_url`` alone
# uniquely identifies an inference target.
ClientIdentity = str


def client_identity(client: vf.ClientConfig) -> ClientIdentity:
"""Stable identity for load balancing across inference clients."""
return (client.api_base_url, client.extra_headers.get("X-data-parallel-rank"))
return client.api_base_url


@runtime_checkable
Expand Down Expand Up @@ -200,27 +200,24 @@ def setup_clients(
k: v for k, v in ((k, os.getenv(v)) for k, v in client_config.headers_from_env.items()) if v is not None
}
for base_url in client_config.base_url:
for dp_rank in range(client_config.dp_rank_count):
headers = {**client_config.headers, **env_headers}
if client_config.dp_rank_count > 1:
headers["X-data-parallel-rank"] = str(dp_rank)
clients.append(
vf.ClientConfig(
client_idx=client_idx,
client_type=client_type,
api_base_url=base_url,
api_key_var=client_config.api_key_var,
timeout=client_config.timeout,
connect_timeout=client_config.connect_timeout,
max_connections=8192,
max_keepalive_connections=8192,
max_retries=10,
extra_headers=headers,
extra_headers_from_state=client_config.extra_headers_from_state,
**renderer_extra,
)
headers = {**client_config.headers, **env_headers}
clients.append(
vf.ClientConfig(
client_idx=client_idx,
client_type=client_type,
api_base_url=base_url,
api_key_var=client_config.api_key_var,
timeout=client_config.timeout,
connect_timeout=client_config.connect_timeout,
max_connections=8192,
max_keepalive_connections=8192,
max_retries=10,
extra_headers=headers,
extra_headers_from_state=client_config.extra_headers_from_state,
**renderer_extra,
)
client_idx += 1
)
client_idx += 1
return clients


Expand Down
1 change: 0 additions & 1 deletion src/prime_rl/utils/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def _rebuild_clients(self) -> None:
api_key_var=self.client_config.api_key_var,
headers=self.client_config.headers,
headers_from_env=self.client_config.headers_from_env,
dp_rank_count=self.client_config.dp_rank_count,
extra_headers_from_state=self.client_config.extra_headers_from_state,
)
self._train_clients = (
Expand Down
14 changes: 9 additions & 5 deletions tests/unit/utils/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ def test_load_lora_adapter_succeeds_on_first_attempt():
)


def test_setup_clients_assigns_renderer_and_dp_rank_headers():
def test_setup_clients_creates_one_renderer_client_per_url():
from renderers import Qwen3VLRendererConfig

# External-LB: base_url is the list of per-rank endpoints (the URL is the rank
# selector), so each URL maps to exactly one client with no rank header.
client_config = ClientConfig(
base_url=["http://worker-a:8000/v1"],
base_url=["http://worker-a:8000/v1", "http://worker-a:8001/v1"],
api_key_var="PRIME_API_KEY",
headers={"X-Test": "test"},
dp_rank_count=2,
extra_headers_from_state={"X-Session-ID": "session_id"},
)

Expand All @@ -70,8 +71,11 @@ def test_setup_clients_assigns_renderer_and_dp_rank_headers():
assert [client.client_type for client in clients] == ["renderer", "renderer"]
assert [client.renderer_config for client in clients] == [renderer_settings, renderer_settings]
assert [client.renderer_model_name for client in clients] == [None, None]
assert [client.api_base_url for client in clients] == ["http://worker-a:8000/v1"] * 2
assert [client.extra_headers["X-data-parallel-rank"] for client in clients] == ["0", "1"]
assert [client.api_base_url for client in clients] == [
"http://worker-a:8000/v1",
"http://worker-a:8001/v1",
]
assert all("X-data-parallel-rank" not in client.extra_headers for client in clients)
assert clients[0].extra_headers["X-Test"] == "test"
assert clients[0].extra_headers_from_state == {"X-Session-ID": "session_id"}

Expand Down
1 change: 0 additions & 1 deletion tests/unit/utils/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,6 @@ def test_elastic_clients_preserve_renderer_model_name_when_model_name_updates():
client_config.headers = {}
client_config.headers_from_env = {}
client_config.extra_headers_from_state = {}
client_config.dp_rank_count = 1

from renderers import Qwen3VLRendererConfig

Expand Down
Loading