@@ -115,7 +115,6 @@ def __init__(
115115 self .__oom_wait_time = 300
116116
117117 self .__shutdown_loop = asyncio .Event ()
118- self .__sent_sentinel = asyncio .Event ()
119118
120119 self .__objs_cache_lock = asyncio .Lock ()
121120 self .__objs_cache : dict [str , BatchObject ] = {}
@@ -195,34 +194,12 @@ async def _wait(self) -> None:
195194 assert self .__bg_tasks is not None
196195 # this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up
197196 shutdown_timeout = self .__connection .timeout_config .insert + 5
198- deadline = time .time () + shutdown_timeout
199- while time .time () < deadline :
200- if not self .__bg_tasks .any_alive ():
201- break
202- await asyncio .sleep (0.1 )
203- if self .__bg_tasks .any_alive ():
204- logger .warning (
205- f"Background batch tasks did not exit within { shutdown_timeout } s. "
206- f"Forcing shutdown. inflight_objs={ len (self .__inflight_objs )} , "
207- f"inflight_refs={ len (self .__inflight_refs )} , "
208- f"loop_alive={ self .__bg_tasks .loop_alive ()} , "
209- f"recv_alive={ self .__bg_tasks .recv_alive ()} "
210- )
211- self .__shutdown_loop .set () # force __loop to exit
212- self .__bg_tasks .recv .cancel ()
213- self .__bg_tasks .loop .cancel ()
214197 try :
215- await asyncio .wait_for (self .__bg_tasks .gather (), timeout = None )
198+ await asyncio .wait_for (self .__bg_tasks .gather (), timeout = shutdown_timeout )
216199 except asyncio .TimeoutError as e :
217200 raise WeaviateBatchStreamError (
218201 "Background batch tasks did not terminate after forced shutdown."
219202 ) from e
220- if self .__bg_tasks .any_alive ():
221- raise WeaviateBatchStreamError (
222- "Background batch tasks did not terminate after forced shutdown. "
223- f"loop_alive={ self .__bg_tasks .loop_alive ()} , "
224- f"recv_alive={ self .__bg_tasks .recv_alive ()} "
225- )
226203
227204 # copy the results to the public results
228205 self .__results_for_wrapper_backup .results = self .__results_for_wrapper .results
@@ -237,6 +214,15 @@ async def _wait(self) -> None:
237214 async def _shutdown (self ) -> None :
238215 self .__is_stopped .set ()
239216
217+ async def __put (self , req : _BatchStreamRequest | None ):
218+ try :
219+ await asyncio .wait_for (self .__reqs .put (req ), timeout = 1 )
220+ return True
221+ except asyncio .TimeoutError :
222+ if self .__bg_exception is not None or self .__shutdown_loop .is_set ():
223+ return False
224+ return await self .__put (req )
225+
240226 async def __loop (self ) -> None :
241227 refresh_time : float = 0.01
242228 while self .__bg_exception is None and not self .__shutdown_loop .is_set ():
@@ -278,23 +264,18 @@ async def __loop(self) -> None:
278264 if paused :
279265 logger .info ("Server is back up, resuming batching loop..." )
280266 paused = False
281- try :
282- await asyncio .wait_for (self .__reqs .put (req ), timeout = 60 )
283- except asyncio .TimeoutError as e :
284- logger .warning (
285- "Batch queue is blocked for more than 60 seconds. Exiting the loop"
286- )
287- self .__bg_exception = e
267+ if not self .__put (req ):
268+ logger .info ("Batch loop is shutting down, stopping putting new requests..." )
288269 return
289270 elif (
290271 self .__is_stopped .is_set ()
291- and not self .__sent_sentinel .is_set ()
292272 and not self .__is_hungup .is_set ()
293273 and not self .__is_shutting_down .is_set ()
294274 and not self .__is_oom .is_set ()
295275 ):
296- await self .__reqs .put (None )
297- self .__sent_sentinel .set ()
276+ await self .__put (None )
277+ logger .info ("Sent sentinel, stopping batch loop..." )
278+ return
298279 await asyncio .sleep (refresh_time )
299280
300281 def __generate_stream_requests (
@@ -347,10 +328,7 @@ def request_maker():
347328 if len (request .data .objects .values ) > 0 or len (request .data .references .values ) > 0 :
348329 yield _BatchStreamRequest (request , uuids , beacons )
349330
350- async def __send (
351- self ,
352- ) -> AsyncGenerator [batch_pb2 .BatchStreamRequest , None ]:
353- self .__sent_sentinel .clear ()
331+ async def __send (self ) -> AsyncGenerator [batch_pb2 .BatchStreamRequest , None ]:
354332 yield batch_pb2 .BatchStreamRequest (
355333 start = batch_pb2 .BatchStreamRequest .Start (
356334 consistency_level = self .__batch_grpc ._consistency_level ,
@@ -393,14 +371,13 @@ async def __send(
393371 logger .info ("Batch send thread exiting due to exception..." )
394372
395373 async def __recv (self ) -> None :
396- stream = self .__batch_grpc .astream (
397- connection = self .__connection ,
398- requests = self .__send (),
399- )
400374 self .__is_renewing_stream .clear ()
401375 self .__is_shutting_down .clear ()
402376 self .__is_hungup .clear ()
403- async for message in stream :
377+ async for message in self .__batch_grpc .astream (
378+ connection = self .__connection ,
379+ requests = self .__send (),
380+ ):
404381 if message .HasField ("started" ):
405382 logger .info ("Batch stream started successfully" )
406383
0 commit comments