Skip to content

Commit 5be0da7

Browse files
committed
[app-builder] support MCP tool invocation in LLM nodes
1 parent 523d902 commit 5be0da7

File tree

4 files changed

+243
-18
lines changed

4 files changed

+243
-18
lines changed

app-builder/jane/plugins/aipp-plugin/src/main/java/modelengine/fit/jober/aipp/fel/FelComponentConfig.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import modelengine.fel.core.chat.ChatModel;
1010
import modelengine.fel.core.chat.Prompt;
1111
import modelengine.fel.engine.operators.patterns.AbstractAgent;
12+
import modelengine.fel.tool.mcp.client.McpClientFactory;
1213
import modelengine.fit.jade.tool.SyncToolCall;
1314
import modelengine.fit.jober.aipp.constants.AippConst;
1415
import modelengine.fitframework.annotation.Bean;
@@ -28,11 +29,12 @@ public class FelComponentConfig {
2829
*
2930
* @param syncToolCall 表示同步工具调用服务的 {@link SyncToolCall}。
3031
* @param chatModel 表示模型流式服务的 {@link ChatModel}。
32+
* @param mcpClientFactory 表示大模型上下文客户端工厂的 {@link McpClientFactory}。
3133
* @return 返回 WaterFlow 场景的 Agent 服务的 {@link AbstractAgent}{@code <}{@link Prompt}{@code ,
3234
* }{@link Prompt}{@code >}。
3335
*/
3436
@Bean(AippConst.WATER_FLOW_AGENT_BEAN)
35-
public AbstractAgent getWaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatModel) {
36-
return new WaterFlowAgent(syncToolCall, chatModel);
37+
public AbstractAgent getWaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatModel, McpClientFactory mcpClientFactory) {
38+
return new WaterFlowAgent(syncToolCall, chatModel, mcpClientFactory);
3739
}
3840
}

app-builder/jane/plugins/aipp-plugin/src/main/java/modelengine/fit/jober/aipp/fel/WaterFlowAgent.java

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,40 @@
66

77
package modelengine.fit.jober.aipp.fel;
88

9+
import com.alibaba.fastjson.JSON;
10+
import com.alibaba.fastjson.JSONObject;
11+
912
import modelengine.fel.core.chat.ChatMessage;
1013
import modelengine.fel.core.chat.ChatModel;
1114
import modelengine.fel.core.chat.Prompt;
1215
import modelengine.fel.core.chat.support.ChatMessages;
1316
import modelengine.fel.core.chat.support.FlatChatMessage;
1417
import modelengine.fel.core.chat.support.ToolMessage;
1518
import modelengine.fel.core.tool.ToolCall;
19+
import modelengine.fel.core.tool.ToolInfo;
1620
import modelengine.fel.engine.flows.AiFlows;
1721
import modelengine.fel.engine.flows.AiProcessFlow;
1822
import modelengine.fel.engine.operators.models.ChatChunk;
1923
import modelengine.fel.engine.operators.models.ChatFlowModel;
2024
import modelengine.fel.engine.operators.patterns.AbstractAgent;
25+
import modelengine.fel.tool.mcp.client.McpClient;
26+
import modelengine.fel.tool.mcp.client.McpClientFactory;
2127
import modelengine.fit.jade.tool.SyncToolCall;
28+
import modelengine.fit.jober.aipp.common.exception.AippErrCode;
29+
import modelengine.fit.jober.aipp.common.exception.AippException;
2230
import modelengine.fit.jober.aipp.constants.AippConst;
31+
import modelengine.fit.jober.aipp.util.McpUtils;
2332
import modelengine.fit.waterflow.domain.context.StateContext;
2433
import modelengine.fitframework.annotation.Fit;
2534
import modelengine.fitframework.inspection.Validation;
35+
import modelengine.fitframework.util.CollectionUtils;
2636
import modelengine.fitframework.util.ObjectUtils;
2737

38+
import java.io.IOException;
2839
import java.util.Collections;
2940
import java.util.List;
3041
import java.util.Map;
42+
import java.util.function.Function;
3143
import java.util.stream.Collectors;
3244

3345
/**
@@ -42,28 +54,30 @@ public class WaterFlowAgent extends AbstractAgent {
4254

4355
private final String agentMsgKey;
4456
private final SyncToolCall syncToolCall;
57+
private final McpClientFactory mcpClientFactory;
4558

4659
/**
4760
* {@link WaterFlowAgent} 的构造方法。
4861
*
4962
* @param syncToolCall 表示工具调用服务的 {@link SyncToolCall}。
5063
* @param chatStreamModel 表示流式对话大模型的 {@link ChatModel}。
64+
* @param mcpClientFactory 表示大模型上下文客户端工厂的 {@link McpClientFactory}。
5165
*/
52-
public WaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatStreamModel) {
66+
public WaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatStreamModel,
67+
McpClientFactory mcpClientFactory) {
5368
super(new ChatFlowModel(chatStreamModel, null));
54-
this.syncToolCall = Validation.notNull(syncToolCall, "The tool sync tool call cannot be null.");
69+
this.syncToolCall = Validation.notNull(syncToolCall, "The tool sync tool call cannot be null.");
70+
this.mcpClientFactory = Validation.notNull(mcpClientFactory, "The mcp client factory cannot be null.");
5571
this.agentMsgKey = AGENT_MSG_KEY;
5672
}
5773

5874
@Override
5975
protected Prompt doToolCall(List<ToolCall> toolCalls, StateContext ctx) {
6076
Validation.notNull(ctx, "The state context cannot be null.");
61-
Map<String, Object> toolContext = ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY),
62-
Collections::emptyMap);
63-
return toolCalls.stream()
64-
.map(toolCall -> (ChatMessage) new ToolMessage(toolCall.id(),
65-
this.syncToolCall.call(toolCall.name(), toolCall.arguments(), toolContext)))
66-
.collect(Collectors.collectingAndThen(Collectors.toList(), ChatMessages::from));
77+
return ChatMessages.from(this.callTools(toolCalls, ctx)
78+
.stream()
79+
.map(message -> (ChatMessage) FlatChatMessage.from(message))
80+
.collect(Collectors.toList()));
6781
}
6882

6983
@Override
@@ -87,18 +101,53 @@ public AiProcessFlow<Prompt, ChatMessage> buildFlow() {
87101
private ChatMessage handleTool(ChatMessage input, StateContext ctx) {
88102
Validation.notNull(ctx, "The state context cannot be null.");
89103
Validation.notNull(input, "The input message cannot be null.");
90-
91-
Map<String, Object> toolContext = ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY),
92-
Collections::emptyMap);
93104
ChatMessages lastRequest = ctx.getState(this.agentMsgKey);
94105
lastRequest.add(input);
95-
input.toolCalls().forEach(toolCall -> {
96-
lastRequest.add(FlatChatMessage.from(new ToolMessage(toolCall.id(),
97-
this.syncToolCall.call(toolCall.name(), toolCall.arguments(), toolContext))));
98-
});
106+
lastRequest.addAll(this.callTools(input.toolCalls(), ctx));
99107
return input;
100108
}
101109

110+
private List<ChatMessage> callTools(List<ToolCall> toolCalls, StateContext ctx) {
111+
if (CollectionUtils.isEmpty(toolCalls)) {
112+
return Collections.emptyList();
113+
}
114+
List<ToolInfo> tools = ctx.getState(AippConst.TOOLS_KEY);
115+
Validation.notEmpty(tools, "Missing tool detected during call.");
116+
Map<String, ToolInfo> toolsMap = tools.stream().collect(Collectors.toMap(ToolInfo::name, Function.identity()));
117+
Map<String, Object> toolContext =
118+
ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY), Collections::emptyMap);
119+
return toolCalls.stream()
120+
.map(toolCall -> this.callTool(toolCall, toolsMap, toolContext))
121+
.collect(Collectors.toList());
122+
}
123+
124+
private ChatMessage callTool(ToolCall toolCall, Map<String, ToolInfo> toolsMap, Map<String, Object> toolContext) {
125+
ToolInfo toolInfo = toolsMap.get(toolCall.name());
126+
if (toolInfo == null) {
127+
throw new IllegalStateException(String.format("The tool call's tool is not exist. [toolName=%s]",
128+
toolCall.name()));
129+
}
130+
Map<String, Object> extensions = Validation.notNull(toolInfo.extensions(),
131+
"The tool call's extension is not exist. [toolName={0}]", toolCall.name());
132+
String toolRealName = Validation.notBlank(ObjectUtils.cast(extensions.get(AippConst.TOOL_REAL_NAME)),
133+
"Can not find the tool real name. [toolName={0}]",
134+
toolCall.name());
135+
Map<String, Object> mcpServerConfig = ObjectUtils.cast(extensions.get(AippConst.MCP_SERVER_KEY));
136+
if (mcpServerConfig != null) {
137+
String url = Validation.notBlank(ObjectUtils.cast(mcpServerConfig.get(AippConst.MCP_SERVER_URL_KEY)),
138+
"The mcp url should not be empty.");
139+
try (McpClient mcpClient = this.mcpClientFactory.create(McpUtils.getBaseUrl(url),
140+
McpUtils.getSseEndpoint(url))) {
141+
mcpClient.initialize();
142+
Object result = mcpClient.callTool(toolRealName, JSONObject.parseObject(toolCall.arguments()));
143+
return new ToolMessage(toolCall.id(), JSON.toJSONString(result));
144+
} catch (IOException exception) {
145+
throw new AippException(AippErrCode.CALL_MCP_SERVER_FAILED, exception.getMessage());
146+
}
147+
}
148+
return new ToolMessage(toolCall.id(), this.syncToolCall.call(toolRealName, toolCall.arguments(), toolContext));
149+
}
150+
102151
private ChatMessages getAgentMsg(ChatMessage input, StateContext ctx) {
103152
Validation.notNull(ctx, "The state context cannot be null.");
104153
return ctx.getState(this.agentMsgKey);
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
package modelengine.fit.jober.aipp.fel;
2+
3+
import modelengine.fel.core.chat.ChatMessage;
4+
import modelengine.fel.core.chat.ChatModel;
5+
import modelengine.fel.core.chat.ChatOption;
6+
import modelengine.fel.core.chat.Prompt;
7+
import modelengine.fel.core.chat.support.AiMessage;
8+
import modelengine.fel.core.chat.support.ChatMessages;
9+
import modelengine.fel.core.chat.support.HumanMessage;
10+
import modelengine.fel.core.tool.ToolCall;
11+
import modelengine.fel.core.tool.ToolInfo;
12+
import modelengine.fel.engine.flows.AiProcessFlow;
13+
import modelengine.fel.tool.mcp.client.McpClient;
14+
import modelengine.fel.tool.mcp.client.McpClientFactory;
15+
import modelengine.fit.jade.tool.SyncToolCall;
16+
import modelengine.fit.jober.aipp.constants.AippConst;
17+
import modelengine.fitframework.flowable.Choir;
18+
import modelengine.fitframework.util.MapBuilder;
19+
20+
import org.apache.commons.collections.CollectionUtils;
21+
import org.jetbrains.annotations.NotNull;
22+
import org.junit.jupiter.api.Assertions;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.extension.ExtendWith;
25+
import org.mockito.Mock;
26+
import org.mockito.junit.jupiter.MockitoExtension;
27+
28+
import java.util.Arrays;
29+
import java.util.Collections;
30+
import java.util.HashMap;
31+
import java.util.List;
32+
import java.util.Map;
33+
import java.util.concurrent.atomic.AtomicInteger;
34+
35+
import static org.junit.jupiter.api.Assertions.*;
36+
import static org.mockito.ArgumentMatchers.any;
37+
import static org.mockito.Mockito.doAnswer;
38+
import static org.mockito.Mockito.mock;
39+
import static org.mockito.Mockito.times;
40+
import static org.mockito.Mockito.verify;
41+
import static org.mockito.Mockito.when;
42+
43+
@ExtendWith(MockitoExtension.class)
44+
class WaterFlowAgentTest {
45+
@Mock
46+
private SyncToolCall syncToolCall;
47+
48+
@Mock
49+
private ChatModel chatModel;
50+
51+
@Mock
52+
private McpClientFactory mcpClientFactory;
53+
54+
55+
@Test
56+
void shouldGetResultWhenRunFlowGivenNoToolCall() {
57+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
58+
59+
String expectResult = "0123";
60+
doAnswer(invocation -> Choir.create(emitter -> {
61+
for (int i = 0; i < 4; i++) {
62+
emitter.emit(new AiMessage(String.valueOf(i)));
63+
}
64+
emitter.complete();
65+
})).when(chatModel).generate(any(), any());
66+
67+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
68+
ChatMessage result = flow.converse()
69+
.bind(ChatOption.custom().build())
70+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
71+
72+
assertEquals(expectResult, result.text());
73+
}
74+
75+
@Test
76+
void shouldGetResultWhenRunFlowGivenStoreToolCall() {
77+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
78+
79+
String expectResult = "tool result:0123";
80+
String realName = "realName";
81+
ToolInfo toolInfo = buildToolInfo(realName);
82+
ToolCall toolCall = ToolCall.custom().id("id").name(toolInfo.name()).arguments("{}").build();
83+
List<ToolCall> toolCalls = Collections.singletonList(toolCall);
84+
AtomicInteger step = new AtomicInteger();
85+
doAnswer(invocation -> {
86+
Prompt prompt = invocation.getArgument(0);
87+
return mockGenerateResult(step, toolCalls, prompt);
88+
}).when(chatModel).generate(any(), any());
89+
Map<String, Object> toolContext = MapBuilder.<String, Object>get().put("key", "value").build();
90+
when(this.syncToolCall.call(realName, toolCall.arguments(), toolContext)).thenReturn("tool result:");
91+
92+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
93+
ChatMessage result = flow.converse()
94+
.bind(ChatOption.custom().build())
95+
.bind(AippConst.TOOL_CONTEXT_KEY, toolContext)
96+
.bind(AippConst.TOOLS_KEY, Collections.singletonList(toolInfo))
97+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
98+
99+
verify(this.mcpClientFactory, times(0)).create(any(), any());
100+
assertEquals(expectResult, result.text());
101+
}
102+
103+
@Test
104+
void shouldGetResultWhenRunFlowGivenMcpToolCall() {
105+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
106+
107+
String expectResult = "\"tool result:\"0123";
108+
String realName = "realName";
109+
String baseUrl = "http://localhost";
110+
String sseEndpoint = "/sse";
111+
ToolInfo toolInfo = buildMcpToolInfo(realName, baseUrl, sseEndpoint);
112+
ToolCall toolCall = ToolCall.custom().id("id").name(toolInfo.name()).arguments("{}").build();
113+
List<ToolCall> toolCalls = Collections.singletonList(toolCall);
114+
AtomicInteger step = new AtomicInteger();
115+
doAnswer(invocation -> {
116+
Prompt prompt = invocation.getArgument(0);
117+
return mockGenerateResult(step, toolCalls, prompt);
118+
}).when(chatModel).generate(any(), any());
119+
Map<String, Object> toolContext = MapBuilder.<String, Object>get().put("key", "value").build();
120+
McpClient mcpClient = mock(McpClient.class);
121+
when(this.mcpClientFactory.create(baseUrl, sseEndpoint)).thenReturn(mcpClient);
122+
when(mcpClient.callTool(realName, new HashMap<>())).thenReturn("tool result:");
123+
124+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
125+
ChatMessage result = flow.converse()
126+
.bind(ChatOption.custom().build())
127+
.bind(AippConst.TOOL_CONTEXT_KEY, toolContext)
128+
.bind(AippConst.TOOLS_KEY, Collections.singletonList(toolInfo))
129+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
130+
131+
verify(this.syncToolCall, times(0)).call(any(), any(), any());
132+
assertEquals(expectResult, result.text());
133+
}
134+
135+
private static Choir<Object> mockGenerateResult(AtomicInteger step, List<ToolCall> toolCalls, Prompt prompt) {
136+
return Choir.create(emitter -> {
137+
138+
if (step.getAndIncrement() == 0) {
139+
emitter.emit(new AiMessage("tool_data", toolCalls));
140+
emitter.complete();
141+
return;
142+
}
143+
if (CollectionUtils.isNotEmpty(prompt.messages())) {
144+
emitter.emit(new AiMessage(prompt.messages().get(prompt.messages().size() - 1).text()));
145+
}
146+
for (int i = 0; i < 4; i++) {
147+
emitter.emit(new AiMessage(String.valueOf(i)));
148+
}
149+
emitter.complete();
150+
});
151+
}
152+
153+
private static ToolInfo buildToolInfo(String realName) {
154+
return ToolInfo.custom()
155+
.name("tool1")
156+
.description("desc")
157+
.parameters(new HashMap<>())
158+
.extensions(MapBuilder.<String, Object>get().put(AippConst.TOOL_REAL_NAME, realName).build())
159+
.build();
160+
}
161+
162+
private static ToolInfo buildMcpToolInfo(String realName, String baseUrl, String sseEndpoint) {
163+
return ToolInfo.custom()
164+
.name("tool1")
165+
.description("desc")
166+
.parameters(new HashMap<>())
167+
.extensions(MapBuilder.<String, Object>get()
168+
.put(AippConst.TOOL_REAL_NAME, realName)
169+
.put(AippConst.MCP_SERVER_KEY,
170+
MapBuilder.get().put(AippConst.MCP_SERVER_URL_KEY, baseUrl + sseEndpoint).build())
171+
.build())
172+
.build();
173+
}
174+
}

app-builder/jane/plugins/aipp-plugin/src/test/java/modelengine/fit/jober/aipp/fitable/LlmComponentTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ protected AiProcessFlow<Prompt, ChatMessage> buildFlow() {
166166
}
167167

168168
private AbstractAgent getWaterFlowAgent(ChatModel model) {
169-
return new WaterFlowAgent(this.syncToolCall, model);
169+
return new WaterFlowAgent(this.syncToolCall, model, this.mcpClientFactory);
170170
}
171171

172172
private ChatModel buildChatStreamModel(String exceptionMsg) {

0 commit comments

Comments
 (0)