Skip to content

Commit 889d0c8

Browse files
feat: Add support for Azure TokenCredentials to AzureAISearchDocumentStore (#3014)
* feat: Add support for Azure TokenCredentials * Fix mypy error * Address review comments and add unit tests --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
1 parent 880057f commit 889d0c8

2 files changed

Lines changed: 60 additions & 2 deletions

File tree

integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import datetime
77
from typing import Any
88

9-
from azure.core.credentials import AzureKeyCredential
9+
from azure.core.credentials import AzureKeyCredential, TokenCredential
1010
from azure.core.exceptions import (
1111
ClientAuthenticationError,
1212
HttpResponseError,
@@ -120,6 +120,7 @@ def __init__(
120120
metadata_fields: dict[str, SearchField | type] | None = None,
121121
vector_search_configuration: VectorSearch | None = None,
122122
include_search_metadata: bool = False,
123+
azure_token_credential: TokenCredential | None = None,
123124
**index_creation_kwargs: Any,
124125
) -> None:
125126
"""
@@ -154,6 +155,8 @@ def __init__(
154155
in the returned documents. When set to True, the `meta` field of the returned
155156
documents will contain the @search.score, @search.reranker_score, @search.highlights,
156157
@search.captions, and other fields returned by Azure AI Search.
158+
:param azure_token_credential: An Azure `TokenCredential` instance used to authenticate requests.
159+
When provided, this takes priority over `api_key`.
157160
:param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class
158161
during index creation. Some of the supported parameters:
159162
- `semantic_search`: Defines semantic configuration of the search index. This parameter is needed
@@ -175,6 +178,7 @@ def __init__(
175178
self._metadata_fields = AzureAISearchDocumentStore._normalize_metadata_index_fields(metadata_fields)
176179
self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH
177180
self._include_search_metadata = include_search_metadata
181+
self._azure_token_credential = azure_token_credential
178182
self._index_creation_kwargs = index_creation_kwargs
179183

180184
@property
@@ -183,7 +187,12 @@ def client(self) -> SearchClient:
183187
resolved_endpoint = self._azure_endpoint.resolve_value()
184188
resolved_key = self._api_key.resolve_value()
185189

186-
credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential()
190+
if self._azure_token_credential is not None:
191+
credential: TokenCredential | AzureKeyCredential = self._azure_token_credential
192+
elif resolved_key:
193+
credential = AzureKeyCredential(resolved_key)
194+
else:
195+
credential = DefaultAzureCredential()
187196

188197
# build a UserAgentPolicy to be used for the request
189198
ua_policy = UserAgentPolicy(user_agent=USER_AGENT)
@@ -316,6 +325,13 @@ def to_dict(self) -> dict[str, Any]:
316325
:returns:
317326
Dictionary with serialized data.
318327
"""
328+
if self._azure_token_credential:
329+
logger.warning(
330+
"AzureAISearchDocumentStore was initialized with `azure_token_credential`, "
331+
"which cannot be serialized. It will be excluded from the serialized output "
332+
"and must be provided again when deserializing."
333+
)
334+
319335
return default_to_dict(
320336
self,
321337
azure_endpoint=(self._azure_endpoint.to_dict() if self._azure_endpoint else None),

integrations/azure_ai_search/tests/test_document_store.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import logging
56
import os
67
import random
78
from datetime import datetime, timezone
89
from unittest.mock import Mock, patch
910

1011
import pytest
12+
from azure.core.credentials import TokenCredential
1113
from azure.search.documents.indexes.models import (
1214
CustomAnalyzer,
1315
SearchableField,
@@ -135,6 +137,24 @@ def test_to_dict_with_params(monkeypatch):
135137
}
136138

137139

140+
def test_to_dict_emits_warning_when_token_credential_is_used(
141+
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
142+
) -> None:
143+
monkeypatch.setenv("AZURE_AI_SEARCH_API_KEY", "test-api-key")
144+
monkeypatch.setenv("AZURE_AI_SEARCH_ENDPOINT", "test-endpoint")
145+
146+
mock_token_credential = Mock(spec=TokenCredential)
147+
document_store = AzureAISearchDocumentStore(azure_token_credential=mock_token_credential)
148+
149+
with caplog.at_level(logging.WARNING):
150+
result = document_store.to_dict()
151+
152+
assert "`azure_token_credential`, which cannot be serialized." in caplog.text
153+
154+
# token credential should not appear in the serialized output
155+
assert "azure_token_credential" not in result["init_parameters"]
156+
157+
138158
def test_from_dict(monkeypatch):
139159
monkeypatch.setenv("AZURE_AI_SEARCH_API_KEY", "test-api-key")
140160
monkeypatch.setenv("AZURE_AI_SEARCH_ENDPOINT", "test-endpoint")
@@ -246,6 +266,28 @@ def test_init():
246266
assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH
247267

248268

269+
def test_token_credential_takes_priority_over_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
270+
monkeypatch.setenv("AZURE_AI_SEARCH_API_KEY", "test-api-key")
271+
monkeypatch.setenv("AZURE_AI_SEARCH_ENDPOINT", "test-endpoint")
272+
273+
mock_token_credential = Mock(spec=TokenCredential)
274+
document_store = AzureAISearchDocumentStore(azure_token_credential=mock_token_credential)
275+
276+
with patch(
277+
"haystack_integrations.document_stores.azure_ai_search.document_store.SearchIndexClient"
278+
) as mock_index_client_cls:
279+
mock_index_client = Mock()
280+
mock_index_client.list_index_names.return_value = ["default"]
281+
mock_index_client.get_index.return_value = Mock(fields=[])
282+
mock_index_client.get_search_client.return_value = Mock()
283+
mock_index_client_cls.return_value = mock_index_client
284+
285+
_ = document_store.client
286+
287+
_, kwargs = mock_index_client_cls.call_args
288+
assert kwargs["credential"] is mock_token_credential
289+
290+
249291
def _build_mock_document_store_with_schema(index_fields):
250292
store = AzureAISearchDocumentStore(
251293
api_key=Secret.from_token("fake-api-key"),

0 commit comments

Comments
 (0)