Skip to content

Commit 920c8e2

Browse files
fix: OpenSearch async client initialization (#2645)
* offload init to a new thread in async environment * Add _ensure_initialized_async method * Fix method call * Fix lint * Fix format * Update tests teardown * Fix tests fixtures * Update test treardown * Update conftest * Update conftest --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent 690e767 commit 920c8e2

3 files changed

Lines changed: 45 additions & 17 deletions

File tree

integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def from_dict(cls, data: dict[str, Any]) -> "OpenSearchDocumentStore":
262262

263263
def _ensure_initialized(self):
264264
# Ideally, we have a warm-up stage for document stores as well as components.
265-
if not self._initialized:
265+
if not self._client:
266266
self._client = OpenSearch(
267267
hosts=self._hosts,
268268
http_auth=self._http_auth,
@@ -271,6 +271,12 @@ def _ensure_initialized(self):
271271
timeout=self._timeout,
272272
**self._kwargs,
273273
)
274+
self._initialized = True
275+
276+
self._ensure_index_exists()
277+
278+
async def _ensure_initialized_async(self):
279+
if not self._async_client:
274280
async_http_auth = AsyncAWSAuth(self._http_auth) if isinstance(self._http_auth, AWSAuth) else self._http_auth
275281
self._async_client = AsyncOpenSearch(
276282
hosts=self._hosts,
@@ -283,10 +289,22 @@ def _ensure_initialized(self):
283289
connection_class=AsyncHttpConnection,
284290
**self._kwargs,
285291
)
286-
287292
self._initialized = True
293+
await self._ensure_index_exists_async()
288294

289-
self._ensure_index_exists()
295+
async def _ensure_index_exists_async(self):
296+
assert self._async_client is not None
297+
298+
if await self._async_client.indices.exists(index=self._index):
299+
logger.debug(
300+
"The index '{index}' already exists. The `embedding_dim`, `method`, `mappings`, and "
301+
"`settings` values will be ignored.",
302+
index=self._index,
303+
)
304+
elif self._create_index:
305+
# Create the index if it doesn't exist
306+
body = {"mappings": self._mappings, "settings": self._settings}
307+
await self._async_client.indices.create(index=self._index, body=body)
290308

291309
def _ensure_index_exists(self):
292310
assert self._client is not None
@@ -315,7 +333,7 @@ async def count_documents_async(self) -> int:
315333
"""
316334
Asynchronously returns the total number of documents in the document store.
317335
"""
318-
self._ensure_initialized()
336+
await self._ensure_initialized_async()
319337

320338
assert self._async_client is not None
321339
return (await self._async_client.count(index=self._index))["count"]
@@ -376,7 +394,8 @@ async def filter_documents_async(self, filters: Optional[dict[str, Any]] = None)
376394
:param filters: The filters to apply to the document list.
377395
:returns: A list of Documents that match the given filters.
378396
"""
379-
self._ensure_initialized()
397+
await self._ensure_initialized_async()
398+
380399
return await self._search_documents_async(self._prepare_filter_search_request(filters))
381400

382401
def _prepare_bulk_write_request(
@@ -477,7 +496,8 @@ async def write_documents_async(
477496
:param policy: The duplicate policy to use when writing documents.
478497
:returns: The number of documents written to the document store.
479498
"""
480-
self._ensure_initialized()
499+
await self._ensure_initialized_async()
500+
assert self._async_client is not None
481501
bulk_params = self._prepare_bulk_write_request(documents=documents, policy=policy, is_async=True)
482502
documents_written, errors = await async_bulk(**bulk_params)
483503
# since we call async_bulk with stats_only=False, errors is guaranteed to be a list (not int)
@@ -525,7 +545,8 @@ async def delete_documents_async(self, document_ids: list[str]) -> None:
525545
526546
:param document_ids: the document ids to delete
527547
"""
528-
self._ensure_initialized()
548+
await self._ensure_initialized_async()
549+
assert self._async_client is not None
529550

530551
await async_bulk(**self._prepare_bulk_delete_request(document_ids=document_ids, is_async=True))
531552

@@ -583,7 +604,7 @@ async def delete_all_documents_async(self, recreate_index: bool = False) -> None
583604
:param recreate_index: If True, the index will be deleted and recreated with the original mappings and
584605
settings. If False, all documents will be deleted using the `delete_by_query` API.
585606
"""
586-
self._ensure_initialized()
607+
await self._ensure_initialized_async()
587608
assert self._async_client is not None
588609

589610
try:
@@ -643,7 +664,7 @@ async def delete_by_filter_async(self, filters: dict[str, Any]) -> int:
643664
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering)
644665
:returns: The number of documents deleted.
645666
"""
646-
self._ensure_initialized()
667+
await self._ensure_initialized_async()
647668
assert self._async_client is not None
648669

649670
try:
@@ -707,7 +728,7 @@ async def update_by_filter_async(self, filters: dict[str, Any], meta: dict[str,
707728
:param meta: The metadata fields to update.
708729
:returns: The number of documents updated.
709730
"""
710-
self._ensure_initialized()
731+
await self._ensure_initialized_async()
711732
assert self._async_client is not None
712733

713734
try:
@@ -863,7 +884,8 @@ async def _bm25_retrieval_async(
863884
See `OpenSearchBM25Retriever` for more information.
864885
"""
865886

866-
self._ensure_initialized()
887+
await self._ensure_initialized_async()
888+
assert self._async_client is not None
867889

868890
search_params = self._prepare_bm25_search_request(
869891
query=query,
@@ -982,7 +1004,8 @@ async def _embedding_retrieval_async(
9821004
9831005
See `OpenSearchEmbeddingRetriever` for more information.
9841006
"""
985-
self._ensure_initialized()
1007+
await self._ensure_initialized_async()
1008+
assert self._async_client is not None
9861009

9871010
search_params = self._prepare_embedding_search_request(
9881011
query_embedding=query_embedding,

integrations/opensearch/tests/conftest.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ def document_store(request):
2525
return_embedding=True,
2626
method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"},
2727
)
28+
store._ensure_initialized()
2829
yield store
2930

30-
store._ensure_initialized()
31+
asyncio.run(store._ensure_initialized_async())
3132
assert store._client
33+
assert store._async_client
3234
store._client.indices.delete(index=index, params={"ignore": [400, 404]})
3335
asyncio.run(store._async_client.close())
3436

@@ -51,8 +53,11 @@ def document_store_2(request):
5153

5254
# Cleanup
5355
store._ensure_initialized()
56+
asyncio.run(store._ensure_initialized_async())
5457
assert store._client
58+
assert store._async_client
5559
store._client.indices.delete(index=index, params={"ignore": [400, 404]})
60+
asyncio.run(store._async_client.close())
5661

5762

5863
@pytest.fixture
@@ -74,13 +79,14 @@ def document_store_readonly(request):
7479
create_index=False,
7580
)
7681
store._ensure_initialized()
82+
asyncio.run(store._ensure_initialized_async())
7783
assert store._client
84+
assert store._async_client
7885
store._client.cluster.put_settings(body={"transient": {"action.auto_create_index": False}})
7986
yield store
8087

8188
store._client.cluster.put_settings(body={"transient": {"action.auto_create_index": True}})
8289
store._client.indices.delete(index=index, params={"ignore": [400, 404]})
83-
asyncio.run(store._async_client.close())
8490

8591

8692
@pytest.fixture
@@ -104,7 +110,6 @@ def document_store_embedding_dim_4_no_emb_returned(request):
104110
yield store
105111

106112
store._client.indices.delete(index=index, params={"ignore": [400, 404]})
107-
asyncio.run(store._async_client.close())
108113

109114

110115
@pytest.fixture
@@ -129,7 +134,6 @@ def document_store_embedding_dim_4_no_emb_returned_faiss(request):
129134
yield store
130135

131136
store._client.indices.delete(index=index, params={"ignore": [400, 404]})
132-
asyncio.run(store._async_client.close())
133137

134138

135139
@pytest.fixture

integrations/opensearch/tests/test_document_store_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
from haystack.dataclasses import Document
9+
from haystack.document_stores.types import DuplicatePolicy
910

1011
from haystack_integrations.document_stores.opensearch.document_store import OpenSearchDocumentStore
1112

@@ -14,7 +15,7 @@
1415
class TestDocumentStoreAsync:
1516
@pytest.mark.asyncio
1617
async def test_write_documents(self, document_store: OpenSearchDocumentStore):
17-
assert await document_store.write_documents_async([Document(id="1")]) == 1
18+
assert await document_store.write_documents_async([Document(id="1")], policy=DuplicatePolicy.OVERWRITE) == 1
1819

1920
@pytest.mark.asyncio
2021
async def test_bm25_retrieval(self, document_store: OpenSearchDocumentStore, test_documents: list[Document]):

0 commit comments

Comments
 (0)