diff --git a/agentscope-core/src/main/java/io/agentscope/core/memory/StaticLongTermMemoryHook.java b/agentscope-core/src/main/java/io/agentscope/core/memory/StaticLongTermMemoryHook.java index 4886deac7..003172039 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/memory/StaticLongTermMemoryHook.java +++ b/agentscope-core/src/main/java/io/agentscope/core/memory/StaticLongTermMemoryHook.java @@ -263,14 +263,20 @@ private Mono handlePostCall(PostCallEvent event) { * *

Scans the message list from end to start to find the most recent user message, * which is typically the current query that should be used for memory retrieval. + * Skips messages injected by other hooks (e.g., GenericRAGHook with name "retrieved_knowledge"). * * @param messages the message list - * @return the last user message, or null if none found + * @return the index of the last user message, or -1 if none found */ private int extractLastUserMessageIndex(List messages) { for (int i = messages.size() - 1; i >= 0; i--) { Msg msg = messages.get(i); if (msg.getRole() == MsgRole.USER) { + // Skip messages injected by other hooks (e.g., RAG knowledge) + String name = msg.getName(); + if ("retrieved_knowledge".equals(name)) { + continue; + } return i; } } diff --git a/agentscope-core/src/main/java/io/agentscope/core/rag/GenericRAGHook.java b/agentscope-core/src/main/java/io/agentscope/core/rag/GenericRAGHook.java index 59e30ea82..86d8f1954 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/rag/GenericRAGHook.java +++ b/agentscope-core/src/main/java/io/agentscope/core/rag/GenericRAGHook.java @@ -145,11 +145,9 @@ private Mono handlePreCall(PreCallEvent event) { if (retrievedDocs == null || retrievedDocs.isEmpty()) { return Mono.just(event); } - List enhancedMessages = new ArrayList<>(); - // Build enhanced messages with knowledge context - Msg enhancedMessage = createEnhancedMessages(retrievedDocs); - enhancedMessages.addAll(inputMessages); - enhancedMessages.add(enhancedMessage); + List enhancedMessages = new ArrayList<>(inputMessages); + // Build enhanced message with knowledge context + enhancedMessages.add(createEnhancedMessage(retrievedDocs)); event.setInputMessages(enhancedMessages); return Mono.just(event); }) @@ -165,7 +163,8 @@ private Mono handlePreCall(PreCallEvent event) { * Extracts query text from message list. * *

Finds the last user message as the query source (not just the last message, which could be - * ASSISTANT or TOOL in ReAct loops). + * ASSISTANT or TOOL in ReAct loops). Skips messages injected by other hooks (e.g., + * StaticLongTermMemoryHook with name "long_term_memory"). * * @param messages the message list * @return the extracted query text, or empty string if no user message found @@ -175,11 +174,15 @@ private String extractQueryFromMessages(List messages) { return ""; } - // Find the last user message (not just the last message, which could be - // ASSISTANT or TOOL in ReAct loops) + // Find the last user message, skipping hook-injected messages for (int i = messages.size() - 1; i >= 0; i--) { Msg msg = messages.get(i); if (msg.getRole() == MsgRole.USER) { + // Skip messages injected by other hooks (e.g., long-term memory) + String name = msg.getName(); + if ("long_term_memory".equals(name)) { + continue; + } return msg.getTextContent(); } } @@ -189,16 +192,18 @@ private String extractQueryFromMessages(List messages) { /** * Creates enhanced message list with knowledge context injected. * - *

The knowledge is injected as a system message at the beginning of the message list. + *

The knowledge is injected as a user message appended to the end of the message list. + * The message uses a distinct name "retrieved_knowledge" so that other hooks can identify + * and skip it when extracting the original user query. * * @param retrievedDocs the retrieved documents - * @return the enhanced message list with knowledge context + * @return the enhanced message with knowledge context */ - private Msg createEnhancedMessages(List retrievedDocs) { + private Msg createEnhancedMessage(List retrievedDocs) { String knowledgeContent = buildKnowledgeContent(retrievedDocs); return Msg.builder() - .name("user") + .name("retrieved_knowledge") .role(MsgRole.USER) .content(TextBlock.builder().text(knowledgeContent).build()) .build(); diff --git a/agentscope-core/src/test/java/io/agentscope/core/memory/StaticLongTermMemoryHookTest.java b/agentscope-core/src/test/java/io/agentscope/core/memory/StaticLongTermMemoryHookTest.java index 2f04a1b68..80eaac710 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/memory/StaticLongTermMemoryHookTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/memory/StaticLongTermMemoryHookTest.java @@ -338,4 +338,47 @@ void testOnEventWithPostCallEventAsyncRecordEmptyMemory() { verify(mockLongTermMemory, never()).record(anyList()); } + + @Test + void testOnEventSkipsRAGInjectedMessage() { + // Simulate GenericRAGHook having already injected a "retrieved_knowledge" message + List inputMessages = new ArrayList<>(); + inputMessages.add( + Msg.builder() + .role(MsgRole.USER) + .content(TextBlock.builder().text("What is the refund policy?").build()) + .build()); + inputMessages.add( + Msg.builder() + .role(MsgRole.USER) + .name("retrieved_knowledge") + .content( + TextBlock.builder() + .text("...") + .build()) + .build()); + + PreCallEvent event = new PreCallEvent(mockAgent, inputMessages); + + when(mockLongTermMemory.retrieve(any(Msg.class))) + .thenReturn(Mono.just("User prefers dark mode")); + + StepVerifier.create(hook.onEvent(event)) + .assertNext( + resultEvent -> { + List messages = resultEvent.getInputMessages(); + assertEquals(3, messages.size()); + // The retrieval should use the original user message, not the RAG + // message + assertEquals( + "What is the refund policy?", messages.get(0).getTextContent()); + assertEquals("retrieved_knowledge", messages.get(1).getName()); + assertEquals(MsgRole.USER, messages.get(2).getRole()); + assertTrue( + messages.get(2) + .getTextContent() + .contains("")); + }) + .verifyComplete(); + } } diff --git a/agentscope-core/src/test/java/io/agentscope/core/rag/GenericRAGHookTest.java b/agentscope-core/src/test/java/io/agentscope/core/rag/GenericRAGHookTest.java new file mode 100644 index 000000000..f8c9e071c --- /dev/null +++ b/agentscope-core/src/test/java/io/agentscope/core/rag/GenericRAGHookTest.java @@ -0,0 +1,162 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.rag; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.agentscope.core.agent.AgentBase; +import io.agentscope.core.hook.PreCallEvent; +import io.agentscope.core.interruption.InterruptContext; +import io.agentscope.core.message.Msg; +import io.agentscope.core.message.MsgRole; +import io.agentscope.core.message.TextBlock; +import io.agentscope.core.rag.model.Document; +import io.agentscope.core.rag.model.DocumentMetadata; +import io.agentscope.core.rag.model.RetrieveConfig; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Unit tests for {@link GenericRAGHook} message extraction logic. + * + *

These tests use a mocked {@link Knowledge} to isolate the hook's + * query-extraction behavior, in particular the skipping of hook-injected + * USER messages (name="long_term_memory"). + */ +@Tag("unit") +@DisplayName("GenericRAGHook Unit Tests (core)") +class GenericRAGHookTest { + + private Knowledge mockKnowledge; + private GenericRAGHook hook; + private AgentBase mockAgent; + + @BeforeEach + void setUp() { + mockKnowledge = mock(Knowledge.class); + hook = new GenericRAGHook(mockKnowledge); + mockAgent = + new AgentBase("MockAgent") { + @Override + protected Mono doCall(List msgs) { + return Mono.just(msgs.get(0)); + } + + @Override + protected Mono doObserve(Msg msg) { + return Mono.empty(); + } + + @Override + protected Mono handleInterrupt( + InterruptContext context, Msg... originalArgs) { + return Mono.just( + Msg.builder() + .name(getName()) + .role(MsgRole.ASSISTANT) + .content(TextBlock.builder().text("Interrupted").build()) + .build()); + } + }; + } + + @Test + @DisplayName("Should skip long_term_memory message and use real user query for retrieval") + void testSkipsLongTermMemoryMessageForQuery() { + Document doc = createDocument("doc1", "Refund policy: 30 days return"); + when(mockKnowledge.retrieve(anyString(), any(RetrieveConfig.class))) + .thenReturn(Mono.just(List.of(doc))); + + List inputMessages = new ArrayList<>(); + inputMessages.add( + Msg.builder() + .role(MsgRole.USER) + .content(TextBlock.builder().text("What is the refund policy?").build()) + .build()); + // Simulate StaticLongTermMemoryHook having already injected its message + inputMessages.add( + Msg.builder() + .role(MsgRole.USER) + .name("long_term_memory") + .content( + TextBlock.builder() + .text("some memory") + .build()) + .build()); + + PreCallEvent event = new PreCallEvent(mockAgent, inputMessages); + + StepVerifier.create(hook.onEvent(event)) + .assertNext( + result -> { + List messages = result.getInputMessages(); + // original user msg + long_term_memory + RAG knowledge = 3 + assertEquals(3, messages.size()); + // The injected knowledge message must have "retrieved_knowledge" name + assertEquals("retrieved_knowledge", messages.get(2).getName()); + assertEquals(MsgRole.USER, messages.get(2).getRole()); + assertTrue( + messages.get(2).getTextContent().contains("retrieved_knowledge") + || messages.get(2) + .getTextContent() + .contains("knowledge base")); + }) + .verifyComplete(); + } + + @Test + @DisplayName("Should return unchanged messages when only long_term_memory USER msg exists") + void testNoRealUserQueryWhenOnlyInjectedMessagesExist() { + List inputMessages = new ArrayList<>(); + // Only a hook-injected message — no genuine user input + inputMessages.add( + Msg.builder() + .role(MsgRole.USER) + .name("long_term_memory") + .content( + TextBlock.builder() + .text("memory content") + .build()) + .build()); + + PreCallEvent event = new PreCallEvent(mockAgent, inputMessages); + + StepVerifier.create(hook.onEvent(event)) + .assertNext( + result -> { + // No real user query found: message list must remain unchanged + assertEquals(1, result.getInputMessages().size()); + }) + .verifyComplete(); + } + + private Document createDocument(String docId, String content) { + TextBlock textBlock = TextBlock.builder().text(content).build(); + DocumentMetadata metadata = new DocumentMetadata(textBlock, docId, "0"); + return new Document(metadata); + } +} diff --git a/agentscope-extensions/agentscope-extensions-rag-simple/src/test/java/io/agentscope/core/rag/hook/GenericRAGHookTest.java b/agentscope-extensions/agentscope-extensions-rag-simple/src/test/java/io/agentscope/core/rag/hook/GenericRAGHookTest.java index 0fcaa5b86..ffd6e1b8f 100644 --- a/agentscope-extensions/agentscope-extensions-rag-simple/src/test/java/io/agentscope/core/rag/hook/GenericRAGHookTest.java +++ b/agentscope-extensions/agentscope-extensions-rag-simple/src/test/java/io/agentscope/core/rag/hook/GenericRAGHookTest.java @@ -162,6 +162,9 @@ void testHandlePreCallEvent() { assertEquals(MsgRole.USER, enhancedMessages.get(0).getRole()); // Second message should be user message with knowledge retrieval assertEquals(MsgRole.USER, enhancedMessages.get(1).getRole()); + // Name should be "retrieved_knowledge" to distinguish from real user + // messages + assertEquals("retrieved_knowledge", enhancedMessages.get(1).getName()); assertTrue( enhancedMessages .get(1) @@ -321,6 +324,83 @@ void testFormatKnowledgeContent() { .verifyComplete(); } + @Test + @DisplayName("Should skip long_term_memory injected messages when extracting query") + void testSkipLongTermMemoryInjectedMessage() { + // Use scoreThreshold=0.0 to guarantee retrieval hits regardless of vector similarity + TestMockEmbeddingModel embeddingModel = new TestMockEmbeddingModel(DIMENSIONS); + InMemoryStore vectorStore = InMemoryStore.builder().dimensions(DIMENSIONS).build(); + Knowledge zeroThresholdKnowledge = + SimpleKnowledge.builder() + .embeddingModel(embeddingModel) + .embeddingStore(vectorStore) + .build(); + RetrieveConfig zeroThresholdConfig = + RetrieveConfig.builder().limit(5).scoreThreshold(0.0).build(); + GenericRAGHook hookWithZeroThreshold = + new GenericRAGHook(zeroThresholdKnowledge, zeroThresholdConfig); + + // Simulate StaticLongTermMemoryHook having already injected a "long_term_memory" message + Document doc = createDocument("doc1", "Refund policy: 30 days return"); + zeroThresholdKnowledge.addDocuments(List.of(doc)).block(); + + List inputMessages = new ArrayList<>(); + inputMessages.add( + Msg.builder() + .role(MsgRole.USER) + .content(TextBlock.builder().text("What is the refund policy?").build()) + .build()); + inputMessages.add( + Msg.builder() + .role(MsgRole.USER) + .name("long_term_memory") + .content( + TextBlock.builder() + .text("...") + .build()) + .build()); + + PreCallEvent event = new PreCallEvent(mockAgent, inputMessages); + + StepVerifier.create(hookWithZeroThreshold.onEvent(event)) + .assertNext( + result -> { + List messages = result.getInputMessages(); + // Should have: original user msg + long_term_memory + RAG knowledge + assertEquals(3, messages.size()); + // The RAG hook must use the original user query, not the memory message + assertEquals("retrieved_knowledge", messages.get(2).getName()); + assertTrue(messages.get(2).getTextContent().contains("knowledge base")); + }) + .verifyComplete(); + } + + @Test + @DisplayName("Should return no results when all user messages are hook-injected") + void testNoRealUserMessageWhenAllAreInjected() { + // Only hook-injected USER messages exist — no genuine user query + List inputMessages = new ArrayList<>(); + inputMessages.add( + Msg.builder() + .role(MsgRole.USER) + .name("long_term_memory") + .content( + TextBlock.builder() + .text("some memory") + .build()) + .build()); + + PreCallEvent event = new PreCallEvent(mockAgent, inputMessages); + + StepVerifier.create(hook.onEvent(event)) + .assertNext( + result -> { + // No real user query found: messages should remain unchanged + assertEquals(1, result.getInputMessages().size()); + }) + .verifyComplete(); + } + /** * Creates a test document. */