Skip to content

Commit 864256f

Browse files
committed
feat: add rag API
1 parent 369a4f5 commit 864256f

File tree

6 files changed

+253
-12
lines changed

6 files changed

+253
-12
lines changed

Dockerfile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ WORKDIR /app
1515
COPY --from=build /app/app/target/tiny-engine-app-*.jar /app/tiny-engine-app.jar
1616
COPY --from=build /app/base/target/tiny-engine-base-*.jar /app/tiny-engine-base.jar
1717
# 设置环境变量
18-
18+
ENV ACCESS_KEY_ID=""
19+
ENV ACCESS_KEY_SECRET=""
20+
ENV INDEX_ID=""
21+
ENV WORK_SPACE_ID=""
1922
ENV FOLDER_PATH="/app/documents"
2023
# 替换为自己的域名接口路径
2124
ENV TINY_ENGINE_URL="https://agent.opentiny.design/material-center/api/resource/download"

base/src/main/java/com/tinyengine/it/controller/AiChatController.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212

1313
package com.tinyengine.it.controller;
1414

15+
import com.tinyengine.it.common.base.Result;
1516
import com.tinyengine.it.common.log.SystemControllerLog;
1617
import com.tinyengine.it.model.dto.ChatRequest;
1718

19+
import com.tinyengine.it.model.dto.NodeDto;
1820
import com.tinyengine.it.service.app.v1.AiChatV1Service;
1921
import io.swagger.v3.oas.annotations.Operation;
2022
import io.swagger.v3.oas.annotations.Parameter;
@@ -35,6 +37,8 @@
3537
import org.springframework.web.bind.annotation.RestController;
3638
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
3739

40+
import java.util.List;
41+
3842
/**
3943
* The type Ai chat controller.
4044
*
@@ -120,4 +124,24 @@ public ResponseEntity<?> completions(@RequestBody ChatRequest request,
120124
.body(e.getMessage());
121125
}
122126
}
127+
128+
/**
129+
* AI search api
130+
*
131+
* @param content the AI search param
132+
* @return ai回答信息 result
133+
*/
134+
@Operation(summary = "搜索知识库", description = "搜索知识库",
135+
parameters = {
136+
@Parameter(name = "content", description = "入参对象")
137+
}, responses = {
138+
@ApiResponse(responseCode = "200", description = "返回信息",
139+
content = @Content(mediaType = "application/json", schema = @Schema())),
140+
@ApiResponse(responseCode = "400", description = "请求失败")
141+
})
142+
@SystemControllerLog(description = "AI serarch api")
143+
@PostMapping("/ai/search")
144+
public Result<List<NodeDto>> search(@RequestBody String content) throws Exception {
145+
return aiChatV1Service.chatSearch(content);
146+
}
123147
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/**
2+
* Copyright (c) 2023 - present TinyEngine Authors.
3+
* Copyright (c) 2023 - present Huawei Cloud Computing Technologies Co., Ltd.
4+
*
5+
* Use of this source code is governed by an MIT-style license.
6+
*
7+
* THE OPEN SOURCE SOFTWARE IN THIS PRODUCT IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL,
8+
* BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR
9+
* A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
10+
*
11+
*/
12+
13+
package com.tinyengine.it.model.dto;
14+
15+
import lombok.Data;
16+
17+
/**
18+
* Node dto
19+
*
20+
* @since 2025-09-16
21+
*/
22+
@Data
23+
public class NodeDto {
24+
private Double score;
25+
private String docName;
26+
private String content;
27+
}

base/src/main/java/com/tinyengine/it/service/app/impl/v1/AiChatV1ServiceImpl.java

Lines changed: 183 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,22 @@
1212

1313
package com.tinyengine.it.service.app.impl.v1;
1414

15+
import com.aliyun.bailian20231229.Client;
16+
import com.aliyun.tea.TeaException;
17+
import com.aliyun.teaopenapi.models.Config;
18+
import com.aliyun.teaopenapi.models.OpenApiRequest;
19+
import com.aliyun.teaopenapi.models.Params;
20+
import com.aliyun.teautil.models.RuntimeOptions;
1521
import com.fasterxml.jackson.databind.JsonNode;
22+
import com.tinyengine.it.common.base.Result;
1623
import com.tinyengine.it.common.log.SystemServiceLog;
1724
import com.tinyengine.it.common.utils.JsonUtils;
1825
import com.tinyengine.it.config.OpenAIConfig;
1926
import com.tinyengine.it.model.dto.ChatRequest;
27+
import com.tinyengine.it.model.dto.NodeDto;
2028
import com.tinyengine.it.service.app.v1.AiChatV1Service;
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
2131
import org.springframework.stereotype.Service;
2232
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
2333

@@ -29,7 +39,9 @@
2939
import java.net.http.HttpResponse;
3040
import java.nio.charset.StandardCharsets;
3141
import java.time.Duration;
42+
import java.util.ArrayList;
3243
import java.util.HashMap;
44+
import java.util.List;
3345
import java.util.Map;
3446

3547
/**
@@ -39,10 +51,16 @@
3951
*/
4052
@Service
4153
public class AiChatV1ServiceImpl implements AiChatV1Service {
54+
private static final String ACCESS_KEY_ID = System.getenv("ACCESS_KEY_ID");
55+
private static final String ACCESS_KEY_SECRET = System.getenv("ACCESS_KEY_SECRET");
56+
private static final String ENDPOINT = "bailian.cn-beijing.aliyuncs.com";
57+
private static final String INDEX_ID = System.getenv("INDEX_ID");
58+
private static final String WORK_SPACE_ID = System.getenv("WORK_SPACE_ID");
4259
private final OpenAIConfig config = new OpenAIConfig();
60+
private static final Logger log = LoggerFactory.getLogger(AiChatV1ServiceImpl.class);
4361
private HttpClient httpClient = HttpClient.newBuilder()
44-
.connectTimeout(Duration.ofSeconds(config.getTimeoutSeconds()))
45-
.build();
62+
.connectTimeout(Duration.ofSeconds(config.getTimeoutSeconds()))
63+
.build();
4664

4765
/**
4866
* chatCompletion.
@@ -61,10 +79,10 @@ public Object chatCompletion(ChatRequest request) throws Exception {
6179
String normalizedUrl = normalizeApiUrl(baseUrl);
6280

6381
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
64-
.uri(URI.create(normalizedUrl))
65-
.header("Content-Type", "application/json")
66-
.header("Authorization", "Bearer " + apiKey)
67-
.POST(HttpRequest.BodyPublishers.ofString(requestBody));
82+
.uri(URI.create(normalizedUrl))
83+
.header("Content-Type", "application/json")
84+
.header("Authorization", "Bearer " + apiKey)
85+
.POST(HttpRequest.BodyPublishers.ofString(requestBody));
6886
if (request.isStream()) {
6987
requestBuilder.header("Accept", "text/event-stream");
7088
return processStreamResponse(requestBuilder);
@@ -90,7 +108,7 @@ private String normalizeApiUrl(String baseUrl) {
90108
return ensureUrlProtocol(baseUrl) + "/chat/completions";
91109
} else {
92110
return ensureUrlProtocol(baseUrl) + "/v1/chat/completions";
93-
}
111+
}
94112
}
95113

96114
/**
@@ -104,6 +122,160 @@ private String ensureUrlProtocol(String url) {
104122
return "https://" + url;
105123
}
106124

125+
/**
126+
* 创建客户端
127+
*/
128+
private Client createClient() throws Exception {
129+
return new Client(new Config()
130+
.setAccessKeyId(ACCESS_KEY_ID)
131+
.setAccessKeySecret(ACCESS_KEY_SECRET)
132+
.setEndpoint(ENDPOINT)
133+
.setEndpointType("access_key"));
134+
}
135+
136+
/**
137+
* 创建API信息
138+
*/
139+
private Params createApiInfo(String WorkspaceId) throws Exception {
140+
return new Params()
141+
// 接口名称
142+
.setAction("Retrieve")
143+
// 接口版本
144+
.setVersion("2023-12-29")
145+
// 接口协议
146+
.setProtocol("HTTPS")
147+
// 接口 HTTP 方法
148+
.setMethod("POST")
149+
.setAuthType("AK")
150+
.setStyle("ROA")
151+
// 接口 PATH
152+
.setPathname("/" + com.aliyun.openapiutil.Client.getEncodeParam(WorkspaceId) + "/index/retrieve")
153+
// 接口请求体内容格式
154+
.setReqBodyType("json")
155+
// 接口响应体内容格式
156+
.setBodyType("json");
157+
}
158+
159+
/**
160+
* 安全类型转换工具方法
161+
*/
162+
private <T> T safeCast(Object obj, Class<T> clazz, T defaultValue) {
163+
if (obj == null) {
164+
return defaultValue;
165+
}
166+
try {
167+
return clazz.cast(obj);
168+
} catch (ClassCastException e) {
169+
log.warn("类型转换失败: {} 无法转换为 {}", obj.getClass().getName(), clazz.getName());
170+
return defaultValue;
171+
}
172+
}
173+
174+
private String safeCastToString(Object obj) {
175+
return safeCast(obj, String.class, "");
176+
}
177+
178+
private Double safeCastToDouble(Object obj) {
179+
return safeCast(obj, Double.class, 0.0);
180+
}
181+
182+
private Long safeCastToLong(Object obj) {
183+
return safeCast(obj, Long.class, 0L);
184+
}
185+
186+
/**
187+
* chatSearch.
188+
*
189+
* @param content the content
190+
* @return String the String
191+
*/
192+
public Result chatSearch(String content) {
193+
try {
194+
Client client = createClient();
195+
Params params = createApiInfo(WORK_SPACE_ID);
196+
197+
Map<String, Object> queries = new HashMap<>();
198+
queries.put("IndexId", INDEX_ID);
199+
queries.put("Query", content);
200+
queries.put("EnableRewrite", "true");
201+
202+
RuntimeOptions runtime = new RuntimeOptions();
203+
OpenApiRequest request = new OpenApiRequest()
204+
.setQuery(com.aliyun.openapiutil.Client.query(queries));
205+
206+
Map<String, ?> response = client.callApi(params, request, runtime);
207+
Map<String, Object> body = (Map<String, Object>) response.get("body");
208+
209+
if (body == null) {
210+
return Result.failed("响应体为空");
211+
}
212+
213+
long status = safeCastToLong(body.get("Status"));
214+
if (status != 200L) {
215+
String message = safeCastToString(body.get("Message"));
216+
log.error("搜索失败: status={}, message={}", status, message);
217+
return Result.failed("搜索失败: " + message);
218+
}
219+
220+
Map data = safeCast(body.get("Data"), Map.class, new HashMap<>());
221+
if (data == null || data.isEmpty()) {
222+
return Result.success(new ArrayList<>());
223+
}
224+
225+
List nodes = safeCast(data.get("Nodes"), List.class, new ArrayList<>());
226+
if (nodes.isEmpty()) {
227+
return Result.success(new ArrayList<>());
228+
}
229+
230+
List nodeDtos = convertToNodeDtos(nodes);
231+
return Result.success(nodeDtos);
232+
233+
} catch (TeaException e) {
234+
log.error("阿里云Tea异常: {}", e.getMessage(), e);
235+
return Result.failed("阿里云服务异常: " + e.getMessage());
236+
} catch (Exception e) {
237+
log.error("搜索异常: {}", e.getMessage(), e);
238+
return Result.failed("系统异常: " + e.getMessage());
239+
}
240+
}
241+
242+
/**
243+
* 转换节点数据
244+
*/
245+
private List<NodeDto> convertToNodeDtos(List<Map<String, Object>> nodes) {
246+
List<NodeDto> nodeDtos = new ArrayList<>();
247+
248+
for (Map<String, Object> node : nodes) {
249+
try {
250+
NodeDto nodeDto = new NodeDto();
251+
252+
// 安全获取文本内容
253+
nodeDto.setContent(safeCastToString(node.get("Text")));
254+
255+
// 安全获取分数
256+
Object scoreObj = node.get("Score");
257+
if (scoreObj instanceof Number) {
258+
nodeDto.setScore(((Number) scoreObj).doubleValue());
259+
} else {
260+
nodeDto.setScore(safeCastToDouble(scoreObj));
261+
}
262+
263+
// 安全获取元数据
264+
Map metadata = safeCast(node.get("Metadata"), Map.class, new HashMap<>());
265+
if (metadata != null) {
266+
nodeDto.setDocName(safeCastToString(metadata.get("doc_name")));
267+
}
268+
269+
nodeDtos.add(nodeDto);
270+
271+
} catch (Exception e) {
272+
log.warn("节点数据转换失败: {}", e.getMessage());
273+
}
274+
}
275+
276+
return nodeDtos;
277+
}
278+
107279
private String buildRequestBody(ChatRequest request) {
108280
Map<String, Object> body = new HashMap<>();
109281
body.put("model", request.getModel() != null ? request.getModel() : config.getDefaultModel());
@@ -120,9 +292,9 @@ private String buildRequestBody(ChatRequest request) {
120292
}
121293

122294
private JsonNode processStandardResponse(HttpRequest.Builder requestBuilder)
123-
throws Exception {
295+
throws Exception {
124296
HttpResponse<String> response = httpClient.send(
125-
requestBuilder.build(), HttpResponse.BodyHandlers.ofString());
297+
requestBuilder.build(), HttpResponse.BodyHandlers.ofString());
126298
return JsonUtils.MAPPER.readTree(response.body());
127299
}
128300

@@ -131,8 +303,8 @@ private StreamingResponseBody processStreamResponse(HttpRequest.Builder requestB
131303
try {
132304
HttpClient client = HttpClient.newHttpClient();
133305
HttpResponse<InputStream> response = client.send(
134-
requestBuilder.build(),
135-
HttpResponse.BodyHandlers.ofInputStream()
306+
requestBuilder.build(),
307+
HttpResponse.BodyHandlers.ofInputStream()
136308
);
137309
if (response.statusCode() != 200) {
138310
String errorBody = new String(response.body().readAllBytes(), StandardCharsets.UTF_8);

base/src/main/java/com/tinyengine/it/service/app/v1/AiChatV1Service.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import com.tinyengine.it.common.base.Result;
1616
import com.tinyengine.it.model.dto.ChatRequest;
17+
import com.tinyengine.it.model.dto.NodeDto;
1718

1819
import java.util.List;
1920

@@ -30,4 +31,12 @@ public interface AiChatV1Service {
3031
* @return Object the Object
3132
*/
3233
public Object chatCompletion(ChatRequest request) throws Exception;
34+
35+
/**
36+
* chatSearch.
37+
*
38+
* @param content the content
39+
* @return String the String
40+
*/
41+
public Result<List<NodeDto>> chatSearch(String content) throws Exception;
3342
}

pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@
6767
<version>${fastjson.version}</version>
6868
</dependency>
6969

70+
<dependency>
71+
<groupId>com.aliyun</groupId>
72+
<artifactId>bailian20231229</artifactId>
73+
<version>2.4.1</version>
74+
</dependency>
75+
7076
<!-- Mybatis-plus 代码生成器依赖-->
7177
<dependency>
7278
<groupId>com.baomidou</groupId>

0 commit comments

Comments
 (0)