Skip to content

Commit 540c81b

Browse files
committed
Merge branch 'main' into pipeline_cls
2 parents d99f8fc + 66b094c commit 540c81b

1 file changed

Lines changed: 71 additions & 47 deletions

File tree

extraasync/sync_async_bridge.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import threading
77
import typing as t
88

9+
from asyncio.events import AbstractEventLoop
10+
from contextvars import ContextVar
911
from queue import Queue as ThreadingQueue
1012
from textwrap import dedent as D
1113

@@ -27,19 +29,16 @@
2729
console_handler.setLevel(logging.DEBUG)
2830

2931
# Create a formatter and attach it to the handler
30-
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
32+
formatter = logging.Formatter(
33+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
34+
)
3135
console_handler.setFormatter(formatter)
3236

3337
# Attach the handler to the logger
3438
logger.addHandler(console_handler)
3539

3640

3741
###########################
38-
39-
T = t.TypeVar("T")
40-
41-
_non_bridge_loop = contextvars.ContextVar("non_bridge_loop", default=None)
42-
_context_bound_loop = contextvars.ContextVar("_context_bound_loop", default=None)
4342
_poison = object()
4443

4544

@@ -52,18 +51,28 @@ class _QueueSet(t.NamedTuple):
5251
class _SyncTask(t.NamedTuple):
5352
sync_task: t.Callable
5453
args: tuple
55-
kwargs: dict
56-
loop: asyncio.BaseEventLoop
54+
kwargs: t.Mapping
55+
loop: AbstractEventLoop
5756
context: contextvars.Context
5857
done_future: asyncio.Future
5958

6059

60+
T = t.TypeVar("T")
61+
62+
_non_bridge_loop: ContextVar[t.Optional[AbstractEventLoop]] = ContextVar(
63+
"non_bridge_loop", default=None
64+
)
65+
_context_bound_task: ContextVar[t.Optional[_SyncTask]] = ContextVar(
66+
"_context_bound_task", default=None
67+
)
68+
69+
6170
def sync_to_async(
62-
func: t.Callable[[...,], T] | t.Coroutine,
71+
func: t.Callable[..., T] | t.Coroutine,
6372
args: t.Sequence[t.Any] = (),
64-
kwargs: t.Mapping[str, t.Any | None] = None
73+
kwargs: t.Optional[t.Mapping[str, t.Any | None]] = None,
6574
) -> T:
66-
""" Allows calling an async function from a synchronous context.
75+
"""Allows calling an async function from a synchronous context.
6776
6877
When called from a synchronous scenario, this
6978
will create and cache an asyncio.loop instance per thread for code that otherwise
@@ -88,9 +97,9 @@ def sync_to_async(
8897
raise TypeError("Can't accept extra arguments for existing coroutine")
8998
coro = func
9099
else:
91-
coro = func(*args, **kwargs)
100+
coro = t.cast(t.Callable, func)(*args, **kwargs)
92101

93-
root_sync_task = _context_bound_loop.get()
102+
root_sync_task = _context_bound_task.get()
94103
if not root_sync_task:
95104
return _sync_to_async_non_bridge(coro, args, kwargs)
96105

@@ -100,11 +109,16 @@ def sync_to_async(
100109
event = _ThreadPool.get(loop).all[threading.current_thread()].event
101110

102111
event.clear()
103-
task = None # Keep a strong reference to the task
112+
task: None | asyncio.Task = None # Keep a strong reference to the task
104113
inner_exception = None
114+
105115
def do_it():
106116
nonlocal task, inner_exception
107-
logger.debug("Creating task in %s from %s", loop, thread_name:=threading.current_thread().name)
117+
logger.debug(
118+
"Creating task in %s from %s",
119+
loop,
120+
thread_name := threading.current_thread().name,
121+
)
108122
try:
109123
task = loop.create_task(coro, context=sync_thread_context_copy)
110124
except Exception as exc:
@@ -118,37 +132,42 @@ def do_it():
118132
# Pauses sync-worker thread until original co-routine is finhsed in
119133
# the original event loop:
120134
event.wait()
135+
if task is None:
136+
# should be unreachable - but this check keeps the linters quiet!
137+
raise RuntimeError("SyncToAsync error: assynchronous code not executed")
121138
if inner_exception:
122139
raise inner_exception
123-
if exc:=task.exception():
140+
if exc := task.exception():
124141
raise exc
125142
return task.result()
126143

127144

128145
def _sync_to_async_non_bridge(
129-
coro: t.Coroutine,
146+
coro: t.Coroutine[t.Any, t.Any, T],
130147
args: t.Sequence[t.Any],
131-
kwargs: t.Mapping[str, t.Any]
148+
kwargs: t.Mapping[str, t.Any],
132149
) -> T:
133150
loop = _non_bridge_loop.get()
134151
if not loop:
135152
try:
136153
loop = asyncio.new_event_loop()
137154
except RuntimeError:
138-
raise RuntimeError(D("""\
155+
raise RuntimeError(
156+
D(
157+
"""\
139158
Error trying to create a new async loop - to be able to call 'sync_to_async' from
140159
code inside a running async loop, the parent 'sync execution branch' must be called
141160
with `extraasync.async_to_sync`.
142-
"""))
161+
"""
162+
)
163+
)
143164
_non_bridge_loop.set(loop)
144165

145-
146166
return loop.run_until_complete(coro)
147167

148168

149-
150169
class _ThreadPool:
151-
loop: asyncio.BaseEventLoop
170+
loop: AbstractEventLoop
152171

153172
pools = dict()
154173

@@ -163,11 +182,12 @@ def get(cls, loop=None):
163182

164183
return cls.pools[loop]
165184

166-
167185
def __init__(self):
168-
self.idle = set()
186+
self.idle = set[_QueueSet | tuple[()]]()
169187
self.all = dict()
170-
self.running = contextvars.ContextVar("running", default=())
188+
self.running: ContextVar[_QueueSet | tuple[()]] = ContextVar(
189+
"running", default=()
190+
)
171191

172192
def __enter__(self):
173193
if self.idle:
@@ -201,29 +221,31 @@ def __repr__(self):
201221
return f"<_ThreadPool with {len(self.all)} threads - with {len(self.all) - len(self.idle)} threads in use>"
202222

203223

204-
205224
def _in_context_sync_worker(sync_task: _SyncTask):
206225

207-
_context_bound_loop.set(sync_task)
226+
_context_bound_task.set(sync_task)
208227
try:
209228
logger.debug("Entering sync call in worker thread %s", sync_task)
210229
result = sync_task.sync_task(*sync_task.args, **sync_task.kwargs)
211230
logger.debug("Returning from sync call in worker thread: %s", result)
212231
finally:
213-
_context_bound_loop.set(None)
232+
_context_bound_task.set(None)
214233
return result
215234

216235

217236
def _sync_worker(queue):
218-
"""Inner function to call sync "tasks" in a separate thread.
219-
220-
221-
"""
237+
"""Inner function to call sync "tasks" in a separate thread."""
222238
while True:
223239
sync_task_bundle = queue.get()
224-
logger.debug("*" * 100 + "Got new sync call %s in worker-thread %s", sync_task_bundle, threading.current_thread().name)
240+
logger.debug(
241+
"*" * 100 + "Got new sync call %s in worker-thread %s",
242+
sync_task_bundle,
243+
threading.current_thread().name,
244+
)
225245
if sync_task_bundle is _poison:
226-
logger.info("Stopping sync-worker thread %s", threading.current_thread().name)
246+
logger.info(
247+
"Stopping sync-worker thread %s", threading.current_thread().name
248+
)
227249
return
228250
context = sync_task_bundle.context
229251
loop = sync_task_bundle.loop
@@ -234,15 +256,13 @@ def _sync_worker(queue):
234256
result = exc
235257
loop.call_soon_threadsafe(fut.set_exception, result)
236258
else:
237-
loop.call_soon_threadsafe(fut.set_result, result )
238-
259+
loop.call_soon_threadsafe(fut.set_result, result)
239260

240261

241262
def async_to_sync(
242-
func: t.Callable[[...,], T],
243-
*,
263+
func: t.Callable[..., T],
244264
args: t.Sequence[t.Any] = (),
245-
kwargs: t.Mapping[str, t.Any | None] = None
265+
kwargs: t.Optional[t.Mapping[str, t.Any | None]] = None,
246266
) -> asyncio.Future[T]:
247267
"""Returns a future wrapping a synchronous call in other thread
248268
@@ -269,10 +289,14 @@ def async_to_sync(
269289
kwargs = {}
270290

271291
task = asyncio.current_task()
292+
if task is None:
293+
raise RuntimeError("async_to_sync called outside of an asyncio task.")
272294
loop = task.get_loop()
273-
context = task.get_context() # it is resposibility of those using the context to copy it!
274-
# (also, with the "extracontext" package, there are
275-
# tools to modify an existing context if needed
295+
context = (
296+
task.get_context()
297+
) # it is resposibility of those using the context to copy it!
298+
# (also, with the "extracontext" package, there are
299+
# tools to modify an existing context if needed
276300

277301
done_future = loop.create_future()
278302

@@ -286,10 +310,10 @@ def async_to_sync(
286310
lambda fut, thread_pool=thread_pool: thread_pool.__exit__(None, None, None)
287311
)
288312

289-
290-
#with _ThreadPool.get() as queue_set:
291-
queue_set.queue.put(_SyncTask(func, args, kwargs, loop, context, done_future))
313+
# with _ThreadPool.get() as queue_set:
314+
queue_set.queue.put(
315+
_SyncTask(func, tuple(args), kwargs, loop, context, done_future)
316+
)
292317
logger.debug("Created future awaiting sync result from worker thread for %s", func)
293318

294319
return done_future
295-

0 commit comments

Comments
 (0)