22
33import json
44import os
5+ from collections .abc import Callable
56from dataclasses import dataclass
6- from typing import Any , Callable , Type
7+ from typing import Any
78from urllib .error import HTTPError
89from urllib .request import Request , urlopen
910
@@ -71,7 +72,7 @@ def get_weather(params: WeatherParams) -> dict:
7172 def __init__ (
7273 self ,
7374 name : str ,
74- schema : Type [BaseModel ],
75+ schema : type [BaseModel ],
7576 handler : Callable [..., Any ],
7677 description : str | None = None ,
7778 ):
@@ -109,7 +110,7 @@ def execute(self, args: dict) -> Any:
109110
110111def create_tool (
111112 name : str ,
112- schema : Type [BaseModel ],
113+ schema : type [BaseModel ],
113114 handler : Callable [..., Any ],
114115 description : str | None = None ,
115116) -> Tool :
@@ -205,6 +206,15 @@ def tool_calls(self) -> list | None:
205206 return None
206207
207208
209+ @dataclass
210+ class StreamToolCallDelta :
211+ """Partial tool call in a stream."""
212+ index : int
213+ id : str | None = None
214+ type : str | None = None
215+ function : dict | None = None # {"name": str, "arguments": str}
216+
217+
208218@dataclass
209219class StreamDelta :
210220 role : str | None = None
@@ -244,6 +254,47 @@ def finish_reason(self) -> str | None:
244254 return self .choices [0 ].finish_reason
245255 return None
246256
257+ @property
258+ def tool_call_deltas (self ) -> list [dict ] | None :
259+ """Get tool call deltas from the first choice."""
260+ if self .choices and self .choices [0 ].delta .tool_calls :
261+ return self .choices [0 ].delta .tool_calls
262+ return None
263+
264+
265+ # Stream events for tool-enabled streaming
266+ @dataclass
267+ class ChunkEvent :
268+ """A chunk of streamed content."""
269+ type : str = "chunk"
270+ chunk : StreamChunk = None
271+
272+
273+ @dataclass
274+ class ToolStartEvent :
275+ """Tool execution is starting."""
276+ type : str = "tool_start"
277+ tool_call : dict = None
278+
279+
280+ @dataclass
281+ class ToolResultEvent :
282+ """Tool execution completed."""
283+ type : str = "tool_result"
284+ tool_call_id : str = None
285+ tool_name : str = None
286+ result : Any = None
287+
288+
289+ @dataclass
290+ class IterationCompleteEvent :
291+ """One iteration of the tool loop completed."""
292+ type : str = "iteration_complete"
293+ iteration : int = 0
294+
295+
296+ StreamEvent = ChunkEvent | ToolStartEvent | ToolResultEvent | IterationCompleteEvent
297+
247298
248299@dataclass
249300class EdgeeConfig :
@@ -541,10 +592,170 @@ def stream(
541592 self ,
542593 model : str ,
543594 input : str | InputObject | dict ,
595+ tools : list [Tool ] | None = None ,
596+ max_tool_iterations : int = 10 ,
544597 ):
545598 """Stream a completion request from the Edgee AI Gateway.
546599
547- Convenience method that calls send(stream=True).
548- Yields StreamChunk objects as they arrive from the API.
600+ Args:
601+ model: The model to use for completion
602+ input: The input (string, dict, or InputObject)
603+ tools: Optional list of Tool instances for automatic execution (simple mode only)
604+ max_tool_iterations: Maximum number of tool execution iterations (default: 10)
605+
606+ Yields:
607+ StreamChunk objects if no tools provided.
608+ StreamEvent objects (ChunkEvent, ToolStartEvent, ToolResultEvent, IterationCompleteEvent)
609+ if tools are provided.
610+
611+ Example without tools:
612+ ```python
613+ for chunk in edgee.stream("gpt-4o", "Hello!"):
614+ print(chunk.text, end="")
615+ ```
616+
617+ Example with tools:
618+ ```python
619+ for event in edgee.stream("gpt-4o", "What's the weather?", tools=[weather_tool]):
620+ if event.type == "chunk":
621+ print(event.chunk.text, end="")
622+ elif event.type == "tool_result":
623+ print(f"Tool result: {event.result}")
624+ ```
625+ """
626+ # Simple mode with tools: use agentic streaming loop
627+ if isinstance (input , str ) and tools :
628+ return self ._stream_simple (model , input , tools , max_tool_iterations )
629+
630+ # Simple mode without tools or advanced mode: regular streaming
631+ if isinstance (input , str ):
632+ messages = [{"role" : "user" , "content" : input }]
633+ return self ._call_api (model , messages , stream = True )
634+
635+ # Advanced mode: full InputObject or dict
636+ if isinstance (input , InputObject ):
637+ messages = input .messages
638+ api_tools = input .tools
639+ tool_choice = input .tool_choice
640+ else :
641+ messages = input .get ("messages" , [])
642+ api_tools = input .get ("tools" )
643+ tool_choice = input .get ("tool_choice" )
644+
645+ return self ._call_api (
646+ model , messages , api_tools = api_tools , tool_choice = tool_choice , stream = True
647+ )
648+
649+ def _stream_simple (
650+ self ,
651+ model : str ,
652+ input : str ,
653+ tools : list [Tool ],
654+ max_iterations : int ,
655+ ):
656+ """Handle simple mode streaming with automatic tool execution.
657+
658+ Yields StreamEvent objects for chunks, tool starts, tool results, and iteration completion.
549659 """
550- return self .send (model = model , input = input , stream = True )
660+ messages : list [dict ] = [{"role" : "user" , "content" : input }]
661+ openai_tools = [t .to_dict () for t in tools ]
662+ tool_map = {t .name : t for t in tools }
663+
664+ for iteration in range (1 , max_iterations + 1 ):
665+ # Accumulate the full response from stream
666+ role : str | None = None
667+ content = ""
668+ tool_calls_accumulator : dict [int , dict ] = {}
669+
670+ # Stream the response
671+ for chunk in self ._call_api (model , messages , api_tools = openai_tools , stream = True ):
672+ # Yield the chunk as an event
673+ yield ChunkEvent (chunk = chunk )
674+
675+ # Accumulate role
676+ if chunk .role :
677+ role = chunk .role
678+
679+ # Accumulate content
680+ if chunk .text :
681+ content += chunk .text
682+
683+ # Accumulate tool calls from deltas
684+ tool_call_deltas = chunk .tool_call_deltas
685+ if tool_call_deltas :
686+ for delta in tool_call_deltas :
687+ idx = delta .get ("index" , 0 )
688+ if idx in tool_calls_accumulator :
689+ # Append to existing tool call
690+ existing = tool_calls_accumulator [idx ]
691+ if delta .get ("function" , {}).get ("arguments" ):
692+ existing ["function" ]["arguments" ] += delta ["function" ]["arguments" ]
693+ else :
694+ # Start new tool call
695+ tool_calls_accumulator [idx ] = {
696+ "id" : delta .get ("id" , "" ),
697+ "type" : "function" ,
698+ "function" : {
699+ "name" : delta .get ("function" , {}).get ("name" , "" ),
700+ "arguments" : delta .get ("function" , {}).get ("arguments" , "" ),
701+ },
702+ }
703+
704+ # Convert accumulated tool calls to list
705+ tool_calls = list (tool_calls_accumulator .values ())
706+
707+ # No tool calls? We're done
708+ if not tool_calls :
709+ return
710+
711+ # Add assistant's message (with tool_calls) to messages
712+ assistant_message : dict = {
713+ "role" : role or "assistant" ,
714+ "tool_calls" : tool_calls ,
715+ }
716+ if content :
717+ assistant_message ["content" ] = content
718+ messages .append (assistant_message )
719+
720+ # Execute each tool call and add results
721+ for tool_call in tool_calls :
722+ tool_name = tool_call ["function" ]["name" ]
723+ tool = tool_map .get (tool_name )
724+
725+ # Yield tool_start event
726+ yield ToolStartEvent (tool_call = tool_call )
727+
728+ if tool :
729+ try :
730+ raw_args = json .loads (tool_call ["function" ]["arguments" ])
731+ result = tool .execute (raw_args )
732+ except ValidationError as e :
733+ result = {"error" : f"Invalid arguments: { e } " }
734+ except json .JSONDecodeError as e :
735+ result = {"error" : f"Failed to parse arguments: { e } " }
736+ except Exception as e :
737+ result = {"error" : f"Tool execution failed: { e } " }
738+ else :
739+ result = {"error" : f"Unknown tool: { tool_name } " }
740+
741+ # Yield tool_result event
742+ yield ToolResultEvent (
743+ tool_call_id = tool_call ["id" ],
744+ tool_name = tool_name ,
745+ result = result ,
746+ )
747+
748+ # Add tool result to messages
749+ messages .append ({
750+ "role" : "tool" ,
751+ "tool_call_id" : tool_call ["id" ],
752+ "content" : result if isinstance (result , str ) else json .dumps (result ),
753+ })
754+
755+ # Yield iteration complete event
756+ yield IterationCompleteEvent (iteration = iteration )
757+
758+ # Loop continues - model will process tool results
759+
760+ # Max iterations reached
761+ raise RuntimeError (f"Max tool iterations ({ max_iterations } ) reached" )
0 commit comments