Skip to content

Commit 524ac1c

Browse files
committed
Added merging of chunks from inline and tool-based RAG in streaming query
1 parent 8d2deee commit 524ac1c

4 files changed

Lines changed: 207 additions & 31 deletions

File tree

src/app/endpoints/streaming_query.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
243243
moderation_result=moderation_result,
244244
vector_store_ids=extract_vector_store_ids_from_tools(responses_params.tools),
245245
rag_id_mapping=configuration.rag_id_mapping,
246+
inline_rag_context=inline_rag_context,
246247
)
247248

248249
# Update metrics for the LLM call
@@ -254,7 +255,6 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
254255
generator, turn_summary = await retrieve_response_generator(
255256
responses_params=responses_params,
256257
context=context,
257-
inline_rag_docs=inline_rag_context.referenced_documents,
258258
)
259259

260260
# Combine inline RAG results (BYOK + Solr) with tool-based results
@@ -283,7 +283,6 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
283283
async def retrieve_response_generator(
284284
responses_params: ResponsesApiParams,
285285
context: ResponseGeneratorContext,
286-
inline_rag_docs: list[ReferencedDocument],
287286
) -> tuple[AsyncIterator[str], TurnSummary]:
288287
"""
289288
Retrieve the appropriate response generator.
@@ -295,7 +294,6 @@ async def retrieve_response_generator(
295294
Args:
296295
responses_params: The Responses API parameters
297296
context: The response generator context
298-
inline_rag_docs: Inline RAG (BYOK + Solr) documents
299297
Returns:
300298
tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary
301299
@@ -328,7 +326,6 @@ async def retrieve_response_generator(
328326
response,
329327
context,
330328
turn_summary,
331-
inline_rag_docs,
332329
),
333330
turn_summary,
334331
)
@@ -582,7 +579,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
582579
turn_response: AsyncIterator[OpenAIResponseObjectStream],
583580
context: ResponseGeneratorContext,
584581
turn_summary: TurnSummary,
585-
inline_rag_docs: list[ReferencedDocument],
586582
) -> AsyncIterator[str]:
587583
"""Generate SSE formatted streaming response.
588584
@@ -594,7 +590,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
594590
turn_response: The streaming response from Llama Stack
595591
context: The response generator context
596592
turn_summary: TurnSummary to populate during streaming
597-
inline_rag_docs: Inline RAG (BYOK + Solr) documents
598593
Yields:
599594
SSE-formatted strings for tokens, tool calls, tool results,
600595
turn completion, and error events.
@@ -773,7 +768,11 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
773768
)
774769
# Combine inline RAG results (BYOK + Solr) with tool-based results
775770
turn_summary.referenced_documents = deduplicate_referenced_documents(
776-
inline_rag_docs + tool_rag_docs
771+
context.inline_rag_context.referenced_documents + tool_rag_docs
772+
)
773+
# Combine inline RAG chunks (BYOK + Solr) with tool-based chunks
774+
turn_summary.rag_chunks = (
775+
context.inline_rag_context.rag_chunks + turn_summary.rag_chunks
777776
)
778777

779778

src/models/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from llama_stack_client import AsyncLlamaStackClient
55

66
from models.requests import QueryRequest
7-
from utils.types import ShieldModerationResult
7+
from utils.types import RAGContext, ShieldModerationResult
88

99

1010
@dataclass
@@ -25,6 +25,7 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes
2525
started_at: Timestamp when the request started (ISO 8601 format)
2626
client: The Llama Stack client for API interactions
2727
moderation_result: The moderation result
28+
inline_rag_context: Inline RAG context
2829
vector_store_ids: Vector store IDs used in the query for source resolution.
2930
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
3031
"""
@@ -47,5 +48,6 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes
4748
moderation_result: ShieldModerationResult
4849

4950
# RAG index identification
51+
inline_rag_context: RAGContext
5052
vector_store_ids: list[str] = field(default_factory=list)
5153
rag_id_mapping: dict[str, str] = field(default_factory=dict)

tests/unit/app/endpoints/test_query.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from models.responses import QueryResponse
1818
from utils.token_counter import TokenCounter
1919
from utils.types import (
20+
RAGChunk,
21+
RAGContext,
22+
ReferencedDocument,
2023
ResponsesApiParams,
2124
ShieldModerationPassed,
2225
ToolCallSummary,
@@ -174,6 +177,93 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary:
174177
assert response.conversation_id == "123"
175178
assert response.response == "Kubernetes is a container orchestration platform"
176179

180+
@pytest.mark.asyncio
181+
async def test_query_merges_inline_and_tool_rag_chunks_and_documents(
182+
self,
183+
dummy_request: Request,
184+
setup_configuration: AppConfig,
185+
mocker: MockerFixture,
186+
) -> None:
187+
"""Test that inline RAG and tool-based RAG chunks/docs are correctly merged."""
188+
query_request = QueryRequest(
189+
query="What is Kubernetes?"
190+
) # pyright: ignore[reportCallIssue]
191+
192+
mocker.patch("app.endpoints.query.configuration", setup_configuration)
193+
mocker.patch("app.endpoints.query.check_configuration_loaded")
194+
mocker.patch("app.endpoints.query.check_tokens_available")
195+
mocker.patch("app.endpoints.query.validate_model_provider_override")
196+
197+
mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient)
198+
mock_response_obj = mocker.Mock()
199+
mock_response_obj.output = []
200+
mock_client.responses = mocker.Mock()
201+
mock_client.responses.create = mocker.AsyncMock(return_value=mock_response_obj)
202+
mock_client_holder = mocker.Mock()
203+
mock_client_holder.get_client.return_value = mock_client
204+
mocker.patch(
205+
"app.endpoints.query.AsyncLlamaStackClientHolder",
206+
return_value=mock_client_holder,
207+
)
208+
mocker.patch(
209+
"app.endpoints.query.run_shield_moderation",
210+
new=mocker.AsyncMock(return_value=ShieldModerationPassed()),
211+
)
212+
213+
inline_chunk = RAGChunk(content="inline chunk content", source="byok")
214+
inline_doc = ReferencedDocument(doc_title="Inline Doc")
215+
inline_rag = RAGContext(
216+
context_text="",
217+
rag_chunks=[inline_chunk],
218+
referenced_documents=[inline_doc],
219+
)
220+
mocker.patch(
221+
"app.endpoints.query.build_rag_context",
222+
new=mocker.AsyncMock(return_value=inline_rag),
223+
)
224+
225+
mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
226+
mock_responses_params.model = "provider1/model1"
227+
mock_responses_params.conversation = "conv_123"
228+
mock_responses_params.tools = None
229+
mock_responses_params.model_dump.return_value = {
230+
"input": "test",
231+
"model": "provider1/model1",
232+
}
233+
mocker.patch(
234+
"app.endpoints.query.prepare_responses_params",
235+
new=mocker.AsyncMock(return_value=mock_responses_params),
236+
)
237+
238+
tool_chunk = RAGChunk(content="tool chunk content", source="vs-1")
239+
tool_doc = ReferencedDocument(doc_title="Tool Doc")
240+
mock_turn_summary = TurnSummary()
241+
mock_turn_summary.rag_chunks = [tool_chunk]
242+
mock_turn_summary.referenced_documents = [tool_doc]
243+
244+
mocker.patch(
245+
"app.endpoints.query.retrieve_response",
246+
new=mocker.AsyncMock(return_value=mock_turn_summary),
247+
)
248+
mocker.patch("app.endpoints.query.store_query_results")
249+
mocker.patch("app.endpoints.query.consume_query_tokens")
250+
mocker.patch("app.endpoints.query.get_available_quotas", return_value={})
251+
252+
response = await query_endpoint_handler(
253+
request=dummy_request,
254+
query_request=query_request,
255+
auth=MOCK_AUTH,
256+
mcp_headers={},
257+
)
258+
259+
assert isinstance(response, QueryResponse)
260+
assert len(response.rag_chunks) == 2
261+
assert response.rag_chunks[0].content == "inline chunk content"
262+
assert response.rag_chunks[1].content == "tool chunk content"
263+
assert len(response.referenced_documents) == 2
264+
assert response.referenced_documents[0].doc_title == "Inline Doc"
265+
assert response.referenced_documents[1].doc_title == "Tool Doc"
266+
177267
@pytest.mark.asyncio
178268
async def test_successful_query_with_conversation(
179269
self,

0 commit comments

Comments
 (0)