diff --git a/integrations/weaviate/docker-compose.yml b/integrations/weaviate/docker-compose.yml index 67cbd31ea7..f8d16f5723 100644 --- a/integrations/weaviate/docker-compose.yml +++ b/integrations/weaviate/docker-compose.yml @@ -1,4 +1,3 @@ -version: '3.4' services: weaviate: command: diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index fe4ff6400a..1f71ea3324 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -217,8 +217,6 @@ def client(self): self._client.connect() - # Test connection, it will raise an exception if it fails. - self._client.collections.list_all(simple=True) if not self._client.collections.exists(self._collection_settings["class"]): self._client.collections.create_from_dict(self._collection_settings) @@ -262,8 +260,7 @@ async def async_client(self): ) await self._async_client.connect() - # Test connection, it will raise an exception if it fails. - await self._async_client.collections.list_all(simple=True) + if not await self._async_client.collections.exists(self._collection_settings["class"]): await self._async_client.collections.create_from_dict(self._collection_settings) @@ -287,6 +284,24 @@ async def async_collection(self): self._async_collection = async_client.collections.get(self._collection_settings["class"]) return self._async_collection + def close(self) -> None: + """ + Close the synchronous Weaviate client connection. + """ + if self._client: + self._client.close() + self._client = None + self._collection = None + + async def close_async(self) -> None: + """ + Close the asynchronous Weaviate client connection. + """ + if self._async_client: + await self._async_client.close() + self._async_client = None + self._async_collection = None + def to_dict(self) -> dict[str, Any]: """ Serializes the component to a dictionary. diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 3ca6cf4f51..7b8e67999e 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -5,6 +5,7 @@ import base64 import logging import os +from collections.abc import Generator from unittest.mock import MagicMock, patch import pytest @@ -46,7 +47,7 @@ def test_init_is_lazy(_mock_client): @pytest.mark.integration class TestWeaviateDocumentStore(DocumentStoreBaseExtendedTests): @pytest.fixture - def document_store(self, request) -> WeaviateDocumentStore: + def document_store(self, request) -> Generator[WeaviateDocumentStore, None, None]: # Use a different index for each test so we can run them in parallel collection_settings = { "class": f"{request.node.name}", @@ -65,6 +66,7 @@ def document_store(self, request) -> WeaviateDocumentStore: ) yield store store.client.collections.delete(collection_settings["class"]) + store.close() @pytest.fixture def filterable_docs(self) -> list[Document]: @@ -154,6 +156,22 @@ def test_connection(self, mock_weaviate_client_class, monkeypatch): {"class": "My_collection", "properties": DOCUMENT_COLLECTION_PROPERTIES} ) + def test_close(self, document_store: WeaviateDocumentStore) -> None: + # Initialise client and collection + assert document_store.client is not None + assert document_store.collection is not None + + document_store.close() + + assert document_store._client is None + assert document_store._collection is None + + # Initialise client and collection, then test it stills works after reopening + assert document_store.client is not None + assert document_store.collection is not None + + assert document_store.count_documents() == 0 + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") def test_to_dict(self, _mock_weaviate, monkeypatch): monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") diff --git a/integrations/weaviate/tests/test_document_store_async.py b/integrations/weaviate/tests/test_document_store_async.py index b4c7803f37..4674c69b55 100644 --- a/integrations/weaviate/tests/test_document_store_async.py +++ b/integrations/weaviate/tests/test_document_store_async.py @@ -2,7 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 +from collections.abc import AsyncGenerator + import pytest +import pytest_asyncio from haystack.dataclasses.document import Document from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore @@ -11,8 +14,8 @@ @pytest.mark.integration class TestWeaviateDocumentStoreAsync: - @pytest.fixture - def document_store(self, request) -> WeaviateDocumentStore: + @pytest_asyncio.fixture + async def document_store(self, request) -> AsyncGenerator[WeaviateDocumentStore, None, None]: collection_settings = { "class": f"{request.node.name}", "invertedIndexConfig": {"indexNullState": True}, @@ -29,6 +32,34 @@ def document_store(self, request) -> WeaviateDocumentStore: ) yield store store.client.collections.delete(collection_settings["class"]) + store.close() + await store.close_async() + + @pytest.mark.asyncio + async def test_close_async(self, document_store: WeaviateDocumentStore) -> None: + # Initialise client and collection + assert await document_store.async_client is not None + assert await document_store.async_collection is not None + + await document_store.close_async() + + assert document_store._async_client is None + assert document_store._async_collection is None + + # Initialise client and collection, then test it stills works after reopening + assert await document_store.async_client is not None + assert await document_store.async_collection is not None + + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Python is an object oriented programming language"), + ] + ) + filters = {"field": "content", "operator": "==", "value": "Haskell"} + + assert await document_store.count_documents_by_filter_async(filters) == 1 @pytest.mark.asyncio async def test_bm25_retrieval_async(self, document_store):