Skip to content

Commit 117da67

Browse files
feat(bidi): support request_state stop_event_loop flag (#1954)
Co-authored-by: agent-of-mkmeral <agent-of-mkmeral@users.noreply.github.com>
1 parent dd7a7d9 commit 117da67

6 files changed

Lines changed: 215 additions & 13 deletions

File tree

src/strands/experimental/bidi/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# Model interface (for custom implementations)
1515
from .models.model import BidiModel
1616

17-
# Built-in tools
17+
# Built-in tools (deprecated - use strands_tools.stop instead)
1818
from .tools import stop_conversation
1919

2020
# Event types - For type hints and event handling
@@ -39,8 +39,6 @@
3939
__all__ = [
4040
# Main interface
4141
"BidiAgent",
42-
# Built-in tools
43-
"stop_conversation",
4442
# Input Event types
4543
"BidiTextInputEvent",
4644
"BidiAudioInputEvent",
@@ -64,6 +62,8 @@
6462
"ToolStreamEvent",
6563
# Model interface
6664
"BidiModel",
65+
# Built-in tools (deprecated)
66+
"stop_conversation",
6767
]
6868

6969

src/strands/experimental/bidi/agent/loop.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import asyncio
77
import logging
8+
import warnings
89
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
910

1011
from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent
@@ -248,6 +249,10 @@ async def _run_tool(self, tool_use: ToolUse) -> None:
248249

249250
tool_results: list[ToolResult] = []
250251

252+
# Ensure request_state exists for tools like strands_tools.stop
253+
if "request_state" not in self._invocation_state:
254+
self._invocation_state["request_state"] = {}
255+
251256
invocation_state: dict[str, Any] = {
252257
**self._invocation_state,
253258
"agent": self._agent,
@@ -282,16 +287,29 @@ async def _run_tool(self, tool_use: ToolUse) -> None:
282287

283288
await self._event_queue.put(ToolResultMessageEvent(tool_result_message))
284289

285-
# Check for stop_conversation before sending to model
286-
if tool_use["name"] == "stop_conversation":
287-
logger.info("tool_name=<%s> | conversation stop requested, skipping model send", tool_use["name"])
290+
# Check for stop_event_loop flag (set by strands_tools.stop, stop_conversation, or any custom tool)
291+
request_state = invocation_state.get("request_state", {})
292+
should_stop = request_state.get("stop_event_loop", False)
293+
294+
# Backward compatibility: also check for stop_conversation by name (deprecated)
295+
if not should_stop and tool_use["name"] == "stop_conversation":
296+
warnings.warn(
297+
"Stopping the event loop by tool name 'stop_conversation' is deprecated. "
298+
"Use request_state['stop_event_loop'] = True instead.",
299+
DeprecationWarning,
300+
stacklevel=2,
301+
)
302+
should_stop = True
303+
304+
if should_stop:
305+
logger.info("stop_event_loop=<True> | stopping conversation")
288306
connection_id = getattr(self._agent.model, "_connection_id", "unknown")
289307
await self._event_queue.put(
290308
BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request")
291309
)
292-
return # Skip the model send
310+
return # Skip sending result to model
293311

294-
# Send result to model (all tools except stop_conversation)
312+
# Send result to model
295313
await self.send(tool_result_event)
296314

297315
except Exception as error:

src/strands/experimental/bidi/io/text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def __call__(self, event: BidiOutputEvent) -> None:
4242

4343
elif isinstance(event, BidiConnectionCloseEvent):
4444
if event.reason == "user_request":
45-
print("user requested connection close using the stop_conversation tool.")
45+
print("user requested connection close using the stop tool.")
4646
logger.debug("connection_id=<%s> | user requested connection close", event.connection_id)
4747
elif isinstance(event, BidiTranscriptStreamEvent):
4848
text = event["text"]

src/strands/experimental/bidi/tools/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
1-
"""Built-in tools for bidirectional agents."""
1+
"""Built-in tools for bidirectional agents.
2+
3+
.. deprecated::
4+
The built-in ``stop_conversation`` tool is deprecated. Use ``strands_tools.stop`` or set
5+
``request_state["stop_event_loop"] = True`` in any custom tool instead.
6+
7+
To stop a bidirectional conversation, use the standard ``stop`` tool from strands_tools::
8+
9+
from strands_tools import stop
10+
agent = BidiAgent(tools=[stop, ...])
11+
12+
The stop tool sets ``request_state["stop_event_loop"] = True``, which signals the
13+
BidiAgent to gracefully close the connection.
14+
"""
215

316
from .stop_conversation import stop_conversation
417

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
"""Tool to gracefully stop a bidirectional connection."""
1+
"""Tool to gracefully stop a bidirectional connection.
2+
3+
.. deprecated::
4+
The ``stop_conversation`` tool is deprecated and will be removed in a future version.
5+
Use ``strands_tools.stop`` or set ``request_state["stop_event_loop"] = True`` in any custom tool instead.
6+
"""
7+
8+
import warnings
29

310
from ....tools.decorator import tool
411

@@ -7,10 +14,19 @@
714
def stop_conversation() -> str:
815
"""Stop the bidirectional conversation gracefully.
916
17+
.. deprecated::
18+
Use ``strands_tools.stop`` or set ``request_state["stop_event_loop"] = True`` in a custom tool instead.
19+
1020
Use ONLY when user says "stop conversation" exactly.
1121
Do NOT use for: "stop", "goodbye", "bye", "exit", "quit", "end" or other farewells or phrases.
1222
1323
Returns:
14-
Success message confirming the conversation will end
24+
Success message confirming the conversation will end.
1525
"""
26+
warnings.warn(
27+
"stop_conversation is deprecated and will be removed in a future version. "
28+
"Use strands_tools.stop or set request_state['stop_event_loop'] = True in any custom tool instead.",
29+
DeprecationWarning,
30+
stacklevel=2,
31+
)
1632
return "Ending conversation"

tests/strands/experimental/bidi/agent/test_loop.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import unittest.mock
2+
import warnings
23

34
import pytest
45
import pytest_asyncio
56

67
from strands import tool
78
from strands.experimental.bidi import BidiAgent
89
from strands.experimental.bidi.models import BidiModel, BidiModelTimeoutError
9-
from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent
10+
from strands.experimental.bidi.types.events import BidiConnectionCloseEvent, BidiConnectionRestartEvent, BidiTextInputEvent
1011
from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent
1112

1213

@@ -93,3 +94,157 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator):
9394
assert tru_messages == exp_messages
9495

9596
agent.model.send.assert_called_with(tool_result_event)
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_bidi_agent_loop_request_state_initialized_for_tools(loop, agent, agenerator):
101+
"""Test that request_state is initialized in invocation_state before tool execution.
102+
103+
This ensures request_state exists for tools that may need it via invocation_state,
104+
even when invocation_state is not provided by the user.
105+
"""
106+
tool_use = {"toolUseId": "t2", "name": "time_tool", "input": {}}
107+
tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="")
108+
109+
agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event]))
110+
111+
# Start without providing invocation_state
112+
await loop.start()
113+
114+
tru_events = []
115+
async for event in loop.receive():
116+
tru_events.append(event)
117+
if len(tru_events) >= 3:
118+
break
119+
120+
# Verify tool executed successfully
121+
tool_result_event = tru_events[1]
122+
assert isinstance(tool_result_event, ToolResultEvent)
123+
assert tool_result_event.tool_result["status"] == "success"
124+
125+
# Verify request_state was initialized in invocation_state
126+
assert "request_state" in loop._invocation_state
127+
assert isinstance(loop._invocation_state["request_state"], dict)
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_bidi_agent_loop_stop_event_loop_flag(agent, agenerator):
132+
"""Test that the stop_event_loop flag in request_state gracefully closes the connection.
133+
134+
This simulates a tool (like strands_tools.stop) setting the flag via invocation_state.
135+
"""
136+
# Use a tool that modifies invocation_state to set the stop flag
137+
# We'll mock the tool executor to simulate this behavior
138+
loop = agent._loop
139+
140+
tool_use = {"toolUseId": "t3", "name": "time_tool", "input": {}}
141+
tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="")
142+
tool_result = {"toolUseId": "t3", "status": "success", "content": [{"text": "12:00"}]}
143+
144+
agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event]))
145+
146+
# Start with request_state that already has stop_event_loop=True
147+
# This simulates a tool having set it during execution
148+
await loop.start(invocation_state={"request_state": {"stop_event_loop": True}})
149+
150+
tru_events = []
151+
async for event in loop.receive():
152+
tru_events.append(event)
153+
154+
# Should receive: tool_use_event, tool_result_event, tool_result_message, connection_close
155+
assert len(tru_events) == 4
156+
157+
# Verify tool executed successfully
158+
tool_result_event = tru_events[1]
159+
assert isinstance(tool_result_event, ToolResultEvent)
160+
assert tool_result_event.tool_result["status"] == "success"
161+
162+
# Verify connection close event was emitted
163+
connection_close_event = tru_events[3]
164+
assert isinstance(connection_close_event, BidiConnectionCloseEvent)
165+
assert connection_close_event["reason"] == "user_request"
166+
167+
# Verify model.send was NOT called (tool result not sent to model)
168+
agent.model.send.assert_not_called()
169+
170+
171+
@pytest.mark.asyncio
172+
async def test_bidi_agent_loop_stop_conversation_deprecated_but_works(loop, agent, agenerator):
173+
"""Test that stop_conversation tool still works but emits a deprecation warning.
174+
175+
The stop_conversation tool is deprecated in favor of request_state["stop_event_loop"],
176+
but should continue to work for backward compatibility via the name-based check.
177+
"""
178+
from strands.experimental.bidi.tools import stop_conversation
179+
180+
agent.tool_registry.register_tool(stop_conversation)
181+
182+
tool_use = {"toolUseId": "t5", "name": "stop_conversation", "input": {}}
183+
tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="")
184+
185+
agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event]))
186+
187+
await loop.start()
188+
189+
tru_events = []
190+
with warnings.catch_warnings(record=True) as caught_warnings:
191+
warnings.simplefilter("always")
192+
async for event in loop.receive():
193+
tru_events.append(event)
194+
195+
# Should receive: tool_use_event, tool_result_event, tool_result_message, connection_close
196+
assert len(tru_events) == 4
197+
198+
# Verify tool executed successfully
199+
tool_result_event = tru_events[1]
200+
assert isinstance(tool_result_event, ToolResultEvent)
201+
assert tool_result_event.tool_result["status"] == "success"
202+
assert "Ending conversation" in tool_result_event.tool_result["content"][0]["text"]
203+
204+
# Verify connection close event was emitted
205+
connection_close_event = tru_events[3]
206+
assert isinstance(connection_close_event, BidiConnectionCloseEvent)
207+
assert connection_close_event["reason"] == "user_request"
208+
209+
# Verify model.send was NOT called (tool result not sent to model)
210+
agent.model.send.assert_not_called()
211+
212+
# Verify deprecation warnings were emitted (from both the tool itself and the loop name check)
213+
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, DeprecationWarning)]
214+
assert len(deprecation_warnings) >= 1
215+
assert any("stop_conversation" in str(w.message).lower() for w in deprecation_warnings)
216+
217+
218+
@pytest.mark.asyncio
219+
async def test_bidi_agent_loop_request_state_preserved_with_invocation_state(agent, agenerator):
220+
"""Test that existing invocation_state is preserved when request_state is initialized."""
221+
222+
@tool(name="check_invocation_state")
223+
async def check_invocation_state(custom_key: str) -> str:
224+
return f"custom_key: {custom_key}"
225+
226+
agent.tool_registry.register_tool(check_invocation_state)
227+
228+
tool_use = {"toolUseId": "t4", "name": "check_invocation_state", "input": {"custom_key": "from_state"}}
229+
tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="")
230+
231+
agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event]))
232+
233+
loop = agent._loop
234+
# Start with custom invocation_state but no request_state
235+
await loop.start(invocation_state={"custom_data": "preserved"})
236+
237+
tru_events = []
238+
async for event in loop.receive():
239+
tru_events.append(event)
240+
if len(tru_events) >= 3:
241+
break
242+
243+
# Verify tool executed successfully
244+
tool_result_event = tru_events[1]
245+
assert isinstance(tool_result_event, ToolResultEvent)
246+
assert tool_result_event.tool_result["status"] == "success"
247+
248+
# Verify request_state was added without removing custom_data
249+
assert "request_state" in loop._invocation_state
250+
assert loop._invocation_state.get("custom_data") == "preserved"

0 commit comments

Comments
 (0)