Skip to content

Commit bd33bcc

Browse files
committed
fix: Remove sem/gather and improve handling of alpha default values
1 parent 7652c08 commit bd33bcc

4 files changed

Lines changed: 38 additions & 35 deletions

File tree

integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/hybrid_retriever.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from haystack.document_stores.types.filter_policy import apply_filter_policy
1010

1111
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore
12+
from haystack_integrations.document_stores.weaviate.document_store import DEFAULT_ALPHA
1213

1314

1415
@component
@@ -23,7 +24,7 @@ def __init__(
2324
document_store: WeaviateDocumentStore,
2425
filters: dict[str, Any] | None = None,
2526
top_k: int = 10,
26-
alpha: float | None = None,
27+
alpha: float = DEFAULT_ALPHA,
2728
max_vector_distance: float | None = None,
2829
filter_policy: str | FilterPolicy = FilterPolicy.REPLACE,
2930
):
@@ -46,7 +47,7 @@ def __init__(
4647
- `alpha = 1.0`: only vector similarity scoring is used.
4748
- Values in between blend the two; higher values favor the vector score, lower values favor BM25.
4849
49-
If `None`, the Weaviate server default is used.
50+
By default, 0.7 is used which is the Weaviate server default.
5051
5152
See the official Weaviate docs on Hybrid Search parameters for more details:
5253
- [Hybrid search parameters](https://weaviate.io/developers/weaviate/search/hybrid#parameters)
@@ -66,7 +67,7 @@ def __init__(
6667
Policy to determine how filters are applied.
6768
"""
6869

69-
if alpha is not None and not 0.0 <= alpha <= 1.0:
70+
if not 0.0 <= alpha <= 1.0:
7071
msg = f"alpha ({alpha}) must be in the range [0.0, 1.0]"
7172
raise ValueError(msg)
7273

integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,14 @@
55
import base64
66
import datetime
77
import json
8-
from asyncio import Semaphore, gather
98
from dataclasses import asdict
109
from typing import Any, NoReturn
11-
from uuid import UUID
1210

1311
from haystack import logging
1412
from haystack.core.serialization import default_from_dict, default_to_dict
1513
from haystack.dataclasses.document import Document
1614
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
1715
from haystack.document_stores.types.policy import DuplicatePolicy
18-
from more_itertools import batched
1916

2017
import weaviate
2118
from weaviate.collections.classes.aggregate import (
@@ -69,6 +66,9 @@
6966
# See WeaviateDocumentStore._query_with_filters() for more information.
7067
DEFAULT_QUERY_LIMIT = 9999
7168

69+
# See weaviate.collections.queries.hybrid.query.sync.pyi for the default value of alpha
70+
DEFAULT_ALPHA = 0.7
71+
7272

7373
class WeaviateDocumentStore:
7474
"""
@@ -112,7 +112,6 @@ def __init__(
112112
additional_config: AdditionalConfig | None = None,
113113
grpc_port: int = 50051,
114114
grpc_secure: bool = False,
115-
concurrency_limit: int = 5,
116115
) -> None:
117116
"""
118117
Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance.
@@ -156,8 +155,6 @@ def __init__(
156155
The port to use for the gRPC connection.
157156
:param grpc_secure:
158157
Whether to use a secure channel for the underlying gRPC API.
159-
:param concurrency_limit:
160-
Number of parallel requests to make. Default is 5.
161158
"""
162159
self._url = url
163160
self._auth_client_secret = auth_client_secret
@@ -166,7 +163,6 @@ def __init__(
166163
self._additional_config = additional_config
167164
self._grpc_port = grpc_port
168165
self._grpc_secure = grpc_secure
169-
self.concurrency_limit = concurrency_limit
170166
self._client: weaviate.WeaviateClient | None = None
171167
self._async_client: weaviate.WeaviateAsyncClient | None = None
172168
self._collection: weaviate.Collection | None = None
@@ -1225,7 +1221,6 @@ async def delete_all_documents_async(self, *, recreate_index: bool = False, batc
12251221
Reference: https://docs.weaviate.io/weaviate/manage-objects/delete#delete-all-objects
12261222
"""
12271223
client = await self.async_client
1228-
sem = Semaphore(max(1, self.concurrency_limit))
12291224

12301225
if recreate_index:
12311226
# get current up-to-date config from server, so we can recreate the collection faithfully
@@ -1245,18 +1240,23 @@ async def delete_all_documents_async(self, *, recreate_index: bool = False, batc
12451240
collection = await self.async_collection
12461241
async for obj in collection.iterator(return_properties=[], include_vector=False):
12471242
uuids.append(obj.uuid)
1248-
1249-
async def _runner(uuids: list[UUID]) -> None:
1250-
async with sem:
1243+
if len(uuids) >= batch_size:
12511244
res = await collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(uuids))
1245+
if res.successful < len(uuids):
1246+
logger.warning(
1247+
"Not all documents in the batch have been deleted. "
1248+
"Make sure to specify a deletion `batch_size` which is less than `QUERY_MAXIMUM_RESULTS`.",
1249+
)
1250+
uuids.clear()
1251+
1252+
if uuids:
1253+
res = await collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(uuids))
12521254
if res.successful < len(uuids):
12531255
logger.warning(
1254-
"Not all documents in the batch have been deleted. "
1256+
"Not all documents have been deleted. "
12551257
"Make sure to specify a deletion `batch_size` which is less than `QUERY_MAXIMUM_RESULTS`.",
12561258
)
12571259

1258-
await gather(*[_runner(list(batch)) for batch in batched(uuids, batch_size)])
1259-
12601260
def delete_by_filter(self, filters: dict[str, Any]) -> int:
12611261
"""
12621262
Deletes all documents that match the provided filters.
@@ -1572,12 +1572,10 @@ def _hybrid_retrieval(
15721572
query_embedding: list[float],
15731573
filters: dict[str, Any] | None = None,
15741574
top_k: int | None = None,
1575-
alpha: float | None = None,
1575+
alpha: float = DEFAULT_ALPHA,
15761576
max_vector_distance: float | None = None,
15771577
) -> list[Document]:
15781578
properties = [p.name for p in self.collection.config.get().properties]
1579-
if alpha is None:
1580-
alpha = 0.7
15811579
result = self.collection.query.hybrid(
15821580
query=query,
15831581
vector=query_embedding,
@@ -1599,14 +1597,12 @@ async def _hybrid_retrieval_async(
15991597
query_embedding: list[float],
16001598
filters: dict[str, Any] | None = None,
16011599
top_k: int | None = None,
1602-
alpha: float | None = None,
1600+
alpha: float = DEFAULT_ALPHA,
16031601
max_vector_distance: float | None = None,
16041602
) -> list[Document]:
16051603
collection = await self.async_collection
16061604
config = await collection.config.get()
16071605
properties = [p.name for p in config.properties]
1608-
if alpha is None:
1609-
alpha = 0.7
16101606
result = await collection.query.hybrid(
16111607
query=query,
16121608
vector=query_embedding,

integrations/weaviate/tests/test_document_store_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,4 +588,4 @@ async def test_delete_all_documents_excessive_batch_size_async(
588588
with caplog.at_level(logging.WARNING):
589589
await document_store.delete_all_documents_async(batch_size=20000)
590590
assert await document_store.count_documents_async() == 5
591-
assert "Not all documents in the batch have been deleted." in caplog.text
591+
assert "Not all documents have been deleted." in caplog.text

integrations/weaviate/tests/test_hybrid_retriever.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from haystack_integrations.components.retrievers.weaviate import WeaviateHybridRetriever
1111
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore
12+
from haystack_integrations.document_stores.weaviate.document_store import DEFAULT_ALPHA
1213

1314

1415
def test_init_default():
@@ -17,7 +18,7 @@ def test_init_default():
1718
assert retriever._document_store == mock_document_store
1819
assert retriever._filters == {}
1920
assert retriever._top_k == 10
20-
assert retriever._alpha is None
21+
assert retriever._alpha == DEFAULT_ALPHA
2122
assert retriever._max_vector_distance is None
2223
assert retriever._filter_policy == FilterPolicy.REPLACE
2324

@@ -56,7 +57,7 @@ def test_to_dict(_mock_weaviate):
5657
"init_parameters": {
5758
"filters": {},
5859
"top_k": 10,
59-
"alpha": None,
60+
"alpha": DEFAULT_ALPHA,
6061
"max_vector_distance": None,
6162
"filter_policy": "replace",
6263
"document_store": {
@@ -112,7 +113,7 @@ def test_from_dict(_mock_weaviate):
112113
"init_parameters": {
113114
"filters": {},
114115
"top_k": 10,
115-
"alpha": None,
116+
"alpha": DEFAULT_ALPHA,
116117
"max_vector_distance": None,
117118
"filter_policy": "replace",
118119
"document_store": {
@@ -142,7 +143,7 @@ def test_from_dict(_mock_weaviate):
142143
assert retriever._document_store
143144
assert retriever._filters == {}
144145
assert retriever._top_k == 10
145-
assert retriever._alpha is None
146+
assert retriever._alpha == DEFAULT_ALPHA
146147
assert retriever._max_vector_distance is None
147148

148149

@@ -200,7 +201,12 @@ def test_run_basic():
200201
assert "documents" in result
201202
assert len(result["documents"]) == 1
202203
mock_document_store._hybrid_retrieval.assert_called_once_with(
203-
query="test query", query_embedding=[0.1, 0.2, 0.3], filters={}, top_k=10, alpha=None, max_vector_distance=None
204+
query="test query",
205+
query_embedding=[0.1, 0.2, 0.3],
206+
filters={},
207+
top_k=10,
208+
alpha=DEFAULT_ALPHA,
209+
max_vector_distance=None,
204210
)
205211

206212

@@ -217,7 +223,7 @@ def test_run_with_runtime_filters():
217223
query_embedding=[0.1, 0.2, 0.3],
218224
filters={"runtime": "filter"},
219225
top_k=10,
220-
alpha=None,
226+
alpha=DEFAULT_ALPHA,
221227
max_vector_distance=None,
222228
)
223229

@@ -259,7 +265,7 @@ def test_run_empty_query():
259265
assert "documents" in result
260266
assert len(result["documents"]) == 0
261267
mock_document_store._hybrid_retrieval.assert_called_once_with(
262-
query="", query_embedding=[0.1, 0.2, 0.3], filters={}, top_k=10, alpha=None, max_vector_distance=None
268+
query="", query_embedding=[0.1, 0.2, 0.3], filters={}, top_k=10, alpha=DEFAULT_ALPHA, max_vector_distance=None
263269
)
264270

265271

@@ -288,7 +294,7 @@ def test_from_dict_no_filter_policy(_mock_weaviate):
288294
"init_parameters": {
289295
"filters": {},
290296
"top_k": 10,
291-
"alpha": None,
297+
"alpha": DEFAULT_ALPHA,
292298
"max_vector_distance": None,
293299
# filter_policy intentionally omitted
294300
"document_store": {
@@ -318,7 +324,7 @@ def test_from_dict_no_filter_policy(_mock_weaviate):
318324
assert retriever._document_store
319325
assert retriever._filters == {}
320326
assert retriever._top_k == 10
321-
assert retriever._alpha is None
327+
assert retriever._alpha == DEFAULT_ALPHA
322328
assert retriever._max_vector_distance is None
323329
assert retriever._filter_policy == FilterPolicy.REPLACE
324330

@@ -381,7 +387,7 @@ def test_run_with_max_vector_distance_zero_runtime():
381387
query_embedding=[0.1, 0.2],
382388
filters={},
383389
top_k=10,
384-
alpha=None,
390+
alpha=DEFAULT_ALPHA,
385391
max_vector_distance=0.0,
386392
)
387393

@@ -402,7 +408,7 @@ def test_run_with_max_vector_distance_zero_init_and_none_runtime():
402408
query_embedding=[0.1, 0.2],
403409
filters={},
404410
top_k=10,
405-
alpha=None,
411+
alpha=DEFAULT_ALPHA,
406412
max_vector_distance=0.0,
407413
)
408414

0 commit comments

Comments
 (0)