Skip to content

Commit 0b16489

Browse files
committed
fix(tool): support Flux-returning tool methods
1 parent de01c66 commit 0b16489

4 files changed

Lines changed: 256 additions & 2 deletions

File tree

agentscope-core/src/main/java/io/agentscope/core/tool/ToolMethodInvoker.java

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
import java.lang.reflect.Parameter;
2626
import java.lang.reflect.ParameterizedType;
2727
import java.lang.reflect.Type;
28+
import java.util.List;
2829
import java.util.Map;
2930
import java.util.concurrent.CompletableFuture;
31+
import reactor.core.publisher.Flux;
3032
import reactor.core.publisher.Mono;
3133

3234
/**
@@ -105,6 +107,9 @@ r, extractGenericType(method)))
105107
.onErrorResume(this::handleError))
106108
.onErrorResume(this::handleError);
107109

110+
} else if (returnType == Flux.class) {
111+
return invokeFlux(toolObject, method, input, agent, context, emitter, converter);
112+
108113
} else {
109114
// Sync method: wrap in Mono.fromCallable
110115
return Mono.fromCallable(
@@ -119,6 +124,101 @@ r, extractGenericType(method)))
119124
}
120125
}
121126

127+
private Mono<ToolResultBlock> invokeFlux(
128+
Object toolObject,
129+
Method method,
130+
Map<String, Object> input,
131+
Agent agent,
132+
ToolExecutionContext context,
133+
ToolEmitter emitter,
134+
ToolResultConverter converter) {
135+
Type itemType = extractGenericType(method);
136+
137+
return Mono.fromCallable(
138+
() -> {
139+
method.setAccessible(true);
140+
Object[] args =
141+
convertParameters(method, input, agent, context, emitter);
142+
@SuppressWarnings("unchecked")
143+
Flux<Object> flux = (Flux<Object>) method.invoke(toolObject, args);
144+
return flux != null ? flux : Flux.empty();
145+
})
146+
.flatMap(
147+
flux ->
148+
flux.doOnNext(
149+
item ->
150+
emitFluxChunk(
151+
emitter, converter, item, itemType))
152+
.collectList()
153+
.map(
154+
items ->
155+
converter.convert(
156+
aggregateFluxItems(items, itemType),
157+
resolveFluxAggregateType(
158+
items, itemType)))
159+
.onErrorResume(this::handleError))
160+
.onErrorResume(this::handleError);
161+
}
162+
163+
private void emitFluxChunk(
164+
ToolEmitter emitter, ToolResultConverter converter, Object item, Type itemType) {
165+
if (item == null) {
166+
return;
167+
}
168+
emitter.emit(toStreamingChunk(item, itemType, converter));
169+
}
170+
171+
private ToolResultBlock toStreamingChunk(
172+
Object item, Type itemType, ToolResultConverter converter) {
173+
if (item instanceof ToolResultBlock) {
174+
return (ToolResultBlock) item;
175+
}
176+
if (item instanceof CharSequence
177+
|| item instanceof Number
178+
|| item instanceof Boolean
179+
|| item instanceof Character) {
180+
return ToolResultBlock.text(String.valueOf(item));
181+
}
182+
return converter.convert(item, itemType);
183+
}
184+
185+
private Object aggregateFluxItems(List<Object> items, Type itemType) {
186+
if (shouldConcatenateFluxItems(items, itemType)) {
187+
StringBuilder aggregated = new StringBuilder();
188+
for (Object item : items) {
189+
if (item != null) {
190+
aggregated.append(item);
191+
}
192+
}
193+
return aggregated.toString();
194+
}
195+
if (items.isEmpty()) {
196+
return null;
197+
}
198+
if (items.size() == 1) {
199+
return items.get(0);
200+
}
201+
return items;
202+
}
203+
204+
private Type resolveFluxAggregateType(List<Object> items, Type itemType) {
205+
if (shouldConcatenateFluxItems(items, itemType)) {
206+
return String.class;
207+
}
208+
if (items.size() == 1) {
209+
return itemType;
210+
}
211+
return List.class;
212+
}
213+
214+
private boolean shouldConcatenateFluxItems(List<Object> items, Type itemType) {
215+
if (itemType == String.class || itemType == CharSequence.class) {
216+
return true;
217+
}
218+
return !items.isEmpty()
219+
&& items.stream().allMatch(item -> item == null || item instanceof CharSequence);
220+
}
221+
122222
/**
123223
* Convert input parameters to method arguments with automatic injection support.
124224
*
@@ -363,7 +463,7 @@ private ToolResultBlock handleInvocationError(Throwable e) {
363463
}
364464

365465
/**
366-
* Extract generic type from method return type (for CompletableFuture<T> or Mono<T>).
466+
* Extract generic type from method return type (for CompletableFuture<T>, Mono<T>, or Flux<T>).
367467
*
368468
* @param method the method
369469
* @return the generic type, or null if not found

agentscope-core/src/test/java/io/agentscope/core/tool/AsyncToolTest.java

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import io.agentscope.core.tool.test.SampleTools;
2727
import io.agentscope.core.util.JsonUtils;
2828
import java.time.Duration;
29+
import java.util.ArrayList;
2930
import java.util.List;
3031
import java.util.Map;
3132
import org.junit.jupiter.api.BeforeEach;
@@ -34,7 +35,7 @@
3435
import org.junit.jupiter.api.Test;
3536

3637
/**
37-
* Tests for async tool execution with CompletableFuture and Mono return types.
38+
* Tests for async tool execution with CompletableFuture, Mono, and Flux return types.
3839
*/
3940
@Tag("unit")
4041
@DisplayName("Async Tool Tests")
@@ -90,6 +91,56 @@ void shouldExecuteMonoAsyncTool() {
9091
assertEquals("\"HelloWorld\"", extractFirstText(response));
9192
}
9293

94+
@Test
95+
@DisplayName("Should execute Flux async tool")
96+
void shouldExecuteFluxAsyncTool() {
97+
Map<String, Object> input = Map.of("str1", "Hello", "str2", "World");
98+
ToolUseBlock toolCall =
99+
ToolUseBlock.builder()
100+
.id("call-async-flux")
101+
.name("async_flux_concat")
102+
.input(input)
103+
.content(JsonUtils.getJsonCodec().toJson(input))
104+
.build();
105+
106+
ToolResultBlock response =
107+
toolkit.callTool(ToolCallParam.builder().toolUseBlock(toolCall).build())
108+
.block(TIMEOUT);
109+
110+
assertNotNull(response, "Response should not be null");
111+
assertEquals("\"HelloWorld\"", extractFirstText(response));
112+
}
113+
114+
@Test
115+
@DisplayName("Should emit Flux chunks while aggregating final tool result")
116+
void shouldEmitFluxChunksWhileAggregatingFinalToolResult() {
117+
List<String> chunkToolIds = new ArrayList<>();
118+
List<String> chunkTexts = new ArrayList<>();
119+
toolkit.setChunkCallback(
120+
(toolUse, chunk) -> {
121+
chunkToolIds.add(toolUse.getId());
122+
chunkTexts.add(extractFirstText(chunk));
123+
});
124+
125+
Map<String, Object> input = Map.of("str1", "Alpha", "str2", "Beta");
126+
ToolUseBlock toolCall =
127+
ToolUseBlock.builder()
128+
.id("call-async-flux-chunk")
129+
.name("async_flux_concat")
130+
.input(input)
131+
.content(JsonUtils.getJsonCodec().toJson(input))
132+
.build();
133+
134+
ToolResultBlock response =
135+
toolkit.callTool(ToolCallParam.builder().toolUseBlock(toolCall).build())
136+
.block(TIMEOUT);
137+
138+
assertNotNull(response, "Response should not be null");
139+
assertEquals(List.of("call-async-flux-chunk", "call-async-flux-chunk"), chunkToolIds);
140+
assertEquals(List.of("Alpha", "Beta"), chunkTexts);
141+
assertEquals("\"AlphaBeta\"", extractFirstText(response));
142+
}
143+
93144
@Test
94145
@DisplayName("Should execute async tool with delay")
95146
void shouldExecuteAsyncToolWithDelay() {

agentscope-core/src/test/java/io/agentscope/core/tool/ToolMethodInvokerTest.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.junit.jupiter.api.Assertions;
3131
import org.junit.jupiter.api.BeforeEach;
3232
import org.junit.jupiter.api.Test;
33+
import reactor.core.publisher.Flux;
3334
import reactor.core.publisher.Mono;
3435

3536
/**
@@ -205,6 +206,26 @@ public Mono<String> suspendToolMonoSync(
205206
@ToolParam(name = "reason", description = "reason") String reason) {
206207
throw new ToolSuspendException(reason);
207208
}
209+
210+
public Flux<String> fluxConcat(
211+
@ToolParam(name = "prefix", description = "prefix") String prefix,
212+
@ToolParam(name = "suffix", description = "suffix") String suffix) {
213+
return Flux.just(prefix, suffix);
214+
}
215+
216+
public Flux<Integer> fluxSingleNumber(
217+
@ToolParam(name = "value", description = "value") Integer value) {
218+
return Flux.just(value);
219+
}
220+
221+
public Flux<Integer> fluxNumbers(
222+
@ToolParam(name = "start", description = "start") Integer start) {
223+
return Flux.just(start, start + 1, start + 2);
224+
}
225+
226+
public Flux<String> emptyFluxString() {
227+
return Flux.empty();
228+
}
208229
}
209230

210231
/** Test POJO for generic type testing (Issue #677). */
@@ -867,6 +888,75 @@ void testGenericMap_WithCustomClassValue() throws Exception {
867888
}
868889

869890
/** Test nested generic types like List&lt;List&lt;Integer&gt;&gt;. */
891+
@Test
892+
void testFluxStringAggregationAndChunkEmission() throws Exception {
893+
TestTools tools = new TestTools();
894+
Method method = TestTools.class.getMethod("fluxConcat", String.class, String.class);
895+
896+
Map<String, Object> input = new HashMap<>();
897+
input.put("prefix", "Hello");
898+
input.put("suffix", "World");
899+
900+
List<String> emittedChunks = new ArrayList<>();
901+
ToolUseBlock toolUseBlock = new ToolUseBlock("flux-id", method.getName(), input);
902+
ToolCallParam param =
903+
ToolCallParam.builder()
904+
.toolUseBlock(toolUseBlock)
905+
.input(input)
906+
.emitter(chunk -> emittedChunks.add(ToolTestUtils.extractContent(chunk)))
907+
.build();
908+
909+
ToolResultBlock response =
910+
invoker.invokeAsync(tools, method, param, responseConverter).block();
911+
912+
Assertions.assertNotNull(response);
913+
Assertions.assertFalse(ToolTestUtils.isErrorResponse(response));
914+
Assertions.assertEquals("\"HelloWorld\"", ToolTestUtils.extractContent(response));
915+
Assertions.assertEquals(List.of("Hello", "World"), emittedChunks);
916+
}
917+
918+
@Test
919+
void testFluxSingleValueAggregation() throws Exception {
920+
TestTools tools = new TestTools();
921+
Method method = TestTools.class.getMethod("fluxSingleNumber", Integer.class);
922+
923+
Map<String, Object> input = new HashMap<>();
924+
input.put("value", 7);
925+
926+
ToolResultBlock response = invokeWithParam(tools, method, input);
927+
928+
Assertions.assertNotNull(response);
929+
Assertions.assertFalse(ToolTestUtils.isErrorResponse(response));
930+
Assertions.assertEquals("7", ToolTestUtils.extractContent(response));
931+
}
932+
933+
@Test
934+
void testFluxMultipleValuesAggregateToJsonArray() throws Exception {
935+
TestTools tools = new TestTools();
936+
Method method = TestTools.class.getMethod("fluxNumbers", Integer.class);
937+
938+
Map<String, Object> input = new HashMap<>();
939+
input.put("start", 3);
940+
941+
ToolResultBlock response = invokeWithParam(tools, method, input);
942+
943+
Assertions.assertNotNull(response);
944+
Assertions.assertFalse(ToolTestUtils.isErrorResponse(response));
945+
Assertions.assertEquals("[3,4,5]", ToolTestUtils.extractContent(response));
946+
}
947+
948+
@Test
949+
void testEmptyFluxStringAggregatesToEmptyString() throws Exception {
950+
TestTools tools = new TestTools();
951+
Method method = TestTools.class.getMethod("emptyFluxString");
952+
953+
ToolResultBlock response = invokeWithParam(tools, method, new HashMap<>());
954+
955+
Assertions.assertNotNull(response);
956+
Assertions.assertFalse(ToolTestUtils.isErrorResponse(response));
957+
Assertions.assertEquals("\"\"", ToolTestUtils.extractContent(response));
958+
}
959+
870960
@Test
871961
void testNestedGenericList() throws Exception {
872962
TestTools tools = new TestTools();

agentscope-core/src/test/java/io/agentscope/core/tool/test/SampleTools.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io.agentscope.core.tool.ToolParam;
2020
import java.time.Duration;
2121
import java.util.concurrent.CompletableFuture;
22+
import reactor.core.publisher.Flux;
2223
import reactor.core.publisher.Mono;
2324

2425
/**
@@ -132,6 +133,18 @@ public Mono<String> asyncConcat(
132133
return Mono.fromCallable(() -> str1 + str2);
133134
}
134135

136+
/**
137+
* Async tool using Flux that streams string chunks.
138+
*/
139+
@Tool(
140+
name = "async_flux_concat",
141+
description = "Asynchronously stream and concatenate two strings")
142+
public Flux<String> asyncFluxConcat(
143+
@ToolParam(name = "str1", description = "First string") String str1,
144+
@ToolParam(name = "str2", description = "Second string") String str2) {
145+
return Flux.just(str1, str2).delayElements(Duration.ofMillis(25));
146+
}
147+
135148
/**
136149
* Async tool using Mono that simulates delay.
137150
*/

0 commit comments

Comments
 (0)