Skip to content

Commit 3a1a24c

Browse files
author
Jashwanth
committed
feat: add on_model_request_callback to plugins
1 parent 9e3b43f commit 3a1a24c

8 files changed

Lines changed: 110 additions & 2 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,16 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
13201320
invocation_context.agent.name
13211321
)
13221322

1323+
callback_context = CallbackContext(
1324+
invocation_context, event_actions=model_response_event.actions
1325+
)
1326+
if response := await invocation_context.plugin_manager.run_on_model_request_callback(
1327+
callback_context=callback_context,
1328+
llm_request=llm_request,
1329+
):
1330+
yield response
1331+
return
1332+
13231333
# Calls the LLM.
13241334
llm = self.__get_llm(invocation_context)
13251335

src/google/adk/plugins/base_plugin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,26 @@ async def before_model_callback(
250250
"""
251251
pass
252252

253+
async def on_model_request_callback(
254+
self, *, callback_context: CallbackContext, llm_request: LlmRequest
255+
) -> Optional[LlmResponse]:
256+
"""Callback executed immediately before a request is sent to the model.
257+
258+
This hook is fired after all `before_model_callback`s have completed and
259+
the request has been finalized (e.g. labels injected). It is the correct
260+
place to observe the exact `LlmRequest` that will be sent to the model.
261+
262+
Args:
263+
callback_context: The context for the current agent call.
264+
llm_request: The final request object to be sent to the model.
265+
266+
Returns:
267+
An optional LlmResponse. If an LlmResponse is returned, it will be used
268+
instead of calling the model. Returning `None` allows the model call
269+
to proceed normally.
270+
"""
271+
pass
272+
253273
async def after_model_callback(
254274
self, *, callback_context: CallbackContext, llm_response: LlmResponse
255275
) -> Optional[LlmResponse]:

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3726,13 +3726,13 @@ async def after_agent_callback(
37263726
)
37273727

37283728
@_safe_callback
3729-
async def before_model_callback(
3729+
async def on_model_request_callback(
37303730
self,
37313731
*,
37323732
callback_context: CallbackContext,
37333733
llm_request: LlmRequest,
37343734
) -> None:
3735-
"""Callback before LLM call.
3735+
"""Callback immediately before LLM call.
37363736
37373737
Logs the LLM request details including:
37383738
1. Prompt content

src/google/adk/plugins/plugin_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"before_tool_callback",
5050
"after_tool_callback",
5151
"before_model_callback",
52+
"on_model_request_callback",
5253
"after_model_callback",
5354
"on_tool_error_callback",
5455
"on_model_error_callback",
@@ -245,6 +246,16 @@ async def run_before_model_callback(
245246
llm_request=llm_request,
246247
)
247248

249+
async def run_on_model_request_callback(
250+
self, *, callback_context: CallbackContext, llm_request: LlmRequest
251+
) -> Optional[LlmResponse]:
252+
"""Runs the `on_model_request_callback` for all plugins."""
253+
return await self._run_callbacks(
254+
"on_model_request_callback",
255+
callback_context=callback_context,
256+
llm_request=llm_request,
257+
)
258+
248259
async def run_after_model_callback(
249260
self, *, callback_context: CallbackContext, llm_response: LlmResponse
250261
) -> Optional[LlmResponse]:

tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class SpanCapturingPlugin(BasePlugin):
6363
def __init__(self):
6464
self.name = 'span_capturing_plugin'
6565
self.before_capture = _SpanCapture()
66+
self.request_capture = _SpanCapture()
6667
self.after_capture = _SpanCapture()
6768
self.error_capture = _SpanCapture()
6869

@@ -80,6 +81,15 @@ async def before_model_callback(
8081
return self._short_circuit_response
8182
return None
8283

84+
async def on_model_request_callback(
85+
self,
86+
*,
87+
callback_context: CallbackContext,
88+
llm_request: LlmRequest,
89+
) -> Optional[LlmResponse]:
90+
self.request_capture.capture()
91+
return None
92+
8393
async def after_model_callback(
8494
self,
8595
*,
@@ -149,6 +159,11 @@ def test_before_and_after_callbacks_share_same_span():
149159
f' before={plugin.before_capture.span_id:#x},'
150160
f' after={plugin.after_capture.span_id:#x}'
151161
)
162+
assert plugin.before_capture.span_id == plugin.request_capture.span_id, (
163+
'before_model_callback and on_model_request_callback saw different spans:'
164+
f' before={plugin.before_capture.span_id:#x},'
165+
f' request={plugin.request_capture.span_id:#x}'
166+
)
152167

153168

154169
def test_callbacks_same_trace_id():

tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ class MockPlugin(BasePlugin):
4141
before_model_text = 'before_model_text from MockPlugin'
4242
after_model_text = 'after_model_text from MockPlugin'
4343
on_model_error_text = 'on_model_error_text from MockPlugin'
44+
on_model_request_text = 'on_model_request_text from MockPlugin'
4445

4546
def __init__(self, name='mock_plugin'):
4647
self.name = name
4748
self.enable_before_model_callback = False
4849
self.enable_after_model_callback = False
4950
self.enable_on_model_error_callback = False
51+
self.enable_on_model_request_callback = False
5052
self.before_model_response = LlmResponse(
5153
content=testing_utils.ModelContent(
5254
[types.Part.from_text(text=self.before_model_text)]
@@ -62,6 +64,11 @@ def __init__(self, name='mock_plugin'):
6264
[types.Part.from_text(text=self.on_model_error_text)]
6365
)
6466
)
67+
self.on_model_request_response = LlmResponse(
68+
content=testing_utils.ModelContent(
69+
[types.Part.from_text(text=self.on_model_request_text)]
70+
)
71+
)
6572

6673
async def before_model_callback(
6774
self, *, callback_context: CallbackContext, llm_request: LlmRequest
@@ -88,6 +95,13 @@ async def on_model_error_callback(
8895
return None
8996
return self.on_model_error_response
9097

98+
async def on_model_request_callback(
99+
self, *, callback_context: CallbackContext, llm_request: LlmRequest
100+
) -> Optional[LlmResponse]:
101+
if not self.enable_on_model_request_callback:
102+
return None
103+
return self.on_model_request_response
104+
91105

92106
CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content'
93107

@@ -138,6 +152,22 @@ def test_before_model_fallback_canonical_callback(mock_plugin):
138152
]
139153

140154

155+
def test_on_model_request_callback_with_plugin(mock_plugin):
156+
"""Tests that the model response is overridden by on_model_request_callback from the plugin."""
157+
responses = ['model_response']
158+
mock_model = testing_utils.MockModel.create(responses=responses)
159+
mock_plugin.enable_on_model_request_callback = True
160+
agent = Agent(
161+
name='root_agent',
162+
model=mock_model,
163+
)
164+
165+
runner = testing_utils.InMemoryRunner(agent, plugins=[mock_plugin])
166+
assert testing_utils.simplify_events(runner.run('test')) == [
167+
('root_agent', mock_plugin.on_model_request_text),
168+
]
169+
170+
141171
def test_before_model_callback_fallback_model(mock_plugin):
142172
"""Tests that the model response is executed normally when both plugin and canonical agent model callback return empty response."""
143173
responses = ['model_response']

tests/unittests/plugins/test_base_plugin.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ async def on_tool_error_callback(self, **kwargs) -> str:
7373
async def before_model_callback(self, **kwargs) -> str:
7474
return "overridden_before_model"
7575

76+
async def on_model_request_callback(self, **kwargs) -> str:
77+
return "overridden_on_model_request"
78+
7679
async def after_model_callback(self, **kwargs) -> str:
7780
return "overridden_after_model"
7881

@@ -158,6 +161,12 @@ async def test_base_plugin_default_callbacks_return_none():
158161
)
159162
is None
160163
)
164+
assert (
165+
await plugin.on_model_request_callback(
166+
callback_context=mock_context, llm_request=mock_context
167+
)
168+
is None
169+
)
161170
assert (
162171
await plugin.after_model_callback(
163172
callback_context=mock_context, llm_response=mock_context
@@ -240,6 +249,12 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
240249
)
241250
== "overridden_before_model"
242251
)
252+
assert (
253+
await plugin.on_model_request_callback(
254+
callback_context=mock_callback_context, llm_request=mock_llm_request
255+
)
256+
== "overridden_on_model_request"
257+
)
243258
assert (
244259
await plugin.after_model_callback(
245260
callback_context=mock_callback_context, llm_response=mock_llm_response

tests/unittests/plugins/test_plugin_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ async def on_tool_error_callback(self, **kwargs):
8585
async def before_model_callback(self, **kwargs):
8686
return await self._handle_callback("before_model_callback")
8787

88+
async def on_model_request_callback(self, **kwargs):
89+
return await self._handle_callback("on_model_request_callback")
90+
8891
async def after_model_callback(self, **kwargs):
8992
return await self._handle_callback("after_model_callback")
9093

@@ -244,6 +247,9 @@ async def test_all_callbacks_are_supported(
244247
await service.run_before_model_callback(
245248
callback_context=mock_context, llm_request=mock_context
246249
)
250+
await service.run_on_model_request_callback(
251+
callback_context=mock_context, llm_request=mock_context
252+
)
247253
await service.run_after_model_callback(
248254
callback_context=mock_context, llm_response=mock_context
249255
)
@@ -265,6 +271,7 @@ async def test_all_callbacks_are_supported(
265271
"after_tool_callback",
266272
"on_tool_error_callback",
267273
"before_model_callback",
274+
"on_model_request_callback",
268275
"after_model_callback",
269276
"on_model_error_callback",
270277
]

0 commit comments

Comments
 (0)