Skip to content

Commit 794a70e

Browse files
selcukguncopybara-github
authored andcommitted
Support async agent and model callbacks
PiperOrigin-RevId: 755542756
1 parent f96cdc6 commit 794a70e

25 files changed

Lines changed: 359 additions & 105 deletions

File tree

src/google/adk/agents/base_agent.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
17+
import inspect
18+
from typing import Any, Awaitable, Union
1819
from typing import AsyncGenerator
1920
from typing import Callable
2021
from typing import final
@@ -37,10 +38,15 @@
3738

3839
tracer = trace.get_tracer('gcp.vertex.agent')
3940

40-
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
41+
BeforeAgentCallback = Callable[
42+
[CallbackContext],
43+
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
44+
]
4145

42-
43-
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
46+
AfterAgentCallback = Callable[
47+
[CallbackContext],
48+
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
49+
]
4450

4551

4652
class BaseAgent(BaseModel):
@@ -119,7 +125,7 @@ async def run_async(
119125
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
120126
ctx = self._create_invocation_context(parent_context)
121127

122-
if event := self.__handle_before_agent_callback(ctx):
128+
if event := await self.__handle_before_agent_callback(ctx):
123129
yield event
124130
if ctx.end_invocation:
125131
return
@@ -130,7 +136,7 @@ async def run_async(
130136
if ctx.end_invocation:
131137
return
132138

133-
if event := self.__handle_after_agent_callback(ctx):
139+
if event := await self.__handle_after_agent_callback(ctx):
134140
yield event
135141

136142
@final
@@ -230,7 +236,7 @@ def _create_invocation_context(
230236
invocation_context.branch = f'{parent_context.branch}.{self.name}'
231237
return invocation_context
232238

233-
def __handle_before_agent_callback(
239+
async def __handle_before_agent_callback(
234240
self, ctx: InvocationContext
235241
) -> Optional[Event]:
236242
"""Runs the before_agent_callback if it exists.
@@ -248,6 +254,9 @@ def __handle_before_agent_callback(
248254
callback_context=callback_context
249255
)
250256

257+
if inspect.isawaitable(before_agent_callback_content):
258+
before_agent_callback_content = await before_agent_callback_content
259+
251260
if before_agent_callback_content:
252261
ret_event = Event(
253262
invocation_id=ctx.invocation_id,
@@ -269,7 +278,7 @@ def __handle_before_agent_callback(
269278

270279
return ret_event
271280

272-
def __handle_after_agent_callback(
281+
async def __handle_after_agent_callback(
273282
self, invocation_context: InvocationContext
274283
) -> Optional[Event]:
275284
"""Runs the after_agent_callback if it exists.
@@ -287,6 +296,9 @@ def __handle_after_agent_callback(
287296
callback_context=callback_context
288297
)
289298

299+
if inspect.isawaitable(after_agent_callback_content):
300+
after_agent_callback_content = await after_agent_callback_content
301+
290302
if after_agent_callback_content or callback_context.state.has_delta():
291303
ret_event = Event(
292304
invocation_id=invocation_context.invocation_id,

src/google/adk/agents/llm_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,12 @@
4949

5050

5151
BeforeModelCallback: TypeAlias = Callable[
52-
[CallbackContext, LlmRequest], Optional[LlmResponse]
52+
[CallbackContext, LlmRequest],
53+
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
5354
]
5455
AfterModelCallback: TypeAlias = Callable[
5556
[CallbackContext, LlmResponse],
56-
Optional[LlmResponse],
57+
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
5758
]
5859
BeforeToolCallback: TypeAlias = Callable[
5960
[BaseTool, dict[str, Any], ToolContext],

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from abc import ABC
1818
import asyncio
19+
import inspect
1920
import logging
2021
from typing import AsyncGenerator
2122
from typing import cast
@@ -199,7 +200,7 @@ def get_author(llm_response):
199200
return "user"
200201
else:
201202
return invocation_context.agent.name
202-
203+
203204
assert invocation_context.live_request_queue
204205
try:
205206
while True:
@@ -447,7 +448,7 @@ async def _call_llm_async(
447448
model_response_event: Event,
448449
) -> AsyncGenerator[LlmResponse, None]:
449450
# Runs before_model_callback if it exists.
450-
if response := self._handle_before_model_callback(
451+
if response := await self._handle_before_model_callback(
451452
invocation_context, llm_request, model_response_event
452453
):
453454
yield response
@@ -460,7 +461,7 @@ async def _call_llm_async(
460461
invocation_context.live_request_queue = LiveRequestQueue()
461462
async for llm_response in self.run_live(invocation_context):
462463
# Runs after_model_callback if it exists.
463-
if altered_llm_response := self._handle_after_model_callback(
464+
if altered_llm_response := await self._handle_after_model_callback(
464465
invocation_context, llm_response, model_response_event
465466
):
466467
llm_response = altered_llm_response
@@ -489,14 +490,14 @@ async def _call_llm_async(
489490
llm_response,
490491
)
491492
# Runs after_model_callback if it exists.
492-
if altered_llm_response := self._handle_after_model_callback(
493+
if altered_llm_response := await self._handle_after_model_callback(
493494
invocation_context, llm_response, model_response_event
494495
):
495496
llm_response = altered_llm_response
496497

497498
yield llm_response
498499

499-
def _handle_before_model_callback(
500+
async def _handle_before_model_callback(
500501
self,
501502
invocation_context: InvocationContext,
502503
llm_request: LlmRequest,
@@ -514,11 +515,16 @@ def _handle_before_model_callback(
514515
callback_context = CallbackContext(
515516
invocation_context, event_actions=model_response_event.actions
516517
)
517-
return agent.before_model_callback(
518+
before_model_callback_content = agent.before_model_callback(
518519
callback_context=callback_context, llm_request=llm_request
519520
)
520521

521-
def _handle_after_model_callback(
522+
if inspect.isawaitable(before_model_callback_content):
523+
before_model_callback_content = await before_model_callback_content
524+
525+
return before_model_callback_content
526+
527+
async def _handle_after_model_callback(
522528
self,
523529
invocation_context: InvocationContext,
524530
llm_response: LlmResponse,
@@ -536,10 +542,15 @@ def _handle_after_model_callback(
536542
callback_context = CallbackContext(
537543
invocation_context, event_actions=model_response_event.actions
538544
)
539-
return agent.after_model_callback(
545+
after_model_callback_content = agent.after_model_callback(
540546
callback_context=callback_context, llm_response=llm_response
541547
)
542548

549+
if inspect.isawaitable(after_model_callback_content):
550+
after_model_callback_content = await after_model_callback_content
551+
552+
return after_model_callback_content
553+
543554
def _finalize_model_response_event(
544555
self,
545556
llm_request: LlmRequest,

tests/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

tests/integration/fixture/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

tests/integration/fixture/callback_agent/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from . import agent
15+
from . import agent

tests/integration/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

tests/integration/test_evalute_agent_in_fixture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from google.adk.evaluation import AgentEvaluator
2020
import pytest
2121

22+
2223
def agent_eval_artifacts_in_fixture():
2324
"""Get all agents from fixture folder."""
2425
agent_eval_artifacts = []

tests/integration/tools/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

tests/unittests/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

0 commit comments

Comments
 (0)