Skip to content

Commit bbc0eba

Browse files
feat: more example + stream with auto_tools
1 parent 5f3bf93 commit bbc0eba

13 files changed

Lines changed: 595 additions & 175 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ dist/
2525
.pytest_cache/
2626
.coverage
2727
htmlcov/
28+
29+
# Environment
30+
.env

edgee/__init__.py

Lines changed: 217 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import json
44
import os
5+
from collections.abc import Callable
56
from dataclasses import dataclass
6-
from typing import Any, Callable, Type
7+
from typing import Any
78
from urllib.error import HTTPError
89
from urllib.request import Request, urlopen
910

@@ -71,7 +72,7 @@ def get_weather(params: WeatherParams) -> dict:
7172
def __init__(
7273
self,
7374
name: str,
74-
schema: Type[BaseModel],
75+
schema: type[BaseModel],
7576
handler: Callable[..., Any],
7677
description: str | None = None,
7778
):
@@ -109,7 +110,7 @@ def execute(self, args: dict) -> Any:
109110

110111
def create_tool(
111112
name: str,
112-
schema: Type[BaseModel],
113+
schema: type[BaseModel],
113114
handler: Callable[..., Any],
114115
description: str | None = None,
115116
) -> Tool:
@@ -205,6 +206,15 @@ def tool_calls(self) -> list | None:
205206
return None
206207

207208

209+
@dataclass
210+
class StreamToolCallDelta:
211+
"""Partial tool call in a stream."""
212+
index: int
213+
id: str | None = None
214+
type: str | None = None
215+
function: dict | None = None # {"name": str, "arguments": str}
216+
217+
208218
@dataclass
209219
class StreamDelta:
210220
role: str | None = None
@@ -244,6 +254,47 @@ def finish_reason(self) -> str | None:
244254
return self.choices[0].finish_reason
245255
return None
246256

257+
@property
258+
def tool_call_deltas(self) -> list[dict] | None:
259+
"""Get tool call deltas from the first choice."""
260+
if self.choices and self.choices[0].delta.tool_calls:
261+
return self.choices[0].delta.tool_calls
262+
return None
263+
264+
265+
# Stream events for tool-enabled streaming
266+
@dataclass
267+
class ChunkEvent:
268+
"""A chunk of streamed content."""
269+
type: str = "chunk"
270+
chunk: StreamChunk = None
271+
272+
273+
@dataclass
274+
class ToolStartEvent:
275+
"""Tool execution is starting."""
276+
type: str = "tool_start"
277+
tool_call: dict = None
278+
279+
280+
@dataclass
281+
class ToolResultEvent:
282+
"""Tool execution completed."""
283+
type: str = "tool_result"
284+
tool_call_id: str = None
285+
tool_name: str = None
286+
result: Any = None
287+
288+
289+
@dataclass
290+
class IterationCompleteEvent:
291+
"""One iteration of the tool loop completed."""
292+
type: str = "iteration_complete"
293+
iteration: int = 0
294+
295+
296+
StreamEvent = ChunkEvent | ToolStartEvent | ToolResultEvent | IterationCompleteEvent
297+
247298

248299
@dataclass
249300
class EdgeeConfig:
@@ -541,10 +592,170 @@ def stream(
541592
self,
542593
model: str,
543594
input: str | InputObject | dict,
595+
tools: list[Tool] | None = None,
596+
max_tool_iterations: int = 10,
544597
):
545598
"""Stream a completion request from the Edgee AI Gateway.
546599
547-
Convenience method that calls send(stream=True).
548-
Yields StreamChunk objects as they arrive from the API.
600+
Args:
601+
model: The model to use for completion
602+
input: The input (string, dict, or InputObject)
603+
tools: Optional list of Tool instances for automatic execution (simple mode only)
604+
max_tool_iterations: Maximum number of tool execution iterations (default: 10)
605+
606+
Yields:
607+
StreamChunk objects if no tools provided.
608+
StreamEvent objects (ChunkEvent, ToolStartEvent, ToolResultEvent, IterationCompleteEvent)
609+
if tools are provided.
610+
611+
Example without tools:
612+
```python
613+
for chunk in edgee.stream("gpt-4o", "Hello!"):
614+
print(chunk.text, end="")
615+
```
616+
617+
Example with tools:
618+
```python
619+
for event in edgee.stream("gpt-4o", "What's the weather?", tools=[weather_tool]):
620+
if event.type == "chunk":
621+
print(event.chunk.text, end="")
622+
elif event.type == "tool_result":
623+
print(f"Tool result: {event.result}")
624+
```
625+
"""
626+
# Simple mode with tools: use agentic streaming loop
627+
if isinstance(input, str) and tools:
628+
return self._stream_simple(model, input, tools, max_tool_iterations)
629+
630+
# Simple mode without tools or advanced mode: regular streaming
631+
if isinstance(input, str):
632+
messages = [{"role": "user", "content": input}]
633+
return self._call_api(model, messages, stream=True)
634+
635+
# Advanced mode: full InputObject or dict
636+
if isinstance(input, InputObject):
637+
messages = input.messages
638+
api_tools = input.tools
639+
tool_choice = input.tool_choice
640+
else:
641+
messages = input.get("messages", [])
642+
api_tools = input.get("tools")
643+
tool_choice = input.get("tool_choice")
644+
645+
return self._call_api(
646+
model, messages, api_tools=api_tools, tool_choice=tool_choice, stream=True
647+
)
648+
649+
def _stream_simple(
650+
self,
651+
model: str,
652+
input: str,
653+
tools: list[Tool],
654+
max_iterations: int,
655+
):
656+
"""Handle simple mode streaming with automatic tool execution.
657+
658+
Yields StreamEvent objects for chunks, tool starts, tool results, and iteration completion.
549659
"""
550-
return self.send(model=model, input=input, stream=True)
660+
messages: list[dict] = [{"role": "user", "content": input}]
661+
openai_tools = [t.to_dict() for t in tools]
662+
tool_map = {t.name: t for t in tools}
663+
664+
for iteration in range(1, max_iterations + 1):
665+
# Accumulate the full response from stream
666+
role: str | None = None
667+
content = ""
668+
tool_calls_accumulator: dict[int, dict] = {}
669+
670+
# Stream the response
671+
for chunk in self._call_api(model, messages, api_tools=openai_tools, stream=True):
672+
# Yield the chunk as an event
673+
yield ChunkEvent(chunk=chunk)
674+
675+
# Accumulate role
676+
if chunk.role:
677+
role = chunk.role
678+
679+
# Accumulate content
680+
if chunk.text:
681+
content += chunk.text
682+
683+
# Accumulate tool calls from deltas
684+
tool_call_deltas = chunk.tool_call_deltas
685+
if tool_call_deltas:
686+
for delta in tool_call_deltas:
687+
idx = delta.get("index", 0)
688+
if idx in tool_calls_accumulator:
689+
# Append to existing tool call
690+
existing = tool_calls_accumulator[idx]
691+
if delta.get("function", {}).get("arguments"):
692+
existing["function"]["arguments"] += delta["function"]["arguments"]
693+
else:
694+
# Start new tool call
695+
tool_calls_accumulator[idx] = {
696+
"id": delta.get("id", ""),
697+
"type": "function",
698+
"function": {
699+
"name": delta.get("function", {}).get("name", ""),
700+
"arguments": delta.get("function", {}).get("arguments", ""),
701+
},
702+
}
703+
704+
# Convert accumulated tool calls to list
705+
tool_calls = list(tool_calls_accumulator.values())
706+
707+
# No tool calls? We're done
708+
if not tool_calls:
709+
return
710+
711+
# Add assistant's message (with tool_calls) to messages
712+
assistant_message: dict = {
713+
"role": role or "assistant",
714+
"tool_calls": tool_calls,
715+
}
716+
if content:
717+
assistant_message["content"] = content
718+
messages.append(assistant_message)
719+
720+
# Execute each tool call and add results
721+
for tool_call in tool_calls:
722+
tool_name = tool_call["function"]["name"]
723+
tool = tool_map.get(tool_name)
724+
725+
# Yield tool_start event
726+
yield ToolStartEvent(tool_call=tool_call)
727+
728+
if tool:
729+
try:
730+
raw_args = json.loads(tool_call["function"]["arguments"])
731+
result = tool.execute(raw_args)
732+
except ValidationError as e:
733+
result = {"error": f"Invalid arguments: {e}"}
734+
except json.JSONDecodeError as e:
735+
result = {"error": f"Failed to parse arguments: {e}"}
736+
except Exception as e:
737+
result = {"error": f"Tool execution failed: {e}"}
738+
else:
739+
result = {"error": f"Unknown tool: {tool_name}"}
740+
741+
# Yield tool_result event
742+
yield ToolResultEvent(
743+
tool_call_id=tool_call["id"],
744+
tool_name=tool_name,
745+
result=result,
746+
)
747+
748+
# Add tool result to messages
749+
messages.append({
750+
"role": "tool",
751+
"tool_call_id": tool_call["id"],
752+
"content": result if isinstance(result, str) else json.dumps(result),
753+
})
754+
755+
# Yield iteration complete event
756+
yield IterationCompleteEvent(iteration=iteration)
757+
758+
# Loop continues - model will process tool results
759+
760+
# Max iterations reached
761+
raise RuntimeError(f"Max tool iterations ({max_iterations}) reached")

0 commit comments

Comments
 (0)