Skip to content

Commit 1527c3c

Browse files
committed
wip added inference streaming option, improve example
1 parent ea069f9 commit 1527c3c

File tree

3 files changed

+127
-27
lines changed

3 files changed

+127
-27
lines changed

datacrunch/InferenceClient/inference_client.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses_json import dataclass_json, Undefined # type: ignore
33
import requests
44
from requests.structures import CaseInsensitiveDict
5-
from typing import Optional, Dict, Any, Union
5+
from typing import Optional, Dict, Any, Union, Generator
66
from urllib.parse import urlparse
77

88

@@ -14,10 +14,76 @@ class InferenceClientError(Exception):
1414
@dataclass_json(undefined=Undefined.EXCLUDE)
1515
@dataclass
1616
class InferenceResponse:
17-
body: Any
1817
headers: CaseInsensitiveDict[str]
1918
status_code: int
2019
status_text: str
20+
_original_response: requests.Response
21+
_stream: bool = False
22+
23+
def _is_stream_response(self, headers: CaseInsensitiveDict[str]) -> bool:
24+
"""Check if the response headers indicate a streaming response.
25+
26+
Args:
27+
headers: The response headers to check
28+
29+
Returns:
30+
bool: True if the response is likely a stream, False otherwise
31+
"""
32+
# Standard chunked transfer encoding
33+
is_chunked_transfer = headers.get(
34+
'Transfer-Encoding', '').lower() == 'chunked'
35+
# Server-Sent Events content type
36+
is_event_stream = headers.get(
37+
'Content-Type', '').lower() == 'text/event-stream'
38+
# NDJSON
39+
is_ndjson = headers.get(
40+
'Content-Type', '').lower() == 'application/x-ndjson'
41+
# Stream JSON
42+
is_stream_json = headers.get(
43+
'Content-Type', '').lower() == 'application/stream+json'
44+
# Keep-alive
45+
is_keep_alive = headers.get(
46+
'Connection', '').lower() == 'keep-alive'
47+
# No content length
48+
has_no_content_length = 'Content-Length' not in headers
49+
50+
# No Content-Length with keep-alive often suggests streaming (though not definitive)
51+
is_keep_alive_and_no_content_length = is_keep_alive and has_no_content_length
52+
53+
return (self._stream or is_chunked_transfer or is_event_stream or is_ndjson or
54+
is_stream_json or is_keep_alive_and_no_content_length)
55+
56+
def output(self, is_text: bool = False) -> Any:
57+
try:
58+
if is_text:
59+
return self._original_response.text
60+
return self._original_response.json()
61+
except Exception as e:
62+
# if the response is a stream (check headers), raise relevant error
63+
if self._is_stream_response(self._original_response.headers):
64+
raise InferenceClientError(
65+
f"Response might be a stream, use the stream method instead")
66+
raise InferenceClientError(
67+
f"Failed to parse response as JSON: {str(e)}")
68+
69+
def stream(self, chunk_size: int = 512, as_text: bool = True) -> Generator[Any, None, None]:
70+
"""Stream the response content.
71+
72+
Args:
73+
chunk_size: Size of chunks to stream, in bytes
74+
as_text: If True, stream as text using iter_lines. If False, stream as binary using iter_content.
75+
76+
Returns:
77+
Generator yielding chunks of the response
78+
"""
79+
if as_text:
80+
for chunk in self._original_response.iter_lines(chunk_size=chunk_size):
81+
if chunk:
82+
yield chunk
83+
else:
84+
for chunk in self._original_response.iter_content(chunk_size=chunk_size):
85+
if chunk:
86+
yield chunk
2187

2288

2389
@dataclass_json(undefined=Undefined.EXCLUDE)
@@ -169,24 +235,24 @@ def _make_request(self, method: str, path: str, **kwargs) -> requests.Response:
169235
except requests.exceptions.RequestException as e:
170236
raise InferenceClientError(f"Request to {path} failed: {str(e)}")
171237

172-
def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None):
173-
response = self.post(
174-
path, json=data, timeout_seconds=timeout_seconds, headers=headers)
238+
def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None, http_method: str = "POST", stream: bool = False):
239+
response = self._make_request(
240+
http_method, path, json=data, timeout_seconds=timeout_seconds, headers=headers, stream=stream)
175241

176242
return InferenceResponse(
177-
body=response.json(),
178243
headers=response.headers,
179244
status_code=response.status_code,
180-
status_text=response.reason
245+
status_text=response.reason,
246+
_original_response=response
181247
)
182248

183-
def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None):
249+
def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None, http_method: str = "POST"):
184250
# Add the "Prefer: respond-async" header to the request, to indicate that the request is async
185251
headers = headers or {}
186252
headers['Prefer'] = 'respond-async'
187253

188-
response = self.post(
189-
path, json=data, timeout_seconds=timeout_seconds, headers=headers)
254+
response = self._make_request(
255+
http_method, path, json=data, timeout_seconds=timeout_seconds, headers=headers)
190256

191257
# TODO: this response format isn't final
192258
execution_id = response.json()['id']

datacrunch/containers/containers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,16 @@ def _validate_inference_client(self) -> None:
358358
raise ValueError(
359359
"Inference client not initialized. Use from_dict_with_inference_key or set_inference_client to initialize inference capabilities.")
360360

361-
def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None) -> InferenceResponse:
361+
def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None, http_method: str = "POST", stream: bool = False) -> InferenceResponse:
362362
"""Runs a synchronous inference request.
363363
364364
Args:
365365
data: The data to send in the request.
366366
path: The endpoint path to send the request to.
367367
timeout_seconds: Maximum time to wait for the response.
368368
headers: Optional headers to include in the request.
369+
http_method: The HTTP method to use for the request.
370+
stream: Whether to stream the response.
369371
370372
Returns:
371373
InferenceResponse: The response from the inference request.
@@ -374,16 +376,18 @@ def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int =
374376
ValueError: If the inference client is not initialized.
375377
"""
376378
self._validate_inference_client()
377-
return self._inference_client.run_sync(data, path, timeout_seconds, headers)
379+
return self._inference_client.run_sync(data, path, timeout_seconds, headers, http_method, stream)
378380

379-
def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None):
381+
def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None, http_method: str = "POST", stream: bool = False):
380382
"""Runs an asynchronous inference request.
381383
382384
Args:
383385
data: The data to send in the request.
384386
path: The endpoint path to send the request to.
385387
timeout_seconds: Maximum time to wait for the response.
386388
headers: Optional headers to include in the request.
389+
http_method: The HTTP method to use for the request.
390+
stream: Whether to stream the response.
387391
388392
Returns:
389393
The response from the inference request.
@@ -392,7 +396,7 @@ def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 *
392396
ValueError: If the inference client is not initialized.
393397
"""
394398
self._validate_inference_client()
395-
return self._inference_client.run(data, path, timeout_seconds, headers)
399+
return self._inference_client.run(data, path, timeout_seconds, headers, http_method, stream)
396400

397401
def health(self):
398402
"""Checks the health of the deployed application.

examples/containers/sglang_deployment_example.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import time
99
import signal
1010
import sys
11+
import json
1112
from datetime import datetime
1213
from datacrunch import DataCrunchClient
1314
from datacrunch.exceptions import APIException
@@ -33,9 +34,9 @@
3334

3435
# Configuration constants
3536
DEPLOYMENT_NAME = f"sglang-deployment-example-{CURRENT_TIMESTAMP}"
36-
MODEL_PATH = "deepseek-ai/deepseek-llm-7b-chat"
37+
SGLANG_IMAGE_URL = "docker.io/lmsysorg/sglang:v0.4.1.post6-cu124"
38+
DEEPSEEK_MODEL_PATH = "deepseek-ai/deepseek-llm-7b-chat"
3739
HF_SECRET_NAME = "huggingface-token"
38-
IMAGE_URL = "docker.io/lmsysorg/sglang:v0.4.1.post6-cu124"
3940

4041
# Get confidential values from environment variables
4142
DATACRUNCH_CLIENT_ID = os.environ.get('DATACRUNCH_CLIENT_ID')
@@ -140,18 +141,19 @@ def graceful_shutdown(signum, frame) -> None:
140141
sys.exit(1)
141142

142143
# Create container configuration
144+
APP_PORT = 30000
143145
container = Container(
144-
image=IMAGE_URL,
145-
exposed_port=30000,
146+
image=SGLANG_IMAGE_URL,
147+
exposed_port=APP_PORT,
146148
healthcheck=HealthcheckSettings(
147149
enabled=True,
148-
port=30000,
150+
port=APP_PORT,
149151
path="/health"
150152
),
151153
entrypoint_overrides=EntrypointOverridesSettings(
152154
enabled=True,
153155
cmd=["python3", "-m", "sglang.launch_server", "--model-path",
154-
MODEL_PATH, "--host", "0.0.0.0", "--port", "30000"]
156+
DEEPSEEK_MODEL_PATH, "--host", "0.0.0.0", "--port", str(APP_PORT)]
155157
),
156158
env=[
157159
EnvVar(
@@ -162,16 +164,19 @@ def graceful_shutdown(signum, frame) -> None:
162164
]
163165
)
164166

165-
# Create scaling configuration - default values
167+
# Create scaling configuration
166168
scaling_options = ScalingOptions(
167169
min_replica_count=1,
168-
max_replica_count=2,
169-
scale_down_policy=ScalingPolicy(delay_seconds=300),
170-
scale_up_policy=ScalingPolicy(delay_seconds=300),
170+
max_replica_count=5,
171+
scale_down_policy=ScalingPolicy(delay_seconds=60 * 5),
172+
scale_up_policy=ScalingPolicy(
173+
delay_seconds=0), # No delay for scale up
171174
queue_message_ttl_seconds=500,
172-
concurrent_requests_per_replica=1,
175+
# Modern LLM engines are optimized for batching requests, with minimal performance impact. Taking advantage of batching can significantly improve throughput.
176+
concurrent_requests_per_replica=32,
173177
scaling_triggers=ScalingTriggers(
174-
queue_load=QueueLoadScalingTrigger(threshold=1),
178+
# lower value means more aggressive scaling
179+
queue_load=QueueLoadScalingTrigger(threshold=0.1),
175180
cpu_utilization=UtilizationScalingTrigger(
176181
enabled=True,
177182
threshold=90
@@ -224,7 +229,7 @@ def graceful_shutdown(signum, frame) -> None:
224229
# Test completions endpoint
225230
print("\nTesting completions API...")
226231
completions_data = {
227-
"model": MODEL_PATH,
232+
"model": DEEPSEEK_MODEL_PATH,
228233
"prompt": "Is consciousness fundamentally computational, or is there something more to subjective experience that cannot be reduced to information processing?",
229234
"max_tokens": 128,
230235
"temperature": 0.7,
@@ -239,6 +244,31 @@ def graceful_shutdown(signum, frame) -> None:
239244
print("Completions API is working!")
240245
print(f"Response: {completions_response}")
241246

247+
# Make a stream sync inference request to the SGLang server
248+
completions_response_stream = created_deployment.run_sync(
249+
completions_data,
250+
path="/v1/completions",
251+
stream=True
252+
)
253+
print("Stream completions API is working!")
254+
# Print the streamed response
255+
for line in completions_response_stream.stream(as_text=True):
256+
if line:
257+
line = line.decode('utf-8')
258+
259+
if line.startswith('data:'):
260+
data = line[5:] # Remove 'data: ' prefix
261+
if data == '[DONE]':
262+
break
263+
try:
264+
event_data = json.loads(data)
265+
token_text = event_data['choices'][0]['text']
266+
267+
# Print token immediately to show progress
268+
print(token_text, end='', flush=True)
269+
except json.JSONDecodeError:
270+
continue
271+
242272
except Exception as e:
243273
print(f"Error testing deployment: {e}")
244274

0 commit comments

Comments
 (0)