diff --git a/integrations/oracle/docker-compose.yml b/integrations/oracle/docker-compose.yml index 7e4c7b9d04..5de27ee073 100644 --- a/integrations/oracle/docker-compose.yml +++ b/integrations/oracle/docker-compose.yml @@ -7,7 +7,7 @@ services: - ORACLE_PASSWORD=haystack - APP_USER=haystack - APP_USER_PASSWORD=haystack - - ORACLE_INIT_PARAMS=vector_memory_size=512M + - ORACLE_INIT_PARAMS=vector_memory_size=2G volumes: - ./init:/container-entrypoint-initdb.d healthcheck: diff --git a/integrations/oracle/init/01_vector_memory.sql b/integrations/oracle/init/01_vector_memory.sql index 79432e07e2..299a076c6f 100644 --- a/integrations/oracle/init/01_vector_memory.sql +++ b/integrations/oracle/init/01_vector_memory.sql @@ -1,7 +1,7 @@ -- Enable vector memory pool required for HNSW in-memory vector indexes. -- This script runs at container startup via /container-entrypoint-initdb.d. --- The primary mechanism is ORACLE_INIT_PARAMS=vector_memory_size=512M in +-- The primary mechanism is ORACLE_INIT_PARAMS=vector_memory_size=2G in -- docker-compose.yml (writes to SPFILE at DB creation time). This ALTER -- SYSTEM acts as a belt-and-suspenders dynamic setter for pre-existing -- database volumes where the SPFILE value may not have been written yet. -ALTER SYSTEM SET vector_memory_size = 512M SCOPE=BOTH; +ALTER SYSTEM SET vector_memory_size = 2G SCOPE=BOTH; diff --git a/integrations/oracle/pydoc/config_docusaurus.yml b/integrations/oracle/pydoc/config_docusaurus.yml index 6dc6f8545d..db26c7c165 100644 --- a/integrations/oracle/pydoc/config_docusaurus.yml +++ b/integrations/oracle/pydoc/config_docusaurus.yml @@ -1,6 +1,10 @@ loaders: - modules: + - haystack_integrations.components.embedders.oracle.document_embedder + - haystack_integrations.components.embedders.oracle.text_embedder - haystack_integrations.components.retrievers.oracle.embedding_retriever + - haystack_integrations.components.retrievers.oracle.hybrid_retriever + - haystack_integrations.components.retrievers.oracle.keyword_retriever - haystack_integrations.document_stores.oracle.document_store search_path: [../src] processors: diff --git a/integrations/oracle/pyproject.toml b/integrations/oracle/pyproject.toml index 1df25cff12..f503371b35 100644 --- a/integrations/oracle/pyproject.toml +++ b/integrations/oracle/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] dependencies = [ "haystack-ai>=2.28.0", - "oracledb>=2.1.0,<3.0.0", + "oracledb>=2.2.0", ] [project.optional-dependencies] @@ -75,7 +75,7 @@ integration = 'pytest -m "integration" {args:tests}' all = 'pytest {args:tests}' unit-cov-retry = 'pytest --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x -m "not integration" {args:tests}' integration-cov-append-retry = 'pytest --cov=haystack_integrations --cov-append --reruns 3 --reruns-delay 30 -x -m "integration" {args:tests}' -types = "mypy -p haystack_integrations.document_stores.oracle -p haystack_integrations.components.retrievers.oracle {args}" +types = "mypy -p haystack_integrations.document_stores.oracle -p haystack_integrations.components.retrievers.oracle -p haystack_integrations.components.embedders.oracle {args}" [tool.pytest.ini_options] diff --git a/integrations/oracle/src/haystack_integrations/components/embedders/oracle/__init__.py b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/__init__.py new file mode 100644 index 0000000000..1a91769f38 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_integrations.components.embedders.oracle.document_embedder import OracleDocumentEmbedder +from haystack_integrations.components.embedders.oracle.text_embedder import OracleTextEmbedder + +__all__ = ["OracleDocumentEmbedder", "OracleTextEmbedder"] diff --git a/integrations/oracle/src/haystack_integrations/components/embedders/oracle/_base.py b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/_base.py new file mode 100644 index 0000000000..56b0b4cf47 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/_base.py @@ -0,0 +1,244 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +import json +import logging +from typing import Any + +import oracledb +from haystack.utils import Secret + +from haystack_integrations.document_stores.oracle import OracleConnectionConfig + +logger = logging.getLogger(__name__) + + +def _supports_parameter(callable_obj: Any, parameter_name: str) -> bool: + try: + return parameter_name in inspect.signature(callable_obj).parameters + except (TypeError, ValueError): + return False + + +def _execute_with_fetch_lobs(cursor: Any, statement: str, parameters: Any = None, **kwargs: Any) -> Any: + if _supports_parameter(cursor.execute, "fetch_lobs"): + kwargs.setdefault("fetch_lobs", False) + if parameters is None: + return cursor.execute(statement, **kwargs) + return cursor.execute(statement, parameters, **kwargs) + + +async def _execute_with_fetch_lobs_async(cursor: Any, statement: str, parameters: Any = None, **kwargs: Any) -> Any: + if _supports_parameter(cursor.execute, "fetch_lobs"): + kwargs.setdefault("fetch_lobs", False) + if parameters is None: + return await cursor.execute(statement, **kwargs) + return await cursor.execute(statement, parameters, **kwargs) + + +def _read_lob(value: Any) -> Any: + if hasattr(value, "read"): + return value.read() + return value + + +async def _read_lob_async(value: Any) -> Any: + if hasattr(value, "read"): + return await value.read() + return value + + +def _resolve_secret(value: Any) -> Any: + if isinstance(value, Secret): + return value.resolve_value() + return value + + +def _serialize_secret(value: Any) -> Any: + if isinstance(value, Secret): + return value.to_dict() + return value + + +class _OracleEmbedderBase: + def __init__( + self, + *, + connection_config: OracleConnectionConfig, + embedding_params: dict[str, Any] | None = None, + use_connection_pool: bool = False, + proxy: Secret | str | None = None, + ) -> None: + if connection_config is None: + msg = "connection_config must be provided." + raise ValueError(msg) + if embedding_params is None: + msg = "embedding_params must be provided." + raise ValueError(msg) + + self.connection_config = connection_config + self.embedding_params = dict(embedding_params) + self.use_connection_pool = use_connection_pool + self.proxy = proxy + + self._client: Any | None = None + self._client_async: Any | None = None + + def _connect_kwargs(self, *, pool_options: bool) -> dict[str, Any]: + cfg = self.connection_config + password = cfg.password.resolve_value() + connect_kwargs: dict[str, Any] = { + "user": cfg.user.resolve_value(), + "password": password, + "dsn": cfg.dsn.resolve_value(), + } + if pool_options: + connect_kwargs["min"] = cfg.min_connections + connect_kwargs["max"] = cfg.max_connections + connect_kwargs["increment"] = 1 + if cfg.wallet_location: + connect_kwargs["config_dir"] = cfg.wallet_location + connect_kwargs["wallet_location"] = cfg.wallet_location + connect_kwargs["wallet_password"] = cfg.wallet_password.resolve_value() if cfg.wallet_password else password + return connect_kwargs + + def _ensure_client(self) -> Any: + if self._client is not None: + return self._client + if self.use_connection_pool: + self._client = oracledb.create_pool(**self._connect_kwargs(pool_options=True)) + else: + self._client = oracledb.connect(**self._connect_kwargs(pool_options=False)) + return self._client + + def _connection_context(self) -> Any: + if self.use_connection_pool: + return self._ensure_client().acquire() + return oracledb.connect(**self._connect_kwargs(pool_options=False)) + + async def _ensure_client_async(self) -> Any: + if self._client_async is not None: + return self._client_async + if self.use_connection_pool: + create_pool_async = getattr(oracledb, "create_pool_async", None) + if create_pool_async is None: + msg = "python-oracledb does not provide create_pool_async." + raise RuntimeError(msg) + self._client_async = create_pool_async(**self._connect_kwargs(pool_options=True)) + else: + self._client_async = await oracledb.connect_async(**self._connect_kwargs(pool_options=False)) + return self._client_async + + async def _connection_context_async(self) -> Any: + if self.use_connection_pool: + return (await self._ensure_client_async()).acquire() + return await oracledb.connect_async(**self._connect_kwargs(pool_options=False)) + + def _serialize_proxy(self) -> Any: + return _serialize_secret(self.proxy) + + def _proxy_value(self) -> str | None: + proxy = _resolve_secret(self.proxy) + return str(proxy) if proxy else None + + def _embed_documents(self, texts: list[str]) -> list[list[float]]: + embeddings: list[list[float]] = [] + + with self._connection_context() as connection, connection.cursor() as cursor: + proxy_was_set = False + proxy = self._proxy_value() + if proxy: + _execute_with_fetch_lobs(cursor, "BEGIN UTL_HTTP.SET_PROXY(:proxy); END;", proxy=proxy) + proxy_was_set = True + try: + vector_array_type = connection.gettype("SYS.VECTOR_ARRAY_T") + chunks = [json.dumps({"chunk_id": index, "chunk_data": text}) for index, text in enumerate(texts, 1)] + inputs = vector_array_type.newobject(chunks) + cursor.setinputsizes(None, oracledb.DB_TYPE_JSON) + _execute_with_fetch_lobs( + cursor, + "SELECT t.* FROM DBMS_VECTOR_CHAIN.UTL_TO_EMBEDDINGS(:1, JSON(:2)) t", + [inputs, self.embedding_params], + ) + for row in cursor: + if row is None: + embeddings.append([]) + continue + row_data = json.loads(_read_lob(row[0])) + embeddings.append(json.loads(row_data["embed_vector"])) + except BaseException as exc: + if proxy_was_set: + self._clear_proxy(cursor, exc) + raise + else: + if proxy_was_set: + self._clear_proxy(cursor, None) + return embeddings + + async def _embed_documents_async(self, texts: list[str]) -> list[list[float]]: + embeddings: list[list[float]] = [] + + connection_context = await self._connection_context_async() + async with connection_context as connection: + with connection.cursor() as cursor: + proxy_was_set = False + proxy = self._proxy_value() + if proxy: + await _execute_with_fetch_lobs_async(cursor, "BEGIN UTL_HTTP.SET_PROXY(:proxy); END;", proxy=proxy) + proxy_was_set = True + try: + vector_array_type = await connection.gettype("SYS.VECTOR_ARRAY_T") + chunks = [ + json.dumps({"chunk_id": index, "chunk_data": text}) for index, text in enumerate(texts, 1) + ] + inputs = vector_array_type.newobject() + for chunk in chunks: + clob = await connection.createlob(oracledb.DB_TYPE_CLOB) + await clob.write(chunk) + inputs.append(clob) + cursor.setinputsizes(None, oracledb.DB_TYPE_JSON) + await _execute_with_fetch_lobs_async( + cursor, + "SELECT t.* FROM DBMS_VECTOR_CHAIN.UTL_TO_EMBEDDINGS(:1, JSON(:2)) t", + [inputs, self.embedding_params], + ) + async for row in cursor: + if row is None: + embeddings.append([]) + continue + row_data = json.loads(await _read_lob_async(row[0])) + embeddings.append(json.loads(row_data["embed_vector"])) + except BaseException as exc: + if proxy_was_set: + await self._clear_proxy_async(cursor, exc) + raise + else: + if proxy_was_set: + await self._clear_proxy_async(cursor, None) + return embeddings + + @staticmethod + def _clear_proxy(cursor: Any, original_error: BaseException | None) -> None: + try: + cursor.execute("BEGIN UTL_HTTP.SET_PROXY(:proxy); END;", proxy=None) + except Exception as cleanup_error: + logger.exception("Failed to clear Oracle session proxy.") + if original_error is not None: + msg = "Failed to clear Oracle session proxy after embedding failed." + raise RuntimeError(msg) from cleanup_error + msg = "Failed to clear Oracle session proxy after embedding succeeded." + raise RuntimeError(msg) from cleanup_error + + @staticmethod + async def _clear_proxy_async(cursor: Any, original_error: BaseException | None) -> None: + try: + await _execute_with_fetch_lobs_async(cursor, "BEGIN UTL_HTTP.SET_PROXY(:proxy); END;", proxy=None) + except Exception as cleanup_error: + logger.exception("Failed to clear Oracle session proxy.") + if original_error is not None: + msg = "Failed to clear Oracle session proxy after async embedding failed." + raise RuntimeError(msg) from cleanup_error + msg = "Failed to clear Oracle session proxy after async embedding succeeded." + raise RuntimeError(msg) from cleanup_error diff --git a/integrations/oracle/src/haystack_integrations/components/embedders/oracle/document_embedder.py b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/document_embedder.py new file mode 100644 index 0000000000..8f99af2814 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/document_embedder.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Mapping +from dataclasses import replace +from typing import Any + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.utils import Secret, deserialize_secrets_inplace + +from haystack_integrations.document_stores.oracle import OracleConnectionConfig + +from ._base import _OracleEmbedderBase + + +@component +class OracleDocumentEmbedder(_OracleEmbedderBase): + """ + Embeds Haystack Documents with Oracle Database embedding functions. + + The component embeds each document's content and can prepend selected metadata + values before sending text to Oracle. It returns copied documents with the + resulting vectors populated in ``Document.embedding``. + """ + + def __init__( + self, + *, + connection_config: OracleConnectionConfig, + embedding_params: dict[str, Any] | None = None, + use_connection_pool: bool = False, + proxy: Secret | str | None = None, + meta_fields_to_embed: list[str] | None = None, + embedding_separator: str = "\n", + ) -> None: + """ + Create an Oracle document embedder. + + :param connection_config: Oracle connection settings, including user, password, DSN, and optional wallet. + :param embedding_params: JSON-serializable Oracle embedding parameters, such as provider and model. + :param use_connection_pool: When ``True``, reuse a python-oracledb connection pool. + :param proxy: Optional HTTP proxy set in the Oracle session with ``UTL_HTTP.SET_PROXY``. + :param meta_fields_to_embed: Metadata keys to prepend to document content before embedding. + Missing keys and ``None`` values are skipped. + :param embedding_separator: Separator used between metadata values and document content. + :raises ValueError: If ``connection_config`` or ``embedding_params`` is missing. + """ + _OracleEmbedderBase.__init__( + self, + connection_config=connection_config, + embedding_params=embedding_params, + use_connection_pool=use_connection_pool, + proxy=proxy, + ) + self.meta_fields_to_embed = list(meta_fields_to_embed or []) + self.embedding_separator = embedding_separator + + def _prepare_texts_to_embed(self, documents: list[Document]) -> list[str]: + texts: list[str] = [] + for document in documents: + meta_values = [ + str(document.meta[field]) for field in self.meta_fields_to_embed if document.meta.get(field) is not None + ] + texts.append(self.embedding_separator.join([*meta_values, document.content or ""])) + return texts + + @component.output_types(documents=list[Document], meta=dict[str, Any]) + def run(self, documents: list[Document]) -> dict[str, Any]: + """ + Compute embeddings and return documents with ``Document.embedding`` populated. + + :param documents: Documents to embed. Each item must be a Haystack ``Document``. + :returns: A dictionary containing ``documents`` and ``meta``. The returned documents are copies of the input + documents with ``embedding`` populated. + :raises TypeError: If ``documents`` is not a list of Haystack ``Document`` objects. + """ + if not isinstance(documents, list) or any(not isinstance(document, Document) for document in documents): + msg = "OracleDocumentEmbedder expects a list of Document objects." + raise TypeError(msg) + embeddings = self._embed_documents(self._prepare_texts_to_embed(documents)) + embedded_documents = [ + replace(document, embedding=embedding) for document, embedding in zip(documents, embeddings, strict=True) + ] + return {"documents": embedded_documents, "meta": self.embedding_params} + + @component.output_types(documents=list[Document], meta=dict[str, Any]) + async def run_async(self, documents: list[Document]) -> dict[str, Any]: + """ + Compute embeddings asynchronously and return documents with ``Document.embedding`` populated. + + :param documents: Documents to embed. Each item must be a Haystack ``Document``. + :returns: A dictionary containing ``documents`` and ``meta``. The returned documents are copies of the input + documents with ``embedding`` populated. + :raises TypeError: If ``documents`` is not a list of Haystack ``Document`` objects. + """ + if not isinstance(documents, list) or any(not isinstance(document, Document) for document in documents): + msg = "OracleDocumentEmbedder expects a list of Document objects." + raise TypeError(msg) + embeddings = await self._embed_documents_async(self._prepare_texts_to_embed(documents)) + embedded_documents = [ + replace(document, embedding=embedding) for document, embedding in zip(documents, embeddings, strict=True) + ] + return {"documents": embedded_documents, "meta": self.embedding_params} + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the component to a dictionary. + """ + return default_to_dict( + self, + connection_config=self.connection_config.to_dict(), + embedding_params=self.embedding_params, + use_connection_pool=self.use_connection_pool, + proxy=self._serialize_proxy(), + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "OracleDocumentEmbedder": + """ + Deserializes the component from a dictionary. + """ + params = data.get("init_parameters", {}) + connection_config = params.get("connection_config") + if isinstance(connection_config, Mapping): + params["connection_config"] = OracleConnectionConfig.from_dict(dict(connection_config)) + if isinstance(params.get("proxy"), dict) and "type" in params["proxy"]: + deserialize_secrets_inplace(params, keys=["proxy"]) + return default_from_dict(cls, data) diff --git a/integrations/oracle/src/haystack_integrations/components/embedders/oracle/py.typed b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/py.typed new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/py.typed @@ -0,0 +1 @@ + diff --git a/integrations/oracle/src/haystack_integrations/components/embedders/oracle/text_embedder.py b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/text_embedder.py new file mode 100644 index 0000000000..e5c0e3a3ea --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/embedders/oracle/text_embedder.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Mapping +from typing import Any + +from haystack import component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace + +from haystack_integrations.document_stores.oracle import OracleConnectionConfig + +from ._base import _OracleEmbedderBase + + +@component +class OracleTextEmbedder(_OracleEmbedderBase): + """ + Embeds strings with Oracle Database embedding functions. + + The component calls ``DBMS_VECTOR_CHAIN.UTL_TO_EMBEDDINGS`` with the configured + Oracle embedding parameters and returns one dense vector for each input text. + """ + + def __init__( + self, + *, + connection_config: OracleConnectionConfig, + embedding_params: dict[str, Any] | None = None, + use_connection_pool: bool = False, + proxy: Secret | str | None = None, + ) -> None: + """ + Create an Oracle text embedder. + + :param connection_config: Oracle connection settings, including user, password, DSN, and optional wallet. + :param embedding_params: JSON-serializable Oracle embedding parameters, such as provider and model. + :param use_connection_pool: When ``True``, reuse a python-oracledb connection pool. + :param proxy: Optional HTTP proxy set in the Oracle session with ``UTL_HTTP.SET_PROXY``. + :raises ValueError: If ``connection_config`` or ``embedding_params`` is missing. + """ + _OracleEmbedderBase.__init__( + self, + connection_config=connection_config, + embedding_params=embedding_params, + use_connection_pool=use_connection_pool, + proxy=proxy, + ) + + @component.output_types(embedding=list[float], meta=dict[str, Any]) + def run(self, text: str) -> dict[str, Any]: + """ + Compute one embedding for a single input string. + + :param text: Text to embed. + :returns: A dictionary containing ``embedding`` and ``meta``. ``meta`` contains the embedding parameters. + :raises TypeError: If ``text`` is not a string. + """ + if not isinstance(text, str): + msg = "OracleTextEmbedder expects a string input." + raise TypeError(msg) + return {"embedding": self._embed_documents([text])[0], "meta": self.embedding_params} + + @component.output_types(embedding=list[float], meta=dict[str, Any]) + async def run_async(self, text: str) -> dict[str, Any]: + """ + Compute one embedding for a single input string asynchronously. + + :param text: Text to embed. + :returns: A dictionary containing ``embedding`` and ``meta``. ``meta`` contains the embedding parameters. + :raises TypeError: If ``text`` is not a string. + """ + if not isinstance(text, str): + msg = "OracleTextEmbedder expects a string input." + raise TypeError(msg) + return {"embedding": (await self._embed_documents_async([text]))[0], "meta": self.embedding_params} + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the component to a dictionary. + """ + return default_to_dict( + self, + connection_config=self.connection_config.to_dict(), + embedding_params=self.embedding_params, + use_connection_pool=self.use_connection_pool, + proxy=self._serialize_proxy(), + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "OracleTextEmbedder": + """ + Deserializes the component from a dictionary. + """ + params = data.get("init_parameters", {}) + connection_config = params.get("connection_config") + if isinstance(connection_config, Mapping): + params["connection_config"] = OracleConnectionConfig.from_dict(dict(connection_config)) + if isinstance(params.get("proxy"), dict) and "type" in params["proxy"]: + deserialize_secrets_inplace(params, keys=["proxy"]) + return default_from_dict(cls, data) diff --git a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py index 2cff16e4d6..5f6662110e 100644 --- a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py +++ b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from haystack_integrations.components.retrievers.oracle.embedding_retriever import OracleEmbeddingRetriever +from haystack_integrations.components.retrievers.oracle.hybrid_retriever import OracleHybridRetriever from haystack_integrations.components.retrievers.oracle.keyword_retriever import OracleKeywordRetriever -__all__ = ["OracleEmbeddingRetriever", "OracleKeywordRetriever"] +__all__ = ["OracleEmbeddingRetriever", "OracleHybridRetriever", "OracleKeywordRetriever"] diff --git a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py index 094fa79b77..a90d73e96c 100644 --- a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py +++ b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py @@ -34,6 +34,15 @@ def __init__( top_k: int = 10, filter_policy: FilterPolicy = FilterPolicy.REPLACE, ) -> None: + """ + Create an Oracle embedding retriever. + + :param document_store: Oracle document store used for vector similarity search. + :param filters: Base Haystack metadata filters applied to every retrieval. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy for combining constructor filters with runtime filters. + :raises TypeError: If ``document_store`` is not an ``OracleDocumentStore``. + """ if not isinstance(document_store, OracleDocumentStore): msg = "document_store must be an instance of OracleDocumentStore" raise TypeError(msg) @@ -75,7 +84,14 @@ async def run_async( filters: dict[str, Any] | None = None, top_k: int | None = None, ) -> dict[str, list[Document]]: - """Async variant of :meth:`run`.""" + """ + Asynchronously retrieve documents by vector similarity. + + :param query_embedding: Dense float vector from an embedder component. + :param filters: Runtime filters, merged with constructor filters according to filter_policy. + :param top_k: Override the constructor top_k for this call. + :returns: ``{"documents": [Document, ...]}`` + """ filters = apply_filter_policy(self.filter_policy, self.filters, filters) docs = await self.document_store._embedding_retrieval_async( query_embedding, diff --git a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/hybrid_retriever.py b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/hybrid_retriever.py new file mode 100644 index 0000000000..81fe1fe731 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/hybrid_retriever.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Literal + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.oracle import OracleDocumentStore + +_VALID_SEARCH_MODES = {"keyword", "hybrid", "semantic"} + + +@component +class OracleHybridRetriever: + """ + Retrieves documents with DBMS_HYBRID_VECTOR.SEARCH. + + The retriever runs against an existing Oracle hybrid vector index. It supports + keyword-only, semantic-only, and hybrid search modes, optional metadata filters, + and optional score metadata returned from Oracle. + """ + + def __init__( + self, + *, + document_store: OracleDocumentStore, + index_name: str, + search_mode: Literal["keyword", "hybrid", "semantic"] = "hybrid", + filters: dict[str, Any] | None = None, + top_k: int = 10, + params: dict[str, Any] | None = None, + return_scores: bool = False, + filter_policy: FilterPolicy = FilterPolicy.REPLACE, + ) -> None: + """ + Create an Oracle hybrid retriever. + + :param document_store: Oracle document store used to fetch documents by row id after hybrid search. + :param index_name: Name of the existing Oracle hybrid vector index. + :param search_mode: Search mode passed to Oracle. Supported values are ``"keyword"``, + ``"semantic"``, and ``"hybrid"``. + :param filters: Base Haystack metadata filters applied to every retrieval. + :param top_k: Maximum number of documents to return. + :param params: Optional DBMS_HYBRID_VECTOR.SEARCH parameters. ``search_text`` and + return settings derived by the retriever cannot be supplied here. + :param return_scores: When ``True``, include Oracle hybrid, text, and vector scores in document metadata. + :param filter_policy: Policy for combining constructor filters with runtime filters. + :raises TypeError: If ``document_store`` is not an ``OracleDocumentStore``. + :raises ValueError: If ``search_mode`` or ``params`` are invalid. + """ + if not isinstance(document_store, OracleDocumentStore): + msg = "document_store must be an instance of OracleDocumentStore" + raise TypeError(msg) + if search_mode not in _VALID_SEARCH_MODES: + msg = f"search_mode must be one of {_VALID_SEARCH_MODES}, got {search_mode!r}" + raise ValueError(msg) + + self.document_store = document_store + self.index_name = index_name + self.search_mode = search_mode + self.filters = filters or {} + self.top_k = top_k + self.params = OracleDocumentStore._validate_hybrid_params(params or {}) + self.return_scores = return_scores + self.filter_policy = FilterPolicy.from_str(filter_policy) if isinstance(filter_policy, str) else filter_policy + + def _merged_params(self, params: dict[str, Any] | None) -> dict[str, Any]: + merged_params = dict(self.params) + merged_params.update(OracleDocumentStore._validate_hybrid_params(params or {})) + return merged_params + + @component.output_types(documents=list[Document]) + def run( + self, + query: str, + filters: dict[str, Any] | None = None, + top_k: int | None = None, + params: dict[str, Any] | None = None, + ) -> dict[str, list[Document]]: + """ + Retrieve documents for a text query. + + :param query: Text query used for keyword, semantic, or hybrid search. + :param filters: Runtime filters combined with constructor filters according to ``filter_policy``. + :param top_k: Optional override for the maximum number of returned documents. + :param params: Optional runtime DBMS_HYBRID_VECTOR.SEARCH parameters merged with constructor params. + :returns: A dictionary containing the retrieved ``documents``. + """ + merged_filters = apply_filter_policy(self.filter_policy, self.filters, filters) + documents = self.document_store._hybrid_retrieval( + query, + index_name=self.index_name, + search_mode=self.search_mode, + filters=merged_filters, + top_k=top_k if top_k is not None else self.top_k, + params=self._merged_params(params), + return_scores=self.return_scores, + ) + return {"documents": documents} + + @component.output_types(documents=list[Document]) + async def run_async( + self, + query: str, + filters: dict[str, Any] | None = None, + top_k: int | None = None, + params: dict[str, Any] | None = None, + ) -> dict[str, list[Document]]: + """ + Asynchronously retrieve documents for a text query. + + :param query: Text query used for keyword, semantic, or hybrid search. + :param filters: Runtime filters combined with constructor filters according to ``filter_policy``. + :param top_k: Optional override for the maximum number of returned documents. + :param params: Optional runtime DBMS_HYBRID_VECTOR.SEARCH parameters merged with constructor params. + :returns: A dictionary containing the retrieved ``documents``. + """ + merged_filters = apply_filter_policy(self.filter_policy, self.filters, filters) + documents = await self.document_store._hybrid_retrieval_async( + query, + index_name=self.index_name, + search_mode=self.search_mode, + filters=merged_filters, + top_k=top_k if top_k is not None else self.top_k, + params=self._merged_params(params), + return_scores=self.return_scores, + ) + return {"documents": documents} + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the component to a dictionary. + """ + return default_to_dict( + self, + document_store=self.document_store.to_dict(), + index_name=self.index_name, + search_mode=self.search_mode, + filters=self.filters, + top_k=self.top_k, + params=self.params, + return_scores=self.return_scores, + filter_policy=self.filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "OracleHybridRetriever": + """ + Deserializes the component from a dictionary. + """ + params = data.get("init_parameters", {}) + if "document_store" in params: + params["document_store"] = OracleDocumentStore.from_dict(params["document_store"]) + if filter_policy := params.get("filter_policy"): + params["filter_policy"] = FilterPolicy.from_str(filter_policy) + return default_from_dict(cls, data) diff --git a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/keyword_retriever.py b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/keyword_retriever.py index c4e6bfe185..1a10ccab6a 100644 --- a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/keyword_retriever.py +++ b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/keyword_retriever.py @@ -32,6 +32,15 @@ def __init__( top_k: int = 10, filter_policy: FilterPolicy = FilterPolicy.REPLACE, ) -> None: + """ + Create an Oracle keyword retriever. + + :param document_store: Oracle document store used for keyword search. + :param filters: Base Haystack metadata filters applied to every retrieval. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy for combining constructor filters with runtime filters. + :raises TypeError: If ``document_store`` is not an ``OracleDocumentStore``. + """ if not isinstance(document_store, OracleDocumentStore): msg = "document_store must be an instance of OracleDocumentStore" raise TypeError(msg) @@ -73,7 +82,14 @@ async def run_async( filters: dict[str, Any] | None = None, top_k: int | None = None, ) -> dict[str, list[Document]]: - """Async variant of :meth:`run`.""" + """ + Asynchronously retrieve documents by keyword search. + + :param query: Keyword query string. + :param filters: Runtime filters, merged with constructor filters according to filter_policy. + :param top_k: Override the constructor top_k for this call. + :returns: ``{"documents": [Document, ...]}`` + """ filters = apply_filter_policy(self.filter_policy, self.filters, filters) docs = await self.document_store._keyword_retrieval_async( query, diff --git a/integrations/oracle/src/haystack_integrations/document_stores/oracle/__init__.py b/integrations/oracle/src/haystack_integrations/document_stores/oracle/__init__.py index 1aa47db7cf..c14b42ad84 100644 --- a/integrations/oracle/src/haystack_integrations/document_stores/oracle/__init__.py +++ b/integrations/oracle/src/haystack_integrations/document_stores/oracle/__init__.py @@ -5,6 +5,7 @@ from haystack_integrations.document_stores.oracle.document_store import ( OracleConnectionConfig, OracleDocumentStore, + OracleVectorizerPreference, ) -__all__ = ["OracleConnectionConfig", "OracleDocumentStore"] +__all__ = ["OracleConnectionConfig", "OracleDocumentStore", "OracleVectorizerPreference"] diff --git a/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py b/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py index 19391bab21..ace5a7c5aa 100644 --- a/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py +++ b/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py @@ -8,23 +8,30 @@ import logging import re import threading -from dataclasses import dataclass -from typing import Any, Literal +import uuid +from dataclasses import dataclass, replace +from typing import Any, Literal, cast import oracledb from haystack import default_from_dict, default_to_dict from haystack.dataclasses import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.errors import FilterError from haystack.utils import Secret, deserialize_secrets_inplace -from .filters import FilterTranslator +from .filters import FilterTranslator, to_hybrid_filter logger = logging.getLogger(__name__) _SAFE_TABLE_NAME = re.compile(r"^[A-Za-z_][A-Za-z0-9_$#]{0,127}$") _SAFE_FIELD_PATH = re.compile(r"^[A-Za-z0-9_.]+$") +_SAFE_HYBRID_PARAM = re.compile(r"^[A-Za-z0-9_.$#:/,+ -]+$") +_VALID_DISTANCE_METRICS = {"COSINE", "EUCLIDEAN", "DOT"} +_VALID_VECTOR_INDEX_TYPES = {"HNSW", "IVF"} +_VALID_HYBRID_SEARCH_MODES = {"keyword", "hybrid", "semantic"} MAX_INDEX_NAME_LEN = 128 +VectorIndexType = Literal["HNSW", "IVF"] def _validate_field_path(field_path: str) -> None: @@ -50,6 +57,255 @@ def _try_parse_number(value: Any) -> Any: return value +def _validate_identifier(identifier: str, field_name: str) -> str: + if not _SAFE_TABLE_NAME.match(identifier): + msg = ( + f"Invalid {field_name} {identifier!r}. Must be a valid Oracle identifier " + "(letters, digits, _, $, # — max 128 chars, cannot start with a digit)." + ) + raise ValueError(msg) + return identifier + + +def _is_missing_object_error(error: oracledb.DatabaseError) -> bool: + original_error = error.args[0] if error.args else error + error_code = getattr(original_error, "code", None) + message = str(error) + return error_code in {942, 1418} or "DRG-10502" in message or "index does not exist" in message.lower() + + +def _is_dbms_search_unavailable_error(error: oracledb.DatabaseError) -> bool: + message = str(error).upper() + return "PLS-00201" in message and "DBMS_SEARCH" in message + + +def _output_type_string_handler(cursor: Any, metadata: Any) -> Any: + if metadata.type_code is oracledb.DB_TYPE_CLOB: + return cursor.var(oracledb.DB_TYPE_LONG, arraysize=cursor.arraysize) + if metadata.type_code is oracledb.DB_TYPE_NCLOB: + return cursor.var(oracledb.DB_TYPE_LONG_NVARCHAR, arraysize=cursor.arraysize) + return None + + +def _validate_distance_metric(distance_metric: str) -> str: + metric = distance_metric.upper() + if metric not in _VALID_DISTANCE_METRICS: + msg = f"distance_metric must be one of {_VALID_DISTANCE_METRICS}, got {distance_metric!r}" + raise ValueError(msg) + return metric + + +def _validate_index_type(index_type: str) -> VectorIndexType: + normalized = index_type.upper() + if normalized not in _VALID_VECTOR_INDEX_TYPES: + msg = f"vector_index_type must be one of {_VALID_VECTOR_INDEX_TYPES}, got {index_type!r}" + raise ValueError(msg) + return cast(VectorIndexType, normalized) + + +def _validate_int_param(config: dict[str, Any], name: str, minimum: int, maximum: int | None = None) -> None: + if name not in config: + return + value = config[name] + if isinstance(value, bool) or not isinstance(value, int) or value < minimum: + msg = f"{name} must be an integer >= {minimum}" + raise ValueError(msg) + if maximum is not None and value > maximum: + msg = f"{name} must be an integer <= {maximum}" + raise ValueError(msg) + + +def _default_index_name(table_name: str, suffix: str) -> str: + return f"{table_name}_{suffix}"[:MAX_INDEX_NAME_LEN] + + +def _normalise_hnsw_params( + table_name: str, + *, + distance_metric: str, + hnsw_neighbors: int, + hnsw_ef_construction: int, + hnsw_accuracy: int, + hnsw_parallel: int, + params: dict[str, Any] | None, +) -> dict[str, Any]: + config = { + "idx_name": _default_index_name(table_name, "vidx"), + "idx_type": "HNSW", + "neighbors": hnsw_neighbors, + "efConstruction": hnsw_ef_construction, + "accuracy": hnsw_accuracy, + "parallel": hnsw_parallel, + "distance_metric": distance_metric, + } + if params: + user_params = dict(params) + if "efconstruction" in user_params and "efConstruction" not in user_params: + user_params["efConstruction"] = user_params.pop("efconstruction") + allowed = {"idx_name", "idx_type", "neighbors", "efConstruction", "accuracy", "parallel"} + invalid = set(user_params) - allowed + if invalid: + msg = f"Unsupported HNSW vector index parameter(s): {sorted(invalid)}" + raise ValueError(msg) + config.update(user_params) + if str(config["idx_type"]).upper() != "HNSW": + msg = "HNSW index parameters must use idx_type='HNSW'." + raise ValueError(msg) + config["idx_name"] = _validate_identifier(str(config["idx_name"]), "idx_name") + _validate_int_param(config, "neighbors", 2, 2048) + _validate_int_param(config, "efConstruction", 1, 65535) + _validate_int_param(config, "accuracy", 1, 100) + _validate_int_param(config, "parallel", 1) + return config + + +def _normalise_ivf_params(table_name: str, *, distance_metric: str, params: dict[str, Any] | None) -> dict[str, Any]: + config = { + "idx_name": _default_index_name(table_name, "ivf_vidx"), + "idx_type": "IVF", + "neighbor_partitions": 32, + "accuracy": 95, + "parallel": 4, + "distance_metric": distance_metric, + } + if params: + allowed = { + "idx_name", + "idx_type", + "neighbor_partitions", + "samples_per_partition", + "min_vectors_per_partition", + "accuracy", + "parallel", + } + invalid = set(params) - allowed + if invalid: + msg = f"Unsupported IVF vector index parameter(s): {sorted(invalid)}" + raise ValueError(msg) + config.update(params) + if str(config["idx_type"]).upper() != "IVF": + msg = "IVF index parameters must use idx_type='IVF'." + raise ValueError(msg) + config["idx_name"] = _validate_identifier(str(config["idx_name"]), "idx_name") + _validate_int_param(config, "neighbor_partitions", 1, 10_000_000) + _validate_int_param(config, "samples_per_partition", 1) + _validate_int_param(config, "min_vectors_per_partition", 0) + _validate_int_param(config, "accuracy", 1, 100) + _validate_int_param(config, "parallel", 1) + return config + + +def _get_vector_index_ddl( + table_name: str, + *, + index_type: str, + distance_metric: str, + hnsw_neighbors: int, + hnsw_ef_construction: int, + hnsw_accuracy: int, + hnsw_parallel: int, + params: dict[str, Any] | None = None, +) -> str: + normalized_type = _validate_index_type(index_type) + metric = _validate_distance_metric(distance_metric) + if normalized_type == "HNSW": + config = _normalise_hnsw_params( + table_name, + distance_metric=metric, + hnsw_neighbors=hnsw_neighbors, + hnsw_ef_construction=hnsw_ef_construction, + hnsw_accuracy=hnsw_accuracy, + hnsw_parallel=hnsw_parallel, + params=params, + ) + return f""" + CREATE VECTOR INDEX IF NOT EXISTS {config["idx_name"]} + ON {table_name}(embedding) + ORGANIZATION INMEMORY NEIGHBOR GRAPH + WITH TARGET ACCURACY {config["accuracy"]} + DISTANCE {config["distance_metric"]} + PARAMETERS (type HNSW, neighbors {config["neighbors"]}, + efConstruction {config["efConstruction"]}) + PARALLEL {config["parallel"]} + """ + + config = _normalise_ivf_params(table_name, distance_metric=metric, params=params) + parameters = f"type IVF, neighbor partitions {config['neighbor_partitions']}" + if "samples_per_partition" in config: + parameters += f", samples_per_partition {config['samples_per_partition']}" + if "min_vectors_per_partition" in config: + parameters += f", min_vectors_per_partition {config['min_vectors_per_partition']}" + return f""" + CREATE VECTOR INDEX IF NOT EXISTS {config["idx_name"]} + ON {table_name}(embedding) + ORGANIZATION NEIGHBOR PARTITIONS + WITH TARGET ACCURACY {config["accuracy"]} + DISTANCE {config["distance_metric"]} + PARAMETERS ({parameters}) + PARALLEL {config["parallel"]} + """ + + +def _serialize_hybrid_parameter(value: Any, field_name: str) -> str: + text = str(value) + if not _SAFE_HYBRID_PARAM.match(text): + msg = f"Invalid hybrid index {field_name} value: {value!r}" + raise ValueError(msg) + return text + + +def _hybrid_identifier_list(values: Any, field_name: str) -> str: + if not isinstance(values, list) or not values: + msg = f"{field_name} must be a non-empty list of Oracle identifiers." + raise ValueError(msg) + return ",".join(_validate_identifier(str(value), field_name) for value in values) + + +def _get_hybrid_index_ddl( + table_name: str, + idx_name: str, + vectorizer_preference: "OracleVectorizerPreference", + params: dict[str, Any] | None = None, +) -> str: + params = params or {} + index_parameters = dict(params.get("parameters") or {}) + if any(key.lower() in {"model", "embedder_spec", "vectorizer", "vector_idxtype"} for key in index_parameters): + msg = ( + "Vectorization parameters must be configured through OracleVectorizerPreference, " + "not under params['parameters']." + ) + raise ValueError(msg) + + parameter_parts = [f"vectorizer {_validate_identifier(vectorizer_preference.preference_name, 'preference_name')}"] + for key, value in index_parameters.items(): + parameter_parts.append( + f"{_serialize_hybrid_parameter(key, 'parameter name')} {_serialize_hybrid_parameter(value, 'parameter')}" + ) + + filter_by = "" + if params.get("filter_by"): + filter_by = " FILTER BY " + _hybrid_identifier_list(params["filter_by"], "filter_by") + + order_by = "" + if params.get("order_by"): + direction = "ASC" if params.get("order_by_asc", True) else "DESC" + order_by = " ORDER BY " + _hybrid_identifier_list(params["order_by"], "order_by") + f" {direction}" + + parallel = "" + if params.get("parallel") is not None: + parallel_value = params["parallel"] + if isinstance(parallel_value, bool) or not isinstance(parallel_value, int) or parallel_value <= 0: + msg = "parallel must be a positive integer." + raise ValueError(msg) + parallel = f" PARALLEL {parallel_value}" + + escaped_params = " ".join(parameter_parts).replace("'", "''") + return ( + f"CREATE HYBRID VECTOR INDEX {idx_name} ON {table_name}(text) " + f"PARAMETERS ('{escaped_params}'){filter_by}{order_by}{parallel}" + ) + + @dataclass class OracleConnectionConfig: """ @@ -99,6 +355,112 @@ def from_dict(cls, data: dict[str, Any]) -> "OracleConnectionConfig": return cls(**data) +class OracleVectorizerPreference: + """ + Manages DBMS_VECTOR_CHAIN vectorizer preferences used by Oracle hybrid vector indexes. + """ + + _CREATE_DDL = """ + BEGIN + DBMS_VECTOR_CHAIN.CREATE_PREFERENCE( + :preference_name, + DBMS_VECTOR_CHAIN.VECTORIZER, + JSON(:preference_params) + ); + END; + """ + _DROP_DDL = "BEGIN DBMS_VECTOR_CHAIN.DROP_PREFERENCE(:preference_name); END;" + + def __init__(self, document_store: "OracleDocumentStore", preference_name: str) -> None: + self.document_store = document_store + self.preference_name = _validate_identifier(preference_name, "preference_name") + + @staticmethod + def _preference_params(text_embedder: Any, params: dict[str, Any] | None = None) -> dict[str, Any]: + embedding_params = getattr(text_embedder, "embedding_params", None) + if embedding_params is None: + embedding_params = getattr(text_embedder, "_embedding_params", None) + if not isinstance(embedding_params, dict): + msg = "text_embedder must expose embedding_params as a dictionary." + raise ValueError(msg) + + preference_params = dict(params or {}) + if "model" in preference_params or "embedder_spec" in preference_params: + return preference_params + if embedding_params.get("provider") == "database": + preference_params["model"] = embedding_params.get("model") + else: + preference_params["embedder_spec"] = embedding_params + return preference_params + + @classmethod + def create( + cls, + document_store: "OracleDocumentStore", + text_embedder: Any, + preference_name: str | None = None, + params: dict[str, Any] | None = None, + ) -> "OracleVectorizerPreference": + """ + Creates a vectorizer preference. + """ + preference = cls(document_store, preference_name or f"pref{uuid.uuid4().hex[:15]}") + with document_store._get_connection() as conn, conn.cursor() as cur: + cur.execute( + cls._CREATE_DDL, + preference_name=preference.preference_name, + preference_params=json.dumps(cls._preference_params(text_embedder, params)), + ) + conn.commit() + return preference + + @classmethod + async def create_async( + cls, + document_store: "OracleDocumentStore", + text_embedder: Any, + preference_name: str | None = None, + params: dict[str, Any] | None = None, + ) -> "OracleVectorizerPreference": + """ + Creates a vectorizer preference asynchronously. + """ + if await document_store._has_async_pool(): + preference = cls(document_store, preference_name or f"pref{uuid.uuid4().hex[:15]}") + pool = await document_store._get_async_pool() + async with pool.acquire() as conn: + with conn.cursor() as cur: + await cur.execute( + cls._CREATE_DDL, + preference_name=preference.preference_name, + preference_params=json.dumps(cls._preference_params(text_embedder, params)), + ) + await conn.commit() + return preference + return await asyncio.to_thread(cls.create, document_store, text_embedder, preference_name, params) + + def drop(self) -> None: + """ + Drops this vectorizer preference. + """ + with self.document_store._get_connection() as conn, conn.cursor() as cur: + cur.execute(self._DROP_DDL, preference_name=self.preference_name) + conn.commit() + + async def drop_async(self) -> None: + """ + Drops this vectorizer preference asynchronously. + """ + if await self.document_store._has_async_pool(): + pool = await self.document_store._get_async_pool() + async with pool.acquire() as conn: + with conn.cursor() as cur: + await cur.execute(self._DROP_DDL, preference_name=self.preference_name) + await conn.commit() + return + await asyncio.to_thread(self.drop) + + class OracleDocumentStore: """ Haystack DocumentStore backed by Oracle AI Vector Search. @@ -132,6 +494,8 @@ def __init__( distance_metric: Literal["COSINE", "EUCLIDEAN", "DOT"] = "COSINE", create_table_if_not_exists: bool = True, create_index: bool = False, + vector_index_type: Literal["HNSW", "IVF"] = "HNSW", + vector_index_params: dict[str, Any] | None = None, hnsw_neighbors: int = 32, hnsw_ef_construction: int = 200, hnsw_accuracy: int = 95, @@ -151,6 +515,12 @@ def __init__( pre-existing table. :param create_index: When ``True``, creates an HNSW vector index on initialisation. Equivalent to calling :meth:`create_hnsw_index` manually. Defaults to ``False``. + :param vector_index_type: Oracle vector index type to create when ``create_index=True``. + Defaults to ``"HNSW"`` to preserve existing behavior. ``"IVF"`` is also supported. + :param vector_index_params: Optional vector index parameters. For HNSW, supported keys are + ``idx_name``, ``idx_type``, ``neighbors``, ``efConstruction``, ``accuracy``, and ``parallel``. + For IVF, supported keys are ``idx_name``, ``idx_type``, ``neighbor_partitions``, + ``samples_per_partition``, ``min_vectors_per_partition``, ``accuracy``, and ``parallel``. :param hnsw_neighbors: Number of neighbours in the HNSW graph. Higher values improve recall at the cost of index size and build time. Defaults to ``32``. :param hnsw_ef_construction: Size of the dynamic candidate list during HNSW index construction. @@ -161,12 +531,7 @@ def __init__( :raises ValueError: If ``table_name`` is not a valid Oracle identifier or ``embedding_dim`` is not a positive integer. """ - if not _SAFE_TABLE_NAME.match(table_name): - msg = ( - f"Invalid table_name {table_name!r}. Must be a valid Oracle identifier " - "(letters, digits, _, $, # — max 128 chars, cannot start with a digit)." - ) - raise ValueError(msg) + _validate_identifier(table_name, "table_name") if embedding_dim <= 0: msg = f"embedding_dim must be a positive integer, got {embedding_dim}" raise ValueError(msg) @@ -174,21 +539,43 @@ def __init__( self.connection_config = connection_config self.table_name = table_name self.embedding_dim = embedding_dim - self.distance_metric = distance_metric + self.distance_metric = _validate_distance_metric(distance_metric) self.create_table_if_not_exists = create_table_if_not_exists self.create_index = create_index + self.vector_index_type = _validate_index_type(vector_index_type) + self.vector_index_params = dict(vector_index_params) if vector_index_params else None self.hnsw_neighbors = hnsw_neighbors self.hnsw_ef_construction = hnsw_ef_construction self.hnsw_accuracy = hnsw_accuracy self.hnsw_parallel = hnsw_parallel self._pool: oracledb.ConnectionPool | None = None + self._async_pool: Any | None = None self._pool_lock = threading.Lock() if create_table_if_not_exists: self._ensure_table() if create_index: - self.create_hnsw_index() + self.create_vector_index(index_type=self.vector_index_type, params=self.vector_index_params) + + def _connect_kwargs(self, *, pool_options: bool = True) -> dict[str, Any]: + cfg = self.connection_config + password = cfg.password.resolve_value() + + connect_kwargs: dict[str, Any] = { + "user": cfg.user.resolve_value(), + "password": password, + "dsn": cfg.dsn.resolve_value(), + } + if pool_options: + connect_kwargs["min"] = cfg.min_connections + connect_kwargs["max"] = cfg.max_connections + connect_kwargs["increment"] = 1 + if cfg.wallet_location: + connect_kwargs["config_dir"] = cfg.wallet_location + connect_kwargs["wallet_location"] = cfg.wallet_location + connect_kwargs["wallet_password"] = cfg.wallet_password.resolve_value() if cfg.wallet_password else password + return connect_kwargs def _get_pool(self) -> oracledb.ConnectionPool: if self._pool is not None: @@ -197,36 +584,56 @@ def _get_pool(self) -> oracledb.ConnectionPool: if self._pool is not None: return self._pool - cfg = self.connection_config - password = cfg.password.resolve_value() - - connect_kwargs: dict[str, Any] = { - "user": cfg.user.resolve_value(), - "password": password, - "dsn": cfg.dsn.resolve_value(), - "min": cfg.min_connections, - "max": cfg.max_connections, - "increment": 1, - } - if cfg.wallet_location: - connect_kwargs["config_dir"] = cfg.wallet_location - connect_kwargs["wallet_location"] = cfg.wallet_location - connect_kwargs["wallet_password"] = ( - cfg.wallet_password.resolve_value() if cfg.wallet_password else password - ) - - self._pool = oracledb.create_pool(**connect_kwargs) + self._pool = oracledb.create_pool(**self._connect_kwargs()) return self._pool def _get_connection(self) -> oracledb.Connection: return self._get_pool().acquire() - def __del__(self) -> None: + async def _has_async_pool(self) -> bool: + return getattr(oracledb, "create_pool_async", None) is not None + + async def _get_async_pool(self) -> Any: + if self._async_pool is not None: + return self._async_pool + create_pool_async = getattr(oracledb, "create_pool_async", None) + if create_pool_async is None: + msg = "python-oracledb does not provide create_pool_async; install a version with async pool support." + raise RuntimeError(msg) + self._async_pool = create_pool_async(**self._connect_kwargs()) + return self._async_pool + + def close(self) -> None: + """ + Close synchronous Oracle resources owned by this document store. + + This releases the connection pool without deleting the backing table or indexes. + """ if self._pool is not None: + pool = self._pool + self._pool = None + try: + pool.close() + except Exception: + logger.warning("Failed to close Oracle connection pool.", exc_info=True) + + async def close_async(self) -> None: + """ + Close asynchronous and synchronous Oracle resources owned by this document store. + + This releases connection pools without deleting the backing table or indexes. + """ + if self._async_pool is not None: + pool = self._async_pool + self._async_pool = None try: - self._pool.close() + await pool.close() except Exception: - logger.warning("Failed to close Oracle connection pool during cleanup.", exc_info=True) + logger.warning("Failed to close Oracle async connection pool.", exc_info=True) + self.close() + + def __del__(self) -> None: + self.close() def _ensure_table(self) -> None: sql = f""" @@ -244,17 +651,28 @@ def _ensure_table(self) -> None: self._ensure_keyword_index() def _ensure_keyword_index(self) -> None: - index_name = f"{self.table_name}_search_idx" - if len(index_name) > MAX_INDEX_NAME_LEN: - index_name = index_name[:MAX_INDEX_NAME_LEN] + index_name = self._keyword_index_name() try: with self._get_connection() as conn, conn.cursor() as cur: cur.execute( - f"BEGIN DBMS_SEARCH.CREATE_INDEX('{index_name}'); " - f"DBMS_SEARCH.ADD_SOURCE('{index_name}', '{self.table_name}'); END;" + """ + BEGIN + DBMS_SEARCH.CREATE_INDEX(:index_name); + DBMS_SEARCH.ADD_SOURCE(:index_name, :table_name); + END; + """, + index_name=index_name, + table_name=self.table_name, ) conn.commit() except oracledb.DatabaseError as e: + if _is_dbms_search_unavailable_error(e): + logger.warning( + "DBMS_SEARCH is unavailable; skipping keyword index creation for %s. " + "Oracle keyword retrieval requires DBMS_SEARCH.", + index_name, + ) + return logger.debug("Could not create keyword index (may already exist): %s", e) def create_keyword_index(self) -> None: @@ -268,25 +686,68 @@ def create_keyword_index(self) -> None: """ self._ensure_keyword_index() + def create_vector_index( + self, + *, + index_type: Literal["HNSW", "IVF"] | None = None, + params: dict[str, Any] | None = None, + ) -> None: + """ + Create a vector index on the embedding column. + + Defaults to the document store's configured index type. Existing callers that use + :meth:`create_hnsw_index` keep the previous HNSW behavior. + """ + sql = _get_vector_index_ddl( + self.table_name, + index_type=index_type or self.vector_index_type, + distance_metric=self.distance_metric, + hnsw_neighbors=self.hnsw_neighbors, + hnsw_ef_construction=self.hnsw_ef_construction, + hnsw_accuracy=self.hnsw_accuracy, + hnsw_parallel=self.hnsw_parallel, + params=params if params is not None else self.vector_index_params, + ) + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql) + conn.commit() + def create_hnsw_index(self) -> None: """ Create an HNSW vector index on the embedding column. Safe to call multiple times — uses IF NOT EXISTS. """ - sql = f""" - CREATE VECTOR INDEX IF NOT EXISTS {self.table_name}_vidx - ON {self.table_name}(embedding) - ORGANIZATION INMEMORY NEIGHBOR GRAPH - WITH TARGET ACCURACY {self.hnsw_accuracy} - DISTANCE {self.distance_metric} - PARAMETERS (type HNSW, neighbors {self.hnsw_neighbors}, - efConstruction {self.hnsw_ef_construction}) - PARALLEL {self.hnsw_parallel} + self.create_vector_index(index_type="HNSW") + + async def create_vector_index_async( + self, + *, + index_type: Literal["HNSW", "IVF"] | None = None, + params: dict[str, Any] | None = None, + ) -> None: """ - with self._get_connection() as conn, conn.cursor() as cur: - cur.execute(sql) - conn.commit() + Asynchronously creates a vector index on the embedding column. + """ + if not await self._has_async_pool(): + await asyncio.to_thread(self.create_vector_index, index_type=index_type, params=params) + return + + sql = _get_vector_index_ddl( + self.table_name, + index_type=index_type or self.vector_index_type, + distance_metric=self.distance_metric, + hnsw_neighbors=self.hnsw_neighbors, + hnsw_ef_construction=self.hnsw_ef_construction, + hnsw_accuracy=self.hnsw_accuracy, + hnsw_parallel=self.hnsw_parallel, + params=params if params is not None else self.vector_index_params, + ) + pool = await self._get_async_pool() + async with pool.acquire() as conn: + with conn.cursor() as cur: + await cur.execute(sql) + await conn.commit() async def create_hnsw_index_async(self) -> None: """ @@ -294,7 +755,96 @@ async def create_hnsw_index_async(self) -> None: Safe to call multiple times — uses ``IF NOT EXISTS``. """ - await asyncio.to_thread(self.create_hnsw_index) + await self.create_vector_index_async(index_type="HNSW") + + def create_hybrid_vector_index( + self, + idx_name: str, + *, + text_embedder: Any | None = None, + vectorizer_preference: OracleVectorizerPreference | None = None, + params: dict[str, Any] | None = None, + ) -> OracleVectorizerPreference: + """ + Create a DBMS_HYBRID_VECTOR hybrid index over the document text column. + + Either provide an existing ``vectorizer_preference`` or a text embedder from which a new + preference can be created. The returned preference can be dropped by the caller if it was + created only for this index. + """ + created_preference = vectorizer_preference is None + if created_preference: + if text_embedder is None: + msg = "text_embedder is required when vectorizer_preference is not provided." + raise ValueError(msg) + vectorizer_preference = OracleVectorizerPreference.create(self, text_embedder) + if vectorizer_preference is None: + msg = "vectorizer_preference could not be created." + raise RuntimeError(msg) + quoted_idx_name = _validate_identifier(idx_name, "idx_name") + ddl = _get_hybrid_index_ddl(self.table_name, quoted_idx_name, vectorizer_preference, params) + try: + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(ddl) + conn.commit() + except Exception: + if created_preference: + try: + vectorizer_preference.drop() + except Exception: + logger.exception( + "Failed to drop auto-created vectorizer preference %s after hybrid index creation failed.", + vectorizer_preference.preference_name, + ) + raise + return vectorizer_preference + + async def create_hybrid_vector_index_async( + self, + idx_name: str, + *, + text_embedder: Any | None = None, + vectorizer_preference: OracleVectorizerPreference | None = None, + params: dict[str, Any] | None = None, + ) -> OracleVectorizerPreference: + """ + Asynchronously create a DBMS_HYBRID_VECTOR hybrid index over the document text column. + """ + created_preference = vectorizer_preference is None + if created_preference: + if text_embedder is None: + msg = "text_embedder is required when vectorizer_preference is not provided." + raise ValueError(msg) + vectorizer_preference = await OracleVectorizerPreference.create_async(self, text_embedder) + if vectorizer_preference is None: + msg = "vectorizer_preference could not be created." + raise RuntimeError(msg) + try: + if not await self._has_async_pool(): + return await asyncio.to_thread( + self.create_hybrid_vector_index, + idx_name, + vectorizer_preference=vectorizer_preference, + params=params, + ) + quoted_idx_name = _validate_identifier(idx_name, "idx_name") + ddl = _get_hybrid_index_ddl(self.table_name, quoted_idx_name, vectorizer_preference, params) + pool = await self._get_async_pool() + async with pool.acquire() as conn: + with conn.cursor() as cur: + await cur.execute(ddl) + await conn.commit() + except Exception: + if created_preference: + try: + await vectorizer_preference.drop_async() + except Exception: + logger.exception( + "Failed to drop auto-created vectorizer preference %s after hybrid index creation failed.", + vectorizer_preference.preference_name, + ) + raise + return vectorizer_preference def write_documents( self, @@ -384,20 +934,33 @@ def _skip_duplicate_documents(self, documents: list[Document]) -> int: def _upsert_documents(self, documents: list[Document]) -> int: sql = f""" - MERGE INTO {self.table_name} t - USING (SELECT :doc_id AS id FROM dual) s ON (t.id = s.id) - WHEN MATCHED THEN - UPDATE SET t.text = :doc_text, t.metadata = :doc_meta, t.embedding = :doc_emb - WHEN NOT MATCHED THEN - INSERT (id, text, metadata, embedding) - VALUES (s.id, :doc_text, :doc_meta, :doc_emb) + BEGIN + UPDATE {self.table_name} + SET text = :doc_text, + metadata = :doc_meta, + embedding = :doc_emb + WHERE id = :doc_id; + + IF SQL%ROWCOUNT = 0 THEN + BEGIN + INSERT INTO {self.table_name} (id, text, metadata, embedding) + VALUES (:doc_id, :doc_text, :doc_meta, :doc_emb); + EXCEPTION + WHEN DUP_VAL_ON_INDEX THEN + UPDATE {self.table_name} + SET text = :doc_text, + metadata = :doc_meta, + embedding = :doc_emb + WHERE id = :doc_id; + END; + END IF; + END; """ rows = [OracleDocumentStore._to_named_row(d) for d in documents] with self._get_connection() as conn, conn.cursor() as cur: cur.executemany(sql, rows) - written = cur.rowcount conn.commit() - return written + return len(rows) async def write_documents_async( self, @@ -437,6 +1000,7 @@ def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Docume where, params = OracleDocumentStore._build_where(filters) sql = f"SELECT id, text, JSON_SERIALIZE(metadata) AS metadata FROM {self.table_name} {where}" with self._get_connection() as conn, conn.cursor() as cur: + cur.outputtypehandler = _output_type_string_handler cur.execute(sql, params) rows = cur.fetchall() return [OracleDocumentStore._row_to_document(r) for r in rows] @@ -498,17 +1062,42 @@ async def count_documents_async(self) -> int: """ return await asyncio.to_thread(self.count_documents) + def _keyword_index_name(self) -> str: + index_name = f"{self.table_name}_search_idx" + return index_name[:MAX_INDEX_NAME_LEN] + + def _drop_keyword_index(self, cur: Any) -> None: + index_name = self._keyword_index_name() + sql = "BEGIN DBMS_SEARCH.DROP_INDEX(:index_name); END;" + try: + cur.execute(sql, index_name=index_name) + except oracledb.DatabaseError as e: + if _is_missing_object_error(e): + logger.debug("Keyword index %s was already absent during table cleanup.", index_name) + return + if _is_dbms_search_unavailable_error(e): + logger.debug("DBMS_SEARCH is unavailable; skipping keyword index cleanup for %s.", index_name) + return + logger.debug("Failed to drop keyword index. SQL: %s", sql) + msg = ( + f"Failed to drop keyword index '{index_name}'. Error: {e!r}. " + "You can find the SQL query in the debug logs." + ) + raise DocumentStoreError(msg) from e + def delete_table(self) -> None: """ Permanently drops the document store table and its associated DBMS_SEARCH keyword index. Uses ``DROP TABLE ... PURGE`` which bypasses the Oracle recycle bin — the operation is - irreversible. The keyword index is dropped after the table; if either operation fails a + irreversible. The DBMS_SEARCH keyword index is dropped before the table because it is created + through the DBMS_SEARCH PL/SQL API. If either operation fails a :class:`DocumentStoreError` is raised. :raises DocumentStoreError: If the table or keyword index cannot be dropped. """ with self._get_connection() as conn, conn.cursor() as cur: + self._drop_keyword_index(cur) sql = f"DROP TABLE {self.table_name} PURGE" try: cur.execute(sql) @@ -519,24 +1108,11 @@ def delete_table(self) -> None: "You can find the SQL query in the debug logs." ) raise DocumentStoreError(msg) from e - index_name = f"{self.table_name}_search_idx" - if len(index_name) > MAX_INDEX_NAME_LEN: - index_name = index_name[:MAX_INDEX_NAME_LEN] - sql = f"BEGIN DBMS_SEARCH.DROP_INDEX('{index_name}'); END;" - try: - cur.execute(sql) - except oracledb.DatabaseError as e: - logger.debug("Failed to drop keyword index. SQL: %s", sql) - msg = ( - f"Failed to drop keyword index '{index_name}'. Error: {e!r}. " - "You can find the SQL query in the debug logs." - ) - raise DocumentStoreError(msg) from e conn.commit() async def delete_table_async(self) -> None: """ - Asynchronously permanently drops the document store table and its DBMS_SEARCH keyword index. + Asynchronously permanently drops the document store table and its indexes. Uses ``DROP TABLE ... PURGE`` which bypasses the Oracle recycle bin — the operation is irreversible. @@ -717,15 +1293,15 @@ def get_metadata_fields_info(self) -> dict[str, dict[str, str]]: """ sql = f"SELECT JSON_DATAGUIDE(metadata) FROM {self.table_name}" with self._get_connection() as conn, conn.cursor() as cur: + cur.outputtypehandler = _output_type_string_handler cur.execute(sql) row = cur.fetchone() if not row or not row[0]: return {} - raw_guide = row[0].read() if hasattr(row[0], "read") else row[0] - if not raw_guide: + if not row[0]: return {} fields: dict[str, dict[str, str]] = {} - dataguide = json.loads(raw_guide) + dataguide = json.loads(row[0]) for path_info in dataguide: path = path_info.get("o:path", "") if path.startswith("$."): @@ -858,6 +1434,175 @@ async def get_metadata_field_unique_values_async( """ return await asyncio.to_thread(self.get_metadata_field_unique_values, metadata_field, search_term, from_, size) + @staticmethod + def _validate_hybrid_params(params: dict[str, Any]) -> dict[str, Any]: + forbidden_top_level = {"search_text", "return"} + forbidden_vector = {"search_text", "search_vector"} + forbidden_text = {"search_text", "search_vector", "contains", "json_textcontains"} + if forbidden_top_level & set(params): + msg = "search_text and return are derived internally and cannot be set in params." + raise ValueError(msg) + if forbidden_vector & set(params.get("vector") or {}): + msg = "params['vector'] cannot contain search_text or search_vector." + raise ValueError(msg) + if forbidden_text & set(params.get("text") or {}): + msg = "params['text'] cannot contain search_text, search_vector, contains, or json_textcontains." + raise ValueError(msg) + return dict(params) + + def _hybrid_search_params( + self, + query: str, + *, + index_name: str, + search_mode: Literal["keyword", "hybrid", "semantic"], + filters: dict[str, Any] | None, + top_k: int, + params: dict[str, Any] | None, + ) -> dict[str, Any]: + if search_mode not in _VALID_HYBRID_SEARCH_MODES: + msg = f"search_mode must be one of {_VALID_HYBRID_SEARCH_MODES}, got {search_mode!r}" + raise ValueError(msg) + + search_params = self._validate_hybrid_params(params or {}) + search_params["hybrid_index_name"] = index_name + + if search_mode in {"hybrid", "semantic"}: + search_params["vector"] = dict(search_params.get("vector") or {}) + search_params["vector"]["search_text"] = query + if search_mode in {"hybrid", "keyword"}: + search_params["text"] = dict(search_params.get("text") or {}) + search_params["text"]["search_text"] = query + + if filters: + if "filter_by" in search_params: + msg = "Cannot combine Haystack filters with params['filter_by']." + raise FilterError(msg) + search_params["filter_by"] = to_hybrid_filter(filters) + + search_params["return"] = { + "topN": top_k, + "values": ["rowid", "score", "vector_score", "text_score"], + "format": "JSON", + } + return search_params + + @staticmethod + def _with_hybrid_scores(search_row: dict[str, Any], document: Document, *, return_scores: bool) -> Document: + score = search_row.get("score") + if not return_scores: + return replace(document, score=score) + return replace( + document, + score=score, + meta={ + **document.meta, + "score": score, + "text_score": search_row.get("text_score"), + "vector_score": search_row.get("vector_score"), + }, + ) + + def _hybrid_retrieval( + self, + query: str, + *, + index_name: str, + search_mode: Literal["keyword", "hybrid", "semantic"] = "hybrid", + filters: dict[str, Any] | None = None, + top_k: int = 10, + params: dict[str, Any] | None = None, + return_scores: bool = False, + ) -> list[Document]: + search_params = self._hybrid_search_params( + query, + index_name=index_name, + search_mode=search_mode, + filters=filters, + top_k=top_k, + params=params, + ) + + documents: list[Document] = [] + with self._get_connection() as conn, conn.cursor() as cur: + cur.outputtypehandler = _output_type_string_handler + cur.setinputsizes(search_params=oracledb.DB_TYPE_JSON) + cur.execute("SELECT DBMS_HYBRID_VECTOR.SEARCH(JSON(:search_params))", search_params=search_params) + search_rows = json.loads(cur.fetchone()[0]) + for search_row in search_rows: + cur.execute( + f"SELECT id, text, JSON_SERIALIZE(metadata) AS metadata FROM {self.table_name} WHERE ROWID = :rid", + rid=search_row["rowid"], + ) + document_row = cur.fetchone() + if document_row is None: + continue + document = OracleDocumentStore._row_to_document(document_row) + document = OracleDocumentStore._with_hybrid_scores(search_row, document, return_scores=return_scores) + documents.append(document) + + return documents + + async def _hybrid_retrieval_async( + self, + query: str, + *, + index_name: str, + search_mode: Literal["keyword", "hybrid", "semantic"] = "hybrid", + filters: dict[str, Any] | None = None, + top_k: int = 10, + params: dict[str, Any] | None = None, + return_scores: bool = False, + ) -> list[Document]: + if not await self._has_async_pool(): + return await asyncio.to_thread( + self._hybrid_retrieval, + query, + index_name=index_name, + search_mode=search_mode, + filters=filters, + top_k=top_k, + params=params, + return_scores=return_scores, + ) + + search_params = self._hybrid_search_params( + query, + index_name=index_name, + search_mode=search_mode, + filters=filters, + top_k=top_k, + params=params, + ) + + documents: list[Document] = [] + pool = await self._get_async_pool() + async with pool.acquire() as conn: + with conn.cursor() as cur: + cur.outputtypehandler = _output_type_string_handler + cur.setinputsizes(search_params=oracledb.DB_TYPE_JSON) + await cur.execute( + "SELECT DBMS_HYBRID_VECTOR.SEARCH(JSON(:search_params))", + search_params=search_params, + ) + search_rows = json.loads((await cur.fetchone())[0]) + for search_row in search_rows: + await cur.execute( + "SELECT id, text, JSON_SERIALIZE(metadata) AS metadata " + f"FROM {self.table_name} WHERE ROWID = :rid", + rid=search_row["rowid"], + ) + document_row = await cur.fetchone() + if document_row is None: + continue + document = OracleDocumentStore._row_to_document(document_row) + document = OracleDocumentStore._with_hybrid_scores( + search_row, document, return_scores=return_scores + ) + documents.append(document) + + return documents + def _embedding_retrieval( self, query_embedding: list[float], @@ -880,6 +1625,7 @@ def _embedding_retrieval( params["query_vec"] = _array.array("f", query_embedding) params["top_k"] = top_k with self._get_connection() as conn, conn.cursor() as cur: + cur.outputtypehandler = _output_type_string_handler try: cur.execute(sql, params) except oracledb.DatabaseError as e: @@ -899,19 +1645,46 @@ async def _embedding_retrieval_async( filters: dict[str, Any] | None = None, top_k: int = 10, ) -> list[Document]: - return await asyncio.to_thread( - self._embedding_retrieval, - query_embedding, - filters=filters, - top_k=top_k, - ) + if not await self._has_async_pool(): + return await asyncio.to_thread( + self._embedding_retrieval, + query_embedding, + filters=filters, + top_k=top_k, + ) + + order = "ASC" + where, params = OracleDocumentStore._build_where(filters) + sql = f""" + SELECT id, text, JSON_SERIALIZE(metadata) AS metadata, + vector_distance(embedding, :query_vec, {self.distance_metric}) AS score + FROM {self.table_name} + {where} + ORDER BY score {order} + FETCH APPROX FIRST :top_k ROWS ONLY + """ + params["query_vec"] = _array.array("f", query_embedding) + params["top_k"] = top_k + pool = await self._get_async_pool() + async with pool.acquire() as conn: + with conn.cursor() as cur: + cur.outputtypehandler = _output_type_string_handler + try: + await cur.execute(sql, params) + except oracledb.DatabaseError as e: + logger.debug("Async embedding retrieval failed. SQL: %s\nParams: %s", sql, params) + msg = ( + f"Async embedding retrieval failed. Error: {e!r}. " + "You can find the SQL query and the parameters in the debug logs." + ) + raise DocumentStoreError(msg) from e + rows = await cur.fetchall() + return [OracleDocumentStore._row_to_document(r, with_score=True) for r in rows] def _keyword_retrieval( self, query: str, *, filters: dict[str, Any] | None = None, top_k: int = 10 ) -> list[Document]: - index_name = f"{self.table_name}_search_idx" - if len(index_name) > MAX_INDEX_NAME_LEN: - index_name = index_name[:MAX_INDEX_NAME_LEN] + index_name = self._keyword_index_name() where, params = OracleDocumentStore._build_where(filters) where_cond = where.replace("WHERE", "WHERE t.") if where else "" sql = f""" @@ -931,6 +1704,7 @@ def _keyword_retrieval( params["query"] = query params["top_k"] = top_k with self._get_connection() as conn, conn.cursor() as cur: + cur.outputtypehandler = _output_type_string_handler try: cur.execute(sql, params) except oracledb.DatabaseError as e: @@ -955,12 +1729,6 @@ def _row_to_document(row: tuple, *, with_score: bool = False) -> Document: else: raw_id, text, metadata_raw, score = *row, None - # oracledb returns CLOB/JSON as LOB objects — read them to strings - if hasattr(text, "read"): - text = text.read() - if hasattr(metadata_raw, "read"): - metadata_raw = metadata_raw.read() - if isinstance(metadata_raw, str): meta = json.loads(metadata_raw) elif isinstance(metadata_raw, dict): @@ -984,19 +1752,23 @@ def to_dict(self) -> dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( - self, - connection_config=self.connection_config.to_dict(), - table_name=self.table_name, - embedding_dim=self.embedding_dim, - distance_metric=self.distance_metric, - create_table_if_not_exists=self.create_table_if_not_exists, - create_index=self.create_index, - hnsw_neighbors=self.hnsw_neighbors, - hnsw_ef_construction=self.hnsw_ef_construction, - hnsw_accuracy=self.hnsw_accuracy, - hnsw_parallel=self.hnsw_parallel, - ) + init_parameters: dict[str, Any] = { + "connection_config": self.connection_config.to_dict(), + "table_name": self.table_name, + "embedding_dim": self.embedding_dim, + "distance_metric": self.distance_metric, + "create_table_if_not_exists": self.create_table_if_not_exists, + "create_index": self.create_index, + "hnsw_neighbors": self.hnsw_neighbors, + "hnsw_ef_construction": self.hnsw_ef_construction, + "hnsw_accuracy": self.hnsw_accuracy, + "hnsw_parallel": self.hnsw_parallel, + } + if self.vector_index_type != "HNSW": + init_parameters["vector_index_type"] = self.vector_index_type + if self.vector_index_params is not None: + init_parameters["vector_index_params"] = self.vector_index_params + return default_to_dict(self, **init_parameters) @classmethod def from_dict(cls, data: dict[str, Any]) -> "OracleDocumentStore": diff --git a/integrations/oracle/src/haystack_integrations/document_stores/oracle/filters.py b/integrations/oracle/src/haystack_integrations/document_stores/oracle/filters.py index f2acdb3091..8f21aea44f 100644 --- a/integrations/oracle/src/haystack_integrations/document_stores/oracle/filters.py +++ b/integrations/oracle/src/haystack_integrations/document_stores/oracle/filters.py @@ -2,12 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 +import re from datetime import datetime from typing import Any, ClassVar from haystack.errors import FilterError _RANGE_OPS = {">", ">=", "<", "<="} +_JSON_FIELD_NAME = r"^[A-Za-z0-9_.]+$" class FilterTranslator: @@ -69,13 +71,26 @@ def translate(self, filters: dict[str, Any], params: dict[str, Any], counter: li msg = f"'value' key missing in comparison filter: {filters}" raise FilterError(msg) - if not isinstance(op, str) or op not in {*self._OP_MAP, "in", "not in"}: + if not isinstance(op, str) or op not in {*self._OP_MAP, "in", "not in", "contains", "not contains"}: msg = f"Unsupported filter operator: {op!r}" raise FilterError(msg) field: str = filters["field"] value: Any = filters["value"] + if op in ("contains", "not contains"): + if isinstance(value, list): + msg = f"{op!r} filter values must be scalar, got list" + raise FilterError(msg) + json_path = FilterTranslator._field_to_json_path(field) + pname = f"p{counter[0]}" + counter[0] += 1 + params[pname] = value + contains_sql = f"JSON_EXISTS(metadata, '{json_path}[*]?(@ == $val)' PASSING :{pname} AS \"val\")" + if op == "contains": + return contains_sql + return f"(NOT {contains_sql})" + if op in ("in", "not in"): if not isinstance(value, list): msg = f"'in' / 'not in' filter values must be a list, got {type(value).__name__!r}" @@ -135,6 +150,102 @@ def _field_to_sql(field: str, value: Any) -> str: return f"TO_NUMBER({json_path})" return json_path + @staticmethod + def _field_to_json_path(field: str) -> str: + if not field.startswith("meta."): + msg = f"Operator 'contains' supports only metadata fields, got {field!r}" + raise FilterError(msg) + key = field[len("meta.") :] + if not re.match(_JSON_FIELD_NAME, key): + msg = f"Invalid metadata field name: {field!r}" + raise FilterError(msg) + return f"$.{key}" + + +def _infer_hybrid_filter_type(value: Any) -> str: + if isinstance(value, bool): + msg = "Boolean values are not supported for Oracle hybrid filters." + raise FilterError(msg) + if isinstance(value, (int, float)): + return "number" + if isinstance(value, str): + return "string" + msg = "Oracle hybrid filters support only string and numeric values." + raise FilterError(msg) + + +def _hybrid_filter_path(field: str) -> str: + if not field.startswith("meta."): + msg = "Oracle hybrid retrieval supports only filters under the 'meta.' field." + raise FilterError(msg) + if not re.match(_JSON_FIELD_NAME, field): + msg = f"Invalid metadata field name: {field!r}" + raise FilterError(msg) + return "metadata." + field[len("meta.") :] + + +def to_hybrid_filter(filters: dict[str, Any]) -> dict[str, Any]: + """ + Converts Haystack filters into DBMS_HYBRID_VECTOR.SEARCH filter_by JSON. + """ + op = filters.get("operator") + if op in ("AND", "OR", "NOT"): + if "conditions" not in filters: + msg = f"'conditions' key missing in logical filter: {filters}" + raise FilterError(msg) + return {"op": op, "args": [to_hybrid_filter(condition) for condition in filters["conditions"]]} + + if "field" not in filters: + msg = f"'field' key missing in comparison filter: {filters}" + raise FilterError(msg) + if "operator" not in filters: + msg = f"'operator' key missing in comparison filter: {filters}" + raise FilterError(msg) + if "value" not in filters: + msg = f"'value' key missing in comparison filter: {filters}" + raise FilterError(msg) + + field = _hybrid_filter_path(filters["field"]) + value = filters["value"] + if value is None: + msg = "Oracle hybrid retrieval does not support null comparisons." + raise FilterError(msg) + if op in {"contains", "not contains"}: + msg = f"Filter operation {op!r} is not supported for Oracle hybrid retrieval." + raise FilterError(msg) + + if op in {"in", "not in"}: + if not isinstance(value, list) or not value: + msg = f"{op!r} filter requires a non-empty list." + raise FilterError(msg) + value_type = _infer_hybrid_filter_type(value[0]) + if any(_infer_hybrid_filter_type(item) != value_type for item in value): + msg = "Oracle hybrid retrieval requires all 'in' filter values to share one type." + raise FilterError(msg) + hybrid_filter: dict[str, Any] = {"op": "IN", "path": field, "type": value_type, "args": value} + if op == "not in": + return {"op": "NOT", "args": [hybrid_filter]} + return hybrid_filter + + hybrid_op_map = { + "==": "=", + "!=": "!=", + ">": ">", + ">=": ">=", + "<": "<", + "<=": "<=", + } + if not isinstance(op, str) or op not in hybrid_op_map: + msg = f"Unsupported filter operator: {op!r}" + raise FilterError(msg) + + return { + "op": hybrid_op_map[op], + "path": field, + "type": _infer_hybrid_filter_type(value), + "args": [value], + } + def _is_iso_date(value: Any) -> bool: """Return True if *value* is a string that Python recognises as a valid ISO-8601 datetime.""" diff --git a/integrations/oracle/tests/conftest.py b/integrations/oracle/tests/conftest.py index aa85d651ab..384a282078 100644 --- a/integrations/oracle/tests/conftest.py +++ b/integrations/oracle/tests/conftest.py @@ -2,27 +2,95 @@ # # SPDX-License-Identifier: Apache-2.0 +import os import uuid from unittest.mock import MagicMock +import oracledb as _oracledb import pytest from haystack.dataclasses import Document from haystack.utils import Secret from haystack_integrations.document_stores.oracle import OracleConnectionConfig, OracleDocumentStore -_USER = "haystack" -_PASSWORD = "haystack" -_DSN = "localhost:1521/freepdb1" +_ORACLE_FEATURE_INTEGRATION_FILE = "test_oracle_features_integration.py" +_ORACLE_FEATURE_SKIP_REASONS = { + "ORA-00904": "Oracle vector APIs are unavailable in this live database", + "ORA-51962": "Oracle vector memory area is exhausted in this live database", +} + + +def _env_value(*names: str, default: str | None = None) -> str | None: + for name in names: + value = os.getenv(name) + if value: + return value + return default + + +def _oracle_feature_skip_reason(exc: BaseException) -> str | None: + if not isinstance(exc, _oracledb.DatabaseError): + return None + message = str(exc) + for error_code, reason in _ORACLE_FEATURE_SKIP_REASONS.items(): + if error_code in message: + return reason + if "PLS-00201" in message and "DBMS_VECTOR_CHAIN" in message: + return "Oracle DBMS_VECTOR_CHAIN APIs are unavailable in this live database" + return None + + +def _is_oracle_feature_integration_test(item) -> bool: + return item.path.name == _ORACLE_FEATURE_INTEGRATION_FILE and item.get_closest_marker("integration") is not None + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_makereport(item, call): + outcome = yield + report = outcome.get_result() + if call.when != "call" or not report.failed or call.excinfo is None: + return + if not _is_oracle_feature_integration_test(item): + return + + reason = _oracle_feature_skip_reason(call.excinfo.value) + if reason is None: + return + + report.outcome = "skipped" + report.longrepr = (str(item.path), report.location[1], f"Skipped: {reason}") + + +def connection_config(*, secret_source: str = "token") -> OracleConnectionConfig: + wallet_location = _env_value("ORACLE_WALLET_LOCATION") + if secret_source == "env_var": + wallet_password = Secret.from_env_var("ORACLE_WALLET_PASSWORD", strict=False) if wallet_location else None + return OracleConnectionConfig( + user=Secret.from_env_var("ORACLE_USER", strict=False), + password=Secret.from_env_var("ORACLE_PASSWORD", strict=False), + dsn=Secret.from_env_var("ORACLE_DSN", strict=False), + wallet_location=wallet_location, + wallet_password=wallet_password, + ) + + wallet_password = _env_value("ORACLE_WALLET_PASSWORD") + return OracleConnectionConfig( + user=Secret.from_token(_env_value("ORACLE_USER", default="haystack")), + password=Secret.from_token(_env_value("ORACLE_PASSWORD", default="haystack")), + dsn=Secret.from_token(_env_value("ORACLE_DSN", default="localhost:1521/freepdb1")), + wallet_location=wallet_location, + wallet_password=Secret.from_token(wallet_password) if wallet_password else None, + ) + + +@pytest.fixture(name="connection_config") +def connection_config_fixture(): + return connection_config def _make_store(table: str, embedding_dim: int) -> OracleDocumentStore: return OracleDocumentStore( - connection_config=OracleConnectionConfig( - user=Secret.from_token(_USER), - password=Secret.from_token(_PASSWORD), - dsn=Secret.from_token(_DSN), - ), + connection_config=connection_config(), table_name=table, embedding_dim=embedding_dim, distance_metric="COSINE", @@ -35,10 +103,13 @@ def document_store(): """768-dim store required by the mixin's filterable_docs fixture.""" table = f"hs_sync_{uuid.uuid4().hex[:8]}" s = _make_store(table, embedding_dim=768) - yield s - with s._get_connection() as conn, conn.cursor() as cur: - cur.execute(f"DROP TABLE {table} PURGE") - conn.commit() + try: + yield s + finally: + try: + s.delete_table() + finally: + s.close() @pytest.fixture @@ -46,10 +117,13 @@ def embedding_store(): """4-dim store for embedding-retrieval, HNSW, and async tests.""" table = f"hs_emb_{uuid.uuid4().hex[:8]}" s = _make_store(table, embedding_dim=4) - yield s - with s._get_connection() as conn, conn.cursor() as cur: - cur.execute(f"DROP TABLE {table} PURGE") - conn.commit() + try: + yield s + finally: + try: + s.delete_table() + finally: + s.close() @pytest.fixture @@ -81,12 +155,10 @@ def patched_store(monkeypatch): monkeypatch.setenv("ORACLE_USER", "u") monkeypatch.setenv("ORACLE_PASSWORD", "p") monkeypatch.setenv("ORACLE_DSN", "localhost/xe") + monkeypatch.delenv("ORACLE_WALLET_LOCATION", raising=False) + monkeypatch.delenv("ORACLE_WALLET_PASSWORD", raising=False) return OracleDocumentStore( - connection_config=OracleConnectionConfig( - user=Secret.from_env_var("ORACLE_USER"), - password=Secret.from_env_var("ORACLE_PASSWORD"), - dsn=Secret.from_env_var("ORACLE_DSN"), - ), + connection_config=connection_config(secret_source="env_var"), table_name="test_docs", embedding_dim=4, create_table_if_not_exists=False, @@ -97,11 +169,14 @@ def patched_store(monkeypatch): def mock_store(): """MagicMock of OracleDocumentStore for retriever unit tests.""" store = MagicMock(spec=OracleDocumentStore) + store.table_name = "test_docs" store.distance_metric = "COSINE" store._embedding_retrieval.return_value = [Document(id="A" * 32, content="hi")] store._embedding_retrieval_async.return_value = [Document(id="A" * 32, content="hi")] store._keyword_retrieval.return_value = [Document(id="A" * 32, content="hi")] store._keyword_retrieval_async.return_value = [Document(id="A" * 32, content="hi")] + store._hybrid_retrieval.return_value = [Document(id="A" * 32, content="hi")] + store._hybrid_retrieval_async.return_value = [Document(id="A" * 32, content="hi")] store.to_dict.return_value = { "type": "haystack_integrations.document_stores.oracle.document_store.OracleDocumentStore", "init_parameters": { diff --git a/integrations/oracle/tests/test_document_store.py b/integrations/oracle/tests/test_document_store.py index fe855b7a5e..81cbdc3061 100644 --- a/integrations/oracle/tests/test_document_store.py +++ b/integrations/oracle/tests/test_document_store.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import logging import uuid import oracledb as _oracledb @@ -27,13 +28,9 @@ FilterableDocsFixtureMixin, UpdateByFilterAsyncTest, ) -from haystack.utils import Secret -from haystack_integrations.document_stores.oracle import OracleConnectionConfig, OracleDocumentStore - -_USER = "haystack" -_PASSWORD = "haystack" -_DSN = "localhost:1521/freepdb1" +from haystack_integrations.document_stores.oracle import OracleDocumentStore +from haystack_integrations.document_stores.oracle.document_store import _output_type_string_handler def _doc(doc_id: str, content: str = "hello", meta: dict | None = None, embedding: list[float] | None = None): @@ -47,6 +44,23 @@ def _uid(suffix: str = "") -> str: return f"{base}{suffix.upper():>4}"[:32] +def test_connection_config_reads_oracle_wallet_env(monkeypatch, connection_config): + monkeypatch.setenv("ORACLE_USER", "wallet_user") + monkeypatch.setenv("ORACLE_PASSWORD", "wallet_password") + monkeypatch.setenv("ORACLE_DSN", "wallet_dsn") + monkeypatch.setenv("ORACLE_WALLET_LOCATION", "/opt/oracle/wallet") + monkeypatch.setenv("ORACLE_WALLET_PASSWORD", "wallet_secret") + + config = connection_config() + + assert config.user.resolve_value() == "wallet_user" + assert config.password.resolve_value() == "wallet_password" + assert config.dsn.resolve_value() == "wallet_dsn" + assert config.wallet_location == "/opt/oracle/wallet" + assert config.wallet_password is not None + assert config.wallet_password.resolve_value() == "wallet_secret" + + @pytest.mark.integration class TestOracleDocumentStore( DocumentStoreBaseTests, @@ -66,24 +80,23 @@ def _mock_doc(content="hello", embedding=None, doc_id="AABB" * 8): return Document(id=doc_id, content=content, meta={"k": "v"}, embedding=embedding) @pytest.fixture - def document_store(self): + def document_store(self, connection_config): """768-dim store — overrides the mixin's NotImplementedError stub.""" table = f"hs_sync_{uuid.uuid4().hex[:8]}" s = OracleDocumentStore( - connection_config=OracleConnectionConfig( - user=Secret.from_token(_USER), - password=Secret.from_token(_PASSWORD), - dsn=Secret.from_token(_DSN), - ), + connection_config=connection_config(), table_name=table, embedding_dim=768, distance_metric="COSINE", create_table_if_not_exists=True, ) - yield s - with s._get_connection() as conn, conn.cursor() as cur: - cur.execute(f"DROP TABLE {table} PURGE") - conn.commit() + try: + yield s + finally: + try: + s.delete_table() + finally: + s.close() # Mixin override def assert_documents_are_equal(self, received: list[Document], expected: list[Document]) -> None: @@ -136,13 +149,18 @@ def test_write_documents_skip_policy_uses_merge_not_matched(self, patched_store, assert "WHEN NOT MATCHED" in sql assert "WHEN MATCHED" not in sql - def test_write_documents_overwrite_policy_uses_full_merge(self, patched_store, mock_pool): + def test_write_documents_overwrite_policy_uses_update_insert_fallback(self, patched_store, mock_pool): _, _, cursor = mock_pool - patched_store.write_documents([self._mock_doc()], policy=DuplicatePolicy.OVERWRITE) + count = patched_store.write_documents( + [self._mock_doc(), self._mock_doc(doc_id="CCDD" * 8)], policy=DuplicatePolicy.OVERWRITE + ) sql = cursor.executemany.call_args[0][0] - assert "MERGE INTO" in sql - assert "WHEN MATCHED" in sql - assert "WHEN NOT MATCHED" in sql + assert count == 2 + assert "MERGE INTO" not in sql + assert "UPDATE test_docs" in sql + assert "IF SQL%ROWCOUNT = 0 THEN" in sql + assert "INSERT INTO test_docs" in sql + assert "WHEN DUP_VAL_ON_INDEX THEN" in sql def test_write_documents_returns_count(self, patched_store, mock_pool): # noqa: ARG002 count = patched_store.write_documents( @@ -163,6 +181,7 @@ def test_filter_documents_no_filter_fetches_all(self, patched_store, mock_pool): ] docs = patched_store.filter_documents() assert len(docs) == 2 + assert cursor.outputtypehandler is _output_type_string_handler sql = cursor.execute.call_args[0][0] assert "WHERE" not in sql @@ -220,6 +239,29 @@ def test_from_dict_roundtrip(self, patched_store): assert restored.embedding_dim == patched_store.embedding_dim assert restored.distance_metric == patched_store.distance_metric + def test_output_type_handler_converts_lobs_to_strings(self): + class FakeCursor: + arraysize = 50 + + def __init__(self): + self.call = None + + def var(self, type_code, *, arraysize): + self.call = (type_code, arraysize) + return "converted" + + class FakeMetadata: + def __init__(self, type_code): + self.type_code = type_code + + cursor = FakeCursor() + + assert _output_type_string_handler(cursor, FakeMetadata(_oracledb.DB_TYPE_CLOB)) == "converted" + assert cursor.call == (_oracledb.DB_TYPE_LONG, 50) + + assert _output_type_string_handler(cursor, FakeMetadata(_oracledb.DB_TYPE_NCLOB)) == "converted" + assert cursor.call == (_oracledb.DB_TYPE_LONG_NVARCHAR, 50) + def test_create_hnsw_index_sql(self, patched_store, mock_pool): _, _, cursor = mock_pool patched_store.create_hnsw_index() @@ -229,6 +271,90 @@ def test_create_hnsw_index_sql(self, patched_store, mock_pool): assert str(patched_store.hnsw_neighbors) in sql assert str(patched_store.hnsw_ef_construction) in sql + def test_create_keyword_index_uses_deterministic_bound_index_name(self, patched_store, mock_pool): + _, conn, cursor = mock_pool + + patched_store.create_keyword_index() + + sql = cursor.execute.call_args.args[0] + assert "DBMS_SEARCH.CREATE_INDEX(:index_name)" in sql + assert "DBMS_SEARCH.ADD_SOURCE(:index_name, :table_name)" in sql + assert cursor.execute.call_args.kwargs == { + "index_name": "test_docs_search_idx", + "table_name": "test_docs", + } + conn.commit.assert_called_once() + + def test_create_keyword_index_warns_when_dbms_search_is_unavailable(self, patched_store, mock_pool, caplog): + _, conn, cursor = mock_pool + cursor.execute.side_effect = _oracledb.DatabaseError( + "ORA-06550: line 1, column 7:\nPLS-00201: identifier 'DBMS_SEARCH.CREATE_INDEX' must be declared" + ) + + with caplog.at_level(logging.WARNING): + patched_store.create_keyword_index() + + assert "DBMS_SEARCH is unavailable" in caplog.text + conn.commit.assert_not_called() + + def test_delete_table_drops_search_index_before_table(self, patched_store, mock_pool): + _, conn, cursor = mock_pool + + patched_store.delete_table() + + calls = cursor.execute.call_args_list + assert calls[0].args[0] == "BEGIN DBMS_SEARCH.DROP_INDEX(:index_name); END;" + assert calls[0].kwargs == {"index_name": "test_docs_search_idx"} + executed_sql = [call.args[0] for call in calls] + assert executed_sql[-1] == "DROP TABLE test_docs PURGE" + conn.commit.assert_called_once() + + def test_delete_table_skips_keyword_cleanup_when_dbms_search_is_unavailable(self, patched_store, mock_pool): + _, conn, cursor = mock_pool + cursor.execute.side_effect = [ + _oracledb.DatabaseError( + "ORA-06550: line 1, column 7:\nPLS-00201: identifier 'DBMS_SEARCH.DROP_INDEX' must be declared" + ), + None, + ] + + patched_store.delete_table() + + calls = cursor.execute.call_args_list + assert calls[0].args[0] == "BEGIN DBMS_SEARCH.DROP_INDEX(:index_name); END;" + assert calls[1].args[0] == "DROP TABLE test_docs PURGE" + conn.commit.assert_called_once() + + def test_close_closes_sync_pool(self, patched_store, mock_pool): + pool, _, _ = mock_pool + + patched_store.count_documents() + patched_store.close() + + pool.close.assert_called_once() + assert patched_store._pool is None + + @pytest.mark.asyncio + async def test_close_async_closes_async_and_sync_pools(self, patched_store, mock_pool): + class FakeAsyncPool: + def __init__(self): + self.closed = False + + async def close(self): + self.closed = True + + pool, _, _ = mock_pool + async_pool = FakeAsyncPool() + patched_store._async_pool = async_pool + + patched_store.count_documents() + await patched_store.close_async() + + assert async_pool.closed is True + assert patched_store._async_pool is None + pool.close.assert_called_once() + assert patched_store._pool is None + def test_write_documents_empty_list_returns_zero(self, document_store): assert document_store.write_documents([]) == 0 diff --git a/integrations/oracle/tests/test_document_store_features.py b/integrations/oracle/tests/test_document_store_features.py new file mode 100644 index 0000000000..4113d38e9d --- /dev/null +++ b/integrations/oracle/tests/test_document_store_features.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest + +from haystack_integrations.document_stores.oracle import OracleVectorizerPreference + + +def test_create_vector_index_ivf_sql(patched_store, mock_pool): + _, _, cursor = mock_pool + patched_store.create_vector_index( + index_type="IVF", + params={"idx_name": "TEST_DOCS_IVF", "neighbor_partitions": 64, "samples_per_partition": 8}, + ) + + sql = cursor.execute.call_args[0][0] + + assert "CREATE VECTOR INDEX IF NOT EXISTS TEST_DOCS_IVF" in sql + assert "ORGANIZATION NEIGHBOR PARTITIONS" in sql + assert "neighbor partitions 64" in sql + assert "samples_per_partition 8" in sql + + +def test_create_hybrid_vector_index_uses_text_column(patched_store, mock_pool): + _, _, cursor = mock_pool + preference = OracleVectorizerPreference(patched_store, "PREF_TEST_DOCS") + + returned = patched_store.create_hybrid_vector_index( + "TEST_DOCS_HYBRID", + vectorizer_preference=preference, + params={"parameters": {"language": "american"}, "parallel": 2}, + ) + sql = cursor.execute.call_args[0][0] + + assert returned is preference + assert "CREATE HYBRID VECTOR INDEX TEST_DOCS_HYBRID ON test_docs(text)" in sql + assert "vectorizer PREF_TEST_DOCS" in sql + assert "language american" in sql + assert "PARALLEL 2" in sql + + +def test_create_hybrid_vector_index_drops_auto_preference_on_failure(patched_store, mock_pool, monkeypatch): + _, _, cursor = mock_pool + cursor.execute.side_effect = RuntimeError("DDL failed") + preference = OracleVectorizerPreference(patched_store, "PREF_TEST_DOCS") + drop_calls = [] + + def create_preference(_cls, _document_store, _text_embedder): + return preference + + def drop_preference(): + drop_calls.append(preference.preference_name) + + monkeypatch.setattr(OracleVectorizerPreference, "create", classmethod(create_preference)) + monkeypatch.setattr(preference, "drop", drop_preference) + + with pytest.raises(RuntimeError, match="DDL failed"): + patched_store.create_hybrid_vector_index("TEST_DOCS_HYBRID", text_embedder=object()) + + assert drop_calls == ["PREF_TEST_DOCS"] + + +def test_create_hybrid_vector_index_keeps_caller_preference_on_failure(patched_store, mock_pool, monkeypatch): + _, _, cursor = mock_pool + cursor.execute.side_effect = RuntimeError("DDL failed") + preference = OracleVectorizerPreference(patched_store, "PREF_TEST_DOCS") + drop_calls = [] + + def drop_preference(): + drop_calls.append(preference.preference_name) + + monkeypatch.setattr(preference, "drop", drop_preference) + + with pytest.raises(RuntimeError, match="DDL failed"): + patched_store.create_hybrid_vector_index("TEST_DOCS_HYBRID", vectorizer_preference=preference) + + assert drop_calls == [] + + +@pytest.mark.asyncio +async def test_create_hybrid_vector_index_async_drops_auto_preference_on_fallback_failure(patched_store, monkeypatch): + preference = OracleVectorizerPreference(patched_store, "PREF_TEST_DOCS") + drop_calls = [] + + async def create_preference(_cls, _document_store, _text_embedder): + return preference + + async def has_async_pool(): + return False + + async def drop_preference(): + drop_calls.append(preference.preference_name) + + def create_index(_idx_name, **_kwargs): + msg = "DDL failed" + raise RuntimeError(msg) + + monkeypatch.setattr(OracleVectorizerPreference, "create_async", classmethod(create_preference)) + monkeypatch.setattr(patched_store, "_has_async_pool", has_async_pool) + monkeypatch.setattr(patched_store, "create_hybrid_vector_index", create_index) + monkeypatch.setattr(preference, "drop_async", drop_preference) + + with pytest.raises(RuntimeError, match="DDL failed"): + await patched_store.create_hybrid_vector_index_async("TEST_DOCS_HYBRID", text_embedder=object()) + + assert drop_calls == ["PREF_TEST_DOCS"] + + +def test_hybrid_retrieval_keeps_scores_aligned_when_rowid_disappears(patched_store, mock_pool): + _, _, cursor = mock_pool + matching_id = "B" * 32 + search_rows = [ + {"rowid": "deleted_rowid", "score": 0.9, "text_score": 0.8, "vector_score": 0.7}, + {"rowid": "matching_rowid", "score": 0.3, "text_score": 0.2, "vector_score": 0.1}, + ] + cursor.fetchone.side_effect = [ + (json.dumps(search_rows),), + None, + (matching_id, "matched document", '{"lang":"en"}'), + ] + + documents = patched_store._hybrid_retrieval( + "query", + index_name="TEST_DOCS_HYBRID", + top_k=2, + return_scores=True, + ) + + assert len(documents) == 1 + assert documents[0].id == matching_id + assert documents[0].score == 0.3 + assert documents[0].meta == {"lang": "en", "score": 0.3, "text_score": 0.2, "vector_score": 0.1} + + +def test_default_to_dict_omits_new_vector_index_fields(patched_store): + data = patched_store.to_dict() + + assert "vector_index_type" not in data["init_parameters"] + assert "vector_index_params" not in data["init_parameters"] + + +def test_custom_vector_index_config_roundtrips(patched_store): + patched_store.vector_index_type = "IVF" + patched_store.vector_index_params = {"idx_name": "TEST_DOCS_IVF", "neighbor_partitions": 64} + + restored = patched_store.from_dict(patched_store.to_dict()) + + assert restored.vector_index_type == "IVF" + assert restored.vector_index_params == {"idx_name": "TEST_DOCS_IVF", "neighbor_partitions": 64} diff --git a/integrations/oracle/tests/test_embedders.py b/integrations/oracle/tests/test_embedders.py new file mode 100644 index 0000000000..9fe937fd99 --- /dev/null +++ b/integrations/oracle/tests/test_embedders.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from haystack.dataclasses import Document + +from haystack_integrations.components.embedders.oracle import OracleDocumentEmbedder, OracleTextEmbedder +from haystack_integrations.components.embedders.oracle._base import ( + _execute_with_fetch_lobs, + _execute_with_fetch_lobs_async, +) + + +def test_text_embedder_requires_connection_config(): + with pytest.raises(ValueError, match="connection_config"): + OracleTextEmbedder( + connection_config=None, + embedding_params={"provider": "database", "model": "demo"}, + ) + + +def test_text_embedder_run_returns_single_embedding(monkeypatch, connection_config): + embedder = OracleTextEmbedder( + connection_config=connection_config(secret_source="env_var"), + embedding_params={"provider": "database", "model": "demo"}, + ) + + def embed_documents(texts): + assert texts == ["hello"] + return [[0.1, 0.2, 0.3]] + + monkeypatch.setattr(embedder, "_embed_documents", embed_documents) + + result = embedder.run("hello") + + assert result["embedding"] == [0.1, 0.2, 0.3] + assert result["meta"] == {"provider": "database", "model": "demo"} + + +def test_text_embedder_rejects_non_string(connection_config): + embedder = OracleTextEmbedder( + connection_config=connection_config(secret_source="env_var"), + embedding_params={"provider": "database", "model": "demo"}, + ) + with pytest.raises(TypeError, match="expects a string"): + embedder.run(["not text"]) + + +def test_text_embedder_execute_uses_fetch_lobs_when_supported(): + class FakeCursor: + def __init__(self): + self.call = None + + def execute(self, statement, parameters=None, *, fetch_lobs=True): + self.call = (statement, parameters, fetch_lobs) + + cursor = FakeCursor() + + _execute_with_fetch_lobs(cursor, "SELECT 1 FROM DUAL", ["p0"]) + + assert cursor.call == ("SELECT 1 FROM DUAL", ["p0"], False) + + +@pytest.mark.asyncio +async def test_text_embedder_execute_async_uses_fetch_lobs_when_supported(): + class FakeCursor: + def __init__(self): + self.call = None + + async def execute(self, statement, parameters=None, *, fetch_lobs=True): + self.call = (statement, parameters, fetch_lobs) + + cursor = FakeCursor() + + await _execute_with_fetch_lobs_async(cursor, "SELECT 1 FROM DUAL", ["p0"]) + + assert cursor.call == ("SELECT 1 FROM DUAL", ["p0"], False) + + +@pytest.mark.asyncio +async def test_text_embedder_async_awaits_gettype(monkeypatch, connection_config): + class FakeLob: + def __init__(self): + self.value = None + + async def write(self, value): + self.value = value + + class FakeVectorArray(list): + pass + + class FakeVectorArrayType: + def newobject(self): + return FakeVectorArray() + + class FakeCursor: + def __init__(self): + self.rows = iter([('{"embed_vector": "[0.1, 0.2, 0.3]"}',)]) + + def __enter__(self): + return self + + def __exit__(self, *_): + return None + + def setinputsizes(self, *_): + return None + + async def execute(self, statement, params): + assert "UTL_TO_EMBEDDINGS" in statement + assert len(params[0]) == 1 + assert params[0][0].value == '{"chunk_id": 1, "chunk_data": "hello"}' + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.rows) + except StopIteration as exc: + raise StopAsyncIteration from exc + + class FakeConnection: + def cursor(self): + return FakeCursor() + + async def gettype(self, name): + assert name == "SYS.VECTOR_ARRAY_T" + return FakeVectorArrayType() + + async def createlob(self, *_): + return FakeLob() + + class FakeConnectionContext: + async def __aenter__(self): + return FakeConnection() + + async def __aexit__(self, *_): + return None + + embedder = OracleTextEmbedder( + connection_config=connection_config(secret_source="env_var"), + embedding_params={"provider": "database", "model": "demo"}, + ) + + async def connection_context_async(): + return FakeConnectionContext() + + monkeypatch.setattr(embedder, "_connection_context_async", connection_context_async) + + result = await embedder.run_async("hello") + + assert result["embedding"] == [0.1, 0.2, 0.3] + + +@pytest.mark.asyncio +async def test_text_embedder_async_reads_lob_rows(monkeypatch, connection_config): + class FakeLob: + async def write(self, _): + return None + + async def read(self): + return '{"embed_vector": "[0.1, 0.2, 0.3]"}' + + class FakeVectorArray(list): + pass + + class FakeVectorArrayType: + def newobject(self): + return FakeVectorArray() + + class FakeCursor: + def __init__(self): + self.rows = iter([(FakeLob(),)]) + + def __enter__(self): + return self + + def __exit__(self, *_): + return None + + def setinputsizes(self, *_): + return None + + async def execute(self, *_): + return None + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.rows) + except StopIteration as exc: + raise StopAsyncIteration from exc + + class FakeConnection: + def cursor(self): + return FakeCursor() + + async def gettype(self, name): + assert name == "SYS.VECTOR_ARRAY_T" + return FakeVectorArrayType() + + async def createlob(self, *_): + return FakeLob() + + class FakeConnectionContext: + async def __aenter__(self): + return FakeConnection() + + async def __aexit__(self, *_): + return None + + embedder = OracleTextEmbedder( + connection_config=connection_config(secret_source="env_var"), + embedding_params={"provider": "database", "model": "demo"}, + ) + + async def connection_context_async(): + return FakeConnectionContext() + + monkeypatch.setattr(embedder, "_connection_context_async", connection_context_async) + + result = await embedder.run_async("hello") + + assert result["embedding"] == [0.1, 0.2, 0.3] + + +def test_document_embedder_prepares_metadata_and_content(connection_config): + embedder = OracleDocumentEmbedder( + connection_config=connection_config(secret_source="env_var"), + embedding_params={"provider": "database", "model": "demo"}, + meta_fields_to_embed=["title", "missing"], + embedding_separator=" | ", + ) + + assert not isinstance(embedder, OracleTextEmbedder) + + texts = embedder._prepare_texts_to_embed([Document(content="body", meta={"title": "heading"})]) + + assert texts == ["heading | body"] + + +def test_document_embedder_sets_document_embeddings(monkeypatch, connection_config): + embedder = OracleDocumentEmbedder( + connection_config=connection_config(secret_source="env_var"), + embedding_params={"provider": "database", "model": "demo"}, + ) + + def embed_documents(texts): + assert texts == ["one", "two"] + return [[0.4, 0.5], [0.6, 0.7]] + + monkeypatch.setattr(embedder, "_embed_documents", embed_documents) + documents = [Document(content="one"), Document(content="two")] + + result = embedder.run(documents) + + assert result["documents"] is not documents + assert documents[0].embedding is None + assert documents[1].embedding is None + assert result["documents"][0].embedding == [0.4, 0.5] + assert result["documents"][1].embedding == [0.6, 0.7] + + +def test_embedder_to_dict_keeps_connection_config_secret_structured(connection_config): + embedder = OracleTextEmbedder( + connection_config=connection_config(secret_source="env_var"), + embedding_params={"provider": "database", "model": "demo"}, + ) + + data = embedder.to_dict() + + assert "connection_params" not in data["init_parameters"] + password = data["init_parameters"]["connection_config"]["password"] + assert password["type"] == "env_var" diff --git a/integrations/oracle/tests/test_filter_translator.py b/integrations/oracle/tests/test_filter_translator.py index a31be3db82..7d9c22cc7c 100644 --- a/integrations/oracle/tests/test_filter_translator.py +++ b/integrations/oracle/tests/test_filter_translator.py @@ -134,3 +134,15 @@ def test_param_counter_increments_correctly(): } ) assert set(params.keys()) == {"p0", "p1", "p2"} + + +def test_contains_operator_uses_json_exists(): + sql, params = _translate({"field": "meta.tags", "operator": "contains", "value": "oracle"}) + assert "JSON_EXISTS(metadata, '$.tags[*]?(@ == $val)' PASSING :p0 AS \"val\")" in sql + assert params == {"p0": "oracle"} + + +def test_not_contains_operator_negates_json_exists(): + sql, params = _translate({"field": "meta.tags", "operator": "not contains", "value": "draft"}) + assert sql.startswith("(NOT JSON_EXISTS") + assert params == {"p0": "draft"} diff --git a/integrations/oracle/tests/test_hybrid_retriever.py b/integrations/oracle/tests/test_hybrid_retriever.py new file mode 100644 index 0000000000..e6ac005da6 --- /dev/null +++ b/integrations/oracle/tests/test_hybrid_retriever.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.errors import FilterError + +from haystack_integrations.components.retrievers.oracle import OracleHybridRetriever + + +def test_search_params_for_hybrid_mode(patched_store): + params = patched_store._hybrid_search_params( + "oracle vector", + index_name="TEST_HYBRID", + search_mode="hybrid", + filters=None, + top_k=5, + params=None, + ) + + assert params["hybrid_index_name"] == "TEST_HYBRID" + assert params["vector"]["search_text"] == "oracle vector" + assert params["text"]["search_text"] == "oracle vector" + assert params["return"]["topN"] == 5 + + +def test_search_params_for_keyword_mode(patched_store): + params = patched_store._hybrid_search_params( + "oracle vector", + index_name="TEST_HYBRID", + search_mode="keyword", + filters=None, + top_k=3, + params=None, + ) + + assert "vector" not in params + assert params["text"]["search_text"] == "oracle vector" + assert params["return"]["topN"] == 3 + + +def test_search_params_converts_filters(patched_store): + params = patched_store._hybrid_search_params( + "oracle", + index_name="TEST_HYBRID", + search_mode="hybrid", + filters={"field": "meta.lang", "operator": "==", "value": "en"}, + top_k=10, + params=None, + ) + + assert params["filter_by"] == {"op": "=", "path": "metadata.lang", "type": "string", "args": ["en"]} + + +def test_search_params_rejects_filter_by_collision(patched_store): + with pytest.raises(FilterError, match="Cannot combine"): + patched_store._hybrid_search_params( + "oracle", + index_name="TEST_HYBRID", + search_mode="hybrid", + filters={"field": "meta.lang", "operator": "==", "value": "en"}, + top_k=10, + params={"filter_by": {"op": "=", "path": "metadata.lang", "type": "string", "args": ["en"]}}, + ) + + +def test_validate_params_rejects_derived_search_text(mock_store): + with pytest.raises(ValueError, match="search_text"): + OracleHybridRetriever( + document_store=mock_store, + index_name="TEST_HYBRID", + params={"vector": {"search_text": "do not set"}}, + ) + + +def test_filter_policy_string_is_supported(mock_store): + retriever = OracleHybridRetriever( + document_store=mock_store, + index_name="TEST_HYBRID", + filter_policy="merge", + ) + + assert retriever.filter_policy == FilterPolicy.MERGE + + +def test_run_calls_hybrid_retrieval(mock_store): + documents = [Document(id="A" * 32, content="hi")] + mock_store._hybrid_retrieval.return_value = documents + retriever = OracleHybridRetriever( + document_store=mock_store, + index_name="TEST_HYBRID", + search_mode="semantic", + top_k=5, + params={"vector": {"score_weight": 2}}, + return_scores=True, + ) + filters = {"field": "meta.lang", "operator": "==", "value": "en"} + + result = retriever.run("oracle", filters=filters, top_k=3, params={"text": {"score_weight": 1}}) + + assert result["documents"] is documents + mock_store._hybrid_retrieval.assert_called_once_with( + "oracle", + index_name="TEST_HYBRID", + search_mode="semantic", + filters=filters, + top_k=3, + params={"vector": {"score_weight": 2}, "text": {"score_weight": 1}}, + return_scores=True, + ) + + +@pytest.mark.asyncio +async def test_run_async_calls_hybrid_retrieval_async(mock_store): + retriever = OracleHybridRetriever(document_store=mock_store, index_name="TEST_HYBRID", top_k=5) + + result = await retriever.run_async("oracle") + + mock_store._hybrid_retrieval_async.assert_called_once_with( + "oracle", + index_name="TEST_HYBRID", + search_mode="hybrid", + filters={}, + top_k=5, + params={}, + return_scores=False, + ) + assert "documents" in result + + +def test_to_dict_from_dict_roundtrip(mock_store): + retriever = OracleHybridRetriever( + document_store=mock_store, + index_name="TEST_HYBRID", + search_mode="semantic", + filters={"field": "meta.lang", "operator": "==", "value": "en"}, + top_k=2, + return_scores=True, + ) + + data = retriever.to_dict() + restored = OracleHybridRetriever.from_dict(data) + + assert restored.index_name == "TEST_HYBRID" + assert restored.search_mode == "semantic" + assert restored.filters == {"field": "meta.lang", "operator": "==", "value": "en"} + assert restored.top_k == 2 + assert restored.return_scores is True diff --git a/integrations/oracle/tests/test_oracle_features_integration.py b/integrations/oracle/tests/test_oracle_features_integration.py new file mode 100644 index 0000000000..3993cfc088 --- /dev/null +++ b/integrations/oracle/tests/test_oracle_features_integration.py @@ -0,0 +1,361 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import uuid +from collections.abc import Iterator +from contextlib import contextmanager + +import oracledb +import pytest +from haystack import Pipeline +from haystack.dataclasses import Document +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.components.embedders.oracle import OracleDocumentEmbedder, OracleTextEmbedder +from haystack_integrations.components.retrievers.oracle import OracleEmbeddingRetriever, OracleHybridRetriever +from haystack_integrations.document_stores.oracle import ( + OracleDocumentStore, + OracleVectorizerPreference, +) + +pytestmark = pytest.mark.integration + +_DEFAULT_EMBEDDING_MODEL = "ALL_MINILM_L12_V2" + + +def _env_value(*names: str, default: str | None = None) -> str | None: + for name in names: + value = os.getenv(name) + if value: + return value + return default + + +def _embedding_params() -> dict: + if params := _env_value("ORACLE_EMBEDDING_PARAMS"): + return json.loads(params) + return { + "provider": _env_value("ORACLE_EMBEDDING_PROVIDER", default="database"), + "model": _env_value("ORACLE_EMBEDDING_MODEL", default=_DEFAULT_EMBEDDING_MODEL), + } + + +def _proxy() -> str | None: + return _env_value("ORACLE_EMBEDDING_PROXY", "ORACLE_PROXY") + + +def _table_name(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex[:12]}".upper() + + +def _drop_table(store: OracleDocumentStore) -> None: + store.delete_table() + + +def _drop_sql_index_if_exists(store: OracleDocumentStore, index_name: str) -> None: + if not index_name.replace("_", "").isalnum(): + msg = f"Invalid test index name: {index_name}" + raise ValueError(msg) + try: + with store._get_connection() as conn, conn.cursor() as cur: + cur.execute(f"DROP INDEX {index_name}") + conn.commit() + except oracledb.DatabaseError as exc: + message = str(exc) + if "ORA-01418" in message or "ORA-00942" in message: + return + raise + + +@contextmanager +def _temporary_store( + connection_config, embedding_dim: int = 4, *, prefix: str = "HS_IT" +) -> Iterator[OracleDocumentStore]: + store = OracleDocumentStore( + connection_config=connection_config(), + table_name=_table_name(prefix), + embedding_dim=embedding_dim, + distance_metric="COSINE", + create_table_if_not_exists=True, + ) + try: + yield store + finally: + try: + _drop_table(store) + finally: + store.close() + + +def _text_embedder(connection_config) -> OracleTextEmbedder: + return OracleTextEmbedder( + connection_config=connection_config(), + embedding_params=_embedding_params(), + proxy=_proxy(), + ) + + +def _document_embedder(connection_config) -> OracleDocumentEmbedder: + return OracleDocumentEmbedder( + connection_config=connection_config(), + embedding_params=_embedding_params(), + proxy=_proxy(), + meta_fields_to_embed=["title"], + ) + + +def test_contains_and_not_contains_filters_live(connection_config) -> None: + run_id = uuid.uuid4().hex + with _temporary_store(connection_config, prefix="HS_FLT") as store: + store.write_documents( + [ + Document(content="Oracle vector search", meta={"run_id": run_id, "tags": ["oracle", "vector"]}), + Document(content="Haystack pipelines", meta={"run_id": run_id, "tags": ["haystack", "pipeline"]}), + ], + policy=DuplicatePolicy.NONE, + ) + + contains_results = store.filter_documents( + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.run_id", "operator": "==", "value": run_id}, + {"field": "meta.tags", "operator": "contains", "value": "oracle"}, + ], + } + ) + not_contains_results = store.filter_documents( + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.run_id", "operator": "==", "value": run_id}, + {"field": "meta.tags", "operator": "not contains", "value": "oracle"}, + ], + } + ) + + assert [doc.content for doc in contains_results] == ["Oracle vector search"] + assert [doc.content for doc in not_contains_results] == ["Haystack pipelines"] + + +def test_hnsw_vector_index_creation_live(connection_config) -> None: + with _temporary_store(connection_config, prefix="HS_HNSW") as hnsw_store: + hnsw_index_name = f"{hnsw_store.table_name}_HNSW" + hnsw_store.write_documents( + [Document(content="hnsw", embedding=[1.0, 0.0, 0.0, 0.0])], + policy=DuplicatePolicy.NONE, + ) + try: + hnsw_store.create_vector_index( + index_type="HNSW", + params={ + "idx_name": hnsw_index_name, + "neighbors": 2, + "efConstruction": 16, + "accuracy": 80, + "parallel": 1, + }, + ) + finally: + _drop_sql_index_if_exists(hnsw_store, hnsw_index_name) + + +def test_ivf_vector_index_creation_live(connection_config) -> None: + with _temporary_store(connection_config, prefix="HS_IVF") as ivf_store: + ivf_index_name = f"{ivf_store.table_name}_IVF" + ivf_store.write_documents( + [Document(content="ivf", embedding=[1.0, 0.0, 0.0, 0.0])], + policy=DuplicatePolicy.NONE, + ) + try: + ivf_store.create_vector_index( + index_type="IVF", + params={ + "idx_name": ivf_index_name, + "neighbor_partitions": 1, + "accuracy": 90, + "parallel": 1, + }, + ) + finally: + _drop_sql_index_if_exists(ivf_store, ivf_index_name) + + +@pytest.mark.asyncio +async def test_async_ivf_vector_index_creation_live(connection_config) -> None: + with _temporary_store(connection_config, prefix="HS_AIVF") as store: + index_name = f"{store.table_name}_IVF" + store.write_documents( + [Document(content="async ivf", embedding=[1.0, 0.0, 0.0, 0.0])], + policy=DuplicatePolicy.NONE, + ) + try: + await store.create_vector_index_async( + index_type="IVF", + params={ + "idx_name": index_name, + "neighbor_partitions": 1, + "accuracy": 90, + "parallel": 1, + }, + ) + finally: + _drop_sql_index_if_exists(store, index_name) + + +def test_oracle_embedders_pipeline_retrieval_live(connection_config) -> None: + text_embedder = _text_embedder(connection_config) + query_embedding = text_embedder.run("Oracle Database vector search")["embedding"] + run_id = uuid.uuid4().hex + + with _temporary_store(connection_config, embedding_dim=len(query_embedding), prefix="HS_EMB") as store: + docs = [ + Document(content="Oracle Database supports AI Vector Search.", meta={"run_id": run_id, "title": "Oracle"}), + Document(content="Haystack pipelines connect components.", meta={"run_id": run_id, "title": "Haystack"}), + ] + embedded_docs = _document_embedder(connection_config).run(docs)["documents"] + store.write_documents(embedded_docs, policy=DuplicatePolicy.NONE) + + pipeline = Pipeline() + pipeline.add_component("text_embedder", text_embedder) + pipeline.add_component( + "retriever", + OracleEmbeddingRetriever( + document_store=store, + filters={"field": "meta.run_id", "operator": "==", "value": run_id}, + top_k=2, + ), + ) + pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + + result = pipeline.run({"text_embedder": {"text": "Oracle vector search"}}) + + retrieved = result["retriever"]["documents"] + assert retrieved + assert any("Oracle Database" in doc.content for doc in retrieved) + + +@pytest.mark.asyncio +async def test_oracle_text_embedder_async_live(connection_config) -> None: + if not hasattr(oracledb, "connect_async"): + pytest.skip("python-oracledb does not provide connect_async") + + result = await _text_embedder(connection_config).run_async("Oracle Database vector search") + + assert result["embedding"] + assert all(isinstance(value, float) for value in result["embedding"]) + + +def test_vectorizer_preference_create_drop_live(connection_config) -> None: + preference: OracleVectorizerPreference | None = None + with _temporary_store(connection_config, prefix="HS_PREF") as store: + try: + preference = OracleVectorizerPreference.create( + store, + _text_embedder(connection_config), + preference_name=f"{store.table_name}_PREF", + ) + assert preference.preference_name == f"{store.table_name}_PREF" + finally: + if preference is not None: + preference.drop() + + +@pytest.mark.asyncio +async def test_async_hybrid_vector_index_creation_live(connection_config) -> None: + text_embedder = _text_embedder(connection_config) + query_embedding = text_embedder.run("Oracle hybrid vector search")["embedding"] + index_name = "" + store = OracleDocumentStore( + connection_config=connection_config(), + table_name=_table_name("HS_AHYB"), + embedding_dim=len(query_embedding), + distance_metric="COSINE", + create_table_if_not_exists=True, + ) + preference: OracleVectorizerPreference | None = None + try: + store.write_documents( + [Document(content="Oracle hybrid vector search", embedding=query_embedding)], + policy=DuplicatePolicy.NONE, + ) + index_name = f"{store.table_name}_HIDX" + preference = await store.create_hybrid_vector_index_async( + index_name, + text_embedder=text_embedder, + params={"parallel": 1}, + ) + assert isinstance(preference, OracleVectorizerPreference) + finally: + try: + try: + if index_name: + _drop_sql_index_if_exists(store, index_name) + _drop_table(store) + finally: + if preference is not None: + preference.drop() + finally: + store.close() + + +def test_hybrid_retriever_live(connection_config) -> None: + text_embedder = _text_embedder(connection_config) + document_embedder = _document_embedder(connection_config) + query_embedding = text_embedder.run("Oracle hybrid search")["embedding"] + index_name = "" + store = OracleDocumentStore( + connection_config=connection_config(), + table_name=_table_name("HS_HYB"), + embedding_dim=len(query_embedding), + distance_metric="COSINE", + create_table_if_not_exists=True, + ) + preference: OracleVectorizerPreference | None = None + try: + docs = document_embedder.run( + [ + Document(content="Oracle Database hybrid vector search.", meta={"title": "Oracle", "lang": "en"}), + Document(content="Haystack supports retrieval pipelines.", meta={"title": "Haystack", "lang": "de"}), + ] + )["documents"] + store.write_documents(docs, policy=DuplicatePolicy.NONE) + index_name = f"{store.table_name}_HIDX" + preference = store.create_hybrid_vector_index(index_name, text_embedder=text_embedder, params={"parallel": 1}) + + result = OracleHybridRetriever( + document_store=store, + index_name=index_name, + search_mode="hybrid", + top_k=2, + return_scores=True, + ).run("Oracle hybrid vector search") + + assert result["documents"] + assert any("Oracle Database" in doc.content for doc in result["documents"]) + assert all(doc.score is not None for doc in result["documents"]) + + filtered_result = OracleHybridRetriever( + document_store=store, + index_name=index_name, + search_mode="hybrid", + filters={"field": "meta.lang", "operator": "==", "value": "en"}, + top_k=2, + ).run("Oracle hybrid vector search") + + assert filtered_result["documents"] + assert all(doc.meta["lang"] == "en" for doc in filtered_result["documents"]) + finally: + try: + try: + if index_name: + _drop_sql_index_if_exists(store, index_name) + _drop_table(store) + finally: + if preference is not None: + preference.drop() + finally: + store.close()