66import threading
77import typing as t
88
9+ from asyncio .events import AbstractEventLoop
10+ from contextvars import ContextVar
911from queue import Queue as ThreadingQueue
1012from textwrap import dedent as D
1113
2729 console_handler .setLevel (logging .DEBUG )
2830
2931 # Create a formatter and attach it to the handler
30- formatter = logging .Formatter ("%(asctime)s - %(name)s - %(levelname)s - %(message)s" )
32+ formatter = logging .Formatter (
33+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
34+ )
3135 console_handler .setFormatter (formatter )
3236
3337 # Attach the handler to the logger
3438 logger .addHandler (console_handler )
3539
3640
3741###########################
38-
39- T = t .TypeVar ("T" )
40-
41- _non_bridge_loop = contextvars .ContextVar ("non_bridge_loop" , default = None )
42- _context_bound_loop = contextvars .ContextVar ("_context_bound_loop" , default = None )
4342_poison = object ()
4443
4544
@@ -52,18 +51,28 @@ class _QueueSet(t.NamedTuple):
5251class _SyncTask (t .NamedTuple ):
5352 sync_task : t .Callable
5453 args : tuple
55- kwargs : dict
56- loop : asyncio . BaseEventLoop
54+ kwargs : t . Mapping
55+ loop : AbstractEventLoop
5756 context : contextvars .Context
5857 done_future : asyncio .Future
5958
6059
60+ T = t .TypeVar ("T" )
61+
62+ _non_bridge_loop : ContextVar [t .Optional [AbstractEventLoop ]] = ContextVar (
63+ "non_bridge_loop" , default = None
64+ )
65+ _context_bound_task : ContextVar [t .Optional [_SyncTask ]] = ContextVar (
66+ "_context_bound_task" , default = None
67+ )
68+
69+
6170def sync_to_async (
62- func : t .Callable [[ ...,] , T ] | t .Coroutine ,
71+ func : t .Callable [..., T ] | t .Coroutine ,
6372 args : t .Sequence [t .Any ] = (),
64- kwargs : t .Mapping [str , t .Any | None ] = None
73+ kwargs : t .Optional [ t . Mapping [str , t .Any | None ]] = None ,
6574) -> T :
66- """ Allows calling an async function from a synchronous context.
75+ """Allows calling an async function from a synchronous context.
6776
6877 When called from a synchronous scenario, this
6978 will create and cache an asyncio.loop instance per thread for code that otherwise
@@ -88,9 +97,9 @@ def sync_to_async(
8897 raise TypeError ("Can't accept extra arguments for existing coroutine" )
8998 coro = func
9099 else :
91- coro = func (* args , ** kwargs )
100+ coro = t . cast ( t . Callable , func ) (* args , ** kwargs )
92101
93- root_sync_task = _context_bound_loop .get ()
102+ root_sync_task = _context_bound_task .get ()
94103 if not root_sync_task :
95104 return _sync_to_async_non_bridge (coro , args , kwargs )
96105
@@ -100,11 +109,16 @@ def sync_to_async(
100109 event = _ThreadPool .get (loop ).all [threading .current_thread ()].event
101110
102111 event .clear ()
103- task = None # Keep a strong reference to the task
112+ task : None | asyncio . Task = None # Keep a strong reference to the task
104113 inner_exception = None
114+
105115 def do_it ():
106116 nonlocal task , inner_exception
107- logger .debug ("Creating task in %s from %s" , loop , thread_name := threading .current_thread ().name )
117+ logger .debug (
118+ "Creating task in %s from %s" ,
119+ loop ,
120+ thread_name := threading .current_thread ().name ,
121+ )
108122 try :
109123 task = loop .create_task (coro , context = sync_thread_context_copy )
110124 except Exception as exc :
@@ -118,37 +132,42 @@ def do_it():
118132 # Pauses sync-worker thread until original co-routine is finhsed in
119133 # the original event loop:
120134 event .wait ()
135+ if task is None :
136+ # should be unreachable - but this check keeps the linters quiet!
137+ raise RuntimeError ("SyncToAsync error: assynchronous code not executed" )
121138 if inner_exception :
122139 raise inner_exception
123- if exc := task .exception ():
140+ if exc := task .exception ():
124141 raise exc
125142 return task .result ()
126143
127144
128145def _sync_to_async_non_bridge (
129- coro : t .Coroutine ,
146+ coro : t .Coroutine [ t . Any , t . Any , T ] ,
130147 args : t .Sequence [t .Any ],
131- kwargs : t .Mapping [str , t .Any ]
148+ kwargs : t .Mapping [str , t .Any ],
132149) -> T :
133150 loop = _non_bridge_loop .get ()
134151 if not loop :
135152 try :
136153 loop = asyncio .new_event_loop ()
137154 except RuntimeError :
138- raise RuntimeError (D ("""\
155+ raise RuntimeError (
156+ D (
157+ """\
139158 Error trying to create a new async loop - to be able to call 'sync_to_async' from
140159 code inside a running async loop, the parent 'sync execution branch' must be called
141160 with `extraasync.async_to_sync`.
142- """ ))
161+ """
162+ )
163+ )
143164 _non_bridge_loop .set (loop )
144165
145-
146166 return loop .run_until_complete (coro )
147167
148168
149-
150169class _ThreadPool :
151- loop : asyncio . BaseEventLoop
170+ loop : AbstractEventLoop
152171
153172 pools = dict ()
154173
@@ -163,11 +182,12 @@ def get(cls, loop=None):
163182
164183 return cls .pools [loop ]
165184
166-
167185 def __init__ (self ):
168- self .idle = set ()
186+ self .idle = set [ _QueueSet | tuple [()]] ()
169187 self .all = dict ()
170- self .running = contextvars .ContextVar ("running" , default = ())
188+ self .running : ContextVar [_QueueSet | tuple [()]] = ContextVar (
189+ "running" , default = ()
190+ )
171191
172192 def __enter__ (self ):
173193 if self .idle :
@@ -201,29 +221,31 @@ def __repr__(self):
201221 return f"<_ThreadPool with { len (self .all )} threads - with { len (self .all ) - len (self .idle )} threads in use>"
202222
203223
204-
205224def _in_context_sync_worker (sync_task : _SyncTask ):
206225
207- _context_bound_loop .set (sync_task )
226+ _context_bound_task .set (sync_task )
208227 try :
209228 logger .debug ("Entering sync call in worker thread %s" , sync_task )
210229 result = sync_task .sync_task (* sync_task .args , ** sync_task .kwargs )
211230 logger .debug ("Returning from sync call in worker thread: %s" , result )
212231 finally :
213- _context_bound_loop .set (None )
232+ _context_bound_task .set (None )
214233 return result
215234
216235
217236def _sync_worker (queue ):
218- """Inner function to call sync "tasks" in a separate thread.
219-
220-
221- """
237+ """Inner function to call sync "tasks" in a separate thread."""
222238 while True :
223239 sync_task_bundle = queue .get ()
224- logger .debug ("*" * 100 + "Got new sync call %s in worker-thread %s" , sync_task_bundle , threading .current_thread ().name )
240+ logger .debug (
241+ "*" * 100 + "Got new sync call %s in worker-thread %s" ,
242+ sync_task_bundle ,
243+ threading .current_thread ().name ,
244+ )
225245 if sync_task_bundle is _poison :
226- logger .info ("Stopping sync-worker thread %s" , threading .current_thread ().name )
246+ logger .info (
247+ "Stopping sync-worker thread %s" , threading .current_thread ().name
248+ )
227249 return
228250 context = sync_task_bundle .context
229251 loop = sync_task_bundle .loop
@@ -234,15 +256,13 @@ def _sync_worker(queue):
234256 result = exc
235257 loop .call_soon_threadsafe (fut .set_exception , result )
236258 else :
237- loop .call_soon_threadsafe (fut .set_result , result )
238-
259+ loop .call_soon_threadsafe (fut .set_result , result )
239260
240261
241262def async_to_sync (
242- func : t .Callable [[...,], T ],
243- * ,
263+ func : t .Callable [..., T ],
244264 args : t .Sequence [t .Any ] = (),
245- kwargs : t .Mapping [str , t .Any | None ] = None
265+ kwargs : t .Optional [ t . Mapping [str , t .Any | None ]] = None ,
246266) -> asyncio .Future [T ]:
247267 """Returns a future wrapping a synchronous call in other thread
248268
@@ -269,10 +289,14 @@ def async_to_sync(
269289 kwargs = {}
270290
271291 task = asyncio .current_task ()
292+ if task is None :
293+ raise RuntimeError ("async_to_sync called outside of an asyncio task." )
272294 loop = task .get_loop ()
273- context = task .get_context () # it is resposibility of those using the context to copy it!
274- # (also, with the "extracontext" package, there are
275- # tools to modify an existing context if needed
295+ context = (
296+ task .get_context ()
297+ ) # it is resposibility of those using the context to copy it!
298+ # (also, with the "extracontext" package, there are
299+ # tools to modify an existing context if needed
276300
277301 done_future = loop .create_future ()
278302
@@ -286,10 +310,10 @@ def async_to_sync(
286310 lambda fut , thread_pool = thread_pool : thread_pool .__exit__ (None , None , None )
287311 )
288312
289-
290- #with _ThreadPool.get() as queue_set:
291- queue_set .queue .put (_SyncTask (func , args , kwargs , loop , context , done_future ))
313+ # with _ThreadPool.get() as queue_set:
314+ queue_set .queue .put (
315+ _SyncTask (func , tuple (args ), kwargs , loop , context , done_future )
316+ )
292317 logger .debug ("Created future awaiting sync result from worker thread for %s" , func )
293318
294319 return done_future
295-
0 commit comments