1+ /*---------------------------------------------------------------------------------------------
2+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+ * This file is a part of the ModelEngine Project.
4+ * Licensed under the MIT License. See License.txt in the project root for license information.
5+ *--------------------------------------------------------------------------------------------*/
6+
7+ package modelengine .fit .jober .aipp .fel ;
8+
9+ import modelengine .fel .core .chat .ChatMessage ;
10+ import modelengine .fel .core .chat .ChatModel ;
11+ import modelengine .fel .core .chat .ChatOption ;
12+ import modelengine .fel .core .chat .Prompt ;
13+ import modelengine .fel .core .chat .support .AiMessage ;
14+ import modelengine .fel .core .chat .support .ChatMessages ;
15+ import modelengine .fel .core .chat .support .HumanMessage ;
16+ import modelengine .fel .core .tool .ToolCall ;
17+ import modelengine .fel .core .tool .ToolInfo ;
18+ import modelengine .fel .engine .flows .AiProcessFlow ;
19+ import modelengine .fel .tool .mcp .client .McpClient ;
20+ import modelengine .fel .tool .mcp .client .McpClientFactory ;
21+ import modelengine .fit .jade .tool .SyncToolCall ;
22+ import modelengine .fit .jober .aipp .constants .AippConst ;
23+ import modelengine .fitframework .flowable .Choir ;
24+ import modelengine .fitframework .util .MapBuilder ;
25+
26+ import org .apache .commons .collections .CollectionUtils ;
27+ import org .junit .jupiter .api .Test ;
28+ import org .junit .jupiter .api .extension .ExtendWith ;
29+ import org .mockito .Mock ;
30+ import org .mockito .junit .jupiter .MockitoExtension ;
31+
32+ import java .util .Collections ;
33+ import java .util .HashMap ;
34+ import java .util .List ;
35+ import java .util .Map ;
36+ import java .util .concurrent .atomic .AtomicReference ;
37+
38+ import static org .junit .jupiter .api .Assertions .*;
39+ import static org .mockito .ArgumentMatchers .any ;
40+ import static org .mockito .Mockito .doAnswer ;
41+ import static org .mockito .Mockito .mock ;
42+ import static org .mockito .Mockito .times ;
43+ import static org .mockito .Mockito .verify ;
44+ import static org .mockito .Mockito .when ;
45+
46+ /**
47+ * {@link WaterFlowAgent} 的测试。
48+ */
49+ @ ExtendWith (MockitoExtension .class )
50+ class WaterFlowAgentTest {
51+ private static final String TEXT_STEP = "textStep" ;
52+ private static final String TOOL_CALL_STEP = "toolCallStep" ;
53+
54+ @ Mock
55+ private SyncToolCall syncToolCall ;
56+ @ Mock
57+ private ChatModel chatModel ;
58+ @ Mock
59+ private McpClientFactory mcpClientFactory ;
60+
61+ @ Test
62+ void shouldGetResultWhenRunFlowGivenNoToolCall () {
63+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
64+
65+ String expectResult = "0123" ;
66+ doAnswer (invocation -> Choir .create (emitter -> {
67+ for (int i = 0 ; i < 4 ; i ++) {
68+ emitter .emit (new AiMessage (String .valueOf (i )));
69+ }
70+ emitter .complete ();
71+ })).when (chatModel ).generate (any (), any ());
72+
73+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
74+ ChatMessage result = flow .converse ()
75+ .bind (ChatOption .custom ().build ())
76+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
77+
78+ assertEquals (expectResult , result .text ());
79+ }
80+
81+ @ Test
82+ void shouldGetResultWhenRunFlowGivenStoreToolCall () {
83+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
84+
85+ String expectResult = "tool result:0123" ;
86+ String realName = "realName" ;
87+ ToolInfo toolInfo = buildToolInfo (realName );
88+ ToolCall toolCall = ToolCall .custom ().id ("id" ).name (toolInfo .name ()).arguments ("{}" ).build ();
89+ List <ToolCall > toolCalls = Collections .singletonList (toolCall );
90+ AtomicReference <String > step = new AtomicReference <>(TOOL_CALL_STEP );
91+ doAnswer (invocation -> {
92+ Prompt prompt = invocation .getArgument (0 );
93+ Choir <Object > result = mockGenerateResult (step .get (), toolCalls , prompt );
94+ step .set (TEXT_STEP );
95+ return result ;
96+ }).when (chatModel ).generate (any (), any ());
97+ Map <String , Object > toolContext = MapBuilder .<String , Object >get ().put ("key" , "value" ).build ();
98+ when (this .syncToolCall .call (realName , toolCall .arguments (), toolContext )).thenReturn ("tool result:" );
99+
100+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
101+ ChatMessage result = flow .converse ()
102+ .bind (ChatOption .custom ().build ())
103+ .bind (AippConst .TOOL_CONTEXT_KEY , toolContext )
104+ .bind (AippConst .TOOLS_KEY , Collections .singletonList (toolInfo ))
105+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
106+
107+ verify (this .mcpClientFactory , times (0 )).create (any (), any ());
108+ assertEquals (expectResult , result .text ());
109+ }
110+
111+ @ Test
112+ void shouldGetResultWhenRunFlowGivenMcpToolCall () {
113+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
114+
115+ String expectResult = "\" tool result:\" 0123" ;
116+ String realName = "realName" ;
117+ String baseUrl = "http://localhost" ;
118+ String sseEndpoint = "/sse" ;
119+ ToolInfo toolInfo = buildMcpToolInfo (realName , baseUrl , sseEndpoint );
120+ ToolCall toolCall = ToolCall .custom ().id ("id" ).name (toolInfo .name ()).arguments ("{}" ).build ();
121+ List <ToolCall > toolCalls = Collections .singletonList (toolCall );
122+ AtomicReference <String > step = new AtomicReference <>(TOOL_CALL_STEP );
123+ doAnswer (invocation -> {
124+ Prompt prompt = invocation .getArgument (0 );
125+ Choir <Object > result = mockGenerateResult (step .get (), toolCalls , prompt );
126+ step .set (TEXT_STEP );
127+ return result ;
128+ }).when (chatModel ).generate (any (), any ());
129+ Map <String , Object > toolContext = MapBuilder .<String , Object >get ().put ("key" , "value" ).build ();
130+ McpClient mcpClient = mock (McpClient .class );
131+ when (this .mcpClientFactory .create (baseUrl , sseEndpoint )).thenReturn (mcpClient );
132+ when (mcpClient .callTool (realName , new HashMap <>())).thenReturn ("tool result:" );
133+
134+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
135+ ChatMessage result = flow .converse ()
136+ .bind (ChatOption .custom ().build ())
137+ .bind (AippConst .TOOL_CONTEXT_KEY , toolContext )
138+ .bind (AippConst .TOOLS_KEY , Collections .singletonList (toolInfo ))
139+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
140+
141+ verify (this .syncToolCall , times (0 )).call (any (), any (), any ());
142+ assertEquals (expectResult , result .text ());
143+ }
144+
145+ private static Choir <Object > mockGenerateResult (String step , List <ToolCall > toolCalls , Prompt prompt ) {
146+ return Choir .create (emitter -> {
147+ if (TOOL_CALL_STEP .equals (step )) {
148+ emitter .emit (new AiMessage ("tool_data" , toolCalls ));
149+ emitter .complete ();
150+ return ;
151+ }
152+ if (CollectionUtils .isNotEmpty (prompt .messages ())) {
153+ emitter .emit (new AiMessage (prompt .messages ().get (prompt .messages ().size () - 1 ).text ()));
154+ }
155+ for (int i = 0 ; i < 4 ; i ++) {
156+ emitter .emit (new AiMessage (String .valueOf (i )));
157+ }
158+ emitter .complete ();
159+ });
160+ }
161+
162+ private static ToolInfo buildToolInfo (String realName ) {
163+ return ToolInfo .custom ()
164+ .name ("tool1" )
165+ .description ("desc" )
166+ .parameters (new HashMap <>())
167+ .extensions (MapBuilder .<String , Object >get ().put (AippConst .TOOL_REAL_NAME , realName ).build ())
168+ .build ();
169+ }
170+
171+ private static ToolInfo buildMcpToolInfo (String realName , String baseUrl , String sseEndpoint ) {
172+ return ToolInfo .custom ()
173+ .name ("tool1" )
174+ .description ("desc" )
175+ .parameters (new HashMap <>())
176+ .extensions (MapBuilder .<String , Object >get ()
177+ .put (AippConst .TOOL_REAL_NAME , realName )
178+ .put (AippConst .MCP_SERVER_KEY ,
179+ MapBuilder .get ().put (AippConst .MCP_SERVER_URL_KEY , baseUrl + sseEndpoint ).build ())
180+ .build ())
181+ .build ();
182+ }
183+ }
0 commit comments