22from dataclasses_json import dataclass_json , Undefined # type: ignore
33import requests
44from requests .structures import CaseInsensitiveDict
5- from typing import Optional , Dict , Any , Union
5+ from typing import Optional , Dict , Any , Union , Generator
66from urllib .parse import urlparse
77
88
@@ -14,10 +14,76 @@ class InferenceClientError(Exception):
1414@dataclass_json (undefined = Undefined .EXCLUDE )
1515@dataclass
1616class InferenceResponse :
17- body : Any
1817 headers : CaseInsensitiveDict [str ]
1918 status_code : int
2019 status_text : str
20+ _original_response : requests .Response
21+ _stream : bool = False
22+
23+ def _is_stream_response (self , headers : CaseInsensitiveDict [str ]) -> bool :
24+ """Check if the response headers indicate a streaming response.
25+
26+ Args:
27+ headers: The response headers to check
28+
29+ Returns:
30+ bool: True if the response is likely a stream, False otherwise
31+ """
32+ # Standard chunked transfer encoding
33+ is_chunked_transfer = headers .get (
34+ 'Transfer-Encoding' , '' ).lower () == 'chunked'
35+ # Server-Sent Events content type
36+ is_event_stream = headers .get (
37+ 'Content-Type' , '' ).lower () == 'text/event-stream'
38+ # NDJSON
39+ is_ndjson = headers .get (
40+ 'Content-Type' , '' ).lower () == 'application/x-ndjson'
41+ # Stream JSON
42+ is_stream_json = headers .get (
43+ 'Content-Type' , '' ).lower () == 'application/stream+json'
44+ # Keep-alive
45+ is_keep_alive = headers .get (
46+ 'Connection' , '' ).lower () == 'keep-alive'
47+ # No content length
48+ has_no_content_length = 'Content-Length' not in headers
49+
50+ # No Content-Length with keep-alive often suggests streaming (though not definitive)
51+ is_keep_alive_and_no_content_length = is_keep_alive and has_no_content_length
52+
53+ return (self ._stream or is_chunked_transfer or is_event_stream or is_ndjson or
54+ is_stream_json or is_keep_alive_and_no_content_length )
55+
56+ def output (self , is_text : bool = False ) -> Any :
57+ try :
58+ if is_text :
59+ return self ._original_response .text
60+ return self ._original_response .json ()
61+ except Exception as e :
62+ # if the response is a stream (check headers), raise relevant error
63+ if self ._is_stream_response (self ._original_response .headers ):
64+ raise InferenceClientError (
65+ f"Response might be a stream, use the stream method instead" )
66+ raise InferenceClientError (
67+ f"Failed to parse response as JSON: { str (e )} " )
68+
69+ def stream (self , chunk_size : int = 512 , as_text : bool = True ) -> Generator [Any , None , None ]:
70+ """Stream the response content.
71+
72+ Args:
73+ chunk_size: Size of chunks to stream, in bytes
74+ as_text: If True, stream as text using iter_lines. If False, stream as binary using iter_content.
75+
76+ Returns:
77+ Generator yielding chunks of the response
78+ """
79+ if as_text :
80+ for chunk in self ._original_response .iter_lines (chunk_size = chunk_size ):
81+ if chunk :
82+ yield chunk
83+ else :
84+ for chunk in self ._original_response .iter_content (chunk_size = chunk_size ):
85+ if chunk :
86+ yield chunk
2187
2288
2389@dataclass_json (undefined = Undefined .EXCLUDE )
@@ -169,24 +235,24 @@ def _make_request(self, method: str, path: str, **kwargs) -> requests.Response:
169235 except requests .exceptions .RequestException as e :
170236 raise InferenceClientError (f"Request to { path } failed: { str (e )} " )
171237
172- def run_sync (self , data : Dict [str , Any ], path : str = "" , timeout_seconds : int = 60 * 5 , headers : Optional [Dict [str , str ]] = None ):
173- response = self .post (
174- path , json = data , timeout_seconds = timeout_seconds , headers = headers )
238+ def run_sync (self , data : Dict [str , Any ], path : str = "" , timeout_seconds : int = 60 * 5 , headers : Optional [Dict [str , str ]] = None , http_method : str = "POST" , stream : bool = False ):
239+ response = self ._make_request (
240+ http_method , path , json = data , timeout_seconds = timeout_seconds , headers = headers , stream = stream )
175241
176242 return InferenceResponse (
177- body = response .json (),
178243 headers = response .headers ,
179244 status_code = response .status_code ,
180- status_text = response .reason
245+ status_text = response .reason ,
246+ _original_response = response
181247 )
182248
183- def run (self , data : Dict [str , Any ], path : str = "" , timeout_seconds : int = 60 * 5 , headers : Optional [Dict [str , str ]] = None ):
249+ def run (self , data : Dict [str , Any ], path : str = "" , timeout_seconds : int = 60 * 5 , headers : Optional [Dict [str , str ]] = None , http_method : str = "POST" ):
184250 # Add the "Prefer: respond-async" header to the request, to indicate that the request is async
185251 headers = headers or {}
186252 headers ['Prefer' ] = 'respond-async'
187253
188- response = self .post (
189- path , json = data , timeout_seconds = timeout_seconds , headers = headers )
254+ response = self ._make_request (
255+ http_method , path , json = data , timeout_seconds = timeout_seconds , headers = headers )
190256
191257 # TODO: this response format isn't final
192258 execution_id = response .json ()['id' ]
0 commit comments