Skip to content

Commit cdf7cf2

Browse files
committed
fix dependencies
1 parent 9b20e12 commit cdf7cf2

4 files changed

Lines changed: 44 additions & 58 deletions

File tree

src/strands/experimental/bidirectional_streaming/agent/agent.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from ..models.bidirectional_model import BidirectionalModel
3434
from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent
3535

36-
3736
logger = logging.getLogger(__name__)
3837

3938
_DEFAULT_AGENT_NAME = "Strands Agents"
@@ -81,7 +80,7 @@ def caller(
8180
8281
Args:
8382
user_message_override: Optional custom message to record instead of default
84-
record_direct_tool_call: Whether to record direct tool calls in message history.
83+
record_direct_tool_call: Whether to record direct tool calls in message history.
8584
For bidirectional agents, this is always True to maintain conversation history.
8685
**kwargs: Keyword arguments to pass to the tool.
8786
@@ -186,12 +185,12 @@ def __init__(
186185
self.model = model
187186
self.system_prompt = system_prompt
188187
self.messages = messages or []
189-
188+
190189
# Agent identification
191190
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
192191
self.name = name or _DEFAULT_AGENT_NAME
193192
self.description = description
194-
193+
195194
# Tool execution configuration
196195
self.record_direct_tool_call = record_direct_tool_call
197196
self.load_tools_from_directory = load_tools_from_directory
@@ -207,25 +206,25 @@ def __init__(
207206

208207
# Initialize tool registry
209208
self.tool_registry = ToolRegistry()
210-
209+
211210
if tools is not None:
212211
self.tool_registry.process_tools(tools)
213-
212+
214213
self.tool_registry.initialize_tools(self.load_tools_from_directory)
215-
214+
216215
# Initialize tool watcher if directory loading is enabled
217216
if self.load_tools_from_directory:
218217
self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry)
219218

220219
# Initialize tool executor
221220
self.tool_executor = tool_executor or ConcurrentToolExecutor()
222-
221+
223222
# Initialize hooks system
224223
self.hooks = HookRegistry()
225224
if hooks:
226225
for hook in hooks:
227226
self.hooks.add_hook(hook)
228-
227+
229228
# Initialize other components
230229
self.event_loop_metrics = EventLoopMetrics()
231230
self.tool_caller = BidirectionalAgent.ToolCaller(self)

src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515
import logging
1616
import traceback
1717
import uuid
18+
from typing import TYPE_CHECKING
1819

19-
from ....tools._validator import validate_and_prepare_tools
2020
from ....telemetry.metrics import Trace
21+
from ....tools._validator import validate_and_prepare_tools
2122
from ....types._events import ToolResultEvent, ToolStreamEvent
2223
from ....types.content import Message
2324
from ....types.tools import ToolResult, ToolUse
2425
from ..models.bidirectional_model import BidirectionalModelSession
2526

26-
27+
if TYPE_CHECKING:
28+
from ..agent import BidirectionalAgent
2729
logger = logging.getLogger(__name__)
2830

2931
# Session constants
@@ -60,7 +62,7 @@ def __init__(self, model_session: BidirectionalModelSession, agent: "Bidirection
6062
# Interruption handling (model-agnostic)
6163
self.interrupted = False
6264
self.interruption_lock = asyncio.Lock()
63-
65+
6466
# Tool execution tracking
6567
self.tool_count = 0
6668

@@ -265,7 +267,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None:
265267
# Basic validation - skip invalid events
266268
if not isinstance(provider_event, dict):
267269
continue
268-
270+
269271
strands_event = provider_event
270272

271273
# Handle interruption detection (provider converts raw patterns to interruptionDetected)
@@ -291,7 +293,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None:
291293
if strands_event.get("messageStop"):
292294
logger.debug("Message added to history")
293295
session.agent.messages.append(strands_event["messageStop"]["message"])
294-
296+
295297
# Handle user audio transcripts - add to message history
296298
if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user":
297299
user_transcript = strands_event["textOutput"]["text"]
@@ -311,7 +313,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None:
311313
"""Execute tools concurrently with interruption support.
312314
313315
Background task that manages tool execution without blocking model event
314-
processing or user interaction. Uses proper asyncio cancellation for
316+
processing or user interaction. Uses proper asyncio cancellation for
315317
interruption handling rather than manual state checks.
316318
317319
Args:
@@ -323,10 +325,10 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None:
323325
tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT)
324326
tool_name = tool_use.get("name")
325327
tool_id = tool_use.get("toolUseId")
326-
328+
327329
session.tool_count += 1
328330
print(f"\nTool #{session.tool_count}: {tool_name}")
329-
331+
330332
logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id)
331333

332334
task_id = str(uuid.uuid4())
@@ -372,110 +374,96 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None:
372374
logger.debug("Tool execution processor stopped")
373375

374376

375-
376-
377-
378377
async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None:
379378
"""Execute tool using the complete Strands tool execution system.
380-
379+
381380
Uses proper Strands ToolExecutor system with validation, error handling,
382381
and event streaming.
383-
382+
384383
Args:
385384
session: BidirectionalConnection for context.
386385
tool_use: Tool use event to execute.
387386
"""
388387
tool_name = tool_use.get("name")
389388
tool_id = tool_use.get("toolUseId")
390-
389+
391390
logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id)
392-
391+
393392
try:
394-
# Create message structure for validation
393+
# Create message structure for validation
395394
tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]}
396-
395+
397396
# Use Strands validation system
398397
tool_uses: list[ToolUse] = []
399398
tool_results: list[ToolResult] = []
400399
invalid_tool_use_ids: list[str] = []
401-
400+
402401
validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids)
403-
402+
404403
# Filter valid tools
405404
valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids]
406-
405+
407406
if not valid_tool_uses:
408407
logger.warning("No valid tools after validation: %s", tool_name)
409408
return
410-
409+
411410
# Create invocation state for tool execution
412411
invocation_state = {
413412
"agent": session.agent,
414413
"model": session.agent.model,
415414
"messages": session.agent.messages,
416415
"system_prompt": session.agent.system_prompt,
417416
}
418-
417+
419418
# Create cycle trace and span
420419
cycle_trace = Trace("Bidirectional Tool Execution")
421420
cycle_span = None
422-
421+
423422
tool_events = session.agent.tool_executor._execute(
424-
session.agent,
425-
valid_tool_uses,
426-
tool_results,
427-
cycle_trace,
428-
cycle_span,
429-
invocation_state
423+
session.agent, valid_tool_uses, tool_results, cycle_trace, cycle_span, invocation_state
430424
)
431-
425+
432426
# Process tool events and send results to provider
433427
async for tool_event in tool_events:
434428
if isinstance(tool_event, ToolResultEvent):
435429
tool_result = tool_event.tool_result
436430
tool_use_id = tool_result.get("toolUseId")
437-
431+
438432
# Send result through provider-specific session
439433
await session.model_session.send_tool_result(tool_use_id, tool_result)
440434
logger.debug("Tool result sent: %s", tool_use_id)
441-
435+
442436
# Handle streaming events if needed later
443437
elif isinstance(tool_event, ToolStreamEvent):
444438
logger.debug("Tool stream event: %s", tool_event)
445439
pass
446-
440+
447441
# Add tool result message to conversation history
448442
if tool_results:
449443
from ....hooks import MessageAddedEvent
450-
444+
451445
tool_result_message: Message = {
452446
"role": "user",
453447
"content": [{"toolResult": result} for result in tool_results],
454448
}
455-
449+
456450
session.agent.messages.append(tool_result_message)
457451
session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message))
458452
logger.debug("Tool result message added to history: %s", tool_name)
459-
453+
460454
logger.debug("Tool execution completed: %s", tool_name)
461-
455+
462456
except asyncio.CancelledError:
463457
logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id)
464458
raise
465459
except Exception as e:
466460
logger.error("Tool execution error: %s - %s", tool_name, str(e))
467-
468-
# Send error result
469-
error_result: ToolResult = {
470-
"toolUseId": tool_id,
471-
"status": "error",
472-
"content": [{"text": f"Error: {str(e)}"}]
473-
}
461+
462+
# Send error result
463+
error_result: ToolResult = {"toolUseId": tool_id, "status": "error", "content": [{"text": f"Error: {str(e)}"}]}
474464
try:
475465
await session.model_session.send_tool_result(tool_id, error_result)
476466
logger.debug("Error result sent: %s", tool_id)
477467
except Exception:
478468
logger.error("Failed to send error result: %s", tool_id)
479469
pass # Session might be closed
480-
481-

src/strands/experimental/bidirectional_streaming/models/novasonic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535
BidirectionalConnectionStartEvent,
3636
InterruptionDetectedEvent,
3737
TextOutputEvent,
38-
UsageMetricsEvent
38+
UsageMetricsEvent,
3939
)
40-
4140
from .bidirectional_model import BidirectionalModel, BidirectionalModelSession
4241

4342
logger = logging.getLogger(__name__)

src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class BidirectionalConnectionEndEvent(TypedDict):
120120
connectionId: Optional[str]
121121
metadata: Optional[Dict[str, Any]]
122122

123+
123124
class UsageMetricsEvent(TypedDict):
124125
"""Token usage and performance tracking.
125126
@@ -162,4 +163,3 @@ class BidirectionalStreamEvent(StreamEvent, total=False):
162163
BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent]
163164
BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent]
164165
usageMetrics: Optional[UsageMetricsEvent]
165-

0 commit comments

Comments
 (0)