Skip to content

Commit 5353a18

Browse files
committed
fix(hook): resolve message extraction conflict between memory and RAG hooks
When both StaticLongTermMemoryHook and GenericRAGHook are attached to the same agent, they interfered with each other's message extraction logic. Each hook searches for "the last user message" to use as its query, but since both hooks inject their results as USER role messages at the end of the message list, one hook could pick up the other hook's injected message instead of the original user input. Changes: - GenericRAGHook: Change injected message name from "user" to "retrieved_knowledge" for clear identification; skip messages with name "long_term_memory" when extracting user query - StaticLongTermMemoryHook: Skip messages with name "retrieved_knowledge" when extracting user query - Add tests to verify hook interference is resolved Closes #1403
1 parent 13a7167 commit 5353a18

4 files changed

Lines changed: 109 additions & 13 deletions

File tree

agentscope-core/src/main/java/io/agentscope/core/memory/StaticLongTermMemoryHook.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,20 @@ private Mono<PostCallEvent> handlePostCall(PostCallEvent event) {
263263
*
264264
* <p>Scans the message list from end to start to find the most recent user message,
265265
* which is typically the current query that should be used for memory retrieval.
266+
* Skips messages injected by other hooks (e.g., GenericRAGHook with name "retrieved_knowledge").
266267
*
267268
* @param messages the message list
268-
* @return the last user message, or null if none found
269+
* @return the index of the last user message, or -1 if none found
269270
*/
270271
private int extractLastUserMessageIndex(List<Msg> messages) {
271272
for (int i = messages.size() - 1; i >= 0; i--) {
272273
Msg msg = messages.get(i);
273274
if (msg.getRole() == MsgRole.USER) {
275+
// Skip messages injected by other hooks (e.g., RAG knowledge)
276+
String name = msg.getName();
277+
if ("retrieved_knowledge".equals(name)) {
278+
continue;
279+
}
274280
return i;
275281
}
276282
}

agentscope-core/src/main/java/io/agentscope/core/rag/GenericRAGHook.java

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,9 @@ private Mono<PreCallEvent> handlePreCall(PreCallEvent event) {
145145
if (retrievedDocs == null || retrievedDocs.isEmpty()) {
146146
return Mono.just(event);
147147
}
148-
List<Msg> enhancedMessages = new ArrayList<>();
149-
// Build enhanced messages with knowledge context
150-
Msg enhancedMessage = createEnhancedMessages(retrievedDocs);
151-
enhancedMessages.addAll(inputMessages);
152-
enhancedMessages.add(enhancedMessage);
148+
List<Msg> enhancedMessages = new ArrayList<>(inputMessages);
149+
// Build enhanced message with knowledge context
150+
enhancedMessages.add(createEnhancedMessage(retrievedDocs));
153151
event.setInputMessages(enhancedMessages);
154152
return Mono.just(event);
155153
})
@@ -165,7 +163,8 @@ private Mono<PreCallEvent> handlePreCall(PreCallEvent event) {
165163
* Extracts query text from message list.
166164
*
167165
* <p>Finds the last user message as the query source (not just the last message, which could be
168-
* ASSISTANT or TOOL in ReAct loops).
166+
* ASSISTANT or TOOL in ReAct loops). Skips messages injected by other hooks (e.g.,
167+
* StaticLongTermMemoryHook with name "long_term_memory").
169168
*
170169
* @param messages the message list
171170
* @return the extracted query text, or empty string if no user message found
@@ -175,11 +174,15 @@ private String extractQueryFromMessages(List<Msg> messages) {
175174
return "";
176175
}
177176

178-
// Find the last user message (not just the last message, which could be
179-
// ASSISTANT or TOOL in ReAct loops)
177+
// Find the last user message, skipping hook-injected messages
180178
for (int i = messages.size() - 1; i >= 0; i--) {
181179
Msg msg = messages.get(i);
182180
if (msg.getRole() == MsgRole.USER) {
181+
// Skip messages injected by other hooks (e.g., long-term memory)
182+
String name = msg.getName();
183+
if ("long_term_memory".equals(name)) {
184+
continue;
185+
}
183186
return msg.getTextContent();
184187
}
185188
}
@@ -189,16 +192,18 @@ private String extractQueryFromMessages(List<Msg> messages) {
189192
/**
190193
* Creates enhanced message list with knowledge context injected.
191194
*
192-
* <p>The knowledge is injected as a system message at the beginning of the message list.
195+
* <p>The knowledge is injected as a user message appended to the end of the message list.
196+
* The message uses a distinct name "retrieved_knowledge" so that other hooks can identify
197+
* and skip it when extracting the original user query.
193198
*
194199
* @param retrievedDocs the retrieved documents
195-
* @return the enhanced message list with knowledge context
200+
* @return the enhanced message with knowledge context
196201
*/
197-
private Msg createEnhancedMessages(List<Document> retrievedDocs) {
202+
private Msg createEnhancedMessage(List<Document> retrievedDocs) {
198203
String knowledgeContent = buildKnowledgeContent(retrievedDocs);
199204

200205
return Msg.builder()
201-
.name("user")
206+
.name("retrieved_knowledge")
202207
.role(MsgRole.USER)
203208
.content(TextBlock.builder().text(knowledgeContent).build())
204209
.build();

agentscope-core/src/test/java/io/agentscope/core/memory/StaticLongTermMemoryHookTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,4 +338,47 @@ void testOnEventWithPostCallEventAsyncRecordEmptyMemory() {
338338

339339
verify(mockLongTermMemory, never()).record(anyList());
340340
}
341+
342+
@Test
343+
void testOnEventSkipsRAGInjectedMessage() {
344+
// Simulate GenericRAGHook having already injected a "retrieved_knowledge" message
345+
List<Msg> inputMessages = new ArrayList<>();
346+
inputMessages.add(
347+
Msg.builder()
348+
.role(MsgRole.USER)
349+
.content(TextBlock.builder().text("What is the refund policy?").build())
350+
.build());
351+
inputMessages.add(
352+
Msg.builder()
353+
.role(MsgRole.USER)
354+
.name("retrieved_knowledge")
355+
.content(
356+
TextBlock.builder()
357+
.text("<retrieved_knowledge>...</retrieved_knowledge>")
358+
.build())
359+
.build());
360+
361+
PreCallEvent event = new PreCallEvent(mockAgent, inputMessages);
362+
363+
when(mockLongTermMemory.retrieve(any(Msg.class)))
364+
.thenReturn(Mono.just("User prefers dark mode"));
365+
366+
StepVerifier.create(hook.onEvent(event))
367+
.assertNext(
368+
resultEvent -> {
369+
List<Msg> messages = resultEvent.getInputMessages();
370+
assertEquals(3, messages.size());
371+
// The retrieval should use the original user message, not the RAG
372+
// message
373+
assertEquals(
374+
"What is the refund policy?", messages.get(0).getTextContent());
375+
assertEquals("retrieved_knowledge", messages.get(1).getName());
376+
assertEquals(MsgRole.USER, messages.get(2).getRole());
377+
assertTrue(
378+
messages.get(2)
379+
.getTextContent()
380+
.contains("<long_term_memory>"));
381+
})
382+
.verifyComplete();
383+
}
341384
}

agentscope-extensions/agentscope-extensions-rag-simple/src/test/java/io/agentscope/core/rag/hook/GenericRAGHookTest.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ void testHandlePreCallEvent() {
162162
assertEquals(MsgRole.USER, enhancedMessages.get(0).getRole());
163163
// Second message should be user message with knowledge retrieval
164164
assertEquals(MsgRole.USER, enhancedMessages.get(1).getRole());
165+
// Name should be "retrieved_knowledge" to distinguish from real user
166+
// messages
167+
assertEquals("retrieved_knowledge", enhancedMessages.get(1).getName());
165168
assertTrue(
166169
enhancedMessages
167170
.get(1)
@@ -321,6 +324,45 @@ void testFormatKnowledgeContent() {
321324
.verifyComplete();
322325
}
323326

327+
@Test
328+
@DisplayName("Should skip long_term_memory injected messages when extracting query")
329+
void testSkipLongTermMemoryInjectedMessage() {
330+
// Simulate StaticLongTermMemoryHook having already injected a "long_term_memory" message
331+
Document doc = createDocument("doc1", "Refund policy: 30 days return");
332+
knowledge.addDocuments(List.of(doc)).block();
333+
334+
List<Msg> inputMessages = new ArrayList<>();
335+
inputMessages.add(
336+
Msg.builder()
337+
.role(MsgRole.USER)
338+
.content(TextBlock.builder().text("What is the refund policy?").build())
339+
.build());
340+
inputMessages.add(
341+
Msg.builder()
342+
.role(MsgRole.USER)
343+
.name("long_term_memory")
344+
.content(
345+
TextBlock.builder()
346+
.text("<long_term_memory>...</long_term_memory>")
347+
.build())
348+
.build());
349+
350+
PreCallEvent event = new PreCallEvent(mockAgent, inputMessages);
351+
352+
StepVerifier.create(hook.onEvent(event))
353+
.assertNext(
354+
result -> {
355+
List<Msg> messages = result.getInputMessages();
356+
// Should have original user msg + long_term_memory + RAG knowledge
357+
assertEquals(3, messages.size());
358+
// The RAG hook should use the original user query, not the memory
359+
// message
360+
assertEquals("retrieved_knowledge", messages.get(2).getName());
361+
assertTrue(messages.get(2).getTextContent().contains("knowledge base"));
362+
})
363+
.verifyComplete();
364+
}
365+
324366
/**
325367
* Creates a test document.
326368
*/

0 commit comments

Comments
 (0)