From 32ce143397b6794387a5a776ea46f7258c28a636 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 2 Apr 2026 12:05:22 -0400 Subject: [PATCH 1/2] feat: add Oracle AI Vector Search DocumentStore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New integration: oracle-haystack Adds OracleDocumentStore backed by Oracle Database 23ai/26ai native VECTOR type, plus OracleEmbeddingRetriever. Key features: - VECTOR(dim, FLOAT32) column with HNSW approximate search index - Supports Oracle Autonomous Database (ADB-S) wallet connections - All three DuplicatePolicy modes via INSERT / MERGE SQL - Haystack filter grammar translated to JSON_VALUE WHERE clauses - Full async support (asyncio.to_thread wrappers) - Credentials handled via Haystack Secret — never serialised as plaintext - 34 unit tests (mocked oracledb) + 13 integration tests (live Oracle) - e2e validated: 200 SQuAD passages at 181 docs/sec write, <700ms query Closes # --- integrations/oracle/CHANGELOG.md | 0 integrations/oracle/README.md | 0 integrations/oracle/pyproject.toml | 70 +++ .../components/retrievers/oracle/__init__.py | 3 + .../retrievers/oracle/embedding_retriever.py | 105 ++++ .../document_stores/oracle/__about__.py | 1 + .../document_stores/oracle/__init__.py | 6 + .../document_stores/oracle/document_store.py | 559 ++++++++++++++++++ integrations/oracle/tests/__init__.py | 0 integrations/oracle/tests/e2e_real_data.py | 209 +++++++ .../oracle/tests/integration/__init__.py | 0 .../integration/test_oracle_document_store.py | 212 +++++++ integrations/oracle/tests/unit/__init__.py | 0 .../oracle/tests/unit/test_document_store.py | 215 +++++++ .../tests/unit/test_embedding_retriever.py | 87 +++ .../tests/unit/test_filter_translator.py | 132 +++++ 16 files changed, 1599 insertions(+) create mode 100644 integrations/oracle/CHANGELOG.md create mode 100644 integrations/oracle/README.md create mode 100644 integrations/oracle/pyproject.toml create mode 100644 integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py create mode 100644 integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py create mode 100644 integrations/oracle/src/haystack_integrations/document_stores/oracle/__about__.py create mode 100644 integrations/oracle/src/haystack_integrations/document_stores/oracle/__init__.py create mode 100644 integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py create mode 100644 integrations/oracle/tests/__init__.py create mode 100644 integrations/oracle/tests/e2e_real_data.py create mode 100644 integrations/oracle/tests/integration/__init__.py create mode 100644 integrations/oracle/tests/integration/test_oracle_document_store.py create mode 100644 integrations/oracle/tests/unit/__init__.py create mode 100644 integrations/oracle/tests/unit/test_document_store.py create mode 100644 integrations/oracle/tests/unit/test_embedding_retriever.py create mode 100644 integrations/oracle/tests/unit/test_filter_translator.py diff --git a/integrations/oracle/CHANGELOG.md b/integrations/oracle/CHANGELOG.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/oracle/README.md b/integrations/oracle/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/oracle/pyproject.toml b/integrations/oracle/pyproject.toml new file mode 100644 index 0000000000..6dfde413c6 --- /dev/null +++ b/integrations/oracle/pyproject.toml @@ -0,0 +1,70 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "oracle-haystack" +dynamic = ["version"] +description = "Oracle AI Vector Search DocumentStore integration for Haystack" +readme = "README.md" +requires-python = ">=3.10" +license = "Apache-2.0" +keywords = ["haystack", "oracle", "vector search", "document store", "RAG", "OCI"] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 3 - Alpha", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Implementation :: CPython", +] +dependencies = [ + "haystack-ai>=2.0.0", + "oracledb>=2.1.0,<3.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-mock>=3.12.0", + "ruff>=0.4.0", + "mypy>=1.9.0", +] + +[project.urls] +"Source Code" = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/oracle" +"Bug Tracker" = "https://github.com/deepset-ai/haystack-core-integrations/issues" + +[tool.hatch.version] +source = "vcs" +fallback-version = "0.1.0" + +[tool.hatch.version.raw-options] +root = "../.." +version_scheme = "no-guess-dev" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" +markers = [ + "unit: fast tests, no Oracle connection required", + "integration: require a live Oracle 23ai instance", +] + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I", "B"] + +[tool.mypy] +python_version = "3.10" +disallow_untyped_defs = true +ignore_missing_imports = true diff --git a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py new file mode 100644 index 0000000000..dc9bed2c38 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py @@ -0,0 +1,3 @@ +from haystack_integrations.components.retrievers.oracle.embedding_retriever import OracleEmbeddingRetriever + +__all__ = ["OracleEmbeddingRetriever"] diff --git a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py new file mode 100644 index 0000000000..d8d3d955a8 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import Any + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document + +from haystack_integrations.document_stores.oracle import OracleDocumentStore + + +def _merge_filters( + base: dict[str, Any] | None, + override: dict[str, Any] | None, +) -> dict[str, Any] | None: + """AND-merge two Haystack filter dicts. Returns None if both are empty.""" + base = base or {} + override = override or {} + if not base and not override: + return None + if not base: + return override + if not override: + return base + return {"operator": "AND", "conditions": [base, override]} + + +@component +class OracleEmbeddingRetriever: + """Retrieves documents from an OracleDocumentStore using vector similarity. + + Use inside a Haystack pipeline after a text embedder:: + + pipeline.add_component("embedder", SentenceTransformersTextEmbedder()) + pipeline.add_component("retriever", OracleEmbeddingRetriever( + document_store=store, top_k=5 + )) + pipeline.connect("embedder.embedding", "retriever.query_embedding") + """ + + def __init__( + self, + *, + document_store: OracleDocumentStore, + filters: dict[str, Any] | None = None, + top_k: int = 10, + ) -> None: + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + + @component.output_types(documents=list[Document]) + def run( + self, + query_embedding: list[float], + filters: dict[str, Any] | None = None, + top_k: int | None = None, + ) -> dict[str, list[Document]]: + """Retrieve documents by vector similarity. + + Args: + query_embedding: Dense float vector from an embedder component. + filters: Runtime filters, AND-merged with constructor filters. + top_k: Override the constructor top_k for this call. + + Returns: + ``{"documents": [Document, ...]}`` + """ + merged = _merge_filters(self.filters, filters) + docs = self.document_store._embedding_retrieval( + query_embedding, + filters=merged, + top_k=top_k if top_k is not None else self.top_k, + ) + return {"documents": docs} + + @component.output_types(documents=list[Document]) + async def run_async( + self, + query_embedding: list[float], + filters: dict[str, Any] | None = None, + top_k: int | None = None, + ) -> dict[str, list[Document]]: + """Async variant of :meth:`run`.""" + merged = _merge_filters(self.filters, filters) + docs = await self.document_store._async_embedding_retrieval( + query_embedding, + filters=merged, + top_k=top_k if top_k is not None else self.top_k, + ) + return {"documents": docs} + + def to_dict(self) -> dict[str, Any]: + return default_to_dict( + self, + document_store=self.document_store.to_dict(), + filters=self.filters, + top_k=self.top_k, + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "OracleEmbeddingRetriever": + params = data.get("init_parameters", {}) + if "document_store" in params: + params["document_store"] = OracleDocumentStore.from_dict(params["document_store"]) + return default_from_dict(cls, data) diff --git a/integrations/oracle/src/haystack_integrations/document_stores/oracle/__about__.py b/integrations/oracle/src/haystack_integrations/document_stores/oracle/__about__.py new file mode 100644 index 0000000000..3dc1f76bc6 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/document_stores/oracle/__about__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/integrations/oracle/src/haystack_integrations/document_stores/oracle/__init__.py b/integrations/oracle/src/haystack_integrations/document_stores/oracle/__init__.py new file mode 100644 index 0000000000..f024a242ef --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/document_stores/oracle/__init__.py @@ -0,0 +1,6 @@ +from haystack_integrations.document_stores.oracle.document_store import ( + OracleConnectionConfig, + OracleDocumentStore, +) + +__all__ = ["OracleConnectionConfig", "OracleDocumentStore"] diff --git a/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py b/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py new file mode 100644 index 0000000000..bfff4d3612 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py @@ -0,0 +1,559 @@ +from __future__ import annotations + +import array as _array +import asyncio +import json +import logging +import re +import threading +from dataclasses import dataclass +from typing import Any, Literal + +import oracledb +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.errors import DuplicateDocumentError +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + +# Oracle vector_distance() returns negative dot product for DOT metric, +# so lower values = more similar for all three metrics → always ASC. +_DISTANCE_ORDER: dict[str, str] = { + "COSINE": "ASC", + "EUCLIDEAN": "ASC", + "DOT": "ASC", +} + +_SAFE_TABLE_NAME = re.compile(r"^[A-Za-z_][A-Za-z0-9_$#]{0,127}$") + + +# --------------------------------------------------------------------------- +# Connection config +# --------------------------------------------------------------------------- + + +@dataclass +class OracleConnectionConfig: + """Connection parameters for Oracle Database. + + Supports both thin (direct TCP) and thick (wallet / ADB-S) modes. + Thin mode requires no Oracle Instant Client; thick mode is activated + automatically when *wallet_location* is provided. + """ + + user: str + password: Secret + dsn: str + wallet_location: str | None = None + wallet_password: Secret | None = None + min_connections: int = 1 + max_connections: int = 5 + + def to_dict(self) -> dict[str, Any]: + return { + "user": self.user, + "password": self.password.to_dict(), + "dsn": self.dsn, + "wallet_location": self.wallet_location, + "wallet_password": self.wallet_password.to_dict() if self.wallet_password else None, + "min_connections": self.min_connections, + "max_connections": self.max_connections, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "OracleConnectionConfig": + deserialize_secrets_inplace(data, keys=["password", "wallet_password"]) + return cls(**data) + + +# --------------------------------------------------------------------------- +# Filter translator +# --------------------------------------------------------------------------- + + +class _FilterTranslator: + """Translates Haystack 2.x filter dicts into Oracle SQL WHERE fragments. + + Example input: + {"operator": "AND", "conditions": [ + {"field": "meta.author", "operator": "==", "value": "Alice"}, + {"field": "meta.year", "operator": ">", "value": 2020}, + ]} + + Example output SQL fragment: + (JSON_VALUE(metadata, '$.author') = :p0 + AND TO_NUMBER(JSON_VALUE(metadata, '$.year')) > :p1) + + Params dict is mutated in-place; caller passes an empty dict and uses it + for cursor.execute / cursor.executemany bindings. + """ + + _OP_MAP: dict[str, str] = { + "==": "=", + "!=": "!=", + ">": ">", + ">=": ">=", + "<": "<", + "<=": "<=", + } + + def translate( + self, + filters: dict[str, Any], + params: dict[str, Any], + counter: list[int], + ) -> str: + op = filters.get("operator") + + # Logical nodes + if op == "AND": + parts = [self.translate(c, params, counter) for c in filters["conditions"]] + return "(" + " AND ".join(parts) + ")" + if op == "OR": + parts = [self.translate(c, params, counter) for c in filters["conditions"]] + return "(" + " OR ".join(parts) + ")" + if op == "NOT": + inner = self.translate(filters["conditions"][0], params, counter) + return f"(NOT {inner})" + + # Comparison leaf + field: str = filters["field"] + value: Any = filters["value"] + col = self._field_to_sql(field, value) + + if op in ("in", "not in"): + placeholders = [] + for v in value: + pname = f"p{counter[0]}" + counter[0] += 1 + params[pname] = v + placeholders.append(f":{pname}") + sql_op = "IN" if op == "in" else "NOT IN" + return f"{col} {sql_op} ({', '.join(placeholders)})" + + pname = f"p{counter[0]}" + counter[0] += 1 + params[pname] = value + sql_op = self._OP_MAP[op] + return f"{col} {sql_op} :{pname}" + + def _field_to_sql(self, field: str, value: Any) -> str: + if field == "id": + return "id" + if field == "content": + return "text" + if field.startswith("meta."): + key = field[len("meta.") :] + json_path = f"JSON_VALUE(metadata, '$.{key}')" + if isinstance(value, (int, float)) and not isinstance(value, bool): + return f"TO_NUMBER({json_path})" + return json_path + # Fallback: treat as top-level JSON key + json_path = f"JSON_VALUE(metadata, '$.{field}')" + if isinstance(value, (int, float)) and not isinstance(value, bool): + return f"TO_NUMBER({json_path})" + return json_path + + +# --------------------------------------------------------------------------- +# Document store +# --------------------------------------------------------------------------- + + +class OracleDocumentStore: + """Haystack DocumentStore backed by Oracle AI Vector Search. + + Requires Oracle Database 23ai or later (for VECTOR data type and + IF NOT EXISTS DDL support). + + Usage:: + + from haystack.utils import Secret + from haystack_integrations.document_stores.oracle import ( + OracleDocumentStore, OracleConnectionConfig, + ) + + store = OracleDocumentStore( + connection_config=OracleConnectionConfig( + user="scott", + password=Secret.from_env_var("ORACLE_PASSWORD"), + dsn="localhost:1521/freepdb1", + ), + embedding_dim=1536, + ) + """ + + def __init__( + self, + *, + connection_config: OracleConnectionConfig, + table_name: str = "haystack_documents", + embedding_dim: int, + distance_metric: Literal["COSINE", "EUCLIDEAN", "DOT"] = "COSINE", + create_table_if_not_exists: bool = True, + create_index: bool = False, + hnsw_neighbors: int = 32, + hnsw_ef_construction: int = 200, + hnsw_accuracy: int = 95, + hnsw_parallel: int = 4, + ) -> None: + if not _SAFE_TABLE_NAME.match(table_name): + msg = ( + f"Invalid table_name {table_name!r}. Must be a valid Oracle identifier " + "(letters, digits, _, $, # — max 128 chars, cannot start with a digit)." + ) + raise ValueError(msg) + if embedding_dim <= 0: + raise ValueError(f"embedding_dim must be a positive integer, got {embedding_dim}") + + self.connection_config = connection_config + self.table_name = table_name + self.embedding_dim = embedding_dim + self.distance_metric = distance_metric + self.create_table_if_not_exists = create_table_if_not_exists + self.create_index = create_index + self.hnsw_neighbors = hnsw_neighbors + self.hnsw_ef_construction = hnsw_ef_construction + self.hnsw_accuracy = hnsw_accuracy + self.hnsw_parallel = hnsw_parallel + + self._pool: oracledb.ConnectionPool | None = None + self._pool_lock = threading.Lock() + + if create_table_if_not_exists: + self._ensure_table() + if create_index: + self.create_hnsw_index() + + # ------------------------------------------------------------------ + # Connection pool + # ------------------------------------------------------------------ + + def _get_pool(self) -> oracledb.ConnectionPool: + if self._pool is not None: + return self._pool + with self._pool_lock: + if self._pool is not None: + return self._pool + + cfg = self.connection_config + password = cfg.password.resolve_value() + + connect_kwargs: dict[str, Any] = { + "user": cfg.user, + "password": password, + "dsn": cfg.dsn, + "min": cfg.min_connections, + "max": cfg.max_connections, + "increment": 1, + } + if cfg.wallet_location: + connect_kwargs["config_dir"] = cfg.wallet_location + connect_kwargs["wallet_location"] = cfg.wallet_location + connect_kwargs["wallet_password"] = ( + cfg.wallet_password.resolve_value() if cfg.wallet_password else password + ) + + self._pool = oracledb.create_pool(**connect_kwargs) + return self._pool + + def _get_connection(self) -> oracledb.Connection: + return self._get_pool().acquire() + + def __del__(self) -> None: + if self._pool is not None: + try: + self._pool.close() + except Exception: + pass + + # ------------------------------------------------------------------ + # DDL + # ------------------------------------------------------------------ + + def _ensure_table(self) -> None: + sql = f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id VARCHAR2(64) PRIMARY KEY, + text CLOB, + metadata JSON, + embedding VECTOR({self.embedding_dim}, FLOAT32) + ) + """ + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql) + conn.commit() + + def create_hnsw_index(self) -> None: + """Create an HNSW vector index on the embedding column. + + Safe to call multiple times — uses IF NOT EXISTS. + """ + sql = f""" + CREATE VECTOR INDEX IF NOT EXISTS {self.table_name}_vidx + ON {self.table_name}(embedding) + ORGANIZATION INMEMORY NEIGHBOR GRAPH + WITH TARGET ACCURACY {self.hnsw_accuracy} + DISTANCE {self.distance_metric} + PARAMETERS (type HNSW, neighbors {self.hnsw_neighbors}, + efConstruction {self.hnsw_ef_construction}) + PARALLEL {self.hnsw_parallel} + """ + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql) + conn.commit() + + async def acreate_hnsw_index(self) -> None: + await asyncio.to_thread(self.create_hnsw_index) + + # ------------------------------------------------------------------ + # Write + # ------------------------------------------------------------------ + + def write_documents( + self, + documents: list[Document], + policy: DuplicatePolicy = DuplicatePolicy.NONE, + ) -> int: + if not documents: + return 0 + if policy == DuplicatePolicy.NONE: + return self._insert_documents(documents) + if policy == DuplicatePolicy.SKIP: + return self._skip_duplicate_documents(documents) + if policy == DuplicatePolicy.OVERWRITE: + return self._upsert_documents(documents) + msg = f"Unknown DuplicatePolicy: {policy}" + raise ValueError(msg) + + def _to_row(self, doc: Document) -> tuple[str, str | None, str, bytes | None]: + """Convert a Document to (id, text, metadata_json, embedding_bytes). + + Haystack IDs are stored verbatim in a VARCHAR2(64) column, so any + string ID (UUID, SHA-256 hash, or custom) is accepted without conversion. + """ + doc_id = doc.id + text = doc.content + meta = json.dumps(doc.meta or {}) + emb: bytes | None = None + if doc.embedding is not None: + emb = _array.array("f", doc.embedding) # type: ignore[assignment] + return doc_id, text, meta, emb + + def _to_named_row(self, doc: Document) -> dict[str, Any]: + doc_id, text, meta, emb = self._to_row(doc) + return {"doc_id": doc_id, "doc_text": text, "doc_meta": meta, "doc_emb": emb} + + def _insert_documents(self, documents: list[Document]) -> int: + sql = f""" + INSERT INTO {self.table_name} (id, text, metadata, embedding) + VALUES (:doc_id, :doc_text, :doc_meta, :doc_emb) + """ + rows = [self._to_named_row(d) for d in documents] + try: + with self._get_connection() as conn, conn.cursor() as cur: + cur.executemany(sql, rows) + conn.commit() + except oracledb.IntegrityError as exc: + raise DuplicateDocumentError( + f"Document already exists. Use DuplicatePolicy.OVERWRITE or SKIP. Original error: {exc}" + ) from exc + return len(rows) + + def _skip_duplicate_documents(self, documents: list[Document]) -> int: + # MERGE rowcount in Oracle reflects rows touched, not just inserted. + # Count before/after to return an accurate number of newly written docs. + sql = f""" + MERGE INTO {self.table_name} t + USING (SELECT :doc_id AS id FROM dual) s ON (t.id = s.id) + WHEN NOT MATCHED THEN + INSERT (id, text, metadata, embedding) + VALUES (s.id, :doc_text, :doc_meta, :doc_emb) + """ + rows = [self._to_named_row(d) for d in documents] + with self._get_connection() as conn, conn.cursor() as cur: + count_before = conn.cursor().execute(f"SELECT COUNT(*) FROM {self.table_name}").fetchone()[0] + cur.executemany(sql, rows) + count_after = conn.cursor().execute(f"SELECT COUNT(*) FROM {self.table_name}").fetchone()[0] + conn.commit() + return count_after - count_before + + def _upsert_documents(self, documents: list[Document]) -> int: + sql = f""" + MERGE INTO {self.table_name} t + USING (SELECT :doc_id AS id FROM dual) s ON (t.id = s.id) + WHEN MATCHED THEN + UPDATE SET t.text = :doc_text, t.metadata = :doc_meta, t.embedding = :doc_emb + WHEN NOT MATCHED THEN + INSERT (id, text, metadata, embedding) + VALUES (s.id, :doc_text, :doc_meta, :doc_emb) + """ + rows = [self._to_named_row(d) for d in documents] + with self._get_connection() as conn, conn.cursor() as cur: + cur.executemany(sql, rows) + written = cur.rowcount + conn.commit() + return written + + async def awrite_documents( + self, + documents: list[Document], + policy: DuplicatePolicy = DuplicatePolicy.NONE, + ) -> int: + return await asyncio.to_thread(self.write_documents, documents, policy) + + # ------------------------------------------------------------------ + # Filter + # ------------------------------------------------------------------ + + def _build_where(self, filters: dict[str, Any] | None) -> tuple[str, dict[str, Any]]: + if not filters: + return "", {} + params: dict[str, Any] = {} + counter = [0] + fragment = _FilterTranslator().translate(filters, params, counter) + return f"WHERE {fragment}", params + + def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]: + where, params = self._build_where(filters) + sql = f"SELECT id, text, metadata FROM {self.table_name} {where}" + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql, params) + rows = cur.fetchall() + return [self._row_to_document(r) for r in rows] + + async def afilter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]: + return await asyncio.to_thread(self.filter_documents, filters) + + # ------------------------------------------------------------------ + # Delete + # ------------------------------------------------------------------ + + def delete_documents(self, document_ids: list[str]) -> None: + if not document_ids: + return + placeholders = ", ".join(f":p{i}" for i in range(len(document_ids))) + sql = f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})" + params = {f"p{i}": doc_id for i, doc_id in enumerate(document_ids)} + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql, params) + conn.commit() + + async def adelete_documents(self, document_ids: list[str]) -> None: + await asyncio.to_thread(self.delete_documents, document_ids) + + # ------------------------------------------------------------------ + # Count + # ------------------------------------------------------------------ + + def count_documents(self) -> int: + sql = f"SELECT COUNT(*) FROM {self.table_name}" + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql) + row = cur.fetchone() + return row[0] if row else 0 + + async def acount_documents(self) -> int: + return await asyncio.to_thread(self.count_documents) + + # ------------------------------------------------------------------ + # Embedding retrieval + # ------------------------------------------------------------------ + + def _embedding_retrieval( + self, + query_embedding: list[float], + *, + filters: dict[str, Any] | None = None, + top_k: int = 10, + ) -> list[Document]: + order = _DISTANCE_ORDER[self.distance_metric] + where, params = self._build_where(filters) + sql = f""" + SELECT id, text, metadata, + vector_distance(embedding, :query_vec, {self.distance_metric}) AS score + FROM {self.table_name} + {where} + ORDER BY score {order} + FETCH APPROX FIRST :top_k ROWS ONLY + """ + params["query_vec"] = _array.array("f", query_embedding) + params["top_k"] = top_k + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql, params) + rows = cur.fetchall() + return [self._row_to_document(r, with_score=True) for r in rows] + + async def _async_embedding_retrieval( + self, + query_embedding: list[float], + *, + filters: dict[str, Any] | None = None, + top_k: int = 10, + ) -> list[Document]: + return await asyncio.to_thread( + self._embedding_retrieval, + query_embedding, + filters=filters, + top_k=top_k, + ) + + # ------------------------------------------------------------------ + # Row conversion + # ------------------------------------------------------------------ + + def _row_to_document(self, row: tuple, *, with_score: bool = False) -> Document: + if with_score: + raw_id, text, metadata_raw, score = row + else: + raw_id, text, metadata_raw, score = *row, None + + # oracledb returns CLOB/JSON as LOB objects — read them to strings + if hasattr(text, "read"): + text = text.read() + if hasattr(metadata_raw, "read"): + metadata_raw = metadata_raw.read() + + if isinstance(metadata_raw, str): + meta = json.loads(metadata_raw) + elif isinstance(metadata_raw, dict): + meta = metadata_raw + else: + meta = {} + + return Document( + id=raw_id.upper() if raw_id else None, + content=text, + meta=meta, + score=float(score) if score is not None else None, + embedding=None, + blob=None, + ) + + # ------------------------------------------------------------------ + # Serialization + # ------------------------------------------------------------------ + + def to_dict(self) -> dict[str, Any]: + return default_to_dict( + self, + connection_config=self.connection_config.to_dict(), + table_name=self.table_name, + embedding_dim=self.embedding_dim, + distance_metric=self.distance_metric, + create_table_if_not_exists=self.create_table_if_not_exists, + create_index=self.create_index, + hnsw_neighbors=self.hnsw_neighbors, + hnsw_ef_construction=self.hnsw_ef_construction, + hnsw_accuracy=self.hnsw_accuracy, + hnsw_parallel=self.hnsw_parallel, + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "OracleDocumentStore": + params = data.get("init_parameters", {}) + if "connection_config" in params: + params["connection_config"] = OracleConnectionConfig.from_dict(params["connection_config"]) + return default_from_dict(cls, data) diff --git a/integrations/oracle/tests/__init__.py b/integrations/oracle/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/oracle/tests/e2e_real_data.py b/integrations/oracle/tests/e2e_real_data.py new file mode 100644 index 0000000000..a29ee9b4fa --- /dev/null +++ b/integrations/oracle/tests/e2e_real_data.py @@ -0,0 +1,209 @@ +"""End-to-end test with a real HuggingFace dataset and real embeddings. + +Uses: + - Dataset: `squad` (Stanford QA) — real Wikipedia passages + - Embedder: `sentence-transformers/all-MiniLM-L6-v2` (384-dim, fast, no API key) + - Store: OracleDocumentStore → Oracle AI Database 26ai on OCI Free Tier + +Run: + ORACLE_USER=ADMIN \ + ORACLE_PASSWORD=... \ + ORACLE_DSN=deepresearch_low \ + ORACLE_WALLET_LOCATION=~/.oracle/wallet_deepresearch \ + ORACLE_WALLET_PASSWORD=... \ + python tests/e2e_real_data.py +""" + +import os +import time + +from datasets import load_dataset +from fastembed import TextEmbedding +from haystack.dataclasses import Document +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret + +from haystack_integrations.components.retrievers.oracle import OracleEmbeddingRetriever +from haystack_integrations.document_stores.oracle import ( + OracleConnectionConfig, + OracleDocumentStore, +) + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +TABLE = "e2e_squad_test" +EMBED_MODEL = "BAAI/bge-small-en-v1.5" # 384-dim, ONNX, ~25MB, no PyTorch +EMBED_DIM = 384 +N_DOCS = 200 # subset of SQuAD to keep it fast +TOP_K = 5 + +QUERIES = [ + "What is the capital of France?", + "Who invented the telephone?", + "How does photosynthesis work?", + "What caused the First World War?", + "Who wrote Romeo and Juliet?", +] + + +def build_store() -> OracleDocumentStore: + return OracleDocumentStore( + connection_config=OracleConnectionConfig( + user=os.environ["ORACLE_USER"], + password=Secret.from_env_var("ORACLE_PASSWORD"), + dsn=os.environ["ORACLE_DSN"], + wallet_location=os.environ.get("ORACLE_WALLET_LOCATION"), + wallet_password=( + Secret.from_env_var("ORACLE_WALLET_PASSWORD") if os.environ.get("ORACLE_WALLET_PASSWORD") else None + ), + ), + table_name=TABLE, + embedding_dim=EMBED_DIM, + distance_metric="COSINE", + create_table_if_not_exists=True, + ) + + +def load_squad_passages(n: int) -> list[dict]: + """Load unique Wikipedia passages from SQuAD validation set.""" + print(f"Loading SQuAD dataset (first {n} unique passages)...") + ds = load_dataset("rajpurkar/squad", split="validation", trust_remote_code=True) + seen, passages = set(), [] + for row in ds: + ctx = row["context"].strip() + if ctx not in seen: + seen.add(ctx) + passages.append( + { + "text": ctx, + "title": row["title"], + "id": row["id"], + } + ) + if len(passages) >= n: + break + print(f" Loaded {len(passages)} unique passages") + return passages + + +def embed(model: TextEmbedding, texts: list[str]) -> list[list[float]]: + return [v.tolist() for v in model.embed(texts)] + + +def main() -> None: + # ------------------------------------------------------------------ + # 1. Load model + # ------------------------------------------------------------------ + print(f"\n{'=' * 60}") + print(f"Loading embedding model: {EMBED_MODEL}") + model = TextEmbedding(model_name=EMBED_MODEL) + print(" Model loaded") + + # ------------------------------------------------------------------ + # 2. Load dataset + # ------------------------------------------------------------------ + passages = load_squad_passages(N_DOCS) + + # ------------------------------------------------------------------ + # 3. Embed + # ------------------------------------------------------------------ + print(f"\nEmbedding {len(passages)} passages...") + t0 = time.perf_counter() + texts = [p["text"] for p in passages] + embeddings = embed(model, texts) + print(f" Done in {time.perf_counter() - t0:.1f}s") + + # ------------------------------------------------------------------ + # 4. Build Haystack Documents + # ------------------------------------------------------------------ + documents = [ + Document( + content=p["text"], + meta={"title": p["title"], "squad_id": p["id"]}, + embedding=emb, + ) + for p, emb in zip(passages, embeddings, strict=True) + ] + + # ------------------------------------------------------------------ + # 5. Connect to Oracle and write + # ------------------------------------------------------------------ + print(f"\nConnecting to Oracle ADB ({os.environ['ORACLE_DSN']})...") + store = build_store() + print(f" Connected — table: {TABLE}") + + print(f"\nWriting {len(documents)} documents...") + t0 = time.perf_counter() + written = store.write_documents(documents, policy=DuplicatePolicy.OVERWRITE) + elapsed = time.perf_counter() - t0 + print(f" Written: {written} docs in {elapsed:.1f}s ({written / elapsed:.0f} docs/sec)") + + total = store.count_documents() + print(f" Total in table: {total}") + + # ------------------------------------------------------------------ + # 6. Create HNSW index + # ------------------------------------------------------------------ + print("\nCreating HNSW index...") + t0 = time.perf_counter() + store.create_hnsw_index() + print(f" Index created in {time.perf_counter() - t0:.1f}s") + + # ------------------------------------------------------------------ + # 7. Query + # ------------------------------------------------------------------ + retriever = OracleEmbeddingRetriever(document_store=store, top_k=TOP_K) + + print(f"\n{'=' * 60}") + print("RETRIEVAL RESULTS") + print(f"{'=' * 60}") + + for query in QUERIES: + query_emb = list(model.embed([query]))[0].tolist() + t0 = time.perf_counter() + result = retriever.run(query_embedding=query_emb) + latency_ms = (time.perf_counter() - t0) * 1000 + + print(f'\nQuery: "{query}" [{latency_ms:.0f}ms]') + print("-" * 60) + for i, doc in enumerate(result["documents"], 1): + snippet = doc.content[:120].replace("\n", " ") + print(f" {i}. [{doc.score:.4f}] {doc.meta['title']}") + print(f" {snippet}...") + + # ------------------------------------------------------------------ + # 8. Filter test — only return passages about a specific topic + # ------------------------------------------------------------------ + print(f"\n{'=' * 60}") + print("FILTERED RETRIEVAL — only 'Beyoncé' passages") + print(f"{'=' * 60}") + query = "Who is Beyoncé?" + query_emb = list(model.embed([query]))[0].tolist() + result = retriever.run( + query_embedding=query_emb, + filters={"field": "meta.title", "operator": "==", "value": "Beyoncé"}, + ) + if result["documents"]: + for doc in result["documents"]: + print(f" [{doc.score:.4f}] {doc.content[:150].replace(chr(10), ' ')}...") + else: + print(" (no Beyoncé passages in this subset)") + + # ------------------------------------------------------------------ + # 9. Cleanup + # ------------------------------------------------------------------ + print(f"\n{'=' * 60}") + answer = input("Drop test table? [y/N] ").strip().lower() + if answer == "y": + with store._get_connection() as conn, conn.cursor() as cur: + cur.execute(f"DROP TABLE {TABLE} PURGE") + conn.commit() + print(f" Table {TABLE} dropped.") + else: + print(f" Table {TABLE} kept — {total} docs remain in Oracle.") + + +if __name__ == "__main__": + main() diff --git a/integrations/oracle/tests/integration/__init__.py b/integrations/oracle/tests/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/oracle/tests/integration/test_oracle_document_store.py b/integrations/oracle/tests/integration/test_oracle_document_store.py new file mode 100644 index 0000000000..0f90c80d1d --- /dev/null +++ b/integrations/oracle/tests/integration/test_oracle_document_store.py @@ -0,0 +1,212 @@ +"""Integration tests against a live Oracle 23ai instance. + +Required environment variables: + ORACLE_USER — database username + ORACLE_PASSWORD — database password + ORACLE_DSN — e.g. localhost:1521/freepdb1 + +Optional (for ADB-S / wallet connections): + ORACLE_WALLET_LOCATION + ORACLE_WALLET_PASSWORD + +Run with: + pytest tests/integration/ -v +""" + +from __future__ import annotations + +import os +from uuid import uuid4 + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret + +from haystack_integrations.document_stores.oracle import OracleConnectionConfig, OracleDocumentStore + +pytestmark = pytest.mark.skipif( + not os.getenv("ORACLE_DSN"), + reason="ORACLE_DSN not set — skipping Oracle integration tests", +) + + +def _unique_table() -> str: + return f"hs_test_{uuid4().hex[:8]}" + + +@pytest.fixture(scope="module") +def store(): + table = _unique_table() + s = OracleDocumentStore( + connection_config=OracleConnectionConfig( + user=os.environ["ORACLE_USER"], + password=Secret.from_env_var("ORACLE_PASSWORD"), + dsn=os.environ["ORACLE_DSN"], + wallet_location=os.getenv("ORACLE_WALLET_LOCATION"), + wallet_password=( + Secret.from_env_var("ORACLE_WALLET_PASSWORD") if os.getenv("ORACLE_WALLET_PASSWORD") else None + ), + ), + table_name=table, + embedding_dim=4, + distance_metric="COSINE", + create_table_if_not_exists=True, + ) + yield s + # Teardown + with s._get_connection() as conn, conn.cursor() as cur: + cur.execute(f"DROP TABLE {table} PURGE") + conn.commit() + + +def _docs(n: int = 3) -> list[Document]: + return [ + Document( + id=uuid4().hex.upper()[:32], + content=f"document {i}", + meta={"index": i, "lang": "en"}, + embedding=[float(i), float(i + 1), float(i + 2), float(i + 3)], + ) + for i in range(n) + ] + + +def test_create_table_idempotent(store): + store._ensure_table() # should not raise + + +def test_write_and_count(store): + docs = _docs(3) + store.write_documents(docs) + assert store.count_documents() >= 3 + + +def test_filter_documents_no_filter(store): + store.write_documents(_docs(2)) + all_docs = store.filter_documents() + assert len(all_docs) >= 2 + + +def test_filter_documents_equality(store): + unique_lang = f"lang_{uuid4().hex[:6]}" + doc = Document( + id=uuid4().hex.upper()[:32], + content="unique lang doc", + meta={"lang": unique_lang}, + embedding=[1.0, 0.0, 0.0, 0.0], + ) + store.write_documents([doc]) + results = store.filter_documents(filters={"field": "meta.lang", "operator": "==", "value": unique_lang}) + assert len(results) == 1 + assert results[0].meta["lang"] == unique_lang + + +def test_filter_documents_in_operator(store): + tag = uuid4().hex[:6] + docs = [ + Document(id=uuid4().hex.upper()[:32], content="a", meta={"tag": f"{tag}_a"}, embedding=[1.0, 0.0, 0.0, 0.0]), + Document(id=uuid4().hex.upper()[:32], content="b", meta={"tag": f"{tag}_b"}, embedding=[0.0, 1.0, 0.0, 0.0]), + Document(id=uuid4().hex.upper()[:32], content="c", meta={"tag": f"{tag}_c"}, embedding=[0.0, 0.0, 1.0, 0.0]), + ] + store.write_documents(docs) + results = store.filter_documents(filters={"field": "meta.tag", "operator": "in", "value": [f"{tag}_a", f"{tag}_b"]}) + assert len(results) == 2 + + +def test_write_duplicate_none_policy_raises(store): + from haystack.document_stores.errors import DuplicateDocumentError + + doc = Document(id=uuid4().hex.upper()[:32], content="dup", meta={}, embedding=[1.0, 0.0, 0.0, 0.0]) + store.write_documents([doc]) + with pytest.raises(DuplicateDocumentError): + store.write_documents([doc], policy=DuplicatePolicy.NONE) + + +def test_write_duplicate_skip_policy_silently_ignores(store): + doc = Document(id=uuid4().hex.upper()[:32], content="skip-me", meta={}, embedding=[1.0, 0.0, 0.0, 0.0]) + store.write_documents([doc]) + count_before = store.count_documents() + store.write_documents([doc], policy=DuplicatePolicy.SKIP) + assert store.count_documents() == count_before + + +def test_write_duplicate_overwrite_policy_updates_content(store): + doc_id = uuid4().hex.upper()[:32] + doc = Document(id=doc_id, content="original", meta={}, embedding=[1.0, 0.0, 0.0, 0.0]) + store.write_documents([doc]) + updated = Document(id=doc_id, content="updated", meta={}, embedding=[1.0, 0.0, 0.0, 0.0]) + store.write_documents([updated], policy=DuplicatePolicy.OVERWRITE) + results = store.filter_documents(filters={"field": "id", "operator": "==", "value": doc_id}) + assert results[0].content == "updated" + + +def test_delete_documents(store): + doc_id = uuid4().hex.upper()[:32] + doc = Document(id=doc_id, content="to delete", meta={}, embedding=[1.0, 0.0, 0.0, 0.0]) + store.write_documents([doc]) + store.delete_documents([doc_id]) + results = store.filter_documents(filters={"field": "id", "operator": "==", "value": doc_id}) + assert len(results) == 0 + + +def test_embedding_retrieval_returns_ordered_results(store): + tag = uuid4().hex[:6] + docs = [ + Document(id=uuid4().hex.upper()[:32], content="near", meta={"tag": tag}, embedding=[1.0, 0.0, 0.0, 0.0]), + Document(id=uuid4().hex.upper()[:32], content="far", meta={"tag": tag}, embedding=[0.0, 0.0, 0.0, 1.0]), + Document(id=uuid4().hex.upper()[:32], content="medium", meta={"tag": tag}, embedding=[0.7, 0.7, 0.0, 0.0]), + ] + store.write_documents(docs) + results = store._embedding_retrieval( + [1.0, 0.0, 0.0, 0.0], + filters={"field": "meta.tag", "operator": "==", "value": tag}, + top_k=3, + ) + assert results[0].content == "near" + assert len(results) == 3 + + +def test_embedding_retrieval_with_filter(store): + tag = uuid4().hex[:6] + docs = [ + Document( + id=uuid4().hex.upper()[:32], content="en", meta={"tag": tag, "lang": "en"}, embedding=[1.0, 0.0, 0.0, 0.0] + ), # noqa: E501 + Document( + id=uuid4().hex.upper()[:32], content="de", meta={"tag": tag, "lang": "de"}, embedding=[1.0, 0.0, 0.0, 0.0] + ), # noqa: E501 + ] + store.write_documents(docs) + results = store._embedding_retrieval( + [1.0, 0.0, 0.0, 0.0], + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.tag", "operator": "==", "value": tag}, + {"field": "meta.lang", "operator": "==", "value": "en"}, + ], + }, + top_k=10, + ) + assert all(d.meta["lang"] == "en" for d in results) + + +def test_hnsw_index_creation(store): + store.create_hnsw_index() + with store._get_connection() as conn, conn.cursor() as cur: + cur.execute( + "SELECT COUNT(*) FROM USER_INDEXES WHERE INDEX_NAME = :1", + [f"{store.table_name.upper()}_VIDX"], + ) + count = cur.fetchone()[0] + assert count == 1 + + +@pytest.mark.asyncio +async def test_async_write_and_retrieve(store): + doc_id = uuid4().hex.upper()[:32] + doc = Document(id=doc_id, content="async test", meta={}, embedding=[0.5, 0.5, 0.0, 0.0]) + await store.awrite_documents([doc]) + results = await store._async_embedding_retrieval([0.5, 0.5, 0.0, 0.0], top_k=1) + assert len(results) >= 1 diff --git a/integrations/oracle/tests/unit/__init__.py b/integrations/oracle/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/oracle/tests/unit/test_document_store.py b/integrations/oracle/tests/unit/test_document_store.py new file mode 100644 index 0000000000..f96dd9f5d7 --- /dev/null +++ b/integrations/oracle/tests/unit/test_document_store.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.errors import DuplicateDocumentError +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret + +from haystack_integrations.document_stores.oracle import OracleConnectionConfig, OracleDocumentStore + + +@pytest.fixture() +def mock_pool(monkeypatch): + """Patch oracledb.create_pool to return a mock pool with a mock connection/cursor.""" + cursor = MagicMock() + cursor.fetchall.return_value = [] + cursor.fetchone.return_value = (0,) + cursor.rowcount = 1 + + conn = MagicMock() + conn.cursor.return_value.__enter__ = lambda s: cursor + conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + pool = MagicMock() + pool.acquire.return_value.__enter__ = lambda s: conn + pool.acquire.return_value.__exit__ = MagicMock(return_value=False) + + monkeypatch.setattr( + "haystack_integrations.document_stores.oracle.document_store.oracledb.create_pool", + lambda **kw: pool, + ) + return pool, conn, cursor + + +@pytest.fixture() +def store(mock_pool, monkeypatch): + monkeypatch.setenv("ORACLE_PASSWORD", "p") + return OracleDocumentStore( + connection_config=OracleConnectionConfig( + user="u", + password=Secret.from_env_var("ORACLE_PASSWORD"), + dsn="localhost/xe", + ), + table_name="test_docs", + embedding_dim=4, + create_table_if_not_exists=False, + ) + + +def _doc(content="hello", embedding=None, doc_id="AABB" * 8): + return Document(id=doc_id, content=content, meta={"k": "v"}, embedding=embedding) + + +# ------------------------------------------------------------------ +# write_documents +# ------------------------------------------------------------------ + + +def test_write_documents_none_policy_calls_insert(store, mock_pool): + _, _, cursor = mock_pool + store.write_documents([_doc()], policy=DuplicatePolicy.NONE) + cursor.executemany.assert_called_once() + sql = cursor.executemany.call_args[0][0] + assert "INSERT INTO" in sql + assert ":doc_id" in sql + + +def test_write_documents_none_policy_duplicate_raises(store, mock_pool): + import oracledb as _oracledb + + _, _, cursor = mock_pool + cursor.executemany.side_effect = _oracledb.IntegrityError("ORA-00001") + with pytest.raises(DuplicateDocumentError): + store.write_documents([_doc()], policy=DuplicatePolicy.NONE) + + +def test_write_documents_skip_policy_uses_merge_not_matched(store, mock_pool): + _, _, cursor = mock_pool + store.write_documents([_doc()], policy=DuplicatePolicy.SKIP) + sql = cursor.executemany.call_args[0][0] + assert "MERGE INTO" in sql + assert "WHEN NOT MATCHED" in sql + assert "WHEN MATCHED" not in sql + + +def test_write_documents_overwrite_policy_uses_full_merge(store, mock_pool): + _, _, cursor = mock_pool + store.write_documents([_doc()], policy=DuplicatePolicy.OVERWRITE) + sql = cursor.executemany.call_args[0][0] + assert "MERGE INTO" in sql + assert "WHEN MATCHED" in sql + assert "WHEN NOT MATCHED" in sql + + +def test_write_documents_returns_count(store, mock_pool): + count = store.write_documents([_doc(), _doc(doc_id="CCDD" * 8)], policy=DuplicatePolicy.NONE) + assert count == 2 + + +def test_write_documents_empty_list_returns_zero(store, mock_pool): + _, _, cursor = mock_pool + count = store.write_documents([], policy=DuplicatePolicy.NONE) + assert count == 0 + cursor.executemany.assert_not_called() + + +# ------------------------------------------------------------------ +# filter_documents +# ------------------------------------------------------------------ + + +def test_filter_documents_no_filter_fetches_all(store, mock_pool): + _, _, cursor = mock_pool + cursor.fetchall.return_value = [ + ("AABB" * 8, "hello", '{"k": "v"}'), + ("CCDD" * 8, "world", "{}"), + ] + docs = store.filter_documents() + assert len(docs) == 2 + sql = cursor.execute.call_args[0][0] + assert "WHERE" not in sql + + +def test_filter_documents_equality_filter_produces_correct_sql(store, mock_pool): + _, _, cursor = mock_pool + cursor.fetchall.return_value = [] + store.filter_documents(filters={"field": "meta.author", "operator": "==", "value": "Alice"}) + sql, params = cursor.execute.call_args[0] + assert "JSON_VALUE(metadata, '$.author') = :p0" in sql + assert params["p0"] == "Alice" + + +def test_filter_documents_and_filter(store, mock_pool): + _, _, cursor = mock_pool + cursor.fetchall.return_value = [] + store.filter_documents( + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.year", "operator": ">", "value": 2020}, + ], + } + ) + sql, params = cursor.execute.call_args[0] + assert "AND" in sql + assert len(params) == 2 + + +# ------------------------------------------------------------------ +# delete_documents +# ------------------------------------------------------------------ + + +def test_delete_documents_builds_correct_sql(store, mock_pool): + _, _, cursor = mock_pool + store.delete_documents(["AABB" * 8, "CCDD" * 8]) + sql = cursor.execute.call_args[0][0] + assert "DELETE FROM" in sql + assert "IN (:p0, :p1)" in sql + + +def test_delete_documents_empty_list_is_noop(store, mock_pool): + _, _, cursor = mock_pool + store.delete_documents([]) + cursor.execute.assert_not_called() + + +# ------------------------------------------------------------------ +# count_documents +# ------------------------------------------------------------------ + + +def test_count_documents_returns_value(store, mock_pool): + _, _, cursor = mock_pool + cursor.fetchone.return_value = (42,) + assert store.count_documents() == 42 + + +# ------------------------------------------------------------------ +# serialization +# ------------------------------------------------------------------ + + +def test_to_dict_does_not_expose_plain_password(store): + d = store.to_dict() + pw = d["init_parameters"]["connection_config"]["password"] + # Secret serializes as {"type": "...", ...} — never a plain string + assert isinstance(pw, dict) + assert pw.get("type") == "env_var" # stored as env-var reference, not plain token + + +def test_from_dict_roundtrip(store): + d = store.to_dict() + restored = OracleDocumentStore.from_dict(d) + assert restored.table_name == store.table_name + assert restored.embedding_dim == store.embedding_dim + assert restored.distance_metric == store.distance_metric + + +# ------------------------------------------------------------------ +# HNSW index SQL shape +# ------------------------------------------------------------------ + + +def test_create_hnsw_index_sql(store, mock_pool): + _, _, cursor = mock_pool + store.create_hnsw_index() + sql = cursor.execute.call_args[0][0] + assert "CREATE VECTOR INDEX" in sql + assert "HNSW" in sql + assert str(store.hnsw_neighbors) in sql + assert str(store.hnsw_ef_construction) in sql diff --git a/integrations/oracle/tests/unit/test_embedding_retriever.py b/integrations/oracle/tests/unit/test_embedding_retriever.py new file mode 100644 index 0000000000..8e3fd6f987 --- /dev/null +++ b/integrations/oracle/tests/unit/test_embedding_retriever.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from haystack.dataclasses import Document + +from haystack_integrations.components.retrievers.oracle import OracleEmbeddingRetriever +from haystack_integrations.document_stores.oracle import OracleDocumentStore + + +@pytest.fixture() +def mock_store(): + store = MagicMock(spec=OracleDocumentStore) + store.distance_metric = "COSINE" + store._embedding_retrieval.return_value = [Document(id="A" * 32, content="hi")] + store._async_embedding_retrieval.return_value = [Document(id="A" * 32, content="hi")] + store.to_dict.return_value = { + "type": "haystack_integrations.document_stores.oracle.document_store.OracleDocumentStore", + "init_parameters": { + "connection_config": { + "user": "u", + "password": {"type": "token", "token": "p"}, + "dsn": "localhost/xe", + "wallet_location": None, + "wallet_password": None, + "min_connections": 1, + "max_connections": 5, + }, + "table_name": "test_docs", + "embedding_dim": 4, + "distance_metric": "COSINE", + "create_table_if_not_exists": False, + "create_index": False, + "hnsw_neighbors": 32, + "hnsw_ef_construction": 200, + "hnsw_accuracy": 95, + "hnsw_parallel": 4, + }, + } + return store + + +def test_run_calls_embedding_retrieval(mock_store): + retriever = OracleEmbeddingRetriever(document_store=mock_store, top_k=5) + result = retriever.run(query_embedding=[0.1, 0.2, 0.3, 0.4]) + mock_store._embedding_retrieval.assert_called_once_with([0.1, 0.2, 0.3, 0.4], filters=None, top_k=5) + assert len(result["documents"]) == 1 + + +def test_run_merges_filters(mock_store): + retriever = OracleEmbeddingRetriever( + document_store=mock_store, + filters={"field": "meta.lang", "operator": "==", "value": "en"}, + ) + retriever.run( + query_embedding=[0.1, 0.2, 0.3, 0.4], + filters={"field": "meta.year", "operator": ">", "value": 2020}, + ) + call_filters = mock_store._embedding_retrieval.call_args.kwargs["filters"] + assert call_filters["operator"] == "AND" + assert len(call_filters["conditions"]) == 2 + + +def test_run_top_k_override(mock_store): + retriever = OracleEmbeddingRetriever(document_store=mock_store, top_k=10) + retriever.run(query_embedding=[0.1, 0.2, 0.3, 0.4], top_k=3) + assert mock_store._embedding_retrieval.call_args.kwargs["top_k"] == 3 + + +def test_to_dict_from_dict_roundtrip(mock_store, monkeypatch): + retriever = OracleEmbeddingRetriever( + document_store=mock_store, + top_k=7, + filters={"field": "meta.x", "operator": "==", "value": "y"}, + ) + d = retriever.to_dict() + assert d["init_parameters"]["top_k"] == 7 + assert d["init_parameters"]["filters"] == {"field": "meta.x", "operator": "==", "value": "y"} + + +@pytest.mark.asyncio +async def test_run_async_calls_async_retrieval(mock_store): + retriever = OracleEmbeddingRetriever(document_store=mock_store, top_k=5) + result = await retriever.run_async(query_embedding=[0.1, 0.2, 0.3, 0.4]) + mock_store._async_embedding_retrieval.assert_called_once() + assert "documents" in result diff --git a/integrations/oracle/tests/unit/test_filter_translator.py b/integrations/oracle/tests/unit/test_filter_translator.py new file mode 100644 index 0000000000..699332767e --- /dev/null +++ b/integrations/oracle/tests/unit/test_filter_translator.py @@ -0,0 +1,132 @@ +from haystack_integrations.document_stores.oracle.document_store import _FilterTranslator + + +def _translate(filters): + params = {} + counter = [0] + sql = _FilterTranslator().translate(filters, params, counter) + return sql, params + + +def test_equality(): + sql, params = _translate({"field": "meta.author", "operator": "==", "value": "Alice"}) + assert "JSON_VALUE(metadata, '$.author') = :p0" in sql + assert params == {"p0": "Alice"} + + +def test_inequality(): + sql, params = _translate({"field": "meta.status", "operator": "!=", "value": "draft"}) + assert "!= :p0" in sql + assert params["p0"] == "draft" + + +def test_greater_than(): + sql, params = _translate({"field": "meta.year", "operator": ">", "value": 2020}) + assert "TO_NUMBER" in sql + assert "> :p0" in sql + assert params["p0"] == 2020 + + +def test_in_operator(): + sql, params = _translate({"field": "meta.lang", "operator": "in", "value": ["en", "de", "fr"]}) + assert "IN (:p0, :p1, :p2)" in sql + assert params == {"p0": "en", "p1": "de", "p2": "fr"} + + +def test_not_in_operator(): + sql, params = _translate({"field": "meta.lang", "operator": "not in", "value": ["xx", "yy"]}) + assert "NOT IN (:p0, :p1)" in sql + + +def test_and_logical(): + sql, params = _translate( + { + "operator": "AND", + "conditions": [ + {"field": "meta.author", "operator": "==", "value": "Alice"}, + {"field": "meta.year", "operator": ">", "value": 2020}, + ], + } + ) + assert sql.startswith("(") + assert " AND " in sql + assert len(params) == 2 + + +def test_or_logical(): + sql, params = _translate( + { + "operator": "OR", + "conditions": [ + {"field": "meta.a", "operator": "==", "value": "x"}, + {"field": "meta.b", "operator": "==", "value": "y"}, + ], + } + ) + assert " OR " in sql + + +def test_not_logical(): + sql, params = _translate( + { + "operator": "NOT", + "conditions": [{"field": "meta.hidden", "operator": "==", "value": True}], + } + ) + assert sql.startswith("(NOT ") + + +def test_nested_and_or(): + sql, params = _translate( + { + "operator": "AND", + "conditions": [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + { + "operator": "OR", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.type", "operator": "==", "value": "blog"}, + ], + }, + ], + } + ) + assert " AND " in sql + assert " OR " in sql + assert len(params) == 3 + + +def test_id_field_maps_to_id_column(): + sql, params = _translate({"field": "id", "operator": "==", "value": "ABCD1234"}) + assert "id = :p0" in sql + + +def test_content_field_maps_to_text(): + sql, params = _translate({"field": "content", "operator": "==", "value": "hello"}) + assert "text = :p0" in sql + + +def test_numeric_value_wraps_in_to_number(): + sql, params = _translate({"field": "meta.count", "operator": ">=", "value": 5}) + assert "TO_NUMBER(" in sql + assert ">= :p0" in sql + + +def test_nested_meta_key(): + sql, params = _translate({"field": "meta.author.city", "operator": "==", "value": "NYC"}) + assert "'$.author.city'" in sql + + +def test_param_counter_increments_correctly(): + sql, params = _translate( + { + "operator": "AND", + "conditions": [ + {"field": "meta.a", "operator": "==", "value": "x"}, + {"field": "meta.b", "operator": "==", "value": "y"}, + {"field": "meta.c", "operator": "==", "value": "z"}, + ], + } + ) + assert set(params.keys()) == {"p0", "p1", "p2"} From b45476ce5c7082123e3af990764598c80c269d67 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 2 Apr 2026 18:35:45 -0400 Subject: [PATCH 2/2] fix: add pydoc config and fix hatch-vcs versioning for API ref CI - Add pydoc/config_docusaurus.yml required by CI_check_api_ref workflow - Add hatch docs/fmt scripts to [tool.hatch.envs.default] - Use tag-pattern + git_describe_command matching repo conventions - Add fallback-version = 0.1.0 for new integration with no tags yet - Add [tool.hatch.envs.test] scripts for unit/integration test runs --- .../oracle/pydoc/config_docusaurus.yml | 14 ++++++++++++ integrations/oracle/pyproject.toml | 22 ++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 integrations/oracle/pydoc/config_docusaurus.yml diff --git a/integrations/oracle/pydoc/config_docusaurus.yml b/integrations/oracle/pydoc/config_docusaurus.yml new file mode 100644 index 0000000000..6dc6f8545d --- /dev/null +++ b/integrations/oracle/pydoc/config_docusaurus.yml @@ -0,0 +1,14 @@ +loaders: + - modules: + - haystack_integrations.components.retrievers.oracle.embedding_retriever + - haystack_integrations.document_stores.oracle.document_store + search_path: [../src] +processors: + - type: filter + documented_only: true + skip_empty_modules: true +renderer: + description: Oracle AI Vector Search integration for Haystack + id: integrations-oracle + filename: oracle.md + title: Oracle AI Vector Search diff --git a/integrations/oracle/pyproject.toml b/integrations/oracle/pyproject.toml index 6dfde413c6..60365fa2de 100644 --- a/integrations/oracle/pyproject.toml +++ b/integrations/oracle/pyproject.toml @@ -41,15 +41,35 @@ dev = [ [tool.hatch.version] source = "vcs" +tag-pattern = 'integrations\/oracle-v(?P.*)' fallback-version = "0.1.0" [tool.hatch.version.raw-options] root = "../.." -version_scheme = "no-guess-dev" +git_describe_command = 'git describe --tags --match="integrations/oracle-v[0-9]*"' [tool.hatch.build.targets.wheel] packages = ["src/haystack_integrations"] +[tool.hatch.envs.default] +installer = "uv" +dependencies = ["haystack-pydoc-tools", "ruff"] + +[tool.hatch.envs.default.scripts] +docs = ["haystack-pydoc pydoc/config_docusaurus.yml"] +fmt = "ruff check --fix {args}; ruff format {args}" + +[tool.hatch.envs.test] +dependencies = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-mock>=3.12.0", +] + +[tool.hatch.envs.test.scripts] +unit = "pytest tests/unit/ -v" +integration = "pytest tests/integration/ -v" + [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto"