|
4 | 4 |
|
5 | 5 | import logging |
6 | 6 | import os |
| 7 | +from types import SimpleNamespace |
| 8 | +from unittest.mock import MagicMock |
7 | 9 |
|
| 10 | +import falkordb as _falkordb_module |
8 | 11 | import pytest |
9 | 12 | from haystack.dataclasses import Document |
| 13 | +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError |
| 14 | +from haystack.document_stores.types import DuplicatePolicy |
| 15 | +from haystack.errors import FilterError |
10 | 16 | from haystack.testing.document_store import DocumentStoreBaseTests |
11 | 17 |
|
12 | 18 | from haystack_integrations.components.retrievers.falkordb import ( |
13 | 19 | FalkorDBCypherRetriever, |
14 | 20 | FalkorDBEmbeddingRetriever, |
15 | 21 | ) |
16 | 22 | from haystack_integrations.document_stores.falkordb import FalkorDBDocumentStore |
| 23 | +from haystack_integrations.document_stores.falkordb.document_store import ( |
| 24 | + _convert_filters, |
| 25 | +) |
17 | 26 |
|
18 | 27 | logger = logging.getLogger(__name__) |
19 | 28 |
|
@@ -44,6 +53,221 @@ def test_to_dict_from_dict(self): |
44 | 53 | assert restored.similarity == "euclidean" |
45 | 54 |
|
46 | 55 |
|
| 56 | +def _result(rows): |
| 57 | + return MagicMock(result_set=rows) |
| 58 | + |
| 59 | + |
| 60 | +@pytest.fixture |
| 61 | +def mock_falkordb(monkeypatch): |
| 62 | + constructor = MagicMock() |
| 63 | + client = MagicMock() |
| 64 | + graph = MagicMock() |
| 65 | + constructor.return_value = client |
| 66 | + client.select_graph.return_value = graph |
| 67 | + graph.query.return_value = _result([]) |
| 68 | + monkeypatch.setattr(_falkordb_module, "FalkorDB", constructor) |
| 69 | + return constructor, client, graph |
| 70 | + |
| 71 | + |
| 72 | +class TestFalkorDBDocumentStoreUnit: |
| 73 | + def test_init_rejects_invalid_similarity(self): |
| 74 | + with pytest.raises(ValueError, match="not supported"): |
| 75 | + FalkorDBDocumentStore(similarity="invalid") |
| 76 | + |
| 77 | + @pytest.mark.parametrize( |
| 78 | + "filter_node, expected_clause, expected_params", |
| 79 | + [ |
| 80 | + ({"field": "year", "operator": "==", "value": 2024}, "coalesce(d.year = $p0, false)", {"p0": 2024}), |
| 81 | + ({"field": "year", "operator": "==", "value": None}, "d.year IS NULL", {}), |
| 82 | + ({"field": "y", "operator": "!=", "value": None}, "d.y IS NOT NULL", {}), |
| 83 | + ({"field": "y", "operator": ">", "value": 1}, "coalesce(d.y > $p0, false)", {"p0": 1}), |
| 84 | + ({"field": "y", "operator": ">", "value": None}, "false", {}), |
| 85 | + ( |
| 86 | + {"field": "tag", "operator": "in", "value": ["a"]}, |
| 87 | + "coalesce(d.tag IN $p0, false)", |
| 88 | + {"p0": ["a"]}, |
| 89 | + ), |
| 90 | + ( |
| 91 | + {"field": "tag", "operator": "not in", "value": ["a"]}, |
| 92 | + "coalesce(NOT (d.tag IN $p0), true)", |
| 93 | + {"p0": ["a"]}, |
| 94 | + ), |
| 95 | + ( |
| 96 | + {"field": "meta.year", "operator": "==", "value": 2024}, |
| 97 | + "coalesce(d.year = $p0, false)", |
| 98 | + {"p0": 2024}, |
| 99 | + ), |
| 100 | + ( |
| 101 | + { |
| 102 | + "operator": "OR", |
| 103 | + "conditions": [ |
| 104 | + {"field": "y", "operator": "==", "value": 1}, |
| 105 | + {"operator": "NOT", "conditions": [{"field": "b", "operator": "<=", "value": 0}]}, |
| 106 | + ], |
| 107 | + }, |
| 108 | + "(coalesce(d.y = $p0, false) OR NOT (coalesce(d.b <= $p1, false)))", |
| 109 | + {"p0": 1, "p1": 0}, |
| 110 | + ), |
| 111 | + ], |
| 112 | + ) |
| 113 | + def test_convert_filters(self, filter_node, expected_clause, expected_params): |
| 114 | + clause, params = _convert_filters(filter_node) |
| 115 | + assert clause == expected_clause |
| 116 | + assert params == expected_params |
| 117 | + |
| 118 | + @pytest.mark.parametrize( |
| 119 | + "filter_node, match", |
| 120 | + [ |
| 121 | + ({"operator": "AND"}, "requires a 'conditions' key"), |
| 122 | + ({"operator": "NOT"}, "requires a 'conditions' key"), |
| 123 | + ({"operator": "==", "value": 1}, "requires a 'field' key"), |
| 124 | + ({"operator": "==", "field": "f"}, "requires a 'value' key"), |
| 125 | + ({"field": "x", "operator": ">", "value": [1, 2]}, "does not support list values"), |
| 126 | + ({"field": "x", "operator": ">", "value": "not-a-date"}, "non-ISO string"), |
| 127 | + ({"field": "x", "operator": "in", "value": "scalar"}, "requires a list value"), |
| 128 | + ({"field": "x", "operator": "not in", "value": "scalar"}, "requires a list value"), |
| 129 | + ({"field": "x", "operator": "regex", "value": "."}, "Unsupported filter operator"), |
| 130 | + ], |
| 131 | + ) |
| 132 | + def test_convert_filters_errors(self, filter_node, match): |
| 133 | + with pytest.raises(FilterError, match=match): |
| 134 | + _convert_filters(filter_node) |
| 135 | + |
| 136 | + @pytest.mark.parametrize("rows, expected", [([[42]], 42), ([], 0)]) |
| 137 | + def test_count_documents(self, mock_falkordb, rows, expected): |
| 138 | + _, _, graph = mock_falkordb |
| 139 | + graph.query.side_effect = [_result([]), _result([]), _result(rows)] |
| 140 | + assert FalkorDBDocumentStore().count_documents() == expected |
| 141 | + |
| 142 | + def test_filter_documents_no_filters(self, mock_falkordb): |
| 143 | + _, _, graph = mock_falkordb |
| 144 | + node = SimpleNamespace(properties={"id": "n1", "content": "hello"}) |
| 145 | + graph.query.side_effect = [_result([]), _result([]), _result([[node]])] |
| 146 | + docs = FalkorDBDocumentStore().filter_documents() |
| 147 | + assert [d.content for d in docs] == ["hello"] |
| 148 | + assert "WHERE" not in graph.query.call_args_list[-1].args[0] |
| 149 | + |
| 150 | + def test_filter_documents_with_filters(self, mock_falkordb): |
| 151 | + _, _, graph = mock_falkordb |
| 152 | + graph.query.side_effect = [_result([]), _result([]), _result([])] |
| 153 | + FalkorDBDocumentStore().filter_documents({"field": "year", "operator": "==", "value": 2024}) |
| 154 | + last = graph.query.call_args_list[-1] |
| 155 | + assert "WHERE" in last.args[0] |
| 156 | + assert last.args[1] == {"p0": 2024} |
| 157 | + |
| 158 | + @pytest.mark.usefixtures("mock_falkordb") |
| 159 | + def test_filter_documents_malformed_raises(self): |
| 160 | + with pytest.raises(FilterError, match="Invalid filter syntax"): |
| 161 | + FalkorDBDocumentStore().filter_documents({"field": "year", "value": 2024}) |
| 162 | + |
| 163 | + def test_delete_documents_empty_is_noop(self, mock_falkordb): |
| 164 | + _, _, graph = mock_falkordb |
| 165 | + FalkorDBDocumentStore().delete_documents([]) |
| 166 | + assert graph.query.call_count == 2 |
| 167 | + |
| 168 | + def test_delete_documents_runs_query(self, mock_falkordb): |
| 169 | + _, _, graph = mock_falkordb |
| 170 | + FalkorDBDocumentStore().delete_documents(["a", "b"]) |
| 171 | + last = graph.query.call_args_list[-1] |
| 172 | + assert "DETACH DELETE" in last.args[0] |
| 173 | + assert last.args[1] == {"ids": ["a", "b"]} |
| 174 | + |
| 175 | + @pytest.mark.parametrize( |
| 176 | + "similarity, raw, scale_score, filters, expected_score, where_expected", |
| 177 | + [ |
| 178 | + ("cosine", 0.4, True, None, 0.8, False), |
| 179 | + ("cosine", 0.4, False, None, 0.4, False), |
| 180 | + ("euclidean", 1.0, True, {"field": "y", "operator": "==", "value": 1}, 0.5, True), |
| 181 | + ], |
| 182 | + ) |
| 183 | + def test_embedding_retrieval( |
| 184 | + self, mock_falkordb, similarity, raw, scale_score, filters, expected_score, where_expected |
| 185 | + ): |
| 186 | + _, _, graph = mock_falkordb |
| 187 | + node = SimpleNamespace(properties={"id": "n1", "content": "hello"}) |
| 188 | + graph.query.side_effect = [_result([]), _result([]), _result([[node, raw]])] |
| 189 | + docs = FalkorDBDocumentStore(similarity=similarity)._embedding_retrieval( |
| 190 | + query_embedding=[0.1], top_k=5, filters=filters, scale_score=scale_score |
| 191 | + ) |
| 192 | + assert docs[0].score == pytest.approx(expected_score) |
| 193 | + assert ("WHERE" in graph.query.call_args_list[-1].args[0]) is where_expected |
| 194 | + |
| 195 | + def test_cypher_retrieval_returns_documents(self, mock_falkordb): |
| 196 | + _, _, graph = mock_falkordb |
| 197 | + node = SimpleNamespace(properties={"id": "n1", "content": "hello"}) |
| 198 | + graph.query.side_effect = [_result([]), _result([]), _result([[node]])] |
| 199 | + docs = FalkorDBDocumentStore()._cypher_retrieval("MATCH (d) RETURN d", parameters={"k": 1}) |
| 200 | + assert docs[0].content == "hello" |
| 201 | + |
| 202 | + @pytest.mark.usefixtures("mock_falkordb") |
| 203 | + def test_write_documents_rejects_non_documents(self): |
| 204 | + with pytest.raises(ValueError, match="expects a list of Documents"): |
| 205 | + FalkorDBDocumentStore().write_documents(["not a doc"]) |
| 206 | + |
| 207 | + @pytest.mark.usefixtures("mock_falkordb") |
| 208 | + def test_write_documents_empty_list_returns_zero(self, caplog): |
| 209 | + with caplog.at_level(logging.WARNING): |
| 210 | + assert FalkorDBDocumentStore().write_documents([]) == 0 |
| 211 | + assert "empty list" in caplog.text |
| 212 | + |
| 213 | + def test_write_documents_policy_none_coerced_to_fail(self, mock_falkordb): |
| 214 | + _, _, graph = mock_falkordb |
| 215 | + graph.query.side_effect = [_result([]), _result([]), _result([["a"]])] |
| 216 | + with pytest.raises(DuplicateDocumentError): |
| 217 | + FalkorDBDocumentStore().write_documents([Document(id="a", content="x")]) |
| 218 | + |
| 219 | + def test_write_documents_drops_duplicates_within_batch(self, mock_falkordb, caplog): |
| 220 | + _, _, graph = mock_falkordb |
| 221 | + graph.query.side_effect = [_result([]), _result([]), _result([]), _result([[2]])] |
| 222 | + with caplog.at_level(logging.INFO): |
| 223 | + written = FalkorDBDocumentStore().write_documents( |
| 224 | + [ |
| 225 | + Document(id="a", content="x"), |
| 226 | + Document(id="a", content="x"), |
| 227 | + Document(id="b", content="y"), |
| 228 | + ], |
| 229 | + policy=DuplicatePolicy.SKIP, |
| 230 | + ) |
| 231 | + assert written == 2 |
| 232 | + assert "already present in the batch" in caplog.text |
| 233 | + sent_ids = [d["id"] for d in graph.query.call_args_list[-1].args[1]["docs"]] |
| 234 | + assert sent_ids == ["a", "b"] |
| 235 | + |
| 236 | + def test_write_documents_skip_filters_existing(self, mock_falkordb): |
| 237 | + _, _, graph = mock_falkordb |
| 238 | + graph.query.side_effect = [_result([]), _result([]), _result([["a"]]), _result([[1]])] |
| 239 | + written = FalkorDBDocumentStore().write_documents( |
| 240 | + [Document(id="a", content="x"), Document(id="b", content="y")], |
| 241 | + policy=DuplicatePolicy.SKIP, |
| 242 | + ) |
| 243 | + assert written == 1 |
| 244 | + sent_ids = [d["id"] for d in graph.query.call_args_list[-1].args[1]["docs"]] |
| 245 | + assert sent_ids == ["b"] |
| 246 | + |
| 247 | + def test_write_documents_overwrite_uses_on_match_set(self, mock_falkordb): |
| 248 | + _, _, graph = mock_falkordb |
| 249 | + graph.query.side_effect = [_result([]), _result([]), _result([[1]])] |
| 250 | + FalkorDBDocumentStore().write_documents([Document(id="a", content="x")], policy=DuplicatePolicy.OVERWRITE) |
| 251 | + assert "ON MATCH SET d = doc" in graph.query.call_args_list[-1].args[0] |
| 252 | + |
| 253 | + def test_write_documents_embeddings_second_pass(self, mock_falkordb): |
| 254 | + _, _, graph = mock_falkordb |
| 255 | + graph.query.side_effect = [_result([]), _result([]), _result([[1]]), _result([])] |
| 256 | + FalkorDBDocumentStore().write_documents( |
| 257 | + [Document(id="a", content="x", embedding=[0.1, 0.2, 0.3])], |
| 258 | + policy=DuplicatePolicy.OVERWRITE, |
| 259 | + ) |
| 260 | + last = graph.query.call_args_list[-1] |
| 261 | + assert "vecf32" in last.args[0] |
| 262 | + assert last.args[1]["docs"] == [{"id": "a", "emb": [0.1, 0.2, 0.3]}] |
| 263 | + |
| 264 | + def test_write_documents_wraps_errors(self, mock_falkordb): |
| 265 | + _, _, graph = mock_falkordb |
| 266 | + graph.query.side_effect = [_result([]), _result([]), Exception("boom")] |
| 267 | + with pytest.raises(DocumentStoreError, match="Failed to write documents"): |
| 268 | + FalkorDBDocumentStore().write_documents([Document(id="a", content="x")], policy=DuplicatePolicy.OVERWRITE) |
| 269 | + |
| 270 | + |
47 | 271 | @pytest.mark.integration |
48 | 272 | class TestDocumentStore(DocumentStoreBaseTests): |
49 | 273 | """ |
|
0 commit comments