diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 0f36d6389d..62720e79ee 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -67,6 +67,21 @@ logger = logging.getLogger('google_adk.' + __name__) +class StopSignal: + + def __init__(self) -> None: + self.stopped = False + + def stop(self) -> None: + self.stopped = True + + def reset(self) -> None: + self.stopped = False + + def is_set(self) -> bool: + return self.stopped + + def _is_tool_call_or_response(event: Event) -> bool: return bool(event.get_function_calls() or event.get_function_responses()) @@ -508,6 +523,7 @@ async def run_async( new_message: Optional[types.Content] = None, state_delta: Optional[dict[str, Any]] = None, run_config: Optional[RunConfig] = None, + stop_signal: Optional[StopSignal] = None, ) -> AsyncGenerator[Event, None]: """Main entry method to run the agent in this runner. @@ -611,10 +627,13 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: session=session, execute_fn=execute, is_live_call=False, + stop_signal=stop_signal, ) ) as agen: async for event in agen: yield event + if event.interrupted: + return # Run compaction after all events are yielded from the agent. # (We don't compact in the middle of an invocation, we only compact at # the end of an invocation.) @@ -836,6 +855,8 @@ async def _exec_with_plugin( session: Session, execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]], is_live_call: bool = False, + *, + stop_signal: Optional[StopSignal] = None, ) -> AsyncGenerator[Event, None]: """Wraps execution with plugin callbacks. @@ -890,6 +911,17 @@ async def _exec_with_plugin( async with Aclosing(execute_fn(invocation_context)) as agen: async for event in agen: + if stop_signal and stop_signal.is_set(): + # See if the run_async execution is interrupted + # If an interruption is called, return an event that signifies as such + interrupted_event = Event( + invocation_id=invocation_context.invocation_id, + author='model', + interrupted=True, + content=types.Content(role='model', parts=[]), + ) + yield interrupted_event + _apply_run_config_custom_metadata( event, invocation_context.run_config ) diff --git a/tests/unittests/runners/test_cancel_async.py b/tests/unittests/runners/test_cancel_async.py new file mode 100644 index 0000000000..7f23edbe78 --- /dev/null +++ b/tests/unittests/runners/test_cancel_async.py @@ -0,0 +1,98 @@ +import asyncio +from typing import Any +from typing import AsyncGenerator + +from google.adk.agents import BaseAgent +from google.adk.agents.run_config import RunConfig +from google.adk.agents.run_config import StreamingMode +from google.adk.events import Event +from google.adk.runners import InMemoryRunner +from google.adk.runners import StopSignal +from google.genai.types import Content +from google.genai.types import Part +import pytest + + +# Used to tell the main loop when to interrupt the run_async +class TriggerSignal: + + def __init__(self): + self.triggered = False + + def set(self): + self.triggered = True + + def reset(self): + self.triggered = False + + def is_set(self) -> bool: + return self.triggered + + +# Mocks an LLM, outputting events every 1s +class MockAgent(BaseAgent): + name: str = "MockAgent" + + async def run_async(self, ctx: Any) -> AsyncGenerator[Event, None]: + cycle = 0 + + while True: + await asyncio.sleep(1) + cycle += 1 + yield Event( + author="model", + content=Content(role="model", parts=[Part(text=f"{cycle}")]), + ) + + +APP_NAME = "adk_test_app" +USER_ID = "adk_test_user" +STOP_SIGNAL = StopSignal() +TRIGGER_SIGNAL = TriggerSignal() + + +@pytest.mark.asyncio +async def test_stop_signal(): + async def consume_stream(runner, trigger_signal, **kwargs): + cycle = 0 + + async for event in runner.run_async(**kwargs): + cycle += 1 + content = event.content.parts[0].text + + if not content: + # The fourth cycle should be interrupted + # Only interrupted cycles would not have content objects given this MockAgent + assert cycle == 4 + assert event.interrupted + if content and content == "3": + # Tell main loop to interrupt the fourth cycle + trigger_signal.set() + + client = MockAgent() + runner = InMemoryRunner(agent=client, app_name=APP_NAME) + + session = await runner.session_service.create_session( + app_name=APP_NAME, user_id=USER_ID + ) + + message = Content(role="user", parts=[Part(text="Necessary message")]) + + task = asyncio.create_task( + consume_stream( + runner=runner, + trigger_signal=TRIGGER_SIGNAL, + session_id=session.id, + user_id=USER_ID, + new_message=message, + stop_signal=STOP_SIGNAL, + run_config=RunConfig(streaming_mode=StreamingMode.SSE), + ) + ) + + # Wait for 3 events to pass + while not TRIGGER_SIGNAL.is_set(): + await asyncio.sleep(0.1) + + # Interrupt the run_async + STOP_SIGNAL.stop()