Skip to content

Commit 3b0e29c

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add conversion from LlmRequest to ChatCompletionsRequest
This is part of a larger chain of commits for adding chat completion API support to the Apigee model. PiperOrigin-RevId: 893207742
1 parent dd46b25 commit 3b0e29c

3 files changed

Lines changed: 695 additions & 48 deletions

File tree

core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java

Lines changed: 333 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,37 @@
2121
import com.fasterxml.jackson.annotation.JsonInclude;
2222
import com.fasterxml.jackson.annotation.JsonProperty;
2323
import com.fasterxml.jackson.annotation.JsonValue;
24+
import com.fasterxml.jackson.core.type.TypeReference;
25+
import com.fasterxml.jackson.databind.ObjectMapper;
26+
import com.google.adk.JsonBaseModel;
27+
import com.google.adk.models.LlmRequest;
28+
import com.google.common.collect.ImmutableList;
29+
import com.google.genai.types.Content;
30+
import com.google.genai.types.FunctionDeclaration;
31+
import com.google.genai.types.FunctionResponse;
32+
import com.google.genai.types.GenerateContentConfig;
33+
import com.google.genai.types.Part;
34+
import java.util.ArrayList;
35+
import java.util.Base64;
2436
import java.util.List;
2537
import java.util.Map;
38+
import java.util.Objects;
39+
import java.util.Optional;
40+
import org.slf4j.Logger;
41+
import org.slf4j.LoggerFactory;
2642

2743
/**
2844
* Data Transfer Objects for Chat Completion API requests.
2945
*
46+
* <p>Can be used to translate from a {@link LlmRequest} into a {@link ChatCompletionsRequest} using
47+
* {@link #fromLlmRequest(LlmRequest, boolean)}.
48+
*
3049
* <p>See
3150
* https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create
3251
*/
3352
@JsonIgnoreProperties(ignoreUnknown = true)
3453
@JsonInclude(JsonInclude.Include.NON_NULL)
35-
final class ChatCompletionsRequest {
54+
public final class ChatCompletionsRequest {
3655

3756
/**
3857
* See
@@ -249,6 +268,319 @@ final class ChatCompletionsRequest {
249268
@JsonProperty("extra_body")
250269
public Map<String, Object> extraBody;
251270

271+
private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsRequest.class);
272+
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();
273+
274+
/**
275+
* Converts a standard {@link LlmRequest} into a {@link ChatCompletionsRequest} for
276+
* /chat/completions compatible endpoints.
277+
*
278+
* @param llmRequest The internal source request containing contents, configuration, and tool
279+
* definitions.
280+
* @param responseStreaming True if the request asks for a streaming response.
281+
* @return A populated ChatCompletionsRequest ready for JSON serialization.
282+
*/
283+
public static ChatCompletionsRequest fromLlmRequest(
284+
LlmRequest llmRequest, boolean responseStreaming) {
285+
ChatCompletionsRequest request = new ChatCompletionsRequest();
286+
request.model = llmRequest.model().orElse("");
287+
request.stream = responseStreaming;
288+
if (responseStreaming) {
289+
StreamOptions options = new StreamOptions();
290+
options.includeUsage = true;
291+
request.streamOptions = options;
292+
}
293+
294+
boolean isOSeries = request.model.matches("^o\\d+(?:-.*)?$");
295+
296+
List<Message> messages = new ArrayList<>();
297+
298+
llmRequest
299+
.config()
300+
.flatMap(config -> processSystemInstruction(config, isOSeries))
301+
.ifPresent(messages::add);
302+
303+
for (Content content : llmRequest.contents()) {
304+
messages.addAll(processContent(content));
305+
}
306+
307+
request.messages = ImmutableList.copyOf(messages);
308+
309+
llmRequest
310+
.config()
311+
.ifPresent(
312+
config -> {
313+
handleConfigOptions(config, request);
314+
handleTools(config, request);
315+
});
316+
317+
return request;
318+
}
319+
320+
/**
321+
* Processes the system instruction configuration and returns a mapped Message if present.
322+
*
323+
* @param config The content generation configuration that may contain a system instruction.
324+
* @param isOSeries True if the target model belongs to the OpenAI o-series (e.g., o1, o3), which
325+
* requires the "developer" role instead of the standard "system" role.
326+
* @return An Optional containing the mapped instruction, or empty if none exists.
327+
*/
328+
private static Optional<Message> processSystemInstruction(
329+
GenerateContentConfig config, boolean isOSeries) {
330+
if (config.systemInstruction().isPresent()) {
331+
Message systemMsg = new Message();
332+
systemMsg.role = isOSeries ? "developer" : "system";
333+
systemMsg.content = new MessageContent(config.systemInstruction().get().text());
334+
return Optional.of(systemMsg);
335+
}
336+
return Optional.empty();
337+
}
338+
339+
/**
340+
* Processes incoming content and returns a list of messages resulting from it.
341+
*
342+
* @param content The incoming content containing parts to map.
343+
* @return A list of mapped messages.
344+
*/
345+
private static List<Message> processContent(Content content) {
346+
Message msg = new Message();
347+
String role = content.role().orElse("user");
348+
msg.role = role.equals("model") ? "assistant" : role;
349+
350+
List<ContentPart> contentParts = new ArrayList<>();
351+
List<ChatCompletionsCommon.ToolCall> toolCalls = new ArrayList<>();
352+
List<Message> toolResponses = new ArrayList<>();
353+
354+
content
355+
.parts()
356+
.ifPresent(
357+
parts -> {
358+
for (Part part : parts) {
359+
if (part.text().isPresent()) {
360+
contentParts.add(processTextPart(part));
361+
} else if (part.inlineData().isPresent()) {
362+
contentParts.add(processInlineDataPart(part));
363+
} else if (part.fileData().isPresent()) {
364+
contentParts.add(processFileDataPart(part));
365+
} else if (part.functionCall().isPresent()) {
366+
toolCalls.add(processFunctionCallPart(part));
367+
} else if (part.functionResponse().isPresent()) {
368+
toolResponses.add(processFunctionResponsePart(part));
369+
} else if (part.executableCode().isPresent()) {
370+
logger.warn("Executable code is not supported in Chat Completion conversion");
371+
} else if (part.codeExecutionResult().isPresent()) {
372+
logger.warn(
373+
"Code execution result is not supported in Chat Completion conversion");
374+
}
375+
}
376+
});
377+
378+
if (!toolResponses.isEmpty()) {
379+
return toolResponses;
380+
} else {
381+
if (!toolCalls.isEmpty()) {
382+
msg.toolCalls = ImmutableList.copyOf(toolCalls);
383+
}
384+
if (!contentParts.isEmpty()) {
385+
if (contentParts.size() == 1 && Objects.equals(contentParts.get(0).type, "text")) {
386+
msg.content = new MessageContent(contentParts.get(0).text);
387+
} else {
388+
msg.content = new MessageContent(ImmutableList.copyOf(contentParts));
389+
}
390+
}
391+
List<Message> messages = new ArrayList<>();
392+
messages.add(msg);
393+
return messages;
394+
}
395+
}
396+
397+
/**
398+
* Processes a text part and returns a mapped ContentPart.
399+
*
400+
* @param part The input part containing simple text.
401+
* @return The mapped text part.
402+
*/
403+
private static ContentPart processTextPart(Part part) {
404+
ContentPart textPart = new ContentPart();
405+
textPart.type = "text";
406+
textPart.text = part.text().get();
407+
return textPart;
408+
}
409+
410+
/**
411+
* Processes an inline data part and returns a mapped ContentPart.
412+
*
413+
* @param part The input part containing base64 inline data.
414+
* @return The mapped inline data part.
415+
*/
416+
private static ContentPart processInlineDataPart(Part part) {
417+
ContentPart imgPart = new ContentPart();
418+
imgPart.type = "image_url";
419+
ImageUrl imageUrl = new ImageUrl();
420+
imageUrl.url =
421+
"data:"
422+
+ part.inlineData().get().mimeType().orElse("image/jpeg")
423+
+ ";base64,"
424+
+ Base64.getEncoder().encodeToString(part.inlineData().get().data().get());
425+
imgPart.imageUrl = imageUrl;
426+
return imgPart;
427+
}
428+
429+
/**
430+
* Processes a file data part and returns a mapped ContentPart.
431+
*
432+
* @param part The input part referencing a stored file via URI.
433+
* @return The mapped file data part.
434+
*/
435+
private static ContentPart processFileDataPart(Part part) {
436+
ContentPart imgPart = new ContentPart();
437+
imgPart.type = "image_url";
438+
ImageUrl imageUrl = new ImageUrl();
439+
imageUrl.url = part.fileData().get().fileUri().orElse("");
440+
imgPart.imageUrl = imageUrl;
441+
return imgPart;
442+
}
443+
444+
/**
445+
* Processes a function call part and returns a mapped ToolCall.
446+
*
447+
* @param part The input part containing a requested function call or invocation.
448+
* @return The mapped function call tool call.
449+
*/
450+
private static ChatCompletionsCommon.ToolCall processFunctionCallPart(Part part) {
451+
com.google.genai.types.FunctionCall fc = part.functionCall().get();
452+
ChatCompletionsCommon.ToolCall toolCall = new ChatCompletionsCommon.ToolCall();
453+
toolCall.id = fc.id().orElse("call_" + fc.name().orElse("unknown"));
454+
toolCall.type = "function";
455+
ChatCompletionsCommon.Function function = new ChatCompletionsCommon.Function();
456+
function.name = fc.name().orElse("");
457+
if (fc.args().isPresent()) {
458+
try {
459+
function.arguments = objectMapper.writeValueAsString(fc.args().get());
460+
} catch (Exception e) {
461+
logger.warn("Failed to serialize function arguments", e);
462+
}
463+
}
464+
toolCall.function = function;
465+
return toolCall;
466+
}
467+
468+
/**
469+
* Processes a function response part and returns a mapped Message.
470+
*
471+
* @param part The input part containing the execution results of a function.
472+
* @return The mapped tool response message.
473+
*/
474+
private static Message processFunctionResponsePart(Part part) {
475+
FunctionResponse fr = part.functionResponse().get();
476+
Message toolResp = new Message();
477+
toolResp.role = "tool";
478+
toolResp.toolCallId = fr.id().orElse("");
479+
if (fr.response().isPresent()) {
480+
try {
481+
toolResp.content = new MessageContent(objectMapper.writeValueAsString(fr.response().get()));
482+
} catch (Exception e) {
483+
logger.warn("Failed to serialize tool response", e);
484+
}
485+
}
486+
return toolResp;
487+
}
488+
489+
/**
490+
* Updates the request based on the provided configuration options.
491+
*
492+
* @param config The content generation configuration containing parameters such as temperature.
493+
* @param request The chat completions request to populate with matching options.
494+
*/
495+
private static void handleConfigOptions(
496+
GenerateContentConfig config, ChatCompletionsRequest request) {
497+
config.temperature().ifPresent(v -> request.temperature = v.doubleValue());
498+
config.topP().ifPresent(v -> request.topP = v.doubleValue());
499+
config
500+
.maxOutputTokens()
501+
.ifPresent(
502+
v -> {
503+
request.maxCompletionTokens = Math.toIntExact(v);
504+
});
505+
config.stopSequences().ifPresent(v -> request.stop = new StopCondition(v));
506+
config.candidateCount().ifPresent(v -> request.n = Math.toIntExact(v));
507+
config.presencePenalty().ifPresent(v -> request.presencePenalty = v.doubleValue());
508+
config.frequencyPenalty().ifPresent(v -> request.frequencyPenalty = v.doubleValue());
509+
config.seed().ifPresent(v -> request.seed = v.longValue());
510+
511+
if (config.responseJsonSchema().isPresent()) {
512+
ResponseFormatJsonSchema format = new ResponseFormatJsonSchema();
513+
ResponseFormatJsonSchema.JsonSchema schema = new ResponseFormatJsonSchema.JsonSchema();
514+
schema.name = "response_schema";
515+
schema.schema =
516+
objectMapper.convertValue(
517+
config.responseJsonSchema().get(), new TypeReference<Map<String, Object>>() {});
518+
schema.strict = true;
519+
format.jsonSchema = schema;
520+
request.responseFormat = format;
521+
} else if (config.responseMimeType().isPresent()
522+
&& config.responseMimeType().get().equals("application/json")) {
523+
request.responseFormat = new ResponseFormatJsonObject();
524+
}
525+
526+
if (config.responseLogprobs().isPresent() && config.responseLogprobs().get()) {
527+
request.logprobs = true;
528+
config.logprobs().ifPresent(v -> request.topLogprobs = Math.toIntExact(v));
529+
}
530+
}
531+
532+
/**
533+
* Updates the request tools list based on the provided tools configuration.
534+
*
535+
* @param config The content generation configuration defining available tools.
536+
* @param request The chat completions request to populate with mapped tool definitions.
537+
*/
538+
private static void handleTools(GenerateContentConfig config, ChatCompletionsRequest request) {
539+
if (config.tools().isPresent()) {
540+
List<Tool> tools = new ArrayList<>();
541+
for (com.google.genai.types.Tool t : config.tools().get()) {
542+
if (t.functionDeclarations().isPresent()) {
543+
for (FunctionDeclaration fd : t.functionDeclarations().get()) {
544+
Tool tool = new Tool();
545+
tool.type = "function";
546+
FunctionDefinition def = new FunctionDefinition();
547+
def.name = fd.name().orElse("");
548+
def.description = fd.description().orElse("");
549+
fd.parameters()
550+
.ifPresent(
551+
params ->
552+
def.parameters =
553+
objectMapper.convertValue(
554+
params, new TypeReference<Map<String, Object>>() {}));
555+
tool.function = def;
556+
tools.add(tool);
557+
}
558+
}
559+
}
560+
if (!tools.isEmpty()) {
561+
request.tools = ImmutableList.copyOf(tools);
562+
if (config.toolConfig().isPresent()
563+
&& config.toolConfig().get().functionCallingConfig().isPresent()) {
564+
config
565+
.toolConfig()
566+
.get()
567+
.functionCallingConfig()
568+
.get()
569+
.mode()
570+
.ifPresent(
571+
mode -> {
572+
switch (mode.knownEnum()) {
573+
case ANY -> request.toolChoice = new ToolChoiceMode("required");
574+
case NONE -> request.toolChoice = new ToolChoiceMode("none");
575+
case AUTO -> request.toolChoice = new ToolChoiceMode("auto");
576+
default -> {}
577+
}
578+
});
579+
}
580+
}
581+
}
582+
}
583+
252584
/**
253585
* A catch-all class for message parameters. See
254586
* https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema)

core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public final class ChatCompletionsResponse {
5050

5151
private ChatCompletionsResponse() {}
5252

53-
static @Nullable FinishReason mapFinishReason(String reason) {
53+
static @Nullable FinishReason mapFinishReason(@Nullable String reason) {
5454
if (reason == null) {
5555
return null;
5656
}
@@ -62,7 +62,7 @@ private ChatCompletionsResponse() {}
6262
};
6363
}
6464

65-
static @Nullable GenerateContentResponseUsageMetadata mapUsage(Usage usage) {
65+
static @Nullable GenerateContentResponseUsageMetadata mapUsage(@Nullable Usage usage) {
6666
if (usage == null) {
6767
return null;
6868
}
@@ -188,8 +188,15 @@ private ImmutableList<Part> mapMessageToParts(Message message) {
188188
return parts.build();
189189
}
190190

191+
/**
192+
* Maps a list of tool calls to a list of {@link Part} objects.
193+
*
194+
* @param toolCalls the list of tool calls to map (non-null).
195+
* @return a list of parts containing converted tool calls.
196+
*/
191197
private ImmutableList<Part> mapToolCallsToParts(
192198
List<ChatCompletionsCommon.ToolCall> toolCalls) {
199+
193200
ImmutableList.Builder<Part> parts = ImmutableList.builder();
194201
for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) {
195202
Part part = toolCall.toPart();

0 commit comments

Comments
 (0)