|
26 | 26 | import com.google.adk.agents.RunConfig.ToolExecutionMode; |
27 | 27 | import com.google.adk.events.Event; |
28 | 28 | import com.google.adk.testing.TestUtils; |
| 29 | +import com.google.adk.tools.BaseTool; |
| 30 | +import com.google.adk.tools.ToolContext; |
29 | 31 | import com.google.common.collect.ImmutableList; |
30 | 32 | import com.google.common.collect.ImmutableMap; |
31 | 33 | import com.google.genai.types.Content; |
32 | 34 | import com.google.genai.types.FunctionCall; |
| 35 | +import com.google.genai.types.FunctionDeclaration; |
33 | 36 | import com.google.genai.types.FunctionResponse; |
34 | 37 | import com.google.genai.types.Part; |
| 38 | +import io.reactivex.rxjava3.core.Single; |
| 39 | +import java.util.ArrayList; |
| 40 | +import java.util.LinkedHashMap; |
| 41 | +import java.util.List; |
| 42 | +import java.util.Map; |
| 43 | +import java.util.Optional; |
35 | 44 | import org.junit.Test; |
36 | 45 | import org.junit.runner.RunWith; |
37 | 46 | import org.junit.runners.JUnit4; |
@@ -388,4 +397,132 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal |
388 | 397 | ImmutableList<FunctionCall> result = Functions.getAskUserConfirmationFunctionCalls(event); |
389 | 398 | assertThat(result).containsExactly(confirmationCall1, confirmationCall2); |
390 | 399 | } |
| 400 | + |
| 401 | + @Test |
| 402 | + public void handleFunctionCalls_parallel_blockingTools_runConcurrently_twoTools() { |
| 403 | + runParallelBlockingToolsTest(/* toolCount= */ 2); |
| 404 | + } |
| 405 | + |
| 406 | + @Test |
| 407 | + public void handleFunctionCalls_parallel_blockingTools_runConcurrently_threeTools() { |
| 408 | + runParallelBlockingToolsTest(/* toolCount= */ 3); |
| 409 | + } |
| 410 | + |
| 411 | + @Test |
| 412 | + public void handleFunctionCalls_parallel_blockingTools_runConcurrently_fiveTools() { |
| 413 | + runParallelBlockingToolsTest(/* toolCount= */ 5); |
| 414 | + } |
| 415 | + |
| 416 | + /** Single-tool case bypasses the parallel scheduler path; must still return the correct event. */ |
| 417 | + @Test |
| 418 | + public void handleFunctionCalls_parallel_blockingTool_singleTool() { |
| 419 | + long sleepMillis = 200L; |
| 420 | + InvocationContext invocationContext = |
| 421 | + createInvocationContext( |
| 422 | + createRootAgent(), |
| 423 | + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build()); |
| 424 | + SleepingTool tool = new SleepingTool("slow_tool_1", sleepMillis); |
| 425 | + Event event = |
| 426 | + createEvent("event").toBuilder() |
| 427 | + .content( |
| 428 | + Content.fromParts( |
| 429 | + Part.builder() |
| 430 | + .functionCall( |
| 431 | + FunctionCall.builder() |
| 432 | + .id("call_1") |
| 433 | + .name("slow_tool_1") |
| 434 | + .args(ImmutableMap.of()) |
| 435 | + .build()) |
| 436 | + .build())) |
| 437 | + .build(); |
| 438 | + |
| 439 | + Event functionResponseEvent = |
| 440 | + Functions.handleFunctionCalls( |
| 441 | + invocationContext, event, ImmutableMap.of("slow_tool_1", tool)) |
| 442 | + .blockingGet(); |
| 443 | + |
| 444 | + assertThat(functionResponseEvent).isNotNull(); |
| 445 | + assertThat(functionResponseEvent.content().get().parts().get()) |
| 446 | + .containsExactly( |
| 447 | + Part.builder() |
| 448 | + .functionResponse( |
| 449 | + FunctionResponse.builder() |
| 450 | + .id("call_1") |
| 451 | + .name("slow_tool_1") |
| 452 | + .response(ImmutableMap.of("tool", "slow_tool_1")) |
| 453 | + .build()) |
| 454 | + .build()); |
| 455 | + } |
| 456 | + |
| 457 | + /** Asserts that {@code toolCount} blocking tools in PARALLEL mode run faster than sequential. */ |
| 458 | + private static void runParallelBlockingToolsTest(int toolCount) { |
| 459 | + long sleepMillis = 500L; |
| 460 | + InvocationContext invocationContext = |
| 461 | + createInvocationContext( |
| 462 | + createRootAgent(), |
| 463 | + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build()); |
| 464 | + |
| 465 | + Map<String, BaseTool> tools = new LinkedHashMap<>(); |
| 466 | + List<Part> callParts = new ArrayList<>(); |
| 467 | + List<Part> expectedResponseParts = new ArrayList<>(); |
| 468 | + for (int i = 1; i <= toolCount; i++) { |
| 469 | + String toolName = "slow_tool_" + i; |
| 470 | + String callId = "call_" + i; |
| 471 | + tools.put(toolName, new SleepingTool(toolName, sleepMillis)); |
| 472 | + callParts.add( |
| 473 | + Part.builder() |
| 474 | + .functionCall( |
| 475 | + FunctionCall.builder().id(callId).name(toolName).args(ImmutableMap.of()).build()) |
| 476 | + .build()); |
| 477 | + expectedResponseParts.add( |
| 478 | + Part.builder() |
| 479 | + .functionResponse( |
| 480 | + FunctionResponse.builder() |
| 481 | + .id(callId) |
| 482 | + .name(toolName) |
| 483 | + .response(ImmutableMap.of("tool", toolName)) |
| 484 | + .build()) |
| 485 | + .build()); |
| 486 | + } |
| 487 | + Event event = |
| 488 | + createEvent("event").toBuilder() |
| 489 | + .content(Content.fromParts(callParts.toArray(new Part[0]))) |
| 490 | + .build(); |
| 491 | + |
| 492 | + long start = System.currentTimeMillis(); |
| 493 | + Event functionResponseEvent = |
| 494 | + Functions.handleFunctionCalls(invocationContext, event, tools).blockingGet(); |
| 495 | + long durationMillis = System.currentTimeMillis() - start; |
| 496 | + |
| 497 | + assertThat(functionResponseEvent).isNotNull(); |
| 498 | + assertThat(functionResponseEvent.content().get().parts().get()) |
| 499 | + .containsExactlyElementsIn(expectedResponseParts) |
| 500 | + .inOrder(); |
| 501 | + // Sequential would be ~toolCount * sleepMillis; parallel is ~sleepMillis + fixed overhead. |
| 502 | + assertThat(durationMillis).isLessThan((long) toolCount * sleepMillis); |
| 503 | + } |
| 504 | + |
| 505 | + /** Tool that blocks the executing thread for {@code sleepMillis} before returning. */ |
| 506 | + private static final class SleepingTool extends BaseTool { |
| 507 | + private final long sleepMillis; |
| 508 | + |
| 509 | + SleepingTool(String name, long sleepMillis) { |
| 510 | + super(name, "Blocking tool used to verify parallel execution."); |
| 511 | + this.sleepMillis = sleepMillis; |
| 512 | + } |
| 513 | + |
| 514 | + @Override |
| 515 | + public Optional<FunctionDeclaration> declaration() { |
| 516 | + return Optional.of(FunctionDeclaration.builder().name(name()).build()); |
| 517 | + } |
| 518 | + |
| 519 | + @Override |
| 520 | + public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) { |
| 521 | + return Single.fromCallable( |
| 522 | + () -> { |
| 523 | + Thread.sleep(sleepMillis); |
| 524 | + return ImmutableMap.<String, Object>of("tool", name()); |
| 525 | + }); |
| 526 | + } |
| 527 | + } |
391 | 528 | } |
0 commit comments