Skip to content

Commit aa90109

Browse files
authored
test: add unit tests for falkordb (#3283)
1 parent 891ebe2 commit aa90109

1 file changed

Lines changed: 224 additions & 0 deletions

File tree

integrations/falkordb/tests/test_document_store.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,25 @@
44

55
import logging
66
import os
7+
from types import SimpleNamespace
8+
from unittest.mock import MagicMock
79

10+
import falkordb as _falkordb_module
811
import pytest
912
from haystack.dataclasses import Document
13+
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
14+
from haystack.document_stores.types import DuplicatePolicy
15+
from haystack.errors import FilterError
1016
from haystack.testing.document_store import DocumentStoreBaseTests
1117

1218
from haystack_integrations.components.retrievers.falkordb import (
1319
FalkorDBCypherRetriever,
1420
FalkorDBEmbeddingRetriever,
1521
)
1622
from haystack_integrations.document_stores.falkordb import FalkorDBDocumentStore
23+
from haystack_integrations.document_stores.falkordb.document_store import (
24+
_convert_filters,
25+
)
1726

1827
logger = logging.getLogger(__name__)
1928

@@ -44,6 +53,221 @@ def test_to_dict_from_dict(self):
4453
assert restored.similarity == "euclidean"
4554

4655

56+
def _result(rows):
57+
return MagicMock(result_set=rows)
58+
59+
60+
@pytest.fixture
61+
def mock_falkordb(monkeypatch):
62+
constructor = MagicMock()
63+
client = MagicMock()
64+
graph = MagicMock()
65+
constructor.return_value = client
66+
client.select_graph.return_value = graph
67+
graph.query.return_value = _result([])
68+
monkeypatch.setattr(_falkordb_module, "FalkorDB", constructor)
69+
return constructor, client, graph
70+
71+
72+
class TestFalkorDBDocumentStoreUnit:
73+
def test_init_rejects_invalid_similarity(self):
74+
with pytest.raises(ValueError, match="not supported"):
75+
FalkorDBDocumentStore(similarity="invalid")
76+
77+
@pytest.mark.parametrize(
78+
"filter_node, expected_clause, expected_params",
79+
[
80+
({"field": "year", "operator": "==", "value": 2024}, "coalesce(d.year = $p0, false)", {"p0": 2024}),
81+
({"field": "year", "operator": "==", "value": None}, "d.year IS NULL", {}),
82+
({"field": "y", "operator": "!=", "value": None}, "d.y IS NOT NULL", {}),
83+
({"field": "y", "operator": ">", "value": 1}, "coalesce(d.y > $p0, false)", {"p0": 1}),
84+
({"field": "y", "operator": ">", "value": None}, "false", {}),
85+
(
86+
{"field": "tag", "operator": "in", "value": ["a"]},
87+
"coalesce(d.tag IN $p0, false)",
88+
{"p0": ["a"]},
89+
),
90+
(
91+
{"field": "tag", "operator": "not in", "value": ["a"]},
92+
"coalesce(NOT (d.tag IN $p0), true)",
93+
{"p0": ["a"]},
94+
),
95+
(
96+
{"field": "meta.year", "operator": "==", "value": 2024},
97+
"coalesce(d.year = $p0, false)",
98+
{"p0": 2024},
99+
),
100+
(
101+
{
102+
"operator": "OR",
103+
"conditions": [
104+
{"field": "y", "operator": "==", "value": 1},
105+
{"operator": "NOT", "conditions": [{"field": "b", "operator": "<=", "value": 0}]},
106+
],
107+
},
108+
"(coalesce(d.y = $p0, false) OR NOT (coalesce(d.b <= $p1, false)))",
109+
{"p0": 1, "p1": 0},
110+
),
111+
],
112+
)
113+
def test_convert_filters(self, filter_node, expected_clause, expected_params):
114+
clause, params = _convert_filters(filter_node)
115+
assert clause == expected_clause
116+
assert params == expected_params
117+
118+
@pytest.mark.parametrize(
119+
"filter_node, match",
120+
[
121+
({"operator": "AND"}, "requires a 'conditions' key"),
122+
({"operator": "NOT"}, "requires a 'conditions' key"),
123+
({"operator": "==", "value": 1}, "requires a 'field' key"),
124+
({"operator": "==", "field": "f"}, "requires a 'value' key"),
125+
({"field": "x", "operator": ">", "value": [1, 2]}, "does not support list values"),
126+
({"field": "x", "operator": ">", "value": "not-a-date"}, "non-ISO string"),
127+
({"field": "x", "operator": "in", "value": "scalar"}, "requires a list value"),
128+
({"field": "x", "operator": "not in", "value": "scalar"}, "requires a list value"),
129+
({"field": "x", "operator": "regex", "value": "."}, "Unsupported filter operator"),
130+
],
131+
)
132+
def test_convert_filters_errors(self, filter_node, match):
133+
with pytest.raises(FilterError, match=match):
134+
_convert_filters(filter_node)
135+
136+
@pytest.mark.parametrize("rows, expected", [([[42]], 42), ([], 0)])
137+
def test_count_documents(self, mock_falkordb, rows, expected):
138+
_, _, graph = mock_falkordb
139+
graph.query.side_effect = [_result([]), _result([]), _result(rows)]
140+
assert FalkorDBDocumentStore().count_documents() == expected
141+
142+
def test_filter_documents_no_filters(self, mock_falkordb):
143+
_, _, graph = mock_falkordb
144+
node = SimpleNamespace(properties={"id": "n1", "content": "hello"})
145+
graph.query.side_effect = [_result([]), _result([]), _result([[node]])]
146+
docs = FalkorDBDocumentStore().filter_documents()
147+
assert [d.content for d in docs] == ["hello"]
148+
assert "WHERE" not in graph.query.call_args_list[-1].args[0]
149+
150+
def test_filter_documents_with_filters(self, mock_falkordb):
151+
_, _, graph = mock_falkordb
152+
graph.query.side_effect = [_result([]), _result([]), _result([])]
153+
FalkorDBDocumentStore().filter_documents({"field": "year", "operator": "==", "value": 2024})
154+
last = graph.query.call_args_list[-1]
155+
assert "WHERE" in last.args[0]
156+
assert last.args[1] == {"p0": 2024}
157+
158+
@pytest.mark.usefixtures("mock_falkordb")
159+
def test_filter_documents_malformed_raises(self):
160+
with pytest.raises(FilterError, match="Invalid filter syntax"):
161+
FalkorDBDocumentStore().filter_documents({"field": "year", "value": 2024})
162+
163+
def test_delete_documents_empty_is_noop(self, mock_falkordb):
164+
_, _, graph = mock_falkordb
165+
FalkorDBDocumentStore().delete_documents([])
166+
assert graph.query.call_count == 2
167+
168+
def test_delete_documents_runs_query(self, mock_falkordb):
169+
_, _, graph = mock_falkordb
170+
FalkorDBDocumentStore().delete_documents(["a", "b"])
171+
last = graph.query.call_args_list[-1]
172+
assert "DETACH DELETE" in last.args[0]
173+
assert last.args[1] == {"ids": ["a", "b"]}
174+
175+
@pytest.mark.parametrize(
176+
"similarity, raw, scale_score, filters, expected_score, where_expected",
177+
[
178+
("cosine", 0.4, True, None, 0.8, False),
179+
("cosine", 0.4, False, None, 0.4, False),
180+
("euclidean", 1.0, True, {"field": "y", "operator": "==", "value": 1}, 0.5, True),
181+
],
182+
)
183+
def test_embedding_retrieval(
184+
self, mock_falkordb, similarity, raw, scale_score, filters, expected_score, where_expected
185+
):
186+
_, _, graph = mock_falkordb
187+
node = SimpleNamespace(properties={"id": "n1", "content": "hello"})
188+
graph.query.side_effect = [_result([]), _result([]), _result([[node, raw]])]
189+
docs = FalkorDBDocumentStore(similarity=similarity)._embedding_retrieval(
190+
query_embedding=[0.1], top_k=5, filters=filters, scale_score=scale_score
191+
)
192+
assert docs[0].score == pytest.approx(expected_score)
193+
assert ("WHERE" in graph.query.call_args_list[-1].args[0]) is where_expected
194+
195+
def test_cypher_retrieval_returns_documents(self, mock_falkordb):
196+
_, _, graph = mock_falkordb
197+
node = SimpleNamespace(properties={"id": "n1", "content": "hello"})
198+
graph.query.side_effect = [_result([]), _result([]), _result([[node]])]
199+
docs = FalkorDBDocumentStore()._cypher_retrieval("MATCH (d) RETURN d", parameters={"k": 1})
200+
assert docs[0].content == "hello"
201+
202+
@pytest.mark.usefixtures("mock_falkordb")
203+
def test_write_documents_rejects_non_documents(self):
204+
with pytest.raises(ValueError, match="expects a list of Documents"):
205+
FalkorDBDocumentStore().write_documents(["not a doc"])
206+
207+
@pytest.mark.usefixtures("mock_falkordb")
208+
def test_write_documents_empty_list_returns_zero(self, caplog):
209+
with caplog.at_level(logging.WARNING):
210+
assert FalkorDBDocumentStore().write_documents([]) == 0
211+
assert "empty list" in caplog.text
212+
213+
def test_write_documents_policy_none_coerced_to_fail(self, mock_falkordb):
214+
_, _, graph = mock_falkordb
215+
graph.query.side_effect = [_result([]), _result([]), _result([["a"]])]
216+
with pytest.raises(DuplicateDocumentError):
217+
FalkorDBDocumentStore().write_documents([Document(id="a", content="x")])
218+
219+
def test_write_documents_drops_duplicates_within_batch(self, mock_falkordb, caplog):
220+
_, _, graph = mock_falkordb
221+
graph.query.side_effect = [_result([]), _result([]), _result([]), _result([[2]])]
222+
with caplog.at_level(logging.INFO):
223+
written = FalkorDBDocumentStore().write_documents(
224+
[
225+
Document(id="a", content="x"),
226+
Document(id="a", content="x"),
227+
Document(id="b", content="y"),
228+
],
229+
policy=DuplicatePolicy.SKIP,
230+
)
231+
assert written == 2
232+
assert "already present in the batch" in caplog.text
233+
sent_ids = [d["id"] for d in graph.query.call_args_list[-1].args[1]["docs"]]
234+
assert sent_ids == ["a", "b"]
235+
236+
def test_write_documents_skip_filters_existing(self, mock_falkordb):
237+
_, _, graph = mock_falkordb
238+
graph.query.side_effect = [_result([]), _result([]), _result([["a"]]), _result([[1]])]
239+
written = FalkorDBDocumentStore().write_documents(
240+
[Document(id="a", content="x"), Document(id="b", content="y")],
241+
policy=DuplicatePolicy.SKIP,
242+
)
243+
assert written == 1
244+
sent_ids = [d["id"] for d in graph.query.call_args_list[-1].args[1]["docs"]]
245+
assert sent_ids == ["b"]
246+
247+
def test_write_documents_overwrite_uses_on_match_set(self, mock_falkordb):
248+
_, _, graph = mock_falkordb
249+
graph.query.side_effect = [_result([]), _result([]), _result([[1]])]
250+
FalkorDBDocumentStore().write_documents([Document(id="a", content="x")], policy=DuplicatePolicy.OVERWRITE)
251+
assert "ON MATCH SET d = doc" in graph.query.call_args_list[-1].args[0]
252+
253+
def test_write_documents_embeddings_second_pass(self, mock_falkordb):
254+
_, _, graph = mock_falkordb
255+
graph.query.side_effect = [_result([]), _result([]), _result([[1]]), _result([])]
256+
FalkorDBDocumentStore().write_documents(
257+
[Document(id="a", content="x", embedding=[0.1, 0.2, 0.3])],
258+
policy=DuplicatePolicy.OVERWRITE,
259+
)
260+
last = graph.query.call_args_list[-1]
261+
assert "vecf32" in last.args[0]
262+
assert last.args[1]["docs"] == [{"id": "a", "emb": [0.1, 0.2, 0.3]}]
263+
264+
def test_write_documents_wraps_errors(self, mock_falkordb):
265+
_, _, graph = mock_falkordb
266+
graph.query.side_effect = [_result([]), _result([]), Exception("boom")]
267+
with pytest.raises(DocumentStoreError, match="Failed to write documents"):
268+
FalkorDBDocumentStore().write_documents([Document(id="a", content="x")], policy=DuplicatePolicy.OVERWRITE)
269+
270+
47271
@pytest.mark.integration
48272
class TestDocumentStore(DocumentStoreBaseTests):
49273
"""

0 commit comments

Comments
 (0)