@@ -41,12 +41,14 @@ class MockPlugin(BasePlugin):
4141 before_model_text = 'before_model_text from MockPlugin'
4242 after_model_text = 'after_model_text from MockPlugin'
4343 on_model_error_text = 'on_model_error_text from MockPlugin'
44+ on_model_request_text = 'on_model_request_text from MockPlugin'
4445
4546 def __init__ (self , name = 'mock_plugin' ):
4647 self .name = name
4748 self .enable_before_model_callback = False
4849 self .enable_after_model_callback = False
4950 self .enable_on_model_error_callback = False
51+ self .enable_on_model_request_callback = False
5052 self .before_model_response = LlmResponse (
5153 content = testing_utils .ModelContent (
5254 [types .Part .from_text (text = self .before_model_text )]
@@ -62,6 +64,11 @@ def __init__(self, name='mock_plugin'):
6264 [types .Part .from_text (text = self .on_model_error_text )]
6365 )
6466 )
67+ self .on_model_request_response = LlmResponse (
68+ content = testing_utils .ModelContent (
69+ [types .Part .from_text (text = self .on_model_request_text )]
70+ )
71+ )
6572
6673 async def before_model_callback (
6774 self , * , callback_context : CallbackContext , llm_request : LlmRequest
@@ -88,6 +95,13 @@ async def on_model_error_callback(
8895 return None
8996 return self .on_model_error_response
9097
98+ async def on_model_request_callback (
99+ self , * , callback_context : CallbackContext , llm_request : LlmRequest
100+ ) -> Optional [LlmResponse ]:
101+ if not self .enable_on_model_request_callback :
102+ return None
103+ return self .on_model_request_response
104+
91105
92106CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content'
93107
@@ -138,6 +152,22 @@ def test_before_model_fallback_canonical_callback(mock_plugin):
138152 ]
139153
140154
155+ def test_on_model_request_callback_with_plugin (mock_plugin ):
156+ """Tests that the model response is overridden by on_model_request_callback from the plugin."""
157+ responses = ['model_response' ]
158+ mock_model = testing_utils .MockModel .create (responses = responses )
159+ mock_plugin .enable_on_model_request_callback = True
160+ agent = Agent (
161+ name = 'root_agent' ,
162+ model = mock_model ,
163+ )
164+
165+ runner = testing_utils .InMemoryRunner (agent , plugins = [mock_plugin ])
166+ assert testing_utils .simplify_events (runner .run ('test' )) == [
167+ ('root_agent' , mock_plugin .on_model_request_text ),
168+ ]
169+
170+
141171def test_before_model_callback_fallback_model (mock_plugin ):
142172 """Tests that the model response is executed normally when both plugin and canonical agent model callback return empty response."""
143173 responses = ['model_response' ]
0 commit comments