|
17 | 17 | from prime_rl.configs.shared import ClientConfig |
18 | 18 | from prime_rl.utils.logger import get_logger |
19 | 19 |
|
20 | | -# Identity tuple used by ``select_train_client`` to key load counts. ``api_base_url`` |
21 | | -# distinguishes servers; ``X-data-parallel-rank`` distinguishes DP shards within a |
22 | | -# server, since the router uses that header to route to specific GPU ranks. |
23 | | -ClientIdentity = tuple[str, str | None] |
| 20 | +# Identity used by ``select_train_client`` to key load counts. With external-LB |
| 21 | +# data parallelism each DP rank is its own endpoint, so ``api_base_url`` alone |
| 22 | +# uniquely identifies an inference target. |
| 23 | +ClientIdentity = str |
24 | 24 |
|
25 | 25 |
|
26 | 26 | def client_identity(client: vf.ClientConfig) -> ClientIdentity: |
27 | 27 | """Stable identity for load balancing across inference clients.""" |
28 | | - return (client.api_base_url, client.extra_headers.get("X-data-parallel-rank")) |
| 28 | + return client.api_base_url |
29 | 29 |
|
30 | 30 |
|
31 | 31 | @runtime_checkable |
@@ -200,27 +200,24 @@ def setup_clients( |
200 | 200 | k: v for k, v in ((k, os.getenv(v)) for k, v in client_config.headers_from_env.items()) if v is not None |
201 | 201 | } |
202 | 202 | for base_url in client_config.base_url: |
203 | | - for dp_rank in range(client_config.dp_rank_count): |
204 | | - headers = {**client_config.headers, **env_headers} |
205 | | - if client_config.dp_rank_count > 1: |
206 | | - headers["X-data-parallel-rank"] = str(dp_rank) |
207 | | - clients.append( |
208 | | - vf.ClientConfig( |
209 | | - client_idx=client_idx, |
210 | | - client_type=client_type, |
211 | | - api_base_url=base_url, |
212 | | - api_key_var=client_config.api_key_var, |
213 | | - timeout=client_config.timeout, |
214 | | - connect_timeout=client_config.connect_timeout, |
215 | | - max_connections=8192, |
216 | | - max_keepalive_connections=8192, |
217 | | - max_retries=10, |
218 | | - extra_headers=headers, |
219 | | - extra_headers_from_state=client_config.extra_headers_from_state, |
220 | | - **renderer_extra, |
221 | | - ) |
| 203 | + headers = {**client_config.headers, **env_headers} |
| 204 | + clients.append( |
| 205 | + vf.ClientConfig( |
| 206 | + client_idx=client_idx, |
| 207 | + client_type=client_type, |
| 208 | + api_base_url=base_url, |
| 209 | + api_key_var=client_config.api_key_var, |
| 210 | + timeout=client_config.timeout, |
| 211 | + connect_timeout=client_config.connect_timeout, |
| 212 | + max_connections=8192, |
| 213 | + max_keepalive_connections=8192, |
| 214 | + max_retries=10, |
| 215 | + extra_headers=headers, |
| 216 | + extra_headers_from_state=client_config.extra_headers_from_state, |
| 217 | + **renderer_extra, |
222 | 218 | ) |
223 | | - client_idx += 1 |
| 219 | + ) |
| 220 | + client_idx += 1 |
224 | 221 | return clients |
225 | 222 |
|
226 | 223 |
|
|
0 commit comments