|
23 | 23 | import time |
24 | 24 | from urllib.parse import urljoin |
25 | 25 |
|
26 | | -from inference_endpoint.core.types import Query |
| 26 | +from inference_endpoint.core.types import Query, QueryResult |
27 | 27 | from inference_endpoint.endpoint_client.configs import ( |
28 | 28 | AioHttpConfig, |
29 | 29 | HTTPClientConfig, |
30 | 30 | ZMQConfig, |
31 | 31 | ) |
32 | | -from inference_endpoint.endpoint_client.futures_client import FuturesHttpClient |
| 32 | +from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient |
33 | 33 | from inference_endpoint.exceptions import ( |
34 | 34 | ExecutionError, |
35 | 35 | InputValidationError, |
@@ -80,75 +80,120 @@ async def run_probe_command(args: argparse.Namespace) -> None: |
80 | 80 | zmq_readiness_queue_addr=f"ipc://{tmp_dir}/ready", |
81 | 81 | ) |
82 | 82 |
|
83 | | - client = FuturesHttpClient(http_config, aiohttp_config, zmq_config) |
| 83 | + client = HTTPEndpointClient(http_config, aiohttp_config, zmq_config) |
84 | 84 | await client.async_start() |
85 | 85 |
|
86 | 86 | logger.info(f"Sending {num_requests} requests...") |
87 | 87 |
|
88 | | - # Send test requests and collect futures |
89 | | - futures = [] |
90 | | - start_times = {} |
| 88 | + # Send test requests |
| 89 | + start_times: dict[str, float] = {} |
| 90 | + sent_query_ids: list[str] = [] |
| 91 | + issue_errors: list[str] = [] |
91 | 92 |
|
92 | 93 | # TODO: this might not work with a real vLLM/SGLang endpoint, fix this. |
93 | 94 | for i in range(num_requests): |
| 95 | + query_id = f"probe-{i}" |
94 | 96 | query = Query( |
95 | | - id=f"probe-{i}", |
| 97 | + id=query_id, |
96 | 98 | data={ |
97 | 99 | "prompt": test_prompt, |
98 | 100 | "model": model_name, |
99 | 101 | "max_tokens": 50, |
100 | 102 | "stream": False, |
101 | 103 | }, |
102 | 104 | ) |
103 | | - start_times[f"probe-{i}"] = time.time() |
104 | 105 |
|
105 | 106 | try: |
106 | | - future = await client.issue_query(query) |
107 | | - futures.append((f"probe-{i}", future)) |
108 | | - # Simple progress indicator |
109 | | - if (i + 1) % max(1, num_requests // 10) == 0 or i == num_requests - 1: |
110 | | - logger.info(f" Issued {i + 1}/{num_requests} requests") |
| 107 | + start_times[query_id] = time.time() |
| 108 | + await client.issue_query_async(query) |
| 109 | + # Only track successfully issued queries |
| 110 | + sent_query_ids.append(query_id) |
111 | 111 | except Exception as e: |
| 112 | + issue_errors.append(f"{query_id}: Failed to issue - {str(e)[:50]}") |
112 | 113 | logger.warning(f"Failed to issue request {i}: {str(e)[:50]}") |
| 114 | + continue |
| 115 | + |
| 116 | + # Simple progress indicator |
| 117 | + if (i + 1) % max(1, num_requests // 10) == 0 or i == num_requests - 1: |
| 118 | + logger.info(f" Issued {i + 1}/{num_requests} requests") |
113 | 119 |
|
114 | 120 | # Wait for all responses |
115 | | - latencies = [] |
116 | | - errors = [] |
117 | | - responses = [] |
| 121 | + latencies: list[float] = [] |
| 122 | + errors: list[str] = issue_errors # Include any issue errors |
| 123 | + responses: list[tuple[str, str]] = [] |
| 124 | + |
| 125 | + # Only count successfully issued queries |
| 126 | + num_expected = len(sent_query_ids) |
| 127 | + if num_expected == 0: |
| 128 | + logger.error("✗ No queries were successfully issued") |
| 129 | + raise ExecutionError("Probe failed: no queries could be issued") |
118 | 130 |
|
119 | 131 | # Wait for all responses with generous timeout (probe queries can be slow) |
120 | | - # Default HTTP client timeout is 30s, give extra buffer for processing |
121 | | - probe_timeout = 60.0 # 60 seconds per query |
| 132 | + probe_timeout = 60.0 # 60 seconds total |
| 133 | + start_wait = time.time() |
| 134 | + |
| 135 | + logger.info(f"Waiting for {num_expected} responses...") |
122 | 136 |
|
123 | | - logger.info(f"Waiting for {len(futures)} responses...") |
| 137 | + received_ids: set[str] = set() |
124 | 138 |
|
125 | | - for idx, (query_id, future) in enumerate(futures): |
| 139 | + while ( |
| 140 | + len(received_ids) < num_expected |
| 141 | + and (time.time() - start_wait) < probe_timeout |
| 142 | + ): |
126 | 143 | try: |
127 | | - result = await asyncio.wait_for(future, timeout=probe_timeout) |
128 | | - # Calculate latency - should always be in start_times |
129 | | - assert ( |
130 | | - query_id in start_times |
131 | | - ), f"Query {query_id} not found in start_times" |
| 144 | + result = await client.get_ready_responses_async() |
| 145 | + |
| 146 | + if result is None: |
| 147 | + await asyncio.sleep(0.01) |
| 148 | + continue |
| 149 | + |
| 150 | + # Skip non-final streaming chunks |
| 151 | + if not isinstance(result, QueryResult): |
| 152 | + continue |
| 153 | + |
| 154 | + query_id = result.id |
| 155 | + |
| 156 | + if query_id in received_ids: |
| 157 | + logger.warning(f"Received duplicate response for {query_id}") |
| 158 | + continue |
| 159 | + |
| 160 | + received_ids.add(query_id) |
| 161 | + |
| 162 | + # Calculate latency - should always be in start_times for issued queries |
| 163 | + if query_id not in start_times: |
| 164 | + logger.warning( |
| 165 | + f"Received response for unknown query_id: {query_id}, skipping" |
| 166 | + ) |
| 167 | + continue |
132 | 168 | latency_ms = (time.time() - start_times[query_id]) * 1000 |
133 | | - latencies.append(latency_ms) |
134 | 169 |
|
135 | 170 | if result.error: |
136 | 171 | errors.append(f"{query_id}: {result.error}") |
137 | 172 | else: |
138 | | - # Store successful response for sanity check |
| 173 | + latencies.append(latency_ms) |
139 | 174 | responses.append((query_id, result.response_output)) |
140 | | - except TimeoutError: |
141 | | - errors.append(f"{query_id}: Timeout (>{probe_timeout}s)") |
| 175 | + |
| 176 | + # Simple progress indicator |
| 177 | + if ( |
| 178 | + len(received_ids) % max(1, num_expected // 10) == 0 |
| 179 | + or len(received_ids) == num_expected |
| 180 | + ): |
| 181 | + logger.info( |
| 182 | + f" Processed {len(received_ids)}/{num_expected} responses" |
| 183 | + ) |
| 184 | + |
142 | 185 | except Exception as e: |
143 | | - errors.append(f"{query_id}: {str(e)[:50]}") |
| 186 | + logger.warning(f"Error receiving response: {str(e)[:50]}") |
| 187 | + await asyncio.sleep(0.01) |
144 | 188 |
|
145 | | - # Simple progress indicator |
146 | | - if (idx + 1) % max(1, len(futures) // 10) == 0 or idx == len(futures) - 1: |
147 | | - logger.info(f" Processed {idx + 1}/{len(futures)} responses") |
| 189 | + # Mark any issued but not received as timeout |
| 190 | + for query_id in sent_query_ids: |
| 191 | + if query_id not in received_ids: |
| 192 | + errors.append(f"{query_id}: Timeout (>{probe_timeout}s)") |
148 | 193 |
|
149 | 194 | # Report results |
150 | 195 | success_count = len(latencies) |
151 | | - logger.info(f"✓ Completed: {success_count}/{num_requests} successful") |
| 196 | + logger.info(f"✓ Completed: {success_count}/{num_expected} successful") |
152 | 197 |
|
153 | 198 | if latencies: |
154 | 199 | avg_latency = sum(latencies) / len(latencies) |
|
0 commit comments