Skip to content

Commit 9a06dd3

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: support optional types in function tool parameters
PiperOrigin-RevId: 924804796
1 parent 29d3203 commit 9a06dd3

7 files changed

Lines changed: 642 additions & 29 deletions

File tree

core/src/main/java/com/google/adk/SchemaUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ private SchemaUtils() {} // Private constructor for utility class
4141
*/
4242
@SuppressWarnings("unchecked") // For tool parameter type casting.
4343
private static Boolean matchType(Object value, Schema schema, Boolean isInput) {
44+
if (value == null) {
45+
return schema.nullable().orElse(false);
46+
}
4447
// Based on types from https://cloud.google.com/vertex-ai/docs/reference/rest/v1/Schema
4548
Type.Known type = schema.type().get().knownEnum();
4649
switch (type) {
@@ -73,7 +76,6 @@ private static Boolean matchType(Object value, Schema schema, Boolean isInput) {
7376
throw new IllegalArgumentException(
7477
"Unsupported type: " + type + " is not a Open API data type.");
7578
default:
76-
// This category includes NULL, which is not supported.
7779
break;
7880
}
7981
return false;

core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@ public static Schema buildSchemaFromType(Type type, ObjectMapper objectMapper) {
188188
*/
189189
private static Schema buildSchemaRecursive(
190190
JavaType javaType, SchemaGenerationContext context, ObjectMapper objectMapper) {
191+
if (Optional.class.isAssignableFrom(javaType.getRawClass())) {
192+
JavaType containedType = javaType.containedType(0);
193+
if (containedType == null) {
194+
return Schema.builder().type("OBJECT").nullable(true).build();
195+
}
196+
Schema innerSchema = buildSchemaRecursive(containedType, context, objectMapper);
197+
return innerSchema.toBuilder().nullable(true).build();
198+
}
191199
if (context.isProcessing(javaType)) {
192200
logger.warn("Type {} is recursive. Omitting from schema.", javaType.toCanonical());
193201
return Schema.builder()

core/src/main/java/com/google/adk/tools/FunctionTool.java

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -271,29 +271,45 @@ private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext to
271271
throws IllegalAccessException, InvocationTargetException {
272272
Object[] arguments = buildArguments(args, toolContext, null);
273273
Object result = func.invoke(instance, arguments);
274-
if (result == null) {
274+
if (result == null || isEmptyOptional(result)) {
275275
return Maybe.empty();
276276
} else if (result instanceof Maybe) {
277277
return ((Maybe<?>) result)
278-
.map(
279-
data -> objectMapper.convertValue(data, new TypeReference<Map<String, Object>>() {}));
278+
.filter(data -> !isEmptyOptional(data))
279+
.map(this::convertToMapOrResult);
280280
} else if (result instanceof Single) {
281281
return ((Single<?>) result)
282-
.map(data -> objectMapper.convertValue(data, new TypeReference<Map<String, Object>>() {}))
283-
.toMaybe();
282+
.toMaybe()
283+
.filter(data -> !isEmptyOptional(data))
284+
.map(this::convertToMapOrResult);
284285
} else {
285-
try {
286-
return Maybe.just(
287-
objectMapper.convertValue(result, new TypeReference<Map<String, Object>>() {}));
288-
} catch (IllegalArgumentException e) {
289-
// Conversion to map failed, in this case we follow
290-
// https://google.github.io/adk-docs/tools-custom/function-tools/#return-type and return
291-
// the { "result": $result }
292-
return Maybe.just(ImmutableMap.of("result", result));
286+
return Maybe.just(convertToMapOrResult(result));
287+
}
288+
}
289+
290+
private Map<String, Object> convertToMapOrResult(Object value) {
291+
if (value instanceof Optional) {
292+
value = ((Optional<?>) value).get();
293+
}
294+
try {
295+
Map<String, Object> map =
296+
objectMapper.convertValue(value, new TypeReference<Map<String, Object>>() {});
297+
if (map == null) {
298+
return ImmutableMap.of();
293299
}
300+
return map;
301+
} catch (IllegalArgumentException e) {
302+
// Conversion to map failed, in this case we follow
303+
// https://google.github.io/adk-docs/tools-custom/function-tools/#return-type and return
304+
// the { "result": $result }
305+
return ImmutableMap.of("result", value);
294306
}
295307
}
296308

309+
private static boolean isEmptyOptional(Object value) {
310+
return value instanceof Optional && ((Optional<?>) value).isEmpty();
311+
}
312+
297313
@SuppressWarnings("unchecked")
298314
public Flowable<Map<String, Object>> callLive(
299315
Map<String, Object> args, ToolContext toolContext, InvocationContext invocationContext)
@@ -308,6 +324,21 @@ public Flowable<Map<String, Object>> callLive(
308324
}
309325
}
310326

327+
@SuppressWarnings("unchecked") // For tool parameter type casting.
328+
private @Nullable Object resolveArgumentValue(
329+
@Nullable Object argValue, Class<?> paramType, Type parameterizedType, String paramName) {
330+
if (paramType.equals(List.class)) {
331+
if (argValue instanceof List) {
332+
Type type = ((ParameterizedType) parameterizedType).getActualTypeArguments()[0];
333+
Class<?> typeArgClass = getTypeClass(type, paramName);
334+
return createList((List<Object>) argValue, typeArgClass);
335+
}
336+
} else if (argValue instanceof Map) {
337+
return objectMapper.convertValue(argValue, paramType);
338+
}
339+
return castValue(argValue, paramType);
340+
}
341+
311342
@SuppressWarnings("unchecked") // For tool parameter type casting.
312343
private Object[] buildArguments(
313344
Map<String, Object> args,
@@ -336,9 +367,14 @@ private Object[] buildArguments(
336367
continue;
337368
}
338369
Annotations.Schema schema = parameters[i].getAnnotation(Annotations.Schema.class);
370+
Class<?> paramType = parameters[i].getType();
339371
if (!args.containsKey(paramName)) {
340372
if (schema != null && schema.optional()) {
341-
arguments[i] = null;
373+
if (paramType.equals(Optional.class)) {
374+
arguments[i] = Optional.empty();
375+
} else {
376+
arguments[i] = null;
377+
}
342378
continue;
343379
} else {
344380
throw new IllegalArgumentException(
@@ -347,22 +383,27 @@ private Object[] buildArguments(
347383
paramName));
348384
}
349385
}
350-
Class<?> paramType = parameters[i].getType();
351386
Object argValue = args.get(paramName);
352-
if (paramType.equals(List.class)) {
353-
if (argValue instanceof List) {
354-
Type type =
355-
((ParameterizedType) parameters[i].getParameterizedType())
356-
.getActualTypeArguments()[0];
357-
Class<?> typeArgClass = getTypeClass(type, paramName);
358-
arguments[i] = createList((List<Object>) argValue, typeArgClass);
359-
continue;
387+
if (paramType.equals(Optional.class)) {
388+
if (argValue == null) {
389+
arguments[i] = Optional.empty();
390+
} else {
391+
Type innerType;
392+
Type paramParameterizedType = parameters[i].getParameterizedType();
393+
if (paramParameterizedType instanceof ParameterizedType pType) {
394+
innerType = pType.getActualTypeArguments()[0];
395+
} else {
396+
innerType = Object.class;
397+
}
398+
Class<?> innerClass = getTypeClass(innerType, paramName);
399+
Object resolvedValue = resolveArgumentValue(argValue, innerClass, innerType, paramName);
400+
arguments[i] = Optional.ofNullable(resolvedValue);
360401
}
361-
} else if (argValue instanceof Map) {
362-
arguments[i] = objectMapper.convertValue(argValue, paramType);
363-
continue;
402+
} else {
403+
arguments[i] =
404+
resolveArgumentValue(
405+
argValue, paramType, parameters[i].getParameterizedType(), paramName);
364406
}
365-
arguments[i] = castValue(argValue, paramType);
366407
}
367408
return arguments;
368409
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk;
18+
19+
import static org.junit.Assert.assertThrows;
20+
21+
import com.google.common.collect.ImmutableMap;
22+
import com.google.genai.types.Schema;
23+
import java.util.HashMap;
24+
import java.util.Map;
25+
import org.junit.Test;
26+
import org.junit.runner.RunWith;
27+
import org.junit.runners.JUnit4;
28+
29+
/** Unit tests for {@link SchemaUtils}. */
30+
@RunWith(JUnit4.class)
31+
public final class SchemaUtilsTest {
32+
33+
@Test
34+
public void validateMapOnSchema_nullableField_allowsNull() {
35+
Schema schema =
36+
Schema.builder()
37+
.type("OBJECT")
38+
.properties(
39+
ImmutableMap.of(
40+
"nullableField", Schema.builder().type("STRING").nullable(true).build()))
41+
.build();
42+
43+
Map<String, Object> args = new HashMap<>();
44+
args.put("nullableField", null);
45+
46+
// Should not throw exception
47+
SchemaUtils.validateMapOnSchema(args, schema, /* isInput= */ true);
48+
}
49+
50+
@Test
51+
public void validateMapOnSchema_nonNullableField_throwsException() {
52+
Schema schema =
53+
Schema.builder()
54+
.type("OBJECT")
55+
.properties(
56+
ImmutableMap.of(
57+
"nonNullableField", Schema.builder().type("STRING").nullable(false).build()))
58+
.build();
59+
60+
Map<String, Object> args = new HashMap<>();
61+
args.put("nonNullableField", null);
62+
63+
assertThrows(
64+
IllegalArgumentException.class,
65+
() -> SchemaUtils.validateMapOnSchema(args, schema, /* isInput= */ true));
66+
}
67+
68+
@Test
69+
public void validateMapOnSchema_implicitNonNullableField_throwsException() {
70+
Schema schema =
71+
Schema.builder()
72+
.type("OBJECT")
73+
.properties(ImmutableMap.of("defaultField", Schema.builder().type("STRING").build()))
74+
.build();
75+
76+
Map<String, Object> args = new HashMap<>();
77+
args.put("defaultField", null);
78+
79+
assertThrows(
80+
IllegalArgumentException.class,
81+
() -> SchemaUtils.validateMapOnSchema(args, schema, /* isInput= */ true));
82+
}
83+
}

core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ public void ensureInvocationCompleted_timeout_cleansUpState() throws IOException
213213

214214
// Wait for cleanup side effects which run after terminal signal.
215215
long deadline = Instant.now().plusMillis(1000).toEpochMilli();
216-
while (!pluginState.getPendingTasksForInvocation(invocationId).isEmpty()
216+
while ((!pluginState.getBatchProcessors().isEmpty()
217+
|| !pluginState.getTraceManagers().isEmpty())
217218
&& Instant.now().toEpochMilli() < deadline) {
218219
try {
219220
Thread.sleep(10);
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.tools;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
21+
import com.fasterxml.jackson.core.type.TypeReference;
22+
import com.google.common.collect.ImmutableMap;
23+
import com.google.genai.types.Schema;
24+
import java.lang.reflect.Type;
25+
import java.util.List;
26+
import java.util.Optional;
27+
import org.junit.Test;
28+
import org.junit.runner.RunWith;
29+
import org.junit.runners.JUnit4;
30+
31+
/** Unit tests for {@link FunctionCallingUtils}. */
32+
@RunWith(JUnit4.class)
33+
public final class FunctionCallingUtilsTest {
34+
35+
public static class PojoWithFields {
36+
public String field1;
37+
public int field2;
38+
}
39+
40+
public static class PojoWithOptionalFields {
41+
public Optional<String> optionalField;
42+
public Optional<PojoWithFields> optionalPojo;
43+
public Optional<List<String>> optionalList;
44+
}
45+
46+
@Test
47+
public void buildSchemaFromType_optionalString_returnsNullableString() {
48+
Type type = new TypeReference<Optional<String>>() {}.getType();
49+
50+
Schema schema = FunctionCallingUtils.buildSchemaFromType(type);
51+
52+
assertThat(schema).isEqualTo(Schema.builder().type("STRING").nullable(true).build());
53+
}
54+
55+
@Test
56+
public void buildSchemaFromType_optionalPojo_returnsNullablePojoWithProperties() {
57+
Type type = new TypeReference<Optional<PojoWithFields>>() {}.getType();
58+
59+
Schema schema = FunctionCallingUtils.buildSchemaFromType(type);
60+
61+
assertThat(schema)
62+
.isEqualTo(
63+
Schema.builder()
64+
.type("OBJECT")
65+
.nullable(true)
66+
.properties(
67+
ImmutableMap.of(
68+
"field1", Schema.builder().type("STRING").build(),
69+
"field2", Schema.builder().type("INTEGER").build()))
70+
.build());
71+
}
72+
73+
@Test
74+
public void buildSchemaFromType_pojoWithOptionalFields_generatesCorrectSchema() {
75+
Type type = PojoWithOptionalFields.class;
76+
77+
Schema schema = FunctionCallingUtils.buildSchemaFromType(type);
78+
79+
Schema expectedSchema =
80+
Schema.builder()
81+
.type("OBJECT")
82+
.properties(
83+
ImmutableMap.of(
84+
"optionalField",
85+
Schema.builder().type("STRING").nullable(true).build(),
86+
"optionalPojo",
87+
Schema.builder()
88+
.type("OBJECT")
89+
.nullable(true)
90+
.properties(
91+
ImmutableMap.of(
92+
"field1", Schema.builder().type("STRING").build(),
93+
"field2", Schema.builder().type("INTEGER").build()))
94+
.build(),
95+
"optionalList",
96+
Schema.builder()
97+
.type("ARRAY")
98+
.nullable(true)
99+
.items(Schema.builder().type("STRING").build())
100+
.build()))
101+
.build();
102+
103+
assertThat(schema).isEqualTo(expectedSchema);
104+
}
105+
}

0 commit comments

Comments
 (0)