44import logging as python_logging
55import os
66from datetime import datetime
7- from typing import Any , Dict , List , Optional
7+ from typing import Any , ClassVar , Dict , List , Optional
88
99from azure .core .credentials import AzureKeyCredential
1010from azure .core .exceptions import ClientAuthenticationError , HttpResponseError , ResourceNotFoundError
1111from azure .identity import DefaultAzureCredential
1212from azure .search .documents import SearchClient
1313from azure .search .documents .indexes import SearchIndexClient
1414from azure .search .documents .indexes .models import (
15+ CharFilter ,
16+ CorsOptions ,
1517 HnswAlgorithmConfiguration ,
1618 HnswParameters ,
19+ LexicalAnalyzer ,
20+ LexicalTokenizer ,
21+ ScoringProfile ,
1722 SearchableField ,
1823 SearchField ,
1924 SearchFieldDataType ,
2025 SearchIndex ,
26+ SearchResourceEncryptionKey ,
27+ SearchSuggester ,
28+ SimilarityAlgorithm ,
2129 SimpleField ,
30+ TokenFilter ,
2231 VectorSearch ,
2332 VectorSearchAlgorithmMetric ,
2433 VectorSearchProfile ,
4049 datetime : "Edm.DateTimeOffset" ,
4150}
4251
52+ # Map of expected field names to their corresponding classes
53+ AZURE_CLASS_MAPPING = {
54+ "suggesters" : SearchSuggester ,
55+ "analyzers" : LexicalAnalyzer ,
56+ "tokenizers" : LexicalTokenizer ,
57+ "token_filters" : TokenFilter ,
58+ "char_filters" : CharFilter ,
59+ "cors_options" : CorsOptions ,
60+ "similarity_algorithm" : SimilarityAlgorithm ,
61+ "encryption_key" : SearchResourceEncryptionKey ,
62+ "scoring_profiles" : ScoringProfile ,
63+ }
64+
4365DEFAULT_VECTOR_SEARCH = VectorSearch (
4466 profiles = [
4567 VectorSearchProfile (name = "default-vector-config" , algorithm_configuration_name = "cosine-algorithm-config" )
6082
6183
6284class AzureAISearchDocumentStore :
85+ TYPE_MAP : ClassVar [Dict [str , type ]] = {"str" : str , "int" : int , "float" : float , "bool" : bool , "datetime" : datetime }
86+
6387 def __init__ (
6488 self ,
6589 * ,
@@ -138,7 +162,7 @@ def client(self) -> SearchClient:
138162 "The index '{idx_name}' does not exist. A new index will be created." ,
139163 idx_name = self ._index_name ,
140164 )
141- self ._create_index (self . _index_name )
165+ self ._create_index ()
142166 except (HttpResponseError , ClientAuthenticationError ) as error :
143167 msg = f"Failed to authenticate with Azure Search: { error } "
144168 raise AzureAISearchDocumentStoreConfigError (msg ) from error
@@ -154,11 +178,9 @@ def client(self) -> SearchClient:
154178
155179 return self ._client
156180
157- def _create_index (self , index_name : str ) -> None :
181+ def _create_index (self ) -> None :
158182 """
159- Creates a new search index.
160- :param index_name: Name of the index to create. If None, the index name from the constructor is used.
161- :param kwargs: Optional keyword parameters.
183+ Internally creates a new search index.
162184 """
163185
164186 # default fields to create index based on Haystack Document (id, content, embedding)
@@ -175,19 +197,66 @@ def _create_index(self, index_name: str) -> None:
175197 ),
176198 ]
177199
178- if not index_name :
179- index_name = self ._index_name
180200 if self ._metadata_fields :
181201 default_fields .extend (self ._create_metadata_index_fields (self ._metadata_fields ))
182202 index = SearchIndex (
183- name = index_name ,
203+ name = self . _index_name ,
184204 fields = default_fields ,
185205 vector_search = self ._vector_search_configuration ,
186206 ** self ._index_creation_kwargs ,
187207 )
188208 if self ._index_client :
189209 self ._index_client .create_index (index )
190210
211+ @classmethod
212+ def _deserialize_metadata_fields (cls , fields : Optional [Dict [str , str ]]) -> Optional [Dict [str , type ]]:
213+ """Convert string representations back to type objects."""
214+ if not fields :
215+ return None
216+ try :
217+ # Use the class-level TYPE_MAP for conversion.
218+ ans = {key : cls .TYPE_MAP [value ] for key , value in fields .items ()}
219+ return ans
220+ except KeyError as e :
221+ msg = f"Unsupported type encountered in metadata_fields: { e } "
222+ raise ValueError (msg ) from e
223+
224+ @staticmethod
225+ def _serialize_index_creation_kwargs (index_creation_kwargs : Dict [str , Any ]) -> Dict [str , Any ]:
226+ """
227+ Serializes the index creation kwargs to a dictionary.
228+ This is needed to handle serialization of Azure AI Search classes
229+ that are passed in the index creation kwargs.
230+ """
231+ result = {}
232+ for key , value in index_creation_kwargs .items ():
233+ if hasattr (value , "as_dict" ):
234+ result [key ] = value .as_dict ()
235+ elif isinstance (value , list ) and all (hasattr (item , "as_dict" ) for item in value ):
236+ result [key ] = [item .as_dict () for item in value ]
237+ else :
238+ result [key ] = value
239+ return result
240+
241+ @classmethod
242+ def _deserialize_index_creation_kwargs (cls , data : Dict [str , Any ]) -> Any :
243+ """
244+ Deserializes the index creation kwargs to the original classes.
245+ """
246+ result = {}
247+ for key , value in data .items ():
248+ if key in AZURE_CLASS_MAPPING :
249+ if isinstance (value , list ):
250+ result [key ] = [AZURE_CLASS_MAPPING [key ].from_dict (item ) for item in value ]
251+ else :
252+ result [key ] = AZURE_CLASS_MAPPING [key ].from_dict (value )
253+ elif isinstance (value , dict ) and hasattr (value , "from_dict" ):
254+ result [key ] = value .from_dict (value )
255+ else :
256+ result [key ] = value
257+
258+ return result [key ]
259+
191260 def to_dict (self ) -> Dict [str , Any ]:
192261 # This is not the best solution to serialise this class but is the fastest to implement.
193262 # Not all kwargs types can be serialised to text so this can fail. We must serialise each
@@ -198,15 +267,21 @@ def to_dict(self) -> Dict[str, Any]:
198267 :returns:
199268 Dictionary with serialized data.
200269 """
270+
271+ if self ._metadata_fields :
272+ serialized_metadata = {key : value .__name__ for key , value in self ._metadata_fields .items ()}
273+ else :
274+ serialized_metadata = None
275+
201276 return default_to_dict (
202277 self ,
203278 azure_endpoint = self ._azure_endpoint .to_dict () if self ._azure_endpoint else None ,
204279 api_key = self ._api_key .to_dict () if self ._api_key else None ,
205280 index_name = self ._index_name ,
206281 embedding_dimension = self ._embedding_dimension ,
207- metadata_fields = self . _metadata_fields ,
282+ metadata_fields = serialized_metadata ,
208283 vector_search_configuration = self ._vector_search_configuration .as_dict (),
209- ** self ._index_creation_kwargs ,
284+ ** self ._serialize_index_creation_kwargs ( self . _index_creation_kwargs ) ,
210285 )
211286
212287 @classmethod
@@ -220,6 +295,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore":
220295 :returns:
221296 Deserialized component.
222297 """
298+ if (fields := data ["init_parameters" ]["metadata_fields" ]) is not None :
299+ data ["init_parameters" ]["metadata_fields" ] = cls ._deserialize_metadata_fields (fields )
300+
301+ for key , _value in AZURE_CLASS_MAPPING .items ():
302+ if key in data ["init_parameters" ]:
303+ param_value = data ["init_parameters" ].get (key )
304+ data ["init_parameters" ][key ] = cls ._deserialize_index_creation_kwargs ({key : param_value })
223305
224306 deserialize_secrets_inplace (data ["init_parameters" ], keys = ["api_key" , "azure_endpoint" ])
225307 if (vector_search_configuration := data ["init_parameters" ].get ("vector_search_configuration" )) is not None :
0 commit comments