Skip to content

Commit 3ed3cac

Browse files
committed
Re-apply lost parts of implementation after rebase
1 parent 96262a6 commit 3ed3cac

7 files changed

Lines changed: 571 additions & 12 deletions

File tree

src/app/endpoints/query.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import datetime
66
import logging
7-
from typing import Annotated, Any, cast
7+
from typing import Annotated, Any, Optional, cast
88

99
from fastapi import APIRouter, Depends, HTTPException, Request
1010
from llama_stack_api.openai_responses import OpenAIResponseObject
@@ -56,6 +56,7 @@
5656
build_tool_call_summary,
5757
extract_text_from_response_output_item,
5858
extract_token_usage,
59+
extract_vector_store_ids_from_tools,
5960
get_topic_summary,
6061
parse_referenced_documents,
6162
prepare_responses_params,
@@ -184,8 +185,14 @@ async def query_endpoint_handler(
184185
):
185186
client = await update_azure_token(client)
186187

188+
# Build index identification mapping for RAG source resolution
189+
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
190+
rag_id_mapping = configuration.rag_id_mapping
191+
187192
# Retrieve response using Responses API
188-
turn_summary = await retrieve_response(client, responses_params)
193+
turn_summary = await retrieve_response(
194+
client, responses_params, vector_store_ids, rag_id_mapping
195+
)
189196

190197
if pre_rag_chunks:
191198
turn_summary.rag_chunks = pre_rag_chunks + (turn_summary.rag_chunks or [])
@@ -269,6 +276,8 @@ def parse_referenced_docs(
269276
async def retrieve_response( # pylint: disable=too-many-locals
270277
client: AsyncLlamaStackClient,
271278
responses_params: ResponsesApiParams,
279+
vector_store_ids: Optional[list[str]] = None,
280+
rag_id_mapping: Optional[dict[str, str]] = None,
272281
) -> TurnSummary:
273282
"""
274283
Retrieve response from LLMs and agents.
@@ -279,6 +288,8 @@ async def retrieve_response( # pylint: disable=too-many-locals
279288
Parameters:
280289
client: The AsyncLlamaStackClient to use for the request.
281290
responses_params: The Responses API parameters.
291+
vector_store_ids: Vector store IDs used in the query for source resolution.
292+
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
282293
283294
Returns:
284295
TurnSummary: Summary of the LLM response content
@@ -323,7 +334,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
323334
summary.llm_response += message_text
324335

325336
tool_call, tool_result = build_tool_call_summary(
326-
output_item, summary.rag_chunks
337+
output_item, summary.rag_chunks, vector_store_ids, rag_id_mapping
327338
)
328339
if tool_call:
329340
summary.tool_calls.append(tool_call)
@@ -337,7 +348,9 @@ async def retrieve_response( # pylint: disable=too-many-locals
337348
)
338349

339350
# Extract referenced documents and token usage from Responses API response
340-
summary.referenced_documents = parse_referenced_documents(response)
351+
summary.referenced_documents = parse_referenced_documents(
352+
response, vector_store_ids, rag_id_mapping
353+
)
341354
summary.token_usage = extract_token_usage(response, responses_params.model)
342355

343356
return summary

src/app/endpoints/streaming_query.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
build_tool_call_summary,
7575
build_tool_result_from_mcp_output_item_done,
7676
extract_token_usage,
77+
extract_vector_store_ids_from_tools,
7778
get_topic_summary,
7879
parse_referenced_documents,
7980
prepare_responses_params,
@@ -204,7 +205,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
204205
):
205206
client = await update_azure_token(client)
206207

207-
# Create context
208+
# Create context with index identification mapping for RAG source resolution
208209
context = ResponseGeneratorContext(
209210
conversation_id=normalize_conversation_id(responses_params.conversation),
210211
model_id=responses_params.model,
@@ -213,6 +214,8 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
213214
query_request=query_request,
214215
started_at=started_at,
215216
client=client,
217+
vector_store_ids=extract_vector_store_ids_from_tools(responses_params.tools),
218+
rag_id_mapping=configuration.rag_id_mapping,
216219
)
217220

218221
# Update metrics for the LLM call
@@ -527,7 +530,10 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
527530
# For all other types (and mcp_call when arguments.done didn't happen),
528531
# emit both call and result together
529532
tool_call, tool_result = build_tool_call_summary(
530-
output_item_done_chunk.item, turn_summary.rag_chunks
533+
output_item_done_chunk.item,
534+
turn_summary.rag_chunks,
535+
context.vector_store_ids,
536+
context.rag_id_mapping,
531537
)
532538
if tool_call:
533539
turn_summary.tool_calls.append(tool_call)
@@ -587,7 +593,11 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
587593
turn_summary.token_usage = extract_token_usage(
588594
latest_response_object, context.model_id
589595
)
590-
tool_based_documents = parse_referenced_documents(latest_response_object)
596+
tool_based_documents = parse_referenced_documents(
597+
latest_response_object,
598+
vector_store_ids=context.vector_store_ids,
599+
rag_id_mapping=context.rag_id_mapping,
600+
)
591601

592602
# Merge pre-RAG documents with tool-based documents (similar to query.py)
593603
if turn_summary.pre_rag_documents:

src/utils/responses.py

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,25 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
292292
)
293293

294294

295+
def extract_vector_store_ids_from_tools(
296+
tools: Optional[list[dict[str, Any]]],
297+
) -> list[str]:
298+
"""Extract vector store IDs from prepared tool configurations.
299+
300+
Parameters:
301+
tools: The prepared tools list from ResponsesApiParams.
302+
303+
Returns:
304+
List of vector store IDs used in file_search tools, or empty list.
305+
"""
306+
if not tools:
307+
return []
308+
for tool in tools:
309+
if tool.get("type") == "file_search":
310+
return tool.get("vector_store_ids", [])
311+
return []
312+
313+
295314
def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]:
296315
"""Convert vector store IDs to tools format for Responses API.
297316
@@ -390,14 +409,18 @@ def _get_token_value(original: str, header: str) -> str | None:
390409

391410
def parse_referenced_documents(
392411
response: Optional[OpenAIResponseObject],
412+
vector_store_ids: Optional[list[str]] = None,
413+
rag_id_mapping: Optional[dict[str, str]] = None,
393414
) -> list[ReferencedDocument]:
394415
"""Parse referenced documents from Responses API response.
395416
396417
Args:
397418
response: The OpenAI Response API response object
419+
vector_store_ids: Vector store IDs used in the query for source resolution.
420+
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
398421
399422
Returns:
400-
List of referenced documents with doc_url and doc_title
423+
List of referenced documents with doc_url, doc_title, and source
401424
"""
402425
documents: list[ReferencedDocument] = []
403426
# Use a set to track unique documents by (doc_url, doc_title) tuple
@@ -407,6 +430,10 @@ def parse_referenced_documents(
407430
if response is None or not response.output:
408431
return documents
409432

433+
resolved_source = _resolve_single_store_source(
434+
vector_store_ids or [], rag_id_mapping or {}
435+
)
436+
410437
for output_item in response.output:
411438
item_type = getattr(output_item, "type", None)
412439

@@ -434,13 +461,36 @@ def parse_referenced_documents(
434461
final_url = doc_url if doc_url else None
435462
if (final_url, doc_title) not in seen_docs:
436463
documents.append(
437-
ReferencedDocument(doc_url=final_url, doc_title=doc_title)
464+
ReferencedDocument(
465+
doc_url=final_url,
466+
doc_title=doc_title,
467+
source=resolved_source,
468+
)
438469
)
439470
seen_docs.add((final_url, doc_title))
440471

441472
return documents
442473

443474

475+
def _resolve_single_store_source(
476+
vector_store_ids: list[str],
477+
rag_id_mapping: dict[str, str],
478+
) -> Optional[str]:
479+
"""Resolve source name when there is exactly one vector store.
480+
481+
Parameters:
482+
vector_store_ids: The vector store IDs used in the query.
483+
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
484+
485+
Returns:
486+
The resolved rag_id if exactly one store is used, None otherwise.
487+
"""
488+
if len(vector_store_ids) == 1:
489+
store_id = vector_store_ids[0]
490+
return rag_id_mapping.get(store_id)
491+
return None
492+
493+
444494
def extract_token_usage(
445495
response: Optional[OpenAIResponseObject], model_id: str
446496
) -> TokenCounter:
@@ -522,15 +572,19 @@ def extract_token_usage(
522572
return token_counter
523573

524574

525-
def build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches
575+
def build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches,too-many-locals
526576
output_item: OpenAIResponseOutput,
527577
rag_chunks: list[RAGChunk],
578+
vector_store_ids: Optional[list[str]] = None,
579+
rag_id_mapping: Optional[dict[str, str]] = None,
528580
) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]:
529581
"""Translate Responses API tool outputs into ToolCallSummary and ToolResultSummary.
530582
531583
Args:
532584
output_item: An OpenAIResponseOutput item from the response.output array
533585
rag_chunks: List to append extracted RAG chunks to (from file_search_call items)
586+
vector_store_ids: Vector store IDs used in the query for source resolution.
587+
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
534588
535589
Returns:
536590
Tuple of (ToolCallSummary, ToolResultSummary), one may be None
@@ -551,7 +605,9 @@ def build_tool_call_summary( # pylint: disable=too-many-return-statements,too-m
551605

552606
if item_type == "file_search_call":
553607
file_search_item = cast(FileSearchCall, output_item)
554-
extract_rag_chunks_from_file_search_item(file_search_item, rag_chunks)
608+
extract_rag_chunks_from_file_search_item(
609+
file_search_item, rag_chunks, vector_store_ids, rag_id_mapping
610+
)
555611
response_payload: Optional[dict[str, Any]] = None
556612
if file_search_item.results is not None:
557613
response_payload = {
@@ -731,20 +787,79 @@ def build_tool_result_from_mcp_output_item_done(
731787
)
732788

733789

790+
def _resolve_source_for_result(
791+
result: Any,
792+
vector_store_ids: list[str],
793+
rag_id_mapping: dict[str, str],
794+
) -> Optional[str]:
795+
"""Resolve the human-friendly index name for a file search result.
796+
797+
Uses the vector store mapping to convert internal llama-stack IDs
798+
to user-facing rag_ids from configuration.
799+
800+
Parameters:
801+
result: A file search result object with optional attributes.
802+
vector_store_ids: The vector store IDs used in this query.
803+
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
804+
805+
Returns:
806+
The resolved index name, or None if resolution is not possible.
807+
"""
808+
if len(vector_store_ids) == 1:
809+
store_id = vector_store_ids[0]
810+
return rag_id_mapping.get(store_id, result.filename)
811+
812+
if len(vector_store_ids) > 1:
813+
attributes = getattr(result, "attributes", {}) or {}
814+
attr_store_id: Optional[str] = attributes.get("vector_store_id")
815+
if attr_store_id and attr_store_id in rag_id_mapping:
816+
return rag_id_mapping[attr_store_id]
817+
818+
return result.filename
819+
820+
821+
def _build_chunk_attributes(result: Any) -> Optional[dict[str, Any]]:
822+
"""Extract document metadata attributes from a file search result.
823+
824+
Parameters:
825+
result: A file search result object with optional attributes.
826+
827+
Returns:
828+
Dictionary of metadata attributes, or None if no attributes available.
829+
"""
830+
attributes = getattr(result, "attributes", None)
831+
if not attributes:
832+
return None
833+
if isinstance(attributes, dict):
834+
return attributes if attributes else None
835+
return None
836+
837+
734838
def extract_rag_chunks_from_file_search_item(
735839
item: FileSearchCall,
736840
rag_chunks: list[RAGChunk],
841+
vector_store_ids: Optional[list[str]] = None,
842+
rag_id_mapping: Optional[dict[str, str]] = None,
737843
) -> None:
738844
"""Extract RAG chunks from a file search tool call item.
739845
740846
Args:
741847
item: The file search tool call item
742848
rag_chunks: List to append extracted RAG chunks to
849+
vector_store_ids: Vector store IDs used in the query for source resolution.
850+
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
743851
"""
744852
if item.results is not None:
745853
for result in item.results:
854+
source = _resolve_source_for_result(
855+
result, vector_store_ids or [], rag_id_mapping or {}
856+
)
857+
attributes = _build_chunk_attributes(result)
746858
rag_chunk = RAGChunk(
747-
content=result.text, source=result.filename, score=result.score
859+
content=result.text,
860+
source=source,
861+
score=result.score,
862+
attributes=attributes,
748863
)
749864
rag_chunks.append(rag_chunk)
750865

src/utils/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,11 @@ class ReferencedDocument(BaseModel):
181181
None, description="Title of the referenced document"
182182
)
183183

184+
source: Optional[str] = Field(
185+
default=None,
186+
description="Index name identifying the knowledge source from configuration",
187+
)
188+
184189

185190
class TurnSummary(BaseModel):
186191
"""Summary of a turn in llama stack."""

tests/unit/app/endpoints/test_query.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ async def test_successful_query_no_conversation(
125125
mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
126126
mock_responses_params.model = "provider1/model1"
127127
mock_responses_params.conversation = "conv_123"
128+
mock_responses_params.tools = None
128129
mock_responses_params.model_dump.return_value = {
129130
"input": "test",
130131
"model": "provider1/model1",
@@ -200,6 +201,7 @@ async def test_successful_query_with_conversation(
200201
mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
201202
mock_responses_params.model = "provider1/model1"
202203
mock_responses_params.conversation = "conv_123"
204+
mock_responses_params.tools = None
203205
mock_responses_params.model_dump.return_value = {
204206
"input": "test",
205207
"model": "provider1/model1",
@@ -273,6 +275,7 @@ async def test_query_with_attachments(
273275
mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
274276
mock_responses_params.model = "provider1/model1"
275277
mock_responses_params.conversation = "conv_123"
278+
mock_responses_params.tools = None
276279
mock_responses_params.model_dump.return_value = {
277280
"input": "test",
278281
"model": "provider1/model1",
@@ -332,6 +335,7 @@ async def test_query_with_topic_summary(
332335
mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
333336
mock_responses_params.model = "provider1/model1"
334337
mock_responses_params.conversation = "conv_123"
338+
mock_responses_params.tools = None
335339
mock_responses_params.model_dump.return_value = {
336340
"input": "test",
337341
"model": "provider1/model1",
@@ -401,6 +405,7 @@ async def test_query_azure_token_refresh(
401405
mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
402406
mock_responses_params.model = "azure/model1"
403407
mock_responses_params.conversation = "conv_123"
408+
mock_responses_params.tools = None
404409
mock_responses_params.model_dump.return_value = {
405410
"input": "test",
406411
"model": "azure/model1",

0 commit comments

Comments
 (0)