Skip to content

Commit df5014e

Browse files
committed
Remove redundant stream cancellation logic, align async with sync impl
1 parent b5ddfb2 commit df5014e

6 files changed

Lines changed: 57 additions & 162 deletions

File tree

.github/workflows/main.yaml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@ env:
2222
WEAVIATE_128: 1.28.16
2323
WEAVIATE_129: 1.29.11
2424
WEAVIATE_130: 1.30.22
25-
WEAVIATE_131: 1.31.20
26-
WEAVIATE_132: 1.32.23
27-
WEAVIATE_133: 1.33.10
28-
WEAVIATE_134: 1.34.5
29-
WEAVIATE_135: 1.35.16-efdedfa
30-
WEAVIATE_136: 1.36.9-d905e6c
31-
WEAVIATE_137: 1.37.0-rc.0-b313954.amd64
32-
25+
WEAVIATE_131: 1.31.22
26+
WEAVIATE_132: 1.32.27
27+
WEAVIATE_133: 1.33.18
28+
WEAVIATE_134: 1.34.19
29+
WEAVIATE_135: 1.35.15
30+
WEAVIATE_136: 1.36.6-8edcf08.amd64
31+
WEAVIATE_137: 1.37.0-dev-29d5c87.amd64
3332

3433
jobs:
3534
lint-and-format:

mock_tests/test_connect.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

weaviate/collections/batch/async_.py

Lines changed: 20 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def __init__(
115115
self.__oom_wait_time = 300
116116

117117
self.__shutdown_loop = asyncio.Event()
118-
self.__sent_sentinel = asyncio.Event()
119118

120119
self.__objs_cache_lock = asyncio.Lock()
121120
self.__objs_cache: dict[str, BatchObject] = {}
@@ -195,34 +194,12 @@ async def _wait(self) -> None:
195194
assert self.__bg_tasks is not None
196195
# this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up
197196
shutdown_timeout = self.__connection.timeout_config.insert + 5
198-
deadline = time.time() + shutdown_timeout
199-
while time.time() < deadline:
200-
if not self.__bg_tasks.any_alive():
201-
break
202-
await asyncio.sleep(0.1)
203-
if self.__bg_tasks.any_alive():
204-
logger.warning(
205-
f"Background batch tasks did not exit within {shutdown_timeout}s. "
206-
f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, "
207-
f"inflight_refs={len(self.__inflight_refs)}, "
208-
f"loop_alive={self.__bg_tasks.loop_alive()}, "
209-
f"recv_alive={self.__bg_tasks.recv_alive()}"
210-
)
211-
self.__shutdown_loop.set() # force __loop to exit
212-
self.__bg_tasks.recv.cancel()
213-
self.__bg_tasks.loop.cancel()
214197
try:
215-
await asyncio.wait_for(self.__bg_tasks.gather(), timeout=None)
198+
await asyncio.wait_for(self.__bg_tasks.gather(), timeout=shutdown_timeout)
216199
except asyncio.TimeoutError as e:
217200
raise WeaviateBatchStreamError(
218201
"Background batch tasks did not terminate after forced shutdown."
219202
) from e
220-
if self.__bg_tasks.any_alive():
221-
raise WeaviateBatchStreamError(
222-
"Background batch tasks did not terminate after forced shutdown. "
223-
f"loop_alive={self.__bg_tasks.loop_alive()}, "
224-
f"recv_alive={self.__bg_tasks.recv_alive()}"
225-
)
226203

227204
# copy the results to the public results
228205
self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results
@@ -237,6 +214,15 @@ async def _wait(self) -> None:
237214
async def _shutdown(self) -> None:
238215
self.__is_stopped.set()
239216

217+
async def __put(self, req: _BatchStreamRequest | None):
218+
try:
219+
await asyncio.wait_for(self.__reqs.put(req), timeout=1)
220+
return True
221+
except asyncio.TimeoutError:
222+
if self.__bg_exception is not None or self.__shutdown_loop.is_set():
223+
return False
224+
return await self.__put(req)
225+
240226
async def __loop(self) -> None:
241227
refresh_time: float = 0.01
242228
while self.__bg_exception is None and not self.__shutdown_loop.is_set():
@@ -278,23 +264,18 @@ async def __loop(self) -> None:
278264
if paused:
279265
logger.info("Server is back up, resuming batching loop...")
280266
paused = False
281-
try:
282-
await asyncio.wait_for(self.__reqs.put(req), timeout=60)
283-
except asyncio.TimeoutError as e:
284-
logger.warning(
285-
"Batch queue is blocked for more than 60 seconds. Exiting the loop"
286-
)
287-
self.__bg_exception = e
267+
if not self.__put(req):
268+
logger.info("Batch loop is shutting down, stopping putting new requests...")
288269
return
289270
elif (
290271
self.__is_stopped.is_set()
291-
and not self.__sent_sentinel.is_set()
292272
and not self.__is_hungup.is_set()
293273
and not self.__is_shutting_down.is_set()
294274
and not self.__is_oom.is_set()
295275
):
296-
await self.__reqs.put(None)
297-
self.__sent_sentinel.set()
276+
await self.__put(None)
277+
logger.info("Sent sentinel, stopping batch loop...")
278+
return
298279
await asyncio.sleep(refresh_time)
299280

300281
def __generate_stream_requests(
@@ -347,10 +328,7 @@ def request_maker():
347328
if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0:
348329
yield _BatchStreamRequest(request, uuids, beacons)
349330

350-
async def __send(
351-
self,
352-
) -> AsyncGenerator[batch_pb2.BatchStreamRequest, None]:
353-
self.__sent_sentinel.clear()
331+
async def __send(self) -> AsyncGenerator[batch_pb2.BatchStreamRequest, None]:
354332
yield batch_pb2.BatchStreamRequest(
355333
start=batch_pb2.BatchStreamRequest.Start(
356334
consistency_level=self.__batch_grpc._consistency_level,
@@ -393,14 +371,13 @@ async def __send(
393371
logger.info("Batch send thread exiting due to exception...")
394372

395373
async def __recv(self) -> None:
396-
stream = self.__batch_grpc.astream(
397-
connection=self.__connection,
398-
requests=self.__send(),
399-
)
400374
self.__is_renewing_stream.clear()
401375
self.__is_shutting_down.clear()
402376
self.__is_hungup.clear()
403-
async for message in stream:
377+
async for message in self.__batch_grpc.astream(
378+
connection=self.__connection,
379+
requests=self.__send(),
380+
):
404381
if message.HasField("started"):
405382
logger.info("Batch stream started successfully")
406383

weaviate/collections/batch/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -910,10 +910,10 @@ def recv_alive(self) -> bool:
910910
return self.recv.is_alive()
911911
return True # not started yet so considered alive
912912

913-
def join(self) -> None:
913+
def join(self, timeout: float | None = None) -> None:
914914
"""Join the background threads."""
915-
self.loop.join()
916-
self.recv.join()
915+
self.loop.join(timeout=timeout)
916+
self.recv.join(timeout=timeout)
917917

918918

919919
class _ClusterBatch:

weaviate/collections/batch/sync.py

Lines changed: 12 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def __init__(
6666

6767
self.__connection = connection
6868
self.__is_gcp_on_wcd = connection._connection_params.is_gcp_on_wcd()
69-
self.__stream_start: Optional[float] = None
7069
self.__is_renewing_stream = threading.Event()
7170
self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM
7271
self.__batch_size = 100
@@ -124,26 +123,6 @@ def number_errors(self) -> int:
124123
def __all_threads_alive(self) -> bool:
125124
return self.__bg_threads.is_alive()
126125

127-
def __any_threads_alive(self) -> bool:
128-
return self.__bg_threads.any_alive()
129-
130-
def __set_active_stream(self, call: Call) -> None:
131-
with self.__stream_lock:
132-
self.__active_stream = call
133-
134-
def __clear_active_stream(self) -> None:
135-
with self.__stream_lock:
136-
self.__active_stream = None
137-
138-
def __cancel_active_stream(self) -> bool:
139-
with self.__stream_lock:
140-
stream = self.__active_stream
141-
142-
if stream is None:
143-
return False
144-
145-
return stream.cancel()
146-
147126
def _start(self) -> None:
148127
self.__start_bg_threads()
149128
logger.info("Provisioned stream to the server for batch processing")
@@ -159,30 +138,12 @@ def _start(self) -> None:
159138
def _wait(self) -> None:
160139
# this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up
161140
shutdown_timeout = self.__connection.timeout_config.insert + 5
162-
deadline = time.time() + shutdown_timeout
163-
while time.time() < deadline:
164-
if not self.__any_threads_alive():
165-
break
166-
time.sleep(0.1)
167-
if self.__any_threads_alive():
168-
logger.warning(
169-
f"Background batch threads did not exit within {shutdown_timeout}s. "
170-
f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, "
171-
f"inflight_refs={len(self.__inflight_refs)}, "
172-
f"loop_alive={self.__bg_threads.loop_alive()}, "
173-
f"recv_alive={self.__bg_threads.recv_alive()}"
174-
)
175-
self.__shutdown_loop.set() # force __loop to exit
176-
self.__is_stopped.set()
177-
self.__cancel_active_stream() # force __recv to exit by cancelling the stream
178-
179-
self.__bg_threads.join()
180-
if self.__any_threads_alive():
141+
try:
142+
self.__bg_threads.join(shutdown_timeout)
143+
except TimeoutError as e:
181144
raise WeaviateBatchStreamError(
182-
"Background batch threads did not terminate after forced shutdown. "
183-
f"loop_alive={self.__bg_threads.loop_alive()}, "
184-
f"recv_alive={self.__bg_threads.recv_alive()}"
185-
)
145+
"Background batch threads did not terminate after forced shutdown."
146+
) from e
186147

187148
# copy the results to the public results
188149
self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results
@@ -194,15 +155,6 @@ def _wait(self) -> None:
194155
self.__results_for_wrapper.imported_shards
195156
)
196157

197-
if self.__bg_exception is not None:
198-
if "StatusCode.CANCELLED(Locally cancelled by application!)" in str(
199-
self.__bg_exception
200-
):
201-
raise WeaviateBatchStreamError(
202-
"The server did not hangup its side of the stream gracefully in time"
203-
)
204-
raise self.__bg_exception
205-
206158
def _shutdown(self) -> None:
207159
# Shutdown the current batch and wait for all requests to be finished
208160
self.__is_stopped.set()
@@ -267,9 +219,9 @@ def __loop(self) -> None:
267219
and not self.__is_shutting_down.is_set()
268220
and not self.__is_oom.is_set()
269221
):
270-
if not self.__put(None):
271-
logger.info("Batch loop is shutting down, stopping putting shutdown signal...")
272-
return
222+
self.__put(None)
223+
logger.info("Sent sentinel, stopping batch loop...")
224+
return
273225
time.sleep(refresh_time)
274226

275227
def __generate_stream_requests(
@@ -370,16 +322,13 @@ def __send(
370322
logger.info("Batch send thread exiting due to exception...")
371323

372324
def __recv(self) -> None:
373-
gen, call = self.__batch_grpc.stream(
374-
connection=self.__connection,
375-
requests=self.__send(),
376-
)
377-
self.__set_active_stream(call)
378-
379325
self.__is_renewing_stream.clear()
380326
self.__is_shutting_down.clear()
381327
self.__is_hungup.clear()
382-
for message in gen:
328+
for message in self.__batch_grpc.stream(
329+
connection=self.__connection,
330+
requests=self.__send(),
331+
):
383332
if message.HasField("started"):
384333
logger.info("Batch stream started successfully")
385334

weaviate/connect/v4.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,27 +1018,22 @@ def grpc_batch_objects(
10181018
def grpc_batch_stream(
10191019
self,
10201020
requests: Generator[batch_pb2.BatchStreamRequest, None, None],
1021-
) -> tuple[Generator[batch_pb2.BatchStreamReply, None, None], Call]:
1021+
) -> Generator[batch_pb2.BatchStreamReply, None, None]:
10221022
assert self.grpc_stub is not None
1023-
call = self.grpc_stub.BatchStream(
1024-
request_iterator=requests,
1025-
timeout=self.timeout_config.stream,
1026-
metadata=self.grpc_headers(),
1027-
)
1028-
1029-
def generator():
1030-
try:
1031-
for msg in call:
1032-
yield msg
1033-
except RpcError as e:
1034-
error = cast(Call, e)
1035-
if error.code() == StatusCode.PERMISSION_DENIED:
1036-
raise InsufficientPermissionsError(error)
1037-
if error.code() == StatusCode.ABORTED:
1038-
raise _BatchStreamShutdownError()
1039-
raise WeaviateBatchStreamError(f"{error.code()}({error.details()})")
1040-
1041-
return generator(), call
1023+
try:
1024+
for msg in self.grpc_stub.BatchStream(
1025+
request_iterator=requests,
1026+
timeout=self.timeout_config.stream,
1027+
metadata=self.grpc_headers(),
1028+
):
1029+
yield msg
1030+
except RpcError as e:
1031+
error = cast(Call, e)
1032+
if error.code() == StatusCode.PERMISSION_DENIED:
1033+
raise InsufficientPermissionsError(error)
1034+
if error.code() == StatusCode.ABORTED:
1035+
raise _BatchStreamShutdownError()
1036+
raise WeaviateBatchStreamError(f"{error.code()}({error.details()})")
10421037

10431038
def grpc_batch_delete(
10441039
self, request: batch_delete_pb2.BatchDeleteRequest

0 commit comments

Comments
 (0)