@@ -110,6 +110,7 @@ def __init__(
110110 self ._max_ack_delay = max_ack_delay
111111 self ._stopping_grace = stopping_grace
112112 self ._has_requested_options = False
113+ self ._num_constructed_acks = 0
113114 self .subscription_id = b""
114115
115116 def _construct_initial_read_req (self ) -> persistent_pb2 .ReadReq :
@@ -130,10 +131,10 @@ def _construct_initial_read_req(self) -> persistent_pb2.ReadReq:
130131 options .all .CopyFrom (shared_pb2 .Empty ())
131132 return persistent_pb2 .ReadReq (options = options )
132133
133- @staticmethod
134134 def _construct_ack_or_nack_read_req (
135- subscription_id : bytes , event_ids : list [UUID ], action : str
135+ self , subscription_id : bytes , event_ids : list [UUID ], action : str
136136 ) -> persistent_pb2 .ReadReq :
137+ self ._num_constructed_acks = len (event_ids )
137138 ids = [shared_pb2 .UUID (string = str (event_id )) for event_id in event_ids ]
138139 if action == "ack" :
139140 read_req = persistent_pb2 .ReadReq (
@@ -210,17 +211,19 @@ async def __anext__(self) -> persistent_pb2.ReadReq:
210211 # First return read request with options, then return read request
211212 # with batch of n/acks whenever the batch is full, or when the n/ack
212213 # actions changes, or periodically, or when stopping.
213-
214214 if not self ._has_requested_options :
215215 # Return initial read request with options.
216216 self ._has_requested_options = True
217217 return self ._construct_initial_read_req ()
218218
219+ # Account on queue for previously returned n/acks.
220+ for _ in range (self ._num_constructed_acks ):
221+ self ._ack_queue .task_done ()
222+ self ._num_constructed_acks = 0
223+
219224 # Return read request with a batch of n/acks...
220225
221226 # Initialise batch, maybe from held n/ack.
222- for _ in self ._batch_ids :
223- self ._ack_queue .task_done ()
224227 self ._batch_ids = []
225228 batch_action : str | None = None
226229 if self ._ack_held is not None :
@@ -234,7 +237,6 @@ async def __anext__(self) -> persistent_pb2.ReadReq:
234237 if self ._is_stopping :
235238 # Allow time for server to process last n/acks.
236239 await asyncio .sleep (self ._stopping_grace )
237- self ._ack_queue .task_done ()
238240 raise StopAsyncIteration from None
239241
240242 try :
@@ -316,12 +318,11 @@ async def nack(
316318 assert action in ["unknown" , "park" , "retry" , "skip" , "stop" ]
317319 await self ._ack_queue .put ((event_id , action ))
318320
319- async def stop (self , * , wait_until_stopped : bool = True ) -> None :
321+ async def stop (self , * , timeout : float | None = None ) -> None :
320322 if not self ._is_poisoned :
321323 self ._is_poisoned = True
322324 await self ._ack_queue .put ((None , "poison" ))
323- if wait_until_stopped :
324- await self ._is_stopped .wait ()
325+ await asyncio .wait_for (self ._is_stopped .wait (), timeout )
325326
326327
327328class SubscriptionReadReqs (BaseSubscriptionReadReqs ):
@@ -364,6 +365,11 @@ def __next__(self) -> persistent_pb2.ReadReq:
364365 self ._has_requested_options = True
365366 return self ._construct_initial_read_req ()
366367
368+ # Account on queue for previously returned n/acks.
369+ for _ in range (self ._num_constructed_acks ):
370+ self ._ack_queue .task_done ()
371+ self ._num_constructed_acks = 0
372+
367373 # Send a batch of n/acks...
368374
369375 # Initialise batch, maybe from held n/ack.
@@ -387,7 +393,6 @@ def __next__(self) -> persistent_pb2.ReadReq:
387393 # Wait for next n/ack, timeout with "max ack delay".
388394 get_timeout = max (0.0 , self ._calc_time_until_next_ack_batch ())
389395 event_id , action = self ._ack_queue .get (timeout = get_timeout )
390- self ._ack_queue .task_done ()
391396
392397 # If queue was poisoned, send non-empty batch now.
393398 if action == "poison" :
@@ -572,7 +577,7 @@ async def init(self) -> None:
572577 msg = f"Expected subscription confirmation, got: { first_read_resp } "
573578 raise SubscriptionConfirmationError (msg )
574579 except BaseException :
575- await self .stop (wait_until_stopped = False )
580+ await self .stop (timeout = 1 )
576581 raise
577582
578583 @property
@@ -600,26 +605,27 @@ async def __anext__(self) -> RecordedEvent:
600605 else : # pragma: no cover
601606 pass
602607 except BaseException :
603- await self .stop (wait_until_stopped = False )
608+ await self .stop (timeout = 1 )
604609 raise
605610
606- async def stop (self , * , wait_until_stopped : bool = True ) -> None :
611+ async def stop (self , * , timeout : float | None = None ) -> None :
607612 if self ._is_context_manager_active :
608613 self ._is_stopping = True
609614 elif not await self ._set_is_stopped ():
610- await self ._read_reqs .stop (wait_until_stopped = wait_until_stopped )
611- self ._read_resp_stream_worker_task .cancel ()
612- self ._stream_stream_call .cancel ()
613- await asyncio .sleep (0.05 )
614- self ._grpc_streamers .remove (self )
615615 try :
616- await self ._read_resp_stream_worker_task
617- except asyncio .CancelledError :
618- pass
619- except grpc .RpcError as e :
620- raise handle_rpc_error (e ) from e
621- if self ._read_reqs .errored :
622- raise self ._read_reqs .errored
616+ await self ._read_reqs .stop (timeout = timeout )
617+ finally :
618+ self ._read_resp_stream_worker_task .cancel ()
619+ self ._stream_stream_call .cancel ()
620+ self ._grpc_streamers .remove (self )
621+ try :
622+ await self ._read_resp_stream_worker_task
623+ except asyncio .CancelledError :
624+ pass
625+ except grpc .RpcError as e :
626+ raise handle_rpc_error (e ) from e
627+ if self ._read_reqs .errored :
628+ raise self ._read_reqs .errored
623629
624630 async def ack (self , item : UUID | RecordedEvent ) -> None :
625631 await self ._read_reqs .ack (event_id = self ._get_event_id (item ))
0 commit comments