Skip to content

Commit 2d675d0

Browse files
Mateusz Krawieccopybara-github
authored andcommitted
fix: support non-map return values returned from Function Tools by automatically wrapping them into {"result": <value>}
Rationale: https://google.github.io/adk-docs/tools-custom/function-tools/#return-type PiperOrigin-RevId: 840105092
1 parent f0da2b4 commit 2d675d0

4 files changed

Lines changed: 33 additions & 29 deletions

File tree

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,10 +514,12 @@ public Flowable<Event> runLive(
514514
.agent()
515515
.runLive(invocationContext)
516516
.doOnNext(event -> this.sessionService.appendEvent(session, event))
517-
.doOnError(
517+
.onErrorResumeNext(
518518
throwable -> {
519519
span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution");
520520
span.recordException(throwable);
521+
span.end();
522+
return Flowable.error(throwable);
521523
}));
522524
} catch (Throwable t) {
523525
span.setStatus(StatusCode.ERROR, "Error during runLive synchronous setup");

core/src/main/java/com/google/adk/tools/FunctionTool.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,15 @@ private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext to
238238
data -> OBJECT_MAPPER.convertValue(data, new TypeReference<Map<String, Object>>() {}))
239239
.toMaybe();
240240
} else {
241-
return Maybe.just(
242-
OBJECT_MAPPER.convertValue(result, new TypeReference<Map<String, Object>>() {}));
241+
try {
242+
return Maybe.just(
243+
OBJECT_MAPPER.convertValue(result, new TypeReference<Map<String, Object>>() {}));
244+
} catch (IllegalArgumentException e) {
245+
// Conversion to map failed, in this case we follow
246+
// https://google.github.io/adk-docs/tools-custom/function-tools/#return-type and return
247+
// the { "result": $result }
248+
return Maybe.just(ImmutableMap.of("result", result));
249+
}
243250
}
244251
}
245252

core/src/test/java/com/google/adk/tools/FunctionToolTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,15 @@ public void call_withPojoParamWithFields() throws Exception {
328328
assertThat(result).containsExactly("field1", "abc", "field2", 123);
329329
}
330330

331+
@Test
332+
public void call_withBooleanReturnValue_returnsMapWithResult() throws Exception {
333+
FunctionTool tool = FunctionTool.create(Functions.class, "returnsBoolean");
334+
335+
Map<String, Object> result = tool.runAsync(ImmutableMap.of(), null).blockingGet();
336+
337+
assertThat(result).containsExactly("result", true);
338+
}
339+
331340
@Test
332341
public void call_withPojoParamWithGettersAndSetters() throws Exception {
333342
FunctionTool tool = FunctionTool.create(Functions.class, "pojoParamWithGettersAndSetters");
@@ -894,6 +903,10 @@ public static Single<Map<String, Object>> returnsSingleMap() {
894903
return Single.just(ImmutableMap.of("key", "value"));
895904
}
896905

906+
public static Boolean returnsBoolean() {
907+
return true;
908+
}
909+
897910
public static PojoWithGettersAndSetters returnsPojo() {
898911
PojoWithGettersAndSetters pojo = new PojoWithGettersAndSetters();
899912
pojo.setField1("abc");

core/src/test/java/com/google/adk/tools/streaming/StreamingToolTest.java

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
@RunWith(JUnit4.class)
4949
public final class StreamingToolTest {
5050

51+
private static final RunConfig BIDI_STREAMING_RUN_CONFIG =
52+
RunConfig.builder().setStreamingMode(StreamingMode.BIDI).build();
53+
5154
public static final class StreamingTools {
5255
public static Flowable<ImmutableMap<String, Object>> monitorStockPrice(
5356
@Schema(name = "stockSymbol") String stockSymbol) {
@@ -130,10 +133,7 @@ public void runLive_asyncFunctionCall_succeeds() throws Exception {
130133

131134
Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet();
132135
List<Event> resEvents =
133-
runner
134-
.runLive(session, liveRequestQueue, RunConfig.builder().build())
135-
.toList()
136-
.blockingGet();
136+
runner.runLive(session, liveRequestQueue, BIDI_STREAMING_RUN_CONFIG).toList().blockingGet();
137137

138138
assertThat(resEvents).isNotNull();
139139
assertThat(resEvents).isNotEmpty();
@@ -215,10 +215,7 @@ public void runLive_functionCall_returnsErrors() throws Exception {
215215

216216
Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet();
217217
List<Event> resEvents =
218-
runner
219-
.runLive(session, liveRequestQueue, RunConfig.builder().build())
220-
.toList()
221-
.blockingGet();
218+
runner.runLive(session, liveRequestQueue, BIDI_STREAMING_RUN_CONFIG).toList().blockingGet();
222219

223220
assertThat(resEvents).isNotNull();
224221
assertThat(resEvents).isNotEmpty();
@@ -301,13 +298,7 @@ public void runLive_videoStreamingTool_receivesVideoFramesAndSendsResultsToLlm()
301298

302299
// Run the agent and collect events.
303300
List<Event> resEvents =
304-
runner
305-
.runLive(
306-
session,
307-
liveRequestQueue,
308-
RunConfig.builder().setStreamingMode(StreamingMode.BIDI).build())
309-
.toList()
310-
.blockingGet();
301+
runner.runLive(session, liveRequestQueue, BIDI_STREAMING_RUN_CONFIG).toList().blockingGet();
311302

312303
// Wait for the tool to send its 3 results back to the LLM
313304
assertThat(testLlm.waitForStreamingToolResults("monitorVideoStream", 3, Duration.ofSeconds(20)))
@@ -409,10 +400,7 @@ public void runLive_stopStreamingTool() throws Exception {
409400

410401
Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet();
411402
List<Event> resEvents =
412-
runner
413-
.runLive(session, liveRequestQueue, RunConfig.builder().build())
414-
.toList()
415-
.blockingGet();
403+
runner.runLive(session, liveRequestQueue, BIDI_STREAMING_RUN_CONFIG).toList().blockingGet();
416404

417405
assertThat(resEvents).isNotNull();
418406
assertThat(resEvents.size()).isAtLeast(1);
@@ -502,13 +490,7 @@ public void runLive_streamingTool_responsesAreSentAsUserContentToLlm() throws Ex
502490
Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet();
503491

504492
List<Event> resEvents =
505-
runner
506-
.runLive(
507-
session,
508-
liveRequestQueue,
509-
RunConfig.builder().setStreamingMode(StreamingMode.BIDI).build())
510-
.toList()
511-
.blockingGet();
493+
runner.runLive(session, liveRequestQueue, BIDI_STREAMING_RUN_CONFIG).toList().blockingGet();
512494

513495
assertThat(resEvents).isNotNull();
514496
assertThat(resEvents).isNotEmpty();

0 commit comments

Comments
 (0)