|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -"""Probe command implementation for endpoint health checking.""" |
| 16 | +"""Probe command for endpoint health checking.""" |
17 | 17 |
|
18 | 18 | import argparse |
19 | 19 | import asyncio |
|
24 | 24 | from urllib.parse import urljoin |
25 | 25 |
|
26 | 26 | from inference_endpoint.core.types import Query, QueryResult |
27 | | -from inference_endpoint.endpoint_client.configs import ( |
| 27 | +from inference_endpoint.endpoint_client import ( |
28 | 28 | AioHttpConfig, |
29 | 29 | HTTPClientConfig, |
| 30 | + HTTPEndpointClient, |
30 | 31 | ZMQConfig, |
31 | 32 | ) |
32 | | -from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient |
33 | 33 | from inference_endpoint.exceptions import ( |
34 | 34 | ExecutionError, |
35 | 35 | InputValidationError, |
|
40 | 40 |
|
41 | 41 |
|
42 | 42 | async def run_probe_command(args: argparse.Namespace) -> None: |
43 | | - """Run endpoint probe to validate connectivity and basic functionality. |
44 | | -
|
45 | | - Actions: |
46 | | - 1. Send test requests using HTTP client with futures |
47 | | - 2. Measure basic latency |
48 | | - 3. Report validation status |
49 | | - """ |
50 | | - # Extract arguments |
| 43 | + """Probe endpoint to validate connectivity and measure latency.""" |
51 | 44 | endpoint = args.endpoint |
52 | 45 | num_requests = args.requests |
53 | | - test_prompt = args.prompt |
| 46 | + prompt = args.prompt |
| 47 | + model = getattr(args, "model", None) |
54 | 48 |
|
55 | | - # Model: use provided or default to valid OpenAI model name |
56 | | - model_name = getattr(args, "model", None) |
57 | | - if not model_name: |
58 | | - logger.error("Model required: --model or specify in YAML config") |
| 49 | + if not model: |
59 | 50 | raise InputValidationError("Model required: --model NAME") |
60 | | - # Note: API key handling would go in HTTP client config if needed |
61 | 51 |
|
62 | 52 | logger.info(f"Probing: {endpoint}") |
63 | 53 |
|
64 | | - # Create temp directory for ZMQ |
65 | 54 | tmp_dir = tempfile.mkdtemp(prefix="probe_") |
66 | 55 | client = None |
67 | 56 |
|
68 | | - # TODO (Rashid): Add a health check with a separate timeout. |
69 | 57 | try: |
70 | | - # Setup HTTP client with futures support |
71 | | - http_config = HTTPClientConfig( |
72 | | - endpoint_url=urljoin(endpoint, "/v1/chat/completions"), |
73 | | - num_workers=1, |
74 | | - max_concurrency=num_requests, |
| 58 | + client = HTTPEndpointClient( |
| 59 | + HTTPClientConfig( |
| 60 | + endpoint_url=urljoin(endpoint, "/v1/chat/completions"), |
| 61 | + num_workers=1, |
| 62 | + max_concurrency=num_requests, |
| 63 | + ), |
| 64 | + AioHttpConfig(), |
| 65 | + ZMQConfig( |
| 66 | + zmq_request_queue_prefix=f"ipc://{tmp_dir}/req", |
| 67 | + zmq_response_queue_addr=f"ipc://{tmp_dir}/resp", |
| 68 | + zmq_readiness_queue_addr=f"ipc://{tmp_dir}/ready", |
| 69 | + ), |
75 | 70 | ) |
76 | | - aiohttp_config = AioHttpConfig() |
77 | | - zmq_config = ZMQConfig( |
78 | | - zmq_request_queue_prefix=f"ipc://{tmp_dir}/req", |
79 | | - zmq_response_queue_addr=f"ipc://{tmp_dir}/resp", |
80 | | - zmq_readiness_queue_addr=f"ipc://{tmp_dir}/ready", |
81 | | - ) |
82 | | - |
83 | | - client = HTTPEndpointClient(http_config, aiohttp_config, zmq_config) |
84 | 71 |
|
| 72 | + # Issue requests |
85 | 73 | logger.info(f"Sending {num_requests} requests...") |
86 | | - |
87 | | - # Send test requests |
88 | 74 | start_times: dict[str, float] = {} |
89 | | - sent_query_ids: list[str] = [] |
90 | | - issue_errors: list[str] = [] |
91 | 75 |
|
92 | | - # TODO: this might not work with a real vLLM/SGLang endpoint, fix this. |
93 | 76 | for i in range(num_requests): |
94 | 77 | query_id = f"probe-{i}" |
95 | | - query = Query( |
96 | | - id=query_id, |
97 | | - data={ |
98 | | - "prompt": test_prompt, |
99 | | - "model": model_name, |
100 | | - "max_tokens": 50, |
101 | | - "stream": False, |
102 | | - }, |
| 78 | + start_times[query_id] = time.time() |
| 79 | + client.issue_query( |
| 80 | + Query( |
| 81 | + id=query_id, |
| 82 | + data={ |
| 83 | + "prompt": prompt, |
| 84 | + "model": model, |
| 85 | + "max_tokens": 50, |
| 86 | + "stream": False, |
| 87 | + }, |
| 88 | + ) |
103 | 89 | ) |
104 | 90 |
|
105 | | - try: |
106 | | - start_times[query_id] = time.time() |
107 | | - await client.issue_query_async(query) |
108 | | - # Only track successfully issued queries |
109 | | - sent_query_ids.append(query_id) |
110 | | - except Exception as e: |
111 | | - issue_errors.append(f"{query_id}: Failed to issue - {str(e)[:50]}") |
112 | | - logger.warning(f"Failed to issue request {i}: {str(e)[:50]}") |
113 | | - continue |
114 | | - |
115 | | - # Simple progress indicator |
116 | | - if (i + 1) % max(1, num_requests // 10) == 0 or i == num_requests - 1: |
117 | | - logger.info(f" Issued {i + 1}/{num_requests} requests") |
118 | | - |
119 | | - # Wait for all responses |
| 91 | + # Collect responses |
| 92 | + logger.info(f"Waiting for {num_requests} responses...") |
120 | 93 | latencies: list[float] = [] |
121 | | - errors: list[str] = issue_errors # Include any issue errors |
| 94 | + errors: list[str] = [] |
122 | 95 | responses: list[tuple[str, str]] = [] |
| 96 | + received: set[str] = set() |
123 | 97 |
|
124 | | - # Only count successfully issued queries |
125 | | - num_expected = len(sent_query_ids) |
126 | | - if num_expected == 0: |
127 | | - logger.error("✗ No queries were successfully issued") |
128 | | - raise ExecutionError("Probe failed: no queries could be issued") |
129 | | - |
130 | | - # Wait for all responses with generous timeout (probe queries can be slow) |
131 | | - probe_timeout = 60.0 # 60 seconds total |
| 98 | + timeout = 60.0 |
132 | 99 | start_wait = time.time() |
133 | 100 |
|
134 | | - logger.info(f"Waiting for {num_expected} responses...") |
135 | | - |
136 | | - received_ids: set[str] = set() |
137 | | - |
138 | | - while ( |
139 | | - len(received_ids) < num_expected |
140 | | - and (time.time() - start_wait) < probe_timeout |
141 | | - ): |
142 | | - try: |
143 | | - result = await client.try_recv_response_async() |
144 | | - |
145 | | - if result is None: |
146 | | - await asyncio.sleep(0.01) |
147 | | - continue |
148 | | - |
149 | | - # Skip non-final streaming chunks |
150 | | - if not isinstance(result, QueryResult): |
151 | | - continue |
152 | | - |
153 | | - query_id = result.id |
154 | | - |
155 | | - if query_id in received_ids: |
156 | | - logger.warning(f"Received duplicate response for {query_id}") |
157 | | - continue |
158 | | - |
159 | | - received_ids.add(query_id) |
160 | | - |
161 | | - # Calculate latency - should always be in start_times for issued queries |
162 | | - if query_id not in start_times: |
163 | | - logger.warning( |
164 | | - f"Received response for unknown query_id: {query_id}, skipping" |
165 | | - ) |
166 | | - continue |
167 | | - latency_ms = (time.time() - start_times[query_id]) * 1000 |
168 | | - |
169 | | - if result.error: |
170 | | - errors.append(f"{query_id}: {result.error}") |
171 | | - else: |
172 | | - latencies.append(latency_ms) |
173 | | - responses.append((query_id, result.response_output)) |
174 | | - |
175 | | - # Simple progress indicator |
176 | | - if ( |
177 | | - len(received_ids) % max(1, num_expected // 10) == 0 |
178 | | - or len(received_ids) == num_expected |
179 | | - ): |
180 | | - logger.info( |
181 | | - f" Processed {len(received_ids)}/{num_expected} responses" |
182 | | - ) |
183 | | - |
184 | | - except Exception as e: |
185 | | - logger.warning(f"Error receiving response: {str(e)[:50]}") |
| 101 | + while len(received) < num_requests and (time.time() - start_wait) < timeout: |
| 102 | + result = await client.try_recv_response() |
| 103 | + |
| 104 | + if result is None: |
186 | 105 | await asyncio.sleep(0.01) |
| 106 | + continue |
| 107 | + |
| 108 | + if not isinstance(result, QueryResult): |
| 109 | + continue |
| 110 | + |
| 111 | + if result.id in received: |
| 112 | + continue |
| 113 | + |
| 114 | + received.add(result.id) |
| 115 | + latency_ms = (time.time() - start_times.get(result.id, time.time())) * 1000 |
187 | 116 |
|
188 | | - # Mark any issued but not received as timeout |
189 | | - for query_id in sent_query_ids: |
190 | | - if query_id not in received_ids: |
191 | | - errors.append(f"{query_id}: Timeout (>{probe_timeout}s)") |
| 117 | + if result.error: |
| 118 | + errors.append(f"{result.id}: {result.error}") |
| 119 | + else: |
| 120 | + latencies.append(latency_ms) |
| 121 | + responses.append((result.id, result.response_output)) |
192 | 122 |
|
193 | | - # Report results |
| 123 | + # Report timeouts |
| 124 | + for query_id in start_times: |
| 125 | + if query_id not in received: |
| 126 | + errors.append(f"{query_id}: Timeout") |
| 127 | + |
| 128 | + # Results |
194 | 129 | success_count = len(latencies) |
195 | | - logger.info(f"✓ Completed: {success_count}/{num_expected} successful") |
| 130 | + logger.info(f"✓ Completed: {success_count}/{num_requests} successful") |
196 | 131 |
|
197 | 132 | if latencies: |
198 | | - avg_latency = sum(latencies) / len(latencies) |
199 | | - logger.info(f"✓ Avg latency: {avg_latency:.0f}ms") |
| 133 | + logger.info(f"✓ Avg latency: {sum(latencies) / len(latencies):.0f}ms") |
200 | 134 | logger.info(f"✓ Range: {min(latencies):.0f}ms - {max(latencies):.0f}ms") |
201 | 135 |
|
202 | | - # Show sample responses for sanity check |
203 | 136 | if responses: |
204 | 137 | logger.info(f"✓ Sample responses ({len(responses)} collected):") |
205 | | - # Show first 10 responses |
206 | 138 | for query_id, response in responses[:10]: |
207 | | - # Truncate long responses |
208 | | - response_preview = ( |
209 | | - response[:100] + "..." if len(response) > 100 else response |
210 | | - ) |
211 | | - logger.info(f" [{query_id}] {response_preview}") |
| 139 | + preview = response[:100] + "..." if len(response) > 100 else response |
| 140 | + logger.info(f" [{query_id}] {preview}") |
212 | 141 |
|
213 | 142 | if errors: |
214 | 143 | logger.warning(f"⚠ Errors: {len(errors)}") |
215 | | - if args.verbose: |
| 144 | + if getattr(args, "verbose", 0): |
216 | 145 | for error in errors[:3]: |
217 | 146 | logger.warning(f" {error}") |
218 | | - if len(errors) > 3: |
219 | | - logger.warning(f" ... +{len(errors) - 3} more") |
220 | 147 |
|
221 | | - # Check if probe was successful |
222 | 148 | if success_count < num_requests * 0.5: |
223 | | - logger.error("✗ Probe failed: Too many errors") |
224 | 149 | raise ExecutionError( |
225 | | - f"Probe failed: only {success_count}/{num_requests} requests successful" |
| 150 | + f"Probe failed: {success_count}/{num_requests} successful" |
226 | 151 | ) |
227 | 152 |
|
228 | 153 | logger.info("✓ Probe successful") |
229 | 154 |
|
230 | 155 | except ExecutionError: |
231 | | - # Re-raise our own exceptions |
232 | 156 | raise |
233 | 157 | except Exception as e: |
234 | | - logger.error("✗ Probe failed") |
235 | | - raise SetupError(f"Probe setup failed: {e}") from e |
| 158 | + raise SetupError(f"Probe failed: {e}") from e |
236 | 159 | finally: |
237 | | - # Cleanup |
238 | | - if client is not None: |
239 | | - await client.shutdown_async() |
| 160 | + if client: |
| 161 | + client.shutdown() |
240 | 162 | shutil.rmtree(tmp_dir, ignore_errors=True) |
0 commit comments