Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions src/utils/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _process_solr_chunks_for_documents(
async def _fetch_byok_rag(
client: AsyncLlamaStackClient,
query: str,
vector_store_ids: Optional[list[str]] = None,
vector_store_ids: Optional[list[str]] = None, # User-facing
) -> tuple[list[RAGChunk], list[ReferencedDocument]]:
"""Fetch chunks and documents from BYOK RAG sources.

Expand All @@ -339,22 +339,23 @@ async def _fetch_byok_rag(

# Determine which BYOK vector stores to query for inline RAG.
# Per-request override takes precedence; otherwise use config-based inline list.
if vector_store_ids is not None:
# Request-level override: filter out Solr store, use the rest
vector_store_ids_to_query = [
vs_id
for vs_id in vector_store_ids
if vs_id != constants.SOLR_DEFAULT_VECTOR_STORE_ID
]
else:
inline_rag_ids = [
rid
for rid in configuration.configuration.rag.inline
if rid != constants.OKP_RAG_ID
]
vector_store_ids_to_query = resolve_vector_store_ids(
inline_rag_ids, configuration.configuration.byok_rag
)
rag_ids_to_query = (
configuration.configuration.rag.inline
if vector_store_ids is None
else vector_store_ids
)

# Translate user-facing rag_ids to llama-stack ids
vector_store_ids_to_query: list[str] = resolve_vector_store_ids(
rag_ids_to_query, configuration.configuration.byok_rag
)

# Request-level override: filter out Solr store, use the rest
vector_store_ids_to_query = [
vs_id
for vs_id in vector_store_ids_to_query
if vs_id != constants.SOLR_DEFAULT_VECTOR_STORE_ID
]

# If inline byok stores are not defined, we disable the inline RAG for backward compatibility
if not vector_store_ids_to_query:
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/utils/test_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,89 @@ async def test_byok_enabled_success(self, mocker) -> None: # type: ignore[no-un
assert rag_chunks[0].content == "Test content"
assert len(referenced_docs) > 0

@pytest.mark.asyncio
async def test_user_facing_ids_translated_to_internal_ids( # type: ignore[no-untyped-def]
self, mocker
) -> None:
"""Test that user-facing rag_ids (vector_store_ids) are translated to llama-stack ids."""
config_mock = mocker.Mock(spec=AppConfig)
byok_rag_mock = mocker.Mock()
byok_rag_mock.rag_id = "my-kb"
byok_rag_mock.vector_db_id = "vs-internal-001"
config_mock.configuration.byok_rag = [byok_rag_mock]
config_mock.score_multiplier_mapping = {"vs-internal-001": 1.0}
config_mock.rag_id_mapping = {"vs-internal-001": "my-kb"}
mocker.patch("utils.vector_search.configuration", config_mock)

chunk_mock = mocker.Mock()
chunk_mock.content = "Test content"
chunk_mock.chunk_id = "chunk_1"
chunk_mock.metadata = {"document_id": "doc_1"}

search_response = mocker.Mock()
search_response.chunks = [chunk_mock]
search_response.scores = [0.9]

client_mock = mocker.AsyncMock()
client_mock.vector_io.query.return_value = search_response

# Pass user-facing rag_id "my-kb"
await _fetch_byok_rag(client_mock, "test query", vector_store_ids=["my-kb"])

# Must be called with the internal llama-stack ID, not the user-facing "my-kb"
client_mock.vector_io.query.assert_called_once_with(
vector_store_id="vs-internal-001",
query="test query",
params={"max_chunks": constants.BYOK_RAG_MAX_CHUNKS, "mode": "vector"},
)

@pytest.mark.asyncio
async def test_multiple_user_facing_ids_each_translated( # type: ignore[no-untyped-def]
self, mocker
) -> None:
"""Test that multiple user-facing rag_ids are each translated to their vector_store_id."""
config_mock = mocker.Mock(spec=AppConfig)
byok_rag_1 = mocker.Mock()
byok_rag_1.rag_id = "kb-part1"
byok_rag_1.vector_db_id = "vs-aaa-111"
byok_rag_2 = mocker.Mock()
byok_rag_2.rag_id = "kb-part2"
byok_rag_2.vector_db_id = "vs-bbb-222"
config_mock.configuration.byok_rag = [byok_rag_1, byok_rag_2]
config_mock.score_multiplier_mapping = {"vs-aaa-111": 1.0, "vs-bbb-222": 1.0}
config_mock.rag_id_mapping = {
"vs-aaa-111": "kb-part1",
"vs-bbb-222": "kb-part2",
}
mocker.patch("utils.vector_search.configuration", config_mock)

chunk_mock = mocker.Mock()
chunk_mock.content = "Content"
chunk_mock.chunk_id = "chunk_1"
chunk_mock.metadata = {}

search_response = mocker.Mock()
search_response.chunks = [chunk_mock]
search_response.scores = [0.8]

client_mock = mocker.AsyncMock()
client_mock.vector_io.query.return_value = search_response

# Pass two user-facing rag_ids
await _fetch_byok_rag(
client_mock, "test query", vector_store_ids=["kb-part1", "kb-part2"]
)

# Each call must use the internal ID, not the user-facing name
call_args = [
call.kwargs["vector_store_id"]
for call in client_mock.vector_io.query.call_args_list
]
assert "vs-aaa-111" in call_args
assert "vs-bbb-222" in call_args
assert "kb-part1" not in call_args
assert "kb-part2" not in call_args


class TestFetchSolrRag:
"""Tests for _fetch_solr_rag async function."""
Expand Down
Loading