11from asyncio import AbstractEventLoop
22from logging import getLogger
3- from typing import Any , AsyncGenerator , Optional , TypeVar
3+ from typing import Any , AsyncGenerator , Callable , Optional , TypeVar
44
55from aio_pika import Channel , ExchangeType , Message , connect_robust
66from aio_pika .abc import AbstractChannel , AbstractRobustConnection
@@ -18,6 +18,7 @@ class AioPikaBroker(AsyncBroker):
1818 def __init__ (
1919 self ,
2020 result_backend : Optional [AsyncResultBackend [_T ]] = None ,
21+ task_id_generator : Optional [Callable [[], str ]] = None ,
2122 qos : int = 10 ,
2223 loop : Optional [AbstractEventLoop ] = None ,
2324 max_channel_pool_size : int = 2 ,
@@ -28,7 +29,7 @@ def __init__(
2829 * connection_args : Any ,
2930 ** connection_kwargs : Any ,
3031 ) -> None :
31- super ().__init__ (result_backend )
32+ super ().__init__ (result_backend , task_id_generator )
3233
3334 async def _get_rmq_connection () -> AbstractRobustConnection :
3435 return await connect_robust (* connection_args , ** connection_kwargs )
@@ -72,7 +73,7 @@ async def kick(self, message: BrokerMessage) -> None:
7273 headers = {
7374 "task_id" : message .task_id ,
7475 "task_name" : message .task_name ,
75- ** message .headers ,
76+ ** message .labels ,
7677 },
7778 )
7879 async with self .channel_pool .acquire () as channel :
@@ -88,10 +89,10 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
8889 async with rmq_message .process ():
8990 try :
9091 yield BrokerMessage (
91- task_id = rmq_message .headers [ "task_id" ] ,
92- task_name = rmq_message .headers [ "task_name" ] ,
92+ task_id = rmq_message .headers . pop ( "task_id" ) ,
93+ task_name = rmq_message .headers . pop ( "task_name" ) ,
9394 message = rmq_message .body ,
94- headers = rmq_message .headers ,
95+ labels = rmq_message .headers ,
9596 )
9697 except (ValueError , LookupError ) as exc :
9798 logger .debug (
@@ -101,4 +102,5 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
101102 )
102103
103104 async def shutdown (self ) -> None :
105+ await super ().shutdown ()
104106 await self .connection_pool .close ()
0 commit comments