Skip to content

Commit 90484ad

Browse files
authored
feat: add str handling for ElasticsearchDocumentStore api_key (#2934)
* feat: add str handling for ElasticsearchDocumentStore api_key * apply feedback: don't serialize api_key as str
1 parent ee42f9c commit 90484ad

2 files changed

Lines changed: 41 additions & 9 deletions

File tree

integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from haystack.dataclasses import Document
1717
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
1818
from haystack.document_stores.types import DuplicatePolicy
19-
from haystack.utils import Secret, deserialize_secrets_inplace
19+
from haystack.utils import Secret
2020
from haystack.version import __version__ as haystack_version
2121

2222
from elasticsearch import AsyncElasticsearch, Elasticsearch, helpers
@@ -82,8 +82,8 @@ def __init__(
8282
hosts: Hosts | None = None,
8383
custom_mapping: dict[str, Any] | None = None,
8484
index: str = "default",
85-
api_key: Secret = Secret.from_env_var("ELASTIC_API_KEY", strict=False),
86-
api_key_id: Secret = Secret.from_env_var("ELASTIC_API_KEY_ID", strict=False),
85+
api_key: Secret | str | None = Secret.from_env_var("ELASTIC_API_KEY", strict=False),
86+
api_key_id: Secret | str | None = Secret.from_env_var("ELASTIC_API_KEY_ID", strict=False),
8787
embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine",
8888
**kwargs: Any,
8989
):
@@ -217,8 +217,10 @@ def _handle_auth(self) -> str | tuple[str, str] | None:
217217

218218
api_key: str | tuple[str, str] | None # make the type checker happy
219219

220-
api_key_resolved = self._api_key.resolve_value()
221-
api_key_id_resolved = self._api_key_id.resolve_value()
220+
api_key_resolved = self._api_key.resolve_value() if isinstance(self._api_key, Secret) else self._api_key
221+
api_key_id_resolved = (
222+
self._api_key_id.resolve_value() if isinstance(self._api_key_id, Secret) else self._api_key_id
223+
)
222224

223225
# Scenario 1: both are found, use them
224226
if api_key_id_resolved and api_key_resolved:
@@ -271,8 +273,8 @@ def to_dict(self) -> dict[str, Any]:
271273
hosts=self._hosts,
272274
custom_mapping=self._custom_mapping,
273275
index=self._index,
274-
api_key=self._api_key.to_dict(),
275-
api_key_id=self._api_key_id.to_dict(),
276+
api_key=self._api_key.to_dict() if isinstance(self._api_key, Secret) else None,
277+
api_key_id=self._api_key_id.to_dict() if isinstance(self._api_key_id, Secret) else None,
276278
embedding_similarity_function=self._embedding_similarity_function,
277279
**self._kwargs,
278280
)
@@ -287,7 +289,10 @@ def from_dict(cls, data: dict[str, Any]) -> "ElasticsearchDocumentStore":
287289
:returns:
288290
Deserialized component.
289291
"""
290-
deserialize_secrets_inplace(data, keys=["api_key", "api_key_id"], recursive=True)
292+
if (api_key := data.get("api_key")) is not None and isinstance(api_key, dict):
293+
data["api_key"] = Secret.from_dict(api_key)
294+
if (api_key_id := data.get("api_key_id")) is not None and isinstance(api_key_id, dict):
295+
data["api_key_id"] = Secret.from_dict(api_key_id)
291296
return default_from_dict(cls, data)
292297

293298
def count_documents(self) -> int:

integrations/elasticsearch/tests/test_document_store.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ def test_to_dict_with_api_keys_as_secret():
108108
_ = document_store.to_dict()
109109

110110

111+
def test_to_dict_with_api_keys_str():
112+
document_store = ElasticsearchDocumentStore(
113+
hosts="https://localhost:9200", api_key="my_api_key", api_key_id="my_api_key_id"
114+
)
115+
res = document_store.to_dict()
116+
assert res["init_parameters"]["api_key"] is None
117+
assert res["init_parameters"]["api_key_id"] is None
118+
119+
111120
def test_from_dict_with_api_keys_env_vars():
112121
data = {
113122
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
@@ -126,6 +135,24 @@ def test_from_dict_with_api_keys_env_vars():
126135
assert document_store._api_key_id == Secret.from_env_var("ELASTIC_API_KEY_ID", strict=False)
127136

128137

138+
def test_from_dict_with_api_keys_str():
139+
data = {
140+
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
141+
"init_parameters": {
142+
"hosts": "some hosts",
143+
"custom_mapping": None,
144+
"index": "default",
145+
"api_key": "my_api_key",
146+
"api_key_id": "my_api_key_id",
147+
"embedding_similarity_function": "cosine",
148+
},
149+
}
150+
151+
document_store = ElasticsearchDocumentStore.from_dict(data)
152+
assert document_store._api_key == "my_api_key"
153+
assert document_store._api_key_id == "my_api_key_id"
154+
155+
129156
def test_api_key_validation_only_api_key():
130157
api_key = Secret.from_token("test_api_key")
131158
document_store = ElasticsearchDocumentStore(hosts="https://localhost:9200", api_key=api_key)
@@ -173,7 +200,7 @@ def test_client_initialization_with_api_key_tuple(_mock_async_es, _mock_es):
173200
@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
174201
@patch("haystack_integrations.document_stores.elasticsearch.document_store.AsyncElasticsearch")
175202
def test_client_initialization_with_api_key_string(_mock_async_es, _mock_es):
176-
api_key = Secret.from_token("test_api_key")
203+
api_key = "test_api_key"
177204

178205
# Mock the client.info() call to avoid actual connection
179206
mock_client = Mock()

0 commit comments

Comments
 (0)