|
3 | 3 | import asyncio |
4 | 4 | import contextlib |
5 | 5 | import functools |
6 | | -from collections.abc import Callable |
| 6 | +import inspect |
| 7 | +from collections.abc import AsyncGenerator, Awaitable, Callable, Generator |
| 8 | +from typing import TypeVar, overload |
7 | 9 |
|
8 | 10 | from reflex.config import get_config |
9 | 11 |
|
@@ -91,28 +93,69 @@ async def monitor_async(): |
91 | 93 | yield |
92 | 94 |
|
93 | 95 |
|
94 | | -def monitor_leaks(): |
| 96 | +YieldType = TypeVar("YieldType") |
| 97 | +SendType = TypeVar("SendType") |
| 98 | +ReturnType = TypeVar("ReturnType") |
| 99 | + |
| 100 | + |
| 101 | +@overload |
| 102 | +def monitor_leaks( |
| 103 | + func: Callable[..., AsyncGenerator[YieldType, ReturnType]], |
| 104 | +) -> Callable[..., AsyncGenerator[YieldType, ReturnType]]: ... |
| 105 | + |
| 106 | + |
| 107 | +@overload |
| 108 | +def monitor_leaks( |
| 109 | + func: Callable[..., Generator[YieldType, SendType, ReturnType]], |
| 110 | +) -> Callable[..., Generator[YieldType, SendType, ReturnType]]: ... |
| 111 | + |
| 112 | + |
| 113 | +@overload |
| 114 | +def monitor_leaks( |
| 115 | + func: Callable[..., Awaitable[ReturnType]], |
| 116 | +) -> Callable[..., Awaitable[ReturnType]]: ... |
| 117 | + |
| 118 | + |
| 119 | +def monitor_leaks(func: Callable) -> Callable: |
95 | 120 | """Framework decorator using the monitoring module's context manager. |
96 | 121 |
|
| 122 | + Args: |
| 123 | + func: The function to be monitored for leaks. |
| 124 | +
|
97 | 125 | Returns: |
98 | 126 | Decorator function that applies PyLeak monitoring to sync/async functions. |
99 | 127 | """ |
| 128 | + if inspect.isasyncgenfunction(func): |
100 | 129 |
|
101 | | - def decorator(func: Callable): |
102 | | - if asyncio.iscoroutinefunction(func): |
| 130 | + @functools.wraps(func) |
| 131 | + async def async_gen_wrapper(*args, **kwargs): |
| 132 | + async with monitor_async(): |
| 133 | + async for item in func(*args, **kwargs): |
| 134 | + yield item |
103 | 135 |
|
104 | | - @functools.wraps(func) |
105 | | - async def async_wrapper(*args, **kwargs): |
106 | | - async with monitor_async(): |
107 | | - return await func(*args, **kwargs) |
| 136 | + return async_gen_wrapper |
108 | 137 |
|
109 | | - return async_wrapper |
| 138 | + if asyncio.iscoroutinefunction(func): |
110 | 139 |
|
111 | 140 | @functools.wraps(func) |
112 | | - def sync_wrapper(*args, **kwargs): |
| 141 | + async def async_wrapper(*args, **kwargs): |
| 142 | + async with monitor_async(): |
| 143 | + return await func(*args, **kwargs) |
| 144 | + |
| 145 | + return async_wrapper |
| 146 | + |
| 147 | + if inspect.isgeneratorfunction(func): |
| 148 | + |
| 149 | + @functools.wraps(func) |
| 150 | + def gen_wrapper(*args, **kwargs): |
113 | 151 | with monitor_sync(): |
114 | | - return func(*args, **kwargs) |
| 152 | + yield from func(*args, **kwargs) |
| 153 | + |
| 154 | + return gen_wrapper |
115 | 155 |
|
116 | | - return sync_wrapper # pyright: ignore[reportReturnType] |
| 156 | + @functools.wraps(func) |
| 157 | + def sync_wrapper(*args, **kwargs): |
| 158 | + with monitor_sync(): |
| 159 | + return func(*args, **kwargs) |
117 | 160 |
|
118 | | - return decorator |
| 161 | + return sync_wrapper # pyright: ignore[reportReturnType] |
0 commit comments