Skip to content

Commit d4d5e16

Browse files
viraatcCopilot
andauthored
fix: sampling-param defaults, and make futures-http-client a testing util (#32)
* fix: futures client, sampling-params updates * Update src/inference_endpoint/config/schema.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 86d62dc commit d4d5e16

16 files changed

Lines changed: 292 additions & 666 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ where = ["src"]
7474
[tool.setuptools.package-dir]
7575
"" = "src"
7676

77-
77+
[tool.autopep8]
78+
max_line_length = 88
7879

7980
[tool.ruff]
8081
target-version = "py312"

src/inference_endpoint/commands/benchmark.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,10 @@ def _run_benchmark(
513513
"model": model_name,
514514
"stream": enable_streaming,
515515
"max_completion_tokens": max_tokens,
516+
"temperature": config.model_params.temperature,
517+
"top_p": config.model_params.top_p,
518+
"top_k": config.model_params.top_k,
519+
"repetition_penalty": config.model_params.repetition_penalty,
516520
},
517521
)
518522
dataloader.load()

src/inference_endpoint/commands/probe.py

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
import time
2424
from urllib.parse import urljoin
2525

26-
from inference_endpoint.core.types import Query
26+
from inference_endpoint.core.types import Query, QueryResult
2727
from inference_endpoint.endpoint_client.configs import (
2828
AioHttpConfig,
2929
HTTPClientConfig,
3030
ZMQConfig,
3131
)
32-
from inference_endpoint.endpoint_client.futures_client import FuturesHttpClient
32+
from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient
3333
from inference_endpoint.exceptions import (
3434
ExecutionError,
3535
InputValidationError,
@@ -80,75 +80,120 @@ async def run_probe_command(args: argparse.Namespace) -> None:
8080
zmq_readiness_queue_addr=f"ipc://{tmp_dir}/ready",
8181
)
8282

83-
client = FuturesHttpClient(http_config, aiohttp_config, zmq_config)
83+
client = HTTPEndpointClient(http_config, aiohttp_config, zmq_config)
8484
await client.async_start()
8585

8686
logger.info(f"Sending {num_requests} requests...")
8787

88-
# Send test requests and collect futures
89-
futures = []
90-
start_times = {}
88+
# Send test requests
89+
start_times: dict[str, float] = {}
90+
sent_query_ids: list[str] = []
91+
issue_errors: list[str] = []
9192

9293
# TODO: this might not work with a real vLLM/SGLang endpoint, fix this.
9394
for i in range(num_requests):
95+
query_id = f"probe-{i}"
9496
query = Query(
95-
id=f"probe-{i}",
97+
id=query_id,
9698
data={
9799
"prompt": test_prompt,
98100
"model": model_name,
99101
"max_tokens": 50,
100102
"stream": False,
101103
},
102104
)
103-
start_times[f"probe-{i}"] = time.time()
104105

105106
try:
106-
future = await client.issue_query(query)
107-
futures.append((f"probe-{i}", future))
108-
# Simple progress indicator
109-
if (i + 1) % max(1, num_requests // 10) == 0 or i == num_requests - 1:
110-
logger.info(f" Issued {i + 1}/{num_requests} requests")
107+
start_times[query_id] = time.time()
108+
await client.issue_query_async(query)
109+
# Only track successfully issued queries
110+
sent_query_ids.append(query_id)
111111
except Exception as e:
112+
issue_errors.append(f"{query_id}: Failed to issue - {str(e)[:50]}")
112113
logger.warning(f"Failed to issue request {i}: {str(e)[:50]}")
114+
continue
115+
116+
# Simple progress indicator
117+
if (i + 1) % max(1, num_requests // 10) == 0 or i == num_requests - 1:
118+
logger.info(f" Issued {i + 1}/{num_requests} requests")
113119

114120
# Wait for all responses
115-
latencies = []
116-
errors = []
117-
responses = []
121+
latencies: list[float] = []
122+
errors: list[str] = issue_errors # Include any issue errors
123+
responses: list[tuple[str, str]] = []
124+
125+
# Only count successfully issued queries
126+
num_expected = len(sent_query_ids)
127+
if num_expected == 0:
128+
logger.error("✗ No queries were successfully issued")
129+
raise ExecutionError("Probe failed: no queries could be issued")
118130

119131
# Wait for all responses with generous timeout (probe queries can be slow)
120-
# Default HTTP client timeout is 30s, give extra buffer for processing
121-
probe_timeout = 60.0 # 60 seconds per query
132+
probe_timeout = 60.0 # 60 seconds total
133+
start_wait = time.time()
134+
135+
logger.info(f"Waiting for {num_expected} responses...")
122136

123-
logger.info(f"Waiting for {len(futures)} responses...")
137+
received_ids: set[str] = set()
124138

125-
for idx, (query_id, future) in enumerate(futures):
139+
while (
140+
len(received_ids) < num_expected
141+
and (time.time() - start_wait) < probe_timeout
142+
):
126143
try:
127-
result = await asyncio.wait_for(future, timeout=probe_timeout)
128-
# Calculate latency - should always be in start_times
129-
assert (
130-
query_id in start_times
131-
), f"Query {query_id} not found in start_times"
144+
result = await client.get_ready_responses_async()
145+
146+
if result is None:
147+
await asyncio.sleep(0.01)
148+
continue
149+
150+
# Skip non-final streaming chunks
151+
if not isinstance(result, QueryResult):
152+
continue
153+
154+
query_id = result.id
155+
156+
if query_id in received_ids:
157+
logger.warning(f"Received duplicate response for {query_id}")
158+
continue
159+
160+
received_ids.add(query_id)
161+
162+
# Calculate latency - should always be in start_times for issued queries
163+
if query_id not in start_times:
164+
logger.warning(
165+
f"Received response for unknown query_id: {query_id}, skipping"
166+
)
167+
continue
132168
latency_ms = (time.time() - start_times[query_id]) * 1000
133-
latencies.append(latency_ms)
134169

135170
if result.error:
136171
errors.append(f"{query_id}: {result.error}")
137172
else:
138-
# Store successful response for sanity check
173+
latencies.append(latency_ms)
139174
responses.append((query_id, result.response_output))
140-
except TimeoutError:
141-
errors.append(f"{query_id}: Timeout (>{probe_timeout}s)")
175+
176+
# Simple progress indicator
177+
if (
178+
len(received_ids) % max(1, num_expected // 10) == 0
179+
or len(received_ids) == num_expected
180+
):
181+
logger.info(
182+
f" Processed {len(received_ids)}/{num_expected} responses"
183+
)
184+
142185
except Exception as e:
143-
errors.append(f"{query_id}: {str(e)[:50]}")
186+
logger.warning(f"Error receiving response: {str(e)[:50]}")
187+
await asyncio.sleep(0.01)
144188

145-
# Simple progress indicator
146-
if (idx + 1) % max(1, len(futures) // 10) == 0 or idx == len(futures) - 1:
147-
logger.info(f" Processed {idx + 1}/{len(futures)} responses")
189+
# Mark any issued but not received as timeout
190+
for query_id in sent_query_ids:
191+
if query_id not in received_ids:
192+
errors.append(f"{query_id}: Timeout (>{probe_timeout}s)")
148193

149194
# Report results
150195
success_count = len(latencies)
151-
logger.info(f"✓ Completed: {success_count}/{num_requests} successful")
196+
logger.info(f"✓ Completed: {success_count}/{num_expected} successful")
152197

153198
if latencies:
154199
avg_latency = sum(latencies) / len(latencies)

src/inference_endpoint/config/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ class ModelParams(BaseModel):
136136
"""Model generation parameters."""
137137

138138
name: str | None = None
139-
temperature: float = 0.7
139+
temperature: float | None = None
140140
top_k: int | None = None
141141
top_p: float | None = None
142+
repetition_penalty: float | None = None
142143
max_new_tokens: int = 1024
143144
osl_distribution: OSLDistribution | None = None
144145
streaming: StreamingMode = StreamingMode.AUTO

src/inference_endpoint/dataset_manager/dataloader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ def default_parser(x):
325325
def load(self):
326326
with open(self.file_path) as file:
327327
for line in file:
328-
self.data.append(self.parser(json.loads(line)))
328+
if line := line.strip():
329+
self.data.append(self.parser(json.loads(line)))
329330

330331
def load_sample(self, index: int) -> Any:
331332
return self.data[index]

src/inference_endpoint/endpoint_client/configs.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,20 +205,22 @@ class ZMQConfig:
205205
"""Configuration for ZMQ sockets and communication."""
206206

207207
# Main ZMQ settings
208-
zmq_io_threads: int = 4 # Number of ZMQ IO threads
209-
zmq_high_water_mark: int = 10_000 # max msg queue size
208+
zmq_io_threads: int = 4 # Number of ZMQ IO threads ; TODO(vir): needs to scale?
209+
zmq_high_water_mark: int = 0 # Max queue size per socket (0=unlimited)
210210

211211
# ZMQ addresses (use None for auto-generated prefixes using PID)
212212
zmq_request_queue_prefix: str | None = None
213213
zmq_response_queue_addr: str | None = None
214214
zmq_readiness_queue_addr: str | None = None
215215

216216
# ZMQ socket options
217-
zmq_linger: int = 0 # Don't block on close
218-
zmq_send_timeout: int = -1 # Non-blocking send
219-
zmq_recv_timeout: int = 100 # Timeout on receive() call
220-
zmq_recv_buffer_size: int = 10 * 1024 * 1024 # 10MB receive buffer
221-
zmq_send_buffer_size: int = 10 * 1024 * 1024 # 10MB send buffer
217+
zmq_linger: int = 0 # 0 = Don't block on close
218+
zmq_immediate: int = 1 # ensure messages only enqueued on READY connections
219+
zmq_send_timeout: int = -1 # -1 = Non-blocking send
220+
zmq_recv_timeout: int = 1 # Timeout on receive() in ms
221+
222+
zmq_recv_buffer_size: int = 10 * 1024 * 1024 # 10MB receive buffer (OS level)
223+
zmq_send_buffer_size: int = 10 * 1024 * 1024 # 10MB send buffer (OS level)
222224

223225
def __post_init__(self):
224226
"""Generate portable ZMQ socket paths if not provided."""

src/inference_endpoint/endpoint_client/http_sample_issuer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import asyncio
1919
import logging
2020
import threading
21-
from typing import Any
2221

2322
from inference_endpoint.core.types import Query, QueryResult, StreamChunk
2423
from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient
@@ -103,7 +102,18 @@ def issue(self, sample: Sample):
103102
if self.n_inflight == 0:
104103
self._client_idle_event.clear()
105104
self.n_inflight += 1
106-
self.http_client.issue_query(Query(id=sample.uuid, data=sample.data))
105+
self.http_client.issue_query(
106+
Query(
107+
id=sample.uuid,
108+
data=sample.data,
109+
headers={
110+
"Content-Type": "application/json",
111+
"Accept": "text/event-stream"
112+
if sample.data.get("stream", False)
113+
else "application/json",
114+
},
115+
)
116+
)
107117

108118
def wait_for_all_complete(self, timeout: float | None = None):
109119
"""Wait (blocking) for all pending queries to complete.
@@ -123,8 +133,3 @@ def shutdown(self):
123133

124134
if self.response_task:
125135
self.response_task.cancel()
126-
127-
def process_sample_data(self, s_uuid: int, sample_data: Any):
128-
raise NotImplementedError(
129-
"HttpClientSampleIssuer does not implement process_sample_data"
130-
)

src/inference_endpoint/endpoint_client/worker.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,18 +266,16 @@ async def _make_http_request(self, query: Query):
266266
return
267267

268268
url = self.http_config.endpoint_url
269-
headers = query.headers if hasattr(query, "headers") else {}
270-
271269
logging.debug(
272-
f"Making HTTP request to {url} with payload: {query} and headers: {headers}"
270+
f"Making HTTP request to {url} with query: {query} and headers: {query.headers}"
273271
)
274272

275273
# Encode query to bytes using adapter
276274
payload_bytes = self._adapter.encode_query(query)
277275

278276
# Issue the request with pre-encoded bytes
279277
async with self._session.post(
280-
url, data=payload_bytes, headers=headers
278+
url, data=payload_bytes, headers=query.headers
281279
) as response:
282280
if response.status != 200:
283281
error_text = await response.text()

src/inference_endpoint/endpoint_client/zmq_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,27 +76,37 @@ def close(self, linger_ms: int | None = None) -> None:
7676
class ZMQPushSocket(ZMQSocket):
7777
"""Async wrapper for ZMQ PUSH socket."""
7878

79-
def __init__(self, context: zmq.asyncio.Context, address: str, config: ZMQConfig):
79+
def __init__(
80+
self,
81+
context: zmq.asyncio.Context,
82+
address: str,
83+
config: ZMQConfig,
84+
bind: bool = False,
85+
):
8086
"""
8187
Initialize ZMQ push socket.
8288
8389
Args:
8490
context: ZMQ context
8591
address: Socket address
8692
config: ZMQ configuration
93+
bind: Whether to bind (True) or connect (False) to the address
8794
"""
88-
super().__init__(context, zmq.PUSH, address, config, bind=False)
95+
super().__init__(context, zmq.PUSH, address, config, bind=bind)
8996
self._encoder = msgspec.msgpack.Encoder()
9097

9198
def _set_socket_options(self, config: ZMQConfig) -> None:
9299
"""Set PUSH socket specific options."""
93100
self.socket.setsockopt(zmq.SNDHWM, config.zmq_high_water_mark)
94101
self.socket.setsockopt(zmq.SNDBUF, config.zmq_send_buffer_size)
95102
self.socket.setsockopt(zmq.SNDTIMEO, config.zmq_send_timeout)
103+
self.socket.setsockopt(zmq.IMMEDIATE, config.zmq_immediate)
96104

97105
@profile
98106
async def send(self, data: Any) -> None:
99-
"""Serialize to msgspec and send data through push socket."""
107+
"""
108+
Serialize to msgspec and send data through push socket.
109+
"""
100110
serialized = self._encoder.encode(data)
101111
await self.socket.send(serialized, flags=zmq.NOBLOCK, copy=False)
102112

src/inference_endpoint/openai/openai_msgspec_adapter.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,26 @@ class ChatMessage(msgspec.Struct, kw_only=True, omit_defaults=True):
3737

3838
role: str
3939
content: str
40-
name: str
40+
name: str | None = None
4141

4242

4343
class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True):
4444
"""OpenAI chat completion request."""
4545

4646
model: str
4747
messages: list[ChatMessage]
48-
temperature: float
49-
max_completion_tokens: int
50-
stream: bool
51-
top_p: float
52-
n: int
53-
stop: str | list[str]
54-
presence_penalty: float
55-
frequency_penalty: float
56-
logit_bias: dict[str, float]
57-
user: str
48+
temperature: float | None = None
49+
max_completion_tokens: int | None = None
50+
stream: bool | None = None
51+
top_p: float | None = None
52+
top_k: int | None = None
53+
repetition_penalty: float | None = None
54+
n: int | None = None
55+
stop: str | list[str] | None = None
56+
presence_penalty: float | None = None
57+
frequency_penalty: float | None = None
58+
logit_bias: dict[str, float] | None = None
59+
user: str | None = None
5860

5961

6062
class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults=True):
@@ -158,6 +160,8 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest:
158160
max_completion_tokens=query.data.get("max_completion_tokens"),
159161
temperature=query.data.get("temperature"),
160162
top_p=query.data.get("top_p"),
163+
top_k=query.data.get("top_k"),
164+
repetition_penalty=query.data.get("repetition_penalty"),
161165
n=query.data.get("n"),
162166
presence_penalty=query.data.get("presence_penalty"),
163167
frequency_penalty=query.data.get("frequency_penalty"),

0 commit comments

Comments
 (0)