@@ -86,29 +86,6 @@ def stream(self, chunk_size: int = 512, as_text: bool = True) -> Generator[Any,
8686 yield chunk
8787
8888
89- @dataclass_json (undefined = Undefined .EXCLUDE )
90- @dataclass
91- class AsyncInferenceExecution :
92- _client : 'InferenceClient'
93- id : str
94- status : str # TODO: add a status enum
95-
96- # TODO: Implement when the status endpoint is done
97- def status (self ) -> str :
98- # Call the status endpoint and update the status when
99- return self .status
100-
101- # TODO: Implement when the cancel inference execution endpoint is done
102- # def cancel(self) -> None:
103- # pass
104-
105- # TODO: Implement when the results endpoint is done
106- def get_results (self ) -> Dict [str , Any ]:
107- pass
108- # alias for get_results
109- output = get_results
110-
111-
11289class InferenceClient :
11390 def __init__ (self , inference_key : str , endpoint_base_url : str , timeout_seconds : int = 60 * 5 ) -> None :
11491 """
@@ -131,6 +108,10 @@ def __init__(self, inference_key: str, endpoint_base_url: str, timeout_seconds:
131108
132109 self .inference_key = inference_key
133110 self .endpoint_base_url = endpoint_base_url .rstrip ('/' )
111+ self .base_domain = self .endpoint_base_url [:self .endpoint_base_url .rindex (
112+ '/' )]
113+ self .deployment_name = self .endpoint_base_url [self .endpoint_base_url .rindex (
114+ '/' )+ 1 :]
134115 self .timeout_seconds = timeout_seconds
135116 self ._session = requests .Session ()
136117 self ._global_headers = {
@@ -246,10 +227,17 @@ def run_sync(self, data: Dict[str, Any], path: str = "", timeout_seconds: int =
246227 _original_response = response
247228 )
248229
249- def run (self , data : Dict [str , Any ], path : str = "" , timeout_seconds : int = 60 * 5 , headers : Optional [Dict [str , str ]] = None , http_method : str = "POST" ):
250- # Add the "Prefer: respond-async" header to the request, to indicate that the request is async
230+ def run (self , data : Dict [str , Any ], path : str = "" , timeout_seconds : int = 60 * 5 , headers : Optional [Dict [str , str ]] = None , http_method : str = "POST" , no_response : bool = False ):
231+ # Add relevant headers to the request, to indicate that the request is async
251232 headers = headers or {}
252- headers ['Prefer' ] = 'respond-async'
233+ if no_response :
234+ # If no_response is True, use the "Prefer: respond-async-proxy" header to run async and don't wait for the response
235+ headers ['Prefer' ] = 'respond-async-proxy'
236+ self ._make_request (
237+ http_method , path , json = data , timeout_seconds = timeout_seconds , headers = headers )
238+ return
239+ # Add the "Prefer: async-inference" header to the request, to run async and wait for the response
240+ headers ['Prefer' ] = 'async-inference'
253241
254242 response = self ._make_request (
255243 http_method , path , json = data , timeout_seconds = timeout_seconds , headers = headers )
@@ -297,3 +285,41 @@ def health(self, healthcheck_path: str = "/health") -> requests.Response:
297285 return self .get (healthcheck_path )
298286 except InferenceClientError as e :
299287 raise InferenceClientError (f"Health check failed: { str (e )} " )
288+
289+
290+ @dataclass_json (undefined = Undefined .EXCLUDE )
291+ @dataclass
292+ class AsyncInferenceExecution :
293+ _inference_client : 'InferenceClient'
294+ id : str
295+ _status : str # TODO: add a status enum?
296+ INFERENCE_ID_HEADER = 'X-Inference-Id'
297+
298+ def status (self ) -> Dict [str , Any ]:
299+ """Get the current status of the async inference execution.
300+
301+ Returns:
302+ Dict[str, Any]: The status response containing the execution status and other metadata
303+ """
304+ url = f'{ self ._inference_client .base_domain } /status/{ self ._inference_client .deployment_name } '
305+ response = self ._inference_client ._session .get (
306+ url , headers = {self .INFERENCE_ID_HEADER : self .id , ** self ._inference_client ._global_headers })
307+
308+ response_json = response .json ()
309+ self ._status = response_json ['status' ]
310+
311+ return response_json
312+
313+ def result (self ) -> Dict [str , Any ]:
314+ """Get the results of the async inference execution.
315+
316+ Returns:
317+ Dict[str, Any]: The results of the inference execution
318+ """
319+ url = f'{ self ._inference_client .base_domain } /results/{ self ._inference_client .deployment_name } '
320+ response = self ._inference_client ._session .get (
321+ url , headers = {self .INFERENCE_ID_HEADER : self .id , ** self ._inference_client ._global_headers })
322+
323+ return response
324+ # alias for get_results
325+ output = result
0 commit comments