Skip to content

Commit 8723620

Browse files
committed
[Stream] Support completions
1 parent f644eec commit 8723620

11 files changed

Lines changed: 278 additions & 3 deletions

File tree

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
---
2+
title: Console
3+
---
4+
5+
!!! Note
6+
7+
Please build the client before calling, the build code is as follows:
8+
9+
```java
10+
CountDownLatch countDownLatch = new CountDownLatch(1);
11+
ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
12+
.countDownLatch(countDownLatch)
13+
.build();
14+
OpenAiClient client = OpenAiClient.builder()
15+
.apiKey(System.getProperty("openai.token"))
16+
.listener(listener)
17+
.build();
18+
```
19+
20+
`System.getProperty("openai.token")` is the key to access the API authorization.
21+
22+
### Create completion
23+
24+
---
25+
26+
Creates a completion for the provided prompt and parameters.
27+
28+
```java
29+
CompletionEntity configure = CompletionEntity.builder()
30+
.model(CompleteModel.TEXT_DAVINCI_003.getName())
31+
.prompt("How to create a completion")
32+
.temperature(2D)
33+
.build();
34+
client.createCompletion(configure);
35+
```
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
---
2+
title: Console
3+
---
4+
5+
!!! Note
6+
7+
调用前请先构建客户端,构建代码如下:
8+
9+
```java
10+
CountDownLatch countDownLatch = new CountDownLatch(1);
11+
ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
12+
.countDownLatch(countDownLatch)
13+
.build();
14+
OpenAiClient client = OpenAiClient.builder()
15+
.apiKey(System.getProperty("openai.token"))
16+
.listener(listener)
17+
.build();
18+
```
19+
20+
`System.getProperty("openai.token")` 是访问 API 授权的关键。
21+
22+
### Create completion
23+
24+
---
25+
26+
为提供的提示和参数创建补全。
27+
28+
```java
29+
CompletionEntity configure = CompletionEntity.builder()
30+
.model(CompleteModel.TEXT_DAVINCI_003.getName())
31+
.prompt("How to create a completion")
32+
.temperature(2D)
33+
.build();
34+
client.createCompletion(configure);
35+
```

docs/mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ plugins:
7474
nav:
7575
- index.md
7676
- Reference:
77+
- Stream (Not provider):
78+
- reference/stream/console.md
7779
- Open Ai:
7880
- reference/openai/users.md
7981
- reference/openai/models.md

pom.xml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
<groupId>org.devlive.sdk</groupId>
77
<artifactId>openai-java-sdk</artifactId>
8-
<version>1.7.0</version>
8+
<version>1.8.0-SNAPSHOT</version>
99

1010
<name>openai-java-sdk</name>
1111
<description>
@@ -103,6 +103,11 @@
103103
<artifactId>okhttp</artifactId>
104104
<version>${okhttp.version}</version>
105105
</dependency>
106+
<dependency>
107+
<groupId>com.squareup.okhttp3</groupId>
108+
<artifactId>okhttp-sse</artifactId>
109+
<version>${okhttp.version}</version>
110+
</dependency>
106111
<dependency>
107112
<groupId>com.google.guava</groupId>
108113
<artifactId>guava</artifactId>

src/main/java/org/devlive/sdk/openai/DefaultClient.java

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
package org.devlive.sdk.openai;
22

3+
import com.fasterxml.jackson.databind.ObjectMapper;
34
import lombok.extern.slf4j.Slf4j;
45
import okhttp3.MultipartBody;
56
import okhttp3.OkHttpClient;
7+
import okhttp3.Request;
8+
import okhttp3.RequestBody;
9+
import okhttp3.sse.EventSource;
10+
import okhttp3.sse.EventSourceListener;
11+
import okhttp3.sse.EventSources;
612
import org.apache.commons.lang3.ObjectUtils;
713
import org.devlive.sdk.openai.entity.AudioEntity;
814
import org.devlive.sdk.openai.entity.ChatEntity;
@@ -14,6 +20,8 @@
1420
import org.devlive.sdk.openai.entity.ModelEntity;
1521
import org.devlive.sdk.openai.entity.ModerationEntity;
1622
import org.devlive.sdk.openai.entity.UserKeyEntity;
23+
import org.devlive.sdk.openai.exception.RequestException;
24+
import org.devlive.sdk.openai.mixin.IgnoreUnknownMixin;
1725
import org.devlive.sdk.openai.model.ProviderModel;
1826
import org.devlive.sdk.openai.model.UrlModel;
1927
import org.devlive.sdk.openai.response.AudioResponse;
@@ -36,6 +44,8 @@ public abstract class DefaultClient
3644
protected DefaultApi api;
3745
protected ProviderModel provider;
3846
protected OkHttpClient client;
47+
protected String apiHost;
48+
protected EventSourceListener listener;
3949

4050
public ModelResponse getModels()
4151
{
@@ -51,8 +61,16 @@ public ModelEntity getModel(String model)
5161

5262
public CompleteResponse createCompletion(CompletionEntity configure)
5363
{
54-
return this.api.fetchCompletions(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure)
55-
.blockingGet();
64+
String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS);
65+
if (ObjectUtils.isNotEmpty(this.listener)) {
66+
configure.setStream(true);
67+
this.createEventSource(url, configure);
68+
return null;
69+
}
70+
else {
71+
return this.api.fetchCompletions(url, configure)
72+
.blockingGet();
73+
}
5674
}
5775

5876
public ChatResponse createChatCompletion(ChatEntity configure)
@@ -168,6 +186,29 @@ public Object retrieveFileContent(String id)
168186
.blockingGet();
169187
}
170188

189+
private ObjectMapper createObjectMapper()
190+
{
191+
ObjectMapper objectMapper = new ObjectMapper();
192+
objectMapper.addMixIn(Object.class, IgnoreUnknownMixin.class);
193+
return objectMapper;
194+
}
195+
196+
private void createEventSource(String url, Object configure)
197+
{
198+
try {
199+
EventSource.Factory factory = EventSources.createFactory(this.client);
200+
ObjectMapper mapper = this.createObjectMapper();
201+
Request request = new Request.Builder()
202+
.url(String.join("/", this.apiHost, url))
203+
.post(RequestBody.create(MultipartBodyUtils.JSON, mapper.writeValueAsString(configure)))
204+
.build();
205+
factory.newEventSource(request, this.listener);
206+
}
207+
catch (Exception e) {
208+
throw new RequestException(String.format("Failed to create event source: %s", e.getMessage()));
209+
}
210+
}
211+
171212
public void close()
172213
{
173214
if (ObjectUtils.isNotEmpty(this.client)) {

src/main/java/org/devlive/sdk/openai/OpenAiClient.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import lombok.Builder;
66
import lombok.extern.slf4j.Slf4j;
77
import okhttp3.OkHttpClient;
8+
import okhttp3.sse.EventSourceListener;
89
import org.apache.commons.lang3.ObjectUtils;
910
import org.apache.commons.lang3.StringUtils;
1011
import org.devlive.sdk.openai.exception.ParamException;
@@ -35,6 +36,8 @@ public class OpenAiClient
3536
// Azure provider requires
3637
private String model; // The model name deployed in azure
3738
private String version;
39+
// Support see
40+
private EventSourceListener listener;
3841

3942
private OpenAiClient(OpenAiClientBuilder builder)
4043
{
@@ -69,9 +72,14 @@ private OpenAiClient(OpenAiClientBuilder builder)
6972
if (ObjectUtils.isEmpty(builder.client)) {
7073
builder.client(null);
7174
}
75+
if (ObjectUtils.isEmpty(builder.listener)) {
76+
builder.listener(null);
77+
}
7278

7379
super.provider = builder.provider;
7480
super.client = builder.client;
81+
super.listener = builder.listener;
82+
super.apiHost = builder.apiHost;
7583
// Build a remote API client
7684
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
7785
this.api = new Retrofit.Builder()
@@ -160,6 +168,9 @@ public OpenAiClientBuilder client(OkHttpClient client)
160168

161169
private String getDefaultHost()
162170
{
171+
if (ObjectUtils.isEmpty(this.provider)) {
172+
this.provider = ProviderModel.OPENAI;
173+
}
163174
if (this.provider.equals(ProviderModel.CLAUDE)) {
164175
return "https://api.anthropic.com";
165176
}

src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ public class CompletionEntity
5050
@JsonProperty(value = "stop")
5151
private List<String> stop;
5252

53+
/**
54+
* Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
55+
* 是否流回部分进度。如果设置,令牌将在可用时作为仅数据服务器发送事件发送,流由 data: [DONE] 消息终止。
56+
*/
57+
@JsonProperty(value = "stream")
58+
private boolean stream = false;
59+
5360
private CompletionEntity(CompletionEntityBuilder builder)
5461
{
5562
if (ObjectUtils.isEmpty(builder.model)) {
@@ -151,6 +158,11 @@ public CompletionEntityBuilder presencePenalty(Double presencePenalty)
151158
return this;
152159
}
153160

161+
private CompletionEntityBuilder stream()
162+
{
163+
return this;
164+
}
165+
154166
public CompletionEntity build()
155167
{
156168
return new CompletionEntity(this);
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package org.devlive.sdk.openai.listener;
2+
3+
import com.fasterxml.jackson.core.JsonProcessingException;
4+
import lombok.Builder;
5+
import lombok.extern.slf4j.Slf4j;
6+
import okhttp3.Response;
7+
import okhttp3.sse.EventSource;
8+
import okhttp3.sse.EventSourceListener;
9+
import org.apache.commons.lang3.ObjectUtils;
10+
import org.devlive.sdk.openai.response.CompleteResponse;
11+
import org.devlive.sdk.openai.utils.JsonUtils;
12+
import org.jetbrains.annotations.NotNull;
13+
import org.jetbrains.annotations.Nullable;
14+
15+
import java.time.LocalDateTime;
16+
import java.util.concurrent.CountDownLatch;
17+
18+
@Slf4j
19+
@Builder
20+
public class ConsoleEventSourceListener
21+
extends EventSourceListener
22+
{
23+
private CountDownLatch countDownLatch;
24+
private JsonUtils<CompleteResponse> jsonUtils;
25+
26+
@Override
27+
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response)
28+
{
29+
log.info("Console listener opened on time {}", LocalDateTime.now());
30+
this.jsonUtils = JsonUtils.getInstance();
31+
}
32+
33+
@Override
34+
public void onClosed(@NotNull EventSource eventSource)
35+
{
36+
log.info("Console listener closed on time {}", LocalDateTime.now());
37+
eventSource.cancel();
38+
this.close();
39+
}
40+
41+
@Override
42+
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data)
43+
{
44+
// OpenAI ends with [DONE] by default
45+
if (data.equals("[DONE]")) {
46+
eventSource.cancel();
47+
this.close();
48+
}
49+
else {
50+
try {
51+
CompleteResponse completeResponse = jsonUtils.getObject(data, CompleteResponse.class);
52+
log.info("Console event received on time {} id {} type {} data {}", LocalDateTime.now(), id, type, completeResponse.getChoices().get(0).getContent());
53+
}
54+
catch (JsonProcessingException e) {
55+
log.warn("Console event error on time {} id {} type {} data {}", LocalDateTime.now(), id, type, data, e);
56+
}
57+
}
58+
}
59+
60+
@Override
61+
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable throwable, @Nullable Response response)
62+
{
63+
if (throwable.getMessage().endsWith("CANCEL")) {
64+
log.info("Console listener cancelled on time {}", LocalDateTime.now());
65+
this.onClosed(eventSource);
66+
}
67+
else {
68+
log.error("Console listener throwable \n{}\n response: \n{}\n", throwable, response);
69+
}
70+
eventSource.cancel();
71+
this.close();
72+
}
73+
74+
private void close()
75+
{
76+
if (ObjectUtils.isNotEmpty(this.countDownLatch)) {
77+
this.countDownLatch.countDown();
78+
}
79+
}
80+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package org.devlive.sdk.openai.mixin;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
5+
@JsonIgnoreProperties(ignoreUnknown = true)
6+
public abstract class IgnoreUnknownMixin
7+
{
8+
}

src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
public class MultipartBodyUtils
1010
{
1111
public static final MediaType TYPE = MediaType.parse("multipart/form-data");
12+
public static final MediaType JSON = MediaType.parse("application/json; charset=utf-8");
1213

1314
private MultipartBodyUtils()
1415
{

0 commit comments

Comments
 (0)