Skip to content

Commit c9285c2

Browse files
committed
chore: Azure AI - adaptations for the new SDK
1 parent a402a09 commit c9285c2

6 files changed

Lines changed: 52 additions & 51 deletions

File tree

integrations/azure_ai_search/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ classifiers = [
2525
]
2626
dependencies = [
2727
"haystack-ai>=2.26.1",
28-
"azure-search-documents>=11.5",
28+
"azure-search-documents>=12.0.0",
2929
"azure-identity"
3030
]
3131

integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import logging as python_logging
6+
from collections.abc import Mapping, Sequence
67
from datetime import datetime
78
from typing import Any
89

@@ -37,7 +38,7 @@
3738
VectorSearchAlgorithmMetric,
3839
VectorSearchProfile,
3940
)
40-
from azure.search.documents.models import VectorizedQuery
41+
from azure.search.documents.models import LookupDocument, VectorizedQuery
4142
from haystack import default_from_dict, default_to_dict, logging
4243
from haystack.dataclasses import Document
4344
from haystack.document_stores.types import DuplicatePolicy
@@ -58,6 +59,20 @@
5859
datetime: "Edm.DateTimeOffset",
5960
}
6061

62+
63+
def _instantiate_azure_model(model_class: Any, data: Any) -> Any:
64+
"""Instantiate an Azure SDK model from a dict, picking the right subclass if the model has subtypes."""
65+
# Some Azure base classes (e.g. LexicalAnalyzer) have multiple subclasses (e.g. CustomAnalyzer);
66+
# the concrete subclass to use is named in the "@odata.type" field of the dict, and the base class
67+
# exposes a __mapping__ from that name to the subclass.
68+
if isinstance(data, Mapping):
69+
subtype_name = data.get("@odata.type")
70+
subtypes = getattr(model_class, "__mapping__", {})
71+
if subtype_name in subtypes:
72+
return subtypes[subtype_name](data)
73+
return model_class(data)
74+
75+
6176
# Map of expected field names to their corresponding classes
6277
AZURE_CLASS_MAPPING: dict[str, Any] = {
6378
"suggesters": SearchSuggester,
@@ -273,11 +288,11 @@ def _create_index(self) -> None:
273288

274289
# default fields to create index based on Haystack Document (id, content, embedding)
275290
default_fields = [
276-
SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
277-
SearchableField(name="content", type=SearchFieldDataType.String),
291+
SimpleField(name="id", type=SearchFieldDataType.STRING, key=True, filterable=True),
292+
SearchableField(name="content", type=SearchFieldDataType.STRING),
278293
SearchField(
279294
name="embedding",
280-
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
295+
type=f"Collection({SearchFieldDataType.SINGLE.value})",
281296
searchable=True,
282297
hidden=False,
283298
vector_search_dimensions=self._embedding_dimension,
@@ -355,23 +370,21 @@ def from_dict(cls, data: dict[str, Any]) -> "AzureAISearchDocumentStore":
355370
Deserialized component.
356371
"""
357372
if (fields := data["init_parameters"]["metadata_fields"]) is not None:
358-
data["init_parameters"]["metadata_fields"] = {
359-
key: SearchField.from_dict(field) for key, field in fields.items()
360-
}
373+
data["init_parameters"]["metadata_fields"] = {key: SearchField(field) for key, field in fields.items()}
361374
else:
362375
data["init_parameters"]["metadata_fields"] = {}
363376

364377
for key, model_class in AZURE_CLASS_MAPPING.items():
365378
if key in data["init_parameters"]:
366379
value = data["init_parameters"][key]
367380
if isinstance(value, list):
368-
data["init_parameters"][key] = [model_class.from_dict(item) for item in value]
381+
data["init_parameters"][key] = [_instantiate_azure_model(model_class, item) for item in value]
369382
else:
370-
data["init_parameters"][key] = model_class.from_dict(value)
383+
data["init_parameters"][key] = _instantiate_azure_model(model_class, value)
371384

372385
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"])
373386
if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None:
374-
data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration)
387+
data["init_parameters"]["vector_search_configuration"] = VectorSearch(vector_search_configuration)
375388
return default_from_dict(cls, data)
376389

377390
def count_documents(self) -> int:
@@ -460,7 +473,7 @@ def _map_azure_field_type(field: Any) -> str:
460473
if field_type is None:
461474
return "keyword"
462475

463-
field_type_name = str(field_type)
476+
field_type_name = field_type.value if hasattr(field_type, "value") else str(field_type)
464477
if field_type_name.startswith("Collection("):
465478
inner_type = field_type_name[len("Collection(") : -1]
466479
return FIELD_TYPE_MAPPING.get(inner_type, "keyword")
@@ -604,7 +617,7 @@ def delete_documents(self, document_ids: list[str]) -> None:
604617
return
605618
documents = self._get_raw_documents_by_id(document_ids)
606619
if documents:
607-
self.client.delete_documents(documents)
620+
self.client.delete_documents([dict(doc) for doc in documents])
608621

609622
def delete_all_documents(self, recreate_index: bool = False) -> None: # noqa: FBT002, FBT001
610623
"""
@@ -758,7 +771,7 @@ def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Docume
758771
else:
759772
return self.search_documents()
760773

761-
def _convert_search_result_to_documents(self, azure_docs: list[dict[str, Any]]) -> list[Document]:
774+
def _convert_search_result_to_documents(self, azure_docs: Sequence[Mapping[str, Any]]) -> list[Document]:
762775
"""
763776
Converts Azure search results to Haystack Documents.
764777
"""
@@ -807,7 +820,7 @@ def _index_exists(self, index_name: str | None) -> bool:
807820
msg = "Index name is required to check if the index exists."
808821
raise ValueError(msg)
809822

810-
def _get_raw_documents_by_id(self, document_ids: list[str]) -> list[dict]:
823+
def _get_raw_documents_by_id(self, document_ids: list[str]) -> list[LookupDocument]:
811824
"""
812825
Retrieves all Azure documents with a matching document_ids from the document store.
813826

integrations/azure_ai_search/tests/test_bm25_retriever.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,12 @@ def test_to_dict():
5050
"embedding_dimension": 768,
5151
"metadata_fields": {},
5252
"vector_search_configuration": {
53-
"profiles": [
54-
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
55-
],
53+
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
5654
"algorithms": [
5755
{
5856
"name": "cosine-algorithm-config",
57+
"hnswParameters": {"metric": "cosine"},
5958
"kind": "hnsw",
60-
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
6159
}
6260
],
6361
},

integrations/azure_ai_search/tests/test_document_store.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,12 @@ def test_to_dict(monkeypatch):
5858
"embedding_dimension": 768,
5959
"metadata_fields": {},
6060
"vector_search_configuration": {
61-
"profiles": [
62-
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
63-
],
61+
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
6462
"algorithms": [
6563
{
6664
"name": "cosine-algorithm-config",
65+
"hnswParameters": {"metric": "cosine"},
6766
"kind": "hnsw",
68-
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
6967
}
7068
],
7169
},
@@ -110,27 +108,25 @@ def test_to_dict_with_params(monkeypatch):
110108
"Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True).as_dict(),
111109
},
112110
"encryption_key": {
113-
"key_name": "my-key",
114-
"key_version": "my-version",
115-
"vault_uri": "my-uri",
111+
"keyVaultKeyName": "my-key",
112+
"keyVaultKeyVersion": "my-version",
113+
"keyVaultUri": "my-uri",
116114
},
117115
"analyzers": [
118116
{
119117
"name": "url-analyze",
120-
"odata_type": "#Microsoft.Azure.Search.CustomAnalyzer",
121-
"tokenizer_name": "uax_url_email",
122-
"token_filters": ["lowercase"],
118+
"tokenizer": "uax_url_email",
119+
"tokenFilters": ["lowercase"],
120+
"@odata.type": "#Microsoft.Azure.Search.CustomAnalyzer",
123121
}
124122
],
125123
"vector_search_configuration": {
126-
"profiles": [
127-
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
128-
],
124+
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
129125
"algorithms": [
130126
{
131127
"name": "cosine-algorithm-config",
128+
"hnswParameters": {"metric": "cosine"},
132129
"kind": "hnsw",
133-
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
134130
}
135131
],
136132
},
@@ -201,27 +197,25 @@ def test_from_dict_with_params(monkeypatch):
201197
"Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True).as_dict(),
202198
},
203199
"encryption_key": {
204-
"key_name": "my-key",
205-
"key_version": "my-version",
206-
"vault_uri": "my-uri",
200+
"keyVaultKeyName": "my-key",
201+
"keyVaultKeyVersion": "my-version",
202+
"keyVaultUri": "my-uri",
207203
},
208204
"analyzers": [
209205
{
210206
"name": "url-analyze",
211-
"odata_type": "#Microsoft.Azure.Search.CustomAnalyzer",
212-
"tokenizer_name": "uax_url_email",
213-
"token_filters": ["lowercase"],
207+
"@odata.type": "#Microsoft.Azure.Search.CustomAnalyzer",
208+
"tokenizer": "uax_url_email",
209+
"tokenFilters": ["lowercase"],
214210
}
215211
],
216212
"vector_search_configuration": {
217-
"profiles": [
218-
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
219-
],
213+
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
220214
"algorithms": [
221215
{
222216
"name": "cosine-algorithm-config",
223217
"kind": "hnsw",
224-
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
218+
"hnswParameters": {"metric": "cosine"},
225219
}
226220
],
227221
},

integrations/azure_ai_search/tests/test_embedding_retriever.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,12 @@ def test_to_dict():
5151
"embedding_dimension": 768,
5252
"metadata_fields": {},
5353
"vector_search_configuration": {
54-
"profiles": [
55-
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
56-
],
54+
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
5755
"algorithms": [
5856
{
5957
"name": "cosine-algorithm-config",
58+
"hnswParameters": {"metric": "cosine"},
6059
"kind": "hnsw",
61-
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
6260
}
6361
],
6462
},

integrations/azure_ai_search/tests/test_hybrid_retriever.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,12 @@ def test_to_dict():
5151
"embedding_dimension": 768,
5252
"metadata_fields": {},
5353
"vector_search_configuration": {
54-
"profiles": [
55-
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
56-
],
54+
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
5755
"algorithms": [
5856
{
5957
"name": "cosine-algorithm-config",
58+
"hnswParameters": {"metric": "cosine"},
6059
"kind": "hnsw",
61-
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
6260
}
6361
],
6462
},

0 commit comments

Comments
 (0)