Skip to content

Commit bcc2dcd

Browse files
authored
feat(rag): OpenAI Embedding Support (#191)
1 parent 53a0731 commit bcc2dcd

6 files changed

Lines changed: 1164 additions & 0 deletions

File tree

agentscope-extensions/agentscope-extensions-rag-simple/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
</exclusions>
5050
</dependency>
5151

52+
<!-- OpenAI Java SDK -->
53+
<dependency>
54+
<groupId>com.openai</groupId>
55+
<artifactId>openai-java</artifactId>
56+
</dependency>
57+
5258
<!-- Qdrant Java SDK -->
5359
<dependency>
5460
<groupId>io.qdrant</groupId>

agentscope-extensions/agentscope-extensions-rag-simple/src/main/java/io/agentscope/core/embedding/EmbeddingUtils.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,21 @@ public static double[] convertDoubleListToArray(List<Double> values) {
229229
}
230230
return array;
231231
}
232+
233+
/**
234+
* Converts a List of Float values to a double array.
235+
*
236+
* <p>This method is used to convert embedding vectors returned by OpenAI SDK
237+
* (which uses List&lt;Float&gt;) to the standard double[] format used by EmbeddingModel.
238+
*
239+
* @param values the list of Float values
240+
* @return the double array
241+
*/
242+
public static double[] convertFloatListToDoubleArray(List<Float> values) {
243+
double[] array = new double[values.size()];
244+
for (int i = 0; i < values.size(); i++) {
245+
array[i] = values.get(i);
246+
}
247+
return array;
248+
}
232249
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.agentscope.core.embedding.openai;
17+
18+
import com.openai.client.OpenAIClient;
19+
import com.openai.client.okhttp.OpenAIOkHttpClient;
20+
import com.openai.models.embeddings.CreateEmbeddingResponse;
21+
import com.openai.models.embeddings.Embedding;
22+
import com.openai.models.embeddings.EmbeddingCreateParams;
23+
import io.agentscope.core.Version;
24+
import io.agentscope.core.embedding.EmbeddingException;
25+
import io.agentscope.core.embedding.EmbeddingModel;
26+
import io.agentscope.core.embedding.EmbeddingUtils;
27+
import io.agentscope.core.message.ContentBlock;
28+
import io.agentscope.core.message.TextBlock;
29+
import io.agentscope.core.model.ExecutionConfig;
30+
import java.util.List;
31+
import org.slf4j.Logger;
32+
import org.slf4j.LoggerFactory;
33+
import reactor.core.publisher.Mono;
34+
35+
/**
36+
* OpenAI Text Embedding Model implementation.
37+
*
38+
* <p>This implementation provides access to OpenAI's text embedding API, supporting both
39+
* single text embedding and batch embedding operations.
40+
*
41+
* <p>Supports only {@link TextBlock} content blocks. Other content block types will result in
42+
* an {@link EmbeddingException}.
43+
*
44+
* <p>Supports timeout and retry configuration through ExecutionConfig.
45+
*/
46+
public class OpenAITextEmbedding implements EmbeddingModel {
47+
48+
private static final Logger log = LoggerFactory.getLogger(OpenAITextEmbedding.class);
49+
50+
private final String apiKey;
51+
private final String modelName;
52+
private final int dimensions;
53+
private final ExecutionConfig defaultExecutionConfig;
54+
55+
private final String baseUrl;
56+
57+
/**
58+
* Creates a new OpenAI text embedding model instance.
59+
*
60+
* @param apiKey the API key for OpenAI authentication
61+
* @param modelName the model name (e.g., "text-embedding-3-small")
62+
* @param dimensions the dimension of embedding vectors
63+
* @param defaultExecutionConfig default execution configuration for timeout and retry
64+
* @param baseUrl custom base URL for OpenAI API (null for default)
65+
*/
66+
public OpenAITextEmbedding(
67+
String apiKey,
68+
String modelName,
69+
int dimensions,
70+
ExecutionConfig defaultExecutionConfig,
71+
String baseUrl) {
72+
this.apiKey = apiKey;
73+
this.modelName = modelName;
74+
this.dimensions = dimensions;
75+
this.defaultExecutionConfig =
76+
EmbeddingUtils.ensureDefaultExecutionConfig(defaultExecutionConfig);
77+
this.baseUrl = baseUrl;
78+
}
79+
80+
/**
81+
* Creates a new builder for OpenAITextEmbedding.
82+
*
83+
* @return a new Builder instance
84+
*/
85+
public static Builder builder() {
86+
return new Builder();
87+
}
88+
89+
@Override
90+
public Mono<double[]> embed(ContentBlock block) {
91+
if (block == null) {
92+
return Mono.error(
93+
new EmbeddingException("ContentBlock cannot be null", modelName, "openai"));
94+
}
95+
96+
if (!(block instanceof TextBlock textBlock)) {
97+
return Mono.error(
98+
new EmbeddingException(
99+
"OpenAITextEmbedding only supports TextBlock, but got: "
100+
+ block.getClass().getSimpleName(),
101+
modelName,
102+
"openai"));
103+
}
104+
105+
String text = textBlock.getText();
106+
if (text == null || text.trim().isEmpty()) {
107+
return Mono.error(
108+
new EmbeddingException(
109+
"TextBlock text cannot be null or empty", modelName, "openai"));
110+
}
111+
112+
Mono<double[]> embeddingMono =
113+
Mono.fromCallable(
114+
() -> {
115+
try {
116+
// Initialize OpenAI client
117+
OpenAIOkHttpClient.Builder clientBuilder =
118+
OpenAIOkHttpClient.builder();
119+
120+
if (apiKey != null) {
121+
clientBuilder.apiKey(apiKey);
122+
}
123+
124+
if (baseUrl != null) {
125+
clientBuilder.baseUrl(baseUrl);
126+
}
127+
128+
// Set unified AgentScope User-Agent (overrides OpenAI SDK
129+
// default)
130+
clientBuilder.putHeader(
131+
"User-Agent", Version.getUserAgent());
132+
133+
OpenAIClient client = clientBuilder.build();
134+
135+
EmbeddingCreateParams createParams =
136+
EmbeddingCreateParams.builder()
137+
.model(modelName)
138+
.dimensions(dimensions)
139+
.encodingFormat(
140+
EmbeddingCreateParams.EncodingFormat
141+
.FLOAT)
142+
.inputOfArrayOfStrings(List.of(text))
143+
.build();
144+
145+
log.debug(
146+
"OpenAI embedding call: model={},"
147+
+ " text_length={}",
148+
modelName,
149+
text.length());
150+
151+
CreateEmbeddingResponse result =
152+
client.embeddings().create(createParams);
153+
154+
if (result == null || result.data() == null) {
155+
throw new EmbeddingException(
156+
"Empty response from OpenAI embedding API",
157+
modelName,
158+
"openai");
159+
}
160+
161+
List<Embedding> embeddings = result.data();
162+
if (embeddings == null
163+
|| embeddings.isEmpty()
164+
|| embeddings.get(0) == null) {
165+
throw new EmbeddingException(
166+
"No embedding data in response",
167+
modelName,
168+
"openai");
169+
}
170+
171+
List<Float> embeddingValues = embeddings.get(0).embedding();
172+
if (embeddingValues == null || embeddingValues.isEmpty()) {
173+
throw new EmbeddingException(
174+
"Empty embedding vector in response",
175+
modelName,
176+
"openai");
177+
}
178+
179+
// Convert List<Float> to double[]
180+
double[] embeddingArray =
181+
EmbeddingUtils.convertFloatListToDoubleArray(
182+
embeddingValues);
183+
184+
// Validate dimension
185+
if (embeddingArray.length != dimensions) {
186+
log.warn(
187+
"Embedding dimension mismatch: expected={},"
188+
+ " actual={}",
189+
dimensions,
190+
embeddingArray.length);
191+
}
192+
193+
return embeddingArray;
194+
} catch (EmbeddingException e) {
195+
throw e;
196+
} catch (Exception e) {
197+
throw new EmbeddingException(
198+
"Failed to generate embedding: " + e.getMessage(),
199+
e,
200+
modelName,
201+
"openai");
202+
}
203+
})
204+
.onErrorMap(
205+
e -> {
206+
if (e instanceof EmbeddingException) {
207+
return e;
208+
}
209+
return new EmbeddingException(
210+
"OpenAI embedding API call failed: " + e.getMessage(),
211+
e,
212+
modelName,
213+
"openai");
214+
});
215+
216+
// Apply timeout and retry
217+
return EmbeddingUtils.applyTimeoutAndRetry(
218+
embeddingMono, defaultExecutionConfig, modelName, "openai", log);
219+
}
220+
221+
@Override
222+
public String getModelName() {
223+
return modelName;
224+
}
225+
226+
@Override
227+
public int getDimensions() {
228+
return dimensions;
229+
}
230+
231+
/**
232+
* Builder for OpenAITextEmbedding.
233+
*/
234+
public static class Builder {
235+
private String apiKey;
236+
private String modelName;
237+
private int dimensions = 1536;
238+
private ExecutionConfig defaultExecutionConfig;
239+
private String baseUrl;
240+
241+
/**
242+
* Sets the API key for OpenAI authentication.
243+
*
244+
* @param apiKey the API key
245+
* @return this builder instance
246+
*/
247+
public Builder apiKey(String apiKey) {
248+
this.apiKey = apiKey;
249+
return this;
250+
}
251+
252+
/**
253+
* Sets the model name to use.
254+
*
255+
* @param modelName the model name (e.g., "text-embedding-3-small")
256+
* @return this builder instance
257+
*/
258+
public Builder modelName(String modelName) {
259+
this.modelName = modelName;
260+
return this;
261+
}
262+
263+
/**
264+
* Sets the dimension of embedding vectors.
265+
*
266+
* @param dimensions the dimension
267+
* @return this builder instance
268+
*/
269+
public Builder dimensions(int dimensions) {
270+
this.dimensions = dimensions;
271+
return this;
272+
}
273+
274+
/**
275+
* Sets the default execution configuration.
276+
*
277+
* @param config the execution config (null for defaults)
278+
* @return this builder instance
279+
*/
280+
public Builder executionConfig(ExecutionConfig config) {
281+
this.defaultExecutionConfig = config;
282+
return this;
283+
}
284+
285+
/**
286+
* Sets a custom base URL for OpenAI API.
287+
*
288+
* @param baseUrl the base URL (null for default)
289+
* @return this builder instance
290+
*/
291+
public Builder baseUrl(String baseUrl) {
292+
this.baseUrl = baseUrl;
293+
return this;
294+
}
295+
296+
/**
297+
* Builds the OpenAITextEmbedding instance.
298+
*
299+
* <p>This method validates required parameters and ensures that the defaultExecutionConfig
300+
* always has proper defaults applied using EmbeddingUtils.ensureDefaultExecutionConfig().
301+
*
302+
* @return configured OpenAITextEmbedding instance
303+
* @throws IllegalStateException if required parameters are missing or invalid
304+
*/
305+
public OpenAITextEmbedding build() {
306+
// Validate required parameters
307+
if (apiKey == null || apiKey.isEmpty()) {
308+
throw new IllegalStateException("apiKey is required and cannot be null or empty");
309+
}
310+
if (modelName == null || modelName.isEmpty()) {
311+
throw new IllegalStateException(
312+
"modelName is required and cannot be null or empty");
313+
}
314+
if (dimensions <= 0) {
315+
throw new IllegalStateException("dimensions must be positive, got: " + dimensions);
316+
}
317+
318+
ExecutionConfig effectiveConfig =
319+
EmbeddingUtils.ensureDefaultExecutionConfig(defaultExecutionConfig);
320+
321+
return new OpenAITextEmbedding(apiKey, modelName, dimensions, effectiveConfig, baseUrl);
322+
}
323+
}
324+
}

0 commit comments

Comments
 (0)