Skip to content

Commit 560977a

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add streaming support for ChatCompletionsHTTPClient
This is part of a larger chain of commits for adding chat completion API support to the Apigee model. PiperOrigin-RevId: 907190909
1 parent b4791ef commit 560977a

2 files changed

Lines changed: 263 additions & 31 deletions

File tree

core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717
package com.google.adk.models.chat;
1818

19+
import com.fasterxml.jackson.core.JsonProcessingException;
1920
import com.fasterxml.jackson.databind.ObjectMapper;
2021
import com.google.adk.JsonBaseModel;
2122
import com.google.adk.models.LlmRequest;
2223
import com.google.adk.models.LlmResponse;
24+
import com.google.common.annotations.VisibleForTesting;
25+
import com.google.common.collect.ImmutableList;
2326
import com.google.common.collect.ImmutableMap;
2427
import com.google.genai.types.HttpOptions;
2528
import io.reactivex.rxjava3.core.BackpressureStrategy;
@@ -38,6 +41,7 @@
3841
import okhttp3.RequestBody;
3942
import okhttp3.Response;
4043
import okhttp3.ResponseBody;
44+
import okio.BufferedSource;
4145
import org.slf4j.Logger;
4246
import org.slf4j.LoggerFactory;
4347

@@ -49,11 +53,12 @@
4953
* <a href="https://developers.openai.com/api/reference/resources/chat">OpenAI Chat Completions API
5054
* reference</a> for the wire protocol.
5155
*/
52-
public class ChatCompletionsHttpClient {
56+
public final class ChatCompletionsHttpClient {
5357
private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsHttpClient.class);
5458
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();
5559

5660
private static final MediaType JSON = MediaType.get("application/json; charset=utf-8");
61+
private static final String SSE_DATA_PREFIX = "data:";
5762

5863
/**
5964
* Default OkHttp call timeout used when the caller does not supply an {@link HttpOptions}
@@ -113,6 +118,10 @@ public class ChatCompletionsHttpClient {
113118
* HTTP(S) URL.
114119
*/
115120
public ChatCompletionsHttpClient(HttpOptions httpOptions) {
121+
this(httpOptions, buildClient(httpOptions));
122+
}
123+
124+
private ChatCompletionsHttpClient(HttpOptions httpOptions, OkHttpClient client) {
116125
Objects.requireNonNull(httpOptions, "httpOptions cannot be null");
117126
String baseUrl =
118127
httpOptions
@@ -133,14 +142,30 @@ public ChatCompletionsHttpClient(HttpOptions httpOptions) {
133142
.headers()
134143
.<ImmutableMap<String, String>>map(ImmutableMap::copyOf)
135144
.orElse(ImmutableMap.of());
145+
this.client = client;
146+
}
136147

137-
// Apply custom timeouts per instance. All internal timeouts are bounded by callTimeout.
148+
/**
149+
* Test-only factory that injects a custom {@link OkHttpClient} (typically a mock) without
150+
* touching production wiring. Production callers should use the public constructor.
151+
*/
152+
@VisibleForTesting
153+
static ChatCompletionsHttpClient forTesting(HttpOptions httpOptions, OkHttpClient client) {
154+
return new ChatCompletionsHttpClient(httpOptions, client);
155+
}
156+
157+
/**
158+
* Builds the production OkHttpClient by forking {@link #SHARED_POOL_CLIENT} so the connection
159+
* pool and dispatcher are reused across instances while applying per-instance timeouts.
160+
*/
161+
private static OkHttpClient buildClient(HttpOptions httpOptions) {
162+
Objects.requireNonNull(httpOptions, "httpOptions cannot be null");
138163
OkHttpClient.Builder builder = SHARED_POOL_CLIENT.newBuilder();
139164
builder.connectTimeout(Duration.ZERO);
140165
builder.readTimeout(Duration.ZERO);
141166
builder.writeTimeout(Duration.ZERO);
142167
builder.callTimeout(resolveCallTimeout(httpOptions));
143-
this.client = builder.build();
168+
return builder.build();
144169
}
145170

146171
/** Resolves the call timeout from HttpOptions. */
@@ -192,11 +217,82 @@ public Flowable<LlmResponse> complete(LlmRequest llmRequest, boolean stream) {
192217
});
193218
}
194219

195-
/** Placeholder for streaming responses. Errors with {@link UnsupportedOperationException}. */
196-
@SuppressWarnings("UnusedVariable")
197220
private Flowable<LlmResponse> createStreamingFlowable(Request request) {
198-
return Flowable.error(
199-
new UnsupportedOperationException("Streaming is not yet implemented in this client."));
221+
return Flowable.create(
222+
emitter -> {
223+
Call call = client.newCall(request);
224+
emitter.setCancellable(call::cancel);
225+
call.enqueue(
226+
new Callback() {
227+
@Override
228+
public void onFailure(Call call, IOException e) {
229+
emitter.tryOnError(e);
230+
}
231+
232+
@Override
233+
public void onResponse(Call call, Response response) {
234+
try (ResponseBody body = response.body()) {
235+
if (!response.isSuccessful()) {
236+
String bodyStr = body != null ? body.string() : "";
237+
emitter.tryOnError(
238+
new IOException(
239+
"HTTP request failed with status: "
240+
+ response
241+
+ " - body: "
242+
+ bodyStr));
243+
return;
244+
}
245+
if (body == null) {
246+
emitter.tryOnError(new IOException("Empty response body"));
247+
return;
248+
}
249+
250+
BufferedSource source = body.source();
251+
ChatCompletionsResponse.ChatCompletionChunkCollection collection =
252+
new ChatCompletionsResponse.ChatCompletionChunkCollection();
253+
while (!source.exhausted() && !emitter.isCancelled()) {
254+
String line = source.readUtf8Line();
255+
if (line == null) {
256+
break;
257+
}
258+
if (line.isEmpty()) {
259+
continue;
260+
}
261+
// TODO: Support SSE "event", "id", and "retry".
262+
// See
263+
// https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
264+
if (!line.startsWith(SSE_DATA_PREFIX)) {
265+
logger.debug("Ignoring SSE line without data prefix: {}", line);
266+
continue;
267+
}
268+
// The SSE spec allows whitespace after the prefix,
269+
// eg: "data:foo" vs "data: foo".
270+
String data = line.substring(SSE_DATA_PREFIX.length()).stripLeading();
271+
if (data.equals("[DONE]")) {
272+
break;
273+
}
274+
// A single malformed chunk must not abort the entire stream. Log a
275+
// warning and continue.
276+
try {
277+
ChatCompletionsResponse.ChatCompletionChunk chunk =
278+
objectMapper.readValue(
279+
data, ChatCompletionsResponse.ChatCompletionChunk.class);
280+
ImmutableList<LlmResponse> responses = collection.processChunk(chunk);
281+
for (LlmResponse resp : responses) {
282+
emitter.onNext(resp);
283+
}
284+
} catch (JsonProcessingException e) {
285+
logger.warn("Failed to parse JSON chunk: {}", data, e);
286+
}
287+
}
288+
emitter.onComplete();
289+
} catch (Exception e) {
290+
emitter.tryOnError(e);
291+
}
292+
}
293+
});
294+
},
295+
BackpressureStrategy.BUFFER);
200296
}
201297

202298
/**
@@ -235,7 +331,8 @@ public void onResponse(Call call, Response response) {
235331
if (!response.isSuccessful()) {
236332
String bodyStr = body != null ? body.string() : "";
237333
emitter.tryOnError(
238-
new IOException("Unexpected code " + response + " - body: " + bodyStr));
334+
new IOException(
335+
"HTTP request failed with status: " + response + " - body: " + bodyStr));
239336
return;
240337
}
241338
if (body == null) {

0 commit comments

Comments
 (0)