Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 62 additions & 60 deletions integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,53 +42,7 @@ def clear_chroma_system_cache():
SharedSystemClient.clear_system_cache()


class TestDocumentStore(
CountDocumentsTest,
DeleteDocumentsTest,
FilterDocumentsTest,
FilterableDocsFixtureMixin,
UpdateByFilterTest,
DeleteAllTest,
DeleteByFilterTest,
CountDocumentsByFilterTest,
CountUniqueMetadataByFilterTest,
GetMetadataFieldsInfoTest,
GetMetadataFieldMinMaxTest,
GetMetadataFieldUniqueValuesTest,
):
"""
Common test cases will be provided by `DocumentStoreBaseTests` but
you can add more to this class.
"""

@pytest.fixture
def document_store(self, embedding_function) -> ChromaDocumentStore:
"""
This is the most basic requirement for the child class: provide
an instance of this document store so the base class can use it.
"""
with mock.patch(
"haystack_integrations.document_stores.chroma.document_store.get_embedding_function"
) as get_func:
get_func.return_value = embedding_function
return ChromaDocumentStore(embedding_function="test_function", collection_name=str(uuid.uuid1()))

def assert_documents_are_equal(self, received: list[Document], expected: list[Document]):
"""
Assert that two lists of Documents are equal.
This is used in every test, if a Document Store implementation has a different behaviour
it should override this method.

This can happen for example when the Document Store sets a score to returned Documents.
Since we can't know what the score will be, we can't compare the Documents reliably.
"""
received.sort(key=operator.attrgetter("id"))
expected.sort(key=operator.attrgetter("id"))

for doc_received, doc_expected in zip(received, expected, strict=True):
assert doc_received.content == doc_expected.content
assert doc_received.meta == doc_expected.meta

class TestDocumentStoreUnit:
def test_init_in_memory(self):
store = ChromaDocumentStore()

Expand Down Expand Up @@ -122,17 +76,6 @@ def test_invalid_initialization_both_host_and_persist_path(self):
store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost")
store._ensure_initialized()

def test_client_settings_applied(self, clear_chroma_system_cache):
"""
Chroma's in-memory client uses a singleton pattern with an internal cache.
Once a client is created with certain settings, Chroma rejects creating another
with different settings in the same process. We clear the cache before and after
this test to avoid conflicts with other tests that use default settings.
"""
store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False})
store._ensure_initialized()
assert store._client.get_settings().anonymized_telemetry is False

def test_to_dict(self, request):
ds = ChromaDocumentStore(
collection_name=request.node.name,
Expand Down Expand Up @@ -182,6 +125,66 @@ def test_same_collection_name_reinitialization(self):
ChromaDocumentStore("test_1")
ChromaDocumentStore("test_1")


@pytest.mark.integration
class TestDocumentStore(
CountDocumentsTest,
DeleteDocumentsTest,
FilterDocumentsTest,
FilterableDocsFixtureMixin,
UpdateByFilterTest,
DeleteAllTest,
DeleteByFilterTest,
CountDocumentsByFilterTest,
CountUniqueMetadataByFilterTest,
GetMetadataFieldsInfoTest,
GetMetadataFieldMinMaxTest,
GetMetadataFieldUniqueValuesTest,
):
"""
Common test cases will be provided by `DocumentStoreBaseTests` but
you can add more to this class.
"""

@pytest.fixture
def document_store(self, embedding_function) -> ChromaDocumentStore:
"""
This is the most basic requirement for the child class: provide
an instance of this document store so the base class can use it.
"""
with mock.patch(
"haystack_integrations.document_stores.chroma.document_store.get_embedding_function"
) as get_func:
get_func.return_value = embedding_function
return ChromaDocumentStore(embedding_function="test_function", collection_name=str(uuid.uuid1()))

def assert_documents_are_equal(self, received: list[Document], expected: list[Document]):
"""
Assert that two lists of Documents are equal.
This is used in every test, if a Document Store implementation has a different behaviour
it should override this method.

This can happen for example when the Document Store sets a score to returned Documents.
Since we can't know what the score will be, we can't compare the Documents reliably.
"""
received.sort(key=operator.attrgetter("id"))
expected.sort(key=operator.attrgetter("id"))

for doc_received, doc_expected in zip(received, expected, strict=True):
assert doc_received.content == doc_expected.content
assert doc_received.meta == doc_expected.meta

def test_client_settings_applied(self, clear_chroma_system_cache):
"""
Chroma's in-memory client uses a singleton pattern with an internal cache.
Once a client is created with certain settings, Chroma rejects creating another
with different settings in the same process. We clear the cache before and after
this test to avoid conflicts with other tests that use default settings.
"""
store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False})
store._ensure_initialized()
assert store._client.get_settings().anonymized_telemetry is False

def test_distance_metric_initialization(self):
store = ChromaDocumentStore("test_2", distance_function="cosine")
store._ensure_initialized()
Expand Down Expand Up @@ -445,7 +448,6 @@ def test_comparison_less_than_equal_with_none(self, document_store, filterable_d
def test_not_operator(self, document_store, filterable_docs):
pass

@pytest.mark.integration
def test_search(self):
document_store = ChromaDocumentStore()
documents = [
Expand Down Expand Up @@ -491,7 +493,6 @@ def test_delete_all_documents_index_recreation(self, document_store: ChromaDocum
document_store.write_documents(docs)
assert document_store.count_documents() == 2

@pytest.mark.integration
def test_search_embeddings(self, document_store: ChromaDocumentStore):
query_embedding = TEST_EMBEDDING_1
documents = [
Expand All @@ -515,6 +516,7 @@ def test_search_embeddings(self, document_store: ChromaDocumentStore):
assert len(result_empty_filters[0]) == 2


@pytest.mark.integration
class TestMetadataOperations:
"""Test new metadata query operations for ChromaDocumentStore"""

Expand Down
1 change: 0 additions & 1 deletion integrations/chroma/tests/test_document_store_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ async def test_update_by_filter_async_no_matches(self, document_store: ChromaDoc
assert updated_count == 0
assert await document_store.count_documents_async() == 2

@pytest.mark.integration
async def test_search_embeddings_async(self, document_store: ChromaDocumentStore):
query_embedding = TEST_EMBEDDING_1
documents = [
Expand Down
61 changes: 31 additions & 30 deletions integrations/elasticsearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,37 @@ def test_client_initialization_with_api_key_string(_mock_async_es, _mock_es):
assert async_call_args[1]["api_key"] == "test_api_key"


@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_init_with_custom_mapping(mock_elasticsearch):
custom_mapping = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
}
}
],
}
mock_client = Mock(
indices=Mock(create=Mock(), exists=Mock(return_value=False)),
)
mock_elasticsearch.return_value = mock_client

_ = ElasticsearchDocumentStore(hosts="http://testhost:9200", custom_mapping=custom_mapping).client
mock_client.indices.create.assert_called_once_with(
index="default",
mappings=custom_mapping,
)


@pytest.mark.integration
class TestDocumentStore(
DocumentStoreBaseExtendedTests,
Expand Down Expand Up @@ -476,36 +507,6 @@ def test_write_documents_different_embedding_sizes_fail(self, document_store: El
with pytest.raises(DocumentStoreError):
document_store.write_documents(docs)

@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_init_with_custom_mapping(self, mock_elasticsearch):
custom_mapping = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
}
}
],
}
mock_client = Mock(
indices=Mock(create=Mock(), exists=Mock(return_value=False)),
)
mock_elasticsearch.return_value = mock_client

_ = ElasticsearchDocumentStore(hosts="http://testhost:9200", custom_mapping=custom_mapping).client
mock_client.indices.create.assert_called_once_with(
index="default",
mappings=custom_mapping,
)

def test_delete_all_documents_index_recreation(self, document_store: ElasticsearchDocumentStore):
# populate the index with some documents
docs = [Document(id="1", content="A first document"), Document(id="2", content="Second document")]
Expand Down
2 changes: 0 additions & 2 deletions integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def test_get_metadata_field_unique_values(self, document_store: MongoDBAtlasDocu
assert len(values_page) == 1
assert values_page[0] in ["alpha", "beta", "gamma"]

@pytest.mark.integration
def test_custom_embedding_field(self):
"""Test that the custom embedding field is correctly used in the document store."""
# Create a document store with a custom embedding field
Expand Down Expand Up @@ -315,7 +314,6 @@ def test_custom_embedding_field(self):
finally:
database[collection_name].drop()

@pytest.mark.integration
def test_custom_content_field(self):
"""Test that the custom content field is correctly used in the document store."""
# Create a document store with a custom content field
Expand Down
75 changes: 41 additions & 34 deletions integrations/mongodb_atlas/tests/test_fulltext_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from time import sleep
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from haystack import Document
Expand All @@ -13,9 +13,11 @@
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


def get_document_store(**kwargs):
def get_document_store(connection_string=None, **kwargs):
if connection_string is None:
connection_string = Secret.from_env_var("MONGO_CONNECTION_STRING_2")
return MongoDBAtlasDocumentStore(
mongo_connection_string=Secret.from_env_var("MONGO_CONNECTION_STRING_2"),
mongo_connection_string=connection_string,
database_name="haystack_test",
collection_name="test_collection",
vector_search_index="cosine_index",
Expand All @@ -24,36 +26,12 @@ def get_document_store(**kwargs):
)


@pytest.mark.skipif(
not os.environ.get("MONGO_CONNECTION_STRING_2"),
reason="No MongoDB Atlas connection string provided",
@patch(
"haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore._ensure_connection_setup"
)
@pytest.mark.integration
class TestFullTextRetrieval:
@pytest.fixture(scope="class")
def document_store(self) -> MongoDBAtlasDocumentStore:
return get_document_store()

@pytest.fixture(autouse=True, scope="class")
def setup_teardown(self, document_store):
document_store._ensure_connection_setup()
document_store._collection.delete_many({})
document_store.write_documents(
[
Document(content="The quick brown fox chased the dog", meta={"meta_field": "right_value"}),
Document(content="The fox was brown", meta={"meta_field": "right_value"}),
Document(content="The lazy dog"),
Document(content="fox fox fox"),
]
)

# Wait for documents to be indexed
sleep(5)

yield

def test_pipeline_correctly_passes_parameters(self, document_store):
document_store = get_document_store()
class TestFullTextRetrievalUnit:
def test_pipeline_correctly_passes_parameters(self, _mock_setup):
document_store = get_document_store(connection_string=Secret.from_token("test"))
mock_collection = MagicMock()
document_store._collection = mock_collection
mock_collection.aggregate.return_value = []
Expand Down Expand Up @@ -98,9 +76,9 @@ def test_pipeline_correctly_passes_parameters(self, document_store):
# Explicitly verify that the path in the text search is using the content_field
assert actual_pipeline[0]["$search"]["compound"]["must"][0]["text"]["path"] == document_store.content_field

def test_pipeline_with_custom_content_field(self, document_store):
def test_pipeline_with_custom_content_field(self, _mock_setup):
# Create a document store with a custom content field
document_store = get_document_store(content_field="custom_text")
document_store = get_document_store(connection_string=Secret.from_token("test"), content_field="custom_text")
mock_collection = MagicMock()
document_store._collection = mock_collection
mock_collection.aggregate.return_value = []
Expand All @@ -125,6 +103,35 @@ def test_pipeline_with_custom_content_field(self, document_store):
assert "$addFields" in actual_pipeline[3]
assert "$project" in actual_pipeline[4]


@pytest.mark.skipif(
not os.environ.get("MONGO_CONNECTION_STRING_2"),
reason="No MongoDB Atlas connection string provided",
)
@pytest.mark.integration
class TestFullTextRetrieval:
@pytest.fixture(scope="class")
def document_store(self) -> MongoDBAtlasDocumentStore:
return get_document_store()

@pytest.fixture(autouse=True, scope="class")
def setup_teardown(self, document_store):
document_store._ensure_connection_setup()
document_store._collection.delete_many({})
document_store.write_documents(
[
Document(content="The quick brown fox chased the dog", meta={"meta_field": "right_value"}),
Document(content="The fox was brown", meta={"meta_field": "right_value"}),
Document(content="The lazy dog"),
Document(content="fox fox fox"),
]
)

# Wait for documents to be indexed
sleep(5)

yield

def test_query_retrieval(self, document_store: MongoDBAtlasDocumentStore):
results = document_store._fulltext_retrieval(query="fox", top_k=2)
assert len(results) == 2
Expand Down
Loading
Loading