Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from haystack.dataclasses import Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.errors import FilterError
from haystack.utils import Secret, deserialize_secrets_inplace

from .filters import FilterTranslator, to_hybrid_filter
from .filters import FilterTranslator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -893,21 +892,26 @@ def _skip_duplicate_documents(self, documents: list[Document]) -> int:
return written

def _upsert_documents(self, documents: list[Document]) -> int:
sql = f"""
MERGE INTO {self.table_name} t
USING (SELECT :doc_id AS id FROM dual) s ON (t.id = s.id)
WHEN MATCHED THEN
UPDATE SET t.text = :doc_text, t.metadata = :doc_meta, t.embedding = :doc_emb
WHEN NOT MATCHED THEN
INSERT (id, text, metadata, embedding)
VALUES (s.id, :doc_text, :doc_meta, :doc_emb)
# A single MERGE combining WHEN MATCHED UPDATE with WHEN NOT MATCHED INSERT raises
# ORA-06531 ("reference to uninitialized collection") inside the DBMS_SEARCH
# keyword-index trigger on Oracle 23ai/26ai — even when every row is an insert.
# Delete-then-insert is an equivalent upsert that avoids the faulty trigger path.
# Rows are de-duplicated by id (last one wins) so a batch that repeats an id cannot
# violate the primary key after the deletes are applied.
rows_by_id: dict[str, dict[str, Any]] = {}
for document in documents:
rows_by_id[document.id] = OracleDocumentStore._to_named_row(document)
rows = list(rows_by_id.values())
delete_sql = f"DELETE FROM {self.table_name} WHERE id = :doc_id"
insert_sql = f"""
INSERT INTO {self.table_name} (id, text, metadata, embedding)
VALUES (:doc_id, :doc_text, :doc_meta, :doc_emb)
"""
rows = [OracleDocumentStore._to_named_row(d) for d in documents]
with self._get_connection() as conn, conn.cursor() as cur:
cur.executemany(sql, rows)
written = cur.rowcount
cur.executemany(delete_sql, [{"doc_id": row["doc_id"]} for row in rows])
cur.executemany(insert_sql, rows)
conn.commit()
return written
return len(rows)

async def write_documents_async(
self,
Expand All @@ -926,12 +930,19 @@ async def write_documents_async(
return await asyncio.to_thread(self.write_documents, documents, policy)

@staticmethod
def _build_where(filters: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
def _build_filter_fragment(filters: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
if not filters:
return "", {}
params: dict[str, Any] = {}
counter = [0]
fragment = FilterTranslator().translate(filters, params, counter)
return fragment, params

@staticmethod
def _build_where(filters: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
fragment, params = OracleDocumentStore._build_filter_fragment(filters)
if not fragment:
return "", {}
return f"WHERE {fragment}", params

def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]:
Expand Down Expand Up @@ -1411,10 +1422,13 @@ def _hybrid_search_params(
*,
index_name: str,
search_mode: Literal["keyword", "hybrid", "semantic"],
filters: dict[str, Any] | None,
top_k: int,
params: dict[str, Any] | None,
) -> dict[str, Any]:
# Haystack metadata filters are applied as a SQL predicate after ranking (see
# _hybrid_retrieval); they are not translated into DBMS_HYBRID_VECTOR ``filter_by``,
# whose paths resolve to base-table columns rather than JSON metadata fields. A native
# ``filter_by`` over declared filterable columns can still be supplied via ``params``.
if search_mode not in _VALID_HYBRID_SEARCH_MODES:
msg = f"search_mode must be one of {_VALID_HYBRID_SEARCH_MODES}, got {search_mode!r}"
raise ValueError(msg)
Expand All @@ -1429,12 +1443,6 @@ def _hybrid_search_params(
search_params["text"] = dict(search_params.get("text") or {})
search_params["text"]["search_text"] = query

if filters:
if "filter_by" in search_params:
msg = "Cannot combine Haystack filters with params['filter_by']."
raise FilterError(msg)
search_params["filter_by"] = to_hybrid_filter(filters)

search_params["return"] = {
"topN": top_k,
"values": ["rowid", "score", "vector_score", "text_score"],
Expand Down Expand Up @@ -1468,25 +1476,32 @@ def _hybrid_retrieval(
query,
index_name=index_name,
search_mode=search_mode,
filters=filters,
top_k=top_k,
params=params,
)
# DBMS_HYBRID_VECTOR ranks the hits; Haystack metadata filters are applied here as a
# SQL predicate while fetching each ranked row. Because filtering happens after ranking,
# fewer than top_k documents may be returned.
filter_fragment, filter_params = OracleDocumentStore._build_filter_fragment(filters)
row_sql = f"SELECT id, text, JSON_SERIALIZE(metadata) AS metadata FROM {self.table_name} WHERE ROWID = :rid"
if filter_fragment:
row_sql += f" AND ({filter_fragment})"

rows: list[tuple[Any, ...]] = []
matched_rows: list[dict[str, Any]] = []
with self._get_connection() as conn, conn.cursor() as cur:
cur.setinputsizes(search_params=oracledb.DB_TYPE_JSON)
cur.execute("SELECT DBMS_HYBRID_VECTOR.SEARCH(JSON(:search_params))", search_params=search_params)
search_rows = self._decode_hybrid_search_result(cur.fetchone()[0])
for row in search_rows:
cur.execute(
f"SELECT id, text, JSON_SERIALIZE(metadata) AS metadata FROM {self.table_name} WHERE ROWID = :rid",
rid=row["rowid"],
)
rows.extend(cur.fetchall())
cur.execute(row_sql, {"rid": row["rowid"], **filter_params})
fetched = cur.fetchall()
if fetched:
rows.extend(fetched)
matched_rows.append(row)

documents = [OracleDocumentStore._row_to_document(row) for row in rows]
self._merge_hybrid_scores(search_rows, documents, return_scores=return_scores)
self._merge_hybrid_scores(matched_rows, documents, return_scores=return_scores)
return documents

async def _hybrid_retrieval_async(
Expand Down Expand Up @@ -1516,12 +1531,17 @@ async def _hybrid_retrieval_async(
query,
index_name=index_name,
search_mode=search_mode,
filters=filters,
top_k=top_k,
params=params,
)
# See _hybrid_retrieval: filters are applied as a SQL predicate after ranking.
filter_fragment, filter_params = OracleDocumentStore._build_filter_fragment(filters)
row_sql = f"SELECT id, text, JSON_SERIALIZE(metadata) AS metadata FROM {self.table_name} WHERE ROWID = :rid"
if filter_fragment:
row_sql += f" AND ({filter_fragment})"

rows: list[tuple[Any, ...]] = []
matched_rows: list[dict[str, Any]] = []
pool = await self._get_async_pool()
async with pool.acquire() as conn:
with conn.cursor() as cur:
Expand All @@ -1534,17 +1554,14 @@ async def _hybrid_retrieval_async(
)
search_rows = await self._decode_hybrid_search_result_async((await _maybe_await(cur.fetchone()))[0])
for row in search_rows:
await _maybe_await(
cur.execute(
"SELECT id, text, JSON_SERIALIZE(metadata) AS metadata "
f"FROM {self.table_name} WHERE ROWID = :rid",
rid=row["rowid"],
)
)
rows.extend(await _maybe_await(cur.fetchall()))
await _maybe_await(cur.execute(row_sql, {"rid": row["rowid"], **filter_params}))
fetched = await _maybe_await(cur.fetchall())
if fetched:
rows.extend(fetched)
matched_rows.append(row)

documents = [OracleDocumentStore._row_to_document(row) for row in rows]
self._merge_hybrid_scores(search_rows, documents, return_scores=return_scores)
self._merge_hybrid_scores(matched_rows, documents, return_scores=return_scores)
return documents

def _embedding_retrieval(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,91 +162,6 @@ def _field_to_json_path(field: str) -> str:
return f"$.{key}"


def _infer_hybrid_filter_type(value: Any) -> str:
if isinstance(value, bool):
msg = "Boolean values are not supported for Oracle hybrid filters."
raise FilterError(msg)
if isinstance(value, (int, float)):
return "number"
if isinstance(value, str):
return "string"
msg = "Oracle hybrid filters support only string and numeric values."
raise FilterError(msg)


def _hybrid_filter_path(field: str) -> str:
if not field.startswith("meta."):
msg = "Oracle hybrid retrieval supports only filters under the 'meta.' field."
raise FilterError(msg)
if not re.match(_JSON_FIELD_NAME, field):
msg = f"Invalid metadata field name: {field!r}"
raise FilterError(msg)
return field


def to_hybrid_filter(filters: dict[str, Any]) -> dict[str, Any]:
"""
Converts Haystack filters into DBMS_HYBRID_VECTOR.SEARCH filter_by JSON.
"""
op = filters.get("operator")
if op in ("AND", "OR", "NOT"):
if "conditions" not in filters:
msg = f"'conditions' key missing in logical filter: {filters}"
raise FilterError(msg)
return {"op": op, "args": [to_hybrid_filter(condition) for condition in filters["conditions"]]}

if "field" not in filters:
msg = f"'field' key missing in comparison filter: {filters}"
raise FilterError(msg)
if "operator" not in filters:
msg = f"'operator' key missing in comparison filter: {filters}"
raise FilterError(msg)
if "value" not in filters:
msg = f"'value' key missing in comparison filter: {filters}"
raise FilterError(msg)

field = _hybrid_filter_path(filters["field"])
value = filters["value"]
if value is None:
msg = "Oracle hybrid retrieval does not support null comparisons."
raise FilterError(msg)
if op in {"contains", "not contains"}:
msg = f"Filter operation {op!r} is not supported for Oracle hybrid retrieval."
raise FilterError(msg)

if op in {"in", "not in"}:
if not isinstance(value, list) or not value:
msg = f"{op!r} filter requires a non-empty list."
raise FilterError(msg)
value_type = _infer_hybrid_filter_type(value[0])
if any(_infer_hybrid_filter_type(item) != value_type for item in value):
msg = "Oracle hybrid retrieval requires all 'in' filter values to share one type."
raise FilterError(msg)
hybrid_filter: dict[str, Any] = {"op": "IN", "path": field, "type": value_type, "args": value}
if op == "not in":
return {"op": "NOT", "args": [hybrid_filter]}
return hybrid_filter

hybrid_op_map = {
"==": "=",
"!=": "!=",
">": ">",
">=": ">=",
"<": "<",
"<=": "<=",
}
if not isinstance(op, str) or op not in hybrid_op_map:
msg = f"Unsupported filter operator: {op!r}"
raise FilterError(msg)

return {
"op": hybrid_op_map[op],
"path": field,
"type": _infer_hybrid_filter_type(value),
"args": [value],
}


def _is_iso_date(value: Any) -> bool:
"""Return True if *value* is a string that Python recognises as a valid ISO-8601 datetime."""
if not isinstance(value, str):
Expand Down
4 changes: 4 additions & 0 deletions integrations/oracle/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
_USER = os.getenv("ORACLE_USER") or os.getenv("VECDB_USER") or "haystack"
_PASSWORD = os.getenv("ORACLE_PASSWORD") or os.getenv("VECDB_PASS") or "haystack"
_DSN = os.getenv("ORACLE_DSN") or os.getenv("ORACLE_DB_DSN") or os.getenv("VECDB_HOST") or "localhost:1521/freepdb1"
_WALLET_LOCATION = os.getenv("ORACLE_WALLET_LOCATION")
_WALLET_PASSWORD = os.getenv("ORACLE_WALLET_PASSWORD")


def _make_store(table: str, embedding_dim: int) -> OracleDocumentStore:
Expand All @@ -23,6 +25,8 @@ def _make_store(table: str, embedding_dim: int) -> OracleDocumentStore:
user=Secret.from_token(_USER),
password=Secret.from_token(_PASSWORD),
dsn=Secret.from_token(_DSN),
wallet_location=_WALLET_LOCATION,
wallet_password=Secret.from_token(_WALLET_PASSWORD) if _WALLET_PASSWORD else None,
),
table_name=table,
embedding_dim=embedding_dim,
Expand Down
20 changes: 15 additions & 5 deletions integrations/oracle/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
_USER = os.getenv("ORACLE_USER") or os.getenv("VECDB_USER") or "haystack"
_PASSWORD = os.getenv("ORACLE_PASSWORD") or os.getenv("VECDB_PASS") or "haystack"
_DSN = os.getenv("ORACLE_DSN") or os.getenv("ORACLE_DB_DSN") or os.getenv("VECDB_HOST") or "localhost:1521/freepdb1"
_WALLET_LOCATION = os.getenv("ORACLE_WALLET_LOCATION")
_WALLET_PASSWORD = os.getenv("ORACLE_WALLET_PASSWORD")


def _doc(doc_id: str, content: str = "hello", meta: dict | None = None, embedding: list[float] | None = None):
Expand Down Expand Up @@ -75,6 +77,8 @@ def document_store(self):
user=Secret.from_token(_USER),
password=Secret.from_token(_PASSWORD),
dsn=Secret.from_token(_DSN),
wallet_location=_WALLET_LOCATION,
wallet_password=Secret.from_token(_WALLET_PASSWORD) if _WALLET_PASSWORD else None,
),
table_name=table,
embedding_dim=768,
Expand Down Expand Up @@ -140,13 +144,19 @@ def test_write_documents_skip_policy_uses_merge_not_matched(self, patched_store,
assert "WHEN NOT MATCHED" in sql
assert "WHEN MATCHED" not in sql

def test_write_documents_overwrite_policy_uses_full_merge(self, patched_store, mock_pool):
def test_write_documents_overwrite_policy_deletes_then_inserts(self, patched_store, mock_pool):
# OVERWRITE uses delete-then-insert instead of a MERGE, because a MERGE combining
# WHEN MATCHED UPDATE with WHEN NOT MATCHED INSERT raises ORA-06531 in the DBMS_SEARCH
# keyword-index trigger on Oracle 23ai/26ai.
_, _, cursor = mock_pool
patched_store.write_documents([self._mock_doc()], policy=DuplicatePolicy.OVERWRITE)
sql = cursor.executemany.call_args[0][0]
assert "MERGE INTO" in sql
assert "WHEN MATCHED" in sql
assert "WHEN NOT MATCHED" in sql
assert cursor.executemany.call_count == 2
delete_sql = cursor.executemany.call_args_list[0][0][0]
insert_sql = cursor.executemany.call_args_list[1][0][0]
assert "DELETE FROM" in delete_sql
assert "INSERT INTO" in insert_sql
assert "MERGE" not in delete_sql
assert "MERGE" not in insert_sql

def test_write_documents_returns_count(self, patched_store, mock_pool): # noqa: ARG002
count = patched_store.write_documents(
Expand Down
Loading
Loading