diff --git a/integrations/qdrant/tests/test_converters.py b/integrations/qdrant/tests/test_converters.py index fc6ba784f2..4a96cb46bb 100644 --- a/integrations/qdrant/tests/test_converters.py +++ b/integrations/qdrant/tests/test_converters.py @@ -1,7 +1,11 @@ import numpy as np +from haystack.dataclasses import Document, SparseEmbedding from qdrant_client.http import models as rest from haystack_integrations.document_stores.qdrant.converters import ( + DENSE_VECTORS_NAME, + SPARSE_VECTORS_NAME, + convert_haystack_documents_to_qdrant_points, convert_id, convert_qdrant_point_to_haystack_document, ) @@ -60,3 +64,63 @@ def test_point_to_document_reverts_proper_structure_from_record_without_sparse() assert document.sparse_embedding is None assert {"test_field": 1} == document.meta assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding) + + +def test_point_to_document_with_sparse_enabled_but_vector_none(): + point = rest.Record( + id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", + payload={"id": "my-id", "content": "Lorem"}, + vector=None, + ) + document = convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) + assert document.embedding is None + assert document.sparse_embedding is None + + +def test_point_to_document_preserves_score_from_scored_point(): + point = rest.ScoredPoint( + id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", + payload={"id": "my-id", "content": "Lorem"}, + vector=[0.1, 0.2], + score=0.75, + version=0, + ) + document = convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=False) + assert document.score == 0.75 + + +def test_convert_haystack_documents_to_qdrant_points_without_sparse(): + doc = Document(content="hello", embedding=[0.1, 0.2, 0.3]) + points = convert_haystack_documents_to_qdrant_points([doc], use_sparse_embeddings=False) + assert len(points) == 1 + assert points[0].vector == [0.1, 0.2, 0.3] + assert points[0].payload["content"] == "hello" + assert "embedding" not in points[0].payload + + +def test_convert_haystack_documents_to_qdrant_points_without_sparse_without_embedding(): + doc = Document(content="hello") + points = convert_haystack_documents_to_qdrant_points([doc], use_sparse_embeddings=False) + assert points[0].vector == {} + + +def test_convert_haystack_documents_to_qdrant_points_with_sparse(): + sparse = SparseEmbedding(indices=[0, 5], values=[0.1, 0.7]) + doc = Document(content="hello", embedding=[0.1, 0.2], sparse_embedding=sparse) + points = convert_haystack_documents_to_qdrant_points([doc], use_sparse_embeddings=True) + assert points[0].vector[DENSE_VECTORS_NAME] == [0.1, 0.2] + assert isinstance(points[0].vector[SPARSE_VECTORS_NAME], rest.SparseVector) + assert points[0].vector[SPARSE_VECTORS_NAME].indices == [0, 5] + assert points[0].vector[SPARSE_VECTORS_NAME].values == [0.1, 0.7] + + +def test_convert_haystack_documents_to_qdrant_points_with_sparse_only_dense(): + doc = Document(content="hello", embedding=[0.1, 0.2]) + points = convert_haystack_documents_to_qdrant_points([doc], use_sparse_embeddings=True) + assert points[0].vector == {DENSE_VECTORS_NAME: [0.1, 0.2]} + + +def test_convert_haystack_documents_to_qdrant_points_with_sparse_no_vectors(): + doc = Document(content="hello") + points = convert_haystack_documents_to_qdrant_points([doc], use_sparse_embeddings=True) + assert points[0].vector == {} diff --git a/integrations/qdrant/tests/test_document_store.py b/integrations/qdrant/tests/test_document_store.py index 170e036a19..35ea659101 100644 --- a/integrations/qdrant/tests/test_document_store.py +++ b/integrations/qdrant/tests/test_document_store.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -28,6 +29,7 @@ SPARSE_VECTORS_NAME, QdrantDocumentStore, QdrantStoreError, + get_batches_from_generator, ) @@ -180,6 +182,166 @@ def test_set_up_collection_with_dimension_mismatch(self): with pytest.raises(ValueError, match="different vector size"): document_store._set_up_collection("test_collection", 768, False, "cosine", False, False) + def test_get_distance_known(self): + document_store = QdrantDocumentStore(location=":memory:") + assert document_store.get_distance("cosine") == rest.Distance.COSINE + assert document_store.get_distance("dot_product") == rest.Distance.DOT + assert document_store.get_distance("l2") == rest.Distance.EUCLID + + def test_get_distance_unknown_raises(self): + document_store = QdrantDocumentStore(location=":memory:") + with pytest.raises(QdrantStoreError, match="not supported"): + document_store.get_distance("unknown") + + def test_validate_filters_accepts_dict_and_native(self): + QdrantDocumentStore._validate_filters(None) + QdrantDocumentStore._validate_filters({"operator": "==", "field": "meta.x", "value": 1}) + QdrantDocumentStore._validate_filters(rest.Filter(must=[])) + + def test_validate_filters_rejects_non_dict_non_filter(self): + with pytest.raises(ValueError, match="must be a dictionary"): + QdrantDocumentStore._validate_filters("not-a-filter") + + def test_validate_filters_rejects_dict_without_operator(self): + with pytest.raises(ValueError, match="Invalid filter syntax"): + QdrantDocumentStore._validate_filters({"field": "meta.x"}) + + def test_check_stop_scrolling(self): + assert QdrantDocumentStore._check_stop_scrolling(None) is True + empty_offset = SimpleNamespace(num=0, uuid="") + assert QdrantDocumentStore._check_stop_scrolling(empty_offset) is True + non_empty_offset = SimpleNamespace(num=5, uuid="abc") + assert QdrantDocumentStore._check_stop_scrolling(non_empty_offset) is False + + def test_infer_type_from_value(self): + assert QdrantDocumentStore._infer_type_from_value(True) == "boolean" + assert QdrantDocumentStore._infer_type_from_value(1) == "long" + assert QdrantDocumentStore._infer_type_from_value(1.5) == "float" + assert QdrantDocumentStore._infer_type_from_value("x") == "keyword" + assert QdrantDocumentStore._infer_type_from_value([1, 2]) == "keyword" + + def test_process_records_fields_info(self): + records = [ + SimpleNamespace(payload={"meta": {"category": "A", "score": 0.9, "missing": None}}), + SimpleNamespace(payload={"meta": {"category": "B"}}), # category already seen + SimpleNamespace(payload=None), # no payload + SimpleNamespace(payload={"other": "noise"}), # no meta + ] + field_info: dict = {} + QdrantDocumentStore._process_records_fields_info(records, field_info) + assert field_info == {"category": {"type": "keyword"}, "score": {"type": "float"}} + + def test_metadata_fields_info_from_schema(self): + schema = { + "meta.category": SimpleNamespace(data_type="keyword"), + "meta.priority": SimpleNamespace(data_type="integer"), + "meta.unknown": object(), # no data_type attribute + "not_meta_prefixed": SimpleNamespace(data_type="keyword"), + } + fields = QdrantDocumentStore._metadata_fields_info_from_schema(schema) + assert fields == { + "category": {"type": "keyword"}, + "priority": {"type": "integer"}, + "unknown": {"type": "unknown"}, + } + + def test_process_records_min_max(self): + records = [ + SimpleNamespace(payload={"meta": {"score": 0.5}}), + SimpleNamespace(payload={"meta": {"score": 0.9}}), + SimpleNamespace(payload={"meta": {"score": None}}), + SimpleNamespace(payload={"meta": {"other": 100}}), + SimpleNamespace(payload=None), + ] + min_v, max_v = QdrantDocumentStore._process_records_min_max(records, "score", None, None) + assert min_v == 0.5 + assert max_v == 0.9 + + def test_process_records_count_unique(self): + records = [ + SimpleNamespace(payload={"meta": {"category": "A", "tags": ["x"]}}), + SimpleNamespace(payload={"meta": {"category": "B", "tags": ["x"]}}), + SimpleNamespace(payload={"meta": {"category": "A", "tags": ["y"]}}), + SimpleNamespace(payload=None), + ] + unique: dict = {"category": set(), "tags": set()} + QdrantDocumentStore._process_records_count_unique(records, ["category", "tags"], unique) + assert unique["category"] == {"A", "B"} + assert unique["tags"] == {"['x']", "['y']"} + + def test_process_records_unique_values_stops_when_filled(self): + records = [SimpleNamespace(payload={"meta": {"v": i}}) for i in range(10)] + values: list = [] + values_set: set = set() + done = QdrantDocumentStore._process_records_unique_values(records, "v", values, values_set, offset=0, limit=3) + assert done is True + assert values[:3] == [0, 1, 2] + + def test_process_records_unique_values_not_done(self): + records = [SimpleNamespace(payload={"meta": {"v": 1}}), SimpleNamespace(payload=None)] + values: list = [] + values_set: set = set() + done = QdrantDocumentStore._process_records_unique_values(records, "v", values, values_set, offset=0, limit=5) + assert done is False + assert values == [1] + + def test_create_updated_point_from_record_adds_missing_meta(self): + record = SimpleNamespace( + id="abc", + payload={"content": "hello"}, + vector=[0.1, 0.2], + ) + point = QdrantDocumentStore._create_updated_point_from_record(record, {"status": "published"}) + assert point.payload["meta"] == {"status": "published"} + assert point.payload["content"] == "hello" + assert point.vector == [0.1, 0.2] + + def test_create_updated_point_from_record_merges_meta(self): + record = SimpleNamespace( + id="abc", + payload={"content": "hello", "meta": {"category": "A"}}, + vector=None, + ) + point = QdrantDocumentStore._create_updated_point_from_record(record, {"status": "published"}) + assert point.payload["meta"] == {"category": "A", "status": "published"} + assert point.vector == {} + + def test_drop_duplicate_documents(self): + document_store = QdrantDocumentStore(location=":memory:") + doc1 = Document(id="1", content="a") + doc2 = Document(id="2", content="b") + doc1_dup = Document(id="1", content="a") + result = document_store._drop_duplicate_documents([doc1, doc2, doc1_dup]) + assert [d.id for d in result] == ["1", "2"] + + def test_prepare_collection_config_without_sparse(self): + document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=False) + vectors_config, sparse_config = document_store._prepare_collection_config( + embedding_dim=768, distance=rest.Distance.COSINE + ) + assert isinstance(vectors_config, rest.VectorParams) + assert sparse_config is None + + def test_prepare_collection_config_with_sparse_and_idf(self): + document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) + vectors_config, sparse_config = document_store._prepare_collection_config( + embedding_dim=768, distance=rest.Distance.COSINE, sparse_idf=True + ) + assert DENSE_VECTORS_NAME in vectors_config + assert sparse_config[SPARSE_VECTORS_NAME].modifier == rest.Modifier.IDF + + def test_prepare_client_params_does_not_mutate_metadata(self): + metadata = {"key": "value"} + document_store = QdrantDocumentStore(location=":memory:", metadata=metadata) + params = document_store._prepare_client_params() + params["metadata"]["added"] = "x" + assert metadata == {"key": "value"} + + def test_get_batches_from_generator(self): + batches = list(get_batches_from_generator([1, 2, 3, 4, 5], 2)) + assert batches == [(1, 2), (3, 4), (5,)] + assert list(get_batches_from_generator([], 2)) == [] + @pytest.mark.integration class TestQdrantDocumentStore( diff --git a/integrations/qdrant/tests/test_embedding_retriever.py b/integrations/qdrant/tests/test_embedding_retriever.py index 59cc8c3f21..b2e9c9206b 100644 --- a/integrations/qdrant/tests/test_embedding_retriever.py +++ b/integrations/qdrant/tests/test_embedding_retriever.py @@ -1,3 +1,6 @@ +from dataclasses import replace +from unittest.mock import AsyncMock, Mock + import pytest from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy @@ -5,6 +8,7 @@ FilterableDocsFixtureMixin, _random_embeddings, ) +from qdrant_client.http import models as rest from haystack_integrations.components.retrievers.qdrant import ( QdrantEmbeddingRetriever, @@ -13,6 +17,10 @@ class TestQdrantRetriever: + def test_init_raises_when_document_store_is_not_qdrant(self): + with pytest.raises(ValueError, match="must be an instance of QdrantDocumentStore"): + QdrantEmbeddingRetriever(document_store="not a document store") + def test_init_default(self): document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=False) retriever = QdrantEmbeddingRetriever(document_store=document_store) @@ -118,6 +126,67 @@ def test_from_dict(self): assert retriever._group_by is None assert retriever._group_size is None + def test_run(self): + mock_store = Mock(spec=QdrantDocumentStore) + mock_store._query_by_embedding.return_value = [Document(content="doc", embedding=[0.1, 0.2])] + + retriever = QdrantEmbeddingRetriever( + document_store=mock_store, + filters={"field": "meta.name", "operator": "==", "value": "foo"}, + top_k=7, + scale_score=True, + return_embedding=True, + score_threshold=0.2, + ) + res = retriever.run(query_embedding=[0.5, 0.7], top_k=3) + + call_kwargs = mock_store._query_by_embedding.call_args.kwargs + assert call_kwargs["query_embedding"] == [0.5, 0.7] + assert call_kwargs["top_k"] == 3 + assert call_kwargs["scale_score"] is True + assert call_kwargs["return_embedding"] is True + assert call_kwargs["score_threshold"] == 0.2 + assert call_kwargs["filters"] == {"field": "meta.name", "operator": "==", "value": "foo"} + assert res["documents"][0].content == "doc" + + @pytest.mark.asyncio + async def test_run_async(self): + mock_store = Mock(spec=QdrantDocumentStore) + mock_store._query_by_embedding_async = AsyncMock(return_value=[Document(content="doc", embedding=[0.1])]) + + retriever = QdrantEmbeddingRetriever(document_store=mock_store) + res = await retriever.run_async(query_embedding=[0.5]) + + mock_store._query_by_embedding_async.assert_awaited_once() + assert res["documents"][0].content == "doc" + + def test_run_raises_when_merge_with_native_init_filter(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantEmbeddingRetriever( + document_store=document_store, + filters=rest.Filter(must=[]), + filter_policy=FilterPolicy.MERGE, + ) + with pytest.raises(ValueError, match="Native Qdrant filters"): + retriever.run(query_embedding=[0.1]) + + def test_run_raises_when_merge_with_native_runtime_filter(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantEmbeddingRetriever(document_store=document_store, filter_policy=FilterPolicy.MERGE) + with pytest.raises(ValueError, match="Native Qdrant filters"): + retriever.run(query_embedding=[0.1], filters=rest.Filter(must=[])) + + @pytest.mark.asyncio + async def test_run_async_raises_when_merge_with_native_filter(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantEmbeddingRetriever( + document_store=document_store, + filters=rest.Filter(must=[]), + filter_policy=FilterPolicy.MERGE, + ) + with pytest.raises(ValueError, match="Native Qdrant filters"): + await retriever.run_async(query_embedding=[0.1]) + @pytest.mark.integration class TestQdrantEmbeddingRetrieverIntegration(FilterableDocsFixtureMixin): @@ -208,8 +277,10 @@ def test_run_with_sparse_activated(self, filterable_docs: list[Document]): def test_run_with_group_by(self, filterable_docs: list[Document]): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) # Add group_field metadata to documents - for index, doc in enumerate(filterable_docs): - doc.meta = {"group_field": index // 2} # So at least two docs have same group each time + filterable_docs = [ + replace(doc, meta={"group_field": index // 2}) # So at least two docs have same group each time + for index, doc in enumerate(filterable_docs) + ] document_store.write_documents(filterable_docs) retriever = QdrantEmbeddingRetriever(document_store=document_store) @@ -315,8 +386,10 @@ async def test_run_with_sparse_activated_async(self, filterable_docs: list[Docum async def test_run_with_group_by_async(self, filterable_docs: list[Document]): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) # Add group_field metadata to documents - for index, doc in enumerate(filterable_docs): - doc.meta = {"group_field": index // 2} # So at least two docs have same group each time + filterable_docs = [ + replace(doc, meta={"group_field": index // 2}) # So at least two docs have same group each time + for index, doc in enumerate(filterable_docs) + ] await document_store.write_documents_async(filterable_docs) retriever = QdrantEmbeddingRetriever(document_store=document_store) diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py index 0470157e4d..dc1ca93312 100644 --- a/integrations/qdrant/tests/test_filters.py +++ b/integrations/qdrant/tests/test_filters.py @@ -5,6 +5,84 @@ from qdrant_client.http import models from haystack_integrations.document_stores.qdrant import QdrantDocumentStore +from haystack_integrations.document_stores.qdrant.filters import ( + convert_filters_to_qdrant, + is_datetime_string, +) + + +class TestConvertFiltersToQdrantUnit: + def test_native_filter_passthrough(self): + native = models.Filter(must=[]) + assert convert_filters_to_qdrant(native) is native + + def test_none_or_empty_returns_none(self): + assert convert_filters_to_qdrant(None) is None + assert convert_filters_to_qdrant({}) is None + assert convert_filters_to_qdrant([]) is None + + @pytest.mark.parametrize( + ("filter_input", "match"), + [ + ({"field": "meta.x", "value": 1}, "Operator not found"), + ({"operator": "~=", "field": "meta.x", "value": 1}, "Unknown operator"), + ({"operator": "AND"}, "'conditions' not found"), + ({"operator": "==", "value": 1}, "'field' or 'value' not found"), + ({"operator": "==", "field": "meta.x"}, "'field' or 'value' not found"), + ({"operator": "in", "field": "meta.x", "value": "not-a-list"}, "is not a list"), + ({"operator": "not in", "field": "meta.x", "value": "not-a-list"}, "is not a list"), + ], + ) + def test_invalid_filter_raises(self, filter_input, match): + with pytest.raises(FilterError, match=match): + convert_filters_to_qdrant(filter_input) + + @pytest.mark.parametrize("operator", ["<", "<=", ">", ">="]) + def test_range_operators_reject_non_numeric_non_datetime(self, operator): + with pytest.raises(FilterError, match="not an int or float or datetime string"): + convert_filters_to_qdrant({"operator": operator, "field": "meta.x", "value": "not-a-date"}) + + @pytest.mark.parametrize("operator", ["<", "<=", ">", ">="]) + def test_range_operators_accept_datetime_strings(self, operator): + qdrant_filter = convert_filters_to_qdrant( + {"operator": operator, "field": "meta.created_at", "value": "2024-01-01T00:00:00"} + ) + assert isinstance(qdrant_filter, models.Filter) + + def test_eq_with_spaces_uses_text_match(self): + qdrant_filter = convert_filters_to_qdrant({"operator": "==", "field": "meta.title", "value": "hello world"}) + condition = qdrant_filter.must[0] + assert isinstance(condition.match, models.MatchText) + + def test_eq_without_spaces_uses_value_match(self): + qdrant_filter = convert_filters_to_qdrant({"operator": "==", "field": "meta.name", "value": "name_0"}) + condition = qdrant_filter.must[0] + assert isinstance(condition.match, models.MatchValue) + + def test_single_logical_condition_unwrapped(self): + qdrant_filter = convert_filters_to_qdrant( + { + "operator": "AND", + "conditions": [{"operator": "==", "field": "meta.x", "value": 1}], + } + ) + # AND of single condition gets returned directly as the inner Filter + assert isinstance(qdrant_filter, models.Filter) + + def test_multiple_top_level_conditions_combined_with_and(self): + qdrant_filter = convert_filters_to_qdrant( + [ + {"operator": "==", "field": "meta.a", "value": 1}, + {"operator": "==", "field": "meta.b", "value": 2}, + ] + ) + assert isinstance(qdrant_filter, models.Filter) + assert len(qdrant_filter.must) == 2 + + +def test_is_datetime_string(): + assert is_datetime_string("2024-01-01T00:00:00") is True + assert is_datetime_string("not a date") is False @pytest.mark.integration diff --git a/integrations/qdrant/tests/test_hybrid_retriever.py b/integrations/qdrant/tests/test_hybrid_retriever.py index 3deaf91dd2..e16fc00f14 100644 --- a/integrations/qdrant/tests/test_hybrid_retriever.py +++ b/integrations/qdrant/tests/test_hybrid_retriever.py @@ -3,6 +3,7 @@ import pytest from haystack.dataclasses import Document, SparseEmbedding from haystack.document_stores.types import FilterPolicy +from qdrant_client.http import models as rest from haystack_integrations.components.retrievers.qdrant import ( QdrantHybridRetriever, @@ -11,6 +12,10 @@ class TestQdrantHybridRetriever: + def test_init_raises_when_document_store_is_not_qdrant(self): + with pytest.raises(ValueError, match="must be an instance of QdrantDocumentStore"): + QdrantHybridRetriever(document_store="not a document store") + def test_init_default(self): document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=True) retriever = QdrantHybridRetriever(document_store=document_store) @@ -244,3 +249,24 @@ async def test_run_with_group_by_async(self): assert result["documents"][0].content == "Test doc" assert result["documents"][0].embedding == [0.1, 0.2] assert result["documents"][0].sparse_embedding == sparse_embedding + + def test_run_raises_when_merge_with_native_filter(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantHybridRetriever( + document_store=document_store, + filters=rest.Filter(must=[]), + filter_policy=FilterPolicy.MERGE, + ) + sparse = SparseEmbedding(indices=[0], values=[0.1]) + with pytest.raises(ValueError, match="Native Qdrant filters"): + retriever.run(query_embedding=[0.1], query_sparse_embedding=sparse) + + @pytest.mark.asyncio + async def test_run_async_raises_when_merge_with_native_filter(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantHybridRetriever(document_store=document_store, filter_policy=FilterPolicy.MERGE) + sparse = SparseEmbedding(indices=[0], values=[0.1]) + with pytest.raises(ValueError, match="Native Qdrant filters"): + await retriever.run_async( + query_embedding=[0.1], query_sparse_embedding=sparse, filters=rest.Filter(must=[]) + ) diff --git a/integrations/qdrant/tests/test_sparse_embedding_retriever.py b/integrations/qdrant/tests/test_sparse_embedding_retriever.py index 6c1a4ff041..b0e38e061d 100644 --- a/integrations/qdrant/tests/test_sparse_embedding_retriever.py +++ b/integrations/qdrant/tests/test_sparse_embedding_retriever.py @@ -1,9 +1,13 @@ +from dataclasses import replace +from unittest.mock import AsyncMock, Mock + import pytest from haystack.dataclasses import Document, SparseEmbedding from haystack.document_stores.types import FilterPolicy from haystack.testing.document_store import ( FilterableDocsFixtureMixin, ) +from qdrant_client.http import models as rest from haystack_integrations.components.retrievers.qdrant import ( QdrantSparseEmbeddingRetriever, @@ -12,6 +16,10 @@ class TestQdrantSparseEmbeddingRetriever: + def test_init_raises_when_document_store_is_not_qdrant(self): + with pytest.raises(ValueError, match="must be an instance of QdrantDocumentStore"): + QdrantSparseEmbeddingRetriever(document_store="not a document store") + def test_init_default(self): document_store = QdrantDocumentStore(location=":memory:", index="test") retriever = QdrantSparseEmbeddingRetriever(document_store=document_store) @@ -146,6 +154,50 @@ def test_from_dict_no_filter_policy(self): assert retriever._group_by is None assert retriever._group_size is None + def test_run(self): + mock_store = Mock(spec=QdrantDocumentStore) + sparse = SparseEmbedding(indices=[0, 5], values=[0.1, 0.7]) + mock_store._query_by_sparse.return_value = [Document(content="doc", sparse_embedding=sparse)] + + retriever = QdrantSparseEmbeddingRetriever(document_store=mock_store) + res = retriever.run(query_sparse_embedding=sparse, top_k=4) + + call_kwargs = mock_store._query_by_sparse.call_args.kwargs + assert call_kwargs["query_sparse_embedding"] == sparse + assert call_kwargs["top_k"] == 4 + assert res["documents"][0].content == "doc" + + @pytest.mark.asyncio + async def test_run_async(self): + mock_store = Mock(spec=QdrantDocumentStore) + sparse = SparseEmbedding(indices=[0, 5], values=[0.1, 0.7]) + mock_store._query_by_sparse_async = AsyncMock(return_value=[Document(content="doc", sparse_embedding=sparse)]) + + retriever = QdrantSparseEmbeddingRetriever(document_store=mock_store) + res = await retriever.run_async(query_sparse_embedding=sparse) + + mock_store._query_by_sparse_async.assert_awaited_once() + assert res["documents"][0].content == "doc" + + def test_run_raises_when_merge_with_native_filter(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantSparseEmbeddingRetriever( + document_store=document_store, + filters=rest.Filter(must=[]), + filter_policy=FilterPolicy.MERGE, + ) + sparse = SparseEmbedding(indices=[0], values=[0.1]) + with pytest.raises(ValueError, match="Native Qdrant filters"): + retriever.run(query_sparse_embedding=sparse) + + @pytest.mark.asyncio + async def test_run_async_raises_when_merge_with_native_filter(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantSparseEmbeddingRetriever(document_store=document_store, filter_policy=FilterPolicy.MERGE) + sparse = SparseEmbedding(indices=[0], values=[0.1]) + with pytest.raises(ValueError, match="Native Qdrant filters"): + await retriever.run_async(query_sparse_embedding=sparse, filters=rest.Filter(must=[])) + @pytest.mark.integration class TestQdrantSparseEmbeddingRetrieverIntegration(FilterableDocsFixtureMixin): @@ -153,8 +205,7 @@ def test_run(self, filterable_docs: list[Document], generate_sparse_embedding): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) # Add fake sparse embedding to documents - for doc in filterable_docs: - doc.sparse_embedding = generate_sparse_embedding() + filterable_docs = [replace(doc, sparse_embedding=generate_sparse_embedding()) for doc in filterable_docs] document_store.write_documents(filterable_docs) retriever = QdrantSparseEmbeddingRetriever(document_store=document_store) @@ -173,9 +224,14 @@ def test_run_with_group_by(self, filterable_docs: list[Document], generate_spars document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) # Add fake sparse embedding to documents - for index, doc in enumerate(filterable_docs): - doc.sparse_embedding = generate_sparse_embedding() - doc.meta = {"group_field": index // 2} # So at least two docs have same group each time + filterable_docs = [ + replace( + doc, + sparse_embedding=generate_sparse_embedding(), + meta={"group_field": index // 2}, # So at least two docs have same group each time + ) + for index, doc in enumerate(filterable_docs) + ] document_store.write_documents(filterable_docs) retriever = QdrantSparseEmbeddingRetriever(document_store=document_store) sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) @@ -197,8 +253,7 @@ async def test_run_async(self, filterable_docs: list[Document], generate_sparse_ document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) # Add fake sparse embedding to documents - for doc in filterable_docs: - doc.sparse_embedding = generate_sparse_embedding() + filterable_docs = [replace(doc, sparse_embedding=generate_sparse_embedding()) for doc in filterable_docs] await document_store.write_documents_async(filterable_docs) retriever = QdrantSparseEmbeddingRetriever(document_store=document_store) @@ -218,9 +273,14 @@ async def test_run_with_group_by_async(self, filterable_docs: list[Document], ge document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) # Add fake sparse embedding to documents - for index, doc in enumerate(filterable_docs): - doc.sparse_embedding = generate_sparse_embedding() - doc.meta = {"group_field": index // 2} # So at least two docs have same group each time + filterable_docs = [ + replace( + doc, + sparse_embedding=generate_sparse_embedding(), + meta={"group_field": index // 2}, # So at least two docs have same group each time + ) + for index, doc in enumerate(filterable_docs) + ] await document_store.write_documents_async(filterable_docs) retriever = QdrantSparseEmbeddingRetriever(document_store=document_store) sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33])