Skip to content

Commit e133e3b

Browse files
committed
autoformat.sh
1 parent 5fc0d25 commit e133e3b

2 files changed

Lines changed: 77 additions & 76 deletions

File tree

src/google/adk/runners.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,22 @@
6767

6868
logger = logging.getLogger('google_adk.' + __name__)
6969

70+
7071
class StopSignal:
72+
7173
def __init__(self):
7274
self.stopped = False
73-
75+
7476
def stop(self):
7577
self.stopped = True
7678

7779
def reset(self):
7880
self.stopped = False
79-
81+
8082
def is_set(self) -> bool:
8183
return self.stopped
8284

85+
8386
def _is_tool_call_or_response(event: Event) -> bool:
8487
return bool(event.get_function_calls() or event.get_function_responses())
8588

@@ -625,7 +628,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
625628
session=session,
626629
execute_fn=execute,
627630
is_live_call=False,
628-
stop_signal=stop_signal
631+
stop_signal=stop_signal,
629632
)
630633
) as agen:
631634
async for event in agen:
@@ -842,7 +845,7 @@ async def _exec_with_plugin(
842845
execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]],
843846
is_live_call: bool = False,
844847
*,
845-
stop_signal: Optional[StopSignal] = None
848+
stop_signal: Optional[StopSignal] = None,
846849
) -> AsyncGenerator[Event, None]:
847850
"""Wraps execution with plugin callbacks.
848851
@@ -904,10 +907,7 @@ async def _exec_with_plugin(
904907
invocation_id=invocation_context.invocation_id,
905908
author='model',
906909
interrupted=True,
907-
content=types.Content(
908-
role='model',
909-
parts=[]
910-
)
910+
content=types.Content(role='model', parts=[]),
911911
)
912912
yield interrupted_event
913913

Lines changed: 69 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,96 @@
11
import asyncio
2-
import pytest
3-
from google.adk.runners import InMemoryRunner
2+
3+
from google.adk.agents import BaseAgent
44
from google.adk.agents.run_config import RunConfig
55
from google.adk.agents.run_config import StreamingMode
6-
from google.adk.runners import StopSignal
7-
from google.genai.types import Part, Content
8-
from google.adk.agents import BaseAgent
96
from google.adk.events import Event
7+
from google.adk.runners import InMemoryRunner
8+
from google.adk.runners import StopSignal
9+
from google.genai.types import Content
10+
from google.genai.types import Part
11+
import pytest
12+
1013

1114
# Used to tell the main loop when to interrupt the run_async
1215
class TriggerSignal:
16+
1317
def __init__(self):
1418
self.triggered = False
15-
19+
1620
def set(self):
1721
self.triggered = True
1822

1923
def reset(self):
2024
self.triggered = False
21-
25+
2226
def is_set(self) -> bool:
2327
return self.triggered
2428

29+
2530
# Mocks an LLM, outputting events every 1s
2631
class MockAgent(BaseAgent):
27-
name: str = "MockAgent"
28-
29-
async def run_async(self, ctx: Any) -> AsyncGenerator[Event, None]:
30-
cycle = 0
31-
32-
while True:
33-
await asyncio.sleep(1)
34-
cycle += 1
35-
yield Event(
36-
author='model',
37-
content=Content(
38-
role='model',
39-
parts=[Part(text=f"{cycle}")]
40-
)
41-
)
32+
name: str = "MockAgent"
33+
34+
async def run_async(self, ctx: Any) -> AsyncGenerator[Event, None]:
35+
cycle = 0
36+
37+
while True:
38+
await asyncio.sleep(1)
39+
cycle += 1
40+
yield Event(
41+
author="model",
42+
content=Content(role="model", parts=[Part(text=f"{cycle}")]),
43+
)
44+
4245

4346
APP_NAME = "adk_test_app"
4447
USER_ID = "adk_test_user"
4548
STOP_SIGNAL = StopSignal()
4649
TRIGGER_SIGNAL = TriggerSignal()
4750

51+
4852
@pytest.mark.asyncio
4953
async def test_stop_signal():
50-
async def consume_stream(runner, trigger_signal, **kwargs):
51-
cycle = 0
52-
53-
async for event in runner.run_async(**kwargs):
54-
cycle += 1
55-
content = event.content.parts[0].text
56-
57-
if not content:
58-
# The fourth cycle should be interrupted
59-
# Only interrupted cycles would not have content objects given this MockAgent
60-
assert cycle == 4
61-
assert event.interrupted
62-
if content and content == "3":
63-
# Tell main loop to interrupt the fourth cycle
64-
trigger_signal.set()
65-
66-
client = MockAgent()
67-
runner = InMemoryRunner(agent=client, app_name=APP_NAME)
68-
69-
session = await runner.session_service.create_session(
70-
app_name=APP_NAME, user_id=USER_ID
71-
)
72-
73-
message = Content(
74-
role="user",
75-
parts=[Part(text="Necessary message")]
76-
)
77-
78-
task = asyncio.create_task(consume_stream(
79-
runner=runner,
80-
trigger_signal=TRIGGER_SIGNAL,
81-
session_id=session.id,
82-
user_id=USER_ID,
83-
new_message=message,
84-
stop_signal=STOP_SIGNAL,
85-
run_config=RunConfig(
86-
streaming_mode=StreamingMode.SSE
87-
)
88-
))
89-
90-
# Wait for 3 events to pass
91-
while not TRIGGER_SIGNAL.is_set():
92-
await asyncio.sleep(0.1)
93-
94-
# Interrupt the run_async
95-
STOP_SIGNAL.stop()
54+
async def consume_stream(runner, trigger_signal, **kwargs):
55+
cycle = 0
56+
57+
async for event in runner.run_async(**kwargs):
58+
cycle += 1
59+
content = event.content.parts[0].text
60+
61+
if not content:
62+
# The fourth cycle should be interrupted
63+
# Only interrupted cycles would not have content objects given this MockAgent
64+
assert cycle == 4
65+
assert event.interrupted
66+
if content and content == "3":
67+
# Tell main loop to interrupt the fourth cycle
68+
trigger_signal.set()
69+
70+
client = MockAgent()
71+
runner = InMemoryRunner(agent=client, app_name=APP_NAME)
72+
73+
session = await runner.session_service.create_session(
74+
app_name=APP_NAME, user_id=USER_ID
75+
)
76+
77+
message = Content(role="user", parts=[Part(text="Necessary message")])
78+
79+
task = asyncio.create_task(
80+
consume_stream(
81+
runner=runner,
82+
trigger_signal=TRIGGER_SIGNAL,
83+
session_id=session.id,
84+
user_id=USER_ID,
85+
new_message=message,
86+
stop_signal=STOP_SIGNAL,
87+
run_config=RunConfig(streaming_mode=StreamingMode.SSE),
88+
)
89+
)
90+
91+
# Wait for 3 events to pass
92+
while not TRIGGER_SIGNAL.is_set():
93+
await asyncio.sleep(0.1)
94+
95+
# Interrupt the run_async
96+
STOP_SIGNAL.stop()

0 commit comments

Comments
 (0)