Skip to content

Commit df4e7d6

Browse files
feat: Add Astra document store operations (#2904)
* Add Astra document store operations * calling static methods from the class and not the instance --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent ecf3ba6 commit df4e7d6

3 files changed

Lines changed: 324 additions & 4 deletions

File tree

integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from warnings import warn
88

99
from astrapy import DataAPIClient as AstraDBClient
10+
from astrapy.collection import FilterType
1011
from astrapy.constants import ReturnDocument
1112
from astrapy.exceptions import CollectionAlreadyExistsException
1213
from haystack import logging
@@ -216,19 +217,22 @@ def _query(self, vector, top_k, filters=None):
216217

217218
return result
218219

219-
def find_documents(self, find_query):
220+
def find_documents(
221+
self, find_query: dict[str, Any], projection: dict[str, Any] | None = None
222+
) -> list[dict[str, Any]]:
220223
"""
221224
Find documents in the Astra index.
222225
223226
:param find_query: a dictionary with the query options
227+
:param projection: optional projection for the returned documents
224228
:returns: the documents found in the index
225229
"""
226230
find_cursor = self._astra_db_collection.find(
227231
filter=find_query.get("filter"),
228232
sort=find_query.get("sort"),
229233
limit=find_query.get("limit"),
230234
include_similarity=find_query.get("includeSimilarity"),
231-
projection={"*": 1},
235+
projection=projection or {"*": 1},
232236
)
233237

234238
find_results = []
@@ -354,12 +358,25 @@ def delete_all_documents(self) -> int:
354358

355359
return delete_result.deleted_count
356360

357-
def count_documents(self, upper_bound: int = 10000) -> int:
361+
def count_documents(self, filters: FilterType | None = None, upper_bound: int = 10000) -> int:
358362
"""
359363
Count the number of documents in the Astra index.
364+
365+
:param filters: optional filter to restrict the counted documents
366+
:param upper_bound: maximum expected count, required by Astra's API
360367
:returns: the number of documents in the index
361368
"""
362-
return self._astra_db_collection.count_documents({}, upper_bound=upper_bound)
369+
return self._astra_db_collection.count_documents(filters or {}, upper_bound=upper_bound)
370+
371+
def distinct(self, key: str, filters: FilterType | None = None) -> list[Any]:
372+
"""
373+
Return the distinct values for a field in the Astra index.
374+
375+
:param key: field name
376+
:param filters: optional filter to restrict the matching documents
377+
:returns: distinct values for the field
378+
"""
379+
return self._astra_db_collection.distinct(key, filter=filters)
363380

364381
def update(
365382
self,

integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,59 @@ def count_documents(self) -> int:
284284
"""
285285
return self.index.count_documents()
286286

287+
@staticmethod
288+
def _normalize_new_filter_input(filters: dict[str, Any]) -> dict[str, Any]:
289+
if not isinstance(filters, dict):
290+
msg = "Filters must be a dictionary"
291+
raise AstraDocumentStoreFilterError(msg)
292+
293+
normalized_filters = filters.copy()
294+
if "id" in normalized_filters:
295+
normalized_filters["_id"] = normalized_filters.pop("id")
296+
297+
return normalized_filters
298+
299+
@staticmethod
300+
def _infer_metadata_field_type(values: list[Any]) -> str:
301+
inferred_types = set()
302+
for value in values:
303+
if isinstance(value, list):
304+
for item in value:
305+
if isinstance(item, bool):
306+
inferred_types.add("boolean")
307+
elif isinstance(item, (int, float)):
308+
inferred_types.add("long")
309+
elif isinstance(item, str):
310+
inferred_types.add("keyword")
311+
elif isinstance(value, bool):
312+
inferred_types.add("boolean")
313+
elif isinstance(value, (int, float)):
314+
inferred_types.add("long")
315+
elif isinstance(value, str):
316+
inferred_types.add("keyword")
317+
318+
if not inferred_types:
319+
return "keyword"
320+
321+
if len(inferred_types) > 1:
322+
logger.warning("Field has mixed metadata types {types}. Defaulting to 'keyword'.", types=inferred_types)
323+
return "keyword"
324+
325+
return next(iter(inferred_types))
326+
327+
@staticmethod
328+
def _normalize_distinct_values(values: list[Any]) -> list[str]:
329+
normalized_values: set[str] = set()
330+
for value in values:
331+
if isinstance(value, list):
332+
normalized_values.update(str(item) for item in value)
333+
elif value is not None:
334+
normalized_values.add(str(value))
335+
return sorted(normalized_values)
336+
337+
def _get_metadata_projection_documents(self) -> list[dict[str, Any]]:
338+
return self.index.find_documents({}, projection={"content": 1, "meta": 1})
339+
287340
def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]:
288341
"""
289342
Returns at most 1000 documents that match the filter.
@@ -484,3 +537,92 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int
484537
logger.info(f"{update_count} documents updated by filter")
485538

486539
return update_count
540+
541+
def count_documents_by_filter(self, filters: dict[str, Any]) -> int:
542+
"""
543+
Applies a filter and counts the documents that matched it.
544+
545+
:param filters: The filters to apply to the document list.
546+
:returns: The number of documents that match the filter.
547+
"""
548+
normalized_filters = AstraDocumentStore._normalize_new_filter_input(filters)
549+
converted_filters = _convert_filters(normalized_filters)
550+
return self.index.count_documents(filters=converted_filters, upper_bound=1_000_000_000)
551+
552+
def count_unique_metadata_by_filter(self, filters: dict[str, Any], metadata_fields: list[str]) -> dict[str, int]:
553+
"""
554+
Applies a filter selecting documents and counts the unique values for each meta field of the matched
555+
documents.
556+
557+
:param filters: The filters to apply to the document list.
558+
:param metadata_fields: The metadata fields to count unique values for.
559+
:returns: A dictionary where the keys are the metadata field names and the values are the count of unique
560+
values.
561+
"""
562+
normalized_filters = AstraDocumentStore._normalize_new_filter_input(filters)
563+
converted_filters = _convert_filters(normalized_filters)
564+
565+
counts = {}
566+
for field in metadata_fields:
567+
distinct_values = self.index.distinct(f"meta.{field}", filters=converted_filters)
568+
counts[field] = len(AstraDocumentStore._normalize_distinct_values(distinct_values))
569+
return counts
570+
571+
def get_metadata_fields_info(self) -> dict[str, dict[str, str]]:
572+
"""
573+
Returns the metadata fields and the corresponding types.
574+
575+
:returns: A dictionary mapping field names to dictionaries with a `type` key.
576+
"""
577+
documents = self._get_metadata_projection_documents()
578+
if not documents:
579+
return {}
580+
581+
fields_info: dict[str, dict[str, str]] = {}
582+
583+
if any(document.get("content") is not None for document in documents):
584+
fields_info["content"] = {"type": "text"}
585+
586+
field_values: dict[str, list[Any]] = {}
587+
for document in documents:
588+
for field, value in document.get("meta", {}).items():
589+
field_values.setdefault(field, []).append(value)
590+
591+
for field, values in field_values.items():
592+
fields_info[field] = {"type": self._infer_metadata_field_type(values)}
593+
594+
return fields_info
595+
596+
def get_metadata_field_min_max(self, metadata_field: str) -> dict[str, Any]:
597+
"""
598+
For a given metadata field, find its max and min value.
599+
600+
:param metadata_field: The metadata field to inspect.
601+
:returns: A dictionary with `min` and `max`.
602+
"""
603+
distinct_values = self.index.distinct(f"meta.{metadata_field}")
604+
comparable_values = [value for value in distinct_values if isinstance(value, (str, int, float, bool))]
605+
if not comparable_values:
606+
return {"min": None, "max": None}
607+
608+
return {"min": min(comparable_values), "max": max(comparable_values)}
609+
610+
def get_metadata_field_unique_values(
611+
self, metadata_field: str, search_term: str | None = None, from_: int = 0, size: int = 10
612+
) -> tuple[list[str], int]:
613+
"""
614+
Retrieves unique values for a field matching a search term or all possible values if no search term is given.
615+
616+
:param metadata_field: The metadata field to inspect.
617+
:param search_term: Optional case-insensitive substring search term.
618+
:param from_: The starting index for pagination.
619+
:param size: The number of values to return.
620+
:returns: A tuple containing the paginated values and the total count.
621+
"""
622+
values = AstraDocumentStore._normalize_distinct_values(self.index.distinct(f"meta.{metadata_field}"))
623+
if search_term:
624+
search_term_lower = search_term.lower()
625+
values = [value for value in values if search_term_lower in value.lower()]
626+
627+
total_count = len(values)
628+
return values[from_ : from_ + size], total_count

integrations/astra/tests/test_document_store.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,93 @@ def test_to_dict(mock_auth): # noqa
4343
}
4444

4545

46+
@pytest.mark.usefixtures("mock_auth")
47+
@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
48+
def test_count_documents_by_filter(mock_astra_client):
49+
mock_index = mock_astra_client.return_value
50+
mock_index.count_documents.return_value = 2
51+
52+
store = AstraDocumentStore()
53+
54+
count = store.count_documents_by_filter({"field": "meta.status", "operator": "==", "value": "draft"})
55+
56+
assert count == 2
57+
mock_index.count_documents.assert_called_once_with(
58+
filters={"meta.status": {"$eq": "draft"}}, upper_bound=1_000_000_000
59+
)
60+
61+
62+
@pytest.mark.usefixtures("mock_auth")
63+
@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
64+
def test_count_unique_metadata_by_filter(mock_astra_client):
65+
mock_index = mock_astra_client.return_value
66+
mock_index.distinct.side_effect = [["news", "docs", ["docs", "faq"], None], [1, 2, 2]]
67+
68+
store = AstraDocumentStore()
69+
70+
counts = store.count_unique_metadata_by_filter(
71+
{"field": "meta.status", "operator": "==", "value": "published"}, ["category", "priority"]
72+
)
73+
74+
assert counts == {"category": 3, "priority": 2}
75+
assert mock_index.distinct.call_args_list == [
76+
mock.call("meta.category", filters={"meta.status": {"$eq": "published"}}),
77+
mock.call("meta.priority", filters={"meta.status": {"$eq": "published"}}),
78+
]
79+
80+
81+
@pytest.mark.usefixtures("mock_auth")
82+
@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
83+
def test_get_metadata_fields_info(mock_astra_client):
84+
mock_index = mock_astra_client.return_value
85+
mock_index.find_documents.return_value = [
86+
{"content": "Doc 1", "meta": {"category": "news", "priority": 1, "active": True}},
87+
{"content": "Doc 2", "meta": {"category": "docs", "priority": 2.5, "tags": ["a", "b"]}},
88+
]
89+
90+
store = AstraDocumentStore()
91+
92+
fields_info = store.get_metadata_fields_info()
93+
94+
assert fields_info == {
95+
"content": {"type": "text"},
96+
"category": {"type": "keyword"},
97+
"priority": {"type": "long"},
98+
"active": {"type": "boolean"},
99+
"tags": {"type": "keyword"},
100+
}
101+
mock_index.find_documents.assert_called_once_with({}, projection={"content": 1, "meta": 1})
102+
103+
104+
@pytest.mark.usefixtures("mock_auth")
105+
@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
106+
def test_get_metadata_field_min_max(mock_astra_client):
107+
mock_index = mock_astra_client.return_value
108+
mock_index.distinct.return_value = [10, 3, 7]
109+
110+
store = AstraDocumentStore()
111+
112+
result = store.get_metadata_field_min_max("priority")
113+
114+
assert result == {"min": 3, "max": 10}
115+
mock_index.distinct.assert_called_once_with("meta.priority")
116+
117+
118+
@pytest.mark.usefixtures("mock_auth")
119+
@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
120+
def test_get_metadata_field_unique_values(mock_astra_client):
121+
mock_index = mock_astra_client.return_value
122+
mock_index.distinct.return_value = ["Beta", "alpha", ["gamma", "alphabet"], None]
123+
124+
store = AstraDocumentStore()
125+
126+
values, total_count = store.get_metadata_field_unique_values("category", search_term="alp", from_=0, size=5)
127+
128+
assert values == ["alpha", "alphabet"]
129+
assert total_count == 2
130+
mock_index.distinct.assert_called_once_with("meta.category")
131+
132+
46133
@pytest.mark.integration
47134
@pytest.mark.skipif(
48135
os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set"
@@ -204,6 +291,80 @@ def test_filter_documents_by_in_operator(self, document_store):
204291
self.assert_documents_are_equal([result[0]], [docs[0]])
205292
self.assert_documents_are_equal([result[1]], [docs[1]])
206293

294+
def test_count_documents_by_filter(self, document_store: AstraDocumentStore):
295+
docs = [
296+
Document(id="1", content="Doc 1", meta={"category": "news", "status": "published", "priority": 3}),
297+
Document(id="2", content="Doc 2", meta={"category": "docs", "status": "draft", "priority": 1}),
298+
Document(id="3", content="Doc 3", meta={"category": "news", "status": "published", "priority": 5}),
299+
]
300+
document_store.write_documents(docs)
301+
302+
count = document_store.count_documents_by_filter(
303+
{"field": "meta.status", "operator": "==", "value": "published"}
304+
)
305+
306+
assert count == 2
307+
308+
def test_count_unique_metadata_by_filter(self, document_store: AstraDocumentStore):
309+
docs = [
310+
Document(id="1", content="Doc 1", meta={"category": "news", "status": "published", "priority": 1}),
311+
Document(id="2", content="Doc 2", meta={"category": "docs", "status": "published", "priority": 2}),
312+
Document(id="3", content="Doc 3", meta={"category": "news", "status": "published", "priority": 2}),
313+
Document(id="4", content="Doc 4", meta={"category": "faq", "status": "draft", "priority": 3}),
314+
]
315+
document_store.write_documents(docs)
316+
317+
counts = document_store.count_unique_metadata_by_filter(
318+
{"field": "meta.status", "operator": "==", "value": "published"},
319+
["category", "priority"],
320+
)
321+
322+
assert counts == {"category": 2, "priority": 2}
323+
324+
def test_get_metadata_fields_info(self, document_store: AstraDocumentStore):
325+
docs = [
326+
Document(id="1", content="Doc 1", meta={"category": "news", "status": "published", "priority": 1}),
327+
Document(id="2", content="Doc 2", meta={"category": "docs", "status": "draft", "priority": 2}),
328+
]
329+
document_store.write_documents(docs)
330+
331+
fields_info = document_store.get_metadata_fields_info()
332+
333+
assert fields_info == {
334+
"content": {"type": "text"},
335+
"category": {"type": "keyword"},
336+
"status": {"type": "keyword"},
337+
"priority": {"type": "long"},
338+
}
339+
340+
def test_get_metadata_field_min_max(self, document_store: AstraDocumentStore):
341+
docs = [
342+
Document(id="1", content="Doc 1", meta={"priority": 3}),
343+
Document(id="2", content="Doc 2", meta={"priority": 1}),
344+
Document(id="3", content="Doc 3", meta={"priority": 7}),
345+
]
346+
document_store.write_documents(docs)
347+
348+
result = document_store.get_metadata_field_min_max("priority")
349+
350+
assert result == {"min": 1, "max": 7}
351+
352+
def test_get_metadata_field_unique_values(self, document_store: AstraDocumentStore):
353+
docs = [
354+
Document(id="1", content="Doc 1", meta={"category": "alpha"}),
355+
Document(id="2", content="Doc 2", meta={"category": "beta"}),
356+
Document(id="3", content="Doc 3", meta={"category": "alphabet"}),
357+
Document(id="4", content="Doc 4", meta={"category": "gamma"}),
358+
]
359+
document_store.write_documents(docs)
360+
361+
values, total_count = document_store.get_metadata_field_unique_values(
362+
"category", search_term="alp", from_=0, size=10
363+
)
364+
365+
assert values == ["alpha", "alphabet"]
366+
assert total_count == 2
367+
207368
@pytest.mark.skip(reason="Unsupported filter operator not.")
208369
def test_not_operator(self, document_store, filterable_docs):
209370
pass

0 commit comments

Comments
 (0)