Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 93 additions & 8 deletions core/src/main/java/com/google/adk/models/ApigeeLlm.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.Strings.isNullOrEmpty;

import com.google.adk.Version;
import com.google.adk.models.chat.ChatCompletionsHttpClient;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
Expand All @@ -28,6 +29,8 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A {@link BaseLlm} implementation for calling an Apigee proxy.
Expand All @@ -36,6 +39,7 @@
* allows for specifying the provider (Gemini or Vertex AI), API version, and model ID.
*/
public class ApigeeLlm extends BaseLlm {
private static final Logger logger = LoggerFactory.getLogger(ApigeeLlm.class);
private static final String GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME =
"GOOGLE_GENAI_USE_VERTEXAI";
private static final String APIGEE_PROXY_URL_ENV_VARIABLE_NAME = "APIGEE_PROXY_URL";
Expand All @@ -51,9 +55,18 @@ public class ApigeeLlm extends BaseLlm {
"user-agent", versionHeaderValue);
}

/** Defines the type of API to be used by the Apigee proxy. */
public enum ApiType {
UNKNOWN,
CHAT_COMPLETIONS,
GENAI
}

private final Gemini geminiDelegate;
private final ChatCompletionsHttpClient chatCompletionsHttpClient;
private final Client apiClient;
private final HttpOptions httpOptions;
private final ApiType apiType;

/**
* Constructs a new ApigeeLlm instance.
Expand All @@ -62,7 +75,8 @@ public class ApigeeLlm extends BaseLlm {
* @param proxyUrl The URL of the Apigee proxy.
* @param customHeaders A map of custom headers to be sent with the request.
*/
private ApigeeLlm(String modelName, String proxyUrl, Map<String, String> customHeaders) {
private ApigeeLlm(
String modelName, String proxyUrl, Map<String, String> customHeaders, ApiType apiType) {
super(modelName);

if (!validateModelString(modelName)) {
Expand All @@ -71,6 +85,16 @@ private ApigeeLlm(String modelName, String proxyUrl, Map<String, String> customH
+ modelName);
}

if (apiType == null || apiType == ApiType.UNKNOWN) {
if (modelName.startsWith("apigee/openai/")) {
this.apiType = ApiType.CHAT_COMPLETIONS;
} else {
this.apiType = ApiType.GENAI;
}
} else {
this.apiType = apiType;
}

String effectiveProxyUrl = proxyUrl;
if (isNullOrEmpty(effectiveProxyUrl)) {
effectiveProxyUrl = System.getenv(APIGEE_PROXY_URL_ENV_VARIABLE_NAME);
Expand All @@ -96,13 +120,26 @@ private ApigeeLlm(String modelName, String proxyUrl, Map<String, String> customH
.buildOrThrow());
}
this.httpOptions = httpOptionsBuilder.build();
Client.Builder apiClientBuilder = Client.builder().httpOptions(this.httpOptions);
if (isVertexAiModel(modelName)) {
apiClientBuilder.vertexAI(true);

if (this.apiType == ApiType.CHAT_COMPLETIONS) {
this.apiClient = null;
this.geminiDelegate = null;
this.chatCompletionsHttpClient = new ChatCompletionsHttpClient(this.httpOptions);
} else {
Client.Builder apiClientBuilder = Client.builder().httpOptions(this.httpOptions);
if (isVertexAiModel(modelName)) {
apiClientBuilder.vertexAI(true);
}
this.apiClient = apiClientBuilder.build();
this.geminiDelegate = new Gemini(modelName, apiClient);
this.chatCompletionsHttpClient = null;
}

this.apiClient = apiClientBuilder.build();
this.geminiDelegate = new Gemini(modelName, apiClient);
logger.trace(
"ApigeeLlm constructed: modelName={} apiType={} effectiveProxyUrl={}",
modelName,
this.apiType,
effectiveProxyUrl);
}

/**
Expand All @@ -113,10 +150,31 @@ private ApigeeLlm(String modelName, String proxyUrl, Map<String, String> customH
*/
@VisibleForTesting
ApigeeLlm(String modelName, Gemini geminiDelegate) {
this(modelName, geminiDelegate, null);
}

/**
* Constructs a new ApigeeLlm instance for testing purposes.
*
* @param modelName The name of the Apigee model to use.
* @param geminiDelegate The Gemini delegate to use for making API calls.
* @param chatCompletionsHttpClient The ChatCompletionsHttpClient to use for making API calls.
*/
@VisibleForTesting
ApigeeLlm(
String modelName,
Gemini geminiDelegate,
ChatCompletionsHttpClient chatCompletionsHttpClient) {
super(modelName);
this.apiClient = null;
this.httpOptions = null;
this.geminiDelegate = geminiDelegate;
this.chatCompletionsHttpClient = chatCompletionsHttpClient;
if (chatCompletionsHttpClient != null) {
this.apiType = ApiType.CHAT_COMPLETIONS;
} else {
this.apiType = ApiType.GENAI;
}
}

/**
Expand Down Expand Up @@ -178,6 +236,7 @@ public static class Builder {
private String modelName;
private String proxyUrl;
private Map<String, String> customHeaders = new HashMap<>();
private ApiType apiType = ApiType.UNKNOWN;

protected Builder() {}

Expand Down Expand Up @@ -243,6 +302,18 @@ public Builder customHeaders(Map<String, String> customHeaders) {
return this;
}

/**
* Sets the explicit {@link ApiType} to use (e.g., CHAT_COMPLETIONS or GENAI).
*
* @param apiType the type of API.
* @return this builder.
*/
@CanIgnoreReturnValue
public Builder apiType(ApiType apiType) {
this.apiType = apiType;
return this;
}

/**
* Builds the {@link ApigeeLlm} instance.
*
Expand All @@ -255,7 +326,7 @@ public ApigeeLlm build() {
throw new IllegalArgumentException("Invalid model string: " + modelName);
}

return new ApigeeLlm(modelName, proxyUrl, customHeaders);
return new ApigeeLlm(modelName, proxyUrl, customHeaders, apiType);
}
}

Expand All @@ -264,11 +335,23 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
String modelToUse = llmRequest.model().orElse(model());
String modelId = getModelId(modelToUse);
LlmRequest newLlmRequest = llmRequest.toBuilder().model(modelId).build();

logger.debug("ApigeeLlm.generateContent routing through {} for model {}", apiType, modelId);

if (apiType == ApiType.CHAT_COMPLETIONS) {
return chatCompletionsHttpClient.complete(newLlmRequest, stream);
}

return geminiDelegate.generateContent(newLlmRequest, stream);
}

@Override
public BaseLlmConnection connect(LlmRequest llmRequest) {
if (apiType == ApiType.CHAT_COMPLETIONS) {
throw new UnsupportedOperationException(
"Streaming connections are not supported for chat completions.");
}

String modelToUse = llmRequest.model().orElse(model());
String modelId = getModelId(modelToUse);
LlmRequest newLlmRequest = llmRequest.toBuilder().model(modelId).build();
Expand Down Expand Up @@ -297,7 +380,9 @@ private static boolean validateModelString(String model) {
return components[1].startsWith("v");
}
if (components.length == 2) {
if (components[0].equals("vertex_ai") || components[0].equals("gemini")) {
if (components[0].equals("vertex_ai")
|| components[0].equals("gemini")
|| components[0].equals("openai")) {
return true;
}
return components[0].startsWith("v");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
* <a href="https://developers.openai.com/api/reference/resources/chat">OpenAI Chat Completions API
* reference</a> for the wire protocol.
*/
public final class ChatCompletionsHttpClient {
public class ChatCompletionsHttpClient {
private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsHttpClient.class);
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();

Expand Down Expand Up @@ -190,14 +190,21 @@ private static Duration resolveCallTimeout(HttpOptions httpOptions) {
public Flowable<LlmResponse> complete(LlmRequest llmRequest, boolean stream) {
return Flowable.defer(
() -> {
String effectiveModelName = llmRequest.model().orElse("?");
logger.trace("Chat Completion Request Contents: {}", llmRequest.contents());
llmRequest.config().ifPresent(c -> logger.trace("Chat Completion Request Config: {}", c));

ChatCompletionsRequest dtoRequest =
ChatCompletionsRequest.fromLlmRequest(llmRequest, stream);
String jsonPayload = objectMapper.writeValueAsString(dtoRequest);
logger.trace(
"Chat Completion Request: model={}, stream={}, messagesCount={}",
dtoRequest.model,
dtoRequest.stream,
dtoRequest.messages != null ? dtoRequest.messages.size() : 0);
logger.trace("Chat Completion Request JSON: {}", jsonPayload);

if (stream) {
logger.debug(
"Sending streaming chat-completion request to model {}", effectiveModelName);
} else {
logger.debug("Sending chat-completion request to model {}", effectiveModelName);
}

Request.Builder requestBuilder =
new Request.Builder().url(completionsUrl).post(RequestBody.create(jsonPayload, JSON));
Expand All @@ -209,11 +216,7 @@ public Flowable<LlmResponse> complete(LlmRequest llmRequest, boolean stream) {
requestBuilder.header("Content-Type", JSON.toString());

Request request = requestBuilder.build();
if (stream) {
return createStreamingFlowable(request);
} else {
return createNonStreamingFlowable(request);
}
return stream ? createStreamingFlowable(request) : createNonStreamingFlowable(request);
});
}

Expand Down Expand Up @@ -274,10 +277,14 @@ public void onResponse(Call call, Response response) {
// A single malformed chunk must not abort the entire stream. Log a
// warning and continue.
try {
logger.trace("Raw streaming chat-completion chunk: {}", data);
ChatCompletionsResponse.ChatCompletionChunk chunk =
objectMapper.readValue(
data, ChatCompletionsResponse.ChatCompletionChunk.class);
ImmutableList<LlmResponse> responses = collection.processChunk(chunk);
if (!responses.isEmpty()) {
logger.trace("Responses to emit: {}", responses);
}
for (LlmResponse resp : responses) {
emitter.onNext(resp);
}
Expand Down Expand Up @@ -341,9 +348,12 @@ public void onResponse(Call call, Response response) {
}

String jsonResponse = body.string();
logger.trace("Raw non-streaming chat-completion response: {}", jsonResponse);
ChatCompletionsResponse.ChatCompletion completion =
objectMapper.readValue(jsonResponse, ChatCompletionsResponse.ChatCompletion.class);
emitter.onNext(completion.toLlmResponse());
LlmResponse llmResponse = completion.toLlmResponse();
logger.trace("Response to emit: {}", llmResponse);
emitter.onNext(llmResponse);
emitter.onComplete();
} catch (Exception e) {
emitter.tryOnError(e);
Expand Down
Loading