1- from typing import Any , AsyncGenerator , Dict , Optional , TypeVar
1+ from asyncio import AbstractEventLoop
2+ from typing import Any , AsyncGenerator , Optional , TypeVar
3+
4+ from aio_pika import Channel , ExchangeType , Message , connect_robust
5+ from aio_pika .abc import AbstractChannel , AbstractRobustConnection
6+ from aio_pika .pool import Pool
27from taskiq .abc .broker import AsyncBroker
38from taskiq .abc .result_backend import AsyncResultBackend
4- from taskiq .message import TaskiqMessage
5- from aio_pika .abc import AbstractRobustConnection
6- from aio_pika .pool import Pool
7- from asyncio import AbstractEventLoop
8- from aio_pika import connect_robust , Channel , Message , ExchangeType
9+ from taskiq .message import BrokerMessage
910
1011_T = TypeVar ("_T" )
1112
@@ -35,7 +36,7 @@ async def _get_rmq_connection() -> AbstractRobustConnection:
3536 loop = loop ,
3637 )
3738
38- async def get_channel () -> Channel :
39+ async def get_channel () -> AbstractChannel :
3940 async with self .connection_pool .acquire () as connection :
4041 return await connection .channel ()
4142
@@ -62,30 +63,30 @@ async def startup(self) -> None:
6263 queue = await channel .declare_queue (self .queue_name )
6364 await queue .bind (exchange = exchange , routing_key = "*" )
6465
65- async def kick (self , message : TaskiqMessage ) -> None :
66+ async def kick (self , message : BrokerMessage ) -> None :
6667 rmq_msg = Message (
67- body = message .json ().encode (),
68- content_type = "application/json" ,
68+ body = message .message .encode (),
6969 headers = {
7070 "task_id" : message .task_id ,
7171 "task_name" : message .task_name ,
72+ ** message .headers ,
7273 },
7374 )
7475 async with self .channel_pool .acquire () as channel :
7576 exchange = await channel .get_exchange (self .exchange_name , ensure = False )
7677 await exchange .publish (rmq_msg , routing_key = message .task_id )
7778
78- async def listen (self ) -> AsyncGenerator [TaskiqMessage , None ]:
79+ async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]:
7980 async with self .channel_pool .acquire () as channel :
8081 await channel .set_qos (prefetch_count = self .qos )
8182 queue = await channel .get_queue (self .queue_name , ensure = False )
8283 async with queue .iterator () as queue_iter :
8384 async for rmq_message in queue_iter :
8485 async with rmq_message .process ():
8586 try :
86- yield TaskiqMessage .parse_raw (
87+ yield BrokerMessage .parse_raw (
8788 rmq_message .body ,
88- content_type = rmq_message .content_type ,
89+ content_type = rmq_message .content_type or "" ,
8990 )
9091 except ValueError :
9192 continue
0 commit comments