Skip to content

Commit 7900738

Browse files
authored
Fix streaming APIs in Agents and Projects libraries (#48832)
* Fix streaming APIs in Agents and Projects libraries * add unit tests * update HttpResponse * remove unused logger * remove unused import
1 parent 12b3511 commit 7900738

File tree

8 files changed

+796
-10
lines changed

8 files changed

+796
-10
lines changed

sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/http/AzureHttpResponseAdapter.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import com.azure.core.http.HttpHeader;
77
import com.azure.core.http.HttpHeaders;
8-
import com.azure.core.util.logging.ClientLogger;
98
import com.openai.core.http.Headers;
109
import com.openai.core.http.HttpResponse;
1110

@@ -17,8 +16,6 @@
1716
*/
1817
final class AzureHttpResponseAdapter implements HttpResponse {
1918

20-
private static final ClientLogger LOGGER = new ClientLogger(AzureHttpResponseAdapter.class);
21-
2219
private final com.azure.core.http.HttpResponse azureResponse;
2320

2421
/**
@@ -42,7 +39,9 @@ public Headers headers() {
4239

4340
@Override
4441
public InputStream body() {
45-
return azureResponse.getBodyAsBinaryData().toStream();
42+
// replace with azureResponse.bodyStream() and delete FluxInputStream class from this package
43+
// when new version of azure-core is released.
44+
return new FluxInputStream(azureResponse.getBody());
4645
}
4746

4847
@Override
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.azure.ai.agents.implementation.http;
5+
6+
import com.azure.core.util.FluxUtil;
7+
import com.azure.core.util.logging.ClientLogger;
8+
import org.reactivestreams.Subscription;
9+
import reactor.core.publisher.Flux;
10+
11+
import java.io.ByteArrayInputStream;
12+
import java.io.IOException;
13+
import java.io.InputStream;
14+
import java.nio.Buffer;
15+
import java.nio.ByteBuffer;
16+
import java.util.concurrent.locks.Condition;
17+
import java.util.concurrent.locks.Lock;
18+
import java.util.concurrent.locks.ReentrantLock;
19+
20+
/**
21+
* An InputStream that subscribes to a Flux.
22+
*/
23+
public class FluxInputStream extends InputStream {
24+
25+
private static final ClientLogger LOGGER = new ClientLogger(FluxInputStream.class);
26+
27+
// The data to subscribe to.
28+
private final Flux<ByteBuffer> data;
29+
30+
// Subscription to request more data from as needed
31+
private Subscription subscription;
32+
33+
private ByteArrayInputStream buffer;
34+
35+
private volatile boolean subscribed;
36+
private volatile boolean fluxComplete;
37+
private volatile boolean waitingForData;
38+
39+
/* The following lock and condition variable is to synchronize access between the reader and the
40+
reactor thread asynchronously reading data from the Flux. If no data is available, the reader
41+
acquires the lock and waits on the dataAvailable condition variable. Once data is available
42+
(or an error or completion event occurs) the reactor thread acquires the lock and signals that
43+
data is available. */
44+
private final Lock lock;
45+
private final Condition dataAvailable;
46+
47+
private IOException lastError;
48+
49+
/**
50+
* Creates a new FluxInputStream
51+
*
52+
* @param data The data to subscribe to and read from.
53+
*/
54+
public FluxInputStream(Flux<ByteBuffer> data) {
55+
this.subscribed = false;
56+
this.fluxComplete = false;
57+
this.waitingForData = false;
58+
this.data = data;
59+
this.lock = new ReentrantLock();
60+
this.dataAvailable = lock.newCondition();
61+
}
62+
63+
@Override
64+
public int read() throws IOException {
65+
byte[] ret = new byte[1];
66+
int count = read(ret, 0, 1);
67+
return count == -1 ? -1 : (ret[0] & 0xFF);
68+
}
69+
70+
@Override
71+
public int read(byte[] b, int off, int len) throws IOException {
72+
validateParameters(b, off, len);
73+
74+
/* If len is 0, then no bytes are read and 0 is returned. */
75+
if (len == 0) {
76+
return 0;
77+
}
78+
/* Attempt to read at least one byte. If no byte is available because the stream is at end of file,
79+
the value -1 is returned; otherwise, at least one byte is read and stored into b. */
80+
81+
/* Not subscribed? subscribe and block for data */
82+
if (!subscribed) {
83+
blockForData();
84+
}
85+
/* Now, we have subscribed. */
86+
/* At this point, buffer should not be null. If it is, that indicates either an error or completion event
87+
was emitted by the Flux. */
88+
if (this.buffer == null) { // Only executed on first subscription.
89+
if (this.lastError != null) {
90+
throw LOGGER.logThrowableAsError(this.lastError);
91+
}
92+
if (this.fluxComplete) {
93+
return -1;
94+
}
95+
throw LOGGER.logExceptionAsError(new IllegalStateException("An unexpected error occurred. No data was "
96+
+ "read from the stream but the stream did not indicate completion."));
97+
}
98+
99+
/* Now we are guaranteed that buffer is SOMETHING. */
100+
/* No data is available in the buffer. */
101+
if (this.buffer.available() == 0) {
102+
/* If the flux completed, there is no more data available to be read from the stream. Return -1. */
103+
if (this.fluxComplete) {
104+
return -1;
105+
}
106+
/* Block current thread until data is available. */
107+
blockForData();
108+
}
109+
110+
/* Data available in buffer, read the buffer. */
111+
if (this.buffer.available() > 0) {
112+
return this.buffer.read(b, off, len);
113+
}
114+
115+
/* If the flux completed, there is no more data available to be read from the stream. Return -1. */
116+
if (this.fluxComplete) {
117+
return -1;
118+
} else {
119+
throw LOGGER.logExceptionAsError(new IllegalStateException("An unexpected error occurred. No data was "
120+
+ "read from the stream but the stream did not indicate completion."));
121+
}
122+
}
123+
124+
@Override
125+
public void close() throws IOException {
126+
if (subscription != null) {
127+
subscription.cancel();
128+
}
129+
130+
if (this.buffer != null) {
131+
this.buffer.close();
132+
}
133+
super.close();
134+
if (this.lastError != null) {
135+
throw LOGGER.logThrowableAsError(this.lastError);
136+
}
137+
}
138+
139+
/**
140+
* Request more data and wait on data to become available.
141+
*/
142+
private void blockForData() {
143+
lock.lock();
144+
try {
145+
waitingForData = true;
146+
if (!subscribed) {
147+
subscribeToData();
148+
} else {
149+
subscription.request(1);
150+
}
151+
// Block current thread until data is available.
152+
while (waitingForData) {
153+
if (fluxComplete) {
154+
break;
155+
} else {
156+
try {
157+
dataAvailable.await();
158+
} catch (InterruptedException e) {
159+
Thread.currentThread().interrupt();
160+
throw LOGGER.logExceptionAsError(new RuntimeException(e));
161+
}
162+
}
163+
}
164+
} finally {
165+
lock.unlock();
166+
}
167+
}
168+
169+
/**
170+
* Subscribes to the data with a special subscriber.
171+
*/
172+
@SuppressWarnings("deprecation")
173+
private void subscribeToData() {
174+
this.data.filter(Buffer::hasRemaining) /* Filter to make sure only non empty byte buffers are emitted. */
175+
.onBackpressureBuffer()
176+
.subscribe(
177+
// ByteBuffer consumer
178+
byteBuffer -> {
179+
this.buffer = new ByteArrayInputStream(FluxUtil.byteBufferToArray(byteBuffer));
180+
lock.lock();
181+
try {
182+
this.waitingForData = false;
183+
// Signal the consumer when data is available.
184+
dataAvailable.signal();
185+
} finally {
186+
lock.unlock();
187+
}
188+
},
189+
// Error consumer
190+
throwable -> {
191+
// Signal the consumer in case an error occurs (indicates we completed without data).
192+
if (throwable instanceof IOException) {
193+
this.lastError = (IOException) throwable;
194+
} else {
195+
this.lastError = new IOException(throwable);
196+
}
197+
signalOnCompleteOrError();
198+
},
199+
// Complete consumer
200+
// Signal the consumer in case we completed without data.
201+
this::signalOnCompleteOrError,
202+
// Subscription consumer
203+
subscription -> {
204+
this.subscription = subscription;
205+
this.subscribed = true;
206+
this.subscription.request(1);
207+
});
208+
}
209+
210+
/**
211+
* Signals to the subscriber when the flux completes without data (onCompletion or onError)
212+
*/
213+
private void signalOnCompleteOrError() {
214+
this.fluxComplete = true;
215+
lock.lock();
216+
try {
217+
this.waitingForData = false;
218+
dataAvailable.signal();
219+
} finally {
220+
lock.unlock();
221+
}
222+
}
223+
224+
/**
225+
* Validates parameters according to {@link InputStream#read(byte[], int, int)} spec.
226+
*
227+
* @param bytes the buffer into which the data is read.
228+
* @param offset the start offset in array bytes at which the data is written.
229+
* @param length the maximum number of bytes to read.
230+
*/
231+
private void validateParameters(byte[] bytes, int offset, int length) {
232+
if (bytes == null) {
233+
throw LOGGER.logExceptionAsError(new NullPointerException("'bytes' cannot be null"));
234+
}
235+
if (offset < 0) {
236+
throw LOGGER.logExceptionAsError(new IndexOutOfBoundsException("'offset' cannot be less than 0"));
237+
}
238+
if (length < 0) {
239+
throw LOGGER.logExceptionAsError(new IndexOutOfBoundsException("'length' cannot be less than 0"));
240+
}
241+
if (length > (bytes.length - offset)) {
242+
throw LOGGER.logExceptionAsError(
243+
new IndexOutOfBoundsException("'length' cannot be greater than 'bytes'.length - 'offset'"));
244+
}
245+
}
246+
}

sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/http/HttpClientHelper.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.openai.core.http.HttpRequestBody;
2222
import com.openai.core.http.HttpResponse;
2323
import com.openai.errors.BadRequestException;
24+
import reactor.core.scheduler.Schedulers;
2425
import com.openai.errors.InternalServerException;
2526
import com.openai.errors.NotFoundException;
2627
import com.openai.errors.OpenAIException;
@@ -110,6 +111,10 @@ public CompletableFuture<HttpResponse> executeAsync(HttpRequest request, Request
110111
.flatMap(azureRequest -> this.httpPipeline.send(azureRequest, buildRequestContext(requestOptions)))
111112
.map(response -> (HttpResponse) new AzureHttpResponseAdapter(response))
112113
.onErrorMap(HttpClientWrapper::mapAzureExceptionToOpenAI)
114+
// publishOn moves the CompletableFuture completion (and all OpenAI SDK continuations that
115+
// run synchronously on it) off the Netty/OkHttp I/O thread and onto a thread pool that
116+
// is safe to block.
117+
.publishOn(Schedulers.boundedElastic())
113118
.toFuture();
114119
}
115120

@@ -240,7 +245,7 @@ private static HttpHeaders toAzureHeaders(Headers sourceHeaders) {
240245
* @return Azure request {@link Context}
241246
*/
242247
private static Context buildRequestContext(RequestOptions requestOptions) {
243-
Context context = new Context("azure-eagerly-read-response", true);
248+
Context context = Context.NONE;
244249
Timeout timeout = requestOptions.getTimeout();
245250
// we use "read" as it's the closest thing to the "response timeout"
246251
if (timeout != null && !timeout.read().isZero() && !timeout.read().isNegative()) {

0 commit comments

Comments
 (0)