Skip to content

Commit 756ce9d

Browse files
committed
[fel] Code tidying up.
1 parent 1d99486 commit 756ce9d

8 files changed

Lines changed: 74 additions & 65 deletions

File tree

framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/OpenAiModel.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import modelengine.fel.core.image.ImageModel;
3636
import modelengine.fel.core.image.ImageOption;
3737
import modelengine.fel.core.model.http.SecureConfig;
38-
import modelengine.fel.core.rerank.RerankApi;
3938
import modelengine.fel.core.rerank.RerankModel;
4039
import modelengine.fel.core.rerank.RerankOption;
4140
import modelengine.fit.http.client.HttpClassicClient;
@@ -180,13 +179,13 @@ public List<Media> generate(String prompt, ImageOption option) {
180179
public List<MeasurableDocument> generate(List<MeasurableDocument> documents, RerankOption rerankOption) {
181180
notEmpty(documents, "The documents cannot be empty.");
182181
notNull(rerankOption, "The rerank option cannot be null.");
183-
List<String> docs = documents.stream().map(MeasurableDocument::text).collect(Collectors.toList());
184-
OpenAiRerankRequest fields = new OpenAiRerankRequest(rerankOption, docs);
185-
182+
String modelSource = StringUtils.blankIf(rerankOption.baseUri(), this.baseUrl);
186183
HttpClassicClientRequest request = this.httpClient.get()
187184
.createRequest(HttpRequestMethod.POST,
188-
UrlUtils.combine(rerankOption.baseUri(), RerankApi.RERANK_ENDPOINT));
185+
UrlUtils.combine(modelSource, OpenAiApi.RERANK_ENDPOINT));
189186
HttpUtils.setBearerAuth(request, StringUtils.blankIf(rerankOption.apiKey(), this.defaultApiKey));
187+
List<String> docs = documents.stream().map(MeasurableDocument::text).collect(Collectors.toList());
188+
OpenAiRerankRequest fields = new OpenAiRerankRequest(rerankOption, docs);
190189
request.entity(Entity.createObject(request, fields));
191190
OpenAiRerankResponse rerankResponse = this.rerankExchange(request);
192191

framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/api/OpenAiApi.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ public interface OpenAiApi {
2727
*/
2828
String IMAGE_ENDPOINT = "/images/generations";
2929

30+
/**
31+
* 重排请求的端点。
32+
*/
33+
String RERANK_ENDPOINT = "/rerank";
34+
3035
/**
3136
* 请求头模型密钥字段。
3237
*/

framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/OpenAiModelTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
package modelengine.fel.community.model.openai;
88

99
import static org.assertj.core.api.Assertions.assertThat;
10+
import static org.junit.jupiter.api.Assertions.assertAll;
1011

1112
import modelengine.fel.community.model.openai.config.OpenAiConfig;
1213
import modelengine.fel.core.chat.ChatMessage;
1314
import modelengine.fel.core.chat.ChatOption;
1415
import modelengine.fel.core.chat.support.ChatMessages;
1516
import modelengine.fel.core.chat.support.HumanMessage;
17+
import modelengine.fel.core.document.Document;
18+
import modelengine.fel.core.document.MeasurableDocument;
1619
import modelengine.fel.core.embed.EmbedOption;
1720
import modelengine.fel.core.embed.Embedding;
1821
import modelengine.fel.core.image.ImageOption;
22+
import modelengine.fel.core.rerank.RerankOption;
1923
import modelengine.fit.http.client.HttpClassicClientFactory;
2024
import modelengine.fitframework.annotation.Fit;
2125
import modelengine.fitframework.conf.Config;
@@ -31,6 +35,8 @@
3135
import org.junit.jupiter.api.Test;
3236

3337
import java.util.Arrays;
38+
import java.util.Collections;
39+
import java.util.HashMap;
3440
import java.util.List;
3541
import java.util.stream.Collectors;
3642

@@ -42,6 +48,9 @@
4248
@MvcTest(classes = TestModelController.class)
4349
public class OpenAiModelTest {
4450
private OpenAiModel openAiModel;
51+
private static final int EXPECTED_TOP_K = 3;
52+
private static final String HIGHEST_RANKED_TEXT = "C++ offers high performance.";
53+
private static final double EXPECTED_HIGHEST_SCORE = 0.999071;
4554

4655
@Fit
4756
private HttpClassicClientFactory httpClientFactory;
@@ -91,4 +100,45 @@ void testOpenAiImageModel() {
91100
"456",
92101
"789");
93102
}
103+
104+
@Test
105+
@DisplayName("测试重排模型返回:应返回按相关性排序的前 K 个文档")
106+
void testOpenAiRerankModel() {
107+
// Given: 准备输入文档
108+
List<MeasurableDocument> inputDocs = Arrays.asList(doc("0", "Java is a programming language."),
109+
doc("1", "Python is great for data science."),
110+
doc("2", HIGHEST_RANKED_TEXT),
111+
doc("3", "Rust offers high performance."),
112+
doc("4", "C offers high performance."));
113+
114+
RerankOption rerankOption = RerankOption.custom().model("rerank-model").build();
115+
116+
// When: 调用重排接口
117+
List<MeasurableDocument> result = openAiModel.generate(inputDocs, rerankOption);
118+
119+
// Then: 验证结果
120+
assertAll(() -> assertThat(result).as("应返回 top-%d 结果", EXPECTED_TOP_K).hasSize(EXPECTED_TOP_K),
121+
122+
() -> {
123+
List<Double> scores = result.stream().map(MeasurableDocument::score).collect(Collectors.toList());
124+
assertThat(scores).as("结果应按相关性分数降序排列").isSortedAccordingTo(Collections.reverseOrder());
125+
},
126+
127+
() -> {
128+
List<String> resultTexts =
129+
result.stream().map(MeasurableDocument::text).collect(Collectors.toList());
130+
List<String> inputTexts =
131+
inputDocs.stream().map(MeasurableDocument::text).collect(Collectors.toList());
132+
assertThat(inputTexts).as("所有返回文档必须来自输入集").containsAll(resultTexts);
133+
},
134+
135+
() -> assertThat(result.get(0).text()).as("得分最高的文档应为 C++").isEqualTo(HIGHEST_RANKED_TEXT),
136+
137+
() -> assertThat(result.get(0).score()).as("最高分应与模拟响应一致").isEqualTo(EXPECTED_HIGHEST_SCORE));
138+
}
139+
140+
private MeasurableDocument doc(String id, String text) {
141+
Document document = Document.custom().id(id).text(text).metadata(new HashMap<>()).build();
142+
return new MeasurableDocument(document, 0.0);
143+
}
94144
}

framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/TestModelController.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import static modelengine.fel.community.model.openai.api.OpenAiApi.CHAT_ENDPOINT;
1010
import static modelengine.fel.community.model.openai.api.OpenAiApi.EMBEDDING_ENDPOINT;
1111
import static modelengine.fel.community.model.openai.api.OpenAiApi.IMAGE_ENDPOINT;
12+
import static modelengine.fel.community.model.openai.api.OpenAiApi.RERANK_ENDPOINT;
1213

1314
import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbeddingResponse;
1415
import modelengine.fel.community.model.openai.entity.image.OpenAiImageResponse;
16+
import modelengine.fel.community.model.openai.entity.rerank.OpenAiRerankResponse;
1517
import modelengine.fit.http.annotation.PostMapping;
1618
import modelengine.fitframework.annotation.Component;
1719
import modelengine.fitframework.flowable.Choir;
@@ -81,4 +83,17 @@ public OpenAiImageResponse image() {
8183
+ "\"data\":[{\"b64_json\":\"123\"}, {\"b64_json\":\"456\"}, {\"b64_json\":\"789\"}]}";
8284
return this.serializer.deserialize(json, OpenAiImageResponse.class);
8385
}
86+
87+
/**
88+
* 测试用重排接口。
89+
*
90+
* @return 表示重排响应的 {@link OpenAiRerankResponse}。
91+
*/
92+
@PostMapping(RERANK_ENDPOINT)
93+
public OpenAiRerankResponse rerank() {
94+
String json =
95+
"{\"results\":[{\"index\":2,\"relevance_score\":0.999071},{\"index\":3,\"relevance_score\":0.7867867},"
96+
+ "{\"index\":0,\"relevance_score\":0.32713068}]}";
97+
return this.serializer.deserialize(json, OpenAiRerankResponse.class);
98+
}
8499
}

framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankApi.java

Lines changed: 0 additions & 19 deletions
This file was deleted.

framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankDocumentProcessor.java

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,13 @@
88

99
import modelengine.fel.core.document.DocumentPostProcessor;
1010
import modelengine.fel.core.document.MeasurableDocument;
11-
import modelengine.fel.core.rerank.RerankApi;
1211
import modelengine.fel.core.rerank.RerankModel;
1312
import modelengine.fel.core.rerank.RerankOption;
14-
import modelengine.fit.http.client.HttpClassicClient;
15-
import modelengine.fit.http.client.HttpClassicClientFactory;
16-
import modelengine.fit.http.client.HttpClassicClientRequest;
17-
import modelengine.fit.http.client.HttpClassicClientResponse;
18-
import modelengine.fit.http.entity.Entity;
19-
import modelengine.fit.http.entity.ObjectEntity;
20-
import modelengine.fit.http.protocol.HttpRequestMethod;
21-
import modelengine.fit.http.protocol.HttpResponseStatus;
22-
import modelengine.fitframework.exception.FitException;
2313
import modelengine.fitframework.inspection.Validation;
24-
import modelengine.fitframework.log.Logger;
25-
import modelengine.fitframework.resource.UrlUtils;
2614
import modelengine.fitframework.util.CollectionUtils;
27-
import modelengine.fitframework.util.LazyLoader;
28-
import modelengine.fitframework.util.ObjectUtils;
2915

30-
import java.io.IOException;
3116
import java.util.Collections;
3217
import java.util.List;
33-
import java.util.stream.Collectors;
3418

3519
/**
3620
* 表示检索文档的后置重排序接口。

framework/fel/java/fel-core/src/main/java/modelengine/fel/core/rerank/RerankApi.java

Lines changed: 0 additions & 19 deletions
This file was deleted.

framework/fel/java/fel-core/src/test/java/modelengine/fel/core/document/support/RerankDocumentProcessorTest.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,8 @@
1111

1212
import modelengine.fel.core.document.Document;
1313
import modelengine.fel.core.document.MeasurableDocument;
14-
import modelengine.fel.core.rerank.RerankModel;
1514
import modelengine.fel.core.rerank.RerankOption;
16-
import modelengine.fit.http.client.HttpClassicClientFactory;
1715
import modelengine.fitframework.annotation.Fit;
18-
import modelengine.fitframework.exception.FitException;
19-
import modelengine.fitframework.test.annotation.Mock;
20-
import modelengine.fitframework.test.annotation.MvcTest;
21-
import modelengine.fitframework.test.domain.mockito.MockitoMockBean;
2216
import modelengine.fitframework.test.domain.mvc.MockMvc;
2317

2418
import org.junit.jupiter.api.BeforeEach;

0 commit comments

Comments
 (0)