diff --git a/reflex/utils/misc.py b/reflex/utils/misc.py index 1c5f948aa5f..59b81ca7ad8 100644 --- a/reflex/utils/misc.py +++ b/reflex/utils/misc.py @@ -5,20 +5,27 @@ from typing import Any -async def run_in_thread(func: Callable) -> Any: +async def run_in_thread(func: Callable, *, timeout: float | None = None) -> Any: """Run a function in a separate thread. To not block the UI event queue, run_in_thread must be inside inside a rx.event(background=True) decorated method. Args: func: The non-async function to run. + timeout: Maximum number of seconds to wait for the function to complete. + If None (default), wait indefinitely. Raises: ValueError: If the function is an async function. + asyncio.TimeoutError: If the function execution exceeds the specified timeout. Returns: Any: The return value of the function. """ if asyncio.coroutines.iscoroutinefunction(func): raise ValueError("func must be a non-async function") - return await asyncio.get_event_loop().run_in_executor(None, func) + + task = asyncio.get_event_loop().run_in_executor(None, func) + if timeout is not None: + return await asyncio.wait_for(task, timeout=timeout) + return await task diff --git a/tests/units/utils/test_misc.py b/tests/units/utils/test_misc.py new file mode 100644 index 00000000000..49cc705dfe3 --- /dev/null +++ b/tests/units/utils/test_misc.py @@ -0,0 +1,42 @@ +"""Test misc utilities.""" + +import asyncio +import time + +import pytest + +from reflex.utils.misc import run_in_thread + + +async def test_run_in_thread(): + """Test that run_in_thread runs a function in a separate thread.""" + + def simple_function(): + return 42 + + result = await run_in_thread(simple_function) + assert result == 42 + + def slow_function(): + time.sleep(0.1) + return "completed" + + result = await run_in_thread(slow_function, timeout=0.5) + assert result == "completed" + + async def async_function(): + return 42 + + with pytest.raises(ValueError): + await run_in_thread(async_function) + + +async def test_run_in_thread_timeout(): + """Test that run_in_thread raises TimeoutError when timeout is exceeded.""" + + def very_slow_function(): + time.sleep(0.5) + return "should not reach here" + + with pytest.raises(asyncio.TimeoutError): + await run_in_thread(very_slow_function, timeout=0.1)