Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions singlestoredb/functions/ext/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""
import argparse
import asyncio
import concurrent.futures
import contextvars
import dataclasses
import datetime
Expand Down Expand Up @@ -1000,6 +1001,15 @@ def __init__(
self.log_level = log_level
self.disable_metrics = disable_metrics

# Dedicated event loop for async UDF execution, isolated from the server loop
self._udf_loop = asyncio.new_event_loop()
self._udf_thread = threading.Thread(
target=self._udf_loop.run_forever,
daemon=True,
name='async-udf-loop',
)
self._udf_thread.start()
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

# Configure logging
self._configure_logging()

Expand Down Expand Up @@ -1033,6 +1043,11 @@ def _configure_logging(self) -> None:
# Prevent propagation to avoid duplicate or differently formatted messages
self.logger.propagate = False

def shutdown(self) -> None:
"""Shut down the dedicated UDF event loop."""
self._udf_loop.call_soon_threadsafe(self._udf_loop.stop)
self._udf_thread.join(timeout=5)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

def get_uvicorn_log_config(self) -> Dict[str, Any]:
"""
Create uvicorn log config that matches the Application's logging format.
Expand Down Expand Up @@ -1189,15 +1204,24 @@ async def __call__(
func_info['colspec'], b''.join(data),
)

func_task = asyncio.create_task(
func(cancel_event, call_timer, *inputs)
if func_info['is_async']
else to_thread(
Comment thread
cursor[bot] marked this conversation as resolved.
lambda: asyncio.run(
func(cancel_event, call_timer, *inputs),
func_task: 'asyncio.Task[Any]'
udf_future: 'Optional[concurrent.futures.Future[Any]]' = None
if func_info['is_async']:
udf_future = asyncio.run_coroutine_threadsafe(
func(cancel_event, call_timer, *inputs),
self._udf_loop,
)
func_task = asyncio.ensure_future(
asyncio.wrap_future(udf_future),
)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
else:
func_task = asyncio.create_task(
to_thread(
lambda: asyncio.run(
func(cancel_event, call_timer, *inputs),
),
),
),
)
)
disconnect_task = asyncio.create_task(
asyncio.sleep(int(1e9))
if ignore_cancel else cancel_on_disconnect(receive),
Expand All @@ -1218,12 +1242,16 @@ async def __call__(
for task in done:
if task is disconnect_task:
cancel_event.set()
if udf_future is not None:
udf_future.cancel()
raise asyncio.CancelledError(
'Function call was cancelled by client disconnect',
)

elif task is timeout_task:
cancel_event.set()
if udf_future is not None:
udf_future.cancel()
raise asyncio.TimeoutError(
'Function call was cancelled due to timeout',
)
Expand Down
34 changes: 18 additions & 16 deletions singlestoredb/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
try:
import pandas as pd
has_pandas = True
_pd_str_dtype = str(pd.DataFrame({'a': ['x']}).dtypes['a'])
except ImportError:
has_pandas = False
_pd_str_dtype = 'object'


class TestConnection(unittest.TestCase):
Expand Down Expand Up @@ -1124,21 +1126,21 @@ def test_alltypes_pandas(self):
('timestamp', 'datetime64[us]'),
('timestamp_6', 'datetime64[us]'),
('year', 'float64'),
('char_100', 'object'),
('char_100', _pd_str_dtype),
('binary_100', 'object'),
('varchar_200', 'object'),
('varchar_200', _pd_str_dtype),
('varbinary_200', 'object'),
('longtext', 'object'),
('mediumtext', 'object'),
('text', 'object'),
('tinytext', 'object'),
('longtext', _pd_str_dtype),
('mediumtext', _pd_str_dtype),
('text', _pd_str_dtype),
('tinytext', _pd_str_dtype),
('longblob', 'object'),
('mediumblob', 'object'),
('blob', 'object'),
('tinyblob', 'object'),
('json', 'object'),
('enum', 'object'),
('set', 'object'),
('enum', _pd_str_dtype),
('set', _pd_str_dtype),
('bit', 'object'),
]

Expand Down Expand Up @@ -1266,21 +1268,21 @@ def test_alltypes_no_nulls_pandas(self):
('timestamp', 'datetime64[us]'),
('timestamp_6', 'datetime64[us]'),
('year', 'int16'),
('char_100', 'object'),
('char_100', _pd_str_dtype),
('binary_100', 'object'),
('varchar_200', 'object'),
('varchar_200', _pd_str_dtype),
('varbinary_200', 'object'),
('longtext', 'object'),
('mediumtext', 'object'),
('text', 'object'),
('tinytext', 'object'),
('longtext', _pd_str_dtype),
('mediumtext', _pd_str_dtype),
('text', _pd_str_dtype),
('tinytext', _pd_str_dtype),
('longblob', 'object'),
('mediumblob', 'object'),
('blob', 'object'),
('tinyblob', 'object'),
('json', 'object'),
('enum', 'object'),
('set', 'object'),
('enum', _pd_str_dtype),
('set', _pd_str_dtype),
('bit', 'object'),
]

Expand Down
Loading