Skip to content

Commit a744acb

Browse files
committed
wip async inference
1 parent 4dc5ddb commit a744acb

File tree

4 files changed

+55
-29
lines changed

4 files changed

+55
-29
lines changed

datacrunch/InferenceClient/inference_client.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -86,29 +86,6 @@ def stream(self, chunk_size: int = 512, as_text: bool = True) -> Generator[Any,
8686
yield chunk
8787

8888

89-
@dataclass_json(undefined=Undefined.EXCLUDE)
90-
@dataclass
91-
class AsyncInferenceExecution:
92-
_client: 'InferenceClient'
93-
id: str
94-
status: str # TODO: add a status enum
95-
96-
# TODO: Implement when the status endpoint is done
97-
def status(self) -> str:
98-
# Call the status endpoint and update the status when
99-
return self.status
100-
101-
# TODO: Implement when the cancel inference execution endpoint is done
102-
# def cancel(self) -> None:
103-
# pass
104-
105-
# TODO: Implement when the results endpoint is done
106-
def get_results(self) -> Dict[str, Any]:
107-
pass
108-
# alias for get_results
109-
output = get_results
110-
111-
11289
class InferenceClient:
11390
def __init__(self, inference_key: str, endpoint_base_url: str, timeout_seconds: int = 60 * 5) -> None:
11491
"""
@@ -131,6 +108,10 @@ def __init__(self, inference_key: str, endpoint_base_url: str, timeout_seconds:
131108

132109
self.inference_key = inference_key
133110
self.endpoint_base_url = endpoint_base_url.rstrip('/')
111+
self.base_domain = self.endpoint_base_url[:self.endpoint_base_url.rindex(
112+
'/')]
113+
self.deployment_name = self.endpoint_base_url[self.endpoint_base_url.rindex(
114+
'/')+1:]
134115
self.timeout_seconds = timeout_seconds
135116
self._session = requests.Session()
136117
self._global_headers = {
@@ -246,10 +227,17 @@ def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int =
246227
_original_response=response
247228
)
248229

249-
def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None, http_method: str = "POST"):
250-
# Add the "Prefer: respond-async" header to the request, to indicate that the request is async
230+
def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None, http_method: str = "POST", no_response: bool = False):
231+
# Add relevant headers to the request, to indicate that the request is async
251232
headers = headers or {}
252-
headers['Prefer'] = 'respond-async'
233+
if no_response:
234+
# If no_response is True, use the "Prefer: respond-async-proxy" header to run async and don't wait for the response
235+
headers['Prefer'] = 'respond-async-proxy'
236+
self._make_request(
237+
http_method, path, json=data, timeout_seconds=timeout_seconds, headers=headers)
238+
return
239+
# Add the "Prefer: async-inference" header to the request, to run async and wait for the response
240+
headers['Prefer'] = 'async-inference'
253241

254242
response = self._make_request(
255243
http_method, path, json=data, timeout_seconds=timeout_seconds, headers=headers)
@@ -297,3 +285,41 @@ def health(self, healthcheck_path: str = "/health") -> requests.Response:
297285
return self.get(healthcheck_path)
298286
except InferenceClientError as e:
299287
raise InferenceClientError(f"Health check failed: {str(e)}")
288+
289+
290+
@dataclass_json(undefined=Undefined.EXCLUDE)
291+
@dataclass
292+
class AsyncInferenceExecution:
293+
_inference_client: 'InferenceClient'
294+
id: str
295+
_status: str # TODO: add a status enum?
296+
INFERENCE_ID_HEADER = 'X-Inference-Id'
297+
298+
def status(self) -> Dict[str, Any]:
299+
"""Get the current status of the async inference execution.
300+
301+
Returns:
302+
Dict[str, Any]: The status response containing the execution status and other metadata
303+
"""
304+
url = f'{self._inference_client.base_domain}/status/{self._inference_client.deployment_name}'
305+
response = self._inference_client._session.get(
306+
url, headers={self.INFERENCE_ID_HEADER: self.id, **self._inference_client._global_headers})
307+
308+
response_json = response.json()
309+
self._status = response_json['status']
310+
311+
return response_json
312+
313+
def result(self) -> Dict[str, Any]:
314+
"""Get the results of the async inference execution.
315+
316+
Returns:
317+
Dict[str, Any]: The results of the inference execution
318+
"""
319+
url = f'{self._inference_client.base_domain}/results/{self._inference_client.deployment_name}'
320+
response = self._inference_client._session.get(
321+
url, headers={self.INFERENCE_ID_HEADER: self.id, **self._inference_client._global_headers})
322+
323+
return response
324+
# alias for get_results
325+
output = result

examples/containers/container_deployments_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from datacrunch import DataCrunchClient
1111
from datacrunch.exceptions import APIException
12-
from datacrunch.containers.containers import (
12+
from datacrunch.containers import (
1313
Container,
1414
ComputeResource,
1515
EnvVar,

examples/containers/sglang_deployment_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from datetime import datetime
1313
from datacrunch import DataCrunchClient
1414
from datacrunch.exceptions import APIException
15-
from datacrunch.containers.containers import (
15+
from datacrunch.containers import (
1616
Container,
1717
ComputeResource,
1818
ScalingOptions,

examples/containers/update_deployment_scaling_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from datacrunch import DataCrunchClient
99
from datacrunch.exceptions import APIException
10-
from datacrunch.containers.containers import (
10+
from datacrunch.containers import (
1111
ScalingOptions,
1212
ScalingPolicy,
1313
ScalingTriggers,

0 commit comments

Comments
 (0)