Skip to content

Commit b039d55

Browse files
committed
[app-builder] support MCP tool invocation in LLM nodes
1 parent 523d902 commit b039d55

4 files changed

Lines changed: 252 additions & 18 deletions

File tree

app-builder/jane/plugins/aipp-plugin/src/main/java/modelengine/fit/jober/aipp/fel/FelComponentConfig.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import modelengine.fel.core.chat.ChatModel;
1010
import modelengine.fel.core.chat.Prompt;
1111
import modelengine.fel.engine.operators.patterns.AbstractAgent;
12+
import modelengine.fel.tool.mcp.client.McpClientFactory;
1213
import modelengine.fit.jade.tool.SyncToolCall;
1314
import modelengine.fit.jober.aipp.constants.AippConst;
1415
import modelengine.fitframework.annotation.Bean;
@@ -28,11 +29,12 @@ public class FelComponentConfig {
2829
*
2930
* @param syncToolCall 表示同步工具调用服务的 {@link SyncToolCall}。
3031
* @param chatModel 表示模型流式服务的 {@link ChatModel}。
32+
* @param mcpClientFactory 表示大模型上下文客户端工厂的 {@link McpClientFactory}。
3133
* @return 返回 WaterFlow 场景的 Agent 服务的 {@link AbstractAgent}{@code <}{@link Prompt}{@code ,
3234
* }{@link Prompt}{@code >}。
3335
*/
3436
@Bean(AippConst.WATER_FLOW_AGENT_BEAN)
35-
public AbstractAgent getWaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatModel) {
36-
return new WaterFlowAgent(syncToolCall, chatModel);
37+
public AbstractAgent getWaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatModel, McpClientFactory mcpClientFactory) {
38+
return new WaterFlowAgent(syncToolCall, chatModel, mcpClientFactory);
3739
}
3840
}

app-builder/jane/plugins/aipp-plugin/src/main/java/modelengine/fit/jober/aipp/fel/WaterFlowAgent.java

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,40 @@
66

77
package modelengine.fit.jober.aipp.fel;
88

9+
import com.alibaba.fastjson.JSON;
10+
import com.alibaba.fastjson.JSONObject;
11+
912
import modelengine.fel.core.chat.ChatMessage;
1013
import modelengine.fel.core.chat.ChatModel;
1114
import modelengine.fel.core.chat.Prompt;
1215
import modelengine.fel.core.chat.support.ChatMessages;
1316
import modelengine.fel.core.chat.support.FlatChatMessage;
1417
import modelengine.fel.core.chat.support.ToolMessage;
1518
import modelengine.fel.core.tool.ToolCall;
19+
import modelengine.fel.core.tool.ToolInfo;
1620
import modelengine.fel.engine.flows.AiFlows;
1721
import modelengine.fel.engine.flows.AiProcessFlow;
1822
import modelengine.fel.engine.operators.models.ChatChunk;
1923
import modelengine.fel.engine.operators.models.ChatFlowModel;
2024
import modelengine.fel.engine.operators.patterns.AbstractAgent;
25+
import modelengine.fel.tool.mcp.client.McpClient;
26+
import modelengine.fel.tool.mcp.client.McpClientFactory;
2127
import modelengine.fit.jade.tool.SyncToolCall;
28+
import modelengine.fit.jober.aipp.common.exception.AippErrCode;
29+
import modelengine.fit.jober.aipp.common.exception.AippException;
2230
import modelengine.fit.jober.aipp.constants.AippConst;
31+
import modelengine.fit.jober.aipp.util.McpUtils;
2332
import modelengine.fit.waterflow.domain.context.StateContext;
2433
import modelengine.fitframework.annotation.Fit;
2534
import modelengine.fitframework.inspection.Validation;
35+
import modelengine.fitframework.util.CollectionUtils;
2636
import modelengine.fitframework.util.ObjectUtils;
2737

38+
import java.io.IOException;
2839
import java.util.Collections;
2940
import java.util.List;
3041
import java.util.Map;
42+
import java.util.function.Function;
3143
import java.util.stream.Collectors;
3244

3345
/**
@@ -42,28 +54,30 @@ public class WaterFlowAgent extends AbstractAgent {
4254

4355
private final String agentMsgKey;
4456
private final SyncToolCall syncToolCall;
57+
private final McpClientFactory mcpClientFactory;
4558

4659
/**
4760
* {@link WaterFlowAgent} 的构造方法。
4861
*
4962
* @param syncToolCall 表示工具调用服务的 {@link SyncToolCall}。
5063
* @param chatStreamModel 表示流式对话大模型的 {@link ChatModel}。
64+
* @param mcpClientFactory 表示大模型上下文客户端工厂的 {@link McpClientFactory}。
5165
*/
52-
public WaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatStreamModel) {
66+
public WaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatStreamModel,
67+
McpClientFactory mcpClientFactory) {
5368
super(new ChatFlowModel(chatStreamModel, null));
54-
this.syncToolCall = Validation.notNull(syncToolCall, "The tool sync tool call cannot be null.");
69+
this.syncToolCall = Validation.notNull(syncToolCall, "The tool sync tool call cannot be null.");
70+
this.mcpClientFactory = Validation.notNull(mcpClientFactory, "The mcp client factory cannot be null.");
5571
this.agentMsgKey = AGENT_MSG_KEY;
5672
}
5773

5874
@Override
5975
protected Prompt doToolCall(List<ToolCall> toolCalls, StateContext ctx) {
6076
Validation.notNull(ctx, "The state context cannot be null.");
61-
Map<String, Object> toolContext = ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY),
62-
Collections::emptyMap);
63-
return toolCalls.stream()
64-
.map(toolCall -> (ChatMessage) new ToolMessage(toolCall.id(),
65-
this.syncToolCall.call(toolCall.name(), toolCall.arguments(), toolContext)))
66-
.collect(Collectors.collectingAndThen(Collectors.toList(), ChatMessages::from));
77+
return ChatMessages.from(this.callTools(toolCalls, ctx)
78+
.stream()
79+
.map(message -> (ChatMessage) FlatChatMessage.from(message))
80+
.collect(Collectors.toList()));
6781
}
6882

6983
@Override
@@ -87,18 +101,53 @@ public AiProcessFlow<Prompt, ChatMessage> buildFlow() {
87101
private ChatMessage handleTool(ChatMessage input, StateContext ctx) {
88102
Validation.notNull(ctx, "The state context cannot be null.");
89103
Validation.notNull(input, "The input message cannot be null.");
90-
91-
Map<String, Object> toolContext = ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY),
92-
Collections::emptyMap);
93104
ChatMessages lastRequest = ctx.getState(this.agentMsgKey);
94105
lastRequest.add(input);
95-
input.toolCalls().forEach(toolCall -> {
96-
lastRequest.add(FlatChatMessage.from(new ToolMessage(toolCall.id(),
97-
this.syncToolCall.call(toolCall.name(), toolCall.arguments(), toolContext))));
98-
});
106+
lastRequest.addAll(this.callTools(input.toolCalls(), ctx));
99107
return input;
100108
}
101109

110+
private List<ChatMessage> callTools(List<ToolCall> toolCalls, StateContext ctx) {
111+
if (CollectionUtils.isEmpty(toolCalls)) {
112+
return Collections.emptyList();
113+
}
114+
List<ToolInfo> tools = ctx.getState(AippConst.TOOLS_KEY);
115+
Validation.notEmpty(tools, "Missing tool detected during call.");
116+
Map<String, ToolInfo> toolsMap = tools.stream().collect(Collectors.toMap(ToolInfo::name, Function.identity()));
117+
Map<String, Object> toolContext =
118+
ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY), Collections::emptyMap);
119+
return toolCalls.stream()
120+
.map(toolCall -> this.callTool(toolCall, toolsMap, toolContext))
121+
.collect(Collectors.toList());
122+
}
123+
124+
private ChatMessage callTool(ToolCall toolCall, Map<String, ToolInfo> toolsMap, Map<String, Object> toolContext) {
125+
ToolInfo toolInfo = toolsMap.get(toolCall.name());
126+
if (toolInfo == null) {
127+
throw new IllegalStateException(String.format("The tool call's tool is not exist. [toolName=%s]",
128+
toolCall.name()));
129+
}
130+
Map<String, Object> extensions = Validation.notNull(toolInfo.extensions(),
131+
"The tool call's extension is not exist. [toolName={0}]", toolCall.name());
132+
String toolRealName = Validation.notBlank(ObjectUtils.cast(extensions.get(AippConst.TOOL_REAL_NAME)),
133+
"Can not find the tool real name. [toolName={0}]",
134+
toolCall.name());
135+
Map<String, Object> mcpServerConfig = ObjectUtils.cast(extensions.get(AippConst.MCP_SERVER_KEY));
136+
if (mcpServerConfig != null) {
137+
String url = Validation.notBlank(ObjectUtils.cast(mcpServerConfig.get(AippConst.MCP_SERVER_URL_KEY)),
138+
"The mcp url should not be empty.");
139+
try (McpClient mcpClient = this.mcpClientFactory.create(McpUtils.getBaseUrl(url),
140+
McpUtils.getSseEndpoint(url))) {
141+
mcpClient.initialize();
142+
Object result = mcpClient.callTool(toolRealName, JSONObject.parseObject(toolCall.arguments()));
143+
return new ToolMessage(toolCall.id(), JSON.toJSONString(result));
144+
} catch (IOException exception) {
145+
throw new AippException(AippErrCode.CALL_MCP_SERVER_FAILED, exception.getMessage());
146+
}
147+
}
148+
return new ToolMessage(toolCall.id(), this.syncToolCall.call(toolRealName, toolCall.arguments(), toolContext));
149+
}
150+
102151
private ChatMessages getAgentMsg(ChatMessage input, StateContext ctx) {
103152
Validation.notNull(ctx, "The state context cannot be null.");
104153
return ctx.getState(this.agentMsgKey);
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+
* This file is a part of the ModelEngine Project.
4+
* Licensed under the MIT License. See License.txt in the project root for license information.
5+
*--------------------------------------------------------------------------------------------*/
6+
7+
package modelengine.fit.jober.aipp.fel;
8+
9+
import modelengine.fel.core.chat.ChatMessage;
10+
import modelengine.fel.core.chat.ChatModel;
11+
import modelengine.fel.core.chat.ChatOption;
12+
import modelengine.fel.core.chat.Prompt;
13+
import modelengine.fel.core.chat.support.AiMessage;
14+
import modelengine.fel.core.chat.support.ChatMessages;
15+
import modelengine.fel.core.chat.support.HumanMessage;
16+
import modelengine.fel.core.tool.ToolCall;
17+
import modelengine.fel.core.tool.ToolInfo;
18+
import modelengine.fel.engine.flows.AiProcessFlow;
19+
import modelengine.fel.tool.mcp.client.McpClient;
20+
import modelengine.fel.tool.mcp.client.McpClientFactory;
21+
import modelengine.fit.jade.tool.SyncToolCall;
22+
import modelengine.fit.jober.aipp.constants.AippConst;
23+
import modelengine.fitframework.flowable.Choir;
24+
import modelengine.fitframework.util.MapBuilder;
25+
26+
import org.apache.commons.collections.CollectionUtils;
27+
import org.junit.jupiter.api.Test;
28+
import org.junit.jupiter.api.extension.ExtendWith;
29+
import org.mockito.Mock;
30+
import org.mockito.junit.jupiter.MockitoExtension;
31+
32+
import java.util.Collections;
33+
import java.util.HashMap;
34+
import java.util.List;
35+
import java.util.Map;
36+
import java.util.concurrent.atomic.AtomicReference;
37+
38+
import static org.junit.jupiter.api.Assertions.*;
39+
import static org.mockito.ArgumentMatchers.any;
40+
import static org.mockito.Mockito.doAnswer;
41+
import static org.mockito.Mockito.mock;
42+
import static org.mockito.Mockito.times;
43+
import static org.mockito.Mockito.verify;
44+
import static org.mockito.Mockito.when;
45+
46+
/**
47+
* {@link WaterFlowAgent} 的测试。
48+
*/
49+
@ExtendWith(MockitoExtension.class)
50+
class WaterFlowAgentTest {
51+
private static final String TEXT_STEP = "textStep";
52+
private static final String TOOL_CALL_STEP = "toolCallStep";
53+
54+
@Mock
55+
private SyncToolCall syncToolCall;
56+
@Mock
57+
private ChatModel chatModel;
58+
@Mock
59+
private McpClientFactory mcpClientFactory;
60+
61+
@Test
62+
void shouldGetResultWhenRunFlowGivenNoToolCall() {
63+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
64+
65+
String expectResult = "0123";
66+
doAnswer(invocation -> Choir.create(emitter -> {
67+
for (int i = 0; i < 4; i++) {
68+
emitter.emit(new AiMessage(String.valueOf(i)));
69+
}
70+
emitter.complete();
71+
})).when(chatModel).generate(any(), any());
72+
73+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
74+
ChatMessage result = flow.converse()
75+
.bind(ChatOption.custom().build())
76+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
77+
78+
assertEquals(expectResult, result.text());
79+
}
80+
81+
@Test
82+
void shouldGetResultWhenRunFlowGivenStoreToolCall() {
83+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
84+
85+
String expectResult = "tool result:0123";
86+
String realName = "realName";
87+
ToolInfo toolInfo = buildToolInfo(realName);
88+
ToolCall toolCall = ToolCall.custom().id("id").name(toolInfo.name()).arguments("{}").build();
89+
List<ToolCall> toolCalls = Collections.singletonList(toolCall);
90+
AtomicReference<String> step = new AtomicReference<>(TOOL_CALL_STEP);
91+
doAnswer(invocation -> {
92+
Prompt prompt = invocation.getArgument(0);
93+
Choir<Object> result = mockGenerateResult(step.get(), toolCalls, prompt);
94+
step.set(TEXT_STEP);
95+
return result;
96+
}).when(chatModel).generate(any(), any());
97+
Map<String, Object> toolContext = MapBuilder.<String, Object>get().put("key", "value").build();
98+
when(this.syncToolCall.call(realName, toolCall.arguments(), toolContext)).thenReturn("tool result:");
99+
100+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
101+
ChatMessage result = flow.converse()
102+
.bind(ChatOption.custom().build())
103+
.bind(AippConst.TOOL_CONTEXT_KEY, toolContext)
104+
.bind(AippConst.TOOLS_KEY, Collections.singletonList(toolInfo))
105+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
106+
107+
verify(this.mcpClientFactory, times(0)).create(any(), any());
108+
assertEquals(expectResult, result.text());
109+
}
110+
111+
@Test
112+
void shouldGetResultWhenRunFlowGivenMcpToolCall() {
113+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
114+
115+
String expectResult = "\"tool result:\"0123";
116+
String realName = "realName";
117+
String baseUrl = "http://localhost";
118+
String sseEndpoint = "/sse";
119+
ToolInfo toolInfo = buildMcpToolInfo(realName, baseUrl, sseEndpoint);
120+
ToolCall toolCall = ToolCall.custom().id("id").name(toolInfo.name()).arguments("{}").build();
121+
List<ToolCall> toolCalls = Collections.singletonList(toolCall);
122+
AtomicReference<String> step = new AtomicReference<>(TOOL_CALL_STEP);
123+
doAnswer(invocation -> {
124+
Prompt prompt = invocation.getArgument(0);
125+
Choir<Object> result = mockGenerateResult(step.get(), toolCalls, prompt);
126+
step.set(TEXT_STEP);
127+
return result;
128+
}).when(chatModel).generate(any(), any());
129+
Map<String, Object> toolContext = MapBuilder.<String, Object>get().put("key", "value").build();
130+
McpClient mcpClient = mock(McpClient.class);
131+
when(this.mcpClientFactory.create(baseUrl, sseEndpoint)).thenReturn(mcpClient);
132+
when(mcpClient.callTool(realName, new HashMap<>())).thenReturn("tool result:");
133+
134+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
135+
ChatMessage result = flow.converse()
136+
.bind(ChatOption.custom().build())
137+
.bind(AippConst.TOOL_CONTEXT_KEY, toolContext)
138+
.bind(AippConst.TOOLS_KEY, Collections.singletonList(toolInfo))
139+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
140+
141+
verify(this.syncToolCall, times(0)).call(any(), any(), any());
142+
assertEquals(expectResult, result.text());
143+
}
144+
145+
private static Choir<Object> mockGenerateResult(String step, List<ToolCall> toolCalls, Prompt prompt) {
146+
return Choir.create(emitter -> {
147+
if (TOOL_CALL_STEP.equals(step)) {
148+
emitter.emit(new AiMessage("tool_data", toolCalls));
149+
emitter.complete();
150+
return;
151+
}
152+
if (CollectionUtils.isNotEmpty(prompt.messages())) {
153+
emitter.emit(new AiMessage(prompt.messages().get(prompt.messages().size() - 1).text()));
154+
}
155+
for (int i = 0; i < 4; i++) {
156+
emitter.emit(new AiMessage(String.valueOf(i)));
157+
}
158+
emitter.complete();
159+
});
160+
}
161+
162+
private static ToolInfo buildToolInfo(String realName) {
163+
return ToolInfo.custom()
164+
.name("tool1")
165+
.description("desc")
166+
.parameters(new HashMap<>())
167+
.extensions(MapBuilder.<String, Object>get().put(AippConst.TOOL_REAL_NAME, realName).build())
168+
.build();
169+
}
170+
171+
private static ToolInfo buildMcpToolInfo(String realName, String baseUrl, String sseEndpoint) {
172+
return ToolInfo.custom()
173+
.name("tool1")
174+
.description("desc")
175+
.parameters(new HashMap<>())
176+
.extensions(MapBuilder.<String, Object>get()
177+
.put(AippConst.TOOL_REAL_NAME, realName)
178+
.put(AippConst.MCP_SERVER_KEY,
179+
MapBuilder.get().put(AippConst.MCP_SERVER_URL_KEY, baseUrl + sseEndpoint).build())
180+
.build())
181+
.build();
182+
}
183+
}

app-builder/jane/plugins/aipp-plugin/src/test/java/modelengine/fit/jober/aipp/fitable/LlmComponentTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ protected AiProcessFlow<Prompt, ChatMessage> buildFlow() {
166166
}
167167

168168
private AbstractAgent getWaterFlowAgent(ChatModel model) {
169-
return new WaterFlowAgent(this.syncToolCall, model);
169+
return new WaterFlowAgent(this.syncToolCall, model, this.mcpClientFactory);
170170
}
171171

172172
private ChatModel buildChatStreamModel(String exceptionMsg) {

0 commit comments

Comments
 (0)