11import asyncio
22import inspect
3- from concurrent .futures import ThreadPoolExecutor
3+ from concurrent .futures import Executor
44from logging import getLogger
55from time import time
6- from typing import Any , Callable , Dict , get_type_hints
6+ from typing import Any , Callable , Dict , Optional , get_type_hints
77
88from taskiq_dependencies import DependencyGraph
99
1010from taskiq .abc .broker import AsyncBroker
1111from taskiq .abc .middleware import TaskiqMiddleware
12- from taskiq .cli .worker .args import WorkerArgs
13- from taskiq .cli .worker .params_parser import parse_params
1412from taskiq .context import Context
1513from taskiq .message import BrokerMessage , TaskiqMessage
14+ from taskiq .receiver .params_parser import parse_params
1615from taskiq .result import TaskiqResult
1716from taskiq .state import TaskiqState
1817from taskiq .utils import maybe_awaitable
@@ -37,20 +36,24 @@ def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
3736class Receiver :
3837 """Class that uses as a callback handler."""
3938
40- def __init__ (self , broker : AsyncBroker , cli_args : WorkerArgs ) -> None :
39+ def __init__ (
40+ self ,
41+ broker : AsyncBroker ,
42+ executor : Optional [Executor ] = None ,
43+ validate_params : bool = True ,
44+ max_async_tasks : int = 20 ,
45+ ) -> None :
4146 self .broker = broker
42- self .cli_args = cli_args
47+ self .executor = executor
48+ self .validate_params = validate_params
4349 self .task_signatures : Dict [str , inspect .Signature ] = {}
4450 self .task_hints : Dict [str , Dict [str , Any ]] = {}
4551 self .dependency_graphs : Dict [str , DependencyGraph ] = {}
4652 for task in self .broker .available_tasks .values ():
4753 self .task_signatures [task .task_name ] = inspect .signature (task .original_func )
4854 self .task_hints [task .task_name ] = get_type_hints (task .original_func )
4955 self .dependency_graphs [task .task_name ] = DependencyGraph (task .original_func )
50- self .executor = ThreadPoolExecutor (
51- max_workers = cli_args .max_threadpool_threads ,
52- )
53- self .sem = asyncio .Semaphore (cli_args .max_async_tasks )
56+ self .sem = asyncio .Semaphore (max_async_tasks )
5457
5558 async def callback ( # noqa: C901, WPS213
5659 self ,
@@ -152,10 +155,10 @@ async def run_task( # noqa: C901, WPS210
152155 loop = asyncio .get_running_loop ()
153156 returned = None
154157 found_exception = None
155- signature = self .task_signatures .get (message .task_name )
158+ signature = None
159+ if self .validate_params :
160+ signature = self .task_signatures .get (message .task_name )
156161 dependency_graph = self .dependency_graphs .get (message .task_name )
157- if self .cli_args .no_parse :
158- signature = None
159162 parse_params (signature , self .task_hints .get (message .task_name ) or {}, message )
160163
161164 dep_ctx = None
@@ -221,3 +224,25 @@ async def run_task( # noqa: C901, WPS210
221224 )
222225
223226 return result
227+
228+ async def listen (self ) -> None : # pragma: no cover
229+ """
230+ This function iterates over tasks asynchronously.
231+
232+ It uses listen() method of an AsyncBroker
233+ to get new messages from queues.
234+ """
235+ logger .debug ("Runing startup event." )
236+ await self .broker .startup ()
237+ logger .info ("Listening started." )
238+ tasks = set ()
239+ async for message in self .broker .listen ():
240+ task = asyncio .create_task (self .callback (message = message , raise_err = False ))
241+ tasks .add (task )
242+
243+ # We want the task to remove itself from the set when it's done.
244+ #
245+ # Because python's GC can silently cancel task
246+ # and it considered to be Hisenbug.
247+ # https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
248+ task .add_done_callback (tasks .discard )
0 commit comments