diff --git a/CHANGELOG.md b/CHANGELOG.md index 9da33aa490..8ef754e236 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. +- **`client.dp_rank_count` removed**: With external-LB data parallelism each DP rank is its own endpoint (one `client.base_url` per rank), so the client no longer pins a rollout to an internal DP shard via the `X-data-parallel-rank` header. The `[orchestrator.student.client] dp_rank_count` field (and its per-rank client expansion) is gone — the router load-balances across the per-rank endpoints. The header was a leftover from the hybrid-LB path dropped in #2696, and forwarding it to single-DP backends raised `data_parallel_rank out of range`. Existing configs setting `dp_rank_count` must drop it (`extra="forbid"` rejects it); there is nothing to migrate. (2026-06-04) - **`inference.kv_cache_offload.cpu_bytes` removed → discriminated `type` config**: The flat `[inference.kv_cache_offload]` block with a single `cpu_bytes` field is replaced by a backend-discriminated union with composable `cpu`/`disk` tiers. Migrate native CPU offload from `[inference.kv_cache_offload]\ncpu_bytes = N` to `[inference.kv_cache_offload]\ntype = "native"` plus `[inference.kv_cache_offload.cpu]\nnum_bytes = N`. A `type = "mooncake"` backend (per-node distributed store; multi-node/SLURM only) and an optional `[inference.kv_cache_offload.disk]\npath = "..."` tier (layered behind cpu) are also available. `extra="forbid"` rejects the old `cpu_bytes` key, so existing configs must migrate. (2026-06-02) - **Orchestrator async-pipeline rewrite** (collection of removals/renames). The orchestrator was rewritten to overlap train/eval rollouts on a shared concurrency limiter; several config fields were removed or renamed. - **`orchestrator.seed` removed**: was only consumed by the deleted buffer; no replacement. diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index 1af3436d1b..e95fd5207b 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -600,16 +600,13 @@ def auto_setup_disaggregated_inference(self): def auto_setup_inference_client(self): """Auto-configure orchestrator student client from the inference server config. - For all modes, sets dp_rank_count from inference DP size. For SFT mode, - also sets base_url - rl/opd rely on the ClientConfig default + For SFT mode, sets base_url - rl/opd rely on the ClientConfig default (``["http://localhost:8000/v1"]``) which already matches the auto-launched student vLLM at inference.server.port = 8000. """ if self.inference is None: return self client = self.orchestrator.student.client - if "dp_rank_count" not in client.model_fields_set: - client.dp_rank_count = self.inference.data_parallel_size_local or self.inference.parallel.dp if self.orchestrator.training_mode == "sft" and "base_url" not in client.model_fields_set: host = self.inference.server.host or "localhost" port = self.inference.server.port diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index ff311f145d..b510a26f75 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -127,9 +127,6 @@ class ClientConfig(BaseConfig): skip_model_check: bool = False """Skip checking that the model is available in the inference pool. Useful for external APIs or keys that do not expose ``/models``.""" - dp_rank_count: int = Field(1, ge=1) - """Number of data-parallel ranks behind each base URL. When > 1, each URL is expanded into ``dp_rank_count`` logical clients pinned via the ``X-data-parallel-rank`` header, so every request within a rollout hits the same DP engine and reuses KV cache. Auto-set from the inference config when using the RL entrypoint.""" - admin_base_url: list[str] | None = None """Separate base URLs for admin operations (weight updates, health checks). When set, admin clients bypass routers and hit each server directly — used in disaggregated P/D deployments where the router must not handle admin traffic.""" diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index b9ee8f4b9d..d5eb9b4a9f 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -17,15 +17,15 @@ from prime_rl.configs.shared import ClientConfig from prime_rl.utils.logger import get_logger -# Identity tuple used by ``select_train_client`` to key load counts. ``api_base_url`` -# distinguishes servers; ``X-data-parallel-rank`` distinguishes DP shards within a -# server, since the router uses that header to route to specific GPU ranks. -ClientIdentity = tuple[str, str | None] +# Identity used by ``select_train_client`` to key load counts. With external-LB +# data parallelism each DP rank is its own endpoint, so ``api_base_url`` alone +# uniquely identifies an inference target. +ClientIdentity = str def client_identity(client: vf.ClientConfig) -> ClientIdentity: """Stable identity for load balancing across inference clients.""" - return (client.api_base_url, client.extra_headers.get("X-data-parallel-rank")) + return client.api_base_url @runtime_checkable @@ -200,27 +200,24 @@ def setup_clients( k: v for k, v in ((k, os.getenv(v)) for k, v in client_config.headers_from_env.items()) if v is not None } for base_url in client_config.base_url: - for dp_rank in range(client_config.dp_rank_count): - headers = {**client_config.headers, **env_headers} - if client_config.dp_rank_count > 1: - headers["X-data-parallel-rank"] = str(dp_rank) - clients.append( - vf.ClientConfig( - client_idx=client_idx, - client_type=client_type, - api_base_url=base_url, - api_key_var=client_config.api_key_var, - timeout=client_config.timeout, - connect_timeout=client_config.connect_timeout, - max_connections=8192, - max_keepalive_connections=8192, - max_retries=10, - extra_headers=headers, - extra_headers_from_state=client_config.extra_headers_from_state, - **renderer_extra, - ) + headers = {**client_config.headers, **env_headers} + clients.append( + vf.ClientConfig( + client_idx=client_idx, + client_type=client_type, + api_base_url=base_url, + api_key_var=client_config.api_key_var, + timeout=client_config.timeout, + connect_timeout=client_config.connect_timeout, + max_connections=8192, + max_keepalive_connections=8192, + max_retries=10, + extra_headers=headers, + extra_headers_from_state=client_config.extra_headers_from_state, + **renderer_extra, ) - client_idx += 1 + ) + client_idx += 1 return clients diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 951b3673c1..b221046078 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -197,7 +197,6 @@ def _rebuild_clients(self) -> None: api_key_var=self.client_config.api_key_var, headers=self.client_config.headers, headers_from_env=self.client_config.headers_from_env, - dp_rank_count=self.client_config.dp_rank_count, extra_headers_from_state=self.client_config.extra_headers_from_state, ) self._train_clients = ( diff --git a/tests/unit/utils/test_client.py b/tests/unit/utils/test_client.py index 69325e6c4a..8e7f79c3a1 100644 --- a/tests/unit/utils/test_client.py +++ b/tests/unit/utils/test_client.py @@ -49,14 +49,15 @@ def test_load_lora_adapter_succeeds_on_first_attempt(): ) -def test_setup_clients_assigns_renderer_and_dp_rank_headers(): +def test_setup_clients_creates_one_renderer_client_per_url(): from renderers import Qwen3VLRendererConfig + # External-LB: base_url is the list of per-rank endpoints (the URL is the rank + # selector), so each URL maps to exactly one client with no rank header. client_config = ClientConfig( - base_url=["http://worker-a:8000/v1"], + base_url=["http://worker-a:8000/v1", "http://worker-a:8001/v1"], api_key_var="PRIME_API_KEY", headers={"X-Test": "test"}, - dp_rank_count=2, extra_headers_from_state={"X-Session-ID": "session_id"}, ) @@ -70,8 +71,11 @@ def test_setup_clients_assigns_renderer_and_dp_rank_headers(): assert [client.client_type for client in clients] == ["renderer", "renderer"] assert [client.renderer_config for client in clients] == [renderer_settings, renderer_settings] assert [client.renderer_model_name for client in clients] == [None, None] - assert [client.api_base_url for client in clients] == ["http://worker-a:8000/v1"] * 2 - assert [client.extra_headers["X-data-parallel-rank"] for client in clients] == ["0", "1"] + assert [client.api_base_url for client in clients] == [ + "http://worker-a:8000/v1", + "http://worker-a:8001/v1", + ] + assert all("X-data-parallel-rank" not in client.extra_headers for client in clients) assert clients[0].extra_headers["X-Test"] == "test" assert clients[0].extra_headers_from_state == {"X-Session-ID": "session_id"} diff --git a/tests/unit/utils/test_elastic.py b/tests/unit/utils/test_elastic.py index 21490497e7..f90e2da298 100644 --- a/tests/unit/utils/test_elastic.py +++ b/tests/unit/utils/test_elastic.py @@ -417,7 +417,6 @@ def test_elastic_clients_preserve_renderer_model_name_when_model_name_updates(): client_config.headers = {} client_config.headers_from_env = {} client_config.extra_headers_from_state = {} - client_config.dp_rank_count = 1 from renderers import Qwen3VLRendererConfig