Skip to content

Commit fd5453d

Browse files
committed
Improvements to batching logic:
- Avoid 60s timeout putting into self.__reqs - Handle graceful stopping of all bg threads - Allow cancelling of hanging streams - Reraise bg_exceptions when they happen - Align shutdown timeout with client-defined insert timeout - Add mock tests of cancelling bidi streamsa
1 parent f5d5098 commit fd5453d

6 files changed

Lines changed: 110 additions & 22 deletions

File tree

mock_tests/conftest.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import time
33
from concurrent import futures
4-
from typing import Generator, Mapping
4+
from typing import AsyncGenerator, Generator, Mapping
55

66
import grpc
77
import pytest
@@ -141,6 +141,16 @@ def weaviate_client(
141141
client.close()
142142

143143

144+
@pytest.fixture(scope="function")
145+
async def weaviate_client_async(
146+
weaviate_mock: HTTPServer, start_grpc_server: grpc.Server
147+
) -> AsyncGenerator[weaviate.WeaviateAsyncClient, None]:
148+
client = weaviate.use_async_with_local(port=MOCK_PORT, host=MOCK_IP, grpc_port=MOCK_PORT_GRPC)
149+
await client.connect()
150+
yield client
151+
await client.close()
152+
153+
144154
@pytest.fixture(scope="function")
145155
def weaviate_timeouts_client(
146156
weaviate_timeouts_mock: HTTPServer, start_grpc_server: grpc.Server
@@ -368,3 +378,39 @@ def forbidden(
368378
service = MockForbiddenWeaviateService()
369379
weaviate_pb2_grpc.add_WeaviateServicer_to_server(service, start_grpc_server)
370380
return weaviate_client.collections.use("ForbiddenCollection")
381+
382+
383+
class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
384+
def BatchStream(
385+
self,
386+
request_iterator: Generator[batch_pb2.BatchStreamRequest, None, None],
387+
context: grpc.ServicerContext,
388+
) -> Generator[batch_pb2.BatchStreamReply, None, None]:
389+
while True:
390+
if context.is_active():
391+
time.sleep(0.1)
392+
else:
393+
raise grpc.RpcError(grpc.StatusCode.DEADLINE_EXCEEDED, "Deadline exceeded")
394+
395+
396+
@pytest.fixture(scope="function")
397+
def stream_cancel(
398+
weaviate_client: weaviate.WeaviateClient,
399+
weaviate_mock: HTTPServer,
400+
start_grpc_server: grpc.Server,
401+
) -> Generator[weaviate.collections.Collection, None, None]:
402+
name = "StreamCancelCollection"
403+
weaviate_mock.expect_request(f"/v1/schema/{name}").respond_with_response(
404+
Response(status=404)
405+
) # skips __create_batch_reset vectorizer logic
406+
weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
407+
client = weaviate.connect_to_local(
408+
port=MOCK_PORT,
409+
host=MOCK_IP,
410+
grpc_port=MOCK_PORT_GRPC,
411+
additional_config=weaviate.classes.init.AdditionalConfig(
412+
timeout=weaviate.classes.init.Timeout(insert=1)
413+
),
414+
)
415+
yield client.collections.use(name)
416+
client.close()

mock_tests/test_connect.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import time
2+
import pytest
3+
import weaviate
4+
from weaviate.proto.v1 import batch_pb2
5+
6+
7+
def test_bidi_stream_cancel_sync(stream_cancel: weaviate.collections.Collection):
8+
def gen():
9+
time.sleep(10)
10+
yield batch_pb2.BatchStreamRequest()
11+
12+
out, call = stream_cancel._connection.grpc_batch_stream(gen())
13+
assert call.is_active()
14+
call.cancel()
15+
assert not call.is_active()
16+
with pytest.raises(weaviate.exceptions.WeaviateBatchStreamError) as e:
17+
next(out)
18+
assert "StatusCode.CANCELLED(Locally cancelled by application!)" in e.value.message
19+
20+
21+
def test_batch_stream_hanging_server(stream_cancel: weaviate.collections.Collection):
22+
with pytest.raises(weaviate.exceptions.WeaviateBatchStreamError) as e:
23+
with stream_cancel.batch.stream() as batch:
24+
batch.add_object()
25+
assert "The server did not hangup its side of the stream gracefully in time" in e.value.message

weaviate/collections/batch/async_.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from weaviate.collections.batch.base import (
1616
GCP_STREAM_TIMEOUT,
17-
SHUTDOWN_TIMEOUT,
1817
ObjectsBatchRequest,
1918
ReferencesBatchRequest,
2019
_BatchDataWrapper,
@@ -194,14 +193,16 @@ async def recv_wrapper() -> None:
194193

195194
async def _wait(self) -> None:
196195
assert self.__bg_tasks is not None
197-
deadline = time.time() + SHUTDOWN_TIMEOUT
196+
# this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up
197+
shutdown_timeout = self.__connection.timeout_config.insert + 5
198+
deadline = time.time() + shutdown_timeout
198199
while time.time() < deadline:
199200
if not self.__bg_tasks.any_alive():
200201
break
201202
await asyncio.sleep(0.1)
202203
if self.__bg_tasks.any_alive():
203204
logger.warning(
204-
f"Background batch tasks did not exit within {SHUTDOWN_TIMEOUT}s. "
205+
f"Background batch tasks did not exit within {shutdown_timeout}s. "
205206
f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, "
206207
f"inflight_refs={len(self.__inflight_refs)}, "
207208
f"loop_alive={self.__bg_tasks.loop_alive()}, "
@@ -211,7 +212,7 @@ async def _wait(self) -> None:
211212
self.__bg_tasks.recv.cancel()
212213
self.__bg_tasks.loop.cancel()
213214
try:
214-
await asyncio.wait_for(self.__bg_tasks.gather(), timeout=5)
215+
await asyncio.wait_for(self.__bg_tasks.gather(), timeout=None)
215216
except asyncio.TimeoutError as e:
216217
raise WeaviateBatchStreamError(
217218
"Background batch tasks did not terminate after forced shutdown."

weaviate/collections/batch/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
GCP_STREAM_TIMEOUT = (
6565
160 # GCP connections have a max lifetime of 180s, leave 20s of buffer as safety
6666
)
67-
SHUTDOWN_TIMEOUT = 300 # time to wait for background threads to exit after shutdown is initiated, in seconds, in the event the server never hangs up
6867

6968

7069
class BatchRequest(ABC, Generic[TBatchInput, TBatchReturn]):
@@ -911,10 +910,10 @@ def recv_alive(self) -> bool:
911910
return self.recv.is_alive()
912911
return True # not started yet so considered alive
913912

914-
def join(self, timeout: Optional[float] = None) -> None:
913+
def join(self) -> None:
915914
"""Join the background threads."""
916-
self.loop.join(timeout=timeout)
917-
self.recv.join(timeout=timeout)
915+
self.loop.join()
916+
self.recv.join()
918917

919918

920919
class _ClusterBatch:

weaviate/collections/batch/sync.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from weaviate.collections.batch.base import (
1212
GCP_STREAM_TIMEOUT,
13-
SHUTDOWN_TIMEOUT,
1413
ObjectsBatchRequest,
1514
ReferencesBatchRequest,
1615
_BatchDataWrapper,
@@ -158,14 +157,16 @@ def _start(self) -> None:
158157
)
159158

160159
def _wait(self) -> None:
161-
deadline = time.time() + SHUTDOWN_TIMEOUT
160+
# this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up
161+
shutdown_timeout = self.__connection.timeout_config.insert + 5
162+
deadline = time.time() + shutdown_timeout
162163
while time.time() < deadline:
163164
if not self.__any_threads_alive():
164165
break
165166
time.sleep(0.1)
166167
if self.__any_threads_alive():
167168
logger.warning(
168-
f"Background batch threads did not exit within {SHUTDOWN_TIMEOUT}s. "
169+
f"Background batch threads did not exit within {shutdown_timeout}s. "
169170
f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, "
170171
f"inflight_refs={len(self.__inflight_refs)}, "
171172
f"loop_alive={self.__bg_threads.loop_alive()}, "
@@ -175,7 +176,7 @@ def _wait(self) -> None:
175176
self.__is_stopped.set()
176177
self.__cancel_active_stream() # force __recv to exit by cancelling the stream
177178

178-
self.__bg_threads.join(timeout=5)
179+
self.__bg_threads.join()
179180
if self.__any_threads_alive():
180181
raise WeaviateBatchStreamError(
181182
"Background batch threads did not terminate after forced shutdown. "
@@ -193,10 +194,29 @@ def _wait(self) -> None:
193194
self.__results_for_wrapper.imported_shards
194195
)
195196

197+
if self.__bg_exception is not None:
198+
if "StatusCode.CANCELLED(Locally cancelled by application!)" in str(
199+
self.__bg_exception
200+
):
201+
raise WeaviateBatchStreamError(
202+
"The server did not hangup its side of the stream gracefully in time"
203+
)
204+
raise self.__bg_exception
205+
196206
def _shutdown(self) -> None:
197207
# Shutdown the current batch and wait for all requests to be finished
198208
self.__is_stopped.set()
199209

210+
def __put(self, req: _BatchStreamRequest | None):
211+
while True:
212+
try:
213+
self.__reqs.put(req, timeout=1)
214+
return True
215+
except Full:
216+
if self.__bg_exception is not None or self.__shutdown_loop.is_set():
217+
return False
218+
return self.__put(req)
219+
200220
def __loop(self) -> None:
201221
refresh_time: float = 0.01
202222
while self.__bg_exception is None and not self.__shutdown_loop.is_set():
@@ -238,21 +258,18 @@ def __loop(self) -> None:
238258
if paused:
239259
logger.info("Server is back up, resuming batching loop...")
240260
paused = False
241-
try:
242-
self.__reqs.put(req, timeout=60)
243-
except Full as e:
244-
logger.warning(
245-
"Batch queue is blocked for more than 60 seconds. Exiting the loop"
246-
)
247-
self.__bg_exception = e
261+
if not self.__put(req):
262+
logger.info("Batch loop is shutting down, stopping putting requests...")
248263
return
249264
elif (
250265
self.__is_stopped.is_set()
251266
and not self.__is_hungup.is_set()
252267
and not self.__is_shutting_down.is_set()
253268
and not self.__is_oom.is_set()
254269
):
255-
self.__reqs.put(None)
270+
if not self.__put(None):
271+
logger.info("Batch loop is shutting down, stopping putting shutdown signal...")
272+
return
256273
time.sleep(refresh_time)
257274

258275
def __generate_stream_requests(

weaviate/connect/v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ def grpc_batch_stream(
10241024
request_iterator=requests,
10251025
timeout=self.timeout_config.stream,
10261026
metadata=self.grpc_headers(),
1027-
)()
1027+
)
10281028

10291029
def generator():
10301030
try:

0 commit comments

Comments
 (0)