Skip to content

Commit e56218e

Browse files
committed
fix(agent): preserve cumulative stream history after tool calls
1 parent 8b2f451 commit e56218e

2 files changed

Lines changed: 241 additions & 4 deletions

File tree

agentscope-core/src/main/java/io/agentscope/core/agent/StreamingHook.java

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
import io.agentscope.core.message.ContentBlock;
2727
import io.agentscope.core.message.Msg;
2828
import io.agentscope.core.message.MsgRole;
29+
import io.agentscope.core.message.TextBlock;
30+
import io.agentscope.core.message.ThinkingBlock;
2931
import io.agentscope.core.message.ToolResultBlock;
32+
import io.agentscope.core.message.ToolUseBlock;
3033
import java.util.ArrayList;
3134
import java.util.HashMap;
3235
import java.util.List;
@@ -48,6 +51,10 @@ class StreamingHook implements Hook {
4851
// Track previous content for incremental mode
4952
private final Map<String, List<ContentBlock>> previousContent = new HashMap<>();
5053

54+
// Track cumulative reasoning content across ReAct reasoning/acting boundaries.
55+
private final List<ContentBlock> cumulativeReasoningContent = new ArrayList<>();
56+
private final Map<String, Integer> cumulativeReasoningPositions = new HashMap<>();
57+
5158
/**
5259
* Creates a new streaming hook.
5360
*
@@ -67,7 +74,11 @@ public <T extends HookEvent> Mono<T> onEvent(T event) {
6774
// This is the last/complete message
6875
if (options.shouldStream(EventType.REASONING)
6976
&& options.shouldIncludeReasoningEmission(false)) {
70-
emitEvent(EventType.REASONING, e.getReasoningMessage(), true);
77+
Msg msgToEmit =
78+
options.isIncremental()
79+
? e.getReasoningMessage()
80+
: accumulateReasoning(e.getReasoningMessage());
81+
emitEvent(EventType.REASONING, msgToEmit, true);
7182
}
7283
return Mono.just(event);
7384
} else if (event instanceof ReasoningChunkEvent) {
@@ -77,7 +88,9 @@ public <T extends HookEvent> Mono<T> onEvent(T event) {
7788
&& options.shouldIncludeReasoningEmission(true)) {
7889
// Use incremental or accumulated based on StreamOptions
7990
Msg msgToEmit =
80-
options.isIncremental() ? e.getIncrementalChunk() : e.getAccumulated();
91+
options.isIncremental()
92+
? e.getIncrementalChunk()
93+
: accumulateReasoning(e.getAccumulated());
8194
emitEvent(EventType.REASONING, msgToEmit, false);
8295
}
8396
return Mono.just(event);
@@ -136,6 +149,47 @@ private Msg createToolMessage(ToolResultBlock toolResultBlock) {
136149
.build();
137150
}
138151

152+
private Msg accumulateReasoning(Msg reasoningMsg) {
153+
for (int index = 0; index < reasoningMsg.getContent().size(); index++) {
154+
ContentBlock block = reasoningMsg.getContent().get(index);
155+
String key = reasoningContentKey(reasoningMsg.getId(), block, index);
156+
Integer position = cumulativeReasoningPositions.get(key);
157+
158+
if (position == null) {
159+
cumulativeReasoningPositions.put(key, cumulativeReasoningContent.size());
160+
cumulativeReasoningContent.add(block);
161+
} else {
162+
cumulativeReasoningContent.set(position, block);
163+
}
164+
}
165+
166+
return Msg.builder()
167+
.id(reasoningMsg.getId())
168+
.name(reasoningMsg.getName())
169+
.role(reasoningMsg.getRole())
170+
.content(new ArrayList<>(cumulativeReasoningContent))
171+
.metadata(new HashMap<>(reasoningMsg.getMetadata()))
172+
.timestamp(reasoningMsg.getTimestamp())
173+
.build();
174+
}
175+
176+
private String reasoningContentKey(String messageId, ContentBlock block, int index) {
177+
if (block instanceof ThinkingBlock) {
178+
return messageId + ":thinking";
179+
}
180+
if (block instanceof TextBlock) {
181+
return messageId + ":text";
182+
}
183+
if (block instanceof ToolUseBlock toolUseBlock) {
184+
String toolCallId = toolUseBlock.getId();
185+
if (toolCallId == null || toolCallId.isBlank()) {
186+
toolCallId = toolUseBlock.getName() + ":" + index;
187+
}
188+
return messageId + ":tool:" + toolCallId;
189+
}
190+
return messageId + ":" + block.getClass().getName() + ":" + index;
191+
}
192+
139193
/**
140194
* Emit an event to the sink.
141195
*
@@ -160,4 +214,4 @@ private void emitEvent(EventType type, Msg msg, boolean isLast) {
160214
previousContent.remove(msg.getId());
161215
}
162216
}
163-
}
217+
}

agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentTest.java

Lines changed: 184 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import io.agentscope.core.message.Msg;
3737
import io.agentscope.core.message.MsgRole;
3838
import io.agentscope.core.message.TextBlock;
39+
import io.agentscope.core.message.ThinkingBlock;
3940
import io.agentscope.core.message.ToolResultBlock;
4041
import io.agentscope.core.message.ToolUseBlock;
4142
import io.agentscope.core.model.ChatResponse;
@@ -807,6 +808,188 @@ public <T extends HookEvent> Mono<T> onEvent(T event) {
807808
"call_stream_1", accumulatedTub.getId(), "Accumulated tool call ID should match");
808809
}
809810

811+
@Test
812+
@DisplayName("Should keep cumulative reasoning chunks across tool calls")
813+
void testCumulativeReasoningStreamKeepsHistoryAfterToolCall() {
814+
MockModel toolModel = createTwoRoundStreamingModel();
815+
816+
agent =
817+
ReActAgent.builder()
818+
.name(TestConstants.TEST_REACT_AGENT_NAME)
819+
.sysPrompt(TestConstants.DEFAULT_SYS_PROMPT)
820+
.model(toolModel)
821+
.toolkit(mockToolkit)
822+
.memory(memory)
823+
.build();
824+
825+
StreamOptions options =
826+
StreamOptions.builder()
827+
.eventTypes(EventType.REASONING)
828+
.incremental(false)
829+
.includeReasoningResult(false)
830+
.build();
831+
832+
List<Event> events =
833+
agent.stream(
834+
TestUtils.createUserMessage("User", "Use a tool, then continue."),
835+
options)
836+
.collectList()
837+
.block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS));
838+
839+
assertNotNull(events, "Streaming events should not be null");
840+
841+
List<Msg> reasoningChunks =
842+
events.stream()
843+
.filter(event -> event.getType() == EventType.REASONING)
844+
.filter(event -> !event.isLast())
845+
.map(Event::getMessage)
846+
.toList();
847+
848+
assertEquals(5, reasoningChunks.size(), "Should emit every streamed reasoning chunk");
849+
850+
Msg finalCumulativeChunk = reasoningChunks.get(reasoningChunks.size() - 1);
851+
List<ContentBlock> cumulativeContent = finalCumulativeChunk.getContent();
852+
853+
assertTrue(
854+
cumulativeContent.stream()
855+
.filter(ThinkingBlock.class::isInstance)
856+
.map(ThinkingBlock.class::cast)
857+
.anyMatch(block -> block.getThinking().contains("think before tool.")),
858+
"Cumulative mode should keep pre-tool thinking content");
859+
assertTrue(
860+
cumulativeContent.stream()
861+
.filter(TextBlock.class::isInstance)
862+
.map(TextBlock.class::cast)
863+
.anyMatch(block -> block.getText().contains("text before tool.")),
864+
"Cumulative mode should keep pre-tool text content");
865+
assertTrue(
866+
cumulativeContent.stream()
867+
.filter(ToolUseBlock.class::isInstance)
868+
.map(ToolUseBlock.class::cast)
869+
.anyMatch(block -> "call_stream_reset_1".equals(block.getId())),
870+
"Cumulative mode should keep the tool call that split reasoning rounds");
871+
assertTrue(
872+
cumulativeContent.stream()
873+
.filter(ThinkingBlock.class::isInstance)
874+
.map(ThinkingBlock.class::cast)
875+
.anyMatch(block -> block.getThinking().contains("think after tool.")),
876+
"Cumulative mode should include post-tool thinking content");
877+
assertTrue(
878+
cumulativeContent.stream()
879+
.filter(TextBlock.class::isInstance)
880+
.map(TextBlock.class::cast)
881+
.anyMatch(block -> block.getText().contains("text after tool.")),
882+
"Cumulative mode should include post-tool text content");
883+
}
884+
885+
@Test
886+
@DisplayName("Should keep incremental reasoning chunks as deltas after tool calls")
887+
void testIncrementalReasoningStreamStillEmitsDeltasAfterToolCall() {
888+
MockModel toolModel = createTwoRoundStreamingModel();
889+
890+
agent =
891+
ReActAgent.builder()
892+
.name(TestConstants.TEST_REACT_AGENT_NAME)
893+
.sysPrompt(TestConstants.DEFAULT_SYS_PROMPT)
894+
.model(toolModel)
895+
.toolkit(mockToolkit)
896+
.memory(memory)
897+
.build();
898+
899+
StreamOptions options =
900+
StreamOptions.builder()
901+
.eventTypes(EventType.REASONING)
902+
.incremental(true)
903+
.includeReasoningResult(false)
904+
.build();
905+
906+
List<Event> events =
907+
agent.stream(
908+
TestUtils.createUserMessage("User", "Use a tool, then continue."),
909+
options)
910+
.collectList()
911+
.block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS));
912+
913+
assertNotNull(events, "Streaming events should not be null");
914+
915+
List<Msg> reasoningChunks =
916+
events.stream()
917+
.filter(event -> event.getType() == EventType.REASONING)
918+
.filter(event -> !event.isLast())
919+
.map(Event::getMessage)
920+
.toList();
921+
922+
assertEquals(5, reasoningChunks.size(), "Should emit every streamed reasoning chunk");
923+
924+
Msg finalIncrementalChunk = reasoningChunks.get(reasoningChunks.size() - 1);
925+
List<ContentBlock> incrementalContent = finalIncrementalChunk.getContent();
926+
927+
assertEquals(1, incrementalContent.size(), "Incremental mode should emit only the delta");
928+
TextBlock textBlock = assertInstanceOf(TextBlock.class, incrementalContent.get(0));
929+
assertEquals("text after tool.", textBlock.getText());
930+
}
931+
932+
private static MockModel createTwoRoundStreamingModel() {
933+
final int[] callCount = {0};
934+
return new MockModel(
935+
messages -> {
936+
int currentCall = callCount[0]++;
937+
if (currentCall == 0) {
938+
return List.of(
939+
ChatResponse.builder()
940+
.id("reasoning-round-1")
941+
.content(
942+
List.of(
943+
ThinkingBlock.builder()
944+
.thinking("think before tool. ")
945+
.build()))
946+
.usage(new ChatUsage(10, 20, 30))
947+
.build(),
948+
ChatResponse.builder()
949+
.id("reasoning-round-1")
950+
.content(
951+
List.of(
952+
TextBlock.builder()
953+
.text("text before tool. ")
954+
.build()))
955+
.usage(new ChatUsage(10, 20, 30))
956+
.build(),
957+
ChatResponse.builder()
958+
.id("reasoning-round-1")
959+
.content(
960+
List.of(
961+
ToolUseBlock.builder()
962+
.id("call_stream_reset_1")
963+
.name(TestConstants.TEST_TOOL_NAME)
964+
.input(Map.of())
965+
.content("{}")
966+
.build()))
967+
.usage(new ChatUsage(10, 20, 30))
968+
.build());
969+
}
970+
971+
return List.of(
972+
ChatResponse.builder()
973+
.id("reasoning-round-2")
974+
.content(
975+
List.of(
976+
ThinkingBlock.builder()
977+
.thinking("think after tool. ")
978+
.build()))
979+
.usage(new ChatUsage(10, 20, 30))
980+
.build(),
981+
ChatResponse.builder()
982+
.id("reasoning-round-2")
983+
.content(
984+
List.of(
985+
TextBlock.builder()
986+
.text("text after tool.")
987+
.build()))
988+
.usage(new ChatUsage(10, 20, 30))
989+
.build());
990+
});
991+
}
992+
810993
@Test
811994
@DisplayName("Should emit ReasoningChunkEvent for multiple parallel tool calls")
812995
void testStreamingMultipleToolCallsChunkEvents() {
@@ -1051,4 +1234,4 @@ private static ChatResponse createToolCallResponseHelper(
10511234
.usage(new ChatUsage(8, 15, 23))
10521235
.build();
10531236
}
1054-
}
1237+
}

0 commit comments

Comments
 (0)