Skip to content

Commit 1db5e45

Browse files
committed
fix: true parallel tool execution via subscribeOn(Schedulers.io()) in concatMapEager
Fixes #1152
1 parent 4009905 commit 1db5e45

3 files changed

Lines changed: 501 additions & 4 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import io.reactivex.rxjava3.core.Single;
4848
import io.reactivex.rxjava3.disposables.Disposable;
4949
import io.reactivex.rxjava3.functions.Function;
50+
import io.reactivex.rxjava3.schedulers.Schedulers;
5051
import java.util.ArrayList;
5152
import java.util.HashMap;
5253
import java.util.HashSet;
@@ -154,13 +155,16 @@ public static Maybe<Event> handleFunctionCalls(
154155
getFunctionCallMapper(invocationContext, tools, toolConfirmations, false, parentContext);
155156

156157
Observable<Event> functionResponseEventsObservable;
157-
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
158+
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL
159+
|| validFunctionCalls.size() <= 1) {
158160
functionResponseEventsObservable =
159161
Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
160162
} else {
161163
functionResponseEventsObservable =
162164
Observable.fromIterable(validFunctionCalls)
163-
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
165+
.concatMapEager(
166+
call ->
167+
functionCallMapper.apply(call).toObservable().subscribeOn(Schedulers.io()));
164168
}
165169
return functionResponseEventsObservable
166170
.toList()
@@ -225,13 +229,16 @@ public static Maybe<Event> handleFunctionCallsLive(
225229
getFunctionCallMapper(invocationContext, tools, toolConfirmations, true, parentContext);
226230

227231
Observable<Event> responseEventsObservable;
228-
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
232+
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL
233+
|| validFunctionCalls.size() <= 1) {
229234
responseEventsObservable =
230235
Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
231236
} else {
232237
responseEventsObservable =
233238
Observable.fromIterable(validFunctionCalls)
234-
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
239+
.concatMapEager(
240+
call ->
241+
functionCallMapper.apply(call).toObservable().subscribeOn(Schedulers.io()));
235242
}
236243

237244
return responseEventsObservable

0 commit comments

Comments
 (0)