Skip to content

Commit 92a4633

Browse files
committed
improved docstrings, wip AsyncInferenceExecution
1 parent 68275a2 commit 92a4633

File tree

2 files changed

+428
-206
lines changed

2 files changed

+428
-206
lines changed

datacrunch/InferenceClient/inference_client.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,29 @@ class InferenceResponse:
2020
status_text: str
2121

2222

23+
@dataclass_json(undefined=Undefined.EXCLUDE)
24+
@dataclass
25+
class AsyncInferenceExecution:
26+
id: str
27+
status: str
28+
29+
_client: 'InferenceClient'
30+
31+
def status(self) -> str:
32+
# TODO: Call the status endpoint and update the status
33+
return self.status
34+
35+
# def cancel(self) -> None:
36+
# # TODO: Call the cancel inference executionendpoint
37+
# pass
38+
39+
def get_results(self) -> Dict[str, Any]:
40+
# TODO: Call the results endpoint
41+
pass
42+
43+
2344
class InferenceClient:
24-
def __init__(self, inference_key: str, endpoint_base_url: str, timeout_seconds: int = 300) -> None:
45+
def __init__(self, inference_key: str, endpoint_base_url: str, timeout_seconds: int = 60 * 5) -> None:
2546
"""
2647
Initialize the InferenceClient.
2748
@@ -157,6 +178,21 @@ def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int =
157178
status_text=response.reason
158179
)
159180

181+
def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None):
182+
# Add the "Prefer: respond-async" header to the request, to indicate that the request is async
183+
headers = headers or {}
184+
headers['Prefer'] = 'respond-async'
185+
186+
response = self.post(
187+
path, json=data, timeout_seconds=timeout_seconds, headers=headers)
188+
189+
# TODO: create an async response class:
190+
# TODO: add a method to check the status of the async request
191+
# TODO: add a method to cancel the async request
192+
# TODO: add a method to get the results of the async request
193+
194+
return '837cdf50-6cf1-44b0-884e-ed115e700480'
195+
160196
def get(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
161197
return self._make_request('GET', path, params=params, headers=headers, timeout_seconds=timeout_seconds)
162198

@@ -181,18 +217,17 @@ def head(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Opti
181217
def options(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
182218
return self._make_request('OPTIONS', path, params=params, headers=headers, timeout_seconds=timeout_seconds)
183219

184-
def health(self) -> dict:
220+
def health(self, healthcheck_path: str = "/health") -> requests.Response:
185221
"""
186222
Check the health status of the API.
187223
188224
Returns:
189-
dict: Health status information
225+
requests.Response: The response from the health check
190226
191227
Raises:
192228
InferenceClientError: If the health check fails
193229
"""
194230
try:
195-
response = self.get('/health')
196-
return response.json()
231+
return self.get(healthcheck_path)
197232
except InferenceClientError as e:
198233
raise InferenceClientError(f"Health check failed: {str(e)}")

0 commit comments

Comments
 (0)