@@ -20,8 +20,29 @@ class InferenceResponse:
2020 status_text : str
2121
2222
23+ @dataclass_json (undefined = Undefined .EXCLUDE )
24+ @dataclass
25+ class AsyncInferenceExecution :
26+ id : str
27+ status : str
28+
29+ _client : 'InferenceClient'
30+
31+ def status (self ) -> str :
32+ # TODO: Call the status endpoint and update the status
33+ return self .status
34+
35+ # def cancel(self) -> None:
36+ # # TODO: Call the cancel inference executionendpoint
37+ # pass
38+
39+ def get_results (self ) -> Dict [str , Any ]:
40+ # TODO: Call the results endpoint
41+ pass
42+
43+
2344class InferenceClient :
24- def __init__ (self , inference_key : str , endpoint_base_url : str , timeout_seconds : int = 300 ) -> None :
45+ def __init__ (self , inference_key : str , endpoint_base_url : str , timeout_seconds : int = 60 * 5 ) -> None :
2546 """
2647 Initialize the InferenceClient.
2748
@@ -157,6 +178,21 @@ def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int =
157178 status_text = response .reason
158179 )
159180
181+ def run (self , data : Dict [str , Any ], path : str = "" , timeout_seconds : int = 60 * 5 , headers : Optional [Dict [str , str ]] = None ):
182+ # Add the "Prefer: respond-async" header to the request, to indicate that the request is async
183+ headers = headers or {}
184+ headers ['Prefer' ] = 'respond-async'
185+
186+ response = self .post (
187+ path , json = data , timeout_seconds = timeout_seconds , headers = headers )
188+
189+ # TODO: create an async response class:
190+ # TODO: add a method to check the status of the async request
191+ # TODO: add a method to cancel the async request
192+ # TODO: add a method to get the results of the async request
193+
194+ return '837cdf50-6cf1-44b0-884e-ed115e700480'
195+
160196 def get (self , path : str , params : Optional [Dict [str , Any ]] = None , headers : Optional [Dict [str , str ]] = None , timeout_seconds : Optional [int ] = None ) -> requests .Response :
161197 return self ._make_request ('GET' , path , params = params , headers = headers , timeout_seconds = timeout_seconds )
162198
@@ -181,18 +217,17 @@ def head(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Opti
181217 def options (self , path : str , params : Optional [Dict [str , Any ]] = None , headers : Optional [Dict [str , str ]] = None , timeout_seconds : Optional [int ] = None ) -> requests .Response :
182218 return self ._make_request ('OPTIONS' , path , params = params , headers = headers , timeout_seconds = timeout_seconds )
183219
184- def health (self ) -> dict :
220+ def health (self , healthcheck_path : str = "/health" ) -> requests . Response :
185221 """
186222 Check the health status of the API.
187223
188224 Returns:
189- dict: Health status information
225+ requests.Response: The response from the health check
190226
191227 Raises:
192228 InferenceClientError: If the health check fails
193229 """
194230 try :
195- response = self .get ('/health' )
196- return response .json ()
231+ return self .get (healthcheck_path )
197232 except InferenceClientError as e :
198233 raise InferenceClientError (f"Health check failed: { str (e )} " )
0 commit comments