diff --git a/src/configuration.py b/src/configuration.py index c9ea8e4af..41cd3deb1 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -2,39 +2,38 @@ from typing import Any, Optional +import yaml + # We want to support environment variable replacement in the configuration # similarly to how it is done in llama-stack, so we use their function directly from llama_stack.core.stack import replace_env_vars -import yaml import constants +from cache.cache import Cache +from cache.cache_factory import CacheFactory +from log import get_logger from models.config import ( A2AStateConfiguration, + AuthenticationConfiguration, AuthorizationConfiguration, AzureEntraIdConfiguration, Configuration, + ConversationHistoryConfiguration, Customization, + DatabaseConfiguration, + InferenceConfiguration, LlamaStackConfiguration, + ModelContextProtocolServer, OkpConfiguration, + QuotaHandlersConfiguration, RagConfiguration, - UserDataCollection, ServiceConfiguration, - ModelContextProtocolServer, - AuthenticationConfiguration, - InferenceConfiguration, - DatabaseConfiguration, - ConversationHistoryConfiguration, - QuotaHandlersConfiguration, SplunkConfiguration, + UserDataCollection, ) - -from cache.cache import Cache -from cache.cache_factory import CacheFactory - from quota.quota_limiter import QuotaLimiter -from quota.token_usage_history import TokenUsageHistory from quota.quota_limiter_factory import QuotaLimiterFactory -from log import get_logger +from quota.token_usage_history import TokenUsageHistory logger = get_logger(__name__) @@ -382,18 +381,28 @@ def okp(self) -> "OkpConfiguration": @property def rag_id_mapping(self) -> dict[str, str]: - """Return mapping from vector_db_id to rag_id from BYOK RAG config. + """Return mapping from vector_db_id to rag_id from BYOK and OKP RAG config. Returns: - dict[str, str]: Mapping where keys are llama-stack vector_db_ids - and values are user-facing rag_ids from configuration. + dict[str, str]: Mapping where keys are llama-stack vector_store_ids + (old vector_db_id) and values are user-facing rag_ids from configuration. Raises: LogicError: If the configuration has not been loaded. """ if self._configuration is None: raise LogicError("logic error: configuration is not loaded") - return {brag.vector_db_id: brag.rag_id for brag in self._configuration.byok_rag} + byok_mapping = { + brag.vector_db_id: brag.rag_id for brag in self._configuration.byok_rag + } + + rag = self._configuration.rag + okp_id = constants.OKP_RAG_ID + okp_enabled = okp_id in (rag.inline or []) or okp_id in (rag.tool or []) + okp_mapping = ( + {constants.SOLR_DEFAULT_VECTOR_STORE_ID: okp_id} if okp_enabled else {} + ) + return {**byok_mapping, **okp_mapping} @property def score_multiplier_mapping(self) -> dict[str, float]: diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index acd1ca5af..da4307782 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -2,16 +2,18 @@ # pylint: disable=too-many-lines +from collections.abc import Generator from pathlib import Path from typing import Any -from collections.abc import Generator -from pydantic import ValidationError import pytest +from pydantic import ValidationError + +import constants +from cache.in_memory_cache import InMemoryCache +from cache.sqlite_cache import SQLiteCache from configuration import AppConfig, LogicError from models.config import CustomProfile, ModelContextProtocolServer -from cache.sqlite_cache import SQLiteCache -from cache.in_memory_cache import InMemoryCache # pylint: disable=broad-exception-caught,protected-access @@ -947,11 +949,61 @@ def test_load_configuration_with_incomplete_azure_entra_id_raises(tmpdir: Path) cfg.load_configuration(str(cfg_filename)) -def test_rag_id_mapping_empty_when_no_byok(minimal_config: AppConfig) -> None: - """Test that rag_id_mapping returns empty dict when no BYOK RAG configured.""" +def test_rag_id_mapping_excludes_solr_when_okp_not_configured( + minimal_config: AppConfig, +) -> None: + """Test that rag_id_mapping does not include OKP/Solr when OKP is not in rag config.""" assert minimal_config.rag_id_mapping == {} +def test_rag_id_mapping_includes_solr_when_okp_in_inline() -> None: + """Test that rag_id_mapping includes OKP/Solr mapping when OKP is in rag.inline.""" + cfg = AppConfig() + cfg.init_from_dict( + { + "name": "test", + "service": {"host": "localhost", "port": 8080}, + "llama_stack": { + "api_key": "k", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {}, + "authentication": {"module": "noop"}, + "rag": {"inline": [constants.OKP_RAG_ID]}, + } + ) + assert constants.SOLR_DEFAULT_VECTOR_STORE_ID in cfg.rag_id_mapping + assert ( + cfg.rag_id_mapping[constants.SOLR_DEFAULT_VECTOR_STORE_ID] + == constants.OKP_RAG_ID + ) + + +def test_rag_id_mapping_includes_solr_when_okp_in_tool() -> None: + """Test that rag_id_mapping includes OKP/Solr mapping when OKP is in rag.tool.""" + cfg = AppConfig() + cfg.init_from_dict( + { + "name": "test", + "service": {"host": "localhost", "port": 8080}, + "llama_stack": { + "api_key": "k", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {}, + "authentication": {"module": "noop"}, + "rag": {"tool": [constants.OKP_RAG_ID]}, + } + ) + assert constants.SOLR_DEFAULT_VECTOR_STORE_ID in cfg.rag_id_mapping + assert ( + cfg.rag_id_mapping[constants.SOLR_DEFAULT_VECTOR_STORE_ID] + == constants.OKP_RAG_ID + ) + + def test_rag_id_mapping_with_byok(tmp_path: Path) -> None: """Test that rag_id_mapping builds correct mapping from BYOK config.""" db_file = tmp_path / "test.db" @@ -980,6 +1032,41 @@ def test_rag_id_mapping_with_byok(tmp_path: Path) -> None: assert cfg.rag_id_mapping == {"vs-001": "my-kb"} +def test_rag_id_mapping_with_byok_and_okp(tmp_path: Path) -> None: + """Test that rag_id_mapping includes both BYOK and OKP entries when OKP is configured.""" + db_file = tmp_path / "test.db" + db_file.touch() + cfg = AppConfig() + cfg.init_from_dict( + { + "name": "test", + "service": {"host": "localhost", "port": 8080}, + "llama_stack": { + "api_key": "k", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {}, + "authentication": {"module": "noop"}, + "rag": {"inline": [constants.OKP_RAG_ID]}, + "byok_rag": [ + { + "rag_id": "my-kb", + "vector_db_id": "vs-001", + "db_path": str(db_file), + }, + ], + } + ) + assert "vs-001" in cfg.rag_id_mapping + assert cfg.rag_id_mapping["vs-001"] == "my-kb" + assert constants.SOLR_DEFAULT_VECTOR_STORE_ID in cfg.rag_id_mapping + assert ( + cfg.rag_id_mapping[constants.SOLR_DEFAULT_VECTOR_STORE_ID] + == constants.OKP_RAG_ID + ) + + def test_resolve_index_name_with_mapping(minimal_config: AppConfig) -> None: """Test resolve_index_name uses mapping when available.""" mapping = {"vs-x": "user-friendly-name"}