@@ -798,7 +798,6 @@ async def batch_add_requests(
798798 if min_delay_between_unprocessed_requests_retries :
799799 logger .warning ('`min_delay_between_unprocessed_requests_retries` is deprecated and not used anymore.' )
800800
801- tasks = set [asyncio .Task ]()
802801 asyncio_queue : asyncio .Queue [Iterable [dict ]] = asyncio .Queue ()
803802 request_params = self ._build_params (clientKey = self .client_key , forefront = forefront )
804803
@@ -815,29 +814,31 @@ async def batch_add_requests(
815814 for batch in batches :
816815 await asyncio_queue .put (batch )
817816
818- # Start a required number of worker tasks to process the batches.
819- for i in range (max_parallel ):
820- coro = self ._batch_add_requests_worker (
821- asyncio_queue ,
822- request_params ,
823- )
824- task = asyncio .create_task (coro , name = f'batch_add_requests_worker_{ i } ' )
825- tasks .add (task )
826-
827- # Wait for all batches to be processed.
828- await asyncio_queue .join ()
829-
830- # Send cancellation signals to all worker tasks and wait for them to finish.
831- for task in tasks :
832- task .cancel ()
833-
834- results : list [BatchAddResponse ] = await asyncio .gather (* tasks )
817+ # Use TaskGroup for structured concurrency — automatic cleanup and error propagation.
818+ try :
819+ async with asyncio .TaskGroup () as tg :
820+ workers = [
821+ tg .create_task (
822+ self ._batch_add_requests_worker (asyncio_queue , request_params ),
823+ name = f'batch_add_requests_worker_{ i } ' ,
824+ )
825+ for i in range (max_parallel )
826+ ]
827+
828+ # Wait for all batches to be processed, then cancel idle workers.
829+ await asyncio_queue .join ()
830+ for worker in workers :
831+ worker .cancel ()
832+ except ExceptionGroup as eg :
833+ # Re-raise the first worker exception directly to maintain backward-compatible error types.
834+ raise eg .exceptions [0 ] from None
835835
836836 # Combine the results from all workers and return them.
837837 processed_requests = list [AddedRequest ]()
838838 unprocessed_requests = list [RequestDraft ]()
839839
840- for result in results :
840+ for worker in workers :
841+ result = worker .result ()
841842 processed_requests .extend (result .data .processed_requests )
842843 unprocessed_requests .extend (result .data .unprocessed_requests )
843844
0 commit comments