Skip to content

Commit e98ff61

Browse files
committed
streaming query
Signed-off-by: Anxhela Coba <acoba@redhat.com>
1 parent e55c56a commit e98ff61

5 files changed

Lines changed: 233 additions & 23 deletions

File tree

src/app/endpoints/query_v2.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
472472
logger.info("Final params being sent to vector_io.query: %s", params)
473473

474474
query_response = await client.vector_io.query(
475-
vector_store_id=vector_store_id, query=query_request.query, params=params
475+
vector_store_id=vector_store_id,
476+
query=query_request.query,
477+
params=params,
476478
)
477479

478480
logger.info("The query response total payload: %s", query_response)
@@ -504,7 +506,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
504506
title = title or cm.get("title")
505507
reference_url = cm.get("reference_url")
506508
else:
507-
doc_id = getattr(cm, "doc_id", None) or getattr(cm, "document_id", None)
509+
doc_id = getattr(cm, "doc_id", None) or getattr(
510+
cm, "document_id", None
511+
)
508512
title = title or getattr(cm, "title", None)
509513
reference_url = getattr(cm, "reference_url", None)
510514
else:
@@ -523,7 +527,11 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
523527
else:
524528
# Use reference_url if online
525529
reference_doc = reference_url or doc_id
526-
doc_url = reference_doc if reference_doc.startswith("http") else ("https://mimir.corp.redhat.com" + reference_doc)
530+
doc_url = (
531+
reference_doc
532+
if reference_doc.startswith("http")
533+
else ("https://mimir.corp.redhat.com" + reference_doc)
534+
)
527535

528536
if reference_doc and reference_doc not in metadata_doc_ids:
529537
metadata_doc_ids.add(reference_doc)
@@ -534,8 +542,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
534542
)
535543
)
536544

537-
logger.info("Extracted %d unique document IDs from chunks", len(doc_ids_from_chunks))
538-
545+
logger.info(
546+
"Extracted %d unique document IDs from chunks", len(doc_ids_from_chunks)
547+
)
539548

540549
except (
541550
APIConnectionError,

src/app/endpoints/streaming_query_v2.py

Lines changed: 194 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Streaming query handler using Responses API (v2)."""
22

33
import logging
4+
import traceback
45
from typing import Annotated, Any, AsyncIterator, Optional, cast
6+
from urllib.parse import urljoin
57

68
from fastapi import APIRouter, Depends, Request
79
from fastapi.responses import StreamingResponse
@@ -14,7 +16,7 @@
1416
OpenAIResponseObjectStreamResponseOutputTextDelta,
1517
OpenAIResponseObjectStreamResponseOutputTextDone,
1618
)
17-
from llama_stack_client import AsyncLlamaStackClient
19+
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
1820

1921
from app.endpoints.query import (
2022
is_transcripts_enabled,
@@ -51,6 +53,7 @@
5153
InternalServerErrorResponse,
5254
NotFoundResponse,
5355
QuotaExceededResponse,
56+
ReferencedDocument,
5457
ServiceUnavailableResponse,
5558
StreamingQueryResponse,
5659
UnauthorizedResponse,
@@ -97,6 +100,7 @@
97100

98101
def create_responses_response_generator( # pylint: disable=too-many-locals,too-many-statements
99102
context: ResponseGeneratorContext,
103+
doc_ids_from_chunks: Optional[list[ReferencedDocument]] = None,
100104
) -> Any:
101105
"""
102106
Create a response generator function for Responses API streaming.
@@ -106,6 +110,7 @@ def create_responses_response_generator( # pylint: disable=too-many-locals,too-
106110
107111
Args:
108112
context: Context object containing all necessary parameters for response generation
113+
doc_ids_from_chunks: Referenced documents extracted from vector DB chunks
109114
110115
Returns:
111116
An async generator function that yields SSE-formatted strings
@@ -294,17 +299,21 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
294299
model_id=context.model_id,
295300
provider_id=context.provider_id,
296301
)
297-
referenced_documents = parse_referenced_documents_from_responses_api(
302+
response_referenced_documents = parse_referenced_documents_from_responses_api(
298303
cast(OpenAIResponseObject, latest_response_object)
299304
)
305+
# Combine doc_ids_from_chunks with response_referenced_documents
306+
all_referenced_documents = (
307+
doc_ids_from_chunks or []
308+
) + response_referenced_documents
300309
available_quotas = get_available_quotas(
301310
configuration.quota_limiters, context.user_id
302311
)
303312
yield stream_end_event(
304313
context.metadata_map,
305314
token_usage,
306315
available_quotas,
307-
referenced_documents,
316+
all_referenced_documents,
308317
media_type,
309318
)
310319

@@ -382,7 +391,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
382391
query_request: QueryRequest,
383392
token: str,
384393
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
385-
) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]:
394+
) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str, list[ReferencedDocument]]:
386395
"""
387396
Retrieve response from LLMs and agents.
388397
@@ -403,8 +412,8 @@ async def retrieve_response( # pylint: disable=too-many-locals
403412
Multi-cluster proxy headers for tool integrations.
404413
405414
Returns:
406-
tuple: A tuple containing the streaming response object
407-
and the conversation ID.
415+
tuple: A tuple containing the streaming response object,
416+
the conversation ID, and the list of referenced documents from vector DB chunks.
408417
"""
409418
# use system prompt from request or default one
410419
system_prompt = get_system_prompt(query_request, configuration)
@@ -415,11 +424,180 @@ async def retrieve_response( # pylint: disable=too-many-locals
415424
if query_request.attachments:
416425
validate_attachments_metadata(query_request.attachments)
417426

418-
# Prepare tools for responses API
427+
# Prepare tools for responses API - skip RAG tools since we're doing direct vector query
419428
toolgroups = await prepare_tools_for_responses_api(
420-
client, query_request, token, configuration, mcp_headers=mcp_headers
429+
client,
430+
query_request,
431+
token,
432+
configuration,
433+
mcp_headers=mcp_headers,
434+
skip_rag_tools=True,
421435
)
422436

437+
# Extract RAG chunks from vector DB query response BEFORE calling responses API
438+
rag_chunks = []
439+
doc_ids_from_chunks = []
440+
retrieved_chunks = []
441+
retrieved_scores = []
442+
443+
# When offline is False, use reference_url for chunk source
444+
# When offline is True, use parent_id for chunk source
445+
# TODO: move this setting to a higher level configuration
446+
offline = True
447+
448+
try:
449+
# Get vector stores for direct querying
450+
if query_request.vector_store_ids:
451+
vector_store_ids = query_request.vector_store_ids
452+
logger.info(
453+
"Using specified vector_store_ids for direct query: %s",
454+
vector_store_ids,
455+
)
456+
else:
457+
vector_store_ids = [
458+
vector_store.id
459+
for vector_store in (await client.vector_stores.list()).data
460+
]
461+
logger.info(
462+
"Using all available vector_store_ids for direct query: %s",
463+
vector_store_ids,
464+
)
465+
466+
if vector_store_ids:
467+
vector_store_id = vector_store_ids[0] # Use first available vector store
468+
469+
params = {"k": 5, "score_threshold": 0.0, "mode": "hybrid"}
470+
logger.info("Initial params: %s", params)
471+
logger.info("query_request.solr: %s", query_request.solr)
472+
if query_request.solr:
473+
# Pass the entire solr dict under the 'solr' key
474+
params["solr"] = query_request.solr
475+
logger.info("Final params with solr filters: %s", params)
476+
else:
477+
logger.info("No solr filters provided")
478+
logger.info("Final params being sent to vector_io.query: %s", params)
479+
480+
query_response = await client.vector_io.query(
481+
vector_store_id=vector_store_id,
482+
query=query_request.query,
483+
params=params,
484+
)
485+
486+
logger.info("The query response total payload: %s", query_response)
487+
488+
if query_response.chunks:
489+
retrieved_chunks = query_response.chunks
490+
retrieved_scores = (
491+
query_response.scores if hasattr(query_response, "scores") else []
492+
)
493+
494+
# Extract doc_ids from chunks for referenced_documents
495+
metadata_doc_ids = set()
496+
497+
for chunk in query_response.chunks:
498+
logger.info("Extract doc ids from chunk: %s", chunk)
499+
500+
# 1) dict metadata
501+
md = getattr(chunk, "metadata", None) or {}
502+
doc_id = md.get("doc_id") or md.get("document_id")
503+
title = md.get("title")
504+
505+
# 2) typed chunk_metadata
506+
if not doc_id:
507+
cm = getattr(chunk, "chunk_metadata", None)
508+
if cm is not None:
509+
# cm might be a pydantic model or a dict depending on caller
510+
if isinstance(cm, dict):
511+
doc_id = cm.get("doc_id") or cm.get("document_id")
512+
title = title or cm.get("title")
513+
reference_url = cm.get("reference_url")
514+
else:
515+
doc_id = getattr(cm, "doc_id", None) or getattr(
516+
cm, "document_id", None
517+
)
518+
title = title or getattr(cm, "title", None)
519+
reference_url = getattr(cm, "reference_url", None)
520+
else:
521+
reference_url = None
522+
else:
523+
reference_url = md.get("reference_url")
524+
525+
if not doc_id and not reference_url:
526+
continue
527+
528+
# Build URL based on offline flag
529+
if offline:
530+
# Use parent/doc path
531+
reference_doc = doc_id
532+
doc_url = "https://mimir.corp.redhat.com" + reference_doc
533+
else:
534+
# Use reference_url if online
535+
reference_doc = reference_url or doc_id
536+
doc_url = (
537+
reference_doc
538+
if reference_doc.startswith("http")
539+
else ("https://mimir.corp.redhat.com" + reference_doc)
540+
)
541+
542+
if reference_doc and reference_doc not in metadata_doc_ids:
543+
metadata_doc_ids.add(reference_doc)
544+
doc_ids_from_chunks.append(
545+
ReferencedDocument(
546+
doc_title=title,
547+
doc_url=doc_url,
548+
)
549+
)
550+
551+
logger.info(
552+
"Extracted %d unique document IDs from chunks", len(doc_ids_from_chunks)
553+
)
554+
555+
except (
556+
APIConnectionError,
557+
APIStatusError,
558+
AttributeError,
559+
KeyError,
560+
ValueError,
561+
) as e:
562+
logger.warning("Failed to query vector database for chunks: %s", e)
563+
logger.debug("Vector DB query error details: %s", traceback.format_exc())
564+
# Continue without RAG chunks
565+
566+
# Convert retrieved chunks to RAGChunk format
567+
for i, chunk in enumerate(retrieved_chunks):
568+
# Extract source from chunk metadata based on offline flag
569+
source = None
570+
if chunk.metadata:
571+
if offline:
572+
parent_id = chunk.metadata.get("parent_id")
573+
if parent_id:
574+
source = urljoin("https://mimir.corp.redhat.com", parent_id)
575+
else:
576+
source = chunk.metadata.get("reference_url")
577+
578+
# Get score from retrieved_scores list if available
579+
score = retrieved_scores[i] if i < len(retrieved_scores) else None
580+
581+
rag_chunks.append(
582+
RAGChunk(
583+
content=chunk.content,
584+
source=source,
585+
score=score,
586+
)
587+
)
588+
589+
logger.info("Retrieved %d chunks from vector DB", len(rag_chunks))
590+
591+
# Format RAG context for injection into user message
592+
rag_context = ""
593+
if rag_chunks:
594+
context_chunks = []
595+
for chunk in rag_chunks[:5]: # Limit to top 5 chunks
596+
chunk_text = f"Source: {chunk.source or 'Unknown'}\n{chunk.content}"
597+
context_chunks.append(chunk_text)
598+
rag_context = "\n\nRelevant documentation:\n" + "\n\n".join(context_chunks)
599+
logger.info("Injecting %d RAG chunks into user message", len(context_chunks))
600+
423601
# Prepare input for Responses API
424602
# Convert attachments to text and concatenate with query
425603
input_text = query_request.query
@@ -430,6 +608,9 @@ async def retrieve_response( # pylint: disable=too-many-locals
430608
f"{attachment.content}"
431609
)
432610

611+
# Add RAG context to input text
612+
input_text += rag_context
613+
433614
# Handle conversation ID for Responses API
434615
# Create conversation upfront if not provided
435616
conversation_id = query_request.conversation_id
@@ -475,4 +656,8 @@ async def retrieve_response( # pylint: disable=too-many-locals
475656
response = await client.responses.create(**create_params)
476657
response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response)
477658

478-
return response_stream, normalize_conversation_id(conversation_id)
659+
return (
660+
response_stream,
661+
normalize_conversation_id(conversation_id),
662+
doc_ids_from_chunks,
663+
)

tests/unit/app/endpoints/test_query_v2.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,20 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( # pylint: disable=to
289289
# Mock shields.list and models.list for run_shield_moderation
290290
mock_client.shields.list = mocker.AsyncMock(return_value=[])
291291
mock_client.models.list = mocker.AsyncMock(return_value=[])
292+
293+
# Mock vector_io.query for direct vector querying
294+
mock_query_response = mocker.Mock()
295+
mock_query_response.chunks = []
296+
mock_query_response.scores = []
297+
mock_client.vector_io.query = mocker.AsyncMock(return_value=mock_query_response)
292298

293299
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
300+
301+
# Mock shield moderation
302+
mock_moderation_result = mocker.Mock()
303+
mock_moderation_result.blocked = False
304+
mocker.patch("app.endpoints.query_v2.run_shield_moderation", return_value=mock_moderation_result)
305+
294306
mock_cfg = mocker.Mock()
295307
mock_cfg.mcp_servers = [
296308
ModelContextProtocolServer(
@@ -314,11 +326,9 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( # pylint: disable=to
314326
kwargs = mock_client.responses.create.call_args.kwargs
315327
tools = kwargs["tools"]
316328
assert isinstance(tools, list)
317-
# Expect one file_search and one mcp tool
329+
# Expect only MCP tools since RAG tools are skipped when doing direct vector querying
318330
tool_types = {t.get("type") for t in tools}
319-
assert tool_types == {"file_search", "mcp"}
320-
file_search = next(t for t in tools if t["type"] == "file_search")
321-
assert file_search["vector_store_ids"] == ["dbA"]
331+
assert tool_types == {"mcp"}
322332
mcp_tool = next(t for t in tools if t["type"] == "mcp")
323333
assert mcp_tool["server_label"] == "fs"
324334
assert mcp_tool["headers"] == {"Authorization": "Bearer mytoken"}

0 commit comments

Comments
 (0)