Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/astra/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"astrapy>=1.5.0,<2.0",
"haystack-ai>=2.24.0",
"haystack-ai>=2.26.1",
"pydantic",
"typing_extensions",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,9 @@ def get_metadata_field_min_max(self, metadata_field: str) -> dict[str, Any]:
:param metadata_field: The metadata field to inspect.
:returns: A dictionary with `min` and `max`.
"""
distinct_values = self.index.distinct(f"meta.{metadata_field}")

field = metadata_field.removeprefix("meta.")
distinct_values = self.index.distinct(f"meta.{field}")
comparable_values = [value for value in distinct_values if isinstance(value, str | int | float | bool)]
if not comparable_values:
return {"min": None, "max": None}
Expand All @@ -621,7 +623,8 @@ def get_metadata_field_unique_values(
:param size: The number of values to return.
:returns: A tuple containing the paginated values and the total count.
"""
values = AstraDocumentStore._normalize_distinct_values(self.index.distinct(f"meta.{metadata_field}"))
field = metadata_field.removeprefix("meta.")
values = AstraDocumentStore._normalize_distinct_values(self.index.distinct(f"meta.{field}"))
if search_term:
search_term_lower = search_term.lower()
values = [value for value in values if search_term_lower in value.lower()]
Expand Down
92 changes: 16 additions & 76 deletions integrations/astra/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from haystack import Document
from haystack.document_stores.errors import MissingDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.testing.document_store import DocumentStoreBaseExtendedTests
from haystack.testing.document_store import (
CountDocumentsByFilterTest,
CountUniqueMetadataByFilterTest,
DocumentStoreBaseExtendedTests,
GetMetadataFieldMinMaxTest,
GetMetadataFieldsInfoTest,
GetMetadataFieldUniqueValuesTest,
)

from haystack_integrations.document_stores.astra import AstraDocumentStore

Expand Down Expand Up @@ -135,7 +142,14 @@ def test_get_metadata_field_unique_values(mock_astra_client):
os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set"
)
@pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set")
class TestDocumentStore(DocumentStoreBaseExtendedTests):
class TestDocumentStore(
DocumentStoreBaseExtendedTests,
CountDocumentsByFilterTest,
CountUniqueMetadataByFilterTest,
GetMetadataFieldsInfoTest,
GetMetadataFieldMinMaxTest,
GetMetadataFieldUniqueValuesTest,
):
"""
Common test cases will be provided by `DocumentStoreBaseExtendedTests` but
you can add more to this class.
Expand Down Expand Up @@ -292,80 +306,6 @@ def test_filter_documents_by_in_operator(self, document_store):
TestDocumentStore.assert_documents_are_equal([result[0]], [docs[0]])
TestDocumentStore.assert_documents_are_equal([result[1]], [docs[1]])

def test_count_documents_by_filter(self, document_store: AstraDocumentStore):
docs = [
Document(id="1", content="Doc 1", meta={"category": "news", "status": "published", "priority": 3}),
Document(id="2", content="Doc 2", meta={"category": "docs", "status": "draft", "priority": 1}),
Document(id="3", content="Doc 3", meta={"category": "news", "status": "published", "priority": 5}),
]
document_store.write_documents(docs)

count = document_store.count_documents_by_filter(
{"field": "meta.status", "operator": "==", "value": "published"}
)

assert count == 2

def test_count_unique_metadata_by_filter(self, document_store: AstraDocumentStore):
docs = [
Document(id="1", content="Doc 1", meta={"category": "news", "status": "published", "priority": 1}),
Document(id="2", content="Doc 2", meta={"category": "docs", "status": "published", "priority": 2}),
Document(id="3", content="Doc 3", meta={"category": "news", "status": "published", "priority": 2}),
Document(id="4", content="Doc 4", meta={"category": "faq", "status": "draft", "priority": 3}),
]
document_store.write_documents(docs)

counts = document_store.count_unique_metadata_by_filter(
{"field": "meta.status", "operator": "==", "value": "published"},
["category", "priority"],
)

assert counts == {"category": 2, "priority": 2}

def test_get_metadata_fields_info(self, document_store: AstraDocumentStore):
docs = [
Document(id="1", content="Doc 1", meta={"category": "news", "status": "published", "priority": 1}),
Document(id="2", content="Doc 2", meta={"category": "docs", "status": "draft", "priority": 2}),
]
document_store.write_documents(docs)

fields_info = document_store.get_metadata_fields_info()

assert fields_info == {
"content": {"type": "text"},
"category": {"type": "keyword"},
"status": {"type": "keyword"},
"priority": {"type": "long"},
}

def test_get_metadata_field_min_max(self, document_store: AstraDocumentStore):
docs = [
Document(id="1", content="Doc 1", meta={"priority": 3}),
Document(id="2", content="Doc 2", meta={"priority": 1}),
Document(id="3", content="Doc 3", meta={"priority": 7}),
]
document_store.write_documents(docs)

result = document_store.get_metadata_field_min_max("priority")

assert result == {"min": 1, "max": 7}

def test_get_metadata_field_unique_values(self, document_store: AstraDocumentStore):
docs = [
Document(id="1", content="Doc 1", meta={"category": "alpha"}),
Document(id="2", content="Doc 2", meta={"category": "beta"}),
Document(id="3", content="Doc 3", meta={"category": "alphabet"}),
Document(id="4", content="Doc 4", meta={"category": "gamma"}),
]
document_store.write_documents(docs)

values, total_count = document_store.get_metadata_field_unique_values(
"category", search_term="alp", from_=0, size=10
)

assert values == ["alpha", "alphabet"]
assert total_count == 2

@pytest.mark.skip(reason="Unsupported filter operator not.")
def test_not_operator(self, document_store, filterable_docs):
pass
Expand Down
Loading