@@ -121,8 +121,19 @@ def send(
121121 self ,
122122 model : str ,
123123 input : str | InputObject | dict ,
124- ) -> SendResponse :
125- """Send a completion request to the Edgee AI Gateway."""
124+ stream : bool = False ,
125+ ):
126+ """Send a completion request to the Edgee AI Gateway.
127+
128+ Args:
129+ model: The model to use for completion
130+ input: The input (string, dict, or InputObject)
131+ stream: If True, returns a generator yielding StreamChunk objects.
132+ If False, returns a SendResponse object.
133+
134+ Returns:
135+ SendResponse if stream=False, or a generator yielding StreamChunk objects if stream=True.
136+ """
126137
127138 if isinstance (input , str ):
128139 messages = [{"role" : "user" , "content" : input }]
@@ -138,6 +149,8 @@ def send(
138149 tool_choice = input .get ("tool_choice" )
139150
140151 body : dict = {"model" : model , "messages" : messages }
152+ if stream :
153+ body ["stream" ] = True
141154 if tools :
142155 body ["tools" ] = tools
143156 if tool_choice :
@@ -153,6 +166,13 @@ def send(
153166 method = "POST" ,
154167 )
155168
169+ if stream :
170+ return self ._handle_streaming_response (request )
171+ else :
172+ return self ._handle_non_streaming_response (request )
173+
174+ def _handle_non_streaming_response (self , request : Request ) -> SendResponse :
175+ """Handle non-streaming response."""
156176 try :
157177 with urlopen (request ) as response :
158178 data = json .loads (response .read ().decode ("utf-8" ))
@@ -179,49 +199,11 @@ def send(
179199
180200 return SendResponse (choices = choices , usage = usage )
181201
182- def stream (
183- self ,
184- model : str ,
185- input : str | InputObject | dict ,
186- ):
187- """Stream a completion request from the Edgee AI Gateway.
188-
189- Yields StreamChunk objects as they arrive from the API.
190- """
191-
192- if isinstance (input , str ):
193- messages = [{"role" : "user" , "content" : input }]
194- tools = None
195- tool_choice = None
196- elif isinstance (input , InputObject ):
197- messages = input .messages
198- tools = input .tools
199- tool_choice = input .tool_choice
200- else :
201- messages = input .get ("messages" , [])
202- tools = input .get ("tools" )
203- tool_choice = input .get ("tool_choice" )
204-
205- body : dict = {"model" : model , "messages" : messages , "stream" : True }
206- if tools :
207- body ["tools" ] = tools
208- if tool_choice :
209- body ["tool_choice" ] = tool_choice
210-
211- request = Request (
212- f"{ self .base_url } { API_ENDPOINT } " ,
213- data = json .dumps (body ).encode ("utf-8" ),
214- headers = {
215- "Content-Type" : "application/json" ,
216- "Authorization" : f"Bearer { self .api_key } " ,
217- },
218- method = "POST" ,
219- )
220-
202+ def _handle_streaming_response (self , request : Request ):
203+ """Handle streaming response, yielding StreamChunk objects."""
221204 try :
222205 with urlopen (request ) as response :
223206 # Read and parse SSE stream
224- buffer = ""
225207 for line in response :
226208 decoded_line = line .decode ("utf-8" )
227209
@@ -262,3 +244,15 @@ def stream(
262244 except HTTPError as e :
263245 error_body = e .read ().decode ("utf-8" )
264246 raise RuntimeError (f"API error { e .code } : { error_body } " ) from e
247+
248+ def stream (
249+ self ,
250+ model : str ,
251+ input : str | InputObject | dict ,
252+ ):
253+ """Stream a completion request from the Edgee AI Gateway.
254+
255+ Convenience method that calls send(stream=True).
256+ Yields StreamChunk objects as they arrive from the API.
257+ """
258+ return self .send (model = model , input = input , stream = True )
0 commit comments