@@ -23,22 +23,24 @@ class InferenceResponse:
2323@dataclass_json (undefined = Undefined .EXCLUDE )
2424@dataclass
2525class AsyncInferenceExecution :
26- id : str
27- status : str
28-
2926 _client : 'InferenceClient'
27+ id : str
28+ status : str # TODO: add a status enum
3029
30+ # TODO: Implement when the status endpoint is done
3131 def status (self ) -> str :
32- # TODO: Call the status endpoint and update the status
32+ # Call the status endpoint and update the status when
3333 return self .status
3434
35+ # TODO: Implement when the cancel inference execution endpoint is done
3536 # def cancel(self) -> None:
36- # # TODO: Call the cancel inference executionendpoint
3737 # pass
3838
39+ # TODO: Implement when the results endpoint is done
3940 def get_results (self ) -> Dict [str , Any ]:
40- # TODO: Call the results endpoint
4141 pass
42+ # alias for get_results
43+ output = get_results
4244
4345
4446class InferenceClient :
@@ -186,12 +188,10 @@ def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 *
186188 response = self .post (
187189 path , json = data , timeout_seconds = timeout_seconds , headers = headers )
188190
189- # TODO: create an async response class:
190- # TODO: add a method to check the status of the async request
191- # TODO: add a method to cancel the async request
192- # TODO: add a method to get the results of the async request
191+ # TODO: this response format isn't final
192+ execution_id = response .json ()['id' ]
193193
194- return '837cdf50-6cf1-44b0-884e-ed115e700480'
194+ return AsyncInferenceExecution ( self , execution_id )
195195
196196 def get (self , path : str , params : Optional [Dict [str , Any ]] = None , headers : Optional [Dict [str , str ]] = None , timeout_seconds : Optional [int ] = None ) -> requests .Response :
197197 return self ._make_request ('GET' , path , params = params , headers = headers , timeout_seconds = timeout_seconds )
0 commit comments