Skip to content

Commit d73f897

Browse files
committed
fix
1 parent 876be7a commit d73f897

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .inference_client import InferenceClient, InferenceResponse
2+
3+
__all__ = ['InferenceClient', 'InferenceResponse']

datacrunch/InferenceClient/inference_client.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,24 @@ class InferenceResponse:
2323
@dataclass_json(undefined=Undefined.EXCLUDE)
2424
@dataclass
2525
class AsyncInferenceExecution:
26-
id: str
27-
status: str
28-
2926
_client: 'InferenceClient'
27+
id: str
28+
status: str # TODO: add a status enum
3029

30+
# TODO: Implement when the status endpoint is done
3131
def status(self) -> str:
32-
# TODO: Call the status endpoint and update the status
32+
# Call the status endpoint and update the status when
3333
return self.status
3434

35+
# TODO: Implement when the cancel inference execution endpoint is done
3536
# def cancel(self) -> None:
36-
# # TODO: Call the cancel inference executionendpoint
3737
# pass
3838

39+
# TODO: Implement when the results endpoint is done
3940
def get_results(self) -> Dict[str, Any]:
40-
# TODO: Call the results endpoint
4141
pass
42+
# alias for get_results
43+
output = get_results
4244

4345

4446
class InferenceClient:
@@ -186,12 +188,10 @@ def run(self, data: Dict[str, Any], path: str = "", timeout_seconds: int = 60 *
186188
response = self.post(
187189
path, json=data, timeout_seconds=timeout_seconds, headers=headers)
188190

189-
# TODO: create an async response class:
190-
# TODO: add a method to check the status of the async request
191-
# TODO: add a method to cancel the async request
192-
# TODO: add a method to get the results of the async request
191+
# TODO: this response format isn't final
192+
execution_id = response.json()['id']
193193

194-
return '837cdf50-6cf1-44b0-884e-ed115e700480'
194+
return AsyncInferenceExecution(self, execution_id)
195195

196196
def get(self, path: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, timeout_seconds: Optional[int] = None) -> requests.Response:
197197
return self._make_request('GET', path, params=params, headers=headers, timeout_seconds=timeout_seconds)

datacrunch/containers/containers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from enum import Enum
1111

1212
from datacrunch.http_client.http_client import HTTPClient
13-
from datacrunch.InferenceClient.inference_client import InferenceClient, InferenceResponse
13+
from datacrunch.InferenceClient import InferenceClient, InferenceResponse
1414

1515

1616
# API endpoints

0 commit comments

Comments
 (0)