diff --git a/src/utils/vector_search.py b/src/utils/vector_search.py index b24b3214a..269ce2b83 100644 --- a/src/utils/vector_search.py +++ b/src/utils/vector_search.py @@ -317,7 +317,7 @@ def _process_solr_chunks_for_documents( async def _fetch_byok_rag( client: AsyncLlamaStackClient, query: str, - vector_store_ids: Optional[list[str]] = None, + vector_store_ids: Optional[list[str]] = None, # User-facing ) -> tuple[list[RAGChunk], list[ReferencedDocument]]: """Fetch chunks and documents from BYOK RAG sources. @@ -339,22 +339,23 @@ async def _fetch_byok_rag( # Determine which BYOK vector stores to query for inline RAG. # Per-request override takes precedence; otherwise use config-based inline list. - if vector_store_ids is not None: - # Request-level override: filter out Solr store, use the rest - vector_store_ids_to_query = [ - vs_id - for vs_id in vector_store_ids - if vs_id != constants.SOLR_DEFAULT_VECTOR_STORE_ID - ] - else: - inline_rag_ids = [ - rid - for rid in configuration.configuration.rag.inline - if rid != constants.OKP_RAG_ID - ] - vector_store_ids_to_query = resolve_vector_store_ids( - inline_rag_ids, configuration.configuration.byok_rag - ) + rag_ids_to_query = ( + configuration.configuration.rag.inline + if vector_store_ids is None + else vector_store_ids + ) + + # Translate user-facing rag_ids to llama-stack ids + vector_store_ids_to_query: list[str] = resolve_vector_store_ids( + rag_ids_to_query, configuration.configuration.byok_rag + ) + + # Request-level override: filter out Solr store, use the rest + vector_store_ids_to_query = [ + vs_id + for vs_id in vector_store_ids_to_query + if vs_id != constants.SOLR_DEFAULT_VECTOR_STORE_ID + ] # If inline byok stores are not defined, we disable the inline RAG for backward compatibility if not vector_store_ids_to_query: diff --git a/tests/unit/utils/test_vector_search.py b/tests/unit/utils/test_vector_search.py index 930f59d36..12a07193d 100644 --- a/tests/unit/utils/test_vector_search.py +++ b/tests/unit/utils/test_vector_search.py @@ -400,6 +400,89 @@ async def test_byok_enabled_success(self, mocker) -> None: # type: ignore[no-un assert rag_chunks[0].content == "Test content" assert len(referenced_docs) > 0 + @pytest.mark.asyncio + async def test_user_facing_ids_translated_to_internal_ids( # type: ignore[no-untyped-def] + self, mocker + ) -> None: + """Test that user-facing rag_ids (vector_store_ids) are translated to llama-stack ids.""" + config_mock = mocker.Mock(spec=AppConfig) + byok_rag_mock = mocker.Mock() + byok_rag_mock.rag_id = "my-kb" + byok_rag_mock.vector_db_id = "vs-internal-001" + config_mock.configuration.byok_rag = [byok_rag_mock] + config_mock.score_multiplier_mapping = {"vs-internal-001": 1.0} + config_mock.rag_id_mapping = {"vs-internal-001": "my-kb"} + mocker.patch("utils.vector_search.configuration", config_mock) + + chunk_mock = mocker.Mock() + chunk_mock.content = "Test content" + chunk_mock.chunk_id = "chunk_1" + chunk_mock.metadata = {"document_id": "doc_1"} + + search_response = mocker.Mock() + search_response.chunks = [chunk_mock] + search_response.scores = [0.9] + + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = search_response + + # Pass user-facing rag_id "my-kb" + await _fetch_byok_rag(client_mock, "test query", vector_store_ids=["my-kb"]) + + # Must be called with the internal llama-stack ID, not the user-facing "my-kb" + client_mock.vector_io.query.assert_called_once_with( + vector_store_id="vs-internal-001", + query="test query", + params={"max_chunks": constants.BYOK_RAG_MAX_CHUNKS, "mode": "vector"}, + ) + + @pytest.mark.asyncio + async def test_multiple_user_facing_ids_each_translated( # type: ignore[no-untyped-def] + self, mocker + ) -> None: + """Test that multiple user-facing rag_ids are each translated to their vector_store_id.""" + config_mock = mocker.Mock(spec=AppConfig) + byok_rag_1 = mocker.Mock() + byok_rag_1.rag_id = "kb-part1" + byok_rag_1.vector_db_id = "vs-aaa-111" + byok_rag_2 = mocker.Mock() + byok_rag_2.rag_id = "kb-part2" + byok_rag_2.vector_db_id = "vs-bbb-222" + config_mock.configuration.byok_rag = [byok_rag_1, byok_rag_2] + config_mock.score_multiplier_mapping = {"vs-aaa-111": 1.0, "vs-bbb-222": 1.0} + config_mock.rag_id_mapping = { + "vs-aaa-111": "kb-part1", + "vs-bbb-222": "kb-part2", + } + mocker.patch("utils.vector_search.configuration", config_mock) + + chunk_mock = mocker.Mock() + chunk_mock.content = "Content" + chunk_mock.chunk_id = "chunk_1" + chunk_mock.metadata = {} + + search_response = mocker.Mock() + search_response.chunks = [chunk_mock] + search_response.scores = [0.8] + + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = search_response + + # Pass two user-facing rag_ids + await _fetch_byok_rag( + client_mock, "test query", vector_store_ids=["kb-part1", "kb-part2"] + ) + + # Each call must use the internal ID, not the user-facing name + call_args = [ + call.kwargs["vector_store_id"] + for call in client_mock.vector_io.query.call_args_list + ] + assert "vs-aaa-111" in call_args + assert "vs-bbb-222" in call_args + assert "kb-part1" not in call_args + assert "kb-part2" not in call_args + class TestFetchSolrRag: """Tests for _fetch_solr_rag async function."""