Skip to content

Commit 155d5b5

Browse files
feat(arcadedb): add metadata query methods to ArcadeDBDocumentStore (#3013)
* feat(arcadedb): add metadata query methods to ArcadeDBDocumentStore * increasing lowest haystack dependency * using Mixin tests and fixing get_metadata_field_min_max 'meta' prefix removal * adding missing comma * fixing type inference and 'meta' prefix removal on other operations * fixing renaming * adding safeguard for delete and update by filter * use stripped field name in return, consistent with get_metadata_field_info * adding missing docstring --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent f487dd6 commit 155d5b5

3 files changed

Lines changed: 306 additions & 12 deletions

File tree

integrations/arcadedb/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifiers = [
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
2626
dependencies = [
27-
"haystack-ai>=2.24.0",
27+
"haystack-ai>=2.26.1",
2828
"requests>=2.28.0",
2929
]
3030

integrations/arcadedb/src/haystack_integrations/document_stores/arcadedb/document_store.py

Lines changed: 209 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class ArcadeDBDocumentStore:
5656
"dot": "DOT_PRODUCT",
5757
}
5858

59+
# Limit for projection documents
60+
SCHEMA_SAMPLING_LIMIT: ClassVar[int] = 1000
61+
5962
def __init__(
6063
self,
6164
*,
@@ -234,6 +237,70 @@ def count_documents(self) -> int:
234237
return int(rows[0].get("cnt", 0))
235238
return 0
236239

240+
@staticmethod
241+
def _extract_distinct_values(rows: list[dict[str, Any]]) -> set[str]:
242+
"""
243+
Extracts and flattens unique non-None strings from 'val' column result rows.
244+
:param rows: Raw result rows from ``_command``.
245+
:returns: A set of unique string values.
246+
"""
247+
result: set[str] = set()
248+
for row in rows:
249+
val = row.get("val")
250+
if isinstance(val, list):
251+
result.update(str(item) for item in val if item is not None)
252+
elif val is not None:
253+
result.add(str(val))
254+
return result
255+
256+
def _get_metadata_projection_documents(self) -> list[dict[str, Any]]:
257+
"""
258+
Private helper to fetch sample documents for schema inference.
259+
Note: Does not `_ensure_initialized()`. To avoid redundant
260+
initialization checks during internal calls, the caller is responsible for
261+
ensuring the document store is initialized before invoking this method.
262+
"""
263+
sql = f"SELECT content, meta FROM `{self._type_name}` LIMIT {self.SCHEMA_SAMPLING_LIMIT}"
264+
return self._command(sql)
265+
266+
@staticmethod
267+
def _infer_metadata_field_type(values: list[Any]) -> str:
268+
"""
269+
Infers the metadata field type from a list of sampled values.
270+
:param values: A list of raw Python values sampled from the field.
271+
:returns: A type string — one of ``"boolean"``, ``"double"``, ``"long"``, or ``"keyword"``.
272+
Returns ``"keyword"`` if values are empty or of mixed types.
273+
"""
274+
inferred_types = set()
275+
for value in values:
276+
if isinstance(value, list):
277+
for item in value:
278+
if isinstance(item, bool):
279+
inferred_types.add("boolean")
280+
elif isinstance(item, float):
281+
inferred_types.add("double")
282+
elif isinstance(item, int):
283+
inferred_types.add("long")
284+
elif isinstance(item, str):
285+
inferred_types.add("keyword")
286+
elif isinstance(value, bool):
287+
inferred_types.add("boolean")
288+
elif isinstance(value, float):
289+
inferred_types.add("double")
290+
elif isinstance(value, int):
291+
inferred_types.add("long")
292+
elif isinstance(value, str):
293+
inferred_types.add("keyword")
294+
295+
if not inferred_types:
296+
return "keyword"
297+
298+
if len(inferred_types) > 1:
299+
logger.warning("Field has mixed metadata types %s. Defaulting to 'keyword'.", inferred_types)
300+
return "keyword"
301+
302+
return next(iter(inferred_types))
303+
237304
def filter_documents(
238305
self,
239306
filters: dict[str, Any] | None = None,
@@ -357,7 +424,14 @@ def delete_by_filter(self, filters: dict[str, Any]) -> int:
357424
:returns: The number of documents deleted.
358425
"""
359426
self._ensure_initialized()
360-
where = _convert_filters(filters)
427+
try:
428+
where = _convert_filters(filters)
429+
except ValueError as e:
430+
raise FilterError(str(e)) from e
431+
432+
if not where:
433+
msg = "delete_by_filter requires a non-empty filter. Use delete_all_documents() to delete all documents."
434+
raise FilterError(msg)
361435

362436
count_result = self._command(f"DELETE FROM `{self._type_name}` WHERE {where}")
363437

@@ -373,14 +447,143 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int
373447
:returns: The number of documents updated.
374448
"""
375449
self._ensure_initialized()
376-
where = _convert_filters(filters)
450+
try:
451+
where = _convert_filters(filters)
452+
except ValueError as e:
453+
raise FilterError(str(e)) from e
454+
455+
if not where:
456+
msg = "update_by_filter requires a non-empty filter."
457+
raise FilterError(msg)
377458

378459
sql_set = ",".join(f"meta[{_sql_str(key)}] = {_map_literal_base(value)}" for key, value in meta.items())
379460
sql = f"UPDATE `{self._type_name}` SET {sql_set} WHERE {where}"
380461
count_result = self._command(sql)
381462

382463
return count_result[0]["count"]
383464

465+
def count_documents_by_filter(self, filters: dict[str, Any]) -> int:
466+
"""
467+
Counts the number of documents matching the provided filter
468+
:param filters: The filters to apply to the documents
469+
:returns: The number of documents that match the filter
470+
"""
471+
self._ensure_initialized()
472+
try:
473+
where = _convert_filters(filters)
474+
except ValueError as e:
475+
raise FilterError(str(e)) from e
476+
477+
sql = f"SELECT count(*) AS cnt FROM `{self._type_name}`"
478+
if where:
479+
sql += f" WHERE {where}"
480+
481+
rows = self._command(sql)
482+
if rows:
483+
return int(rows[0].get("cnt", 0))
484+
return 0
485+
486+
def count_unique_metadata_by_filter(self, filters: dict[str, Any], metadata_fields: list[str]) -> dict[str, int]:
487+
"""
488+
Counts unique values for each metadata field in documents matching the provided filters.
489+
:param filters: The filters to apply to the document list.
490+
:param metadata_fields: Metadata fields for which to count unique values.
491+
:returns: A dictionary where keys are metadata field names and values are the
492+
counts of unique values for that field.
493+
"""
494+
self._ensure_initialized()
495+
try:
496+
where = _convert_filters(filters)
497+
except ValueError as e:
498+
raise FilterError(str(e)) from e
499+
500+
if not metadata_fields:
501+
return {}
502+
503+
counts = {}
504+
for field in metadata_fields: # Arcade doesn't support COUNT(DISTINCT..)
505+
field_name = field.removeprefix("meta.")
506+
sql = f"SELECT DISTINCT meta[{_sql_str(field_name)}] AS val FROM `{self._type_name}`"
507+
if where:
508+
sql += f" WHERE {where}"
509+
rows = self._command(sql)
510+
counts[field_name] = len(self._extract_distinct_values(rows))
511+
512+
return counts
513+
514+
def get_metadata_fields_info(self) -> dict[str, dict[str, str]]:
515+
"""
516+
Returns the metadata fields and their corresponding types based on sampled documents.
517+
:returns: A dictionary mapping field names to dictionaries with a `type` key.
518+
"""
519+
self._ensure_initialized()
520+
documents = self._get_metadata_projection_documents()
521+
522+
if not documents:
523+
return {}
524+
525+
fields_info: dict[str, dict[str, str]] = {}
526+
527+
if any(document.get("content") is not None for document in documents):
528+
fields_info["content"] = {"type": "text"}
529+
530+
field_values: dict[str, list[Any]] = {}
531+
for document in documents:
532+
for field, value in document.get("meta", {}).items():
533+
field_values.setdefault(field, []).append(value)
534+
535+
for field, values in field_values.items():
536+
fields_info[field] = {"type": self._infer_metadata_field_type(values)}
537+
538+
return fields_info
539+
540+
def get_metadata_field_min_max(self, metadata_field: str) -> dict[str, Any]:
541+
"""
542+
For a given metadata field, finds its min and max values.
543+
:param metadata_field: The metadata field to inspect.
544+
:returns: A dictionary with `min` and `max` keys and their corresponding values.
545+
"""
546+
self._ensure_initialized()
547+
548+
field_name = metadata_field.removeprefix("meta.")
549+
field_ref = f"meta[{_sql_str(field_name)}]"
550+
sql = f"SELECT MIN({field_ref}) AS min_value, MAX({field_ref}) AS max_value FROM `{self._type_name}`"
551+
rows = self._command(sql)
552+
553+
if not rows:
554+
return {"min": None, "max": None}
555+
556+
return {"min": rows[0].get("min_value"), "max": rows[0].get("max_value")}
557+
558+
def get_metadata_field_unique_values(
559+
self, metadata_field: str, search_term: str | None = None, from_: int = 0, size: int = 10
560+
) -> tuple[list[str], int]:
561+
"""
562+
Retrieves unique values for a field matching a search term or all possible values
563+
if no search term is given.
564+
:param metadata_field: The metadata field to inspect.
565+
:param search_term: Optional case-insensitive substring search term.
566+
:param from_: The starting index for pagination.
567+
:param size: The number of values to return.
568+
:returns: A tuple containing the paginated values and the total count.
569+
"""
570+
self._ensure_initialized()
571+
572+
metadata_field = metadata_field.removeprefix("meta.")
573+
field_ref = f"meta[{_sql_str(metadata_field)}]"
574+
where = ""
575+
576+
if search_term:
577+
search_val = _sql_str(f"%{search_term}%")
578+
where = f" WHERE {field_ref} ILIKE {search_val}"
579+
580+
sql = f"SELECT DISTINCT {field_ref} AS val FROM `{self._type_name}`{where}"
581+
rows = self._command(sql)
582+
583+
all_values = sorted(self._extract_distinct_values(rows))
584+
total_count = len(all_values)
585+
return all_values[from_ : from_ + size], total_count
586+
384587
# ------------------------------------------------------------------
385588
# Retrieval (called by Retriever components)
386589
# ------------------------------------------------------------------
@@ -410,7 +613,10 @@ def _embedding_retrieval(
410613
return []
411614

412615
neighbors = rows[0]["neighbors"]
413-
where = _convert_filters(filters)
616+
try:
617+
where = _convert_filters(filters)
618+
except ValueError as e:
619+
raise FilterError(str(e)) from e
414620

415621
documents = []
416622
for neighbor in neighbors:

integrations/arcadedb/tests/test_document_store.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99
from haystack import Document
1010
from haystack.document_stores.errors import DuplicateDocumentError
1111
from haystack.document_stores.types import DuplicatePolicy
12-
from haystack.testing.document_store import DocumentStoreBaseExtendedTests
12+
from haystack.testing.document_store import (
13+
CountDocumentsByFilterTest,
14+
CountUniqueMetadataByFilterTest,
15+
DocumentStoreBaseExtendedTests,
16+
FilterableDocsFixtureMixin,
17+
GetMetadataFieldMinMaxTest,
18+
GetMetadataFieldsInfoTest,
19+
GetMetadataFieldUniqueValuesTest,
20+
)
1321

1422
from haystack_integrations.document_stores.arcadedb import ArcadeDBDocumentStore
1523

@@ -48,7 +56,15 @@ def test_to_dict_from_dict(self):
4856
reason="Set ARCADEDB_PASSWORD (e.g. via repo secret in CI) to run integration tests.",
4957
)
5058
@pytest.mark.integration
51-
class TestArcadeDBDocumentStore(DocumentStoreBaseExtendedTests):
59+
class TestArcadeDBDocumentStore(
60+
CountDocumentsByFilterTest,
61+
CountUniqueMetadataByFilterTest,
62+
DocumentStoreBaseExtendedTests,
63+
FilterableDocsFixtureMixin,
64+
GetMetadataFieldMinMaxTest,
65+
GetMetadataFieldsInfoTest,
66+
GetMetadataFieldUniqueValuesTest,
67+
):
5268
"""
5369
Run Haystack DocumentStore mixin tests against ArcadeDBDocumentStore.
5470
@@ -73,15 +89,16 @@ def assert_documents_are_equal(self, received: list[Document], expected: list[Do
7389
received = sorted(received, key=lambda x: x.id)
7490
expected = sorted(expected, key=lambda x: x.id)
7591
for received_doc, expected_doc in zip(received, expected, strict=True):
76-
received_doc.score = None
92+
actual = dataclasses.replace(received_doc, score=None)
7793
if expected_doc.embedding is None:
78-
received_doc.embedding = None
79-
elif received_doc.embedding is None:
94+
actual = dataclasses.replace(actual, embedding=None)
95+
elif actual.embedding is None:
8096
assert expected_doc.embedding is None
8197
else:
82-
assert received_doc.embedding == pytest.approx(expected_doc.embedding)
83-
received_doc.embedding, expected_doc.embedding = None, None
84-
assert received_doc == expected_doc
98+
assert actual.embedding == pytest.approx(expected_doc.embedding)
99+
actual = dataclasses.replace(actual, embedding=None)
100+
expected_clean = dataclasses.replace(expected_doc, embedding=None)
101+
assert actual == expected_clean
85102

86103
def test_write_documents(self, document_store: ArcadeDBDocumentStore):
87104
"""Override mixin: test default write_documents and duplicate fail behaviour."""
@@ -122,3 +139,74 @@ def test_embedding_retrieval(self, document_store: ArcadeDBDocumentStore):
122139
)
123140
assert len(results) <= 3
124141
assert results[0].score is not None
142+
143+
def test_count_documents_by_empty_filter(self, document_store: ArcadeDBDocumentStore):
144+
"""Counts all documents when an empty filter is provided."""
145+
docs = [
146+
Document(id="1", content="Doc 1", meta={"category": "news"}),
147+
]
148+
document_store.write_documents(docs)
149+
150+
count = document_store.count_documents_by_filter({})
151+
152+
assert count == 1
153+
154+
def test_count_unique_metadata_by_filter_empty_fields(self, document_store: ArcadeDBDocumentStore):
155+
"""Returns an empty dict when no metadata fields are requested."""
156+
docs = [
157+
Document(id="1", content="Doc 1", meta={"category": "news"}),
158+
]
159+
document_store.write_documents(docs)
160+
161+
counts = document_store.count_unique_metadata_by_filter(
162+
{"field": "meta.status", "operator": "==", "value": "news"},
163+
[],
164+
)
165+
166+
assert counts == {}
167+
168+
def test_get_metadata_field_min_max_nonexistent_field(self, document_store: ArcadeDBDocumentStore):
169+
"""Returns None for both min and max when the field does not exist."""
170+
docs = [Document(id="1", content="Doc 1", meta={"category": "news"})]
171+
document_store.write_documents(docs)
172+
173+
result = document_store.get_metadata_field_min_max("nonexistent")
174+
175+
assert result == {"min": None, "max": None}
176+
177+
def test_get_metadata_field_unique_values_pagination(self, document_store: ArcadeDBDocumentStore):
178+
"""Respects size limit while total reflects the full unpaginated count."""
179+
docs = [
180+
Document(id="1", content="Doc 1", meta={"category": "alpha"}),
181+
Document(id="2", content="Doc 2", meta={"category": "beta"}),
182+
Document(id="3", content="Doc 3", meta={"category": "gamma"}),
183+
]
184+
document_store.write_documents(docs)
185+
186+
values, total = document_store.get_metadata_field_unique_values("category", from_=0, size=2)
187+
188+
assert len(values) == 2
189+
assert total == 3
190+
191+
def test_get_metadata_field_unique_values_case_insensitive(self, document_store: ArcadeDBDocumentStore):
192+
"""Matches values case-insensitively when a search term is provided."""
193+
docs = [
194+
Document(id="1", content="Doc 1", meta={"category": "Books"}),
195+
Document(id="2", content="Doc 2", meta={"category": "books"}),
196+
Document(id="3", content="Doc 3", meta={"category": "ELECTRONICS"}),
197+
]
198+
document_store.write_documents(docs)
199+
200+
_, total = document_store.get_metadata_field_unique_values("category", search_term="book")
201+
202+
assert total == 2
203+
204+
def test_get_metadata_field_unique_values_no_matches(self, document_store: ArcadeDBDocumentStore):
205+
"""Returns empty results when no metadata values match the search term."""
206+
docs = [Document(id="1", content="Doc 1", meta={"category": "news"})]
207+
document_store.write_documents(docs)
208+
209+
values, total = document_store.get_metadata_field_unique_values("category", search_term="sports")
210+
211+
assert values == []
212+
assert total == 0

0 commit comments

Comments
 (0)