Skip to content

Commit dd510d0

Browse files
damianmomotgooglecopybara-github
authored andcommitted
fix: introduce PARALLEL_SUBSCRIBE ToolExecutionMode; restore previous PARALLEL semantics
PiperOrigin-RevId: 921591557
1 parent fe88217 commit dd510d0

3 files changed

Lines changed: 97 additions & 27 deletions

File tree

core/src/main/java/com/google/adk/agents/RunConfig.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,26 @@ public enum StreamingMode {
4141
/**
4242
* Execution mode when the model requests multiple tools.
4343
*
44-
* <p>NONE: defaults to SEQUENTIAL.
44+
* <p>NONE: defaults to PARALLEL.
4545
*
46-
* <p>SEQUENTIAL: tools execute in request order on the caller thread.
46+
* <p>SEQUENTIAL: tools execute strictly in request order on the caller thread; each tool must
47+
* complete (including any asynchronous work) before the next one is subscribed to.
4748
*
48-
* <p>PARALLEL: tools execute concurrently on worker threads. Tool implementations must be
49-
* thread-safe.
49+
* <p>PARALLEL: tools are subscribed to eagerly on the caller thread (i.e. all are kicked off
50+
* up-front), but no worker threads are introduced. Tools that are truly asynchronous (e.g. they
51+
* return a {@code Single} backed by I/O or another scheduler) will run concurrently; tools that
52+
* block the subscribing thread (e.g. {@code Single.fromCallable} that performs blocking work)
53+
* will still execute sequentially. This preserves the historical default behavior.
54+
*
55+
* <p>PARALLEL_SUBSCRIBE: like {@code PARALLEL}, but every tool is additionally subscribed on a
56+
* worker thread, so blocking tools also run concurrently. Tool implementations must be
57+
* thread-safe. The worker is the agent's executor when set, otherwise the RxJava IO scheduler.
5058
*/
5159
public enum ToolExecutionMode {
5260
NONE,
5361
SEQUENTIAL,
54-
PARALLEL
62+
PARALLEL,
63+
PARALLEL_SUBSCRIBE
5564
}
5665

5766
public abstract @Nullable SpeechConfig speechConfig();

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -236,23 +236,40 @@ public static Maybe<Event> handleFunctionCallsLive(
236236
}
237237

238238
/**
239-
* Sequential by default; only {@link ToolExecutionMode#PARALLEL} with multiple calls dispatches
240-
* tools on workers (using {@code concatMapEager} to preserve input order).
239+
* Builds the tool-execution {@link Observable} for the configured {@link ToolExecutionMode}.
240+
*
241+
* <ul>
242+
* <li>{@link ToolExecutionMode#SEQUENTIAL} (or a single call, where parallelism is moot) uses
243+
* {@code concatMapMaybe}: each tool is subscribed only after the previous one completes.
244+
* <li>{@link ToolExecutionMode#PARALLEL} (the default) uses {@code concatMapEager}: all tools
245+
* are subscribed eagerly on the caller thread. Async tools therefore run concurrently, but
246+
* tools that block the subscribing thread still execute sequentially. This matches the
247+
* historical behavior of the default mode.
248+
* <li>{@link ToolExecutionMode#PARALLEL_SUBSCRIBE} uses {@code concatMapEager} and additionally
249+
* subscribes each tool on a worker scheduler, so blocking tools also run concurrently.
250+
* {@code concatMapEager} preserves input order required by {@link
251+
* #mergeParallelFunctionResponseEvents}.
252+
* </ul>
241253
*/
242254
private static Observable<Event> buildToolExecutionObservable(
243255
InvocationContext invocationContext,
244256
List<FunctionCall> validFunctionCalls,
245257
Function<FunctionCall, Maybe<Event>> functionCallMapper) {
246-
boolean parallel =
247-
invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.PARALLEL
248-
&& validFunctionCalls.size() > 1;
249-
if (!parallel) {
258+
ToolExecutionMode mode = invocationContext.runConfig().toolExecutionMode();
259+
boolean sequential = mode == ToolExecutionMode.SEQUENTIAL || validFunctionCalls.size() <= 1;
260+
if (sequential) {
250261
return Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
251262
}
252-
Scheduler scheduler = resolveToolExecutionScheduler(invocationContext);
263+
if (mode == ToolExecutionMode.PARALLEL_SUBSCRIBE) {
264+
Scheduler scheduler = resolveToolExecutionScheduler(invocationContext);
265+
return Observable.fromIterable(validFunctionCalls)
266+
.concatMapEager(
267+
call -> functionCallMapper.apply(call).toObservable().subscribeOn(scheduler));
268+
}
269+
// PARALLEL (and NONE, which defaults to PARALLEL): eager subscribe on the caller thread,
270+
// without offloading to a worker. Async tools run concurrently; blocking tools still block.
253271
return Observable.fromIterable(validFunctionCalls)
254-
.concatMapEager(
255-
call -> functionCallMapper.apply(call).toObservable().subscribeOn(scheduler));
272+
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
256273
}
257274

258275
/** Agent executor if set, otherwise the IO scheduler. */

core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,10 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal
398398
assertThat(result).containsExactly(confirmationCall1, confirmationCall2);
399399
}
400400

401-
// Default ToolExecutionMode.NONE must execute tools sequentially.
401+
// Default ToolExecutionMode.NONE behaves like PARALLEL: blocking tools still execute serially
402+
// on the caller thread (no worker scheduler is used), preserving the historical default.
402403
@Test
403-
public void handleFunctionCalls_defaultMode_blockingTools_runSequentially() {
404+
public void handleFunctionCalls_defaultMode_blockingTools_runSerially() {
404405
long sleepMillis = 300L;
405406
int toolCount = 2;
406407
InvocationContext invocationContext =
@@ -435,29 +436,69 @@ public void handleFunctionCalls_defaultMode_blockingTools_runSequentially() {
435436
assertThat(durationMillis).isAtLeast((long) toolCount * sleepMillis);
436437
}
437438

439+
// PARALLEL mode does NOT introduce worker threads; blocking tools still run serially on the
440+
// caller thread. PARALLEL_SUBSCRIBE is the mode that runs blocking tools concurrently.
438441
@Test
439-
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_twoTools() {
440-
runParallelBlockingToolsTest(/* toolCount= */ 2);
442+
public void handleFunctionCalls_parallel_blockingTools_runSerially() {
443+
long sleepMillis = 300L;
444+
int toolCount = 2;
445+
InvocationContext invocationContext =
446+
createInvocationContext(
447+
createRootAgent(),
448+
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
449+
450+
Map<String, BaseTool> tools = new LinkedHashMap<>();
451+
List<Part> callParts = new ArrayList<>();
452+
for (int i = 1; i <= toolCount; i++) {
453+
String toolName = "slow_tool_" + i;
454+
tools.put(toolName, new SleepingTool(toolName, sleepMillis));
455+
callParts.add(
456+
Part.builder()
457+
.functionCall(
458+
FunctionCall.builder()
459+
.id("call_" + i)
460+
.name(toolName)
461+
.args(ImmutableMap.of())
462+
.build())
463+
.build());
464+
}
465+
Event event =
466+
createEvent("event").toBuilder()
467+
.content(Content.fromParts(callParts.toArray(new Part[0])))
468+
.build();
469+
470+
long start = System.currentTimeMillis();
471+
Event functionResponseEvent =
472+
Functions.handleFunctionCalls(invocationContext, event, tools).blockingGet();
473+
long durationMillis = System.currentTimeMillis() - start;
474+
475+
assertThat(functionResponseEvent).isNotNull();
476+
assertThat(durationMillis).isAtLeast((long) toolCount * sleepMillis);
441477
}
442478

443479
@Test
444-
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_threeTools() {
445-
runParallelBlockingToolsTest(/* toolCount= */ 3);
480+
public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_twoTools() {
481+
runParallelSubscribeBlockingToolsTest(/* toolCount= */ 2);
446482
}
447483

448484
@Test
449-
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_fiveTools() {
450-
runParallelBlockingToolsTest(/* toolCount= */ 5);
485+
public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_threeTools() {
486+
runParallelSubscribeBlockingToolsTest(/* toolCount= */ 3);
487+
}
488+
489+
@Test
490+
public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_fiveTools() {
491+
runParallelSubscribeBlockingToolsTest(/* toolCount= */ 5);
451492
}
452493

453494
/** Single-tool case bypasses the parallel scheduler path; must still return the correct event. */
454495
@Test
455-
public void handleFunctionCalls_parallel_blockingTool_singleTool() {
496+
public void handleFunctionCalls_parallelSubscribe_blockingTool_singleTool() {
456497
long sleepMillis = 200L;
457498
InvocationContext invocationContext =
458499
createInvocationContext(
459500
createRootAgent(),
460-
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
501+
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL_SUBSCRIBE).build());
461502
SleepingTool tool = new SleepingTool("slow_tool_1", sleepMillis);
462503
Event event =
463504
createEvent("event").toBuilder()
@@ -491,13 +532,16 @@ public void handleFunctionCalls_parallel_blockingTool_singleTool() {
491532
.build());
492533
}
493534

494-
/** Asserts that {@code toolCount} blocking tools in PARALLEL mode run faster than sequential. */
495-
private static void runParallelBlockingToolsTest(int toolCount) {
535+
/**
536+
* Asserts that {@code toolCount} blocking tools in PARALLEL_SUBSCRIBE mode run faster than
537+
* sequential, since each tool is subscribed on a worker thread.
538+
*/
539+
private static void runParallelSubscribeBlockingToolsTest(int toolCount) {
496540
long sleepMillis = 500L;
497541
InvocationContext invocationContext =
498542
createInvocationContext(
499543
createRootAgent(),
500-
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
544+
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL_SUBSCRIBE).build());
501545

502546
Map<String, BaseTool> tools = new LinkedHashMap<>();
503547
List<Part> callParts = new ArrayList<>();

0 commit comments

Comments
 (0)