Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@
logger = logging.getLogger('google_adk.' + __name__)


class StopSignal:

def __init__(self) -> None:
self.stopped = False

def stop(self) -> None:
self.stopped = True

def reset(self) -> None:
self.stopped = False

def is_set(self) -> bool:
return self.stopped


def _is_tool_call_or_response(event: Event) -> bool:
return bool(event.get_function_calls() or event.get_function_responses())

Expand Down Expand Up @@ -508,6 +523,7 @@ async def run_async(
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
stop_signal: Optional[StopSignal] = None,
) -> AsyncGenerator[Event, None]:
"""Main entry method to run the agent in this runner.

Expand Down Expand Up @@ -611,10 +627,13 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
session=session,
execute_fn=execute,
is_live_call=False,
stop_signal=stop_signal,
)
) as agen:
async for event in agen:
yield event
if event.interrupted:
return
# Run compaction after all events are yielded from the agent.
# (We don't compact in the middle of an invocation, we only compact at
# the end of an invocation.)
Expand Down Expand Up @@ -836,6 +855,8 @@ async def _exec_with_plugin(
session: Session,
execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]],
is_live_call: bool = False,
*,
stop_signal: Optional[StopSignal] = None,
) -> AsyncGenerator[Event, None]:
"""Wraps execution with plugin callbacks.

Expand Down Expand Up @@ -890,6 +911,17 @@ async def _exec_with_plugin(

async with Aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
if stop_signal and stop_signal.is_set():
# See if the run_async execution is interrupted
# If an interruption is called, return an event that signifies as such
interrupted_event = Event(
invocation_id=invocation_context.invocation_id,
author='model',
interrupted=True,
content=types.Content(role='model', parts=[]),
)
yield interrupted_event

_apply_run_config_custom_metadata(
event, invocation_context.run_config
)
Expand Down
98 changes: 98 additions & 0 deletions tests/unittests/runners/test_cancel_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import asyncio
from typing import Any
from typing import AsyncGenerator

from google.adk.agents import BaseAgent
from google.adk.agents.run_config import RunConfig
from google.adk.agents.run_config import StreamingMode
from google.adk.events import Event
from google.adk.runners import InMemoryRunner
from google.adk.runners import StopSignal
from google.genai.types import Content
from google.genai.types import Part
import pytest


# Used to tell the main loop when to interrupt the run_async
class TriggerSignal:

def __init__(self):
self.triggered = False

def set(self):
self.triggered = True

def reset(self):
self.triggered = False

def is_set(self) -> bool:
return self.triggered


# Mocks an LLM, outputting events every 1s
class MockAgent(BaseAgent):
name: str = "MockAgent"

async def run_async(self, ctx: Any) -> AsyncGenerator[Event, None]:
cycle = 0

while True:
await asyncio.sleep(1)
cycle += 1
yield Event(
author="model",
content=Content(role="model", parts=[Part(text=f"{cycle}")]),
)


APP_NAME = "adk_test_app"
USER_ID = "adk_test_user"
STOP_SIGNAL = StopSignal()
TRIGGER_SIGNAL = TriggerSignal()


@pytest.mark.asyncio
async def test_stop_signal():
async def consume_stream(runner, trigger_signal, **kwargs):
cycle = 0

async for event in runner.run_async(**kwargs):
cycle += 1
content = event.content.parts[0].text

if not content:
# The fourth cycle should be interrupted
# Only interrupted cycles would not have content objects given this MockAgent
assert cycle == 4
assert event.interrupted
if content and content == "3":
# Tell main loop to interrupt the fourth cycle
trigger_signal.set()

client = MockAgent()
runner = InMemoryRunner(agent=client, app_name=APP_NAME)

session = await runner.session_service.create_session(
app_name=APP_NAME, user_id=USER_ID
)

message = Content(role="user", parts=[Part(text="Necessary message")])

task = asyncio.create_task(
consume_stream(
runner=runner,
trigger_signal=TRIGGER_SIGNAL,
session_id=session.id,
user_id=USER_ID,
new_message=message,
stop_signal=STOP_SIGNAL,
run_config=RunConfig(streaming_mode=StreamingMode.SSE),
)
)

# Wait for 3 events to pass
while not TRIGGER_SIGNAL.is_set():
await asyncio.sleep(0.1)

# Interrupt the run_async
STOP_SIGNAL.stop()
Loading