1+ package modelengine .fit .jober .aipp .fel ;
2+
3+ import modelengine .fel .core .chat .ChatMessage ;
4+ import modelengine .fel .core .chat .ChatModel ;
5+ import modelengine .fel .core .chat .ChatOption ;
6+ import modelengine .fel .core .chat .Prompt ;
7+ import modelengine .fel .core .chat .support .AiMessage ;
8+ import modelengine .fel .core .chat .support .ChatMessages ;
9+ import modelengine .fel .core .chat .support .HumanMessage ;
10+ import modelengine .fel .core .tool .ToolCall ;
11+ import modelengine .fel .core .tool .ToolInfo ;
12+ import modelengine .fel .engine .flows .AiProcessFlow ;
13+ import modelengine .fel .tool .mcp .client .McpClient ;
14+ import modelengine .fel .tool .mcp .client .McpClientFactory ;
15+ import modelengine .fit .jade .tool .SyncToolCall ;
16+ import modelengine .fit .jober .aipp .constants .AippConst ;
17+ import modelengine .fitframework .flowable .Choir ;
18+ import modelengine .fitframework .util .MapBuilder ;
19+
20+ import org .apache .commons .collections .CollectionUtils ;
21+ import org .jetbrains .annotations .NotNull ;
22+ import org .junit .jupiter .api .Assertions ;
23+ import org .junit .jupiter .api .Test ;
24+ import org .junit .jupiter .api .extension .ExtendWith ;
25+ import org .mockito .Mock ;
26+ import org .mockito .junit .jupiter .MockitoExtension ;
27+
28+ import java .util .Arrays ;
29+ import java .util .Collections ;
30+ import java .util .HashMap ;
31+ import java .util .List ;
32+ import java .util .Map ;
33+ import java .util .concurrent .atomic .AtomicInteger ;
34+
35+ import static org .junit .jupiter .api .Assertions .*;
36+ import static org .mockito .ArgumentMatchers .any ;
37+ import static org .mockito .Mockito .doAnswer ;
38+ import static org .mockito .Mockito .mock ;
39+ import static org .mockito .Mockito .times ;
40+ import static org .mockito .Mockito .verify ;
41+ import static org .mockito .Mockito .when ;
42+
43+ @ ExtendWith (MockitoExtension .class )
44+ class WaterFlowAgentTest {
45+ @ Mock
46+ private SyncToolCall syncToolCall ;
47+
48+ @ Mock
49+ private ChatModel chatModel ;
50+
51+ @ Mock
52+ private McpClientFactory mcpClientFactory ;
53+
54+
55+ @ Test
56+ void shouldGetResultWhenRunFlowGivenNoToolCall () {
57+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
58+
59+ String expectResult = "0123" ;
60+ doAnswer (invocation -> Choir .create (emitter -> {
61+ for (int i = 0 ; i < 4 ; i ++) {
62+ emitter .emit (new AiMessage (String .valueOf (i )));
63+ }
64+ emitter .complete ();
65+ })).when (chatModel ).generate (any (), any ());
66+
67+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
68+ ChatMessage result = flow .converse ()
69+ .bind (ChatOption .custom ().build ())
70+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
71+
72+ assertEquals (expectResult , result .text ());
73+ }
74+
75+ @ Test
76+ void shouldGetResultWhenRunFlowGivenStoreToolCall () {
77+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
78+
79+ String expectResult = "tool result:0123" ;
80+ String realName = "realName" ;
81+ ToolInfo toolInfo = buildToolInfo (realName );
82+ ToolCall toolCall = ToolCall .custom ().id ("id" ).name (toolInfo .name ()).arguments ("{}" ).build ();
83+ List <ToolCall > toolCalls = Collections .singletonList (toolCall );
84+ AtomicInteger step = new AtomicInteger ();
85+ doAnswer (invocation -> {
86+ Prompt prompt = invocation .getArgument (0 );
87+ return mockGenerateResult (step , toolCalls , prompt );
88+ }).when (chatModel ).generate (any (), any ());
89+ Map <String , Object > toolContext = MapBuilder .<String , Object >get ().put ("key" , "value" ).build ();
90+ when (this .syncToolCall .call (realName , toolCall .arguments (), toolContext )).thenReturn ("tool result:" );
91+
92+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
93+ ChatMessage result = flow .converse ()
94+ .bind (ChatOption .custom ().build ())
95+ .bind (AippConst .TOOL_CONTEXT_KEY , toolContext )
96+ .bind (AippConst .TOOLS_KEY , Collections .singletonList (toolInfo ))
97+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
98+
99+ verify (this .mcpClientFactory , times (0 )).create (any (), any ());
100+ assertEquals (expectResult , result .text ());
101+ }
102+
103+ @ Test
104+ void shouldGetResultWhenRunFlowGivenMcpToolCall () {
105+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
106+
107+ String expectResult = "\" tool result:\" 0123" ;
108+ String realName = "realName" ;
109+ String baseUrl = "http://localhost" ;
110+ String sseEndpoint = "/sse" ;
111+ ToolInfo toolInfo = buildMcpToolInfo (realName , baseUrl , sseEndpoint );
112+ ToolCall toolCall = ToolCall .custom ().id ("id" ).name (toolInfo .name ()).arguments ("{}" ).build ();
113+ List <ToolCall > toolCalls = Collections .singletonList (toolCall );
114+ AtomicInteger step = new AtomicInteger ();
115+ doAnswer (invocation -> {
116+ Prompt prompt = invocation .getArgument (0 );
117+ return mockGenerateResult (step , toolCalls , prompt );
118+ }).when (chatModel ).generate (any (), any ());
119+ Map <String , Object > toolContext = MapBuilder .<String , Object >get ().put ("key" , "value" ).build ();
120+ McpClient mcpClient = mock (McpClient .class );
121+ when (this .mcpClientFactory .create (baseUrl , sseEndpoint )).thenReturn (mcpClient );
122+ when (mcpClient .callTool (realName , new HashMap <>())).thenReturn ("tool result:" );
123+
124+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
125+ ChatMessage result = flow .converse ()
126+ .bind (ChatOption .custom ().build ())
127+ .bind (AippConst .TOOL_CONTEXT_KEY , toolContext )
128+ .bind (AippConst .TOOLS_KEY , Collections .singletonList (toolInfo ))
129+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
130+
131+ verify (this .syncToolCall , times (0 )).call (any (), any (), any ());
132+ assertEquals (expectResult , result .text ());
133+ }
134+
135+ private static Choir <Object > mockGenerateResult (AtomicInteger step , List <ToolCall > toolCalls , Prompt prompt ) {
136+ return Choir .create (emitter -> {
137+
138+ if (step .getAndIncrement () == 0 ) {
139+ emitter .emit (new AiMessage ("tool_data" , toolCalls ));
140+ emitter .complete ();
141+ return ;
142+ }
143+ if (CollectionUtils .isNotEmpty (prompt .messages ())) {
144+ emitter .emit (new AiMessage (prompt .messages ().get (prompt .messages ().size () - 1 ).text ()));
145+ }
146+ for (int i = 0 ; i < 4 ; i ++) {
147+ emitter .emit (new AiMessage (String .valueOf (i )));
148+ }
149+ emitter .complete ();
150+ });
151+ }
152+
153+ private static ToolInfo buildToolInfo (String realName ) {
154+ return ToolInfo .custom ()
155+ .name ("tool1" )
156+ .description ("desc" )
157+ .parameters (new HashMap <>())
158+ .extensions (MapBuilder .<String , Object >get ().put (AippConst .TOOL_REAL_NAME , realName ).build ())
159+ .build ();
160+ }
161+
162+ private static ToolInfo buildMcpToolInfo (String realName , String baseUrl , String sseEndpoint ) {
163+ return ToolInfo .custom ()
164+ .name ("tool1" )
165+ .description ("desc" )
166+ .parameters (new HashMap <>())
167+ .extensions (MapBuilder .<String , Object >get ()
168+ .put (AippConst .TOOL_REAL_NAME , realName )
169+ .put (AippConst .MCP_SERVER_KEY ,
170+ MapBuilder .get ().put (AippConst .MCP_SERVER_URL_KEY , baseUrl + sseEndpoint ).build ())
171+ .build ())
172+ .build ();
173+ }
174+ }
0 commit comments