Skip to content

Commit ead7d68

Browse files
authored
test: Azure AI Search - add unit tests (#3201)
1 parent 940d470 commit ead7d68

5 files changed

Lines changed: 286 additions & 0 deletions

File tree

integrations/azure_ai_search/tests/test_bm25_retriever.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,19 @@ def test_run_time_params():
149149
assert res["documents"][0].content == "Test doc"
150150

151151

152+
def test_init_raises_type_error_on_invalid_document_store():
153+
with pytest.raises(TypeError, match="document_store must be an instance of AzureAISearchDocumentStore"):
154+
AzureAISearchBM25Retriever(document_store=object())
155+
156+
157+
def test_run_raises_runtime_error_when_retrieval_fails():
158+
mock_store = Mock(spec=AzureAISearchDocumentStore)
159+
mock_store._bm25_retrieval.side_effect = RuntimeError("boom")
160+
retriever = AzureAISearchBM25Retriever(document_store=mock_store)
161+
with pytest.raises(RuntimeError, match="bm25 retrieval process"):
162+
retriever.run(query="Test query")
163+
164+
152165
@pytest.mark.skipif(
153166
not os.environ.get("AZURE_AI_SEARCH_ENDPOINT", None) and not os.environ.get("AZURE_AI_SEARCH_API_KEY", None),
154167
reason="Missing AZURE_AI_SEARCH_ENDPOINT or AZURE_AI_SEARCH_API_KEY.",

integrations/azure_ai_search/tests/test_document_store.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pytest
1212
from azure.core.credentials import TokenCredential
13+
from azure.core.exceptions import ResourceNotFoundError
1314
from azure.search.documents.indexes.models import (
1415
CustomAnalyzer,
1516
SearchableField,
@@ -336,6 +337,128 @@ def test_query_sql_raises_not_implemented():
336337
document_store.query_sql("SELECT * FROM test-index")
337338

338339

340+
@pytest.mark.parametrize(
341+
"metadata_fields, expected_error_match",
342+
[
343+
(
344+
{"Title": SearchField(name="mismatched", type="Edm.String", filterable=True)},
345+
"Name of SearchField",
346+
),
347+
({"Pages": object}, "Unsupported field type"),
348+
],
349+
)
350+
def test_normalize_metadata_index_fields_raises(metadata_fields, expected_error_match):
351+
with pytest.raises(ValueError, match=expected_error_match):
352+
AzureAISearchDocumentStore._normalize_metadata_index_fields(metadata_fields)
353+
354+
355+
def test_normalize_metadata_index_fields_skips_non_alpha_keys(caplog):
356+
with caplog.at_level(logging.WARNING):
357+
normalized = AzureAISearchDocumentStore._normalize_metadata_index_fields({"1invalid": str, "valid": int})
358+
assert "valid" in normalized
359+
assert "1invalid" not in normalized
360+
assert "Invalid key" in caplog.text
361+
362+
363+
def test_normalize_metadata_index_fields_returns_empty_for_none():
364+
assert AzureAISearchDocumentStore._normalize_metadata_index_fields(None) == {}
365+
366+
367+
@pytest.mark.parametrize(
368+
"method, kwargs, expected_match",
369+
[
370+
("_bm25_retrieval", {"query": None}, "query must not be None"),
371+
("_hybrid_retrieval", {"query": None, "query_embedding": [0.1]}, "query must not be None"),
372+
("_hybrid_retrieval", {"query": "q", "query_embedding": []}, "query_embedding must be a non-empty"),
373+
("_embedding_retrieval", {"query_embedding": []}, "query_embedding must be a non-empty"),
374+
],
375+
)
376+
def test_internal_retrieval_validates_inputs(method, kwargs, expected_match):
377+
document_store = AzureAISearchDocumentStore(
378+
api_key=Secret.from_token("fake-api-key"),
379+
azure_endpoint=Secret.from_token("fake-endpoint"),
380+
index_name="test-index",
381+
)
382+
with pytest.raises(ValueError, match=expected_match):
383+
getattr(document_store, method)(**kwargs)
384+
385+
386+
def test_collect_unique_values_combines_lists_and_scalars():
387+
docs = [
388+
{"tags": ["a", "b"]},
389+
{"tags": "c"},
390+
{"tags": None},
391+
{"tags": ["a", "d"]},
392+
]
393+
assert AzureAISearchDocumentStore._collect_unique_values(docs, "tags") == {"a", "b", "c", "d"}
394+
395+
396+
@pytest.mark.parametrize(
397+
"docs, expected",
398+
[
399+
([], {"min": None, "max": None}),
400+
([{"x": None}, {"x": [1, 2]}], {"min": None, "max": None}),
401+
([{"x": 3}, {"x": 1}, {"x": 2}], {"min": 1, "max": 3}),
402+
],
403+
)
404+
def test_get_min_max_from_documents(docs, expected):
405+
assert AzureAISearchDocumentStore._get_min_max_from_documents(docs, "x") == expected
406+
407+
408+
@pytest.mark.parametrize(
409+
"field, expected_type",
410+
[
411+
(SimpleField(name="cat", type=SearchFieldDataType.String, filterable=True), "keyword"),
412+
(SearchableField(name="content", type=SearchFieldDataType.String), "text"),
413+
(SearchableField(name="title", type=SearchFieldDataType.String), "text"),
414+
(SimpleField(name="year", type=SearchFieldDataType.Int32, filterable=True), "long"),
415+
(SimpleField(name="rating", type=SearchFieldDataType.Double, filterable=True), "double"),
416+
(
417+
SearchField(
418+
name="tags",
419+
type=SearchFieldDataType.Collection(SearchFieldDataType.String),
420+
filterable=True,
421+
),
422+
"keyword",
423+
),
424+
(SimpleField(name="when", type=SearchFieldDataType.DateTimeOffset, filterable=True), "date"),
425+
],
426+
)
427+
def test_map_azure_field_type_variants(field, expected_type):
428+
assert AzureAISearchDocumentStore._map_azure_field_type(field) == expected_type
429+
430+
431+
def test_map_azure_field_type_without_type_attribute():
432+
field = Mock(spec=[])
433+
field.name = "custom"
434+
assert AzureAISearchDocumentStore._map_azure_field_type(field) == "keyword"
435+
436+
437+
def test_index_exists_raises_without_index_name():
438+
document_store = AzureAISearchDocumentStore(
439+
api_key=Secret.from_token("fake-api-key"),
440+
azure_endpoint=Secret.from_token("fake-endpoint"),
441+
index_name="test-index",
442+
)
443+
document_store._index_client = Mock()
444+
with pytest.raises(ValueError, match="Index name is required"):
445+
document_store._index_exists(None)
446+
447+
448+
def test_get_raw_documents_by_id_skips_not_found(caplog):
449+
store, search_client, _ = _build_mock_document_store_with_schema(
450+
[SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True)]
451+
)
452+
search_client.get_document.side_effect = [
453+
{"id": "1", "content": "c1"},
454+
ResourceNotFoundError("not found"),
455+
]
456+
with caplog.at_level(logging.WARNING):
457+
result = store._get_raw_documents_by_id(["1", "missing"])
458+
assert result == [{"id": "1", "content": "c1"}]
459+
assert "missing" in caplog.text
460+
461+
339462
def _assert_documents_are_equal(received: list[Document], expected: list[Document]):
340463
"""
341464
Assert that two lists of Documents are equal.

integrations/azure_ai_search/tests/test_embedding_retriever.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,19 @@ def test_run_time_params():
162162
assert res["documents"][0].embedding == [0.1, 0.2]
163163

164164

165+
def test_init_raises_on_invalid_document_store():
166+
with pytest.raises(Exception, match="document_store must be an instance of AzureAISearchDocumentStore"):
167+
AzureAISearchEmbeddingRetriever(document_store=object())
168+
169+
170+
def test_run_raises_runtime_error_when_retrieval_fails():
171+
mock_store = Mock(spec=AzureAISearchDocumentStore)
172+
mock_store._embedding_retrieval.side_effect = RuntimeError("boom")
173+
retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store)
174+
with pytest.raises(RuntimeError, match="embedding retrieval process"):
175+
retriever.run(query_embedding=[0.1, 0.2])
176+
177+
165178
@pytest.mark.skipif(
166179
not os.environ.get("AZURE_AI_SEARCH_ENDPOINT", None) and not os.environ.get("AZURE_AI_SEARCH_API_KEY", None),
167180
reason="Missing AZURE_AI_SEARCH_ENDPOINT or AZURE_AI_SEARCH_API_KEY.",
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import pytest
6+
7+
from haystack_integrations.document_stores.azure_ai_search.errors import AzureAISearchDocumentStoreFilterError
8+
from haystack_integrations.document_stores.azure_ai_search.filters import _normalize_filters
9+
10+
11+
@pytest.mark.parametrize(
12+
"filters, expected",
13+
[
14+
({"field": "meta.name", "operator": "==", "value": "alice"}, "name eq 'alice'"),
15+
({"field": "meta.active", "operator": "==", "value": True}, "active eq true"),
16+
({"field": "meta.count", "operator": "==", "value": 3}, "count eq 3"),
17+
({"field": "meta.name", "operator": "==", "value": None}, "name eq null"),
18+
({"field": "meta.name", "operator": "!=", "value": "alice"}, "not (name eq 'alice')"),
19+
({"field": "meta.active", "operator": "!=", "value": False}, "not (active eq false)"),
20+
({"field": "meta.count", "operator": "!=", "value": 3}, "not (count eq 3)"),
21+
(
22+
{"field": "meta.page", "operator": "in", "value": ["1", "2"]},
23+
"search.in(page,'1,2',',')",
24+
),
25+
({"field": "meta.count", "operator": ">", "value": 5}, "count gt 5"),
26+
({"field": "meta.count", "operator": ">=", "value": 5}, "count ge 5"),
27+
({"field": "meta.count", "operator": "<", "value": 5}, "count lt 5"),
28+
({"field": "meta.count", "operator": "<=", "value": 5}, "count le 5"),
29+
(
30+
{"field": "meta.date", "operator": ">", "value": "2020-01-01T00:00:00Z"},
31+
"date gt 2020-01-01T00:00:00Z",
32+
),
33+
({"field": "bare_field", "operator": "==", "value": "x"}, "bare_field eq 'x'"),
34+
],
35+
)
36+
def test_normalize_filters_comparison_conditions(filters, expected):
37+
assert _normalize_filters(filters) == expected
38+
39+
40+
@pytest.mark.parametrize(
41+
"filters, expected",
42+
[
43+
(
44+
{
45+
"operator": "AND",
46+
"conditions": [
47+
{"field": "meta.name", "operator": "==", "value": "alice"},
48+
{"field": "meta.count", "operator": ">=", "value": 1},
49+
],
50+
},
51+
"(name eq 'alice') and (count ge 1)",
52+
),
53+
(
54+
{
55+
"operator": "OR",
56+
"conditions": [
57+
{"field": "meta.name", "operator": "==", "value": "alice"},
58+
{"field": "meta.name", "operator": "==", "value": "bob"},
59+
],
60+
},
61+
"(name eq 'alice') or (name eq 'bob')",
62+
),
63+
(
64+
{
65+
"operator": "NOT",
66+
"conditions": [{"field": "meta.name", "operator": "==", "value": "alice"}],
67+
},
68+
"not ((name eq 'alice'))",
69+
),
70+
(
71+
{
72+
"operator": "AND",
73+
"conditions": [
74+
{"field": "meta.name", "operator": "==", "value": "alice"},
75+
{
76+
"operator": "OR",
77+
"conditions": [
78+
{"field": "meta.count", "operator": ">", "value": 1},
79+
{"field": "meta.count", "operator": "<", "value": 10},
80+
],
81+
},
82+
],
83+
},
84+
"(name eq 'alice') and ((count gt 1) or (count lt 10))",
85+
),
86+
],
87+
)
88+
def test_normalize_filters_logical_conditions(filters, expected):
89+
assert _normalize_filters(filters) == expected
90+
91+
92+
@pytest.mark.parametrize(
93+
"filters, expected_match",
94+
[
95+
("not a dict", "Filters must be a dictionary"),
96+
({"operator": "AND"}, "Missing key"),
97+
({"conditions": []}, "Missing key"),
98+
(
99+
{"operator": "XOR", "conditions": [{"field": "a", "operator": "==", "value": 1}]},
100+
"Unknown operator XOR",
101+
),
102+
({"field": "f"}, "Missing key"),
103+
({"field": "f", "operator": "???", "value": 1}, "Unknown operator"),
104+
(
105+
{"field": "f", "operator": ">", "value": "not-a-date"},
106+
"Invalid value type",
107+
),
108+
(
109+
{"field": "f", "operator": ">", "value": [1, 2]},
110+
"Invalid value type",
111+
),
112+
(
113+
{"field": "f", "operator": "in", "value": "not-a-list"},
114+
"only supports a list of strings",
115+
),
116+
(
117+
{"field": "f", "operator": "in", "value": [1, 2]},
118+
"only supports a list of strings",
119+
),
120+
],
121+
)
122+
def test_normalize_filters_raises_on_invalid_input(filters, expected_match):
123+
with pytest.raises(AzureAISearchDocumentStoreFilterError, match=expected_match):
124+
_normalize_filters(filters)

integrations/azure_ai_search/tests/test_hybrid_retriever.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ def test_run_time_params():
168168
assert res["documents"][0].embedding == [0.1, 0.2]
169169

170170

171+
def test_init_raises_type_error_on_invalid_document_store():
172+
with pytest.raises(TypeError, match="document_store must be an instance of AzureAISearchDocumentStore"):
173+
AzureAISearchHybridRetriever(document_store=object())
174+
175+
176+
def test_run_raises_runtime_error_when_retrieval_fails():
177+
mock_store = Mock(spec=AzureAISearchDocumentStore)
178+
mock_store._hybrid_retrieval.side_effect = RuntimeError("boom")
179+
retriever = AzureAISearchHybridRetriever(document_store=mock_store)
180+
with pytest.raises(RuntimeError, match="hybrid retrieval process"):
181+
retriever.run(query="Test query", query_embedding=[0.1, 0.2])
182+
183+
171184
@pytest.mark.skipif(
172185
not os.environ.get("AZURE_AI_SEARCH_ENDPOINT", None) and not os.environ.get("AZURE_AI_SEARCH_API_KEY", None),
173186
reason="Missing AZURE_AI_SEARCH_ENDPOINT or AZURE_AI_SEARCH_API_KEY.",

0 commit comments

Comments
 (0)