|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 | import logging as python_logging |
| 6 | +from collections.abc import Mapping, Sequence |
6 | 7 | from datetime import datetime |
7 | 8 | from typing import Any |
8 | 9 |
|
|
37 | 38 | VectorSearchAlgorithmMetric, |
38 | 39 | VectorSearchProfile, |
39 | 40 | ) |
40 | | -from azure.search.documents.models import VectorizedQuery |
| 41 | +from azure.search.documents.models import LookupDocument, VectorizedQuery |
41 | 42 | from haystack import default_from_dict, default_to_dict, logging |
42 | 43 | from haystack.dataclasses import Document |
43 | 44 | from haystack.document_stores.types import DuplicatePolicy |
|
58 | 59 | datetime: "Edm.DateTimeOffset", |
59 | 60 | } |
60 | 61 |
|
| 62 | + |
| 63 | +def _instantiate_azure_model(model_class: Any, data: Any) -> Any: |
| 64 | + """Instantiate an Azure SDK model from a dict, picking the right subclass if the model has subtypes.""" |
| 65 | + # Some Azure base classes (e.g. LexicalAnalyzer) have multiple subclasses (e.g. CustomAnalyzer); |
| 66 | + # the concrete subclass to use is named in the "@odata.type" field of the dict, and the base class |
| 67 | + # exposes a __mapping__ from that name to the subclass. |
| 68 | + if isinstance(data, Mapping): |
| 69 | + subtype_name = data.get("@odata.type") |
| 70 | + subtypes = getattr(model_class, "__mapping__", {}) |
| 71 | + if subtype_name in subtypes: |
| 72 | + return subtypes[subtype_name](data) |
| 73 | + return model_class(data) |
| 74 | + |
| 75 | + |
61 | 76 | # Map of expected field names to their corresponding classes |
62 | 77 | AZURE_CLASS_MAPPING: dict[str, Any] = { |
63 | 78 | "suggesters": SearchSuggester, |
@@ -273,11 +288,11 @@ def _create_index(self) -> None: |
273 | 288 |
|
274 | 289 | # default fields to create index based on Haystack Document (id, content, embedding) |
275 | 290 | default_fields = [ |
276 | | - SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True), |
277 | | - SearchableField(name="content", type=SearchFieldDataType.String), |
| 291 | + SimpleField(name="id", type=SearchFieldDataType.STRING, key=True, filterable=True), |
| 292 | + SearchableField(name="content", type=SearchFieldDataType.STRING), |
278 | 293 | SearchField( |
279 | 294 | name="embedding", |
280 | | - type=SearchFieldDataType.Collection(SearchFieldDataType.Single), |
| 295 | + type=f"Collection({SearchFieldDataType.SINGLE.value})", |
281 | 296 | searchable=True, |
282 | 297 | hidden=False, |
283 | 298 | vector_search_dimensions=self._embedding_dimension, |
@@ -355,23 +370,21 @@ def from_dict(cls, data: dict[str, Any]) -> "AzureAISearchDocumentStore": |
355 | 370 | Deserialized component. |
356 | 371 | """ |
357 | 372 | if (fields := data["init_parameters"]["metadata_fields"]) is not None: |
358 | | - data["init_parameters"]["metadata_fields"] = { |
359 | | - key: SearchField.from_dict(field) for key, field in fields.items() |
360 | | - } |
| 373 | + data["init_parameters"]["metadata_fields"] = {key: SearchField(field) for key, field in fields.items()} |
361 | 374 | else: |
362 | 375 | data["init_parameters"]["metadata_fields"] = {} |
363 | 376 |
|
364 | 377 | for key, model_class in AZURE_CLASS_MAPPING.items(): |
365 | 378 | if key in data["init_parameters"]: |
366 | 379 | value = data["init_parameters"][key] |
367 | 380 | if isinstance(value, list): |
368 | | - data["init_parameters"][key] = [model_class.from_dict(item) for item in value] |
| 381 | + data["init_parameters"][key] = [_instantiate_azure_model(model_class, item) for item in value] |
369 | 382 | else: |
370 | | - data["init_parameters"][key] = model_class.from_dict(value) |
| 383 | + data["init_parameters"][key] = _instantiate_azure_model(model_class, value) |
371 | 384 |
|
372 | 385 | deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"]) |
373 | 386 | if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None: |
374 | | - data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration) |
| 387 | + data["init_parameters"]["vector_search_configuration"] = VectorSearch(vector_search_configuration) |
375 | 388 | return default_from_dict(cls, data) |
376 | 389 |
|
377 | 390 | def count_documents(self) -> int: |
@@ -460,7 +473,7 @@ def _map_azure_field_type(field: Any) -> str: |
460 | 473 | if field_type is None: |
461 | 474 | return "keyword" |
462 | 475 |
|
463 | | - field_type_name = str(field_type) |
| 476 | + field_type_name = field_type.value if hasattr(field_type, "value") else str(field_type) |
464 | 477 | if field_type_name.startswith("Collection("): |
465 | 478 | inner_type = field_type_name[len("Collection(") : -1] |
466 | 479 | return FIELD_TYPE_MAPPING.get(inner_type, "keyword") |
@@ -604,7 +617,7 @@ def delete_documents(self, document_ids: list[str]) -> None: |
604 | 617 | return |
605 | 618 | documents = self._get_raw_documents_by_id(document_ids) |
606 | 619 | if documents: |
607 | | - self.client.delete_documents(documents) |
| 620 | + self.client.delete_documents([dict(doc) for doc in documents]) |
608 | 621 |
|
609 | 622 | def delete_all_documents(self, recreate_index: bool = False) -> None: # noqa: FBT002, FBT001 |
610 | 623 | """ |
@@ -758,7 +771,7 @@ def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Docume |
758 | 771 | else: |
759 | 772 | return self.search_documents() |
760 | 773 |
|
761 | | - def _convert_search_result_to_documents(self, azure_docs: list[dict[str, Any]]) -> list[Document]: |
| 774 | + def _convert_search_result_to_documents(self, azure_docs: Sequence[Mapping[str, Any]]) -> list[Document]: |
762 | 775 | """ |
763 | 776 | Converts Azure search results to Haystack Documents. |
764 | 777 | """ |
@@ -807,7 +820,7 @@ def _index_exists(self, index_name: str | None) -> bool: |
807 | 820 | msg = "Index name is required to check if the index exists." |
808 | 821 | raise ValueError(msg) |
809 | 822 |
|
810 | | - def _get_raw_documents_by_id(self, document_ids: list[str]) -> list[dict]: |
| 823 | + def _get_raw_documents_by_id(self, document_ids: list[str]) -> list[LookupDocument]: |
811 | 824 | """ |
812 | 825 | Retrieves all Azure documents with a matching document_ids from the document store. |
813 | 826 |
|
|
0 commit comments