99import asyncio
1010import sys
1111from functools import partial
12- from inspect import CO_ASYNC_GENERATOR , CO_COROUTINE , iscoroutinefunction
12+ from inspect import (
13+ CO_ASYNC_GENERATOR ,
14+ CO_COROUTINE ,
15+ CO_GENERATOR ,
16+ CO_ITERABLE_COROUTINE ,
17+ iscoroutinefunction ,
18+ )
1319from threading import Lock , RLock
1420
1521from .__wrapt__ import BoundFunctionWrapper , CallableObjectProxy , FunctionWrapper
2329# inner decorator that invokes an async def via asyncio.run()).
2430
2531
26- _CO_SYNC_MASK = ~ (CO_COROUTINE | CO_ASYNC_GENERATOR )
27-
28-
2932class _SyncCodeProxy (CallableObjectProxy ):
3033
34+ def __init__ (self , wrapped , generator = None ):
35+ super ().__init__ (wrapped )
36+ self ._self_generator = generator
37+
3138 @property
3239 def co_flags (self ):
33- return self .__wrapped__ .co_flags & _CO_SYNC_MASK
40+ original = self .__wrapped__ .co_flags
41+ # Strip async-axis and iterable-coroutine bits; sync means neither
42+ # coroutine function nor async generator nor types.coroutine-style.
43+ flags = original & ~ (CO_COROUTINE | CO_ASYNC_GENERATOR | CO_ITERABLE_COROUTINE )
44+ if self ._self_generator is True :
45+ flags |= CO_GENERATOR
46+ elif self ._self_generator is False :
47+ flags &= ~ CO_GENERATOR
48+ else :
49+ # Auto: if input was an async generator, preserve generator-ness
50+ # on the sync side by setting CO_GENERATOR. Otherwise leave
51+ # CO_GENERATOR as-is (already copied from the wrapped flags).
52+ if original & CO_ASYNC_GENERATOR :
53+ flags |= CO_GENERATOR
54+ return flags
3455
3556
3657class _SyncFunctionSurrogate (CallableObjectProxy ):
3758
59+ def __init__ (self , wrapped , generator = None ):
60+ super ().__init__ (wrapped )
61+ self ._self_generator = generator
62+
3863 @property
3964 def __code__ (self ):
40- return _SyncCodeProxy (self .__wrapped__ .__code__ )
65+ return _SyncCodeProxy (self .__wrapped__ .__code__ , self . _self_generator )
4166
4267
4368class _BoundSyncFunctionWrapper (BoundFunctionWrapper ):
@@ -48,77 +73,151 @@ def __init__(self, *args, **kwargs):
4873
4974 @property
5075 def __func__ (self ):
51- return _SyncFunctionSurrogate (self .__wrapped__ .__func__ )
76+ return _SyncFunctionSurrogate (
77+ self .__wrapped__ .__func__ , self ._self_parent ._self_generator
78+ )
5279
5380
5481class _SyncFunctionWrapper (FunctionWrapper ):
5582
5683 __bound_function_wrapper__ = _BoundSyncFunctionWrapper
5784
58- def __init__ (self , * args , ** kwargs ):
59- super ().__init__ (* args , ** kwargs )
85+ def __init__ (self , wrapped , wrapper , generator = None ):
86+ super ().__init__ (wrapped , wrapper )
6087 self ._self_is_not_coroutine = True
88+ self ._self_generator = generator
6189
6290 @property
6391 def __code__ (self ):
64- return _SyncCodeProxy (self .__wrapped__ .__code__ )
92+ return _SyncCodeProxy (self .__wrapped__ .__code__ , self . _self_generator )
6593
6694
6795class _AsyncCodeProxy (CallableObjectProxy ):
6896
97+ def __init__ (self , wrapped , generator = None ):
98+ super ().__init__ (wrapped )
99+ self ._self_generator = generator
100+
69101 @property
70102 def co_flags (self ):
71- return (self .__wrapped__ .co_flags & ~ CO_ASYNC_GENERATOR ) | CO_COROUTINE
103+ original = self .__wrapped__ .co_flags
104+ # Strip all four convention bits; we reassert the right ones below.
105+ flags = original & ~ (
106+ CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
107+ )
108+ if self ._self_generator is True :
109+ flags |= CO_ASYNC_GENERATOR
110+ elif self ._self_generator is False :
111+ flags |= CO_COROUTINE
112+ else :
113+ # Auto: if input was a generator (sync or async), produce an
114+ # async generator; otherwise produce a coroutine function.
115+ if original & (CO_GENERATOR | CO_ASYNC_GENERATOR ):
116+ flags |= CO_ASYNC_GENERATOR
117+ else :
118+ flags |= CO_COROUTINE
119+ return flags
72120
73121
74122class _AsyncFunctionSurrogate (CallableObjectProxy ):
75123
124+ def __init__ (self , wrapped , generator = None ):
125+ super ().__init__ (wrapped )
126+ self ._self_generator = generator
127+
76128 @property
77129 def __code__ (self ):
78- return _AsyncCodeProxy (self .__wrapped__ .__code__ )
130+ return _AsyncCodeProxy (self .__wrapped__ .__code__ , self . _self_generator )
79131
80132
81133class _BoundAsyncFunctionWrapper (BoundFunctionWrapper ):
82134
83135 @property
84136 def __func__ (self ):
85- return _AsyncFunctionSurrogate (self .__wrapped__ .__func__ )
137+ return _AsyncFunctionSurrogate (
138+ self .__wrapped__ .__func__ , self ._self_parent ._self_generator
139+ )
86140
87141
88142class _AsyncFunctionWrapper (FunctionWrapper ):
89143
90144 __bound_function_wrapper__ = _BoundAsyncFunctionWrapper
91145
146+ def __init__ (self , wrapped , wrapper , generator = None ):
147+ super ().__init__ (wrapped , wrapper )
148+ self ._self_generator = generator
149+
92150 @property
93151 def __code__ (self ):
94- return _AsyncCodeProxy (self .__wrapped__ .__code__ )
152+ return _AsyncCodeProxy (self .__wrapped__ .__code__ , self . _self_generator )
95153
96154
97- def mark_as_sync (wrapped ):
155+ def mark_as_sync (wrapped = None , / , * , generator = None ):
98156 """Mark a callable as synchronous from the perspective of calling
99157 convention detection. The returned wrapper is a pass-through that
100158 reports `inspect.iscoroutinefunction()` as False regardless of
101159 whether the underlying callable is declared `async def`. Useful
102160 when a stacked decorator has already collapsed an async function
103- into a synchronous one (for example by using `asyncio.run()`)."""
161+ into a synchronous one (for example by using `asyncio.run()`).
104162
105- def wrapper ( wrapped , instance , args , kwargs ):
106- return wrapped ( * args , ** kwargs )
163+ The `generator` keyword toggles the sync generator bit
164+ (`CO_GENERATOR`) on the resulting wrapper. Tri-state:
107165
108- return _SyncFunctionWrapper (wrapped , wrapper )
166+ - `None` (default): auto. Preserve generator-ness from the input --
167+ if the input was an async generator, the wrapper reports as a sync
168+ generator; otherwise CO_GENERATOR is copied through unchanged.
169+ - `True`: force CO_GENERATOR on. Wrapper reports as a sync generator.
170+ - `False`: force CO_GENERATOR off. Wrapper reports as a plain sync
171+ function even if the input had CO_GENERATOR set.
172+
173+ Regardless of `generator`, CO_COROUTINE, CO_ASYNC_GENERATOR, and
174+ CO_ITERABLE_COROUTINE are all cleared (sync means none of those).
175+ """
176+
177+ def _decorator (wrapped ):
178+ def _wrapper (wrapped , instance , args , kwargs ):
179+ return wrapped (* args , ** kwargs )
109180
181+ return _SyncFunctionWrapper (wrapped , _wrapper , generator = generator )
110182
111- def mark_as_async (wrapped ):
183+ if wrapped is None :
184+ return _decorator
185+ return _decorator (wrapped )
186+
187+
188+ def mark_as_async (wrapped = None , / , * , generator = None ):
112189 """Mark a callable as asynchronous from the perspective of calling
113190 convention detection. The returned wrapper reports
114191 `inspect.iscoroutinefunction()` as True regardless of whether the
115192 underlying callable is declared `async def`. Useful when a stacked
116- decorator returns a coroutine from a plain `def` wrapper."""
193+ decorator returns a coroutine from a plain `def` wrapper.
194+
195+ The `generator` keyword chooses between coroutine function and
196+ async generator reporting. Tri-state:
197+
198+ - `None` (default): auto. If the input was a sync or async
199+ generator, the wrapper reports as an async generator
200+ (`CO_ASYNC_GENERATOR`); otherwise it reports as a coroutine
201+ function (`CO_COROUTINE`).
202+ - `True`: force async generator reporting (`CO_ASYNC_GENERATOR` set,
203+ `CO_COROUTINE` cleared). These two flags are mutually exclusive at
204+ the CPython code-object level.
205+ - `False`: force coroutine function reporting (`CO_COROUTINE` set,
206+ `CO_ASYNC_GENERATOR` cleared).
207+
208+ CO_GENERATOR and CO_ITERABLE_COROUTINE are always cleared (the
209+ async path does not use either).
210+ """
117211
118- async def wrapper (wrapped , instance , args , kwargs ):
212+ async def _wrapper (wrapped , instance , args , kwargs ):
119213 return wrapped (* args , ** kwargs )
120214
121- return _AsyncFunctionWrapper (wrapped , wrapper )
215+ def _decorator (wrapped ):
216+ return _AsyncFunctionWrapper (wrapped , _wrapper , generator = generator )
217+
218+ if wrapped is None :
219+ return _decorator
220+ return _decorator (wrapped )
122221
123222
124223def async_to_sync (wrapped ):
0 commit comments