11import asyncio
2- import pytest
3- from google .adk .runners import InMemoryRunner
2+
3+ from google .adk .agents import BaseAgent
44from google .adk .agents .run_config import RunConfig
55from google .adk .agents .run_config import StreamingMode
6- from google .adk .runners import StopSignal
7- from google .genai .types import Part , Content
8- from google .adk .agents import BaseAgent
96from google .adk .events import Event
7+ from google .adk .runners import InMemoryRunner
8+ from google .adk .runners import StopSignal
9+ from google .genai .types import Content
10+ from google .genai .types import Part
11+ import pytest
12+
1013
1114# Used to tell the main loop when to interrupt the run_async
1215class TriggerSignal :
16+
1317 def __init__ (self ):
1418 self .triggered = False
15-
19+
1620 def set (self ):
1721 self .triggered = True
1822
1923 def reset (self ):
2024 self .triggered = False
21-
25+
2226 def is_set (self ) -> bool :
2327 return self .triggered
2428
29+
2530# Mocks an LLM, outputting events every 1s
2631class MockAgent (BaseAgent ):
27- name : str = "MockAgent"
28-
29- async def run_async (self , ctx : Any ) -> AsyncGenerator [Event , None ]:
30- cycle = 0
31-
32- while True :
33- await asyncio .sleep (1 )
34- cycle += 1
35- yield Event (
36- author = 'model' ,
37- content = Content (
38- role = 'model' ,
39- parts = [Part (text = f"{ cycle } " )]
40- )
41- )
32+ name : str = "MockAgent"
33+
34+ async def run_async (self , ctx : Any ) -> AsyncGenerator [Event , None ]:
35+ cycle = 0
36+
37+ while True :
38+ await asyncio .sleep (1 )
39+ cycle += 1
40+ yield Event (
41+ author = "model" ,
42+ content = Content (role = "model" , parts = [Part (text = f"{ cycle } " )]),
43+ )
44+
4245
4346APP_NAME = "adk_test_app"
4447USER_ID = "adk_test_user"
4548STOP_SIGNAL = StopSignal ()
4649TRIGGER_SIGNAL = TriggerSignal ()
4750
51+
4852@pytest .mark .asyncio
4953async def test_stop_signal ():
50- async def consume_stream (runner , trigger_signal , ** kwargs ):
51- cycle = 0
52-
53- async for event in runner .run_async (** kwargs ):
54- cycle += 1
55- content = event .content .parts [0 ].text
56-
57- if not content :
58- # The fourth cycle should be interrupted
59- # Only interrupted cycles would not have content objects given this MockAgent
60- assert cycle == 4
61- assert event .interrupted
62- if content and content == "3" :
63- # Tell main loop to interrupt the fourth cycle
64- trigger_signal .set ()
65-
66- client = MockAgent ()
67- runner = InMemoryRunner (agent = client , app_name = APP_NAME )
68-
69- session = await runner .session_service .create_session (
70- app_name = APP_NAME , user_id = USER_ID
71- )
72-
73- message = Content (
74- role = "user" ,
75- parts = [Part (text = "Necessary message" )]
76- )
77-
78- task = asyncio .create_task (consume_stream (
79- runner = runner ,
80- trigger_signal = TRIGGER_SIGNAL ,
81- session_id = session .id ,
82- user_id = USER_ID ,
83- new_message = message ,
84- stop_signal = STOP_SIGNAL ,
85- run_config = RunConfig (
86- streaming_mode = StreamingMode .SSE
87- )
88- ))
89-
90- # Wait for 3 events to pass
91- while not TRIGGER_SIGNAL .is_set ():
92- await asyncio .sleep (0.1 )
93-
94- # Interrupt the run_async
95- STOP_SIGNAL .stop ()
54+ async def consume_stream (runner , trigger_signal , ** kwargs ):
55+ cycle = 0
56+
57+ async for event in runner .run_async (** kwargs ):
58+ cycle += 1
59+ content = event .content .parts [0 ].text
60+
61+ if not content :
62+ # The fourth cycle should be interrupted
63+ # Only interrupted cycles would not have content objects given this MockAgent
64+ assert cycle == 4
65+ assert event .interrupted
66+ if content and content == "3" :
67+ # Tell main loop to interrupt the fourth cycle
68+ trigger_signal .set ()
69+
70+ client = MockAgent ()
71+ runner = InMemoryRunner (agent = client , app_name = APP_NAME )
72+
73+ session = await runner .session_service .create_session (
74+ app_name = APP_NAME , user_id = USER_ID
75+ )
76+
77+ message = Content (role = "user" , parts = [Part (text = "Necessary message" )])
78+
79+ task = asyncio .create_task (
80+ consume_stream (
81+ runner = runner ,
82+ trigger_signal = TRIGGER_SIGNAL ,
83+ session_id = session .id ,
84+ user_id = USER_ID ,
85+ new_message = message ,
86+ stop_signal = STOP_SIGNAL ,
87+ run_config = RunConfig (streaming_mode = StreamingMode .SSE ),
88+ )
89+ )
90+
91+ # Wait for 3 events to pass
92+ while not TRIGGER_SIGNAL .is_set ():
93+ await asyncio .sleep (0.1 )
94+
95+ # Interrupt the run_async
96+ STOP_SIGNAL .stop ()
0 commit comments