Skip to content

Commit 68275a2

Browse files
committed
Created InferenceClient class, use that in Deployment
1 parent d92163c commit 68275a2

File tree

3 files changed

+229
-32
lines changed

3 files changed

+229
-32
lines changed

datacrunch/InferenceClient/__init__.py

Whitespace-only changes.
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from dataclasses import dataclass
2+
from dataclasses_json import dataclass_json, Undefined # type: ignore
3+
import requests
4+
from requests.structures import CaseInsensitiveDict
5+
from typing import Optional, Dict, Any, Union
6+
from urllib.parse import urlparse
7+
8+
9+
class InferenceClientError(Exception):
10+
"""Base exception for InferenceClient errors."""
11+
pass
12+
13+
14+
@dataclass_json(undefined=Undefined.EXCLUDE)
15+
@dataclass
16+
class InferenceResponse:
17+
body: Any
18+
headers: CaseInsensitiveDict[str]
19+
status_code: int
20+
status_text: str
21+
22+
23+
class InferenceClient:
24+
def __init__(self, inference_key: str, endpoint_base_url: str, timeout_seconds: int = 300) -> None:
25+
"""
26+
Initialize the InferenceClient.
27+
28+
Args:
29+
inference_key: The authentication key for the API
30+
endpoint_base_url: The base URL for the API
31+
timeout_seconds: Request timeout in seconds
32+
33+
Raises:
34+
InferenceClientError: If the parameters are invalid
35+
"""
36+
if not inference_key:
37+
raise InferenceClientError("inference_key cannot be empty")
38+
39+
parsed_url = urlparse(endpoint_base_url)
40+
if not parsed_url.scheme or not parsed_url.netloc:
41+
raise InferenceClientError("endpoint_base_url must be a valid URL")
42+
43+
self.inference_key = inference_key
44+
self.endpoint_base_url = endpoint_base_url.rstrip('/')
45+
self.timeout_seconds = timeout_seconds
46+
self._session = requests.Session()
47+
self._global_headers = {
48+
'Authorization': f'Bearer {inference_key}',
49+
'Content-Type': 'application/json'
50+
}
51+
52+
def __enter__(self):
53+
return self
54+
55+
def __exit__(self, exc_type, exc_val, exc_tb):
56+
self._session.close()
57+
58+
@property
59+
def global_headers(self) -> Dict[str, str]:
60+
"""
61+
Get the current global headers that will be used for all requests.
62+
63+
Returns:
64+
Dictionary of current global headers
65+
"""
66+
return self._global_headers.copy()
67+
68+
def set_global_header(self, key: str, value: str) -> None:
69+
"""
70+
Set or update a global header that will be used for all requests.
71+
72+
Args:
73+
key: Header name
74+
value: Header value
75+
"""
76+
self._global_headers[key] = value
77+
78+
def set_global_headers(self, headers: Dict[str, str]) -> None:
79+
"""
80+
Set multiple global headers at once that will be used for all requests.
81+
82+
Args:
83+
headers: Dictionary of headers to set globally
84+
"""
85+
self._global_headers.update(headers)
86+
87+
def remove_global_header(self, key: str) -> None:
88+
"""
89+
Remove a global header.
90+
91+
Args:
92+
key: Header name to remove from global headers
93+
"""
94+
if key in self._global_headers:
95+
del self._global_headers[key]
96+
97+
def _build_url(self, path: str) -> str:
98+
"""Construct the full URL by joining the base URL with the path."""
99+
return f"{self.endpoint_base_url}/{path.lstrip('/')}"
100+
101+
def _build_request_headers(self, request_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
102+
"""
103+
Build the final headers by merging global headers with request-specific headers.
104+
105+
Args:
106+
request_headers: Optional headers specific to this request
107+
108+
Returns:
109+
Merged headers dictionary
110+
"""
111+
headers = self._global_headers.copy()
112+
if request_headers:
113+
headers.update(request_headers)
114+
return headers
115+
116+
def _make_request(self, method: str, path: str, **kwargs) -> requests.Response:
117+
"""
118+
Make an HTTP request with error handling.
119+
120+
Args:
121+
method: HTTP method to use
122+
path: API endpoint path
123+
**kwargs: Additional arguments to pass to the request
124+
125+
Returns:
126+
Response object from the request
127+
128+
Raises:
129+
InferenceClientError: If the request fails
130+
"""
131+
timeout = kwargs.pop('timeout_seconds', self.timeout_seconds)
132+
try:
133+
response = self._session.request(
134+
method=method,
135+
url=self._build_url(path),
136+
headers=self._build_request_headers(
137+
kwargs.pop('headers', None)),
138+
timeout=timeout,
139+
**kwargs
140+
)
141+
response.raise_for_status()
142+
return response
143+
except requests.exceptions.Timeout:
144+
raise InferenceClientError(
145+
f"Request to {path} timed out after {timeout} seconds")
146+
except requests.exceptions.RequestException as e:
147+
raise InferenceClientError(f"Request to {path} failed: {str(e)}")
148+
149+
def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5, headers: Optional[Dict[str, str]] = None):
150+
response = self.post(
151+
path, json=data, timeout_seconds=timeout_seconds, headers=headers)
152+
153+
return InferenceResponse(
154+
body=response.json(),
155+
headers=response.headers,
156+
status_code=response.status_code,
157+
status_text=response.reason
158+
)
159+
160+
def get(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
161+
return self._make_request('GET', path, params=params, headers=headers, timeout_seconds=timeout_seconds)
162+
163+
def post(self, path: str, json: Optional[Dict[str, Any]] = None, data: Optional[Union[str, Dict[str, Any]]] = None,
164+
params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
165+
return self._make_request('POST', path, json=json, data=data, params=params, headers=headers, timeout_seconds=timeout_seconds)
166+
167+
def put(self, path: str, json: Optional[Dict[str, Any]] = None, data: Optional[Union[str, Dict[str, Any]]] = None,
168+
params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
169+
return self._make_request('PUT', path, json=json, data=data, params=params, headers=headers, timeout_seconds=timeout_seconds)
170+
171+
def delete(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
172+
return self._make_request('DELETE', path, params=params, headers=headers, timeout_seconds=timeout_seconds)
173+
174+
def patch(self, path: str, json: Optional[Dict[str, Any]] = None, data: Optional[Union[str, Dict[str, Any]]] = None,
175+
params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
176+
return self._make_request('PATCH', path, json=json, data=data, params=params, headers=headers, timeout_seconds=timeout_seconds)
177+
178+
def head(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
179+
return self._make_request('HEAD', path, params=params, headers=headers, timeout_seconds=timeout_seconds)
180+
181+
def options(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
182+
return self._make_request('OPTIONS', path, params=params, headers=headers, timeout_seconds=timeout_seconds)
183+
184+
def health(self) -> dict:
185+
"""
186+
Check the health status of the API.
187+
188+
Returns:
189+
dict: Health status information
190+
191+
Raises:
192+
InferenceClientError: If the health check fails
193+
"""
194+
try:
195+
response = self.get('/health')
196+
return response.json()
197+
except InferenceClientError as e:
198+
raise InferenceClientError(f"Health check failed: {str(e)}")

datacrunch/containers/containers.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import requests
2-
from requests.structures import CaseInsensitiveDict
32
from dataclasses import dataclass
43
from dataclasses_json import dataclass_json, Undefined # type: ignore
54
from typing import List, Optional, Dict, Any
65
from enum import Enum
76

87
from datacrunch.http_client.http_client import HTTPClient
8+
from datacrunch.InferenceClient.inference_client import InferenceClient, InferenceResponse
99

1010

1111
# API endpoints
@@ -238,7 +238,6 @@ class Deployment:
238238
:param endpoint_base_url: Optional base URL for the deployment endpoint
239239
:param scaling: Optional scaling configuration
240240
:param created_at: Optional timestamp when the deployment was created
241-
:param inference_key: Optional inference key for the deployment
242241
"""
243242
name: str
244243
container_registry_settings: ContainerRegistrySettings
@@ -249,13 +248,13 @@ class Deployment:
249248
scaling: Optional[ScalingOptions] = None
250249
created_at: Optional[str] = None
251250

252-
_inference_key: Optional[str] = None
251+
_inference_client: Optional[InferenceClient] = None
253252

254253
def __str__(self):
255254
"""String representation of the deployment, excluding sensitive information."""
256-
# Get all attributes except _inference_key
255+
# Get all attributes except _inference_client
257256
attrs = {k: v for k, v in self.__dict__.items() if k !=
258-
'_inference_key'}
257+
'_inference_client'}
259258
# Format each attribute
260259
attr_strs = [f"{k}={repr(v)}" for k, v in attrs.items()]
261260
return f"Deployment({', '.join(attr_strs)})"
@@ -265,38 +264,47 @@ def __repr__(self):
265264
return self.__str__()
266265

267266
@classmethod
268-
def from_dict_with_inference_key(cls, data: Dict[str, Any], inference_key: str = None, **kwargs) -> 'Deployment':
267+
def from_dict_with_inference_key(cls, data: Dict[str, Any], inference_key: str = None) -> 'Deployment':
269268
"""Create a Deployment instance from a dictionary with an inference key.
270269
271270
:param data: Dictionary containing deployment data
272271
:param inference_key: inference key to set on the deployment
273-
:param **kwargs: Additional arguments to pass to from_dict
274272
:return: Deployment instance
275273
"""
276274
deployment = Deployment.from_dict(data, infer_missing=True)
277-
deployment._inference_key = inference_key
275+
if inference_key and deployment.endpoint_base_url:
276+
deployment._inference_client = InferenceClient(
277+
inference_key=inference_key,
278+
endpoint_base_url=deployment.endpoint_base_url
279+
)
278280
return deployment
279281

280-
def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5):
281-
if self._inference_key is None:
282-
# TODO: do something better
283-
raise ValueError("Inference key is not set")
282+
def set_inference_client(self, inference_key: str) -> None:
283+
"""Set the inference client for this deployment.
284284
285-
response = requests.post(
286-
url=f"{self.endpoint_base_url}{path}",
287-
json=data,
288-
headers={"Authorization": f"Bearer {self._inference_key}"},
289-
timeout=timeout_seconds
285+
:param inference_key: The inference key to use for authentication
286+
:type inference_key: str
287+
:raises ValueError: If endpoint_base_url is not set
288+
"""
289+
if self.endpoint_base_url is None:
290+
raise ValueError(
291+
"Endpoint base URL must be set to use inference client")
292+
self._inference_client = InferenceClient(
293+
inference_key=inference_key,
294+
endpoint_base_url=self.endpoint_base_url
290295
)
291296

292-
return InferenceResponse(
293-
body=response.json(),
294-
headers=response.headers,
295-
status_code=response.status_code,
296-
status_text=response.reason
297-
)
297+
def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 * 5) -> InferenceResponse:
298+
if self._inference_client is None:
299+
if self.endpoint_base_url is None:
300+
raise ValueError(
301+
"Endpoint base URL must be set to use run_sync")
302+
raise ValueError(
303+
"Inference client not initialized. Use from_dict_with_inference_key or set_inference_client to initialize inference capabilities.")
304+
return self._inference_client.run_sync(data, path, timeout_seconds)
298305

299306
def health(self):
307+
# TODO: use inference client?
300308
healthcheck_path = "health"
301309
if self.containers and self.containers[0].healthcheck and self.containers[0].healthcheck.path:
302310
healthcheck_path = self.containers[0].healthcheck.path.lstrip('/')
@@ -310,15 +318,6 @@ def health(self):
310318
healthcheck = health
311319

312320

313-
@dataclass_json(undefined=Undefined.EXCLUDE)
314-
@dataclass
315-
class InferenceResponse:
316-
body: Any
317-
headers: CaseInsensitiveDict[str]
318-
status_code: int
319-
status_text: str
320-
321-
322321
@dataclass_json
323322
@dataclass
324323
class ReplicaInfo:

0 commit comments

Comments
 (0)