Skip to content

Commit f9aab42

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add thought signature support for chat completions
- https://ai.google.dev/gemini-api/docs/thought-signatures - Also adds logging that's similar to the GenAI flow - Also fixes some issues that appeared when testing streaming - Final tool call no longer returns the accumulated text - Add thought signature support PiperOrigin-RevId: 919115462
1 parent 0a40557 commit f9aab42

6 files changed

Lines changed: 804 additions & 125 deletions

File tree

core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,21 @@ private static Duration resolveCallTimeout(HttpOptions httpOptions) {
190190
public Flowable<LlmResponse> complete(LlmRequest llmRequest, boolean stream) {
191191
return Flowable.defer(
192192
() -> {
193+
String effectiveModelName = llmRequest.model().orElse("?");
194+
logger.trace("Chat Completion Request Contents: {}", llmRequest.contents());
195+
llmRequest.config().ifPresent(c -> logger.trace("Chat Completion Request Config: {}", c));
196+
193197
ChatCompletionsRequest dtoRequest =
194198
ChatCompletionsRequest.fromLlmRequest(llmRequest, stream);
195199
String jsonPayload = objectMapper.writeValueAsString(dtoRequest);
196-
logger.trace(
197-
"Chat Completion Request: model={}, stream={}, messagesCount={}",
198-
dtoRequest.model,
199-
dtoRequest.stream,
200-
dtoRequest.messages != null ? dtoRequest.messages.size() : 0);
200+
logger.trace("Chat Completion Request JSON: {}", jsonPayload);
201+
202+
if (stream) {
203+
logger.debug(
204+
"Sending streaming chat-completion request to model {}", effectiveModelName);
205+
} else {
206+
logger.debug("Sending chat-completion request to model {}", effectiveModelName);
207+
}
201208

202209
Request.Builder requestBuilder =
203210
new Request.Builder().url(completionsUrl).post(RequestBody.create(jsonPayload, JSON));
@@ -209,11 +216,7 @@ public Flowable<LlmResponse> complete(LlmRequest llmRequest, boolean stream) {
209216
requestBuilder.header("Content-Type", JSON.toString());
210217

211218
Request request = requestBuilder.build();
212-
if (stream) {
213-
return createStreamingFlowable(request);
214-
} else {
215-
return createNonStreamingFlowable(request);
216-
}
219+
return stream ? createStreamingFlowable(request) : createNonStreamingFlowable(request);
217220
});
218221
}
219222

@@ -274,10 +277,14 @@ public void onResponse(Call call, Response response) {
274277
// A single malformed chunk must not abort the entire stream. Log a
275278
// warning and continue.
276279
try {
280+
logger.trace("Raw streaming chat-completion chunk: {}", data);
277281
ChatCompletionsResponse.ChatCompletionChunk chunk =
278282
objectMapper.readValue(
279283
data, ChatCompletionsResponse.ChatCompletionChunk.class);
280284
ImmutableList<LlmResponse> responses = collection.processChunk(chunk);
285+
if (!responses.isEmpty()) {
286+
logger.trace("Responses to emit: {}", responses);
287+
}
281288
for (LlmResponse resp : responses) {
282289
emitter.onNext(resp);
283290
}
@@ -341,9 +348,12 @@ public void onResponse(Call call, Response response) {
341348
}
342349

343350
String jsonResponse = body.string();
351+
logger.trace("Raw non-streaming chat-completion response: {}", jsonResponse);
344352
ChatCompletionsResponse.ChatCompletion completion =
345353
objectMapper.readValue(jsonResponse, ChatCompletionsResponse.ChatCompletion.class);
346-
emitter.onNext(completion.toLlmResponse());
354+
LlmResponse llmResponse = completion.toLlmResponse();
355+
logger.trace("Response to emit: {}", llmResponse);
356+
emitter.onNext(llmResponse);
347357
emitter.onComplete();
348358
} catch (Exception e) {
349359
emitter.tryOnError(e);

core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.google.adk.JsonBaseModel;
2727
import com.google.adk.models.LlmRequest;
2828
import com.google.common.collect.ImmutableList;
29+
import com.google.common.collect.ImmutableMap;
2930
import com.google.genai.types.Content;
3031
import com.google.genai.types.FunctionDeclaration;
3132
import com.google.genai.types.FunctionResponse;
@@ -351,41 +352,43 @@ private static List<Message> processContent(Content content) {
351352
List<ChatCompletionsCommon.ToolCall> toolCalls = new ArrayList<>();
352353
List<Message> toolResponses = new ArrayList<>();
353354
List<String> refusals = new ArrayList<>();
354-
355-
content
356-
.parts()
357-
.ifPresent(
358-
parts -> {
359-
for (Part part : parts) {
360-
if (part.text().isPresent()) {
361-
// Text Parts may carry refusal content prefixed with REFUSAL_PREFIX.
362-
ChatCompletionsCommon.RefusalSplit split =
363-
ChatCompletionsCommon.parseRefusalPrefix(part.text().get());
364-
if (split.content() != null) {
365-
ContentPart textPart = new ContentPart();
366-
textPart.type = "text";
367-
textPart.text = split.content();
368-
contentParts.add(textPart);
369-
}
370-
if (split.refusal() != null) {
371-
refusals.add(split.refusal());
372-
}
373-
} else if (part.inlineData().isPresent()) {
374-
contentParts.add(processInlineDataPart(part));
375-
} else if (part.fileData().isPresent()) {
376-
contentParts.add(processFileDataPart(part));
377-
} else if (part.functionCall().isPresent()) {
378-
toolCalls.add(processFunctionCallPart(part));
379-
} else if (part.functionResponse().isPresent()) {
380-
toolResponses.add(processFunctionResponsePart(part));
381-
} else if (part.executableCode().isPresent()) {
382-
logger.warn("Executable code is not supported in Chat Completion conversion");
383-
} else if (part.codeExecutionResult().isPresent()) {
384-
logger.warn(
385-
"Code execution result is not supported in Chat Completion conversion");
386-
}
387-
}
388-
});
355+
// Capture a message-level thought_signature from the first text Part that carries one.
356+
// This signature must be echoed back on subsequent turns to ensure proper round-tripping.
357+
byte[] textThoughtSignature = null;
358+
359+
if (content.parts().isPresent()) {
360+
for (Part part : content.parts().get()) {
361+
if (part.text().isPresent()) {
362+
// Text Parts may carry refusal content prefixed with REFUSAL_PREFIX.
363+
ChatCompletionsCommon.RefusalSplit split =
364+
ChatCompletionsCommon.parseRefusalPrefix(part.text().get());
365+
if (split.content() != null) {
366+
ContentPart textPart = new ContentPart();
367+
textPart.type = "text";
368+
textPart.text = split.content();
369+
contentParts.add(textPart);
370+
}
371+
if (split.refusal() != null) {
372+
refusals.add(split.refusal());
373+
}
374+
if (textThoughtSignature == null && part.thoughtSignature().isPresent()) {
375+
textThoughtSignature = part.thoughtSignature().get();
376+
}
377+
} else if (part.inlineData().isPresent()) {
378+
contentParts.add(processInlineDataPart(part));
379+
} else if (part.fileData().isPresent()) {
380+
contentParts.add(processFileDataPart(part));
381+
} else if (part.functionCall().isPresent()) {
382+
toolCalls.add(processFunctionCallPart(part));
383+
} else if (part.functionResponse().isPresent()) {
384+
toolResponses.add(processFunctionResponsePart(part));
385+
} else if (part.executableCode().isPresent()) {
386+
logger.warn("Executable code is not supported in Chat Completion conversion");
387+
} else if (part.codeExecutionResult().isPresent()) {
388+
logger.warn("Code execution result is not supported in Chat Completion conversion");
389+
}
390+
}
391+
}
389392

390393
if (!toolResponses.isEmpty()) {
391394
return toolResponses;
@@ -403,6 +406,14 @@ private static List<Message> processContent(Content content) {
403406
msg.content = new MessageContent(ImmutableList.copyOf(contentParts));
404407
}
405408
}
409+
// Round-trip the message-level thought_signature for assistant text responses.
410+
if (textThoughtSignature != null) {
411+
msg.extraContent =
412+
ImmutableMap.of(
413+
"google",
414+
ImmutableMap.of(
415+
"thought_signature", Base64.getEncoder().encodeToString(textThoughtSignature)));
416+
}
406417
List<Message> messages = new ArrayList<>();
407418
messages.add(msg);
408419
return messages;
@@ -446,6 +457,10 @@ private static ContentPart processFileDataPart(Part part) {
446457
/**
447458
* Processes a function call part and returns a mapped ToolCall.
448459
*
460+
* <p>If the source {@link Part} carries a {@code thoughtSignature}, it is round-tripped back out
461+
* as a base64-encoded string in {@code extra_content.google.thought_signature} to satisfy
462+
* endpoint requirements.
463+
*
449464
* @param part The input part containing a requested function call or invocation.
450465
* @return The mapped function call tool call.
451466
*/
@@ -464,6 +479,13 @@ private static ChatCompletionsCommon.ToolCall processFunctionCallPart(Part part)
464479
}
465480
}
466481
toolCall.function = function;
482+
part.thoughtSignature()
483+
.ifPresent(
484+
sigBytes -> {
485+
String sig = Base64.getEncoder().encodeToString(sigBytes);
486+
toolCall.extraContent =
487+
ImmutableMap.of("google", ImmutableMap.of("thought_signature", sig));
488+
});
467489
return toolCall;
468490
}
469491

@@ -616,6 +638,13 @@ static class Message {
616638

617639
/** See class definition for more details. */
618640
public String refusal;
641+
642+
/**
643+
* Message-level additional parameters used by some providers. Used for round-tripping data like
644+
* {@code extra_content.google.thought_signature}.
645+
*/
646+
@JsonProperty("extra_content")
647+
public Map<String, Object> extraContent;
619648
}
620649

621650
/**

0 commit comments

Comments
 (0)