Skip to content

Commit 1c31437

Browse files
committed
update http client apis
1 parent 27b9464 commit 1c31437

7 files changed

Lines changed: 230 additions & 428 deletions

File tree

src/inference_endpoint/commands/probe.py

Lines changed: 70 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Probe command implementation for endpoint health checking."""
16+
"""Probe command for endpoint health checking."""
1717

1818
import argparse
1919
import asyncio
@@ -24,12 +24,12 @@
2424
from urllib.parse import urljoin
2525

2626
from inference_endpoint.core.types import Query, QueryResult
27-
from inference_endpoint.endpoint_client.configs import (
27+
from inference_endpoint.endpoint_client import (
2828
AioHttpConfig,
2929
HTTPClientConfig,
30+
HTTPEndpointClient,
3031
ZMQConfig,
3132
)
32-
from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient
3333
from inference_endpoint.exceptions import (
3434
ExecutionError,
3535
InputValidationError,
@@ -40,201 +40,123 @@
4040

4141

4242
async def run_probe_command(args: argparse.Namespace) -> None:
43-
"""Run endpoint probe to validate connectivity and basic functionality.
44-
45-
Actions:
46-
1. Send test requests using HTTP client with futures
47-
2. Measure basic latency
48-
3. Report validation status
49-
"""
50-
# Extract arguments
43+
"""Probe endpoint to validate connectivity and measure latency."""
5144
endpoint = args.endpoint
5245
num_requests = args.requests
53-
test_prompt = args.prompt
46+
prompt = args.prompt
47+
model = getattr(args, "model", None)
5448

55-
# Model: use provided or default to valid OpenAI model name
56-
model_name = getattr(args, "model", None)
57-
if not model_name:
58-
logger.error("Model required: --model or specify in YAML config")
49+
if not model:
5950
raise InputValidationError("Model required: --model NAME")
60-
# Note: API key handling would go in HTTP client config if needed
6151

6252
logger.info(f"Probing: {endpoint}")
6353

64-
# Create temp directory for ZMQ
6554
tmp_dir = tempfile.mkdtemp(prefix="probe_")
6655
client = None
6756

68-
# TODO (Rashid): Add a health check with a separate timeout.
6957
try:
70-
# Setup HTTP client with futures support
71-
http_config = HTTPClientConfig(
72-
endpoint_url=urljoin(endpoint, "/v1/chat/completions"),
73-
num_workers=1,
74-
max_concurrency=num_requests,
58+
client = HTTPEndpointClient(
59+
HTTPClientConfig(
60+
endpoint_url=urljoin(endpoint, "/v1/chat/completions"),
61+
num_workers=1,
62+
max_concurrency=num_requests,
63+
),
64+
AioHttpConfig(),
65+
ZMQConfig(
66+
zmq_request_queue_prefix=f"ipc://{tmp_dir}/req",
67+
zmq_response_queue_addr=f"ipc://{tmp_dir}/resp",
68+
zmq_readiness_queue_addr=f"ipc://{tmp_dir}/ready",
69+
),
7570
)
76-
aiohttp_config = AioHttpConfig()
77-
zmq_config = ZMQConfig(
78-
zmq_request_queue_prefix=f"ipc://{tmp_dir}/req",
79-
zmq_response_queue_addr=f"ipc://{tmp_dir}/resp",
80-
zmq_readiness_queue_addr=f"ipc://{tmp_dir}/ready",
81-
)
82-
83-
client = HTTPEndpointClient(http_config, aiohttp_config, zmq_config)
8471

72+
# Issue requests
8573
logger.info(f"Sending {num_requests} requests...")
86-
87-
# Send test requests
8874
start_times: dict[str, float] = {}
89-
sent_query_ids: list[str] = []
90-
issue_errors: list[str] = []
9175

92-
# TODO: this might not work with a real vLLM/SGLang endpoint, fix this.
9376
for i in range(num_requests):
9477
query_id = f"probe-{i}"
95-
query = Query(
96-
id=query_id,
97-
data={
98-
"prompt": test_prompt,
99-
"model": model_name,
100-
"max_tokens": 50,
101-
"stream": False,
102-
},
78+
start_times[query_id] = time.time()
79+
client.issue_query(
80+
Query(
81+
id=query_id,
82+
data={
83+
"prompt": prompt,
84+
"model": model,
85+
"max_tokens": 50,
86+
"stream": False,
87+
},
88+
)
10389
)
10490

105-
try:
106-
start_times[query_id] = time.time()
107-
await client.issue_query_async(query)
108-
# Only track successfully issued queries
109-
sent_query_ids.append(query_id)
110-
except Exception as e:
111-
issue_errors.append(f"{query_id}: Failed to issue - {str(e)[:50]}")
112-
logger.warning(f"Failed to issue request {i}: {str(e)[:50]}")
113-
continue
114-
115-
# Simple progress indicator
116-
if (i + 1) % max(1, num_requests // 10) == 0 or i == num_requests - 1:
117-
logger.info(f" Issued {i + 1}/{num_requests} requests")
118-
119-
# Wait for all responses
91+
# Collect responses
92+
logger.info(f"Waiting for {num_requests} responses...")
12093
latencies: list[float] = []
121-
errors: list[str] = issue_errors # Include any issue errors
94+
errors: list[str] = []
12295
responses: list[tuple[str, str]] = []
96+
received: set[str] = set()
12397

124-
# Only count successfully issued queries
125-
num_expected = len(sent_query_ids)
126-
if num_expected == 0:
127-
logger.error("✗ No queries were successfully issued")
128-
raise ExecutionError("Probe failed: no queries could be issued")
129-
130-
# Wait for all responses with generous timeout (probe queries can be slow)
131-
probe_timeout = 60.0 # 60 seconds total
98+
timeout = 60.0
13299
start_wait = time.time()
133100

134-
logger.info(f"Waiting for {num_expected} responses...")
135-
136-
received_ids: set[str] = set()
137-
138-
while (
139-
len(received_ids) < num_expected
140-
and (time.time() - start_wait) < probe_timeout
141-
):
142-
try:
143-
result = await client.try_recv_response_async()
144-
145-
if result is None:
146-
await asyncio.sleep(0.01)
147-
continue
148-
149-
# Skip non-final streaming chunks
150-
if not isinstance(result, QueryResult):
151-
continue
152-
153-
query_id = result.id
154-
155-
if query_id in received_ids:
156-
logger.warning(f"Received duplicate response for {query_id}")
157-
continue
158-
159-
received_ids.add(query_id)
160-
161-
# Calculate latency - should always be in start_times for issued queries
162-
if query_id not in start_times:
163-
logger.warning(
164-
f"Received response for unknown query_id: {query_id}, skipping"
165-
)
166-
continue
167-
latency_ms = (time.time() - start_times[query_id]) * 1000
168-
169-
if result.error:
170-
errors.append(f"{query_id}: {result.error}")
171-
else:
172-
latencies.append(latency_ms)
173-
responses.append((query_id, result.response_output))
174-
175-
# Simple progress indicator
176-
if (
177-
len(received_ids) % max(1, num_expected // 10) == 0
178-
or len(received_ids) == num_expected
179-
):
180-
logger.info(
181-
f" Processed {len(received_ids)}/{num_expected} responses"
182-
)
183-
184-
except Exception as e:
185-
logger.warning(f"Error receiving response: {str(e)[:50]}")
101+
while len(received) < num_requests and (time.time() - start_wait) < timeout:
102+
result = await client.try_recv_response()
103+
104+
if result is None:
186105
await asyncio.sleep(0.01)
106+
continue
107+
108+
if not isinstance(result, QueryResult):
109+
continue
110+
111+
if result.id in received:
112+
continue
113+
114+
received.add(result.id)
115+
latency_ms = (time.time() - start_times.get(result.id, time.time())) * 1000
187116

188-
# Mark any issued but not received as timeout
189-
for query_id in sent_query_ids:
190-
if query_id not in received_ids:
191-
errors.append(f"{query_id}: Timeout (>{probe_timeout}s)")
117+
if result.error:
118+
errors.append(f"{result.id}: {result.error}")
119+
else:
120+
latencies.append(latency_ms)
121+
responses.append((result.id, result.response_output))
192122

193-
# Report results
123+
# Report timeouts
124+
for query_id in start_times:
125+
if query_id not in received:
126+
errors.append(f"{query_id}: Timeout")
127+
128+
# Results
194129
success_count = len(latencies)
195-
logger.info(f"✓ Completed: {success_count}/{num_expected} successful")
130+
logger.info(f"✓ Completed: {success_count}/{num_requests} successful")
196131

197132
if latencies:
198-
avg_latency = sum(latencies) / len(latencies)
199-
logger.info(f"✓ Avg latency: {avg_latency:.0f}ms")
133+
logger.info(f"✓ Avg latency: {sum(latencies) / len(latencies):.0f}ms")
200134
logger.info(f"✓ Range: {min(latencies):.0f}ms - {max(latencies):.0f}ms")
201135

202-
# Show sample responses for sanity check
203136
if responses:
204137
logger.info(f"✓ Sample responses ({len(responses)} collected):")
205-
# Show first 10 responses
206138
for query_id, response in responses[:10]:
207-
# Truncate long responses
208-
response_preview = (
209-
response[:100] + "..." if len(response) > 100 else response
210-
)
211-
logger.info(f" [{query_id}] {response_preview}")
139+
preview = response[:100] + "..." if len(response) > 100 else response
140+
logger.info(f" [{query_id}] {preview}")
212141

213142
if errors:
214143
logger.warning(f"⚠ Errors: {len(errors)}")
215-
if args.verbose:
144+
if getattr(args, "verbose", 0):
216145
for error in errors[:3]:
217146
logger.warning(f" {error}")
218-
if len(errors) > 3:
219-
logger.warning(f" ... +{len(errors) - 3} more")
220147

221-
# Check if probe was successful
222148
if success_count < num_requests * 0.5:
223-
logger.error("✗ Probe failed: Too many errors")
224149
raise ExecutionError(
225-
f"Probe failed: only {success_count}/{num_requests} requests successful"
150+
f"Probe failed: {success_count}/{num_requests} successful"
226151
)
227152

228153
logger.info("✓ Probe successful")
229154

230155
except ExecutionError:
231-
# Re-raise our own exceptions
232156
raise
233157
except Exception as e:
234-
logger.error("✗ Probe failed")
235-
raise SetupError(f"Probe setup failed: {e}") from e
158+
raise SetupError(f"Probe failed: {e}") from e
236159
finally:
237-
# Cleanup
238-
if client is not None:
239-
await client.shutdown_async()
160+
if client:
161+
client.shutdown()
240162
shutil.rmtree(tmp_dir, ignore_errors=True)

0 commit comments

Comments
 (0)