|
24 | 24 | from google.adk.sessions.session import Session |
25 | 25 | from google.genai.types import Content |
26 | 26 | from google.genai.types import FunctionCall |
| 27 | +from google.genai.types import FunctionResponse |
27 | 28 | from google.genai.types import Part |
28 | 29 | import pytest |
29 | 30 |
|
@@ -210,6 +211,82 @@ def test_should_not_pause_invocation_with_no_function_calls( |
210 | 211 | nonpausable_event |
211 | 212 | ) |
212 | 213 |
|
| 214 | + def test_has_unresolved_long_running_tool_calls_with_matching_response(self): |
| 215 | + """Tests that matching function responses resolve the pause.""" |
| 216 | + invocation_context = self._create_test_invocation_context( |
| 217 | + ResumabilityConfig(is_resumable=True) |
| 218 | + ) |
| 219 | + function_call = FunctionCall( |
| 220 | + id='tool_call_id_1', |
| 221 | + name='long_running_function_call', |
| 222 | + args={}, |
| 223 | + ) |
| 224 | + paused_event = Event( |
| 225 | + invocation_id='inv_1', |
| 226 | + author='agent', |
| 227 | + content=testing_utils.ModelContent([Part(function_call=function_call)]), |
| 228 | + long_running_tool_ids={function_call.id}, |
| 229 | + ) |
| 230 | + resolved_event = Event( |
| 231 | + invocation_id='inv_1', |
| 232 | + author='user', |
| 233 | + content=Content( |
| 234 | + role='user', |
| 235 | + parts=[ |
| 236 | + Part( |
| 237 | + function_response=FunctionResponse( |
| 238 | + name='long_running_function_call', |
| 239 | + response={'result': 'done'}, |
| 240 | + id=function_call.id, |
| 241 | + ) |
| 242 | + ) |
| 243 | + ], |
| 244 | + ), |
| 245 | + ) |
| 246 | + |
| 247 | + assert not invocation_context.has_unresolved_long_running_tool_calls( |
| 248 | + [paused_event, resolved_event] |
| 249 | + ) |
| 250 | + |
| 251 | + def test_has_unresolved_long_running_tool_calls_without_matching_response( |
| 252 | + self, |
| 253 | + ): |
| 254 | + """Tests that unmatched long-running calls still pause the invocation.""" |
| 255 | + invocation_context = self._create_test_invocation_context( |
| 256 | + ResumabilityConfig(is_resumable=True) |
| 257 | + ) |
| 258 | + function_call = FunctionCall( |
| 259 | + id='tool_call_id_1', |
| 260 | + name='long_running_function_call', |
| 261 | + args={}, |
| 262 | + ) |
| 263 | + paused_event = Event( |
| 264 | + invocation_id='inv_1', |
| 265 | + author='agent', |
| 266 | + content=testing_utils.ModelContent([Part(function_call=function_call)]), |
| 267 | + long_running_tool_ids={function_call.id}, |
| 268 | + ) |
| 269 | + unrelated_response_event = Event( |
| 270 | + invocation_id='inv_1', |
| 271 | + author='user', |
| 272 | + content=Content( |
| 273 | + role='user', |
| 274 | + parts=[ |
| 275 | + Part( |
| 276 | + function_response=FunctionResponse( |
| 277 | + name='long_running_function_call', |
| 278 | + response={'result': 'done'}, |
| 279 | + id='different_tool_call_id', |
| 280 | + ) |
| 281 | + ) |
| 282 | + ], |
| 283 | + ), |
| 284 | + ) |
| 285 | + |
| 286 | + assert invocation_context.has_unresolved_long_running_tool_calls( |
| 287 | + [paused_event, unrelated_response_event] |
| 288 | + ) |
| 289 | + |
213 | 290 | def test_is_resumable_true(self): |
214 | 291 | """Tests that is_resumable is True when resumability is enabled.""" |
215 | 292 | invocation_context = self._create_test_invocation_context( |
|
0 commit comments