Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/arcadedb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"haystack-ai>=2.24.0",
"haystack-ai>=2.26.1",
"requests>=2.28.0",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class ArcadeDBDocumentStore:
"dot": "DOT_PRODUCT",
}

# Limit for projection documents
SCHEMA_SAMPLING_LIMIT: ClassVar[int] = 1000

def __init__(
self,
*,
Expand Down Expand Up @@ -234,6 +237,70 @@ def count_documents(self) -> int:
return int(rows[0].get("cnt", 0))
return 0

@staticmethod
def _extract_distinct_values(rows: list[dict[str, Any]]) -> set[str]:
"""
Extracts and flattens unique non-None strings from 'val' column result rows.
:param rows: Raw result rows from ``_command``.
:returns: A set of unique string values.
"""
result: set[str] = set()
for row in rows:
val = row.get("val")
if isinstance(val, list):
result.update(str(item) for item in val if item is not None)
elif val is not None:
result.add(str(val))
return result

def _get_metadata_projection_documents(self) -> list[dict[str, Any]]:
"""
Private helper to fetch sample documents for schema inference.
Note: Does not `_ensure_initialized()`. To avoid redundant
initialization checks during internal calls, the caller is responsible for
ensuring the document store is initialized before invoking this method.
"""
sql = f"SELECT content, meta FROM `{self._type_name}` LIMIT {self.SCHEMA_SAMPLING_LIMIT}"
return self._command(sql)

@staticmethod
def _infer_metadata_field_type(values: list[Any]) -> str:
"""
Infers the metadata field type from a list of sampled values.
:param values: A list of raw Python values sampled from the field.
:returns: A type string — one of ``"boolean"``, ``"double"``, ``"long"``, or ``"keyword"``.
Returns ``"keyword"`` if values are empty or of mixed types.
"""
inferred_types = set()
for value in values:
if isinstance(value, list):
for item in value:
if isinstance(item, bool):
inferred_types.add("boolean")
elif isinstance(item, float):
inferred_types.add("double")
elif isinstance(item, int):
inferred_types.add("long")
elif isinstance(item, str):
inferred_types.add("keyword")
elif isinstance(value, bool):
inferred_types.add("boolean")
elif isinstance(value, float):
inferred_types.add("double")
elif isinstance(value, int):
inferred_types.add("long")
elif isinstance(value, str):
inferred_types.add("keyword")

if not inferred_types:
return "keyword"

if len(inferred_types) > 1:
logger.warning("Field has mixed metadata types %s. Defaulting to 'keyword'.", inferred_types)
return "keyword"

return next(iter(inferred_types))

def filter_documents(
self,
filters: dict[str, Any] | None = None,
Expand Down Expand Up @@ -357,7 +424,14 @@ def delete_by_filter(self, filters: dict[str, Any]) -> int:
:returns: The number of documents deleted.
"""
self._ensure_initialized()
where = _convert_filters(filters)
try:
where = _convert_filters(filters)
except ValueError as e:
raise FilterError(str(e)) from e

if not where:
msg = "delete_by_filter requires a non-empty filter. Use delete_all_documents() to delete all documents."
raise FilterError(msg)

count_result = self._command(f"DELETE FROM `{self._type_name}` WHERE {where}")

Expand All @@ -373,14 +447,143 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int
:returns: The number of documents updated.
"""
self._ensure_initialized()
where = _convert_filters(filters)
try:
where = _convert_filters(filters)
except ValueError as e:
raise FilterError(str(e)) from e

if not where:
msg = "update_by_filter requires a non-empty filter."
raise FilterError(msg)

sql_set = ",".join(f"meta[{_sql_str(key)}] = {_map_literal_base(value)}" for key, value in meta.items())
sql = f"UPDATE `{self._type_name}` SET {sql_set} WHERE {where}"
count_result = self._command(sql)

return count_result[0]["count"]

def count_documents_by_filter(self, filters: dict[str, Any]) -> int:
"""
Counts the number of documents matching the provided filter
:param filters: The filters to apply to the documents
:returns: The number of documents that match the filter
"""
self._ensure_initialized()
try:
where = _convert_filters(filters)
except ValueError as e:
raise FilterError(str(e)) from e

sql = f"SELECT count(*) AS cnt FROM `{self._type_name}`"
if where:
sql += f" WHERE {where}"

rows = self._command(sql)
if rows:
return int(rows[0].get("cnt", 0))
return 0

def count_unique_metadata_by_filter(self, filters: dict[str, Any], metadata_fields: list[str]) -> dict[str, int]:
"""
Counts unique values for each metadata field in documents matching the provided filters.
:param filters: The filters to apply to the document list.
:param metadata_fields: Metadata fields for which to count unique values.
:returns: A dictionary where keys are metadata field names and values are the
counts of unique values for that field.
"""
self._ensure_initialized()
try:
where = _convert_filters(filters)
except ValueError as e:
raise FilterError(str(e)) from e

if not metadata_fields:
return {}

counts = {}
for field in metadata_fields: # Arcade doesn't support COUNT(DISTINCT..)
field_name = field.removeprefix("meta.")
sql = f"SELECT DISTINCT meta[{_sql_str(field_name)}] AS val FROM `{self._type_name}`"
if where:
sql += f" WHERE {where}"
rows = self._command(sql)
counts[field_name] = len(self._extract_distinct_values(rows))

return counts

def get_metadata_fields_info(self) -> dict[str, dict[str, str]]:
"""
Returns the metadata fields and their corresponding types based on sampled documents.
:returns: A dictionary mapping field names to dictionaries with a `type` key.
"""
self._ensure_initialized()
documents = self._get_metadata_projection_documents()

if not documents:
return {}

fields_info: dict[str, dict[str, str]] = {}

if any(document.get("content") is not None for document in documents):
fields_info["content"] = {"type": "text"}

field_values: dict[str, list[Any]] = {}
for document in documents:
for field, value in document.get("meta", {}).items():
field_values.setdefault(field, []).append(value)

for field, values in field_values.items():
fields_info[field] = {"type": self._infer_metadata_field_type(values)}

return fields_info

def get_metadata_field_min_max(self, metadata_field: str) -> dict[str, Any]:
"""
For a given metadata field, finds its min and max values.
:param metadata_field: The metadata field to inspect.
:returns: A dictionary with `min` and `max` keys and their corresponding values.
"""
self._ensure_initialized()

field_name = metadata_field.removeprefix("meta.")
field_ref = f"meta[{_sql_str(field_name)}]"
sql = f"SELECT MIN({field_ref}) AS min_value, MAX({field_ref}) AS max_value FROM `{self._type_name}`"
rows = self._command(sql)

if not rows:
return {"min": None, "max": None}

return {"min": rows[0].get("min_value"), "max": rows[0].get("max_value")}

def get_metadata_field_unique_values(
self, metadata_field: str, search_term: str | None = None, from_: int = 0, size: int = 10
) -> tuple[list[str], int]:
"""
Retrieves unique values for a field matching a search term or all possible values
if no search term is given.
:param metadata_field: The metadata field to inspect.
:param search_term: Optional case-insensitive substring search term.
:param from_: The starting index for pagination.
:param size: The number of values to return.
:returns: A tuple containing the paginated values and the total count.
"""
self._ensure_initialized()

metadata_field = metadata_field.removeprefix("meta.")
field_ref = f"meta[{_sql_str(metadata_field)}]"
where = ""

if search_term:
search_val = _sql_str(f"%{search_term}%")
where = f" WHERE {field_ref} ILIKE {search_val}"

sql = f"SELECT DISTINCT {field_ref} AS val FROM `{self._type_name}`{where}"
rows = self._command(sql)

all_values = sorted(self._extract_distinct_values(rows))
total_count = len(all_values)
return all_values[from_ : from_ + size], total_count

# ------------------------------------------------------------------
# Retrieval (called by Retriever components)
# ------------------------------------------------------------------
Expand Down Expand Up @@ -410,7 +613,10 @@ def _embedding_retrieval(
return []

neighbors = rows[0]["neighbors"]
where = _convert_filters(filters)
try:
where = _convert_filters(filters)
except ValueError as e:
raise FilterError(str(e)) from e

documents = []
for neighbor in neighbors:
Expand Down
104 changes: 96 additions & 8 deletions integrations/arcadedb/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
from haystack import Document
from haystack.document_stores.errors import DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.testing.document_store import DocumentStoreBaseExtendedTests
from haystack.testing.document_store import (
CountDocumentsByFilterTest,
CountUniqueMetadataByFilterTest,
DocumentStoreBaseExtendedTests,
FilterableDocsFixtureMixin,
GetMetadataFieldMinMaxTest,
GetMetadataFieldsInfoTest,
GetMetadataFieldUniqueValuesTest,
)

from haystack_integrations.document_stores.arcadedb import ArcadeDBDocumentStore

Expand Down Expand Up @@ -48,7 +56,15 @@ def test_to_dict_from_dict(self):
reason="Set ARCADEDB_PASSWORD (e.g. via repo secret in CI) to run integration tests.",
)
@pytest.mark.integration
class TestArcadeDBDocumentStore(DocumentStoreBaseExtendedTests):
class TestArcadeDBDocumentStore(
CountDocumentsByFilterTest,
CountUniqueMetadataByFilterTest,
DocumentStoreBaseExtendedTests,
FilterableDocsFixtureMixin,
GetMetadataFieldMinMaxTest,
GetMetadataFieldsInfoTest,
GetMetadataFieldUniqueValuesTest,
):
"""
Run Haystack DocumentStore mixin tests against ArcadeDBDocumentStore.

Expand All @@ -73,15 +89,16 @@ def assert_documents_are_equal(self, received: list[Document], expected: list[Do
received = sorted(received, key=lambda x: x.id)
expected = sorted(expected, key=lambda x: x.id)
for received_doc, expected_doc in zip(received, expected, strict=True):
received_doc.score = None
actual = dataclasses.replace(received_doc, score=None)
if expected_doc.embedding is None:
received_doc.embedding = None
elif received_doc.embedding is None:
actual = dataclasses.replace(actual, embedding=None)
elif actual.embedding is None:
assert expected_doc.embedding is None
else:
assert received_doc.embedding == pytest.approx(expected_doc.embedding)
received_doc.embedding, expected_doc.embedding = None, None
assert received_doc == expected_doc
assert actual.embedding == pytest.approx(expected_doc.embedding)
actual = dataclasses.replace(actual, embedding=None)
expected_clean = dataclasses.replace(expected_doc, embedding=None)
assert actual == expected_clean

def test_write_documents(self, document_store: ArcadeDBDocumentStore):
"""Override mixin: test default write_documents and duplicate fail behaviour."""
Expand Down Expand Up @@ -122,3 +139,74 @@ def test_embedding_retrieval(self, document_store: ArcadeDBDocumentStore):
)
assert len(results) <= 3
assert results[0].score is not None

def test_count_documents_by_empty_filter(self, document_store: ArcadeDBDocumentStore):
"""Counts all documents when an empty filter is provided."""
docs = [
Document(id="1", content="Doc 1", meta={"category": "news"}),
]
document_store.write_documents(docs)

count = document_store.count_documents_by_filter({})

assert count == 1

def test_count_unique_metadata_by_filter_empty_fields(self, document_store: ArcadeDBDocumentStore):
"""Returns an empty dict when no metadata fields are requested."""
docs = [
Document(id="1", content="Doc 1", meta={"category": "news"}),
]
document_store.write_documents(docs)

counts = document_store.count_unique_metadata_by_filter(
{"field": "meta.status", "operator": "==", "value": "news"},
[],
)

assert counts == {}

def test_get_metadata_field_min_max_nonexistent_field(self, document_store: ArcadeDBDocumentStore):
"""Returns None for both min and max when the field does not exist."""
docs = [Document(id="1", content="Doc 1", meta={"category": "news"})]
document_store.write_documents(docs)

result = document_store.get_metadata_field_min_max("nonexistent")

assert result == {"min": None, "max": None}

def test_get_metadata_field_unique_values_pagination(self, document_store: ArcadeDBDocumentStore):
"""Respects size limit while total reflects the full unpaginated count."""
docs = [
Document(id="1", content="Doc 1", meta={"category": "alpha"}),
Document(id="2", content="Doc 2", meta={"category": "beta"}),
Document(id="3", content="Doc 3", meta={"category": "gamma"}),
]
document_store.write_documents(docs)

values, total = document_store.get_metadata_field_unique_values("category", from_=0, size=2)

assert len(values) == 2
assert total == 3

def test_get_metadata_field_unique_values_case_insensitive(self, document_store: ArcadeDBDocumentStore):
"""Matches values case-insensitively when a search term is provided."""
docs = [
Document(id="1", content="Doc 1", meta={"category": "Books"}),
Document(id="2", content="Doc 2", meta={"category": "books"}),
Document(id="3", content="Doc 3", meta={"category": "ELECTRONICS"}),
]
document_store.write_documents(docs)

_, total = document_store.get_metadata_field_unique_values("category", search_term="book")

assert total == 2

def test_get_metadata_field_unique_values_no_matches(self, document_store: ArcadeDBDocumentStore):
"""Returns empty results when no metadata values match the search term."""
docs = [Document(id="1", content="Doc 1", meta={"category": "news"})]
document_store.write_documents(docs)

values, total = document_store.get_metadata_field_unique_values("category", search_term="sports")

assert values == []
assert total == 0
Loading