Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ private ChatCompletionsCommon() {}

private static final ObjectMapper objectMapper = new ObjectMapper();

static final String EMPTY_JSON_OBJECT = "{}";
static final Map<String, Object> EMPTY_PARAMETERS_SCHEMA =
Map.of("type", "object", "properties", Map.of());

public static final String ROLE_ASSISTANT = "assistant";
public static final String ROLE_MODEL = "model";

Expand Down Expand Up @@ -157,6 +161,18 @@ public Part applyThoughtSignature(Part part) {
}
}

/**
* Robust defense: Function arguments and responses MUST be JSON objects. If the LLM hallucinates
* "null", "NULL", "none", or conversational text, we fallback to an empty JSON object "{}". This
* prevents OpenAI-compatible proxies (like Groq) from throwing 400 Bad Request errors.
*/
static String enforceJsonObject(String json) {
if (json == null || !json.trim().startsWith("{")) {
return EMPTY_JSON_OBJECT;
}
return json.trim();
}

/**
* See
* https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_function_tool_call%20%3E%20(schema)
Expand All @@ -181,21 +197,27 @@ public FunctionCall toFunctionCall(@Nullable String toolCallId) {
if (name != null) {
fcBuilder.name(name);
}
if (arguments != null && !arguments.isEmpty()) {
try {
Map<String, Object> args =
objectMapper.readValue(arguments, new TypeReference<Map<String, Object>>() {});
fcBuilder.args(args);
} catch (Exception e) {
throw new IllegalArgumentException(
"Failed to parse function arguments JSON: " + arguments, e);
}
}
fcBuilder.args(parseArguments(arguments));
if (toolCallId != null) {
fcBuilder.id(toolCallId);
}
return fcBuilder.build();
}

private Map<String, Object> parseArguments(String arguments) {
if (arguments == null || arguments.trim().isEmpty()) {
return Map.of();
}
try {
String json = enforceJsonObject(arguments);
Map<String, Object> result =
objectMapper.readValue(json, new TypeReference<Map<String, Object>>() {});
return result != null ? result : Map.of();
} catch (Exception e) {
throw new IllegalArgumentException(
"Failed to parse function arguments JSON: " + arguments, e);
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.google.adk.JsonBaseModel;
import com.google.adk.models.LlmRequest;
import com.google.common.collect.ImmutableList;
Expand All @@ -32,6 +36,8 @@
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Part;
import com.google.genai.types.Type;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
Expand Down Expand Up @@ -270,7 +276,28 @@ public final class ChatCompletionsRequest {
public Map<String, Object> extraBody;

private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsRequest.class);
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();

/**
* Registers a custom serializer to force JSON Schema types to lowercase (e.g., "STRING" ->
* "string"). The genai SDK uses uppercase Enums for schema types, which strict OpenAI-compatible
* endpoints reject with HTTP 400.
*/
private static SimpleModule schemaNormalizerModule() {
SimpleModule module = new SimpleModule();
module.addSerializer(
Type.class,
new JsonSerializer<Type>() {
@Override
public void serialize(Type value, JsonGenerator gen, SerializerProvider serializers)
throws IOException {
gen.writeString(value.toString().toLowerCase());
}
});
return module;
}

private static final ObjectMapper objectMapper =
JsonBaseModel.getMapper().copy().registerModule(schemaNormalizerModule());

/**
* Converts a standard {@link LlmRequest} into a {@link ChatCompletionsRequest} for
Expand Down Expand Up @@ -473,10 +500,14 @@ private static ChatCompletionsCommon.ToolCall processFunctionCallPart(Part part)
function.name = fc.name().orElse("");
if (fc.args().isPresent()) {
try {
function.arguments = objectMapper.writeValueAsString(fc.args().get());
String json = objectMapper.writeValueAsString(fc.args().get());
function.arguments = ChatCompletionsCommon.enforceJsonObject(json);
} catch (Exception e) {
logger.warn("Failed to serialize function arguments", e);
function.arguments = ChatCompletionsCommon.EMPTY_JSON_OBJECT;
}
} else {
function.arguments = ChatCompletionsCommon.EMPTY_JSON_OBJECT;
}
toolCall.function = function;
part.thoughtSignature()
Expand All @@ -502,10 +533,14 @@ private static Message processFunctionResponsePart(Part part) {
toolResp.toolCallId = fr.id().orElse("");
if (fr.response().isPresent()) {
try {
toolResp.content = new MessageContent(objectMapper.writeValueAsString(fr.response().get()));
String json = objectMapper.writeValueAsString(fr.response().get());
toolResp.content = new MessageContent(ChatCompletionsCommon.enforceJsonObject(json));
} catch (Exception e) {
logger.warn("Failed to serialize tool response", e);
toolResp.content = new MessageContent(ChatCompletionsCommon.EMPTY_JSON_OBJECT);
}
} else {
toolResp.content = new MessageContent(ChatCompletionsCommon.EMPTY_JSON_OBJECT);
}
return toolResp;
}
Expand Down Expand Up @@ -570,12 +605,15 @@ private static void handleTools(GenerateContentConfig config, ChatCompletionsReq
FunctionDefinition def = new FunctionDefinition();
def.name = fd.name().orElse("");
def.description = fd.description().orElse("");
fd.parameters()
.ifPresent(
params ->
def.parameters =
objectMapper.convertValue(
params, new TypeReference<Map<String, Object>>() {}));
if (fd.parameters().isPresent()) {
def.parameters =
objectMapper.convertValue(
fd.parameters().get(), new TypeReference<Map<String, Object>>() {});
} else {
// OpenAI-compatible APIs (like Groq) strictly require the parameters object
// to exist, even for zero-argument functions.
def.parameters = ChatCompletionsCommon.EMPTY_PARAMETERS_SCHEMA;
}
tool.function = def;
tools.add(tool);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import com.google.genai.types.Tool;
import com.google.genai.types.ToolConfig;
import java.util.AbstractMap;
Expand Down Expand Up @@ -567,6 +568,84 @@ public void testFromLlmRequest_withFunctionCall() throws Exception {
assertThat(msg.toolCalls.get(0).function.arguments).isEqualTo("{\"location\":\"Paris\"}");
}

@Test
public void testFromLlmRequest_withAbsentFunctionArguments() throws Exception {
FunctionCall functionCall = FunctionCall.builder().id("call_123").name("get_time").build();
Part part = Part.builder().functionCall(functionCall).build();
Content content = Content.builder().role("model").parts(ImmutableList.of(part)).build();

LlmRequest llmRequest =
LlmRequest.builder().model("gemini-1.5-pro").contents(ImmutableList.of(content)).build();

ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);

assertThat(request.messages).hasSize(1);
ChatCompletionsRequest.Message msg = request.messages.get(0);
assertThat(msg.role).isEqualTo("assistant");
assertThat(msg.toolCalls).hasSize(1);
assertThat(msg.toolCalls.get(0).function.name).isEqualTo("get_time");
assertThat(msg.toolCalls.get(0).function.arguments).isEqualTo("{}");
}

@Test
public void testFromLlmRequest_withAbsentParameters() throws Exception {
FunctionDeclaration function =
FunctionDeclaration.builder().name("test_func").description("A test function").build();

Tool tool = Tool.builder().functionDeclarations(ImmutableList.of(function)).build();
GenerateContentConfig config =
GenerateContentConfig.builder().tools(ImmutableList.of(tool)).build();

LlmRequest llmRequest =
LlmRequest.builder()
.model("gemini-1.5-pro")
.config(config)
.contents(ImmutableList.of())
.build();

ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);

assertThat(request.tools).hasSize(1);
Map<String, Object> params = (Map<String, Object>) request.tools.get(0).function.parameters;
assertThat(params.get("type")).isEqualTo("object");
@SuppressWarnings("unchecked")
Map<String, Object> props = (Map<String, Object>) params.get("properties");
assertThat(props).isEmpty();
}

@Test
public void testFromLlmRequest_normalizesSchemaTypeToLowerCase() throws Exception {
Schema param1Schema = Schema.builder().type("STRING").build();

Schema functionSchema =
Schema.builder().type("OBJECT").properties(ImmutableMap.of("param1", param1Schema)).build();

FunctionDeclaration function =
FunctionDeclaration.builder().name("test_func").parameters(functionSchema).build();

Tool tool = Tool.builder().functionDeclarations(ImmutableList.of(function)).build();
GenerateContentConfig config =
GenerateContentConfig.builder().tools(ImmutableList.of(tool)).build();

LlmRequest llmRequest =
LlmRequest.builder()
.model("gemini-1.5-pro")
.config(config)
.contents(ImmutableList.of())
.build();

ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);

assertThat(request.tools).hasSize(1);
Map<String, Object> params = (Map<String, Object>) request.tools.get(0).function.parameters;
assertThat(params.get("type")).isEqualTo("object");
@SuppressWarnings("unchecked")
Map<String, Object> props = (Map<String, Object>) params.get("properties");
@SuppressWarnings("unchecked")
Map<String, Object> param1 = (Map<String, Object>) props.get("param1");
assertThat(param1.get("type")).isEqualTo("string");
}

@Test
public void testFromLlmRequest_withStreamOptions() throws Exception {
LlmRequest llmRequest =
Expand Down Expand Up @@ -628,11 +707,11 @@ public void testFromLlmRequest_withFunctionResponse() throws Exception {

assertThat(request.messages.get(1).role).isEqualTo("tool");
assertThat(request.messages.get(1).toolCallId).isEmpty();
assertThat(request.messages.get(1).content).isNull();
assertThat(request.messages.get(1).content.getValue()).isEqualTo("{}");

assertThat(request.messages.get(2).role).isEqualTo("tool");
assertThat(request.messages.get(2).toolCallId).isEqualTo("call_faulty");
assertThat(request.messages.get(2).content).isNull();
assertThat(request.messages.get(2).content.getValue()).isEqualTo("{}");
}

@Test
Expand Down