Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/azure_ai_search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
]
dependencies = [
"haystack-ai>=2.26.1",
"azure-search-documents>=11.5",
"azure-search-documents>=12.0.0",
"azure-identity"
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging as python_logging
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any

Expand Down Expand Up @@ -37,7 +38,7 @@
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from azure.search.documents.models import VectorizedQuery
from azure.search.documents.models import LookupDocument, VectorizedQuery
from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document
from haystack.document_stores.types import DuplicatePolicy
Expand All @@ -58,6 +59,20 @@
datetime: "Edm.DateTimeOffset",
}


def _instantiate_azure_model(model_class: Any, data: Any) -> Any:
"""Instantiate an Azure SDK model from a dict, picking the right subclass if the model has subtypes."""
# Some Azure base classes (e.g. LexicalAnalyzer) have multiple subclasses (e.g. CustomAnalyzer);
# the concrete subclass to use is named in the "@odata.type" field of the dict, and the base class
# exposes a __mapping__ from that name to the subclass.
if isinstance(data, Mapping):
subtype_name = data.get("@odata.type")
subtypes = getattr(model_class, "__mapping__", {})
if subtype_name in subtypes:
return subtypes[subtype_name](data)
return model_class(data)


# Map of expected field names to their corresponding classes
AZURE_CLASS_MAPPING: dict[str, Any] = {
"suggesters": SearchSuggester,
Expand Down Expand Up @@ -273,11 +288,11 @@ def _create_index(self) -> None:

# default fields to create index based on Haystack Document (id, content, embedding)
default_fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
SearchableField(name="content", type=SearchFieldDataType.String),
SimpleField(name="id", type=SearchFieldDataType.STRING, key=True, filterable=True),
SearchableField(name="content", type=SearchFieldDataType.STRING),
SearchField(
name="embedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
type=f"Collection({SearchFieldDataType.SINGLE.value})",
searchable=True,
hidden=False,
vector_search_dimensions=self._embedding_dimension,
Expand Down Expand Up @@ -355,23 +370,21 @@ def from_dict(cls, data: dict[str, Any]) -> "AzureAISearchDocumentStore":
Deserialized component.
"""
if (fields := data["init_parameters"]["metadata_fields"]) is not None:
data["init_parameters"]["metadata_fields"] = {
key: SearchField.from_dict(field) for key, field in fields.items()
}
data["init_parameters"]["metadata_fields"] = {key: SearchField(field) for key, field in fields.items()}
else:
data["init_parameters"]["metadata_fields"] = {}

for key, model_class in AZURE_CLASS_MAPPING.items():
if key in data["init_parameters"]:
value = data["init_parameters"][key]
if isinstance(value, list):
data["init_parameters"][key] = [model_class.from_dict(item) for item in value]
data["init_parameters"][key] = [_instantiate_azure_model(model_class, item) for item in value]
else:
data["init_parameters"][key] = model_class.from_dict(value)
data["init_parameters"][key] = _instantiate_azure_model(model_class, value)

deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"])
if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None:
data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration)
data["init_parameters"]["vector_search_configuration"] = VectorSearch(vector_search_configuration)
return default_from_dict(cls, data)

def count_documents(self) -> int:
Expand Down Expand Up @@ -460,7 +473,7 @@ def _map_azure_field_type(field: Any) -> str:
if field_type is None:
return "keyword"

field_type_name = str(field_type)
field_type_name = field_type.value if hasattr(field_type, "value") else str(field_type)
if field_type_name.startswith("Collection("):
inner_type = field_type_name[len("Collection(") : -1]
return FIELD_TYPE_MAPPING.get(inner_type, "keyword")
Expand Down Expand Up @@ -604,7 +617,7 @@ def delete_documents(self, document_ids: list[str]) -> None:
return
documents = self._get_raw_documents_by_id(document_ids)
if documents:
self.client.delete_documents(documents)
self.client.delete_documents([dict(doc) for doc in documents])

def delete_all_documents(self, recreate_index: bool = False) -> None: # noqa: FBT002, FBT001
"""
Expand Down Expand Up @@ -758,7 +771,7 @@ def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Docume
else:
return self.search_documents()

def _convert_search_result_to_documents(self, azure_docs: list[dict[str, Any]]) -> list[Document]:
def _convert_search_result_to_documents(self, azure_docs: Sequence[Mapping[str, Any]]) -> list[Document]:
"""
Converts Azure search results to Haystack Documents.
"""
Expand Down Expand Up @@ -807,7 +820,7 @@ def _index_exists(self, index_name: str | None) -> bool:
msg = "Index name is required to check if the index exists."
raise ValueError(msg)

def _get_raw_documents_by_id(self, document_ids: list[str]) -> list[dict]:
def _get_raw_documents_by_id(self, document_ids: list[str]) -> list[LookupDocument]:
"""
Retrieves all Azure documents with a matching document_ids from the document store.

Expand Down
1 change: 0 additions & 1 deletion integrations/azure_ai_search/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def document_store(request):

store = AzureAISearchDocumentStore(
index_name=index_name,
create_index=True,
embedding_dimension=768,
metadata_fields=metadata_fields,
include_search_metadata=include_search_metadata,
Expand Down
6 changes: 2 additions & 4 deletions integrations/azure_ai_search/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,12 @@ def test_to_dict():
"embedding_dimension": 768,
"metadata_fields": {},
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
],
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
"algorithms": [
{
"name": "cosine-algorithm-config",
"hnswParameters": {"metric": "cosine"},
"kind": "hnsw",
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
}
],
},
Expand Down
42 changes: 18 additions & 24 deletions integrations/azure_ai_search/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@ def test_to_dict(monkeypatch):
"embedding_dimension": 768,
"metadata_fields": {},
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
],
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
"algorithms": [
{
"name": "cosine-algorithm-config",
"hnswParameters": {"metric": "cosine"},
"kind": "hnsw",
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
}
],
},
Expand Down Expand Up @@ -110,27 +108,25 @@ def test_to_dict_with_params(monkeypatch):
"Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True).as_dict(),
},
"encryption_key": {
"key_name": "my-key",
"key_version": "my-version",
"vault_uri": "my-uri",
"keyVaultKeyName": "my-key",
"keyVaultKeyVersion": "my-version",
"keyVaultUri": "my-uri",
},
"analyzers": [
{
"name": "url-analyze",
"odata_type": "#Microsoft.Azure.Search.CustomAnalyzer",
"tokenizer_name": "uax_url_email",
"token_filters": ["lowercase"],
"tokenizer": "uax_url_email",
"tokenFilters": ["lowercase"],
"@odata.type": "#Microsoft.Azure.Search.CustomAnalyzer",
}
],
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
],
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
"algorithms": [
{
"name": "cosine-algorithm-config",
"hnswParameters": {"metric": "cosine"},
"kind": "hnsw",
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
}
],
},
Expand Down Expand Up @@ -201,27 +197,25 @@ def test_from_dict_with_params(monkeypatch):
"Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True).as_dict(),
},
"encryption_key": {
"key_name": "my-key",
"key_version": "my-version",
"vault_uri": "my-uri",
"keyVaultKeyName": "my-key",
"keyVaultKeyVersion": "my-version",
"keyVaultUri": "my-uri",
},
"analyzers": [
{
"name": "url-analyze",
"odata_type": "#Microsoft.Azure.Search.CustomAnalyzer",
"tokenizer_name": "uax_url_email",
"token_filters": ["lowercase"],
"@odata.type": "#Microsoft.Azure.Search.CustomAnalyzer",
"tokenizer": "uax_url_email",
"tokenFilters": ["lowercase"],
}
],
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
],
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
"algorithms": [
{
"name": "cosine-algorithm-config",
"kind": "hnsw",
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
"hnswParameters": {"metric": "cosine"},
}
],
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ def test_to_dict():
"embedding_dimension": 768,
"metadata_fields": {},
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
],
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
"algorithms": [
{
"name": "cosine-algorithm-config",
"hnswParameters": {"metric": "cosine"},
"kind": "hnsw",
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
}
],
},
Expand Down
6 changes: 2 additions & 4 deletions integrations/azure_ai_search/tests/test_hybrid_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ def test_to_dict():
"embedding_dimension": 768,
"metadata_fields": {},
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
],
"profiles": [{"name": "default-vector-config", "algorithm": "cosine-algorithm-config"}],
"algorithms": [
{
"name": "cosine-algorithm-config",
"hnswParameters": {"metric": "cosine"},
"kind": "hnsw",
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
}
],
},
Expand Down
Loading