Skip to content

Commit 410cd28

Browse files
committed
Handle async gen and regular gen
Add overloads to `monitor_leaks` to avoid pyright ignores Remove extra layer of callable in `monitor_leaks`
1 parent cd8967b commit 410cd28

2 files changed

Lines changed: 58 additions & 15 deletions

File tree

reflex/monitoring.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import asyncio
44
import contextlib
55
import functools
6-
from collections.abc import Callable
6+
import inspect
7+
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
8+
from typing import TypeVar, overload
79

810
from reflex.config import get_config
911

@@ -91,28 +93,69 @@ async def monitor_async():
9193
yield
9294

9395

94-
def monitor_leaks():
96+
YieldType = TypeVar("YieldType")
97+
SendType = TypeVar("SendType")
98+
ReturnType = TypeVar("ReturnType")
99+
100+
101+
@overload
102+
def monitor_leaks(
103+
func: Callable[..., AsyncGenerator[YieldType, ReturnType]],
104+
) -> Callable[..., AsyncGenerator[YieldType, ReturnType]]: ...
105+
106+
107+
@overload
108+
def monitor_leaks(
109+
func: Callable[..., Generator[YieldType, SendType, ReturnType]],
110+
) -> Callable[..., Generator[YieldType, SendType, ReturnType]]: ...
111+
112+
113+
@overload
114+
def monitor_leaks(
115+
func: Callable[..., Awaitable[ReturnType]],
116+
) -> Callable[..., Awaitable[ReturnType]]: ...
117+
118+
119+
def monitor_leaks(func: Callable) -> Callable:
95120
"""Framework decorator using the monitoring module's context manager.
96121
122+
Args:
123+
func: The function to be monitored for leaks.
124+
97125
Returns:
98126
Decorator function that applies PyLeak monitoring to sync/async functions.
99127
"""
128+
if inspect.isasyncgenfunction(func):
100129

101-
def decorator(func: Callable):
102-
if asyncio.iscoroutinefunction(func):
130+
@functools.wraps(func)
131+
async def async_gen_wrapper(*args, **kwargs):
132+
async with monitor_async():
133+
async for item in func(*args, **kwargs):
134+
yield item
103135

104-
@functools.wraps(func)
105-
async def async_wrapper(*args, **kwargs):
106-
async with monitor_async():
107-
return await func(*args, **kwargs)
136+
return async_gen_wrapper
108137

109-
return async_wrapper
138+
if asyncio.iscoroutinefunction(func):
110139

111140
@functools.wraps(func)
112-
def sync_wrapper(*args, **kwargs):
141+
async def async_wrapper(*args, **kwargs):
142+
async with monitor_async():
143+
return await func(*args, **kwargs)
144+
145+
return async_wrapper
146+
147+
if inspect.isgeneratorfunction(func):
148+
149+
@functools.wraps(func)
150+
def gen_wrapper(*args, **kwargs):
113151
with monitor_sync():
114-
return func(*args, **kwargs)
152+
yield from func(*args, **kwargs)
153+
154+
return gen_wrapper
115155

116-
return sync_wrapper # pyright: ignore[reportReturnType]
156+
@functools.wraps(func)
157+
def sync_wrapper(*args, **kwargs):
158+
with monitor_sync():
159+
return func(*args, **kwargs)
117160

118-
return decorator
161+
return sync_wrapper # pyright: ignore[reportReturnType]

reflex/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,7 +1787,7 @@ async def _process_event(
17871787
# Get the function to process the event.
17881788
if is_pyleak_enabled():
17891789
console.debug(f"Monitoring leaks for handler: {handler.fn.__qualname__}")
1790-
fn = functools.partial(monitor_leaks()(handler.fn), state)
1790+
fn = functools.partial(monitor_leaks(handler.fn), state)
17911791
else:
17921792
fn = functools.partial(handler.fn, state)
17931793

@@ -1874,7 +1874,7 @@ async def _process_event(
18741874

18751875
# Handle regular event chains.
18761876
else:
1877-
yield await state._as_state_update(handler, events, final=True) # pyright: ignore[reportArgumentType]
1877+
yield await state._as_state_update(handler, events, final=True)
18781878

18791879
# If an error occurs, throw a window alert.
18801880
except Exception as ex:

0 commit comments

Comments
 (0)