Skip to content

Commit 020499b

Browse files
damianmomotgooglecopybara-github
authored andcommitted
fix: run tools concurrently in PARALLEL ToolExecutionMode
PiperOrigin-RevId: 921331293
1 parent ae13073 commit 020499b

2 files changed

Lines changed: 175 additions & 18 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@
4444
import io.reactivex.rxjava3.core.Flowable;
4545
import io.reactivex.rxjava3.core.Maybe;
4646
import io.reactivex.rxjava3.core.Observable;
47+
import io.reactivex.rxjava3.core.Scheduler;
4748
import io.reactivex.rxjava3.core.Single;
4849
import io.reactivex.rxjava3.disposables.Disposable;
4950
import io.reactivex.rxjava3.functions.Function;
51+
import io.reactivex.rxjava3.schedulers.Schedulers;
5052
import java.util.ArrayList;
5153
import java.util.HashMap;
5254
import java.util.HashSet;
@@ -153,15 +155,8 @@ public static Maybe<Event> handleFunctionCalls(
153155
Function<FunctionCall, Maybe<Event>> functionCallMapper =
154156
getFunctionCallMapper(invocationContext, tools, toolConfirmations, false, parentContext);
155157

156-
Observable<Event> functionResponseEventsObservable;
157-
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
158-
functionResponseEventsObservable =
159-
Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
160-
} else {
161-
functionResponseEventsObservable =
162-
Observable.fromIterable(validFunctionCalls)
163-
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
164-
}
158+
Observable<Event> functionResponseEventsObservable =
159+
buildToolExecutionObservable(invocationContext, validFunctionCalls, functionCallMapper);
165160
return functionResponseEventsObservable
166161
.toList()
167162
.toMaybe()
@@ -224,15 +219,8 @@ public static Maybe<Event> handleFunctionCallsLive(
224219
Function<FunctionCall, Maybe<Event>> functionCallMapper =
225220
getFunctionCallMapper(invocationContext, tools, toolConfirmations, true, parentContext);
226221

227-
Observable<Event> responseEventsObservable;
228-
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
229-
responseEventsObservable =
230-
Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
231-
} else {
232-
responseEventsObservable =
233-
Observable.fromIterable(validFunctionCalls)
234-
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
235-
}
222+
Observable<Event> responseEventsObservable =
223+
buildToolExecutionObservable(invocationContext, validFunctionCalls, functionCallMapper);
236224

237225
return responseEventsObservable
238226
.toList()
@@ -247,6 +235,38 @@ public static Maybe<Event> handleFunctionCallsLive(
247235
});
248236
}
249237

238+
/**
239+
* Builds the tool-execution {@link Observable}.
240+
*
241+
* <p>SEQUENTIAL (or a single call, where parallelism is moot) runs on the caller thread via
242+
* {@code concatMapMaybe} to keep synchronous semantics. PARALLEL with multiple calls dispatches
243+
* each tool on a worker so blocking calls run concurrently; {@code concatMapEager} preserves
244+
* input order required by {@link #mergeParallelFunctionResponseEvents}.
245+
*/
246+
private static Observable<Event> buildToolExecutionObservable(
247+
InvocationContext invocationContext,
248+
List<FunctionCall> validFunctionCalls,
249+
Function<FunctionCall, Maybe<Event>> functionCallMapper) {
250+
boolean parallel =
251+
invocationContext.runConfig().toolExecutionMode() != ToolExecutionMode.SEQUENTIAL
252+
&& validFunctionCalls.size() > 1;
253+
if (!parallel) {
254+
return Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
255+
}
256+
Scheduler scheduler = resolveToolExecutionScheduler(invocationContext);
257+
return Observable.fromIterable(validFunctionCalls)
258+
.concatMapEager(
259+
call -> functionCallMapper.apply(call).toObservable().subscribeOn(scheduler));
260+
}
261+
262+
/** Agent executor if set, otherwise the IO scheduler. */
263+
private static Scheduler resolveToolExecutionScheduler(InvocationContext invocationContext) {
264+
if (invocationContext.agent() instanceof LlmAgent llmAgent) {
265+
return llmAgent.executor().map(Schedulers::from).orElse(Schedulers.io());
266+
}
267+
return Schedulers.io();
268+
}
269+
250270
private static Function<FunctionCall, Maybe<Event>> getFunctionCallMapper(
251271
InvocationContext invocationContext,
252272
Map<String, BaseTool> tools,

core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,21 @@
2626
import com.google.adk.agents.RunConfig.ToolExecutionMode;
2727
import com.google.adk.events.Event;
2828
import com.google.adk.testing.TestUtils;
29+
import com.google.adk.tools.BaseTool;
30+
import com.google.adk.tools.ToolContext;
2931
import com.google.common.collect.ImmutableList;
3032
import com.google.common.collect.ImmutableMap;
3133
import com.google.genai.types.Content;
3234
import com.google.genai.types.FunctionCall;
35+
import com.google.genai.types.FunctionDeclaration;
3336
import com.google.genai.types.FunctionResponse;
3437
import com.google.genai.types.Part;
38+
import io.reactivex.rxjava3.core.Single;
39+
import java.util.ArrayList;
40+
import java.util.LinkedHashMap;
41+
import java.util.List;
42+
import java.util.Map;
43+
import java.util.Optional;
3544
import org.junit.Test;
3645
import org.junit.runner.RunWith;
3746
import org.junit.runners.JUnit4;
@@ -388,4 +397,132 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal
388397
ImmutableList<FunctionCall> result = Functions.getAskUserConfirmationFunctionCalls(event);
389398
assertThat(result).containsExactly(confirmationCall1, confirmationCall2);
390399
}
400+
401+
@Test
402+
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_twoTools() {
403+
runParallelBlockingToolsTest(/* toolCount= */ 2);
404+
}
405+
406+
@Test
407+
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_threeTools() {
408+
runParallelBlockingToolsTest(/* toolCount= */ 3);
409+
}
410+
411+
@Test
412+
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_fiveTools() {
413+
runParallelBlockingToolsTest(/* toolCount= */ 5);
414+
}
415+
416+
/** Single-tool case bypasses the parallel scheduler path; must still return the correct event. */
417+
@Test
418+
public void handleFunctionCalls_parallel_blockingTool_singleTool() {
419+
long sleepMillis = 200L;
420+
InvocationContext invocationContext =
421+
createInvocationContext(
422+
createRootAgent(),
423+
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
424+
SleepingTool tool = new SleepingTool("slow_tool_1", sleepMillis);
425+
Event event =
426+
createEvent("event").toBuilder()
427+
.content(
428+
Content.fromParts(
429+
Part.builder()
430+
.functionCall(
431+
FunctionCall.builder()
432+
.id("call_1")
433+
.name("slow_tool_1")
434+
.args(ImmutableMap.of())
435+
.build())
436+
.build()))
437+
.build();
438+
439+
Event functionResponseEvent =
440+
Functions.handleFunctionCalls(
441+
invocationContext, event, ImmutableMap.of("slow_tool_1", tool))
442+
.blockingGet();
443+
444+
assertThat(functionResponseEvent).isNotNull();
445+
assertThat(functionResponseEvent.content().get().parts().get())
446+
.containsExactly(
447+
Part.builder()
448+
.functionResponse(
449+
FunctionResponse.builder()
450+
.id("call_1")
451+
.name("slow_tool_1")
452+
.response(ImmutableMap.of("tool", "slow_tool_1"))
453+
.build())
454+
.build());
455+
}
456+
457+
/** Asserts that {@code toolCount} blocking tools in PARALLEL mode run faster than sequential. */
458+
private static void runParallelBlockingToolsTest(int toolCount) {
459+
long sleepMillis = 500L;
460+
InvocationContext invocationContext =
461+
createInvocationContext(
462+
createRootAgent(),
463+
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
464+
465+
Map<String, BaseTool> tools = new LinkedHashMap<>();
466+
List<Part> callParts = new ArrayList<>();
467+
List<Part> expectedResponseParts = new ArrayList<>();
468+
for (int i = 1; i <= toolCount; i++) {
469+
String toolName = "slow_tool_" + i;
470+
String callId = "call_" + i;
471+
tools.put(toolName, new SleepingTool(toolName, sleepMillis));
472+
callParts.add(
473+
Part.builder()
474+
.functionCall(
475+
FunctionCall.builder().id(callId).name(toolName).args(ImmutableMap.of()).build())
476+
.build());
477+
expectedResponseParts.add(
478+
Part.builder()
479+
.functionResponse(
480+
FunctionResponse.builder()
481+
.id(callId)
482+
.name(toolName)
483+
.response(ImmutableMap.of("tool", toolName))
484+
.build())
485+
.build());
486+
}
487+
Event event =
488+
createEvent("event").toBuilder()
489+
.content(Content.fromParts(callParts.toArray(new Part[0])))
490+
.build();
491+
492+
long start = System.currentTimeMillis();
493+
Event functionResponseEvent =
494+
Functions.handleFunctionCalls(invocationContext, event, tools).blockingGet();
495+
long durationMillis = System.currentTimeMillis() - start;
496+
497+
assertThat(functionResponseEvent).isNotNull();
498+
assertThat(functionResponseEvent.content().get().parts().get())
499+
.containsExactlyElementsIn(expectedResponseParts)
500+
.inOrder();
501+
// Sequential would be ~toolCount * sleepMillis; parallel is ~sleepMillis + fixed overhead.
502+
assertThat(durationMillis).isLessThan((long) toolCount * sleepMillis);
503+
}
504+
505+
/** Tool that blocks the executing thread for {@code sleepMillis} before returning. */
506+
private static final class SleepingTool extends BaseTool {
507+
private final long sleepMillis;
508+
509+
SleepingTool(String name, long sleepMillis) {
510+
super(name, "Blocking tool used to verify parallel execution.");
511+
this.sleepMillis = sleepMillis;
512+
}
513+
514+
@Override
515+
public Optional<FunctionDeclaration> declaration() {
516+
return Optional.of(FunctionDeclaration.builder().name(name()).build());
517+
}
518+
519+
@Override
520+
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
521+
return Single.fromCallable(
522+
() -> {
523+
Thread.sleep(sleepMillis);
524+
return ImmutableMap.<String, Object>of("tool", name());
525+
});
526+
}
527+
}
391528
}

0 commit comments

Comments
 (0)