Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,20 @@ private Mono<PostCallEvent> handlePostCall(PostCallEvent event) {
*
* <p>Scans the message list from end to start to find the most recent user message,
* which is typically the current query that should be used for memory retrieval.
* Skips messages injected by other hooks (e.g., GenericRAGHook with name "retrieved_knowledge").
*
* @param messages the message list
* @return the last user message, or null if none found
* @return the index of the last user message, or -1 if none found
*/
private int extractLastUserMessageIndex(List<Msg> messages) {
for (int i = messages.size() - 1; i >= 0; i--) {
Msg msg = messages.get(i);
if (msg.getRole() == MsgRole.USER) {
// Skip messages injected by other hooks (e.g., RAG knowledge)
String name = msg.getName();
if ("retrieved_knowledge".equals(name)) {
continue;
}
return i;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,9 @@ private Mono<PreCallEvent> handlePreCall(PreCallEvent event) {
if (retrievedDocs == null || retrievedDocs.isEmpty()) {
return Mono.just(event);
}
List<Msg> enhancedMessages = new ArrayList<>();
// Build enhanced messages with knowledge context
Msg enhancedMessage = createEnhancedMessages(retrievedDocs);
enhancedMessages.addAll(inputMessages);
enhancedMessages.add(enhancedMessage);
List<Msg> enhancedMessages = new ArrayList<>(inputMessages);
// Build enhanced message with knowledge context
enhancedMessages.add(createEnhancedMessage(retrievedDocs));
event.setInputMessages(enhancedMessages);
return Mono.just(event);
})
Expand All @@ -165,7 +163,8 @@ private Mono<PreCallEvent> handlePreCall(PreCallEvent event) {
* Extracts query text from message list.
*
* <p>Finds the last user message as the query source (not just the last message, which could be
* ASSISTANT or TOOL in ReAct loops).
* ASSISTANT or TOOL in ReAct loops). Skips messages injected by other hooks (e.g.,
* StaticLongTermMemoryHook with name "long_term_memory").
*
* @param messages the message list
* @return the extracted query text, or empty string if no user message found
Expand All @@ -175,11 +174,15 @@ private String extractQueryFromMessages(List<Msg> messages) {
return "";
}

// Find the last user message (not just the last message, which could be
// ASSISTANT or TOOL in ReAct loops)
// Find the last user message, skipping hook-injected messages
for (int i = messages.size() - 1; i >= 0; i--) {
Msg msg = messages.get(i);
if (msg.getRole() == MsgRole.USER) {
// Skip messages injected by other hooks (e.g., long-term memory)
String name = msg.getName();
if ("long_term_memory".equals(name)) {
continue;
}
return msg.getTextContent();
}
}
Expand All @@ -189,16 +192,18 @@ private String extractQueryFromMessages(List<Msg> messages) {
/**
* Creates enhanced message list with knowledge context injected.
*
* <p>The knowledge is injected as a system message at the beginning of the message list.
* <p>The knowledge is injected as a user message appended to the end of the message list.
* The message uses a distinct name "retrieved_knowledge" so that other hooks can identify
* and skip it when extracting the original user query.
*
* @param retrievedDocs the retrieved documents
* @return the enhanced message list with knowledge context
* @return the enhanced message with knowledge context
*/
private Msg createEnhancedMessages(List<Document> retrievedDocs) {
private Msg createEnhancedMessage(List<Document> retrievedDocs) {
String knowledgeContent = buildKnowledgeContent(retrievedDocs);

return Msg.builder()
.name("user")
.name("retrieved_knowledge")
.role(MsgRole.USER)
.content(TextBlock.builder().text(knowledgeContent).build())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,4 +338,47 @@ void testOnEventWithPostCallEventAsyncRecordEmptyMemory() {

verify(mockLongTermMemory, never()).record(anyList());
}

@Test
void testOnEventSkipsRAGInjectedMessage() {
// Simulate GenericRAGHook having already injected a "retrieved_knowledge" message
List<Msg> inputMessages = new ArrayList<>();
inputMessages.add(
Msg.builder()
.role(MsgRole.USER)
.content(TextBlock.builder().text("What is the refund policy?").build())
.build());
inputMessages.add(
Msg.builder()
.role(MsgRole.USER)
.name("retrieved_knowledge")
.content(
TextBlock.builder()
.text("<retrieved_knowledge>...</retrieved_knowledge>")
.build())
.build());

PreCallEvent event = new PreCallEvent(mockAgent, inputMessages);

when(mockLongTermMemory.retrieve(any(Msg.class)))
.thenReturn(Mono.just("User prefers dark mode"));

StepVerifier.create(hook.onEvent(event))
.assertNext(
resultEvent -> {
List<Msg> messages = resultEvent.getInputMessages();
assertEquals(3, messages.size());
// The retrieval should use the original user message, not the RAG
// message
assertEquals(
"What is the refund policy?", messages.get(0).getTextContent());
assertEquals("retrieved_knowledge", messages.get(1).getName());
assertEquals(MsgRole.USER, messages.get(2).getRole());
assertTrue(
messages.get(2)
.getTextContent()
.contains("<long_term_memory>"));
})
.verifyComplete();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Copyright 2024-2026 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.agentscope.core.rag;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import io.agentscope.core.agent.AgentBase;
import io.agentscope.core.hook.PreCallEvent;
import io.agentscope.core.interruption.InterruptContext;
import io.agentscope.core.message.Msg;
import io.agentscope.core.message.MsgRole;
import io.agentscope.core.message.TextBlock;
import io.agentscope.core.rag.model.Document;
import io.agentscope.core.rag.model.DocumentMetadata;
import io.agentscope.core.rag.model.RetrieveConfig;
import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

/**
* Unit tests for {@link GenericRAGHook} message extraction logic.
*
* <p>These tests use a mocked {@link Knowledge} to isolate the hook's
* query-extraction behavior, in particular the skipping of hook-injected
* USER messages (name="long_term_memory").
*/
@Tag("unit")
@DisplayName("GenericRAGHook Unit Tests (core)")
class GenericRAGHookTest {

private Knowledge mockKnowledge;
private GenericRAGHook hook;
private AgentBase mockAgent;

@BeforeEach
void setUp() {
mockKnowledge = mock(Knowledge.class);
hook = new GenericRAGHook(mockKnowledge);
mockAgent =
new AgentBase("MockAgent") {
@Override
protected Mono<Msg> doCall(List<Msg> msgs) {
return Mono.just(msgs.get(0));
}

@Override
protected Mono<Void> doObserve(Msg msg) {
return Mono.empty();
}

@Override
protected Mono<Msg> handleInterrupt(
InterruptContext context, Msg... originalArgs) {
return Mono.just(
Msg.builder()
.name(getName())
.role(MsgRole.ASSISTANT)
.content(TextBlock.builder().text("Interrupted").build())
.build());
}
};
}

@Test
@DisplayName("Should skip long_term_memory message and use real user query for retrieval")
void testSkipsLongTermMemoryMessageForQuery() {
Document doc = createDocument("doc1", "Refund policy: 30 days return");
when(mockKnowledge.retrieve(anyString(), any(RetrieveConfig.class)))
.thenReturn(Mono.just(List.of(doc)));

List<Msg> inputMessages = new ArrayList<>();
inputMessages.add(
Msg.builder()
.role(MsgRole.USER)
.content(TextBlock.builder().text("What is the refund policy?").build())
.build());
// Simulate StaticLongTermMemoryHook having already injected its message
inputMessages.add(
Msg.builder()
.role(MsgRole.USER)
.name("long_term_memory")
.content(
TextBlock.builder()
.text("<long_term_memory>some memory</long_term_memory>")
.build())
.build());

PreCallEvent event = new PreCallEvent(mockAgent, inputMessages);

StepVerifier.create(hook.onEvent(event))
.assertNext(
result -> {
List<Msg> messages = result.getInputMessages();
// original user msg + long_term_memory + RAG knowledge = 3
assertEquals(3, messages.size());
// The injected knowledge message must have "retrieved_knowledge" name
assertEquals("retrieved_knowledge", messages.get(2).getName());
assertEquals(MsgRole.USER, messages.get(2).getRole());
assertTrue(
messages.get(2).getTextContent().contains("retrieved_knowledge")
|| messages.get(2)
.getTextContent()
.contains("knowledge base"));
})
.verifyComplete();
}

@Test
@DisplayName("Should return unchanged messages when only long_term_memory USER msg exists")
void testNoRealUserQueryWhenOnlyInjectedMessagesExist() {
List<Msg> inputMessages = new ArrayList<>();
// Only a hook-injected message — no genuine user input
inputMessages.add(
Msg.builder()
.role(MsgRole.USER)
.name("long_term_memory")
.content(
TextBlock.builder()
.text("<long_term_memory>memory content</long_term_memory>")
.build())
.build());

PreCallEvent event = new PreCallEvent(mockAgent, inputMessages);

StepVerifier.create(hook.onEvent(event))
.assertNext(
result -> {
// No real user query found: message list must remain unchanged
assertEquals(1, result.getInputMessages().size());
})
.verifyComplete();
}

private Document createDocument(String docId, String content) {
TextBlock textBlock = TextBlock.builder().text(content).build();
DocumentMetadata metadata = new DocumentMetadata(textBlock, docId, "0");
return new Document(metadata);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ void testHandlePreCallEvent() {
assertEquals(MsgRole.USER, enhancedMessages.get(0).getRole());
// Second message should be user message with knowledge retrieval
assertEquals(MsgRole.USER, enhancedMessages.get(1).getRole());
// Name should be "retrieved_knowledge" to distinguish from real user
// messages
assertEquals("retrieved_knowledge", enhancedMessages.get(1).getName());
assertTrue(
enhancedMessages
.get(1)
Expand Down Expand Up @@ -321,6 +324,83 @@ void testFormatKnowledgeContent() {
.verifyComplete();
}

@Test
@DisplayName("Should skip long_term_memory injected messages when extracting query")
void testSkipLongTermMemoryInjectedMessage() {
// Use scoreThreshold=0.0 to guarantee retrieval hits regardless of vector similarity
TestMockEmbeddingModel embeddingModel = new TestMockEmbeddingModel(DIMENSIONS);
InMemoryStore vectorStore = InMemoryStore.builder().dimensions(DIMENSIONS).build();
Knowledge zeroThresholdKnowledge =
SimpleKnowledge.builder()
.embeddingModel(embeddingModel)
.embeddingStore(vectorStore)
.build();
RetrieveConfig zeroThresholdConfig =
RetrieveConfig.builder().limit(5).scoreThreshold(0.0).build();
GenericRAGHook hookWithZeroThreshold =
new GenericRAGHook(zeroThresholdKnowledge, zeroThresholdConfig);

// Simulate StaticLongTermMemoryHook having already injected a "long_term_memory" message
Document doc = createDocument("doc1", "Refund policy: 30 days return");
zeroThresholdKnowledge.addDocuments(List.of(doc)).block();

List<Msg> inputMessages = new ArrayList<>();
inputMessages.add(
Msg.builder()
.role(MsgRole.USER)
.content(TextBlock.builder().text("What is the refund policy?").build())
.build());
inputMessages.add(
Msg.builder()
.role(MsgRole.USER)
.name("long_term_memory")
.content(
TextBlock.builder()
.text("<long_term_memory>...</long_term_memory>")
.build())
.build());

PreCallEvent event = new PreCallEvent(mockAgent, inputMessages);

StepVerifier.create(hookWithZeroThreshold.onEvent(event))
.assertNext(
result -> {
List<Msg> messages = result.getInputMessages();
// Should have: original user msg + long_term_memory + RAG knowledge
assertEquals(3, messages.size());
// The RAG hook must use the original user query, not the memory message
assertEquals("retrieved_knowledge", messages.get(2).getName());
assertTrue(messages.get(2).getTextContent().contains("knowledge base"));
})
.verifyComplete();
}

@Test
@DisplayName("Should return no results when all user messages are hook-injected")
void testNoRealUserMessageWhenAllAreInjected() {
// Only hook-injected USER messages exist — no genuine user query
List<Msg> inputMessages = new ArrayList<>();
inputMessages.add(
Msg.builder()
.role(MsgRole.USER)
.name("long_term_memory")
.content(
TextBlock.builder()
.text("<long_term_memory>some memory</long_term_memory>")
.build())
.build());

PreCallEvent event = new PreCallEvent(mockAgent, inputMessages);

StepVerifier.create(hook.onEvent(event))
.assertNext(
result -> {
// No real user query found: messages should remain unchanged
assertEquals(1, result.getInputMessages().size());
})
.verifyComplete();
}

/**
* Creates a test document.
*/
Expand Down
Loading