Skip to content

Commit 248f897

Browse files
kesmit13claude
andcommitted
Add collocated Python UDF server with pre-fork process mode
Add a standalone collocated UDF server package that can run as a drop-in replacement for the Rust wasm-udf-server. Uses pre-fork worker processes (default) for true CPU parallelism, avoiding GIL contention in the C-accelerated call path. Thread pool mode is available via --process-mode thread. Collapse the wasm subpackage into a single wasm.py module since it only contained one class re-exported through __init__.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a18eb52 commit 248f897

File tree

9 files changed

+1006
-158
lines changed

9 files changed

+1006
-158
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ dev = [
7474
"singlestoredb[test,docs,build]",
7575
]
7676

77+
[project.scripts]
78+
python-udf-server = "singlestoredb.functions.ext.collocated.__main__:main"
79+
7780
[project.entry-points.pytest11]
7881
singlestoredb = "singlestoredb.pytest"
7982

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""High-performance collocated Python UDF server for SingleStoreDB."""
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
CLI entry point for the collocated Python UDF server.
3+
4+
Usage::
5+
6+
python -m singlestoredb.functions.ext.collocated \\
7+
--extension myfuncs \\
8+
--extension-path /home/user/libs \\
9+
--socket /tmp/my-udf.sock
10+
11+
Arguments match the Rust wasm-udf-server CLI for drop-in compatibility.
12+
"""
13+
import argparse
14+
import logging
15+
import os
16+
import secrets
17+
import sys
18+
import tempfile
19+
from typing import Any
20+
21+
from .registry import setup_logging
22+
from .server import Server
23+
24+
logger = logging.getLogger('collocated')
25+
26+
27+
def main(argv: Any = None) -> None:
28+
parser = argparse.ArgumentParser(
29+
prog='python -m singlestoredb.functions.ext.collocated',
30+
description='High-performance collocated Python UDF server',
31+
)
32+
parser.add_argument(
33+
'--extension',
34+
default=os.environ.get('EXTERNAL_UDF_EXTENSION', ''),
35+
help=(
36+
'Python module to import (e.g. myfuncs). '
37+
'Env: EXTERNAL_UDF_EXTENSION'
38+
),
39+
)
40+
parser.add_argument(
41+
'--extension-path',
42+
default=os.environ.get('EXTERNAL_UDF_EXTENSION_PATH', ''),
43+
help=(
44+
'Colon-separated search dirs for the module. '
45+
'Env: EXTERNAL_UDF_EXTENSION_PATH'
46+
),
47+
)
48+
parser.add_argument(
49+
'--socket',
50+
default=os.environ.get(
51+
'EXTERNAL_UDF_SOCKET_PATH',
52+
os.path.join(
53+
tempfile.gettempdir(),
54+
f'singlestore-udf-{os.getpid()}-{secrets.token_hex(4)}.sock',
55+
),
56+
),
57+
help=(
58+
'Unix socket path. '
59+
'Env: EXTERNAL_UDF_SOCKET_PATH'
60+
),
61+
)
62+
parser.add_argument(
63+
'--n-workers',
64+
type=int,
65+
default=int(os.environ.get('EXTERNAL_UDF_N_WORKERS', '0')),
66+
help=(
67+
'Worker threads (0 = CPU count). '
68+
'Env: EXTERNAL_UDF_N_WORKERS'
69+
),
70+
)
71+
parser.add_argument(
72+
'--max-connections',
73+
type=int,
74+
default=int(os.environ.get('EXTERNAL_UDF_MAX_CONNECTIONS', '32')),
75+
help=(
76+
'Socket backlog. '
77+
'Env: EXTERNAL_UDF_MAX_CONNECTIONS'
78+
),
79+
)
80+
parser.add_argument(
81+
'--log-level',
82+
default=os.environ.get('EXTERNAL_UDF_LOG_LEVEL', 'info'),
83+
choices=['debug', 'info', 'warning', 'error'],
84+
help=(
85+
'Logging level. '
86+
'Env: EXTERNAL_UDF_LOG_LEVEL'
87+
),
88+
)
89+
parser.add_argument(
90+
'--process-mode',
91+
default=os.environ.get('EXTERNAL_UDF_PROCESS_MODE', 'process'),
92+
choices=['thread', 'process'],
93+
help=(
94+
'Concurrency mode: "thread" uses a thread pool, '
95+
'"process" uses pre-fork workers for true CPU '
96+
'parallelism. Env: EXTERNAL_UDF_PROCESS_MODE'
97+
),
98+
)
99+
100+
args = parser.parse_args(argv)
101+
102+
if not args.extension:
103+
parser.error(
104+
'--extension is required '
105+
'(or set EXTERNAL_UDF_EXTENSION env var)',
106+
)
107+
108+
# Setup logging
109+
level = getattr(logging, args.log_level.upper())
110+
setup_logging(level)
111+
112+
config = {
113+
'extension': args.extension,
114+
'extension_path': args.extension_path,
115+
'socket': args.socket,
116+
'n_workers': args.n_workers,
117+
'max_connections': args.max_connections,
118+
'process_mode': args.process_mode,
119+
}
120+
121+
server = Server(config)
122+
try:
123+
server.run()
124+
except RuntimeError as exc:
125+
logger.error(str(exc))
126+
sys.exit(1)
127+
except KeyboardInterrupt:
128+
pass
129+
130+
131+
if __name__ == '__main__':
132+
main()
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""
2+
Connection handler: protocol, mmap I/O, request loop.
3+
4+
Implements the binary socket protocol matching the Rust wasm-udf-server:
5+
handshake, control signal dispatch, and UDF request loop with mmap I/O.
6+
"""
7+
from __future__ import annotations
8+
9+
import array
10+
import logging
11+
import mmap
12+
import os
13+
import select
14+
import socket
15+
import struct
16+
import threading
17+
import traceback
18+
from typing import TYPE_CHECKING
19+
20+
from .control import dispatch_control_signal
21+
from .registry import call_function
22+
23+
if TYPE_CHECKING:
24+
from .server import SharedRegistry
25+
26+
logger = logging.getLogger('collocated.connection')
27+
28+
# Protocol constants
29+
PROTOCOL_VERSION = 1
30+
STATUS_OK = 200
31+
STATUS_BAD_REQUEST = 400
32+
STATUS_ERROR = 500
33+
34+
# Minimum output mmap size to avoid repeated ftruncate
35+
_MIN_OUTPUT_SIZE = 128 * 1024
36+
37+
38+
def handle_connection(
39+
conn: socket.socket,
40+
shared_registry: SharedRegistry,
41+
shutdown_event: threading.Event,
42+
) -> None:
43+
"""Handle a single client connection (runs in a thread pool worker)."""
44+
try:
45+
_handle_connection_inner(conn, shared_registry, shutdown_event)
46+
except Exception:
47+
logger.error(f'Connection error:\n{traceback.format_exc()}')
48+
finally:
49+
try:
50+
conn.close()
51+
except OSError:
52+
pass
53+
54+
55+
def _handle_connection_inner(
56+
conn: socket.socket,
57+
shared_registry: SharedRegistry,
58+
shutdown_event: threading.Event,
59+
) -> None:
60+
"""Inner connection handler (may raise)."""
61+
# --- Handshake ---
62+
# Receive 16 bytes: [version: u64 LE][namelen: u64 LE]
63+
header = _recv_exact(conn, 16)
64+
if header is None:
65+
return
66+
version, namelen = struct.unpack('<QQ', header)
67+
68+
if version != PROTOCOL_VERSION:
69+
logger.warning(f'Unsupported protocol version: {version}')
70+
return
71+
72+
# Receive function name + 2 FDs via SCM_RIGHTS
73+
fd_model = array.array('i', [0, 0])
74+
msg, ancdata, flags, addr = conn.recvmsg(
75+
namelen,
76+
socket.CMSG_LEN(2 * fd_model.itemsize),
77+
)
78+
if len(ancdata) != 1:
79+
logger.warning(f'Expected 1 ancdata, got {len(ancdata)}')
80+
return
81+
82+
function_name = msg.decode('utf8')
83+
input_fd, output_fd = struct.unpack('<ii', ancdata[0][2])
84+
85+
# --- Control signal path ---
86+
if function_name.startswith('@@'):
87+
logger.info(f"Received control signal '{function_name}'")
88+
_handle_control_signal(
89+
conn, function_name, input_fd, output_fd, shared_registry,
90+
)
91+
return
92+
93+
# --- UDF request loop ---
94+
logger.info(f"Received request for function '{function_name}'")
95+
_handle_udf_loop(
96+
conn, function_name, input_fd, output_fd,
97+
shared_registry, shutdown_event,
98+
)
99+
100+
101+
def _handle_control_signal(
102+
conn: socket.socket,
103+
signal_name: str,
104+
input_fd: int,
105+
output_fd: int,
106+
shared_registry: SharedRegistry,
107+
) -> None:
108+
"""Handle a @@-prefixed control signal (one-shot request-response)."""
109+
try:
110+
# Read 8-byte request length
111+
len_buf = _recv_exact(conn, 8)
112+
if len_buf is None:
113+
return
114+
length = struct.unpack('<Q', len_buf)[0]
115+
116+
# Read input data from mmap (if any)
117+
request_data = b''
118+
if length > 0:
119+
mem = mmap.mmap(
120+
input_fd, length, mmap.MAP_SHARED, mmap.PROT_READ,
121+
)
122+
try:
123+
request_data = mem[:length]
124+
finally:
125+
mem.close()
126+
127+
# Dispatch
128+
result = dispatch_control_signal(
129+
signal_name, request_data, shared_registry,
130+
)
131+
132+
if result.ok:
133+
# Write response to output mmap
134+
response_bytes = result.data.encode('utf8')
135+
response_size = len(response_bytes)
136+
os.ftruncate(output_fd, max(_MIN_OUTPUT_SIZE, response_size))
137+
os.lseek(output_fd, 0, os.SEEK_SET)
138+
os.write(output_fd, response_bytes)
139+
140+
# Send [status=200, size]
141+
conn.sendall(struct.pack('<QQ', STATUS_OK, response_size))
142+
logger.debug(
143+
f"Control signal '{signal_name}' succeeded, "
144+
f'{response_size} bytes',
145+
)
146+
else:
147+
# Send [status=400, len, message]
148+
err_bytes = result.data.encode('utf8')
149+
conn.sendall(
150+
struct.pack(
151+
f'<QQ{len(err_bytes)}s',
152+
STATUS_BAD_REQUEST, len(err_bytes), err_bytes,
153+
),
154+
)
155+
logger.warning(
156+
f"Control signal '{signal_name}' failed: {result.data}",
157+
)
158+
finally:
159+
os.close(input_fd)
160+
os.close(output_fd)
161+
162+
163+
def _handle_udf_loop(
164+
conn: socket.socket,
165+
function_name: str,
166+
input_fd: int,
167+
output_fd: int,
168+
shared_registry: SharedRegistry,
169+
shutdown_event: threading.Event,
170+
) -> None:
171+
"""Handle the UDF request loop for a single function."""
172+
# Track output mmap size to avoid repeated ftruncate
173+
current_output_size = 0
174+
175+
try:
176+
# Get thread-local registry
177+
registry = shared_registry.get_thread_local_registry()
178+
179+
while not shutdown_event.is_set():
180+
# Select-based recv with 100ms timeout for shutdown checks
181+
readable, _, _ = select.select([conn], [], [], 0.1)
182+
if not readable:
183+
continue
184+
185+
# Read 8-byte request length
186+
len_buf = _recv_exact(conn, 8)
187+
if len_buf is None:
188+
break
189+
length = struct.unpack('<Q', len_buf)[0]
190+
if length == 0:
191+
break
192+
193+
# Read input from mmap
194+
mem = mmap.mmap(
195+
input_fd, length, mmap.MAP_SHARED, mmap.PROT_READ,
196+
)
197+
try:
198+
input_data = bytes(mem[:length])
199+
finally:
200+
mem.close()
201+
202+
# Refresh registry if generation changed
203+
registry = shared_registry.get_thread_local_registry()
204+
205+
# Call function
206+
try:
207+
output_data = call_function(registry, function_name, input_data)
208+
209+
# Write result to output mmap
210+
response_size = len(output_data)
211+
needed = max(_MIN_OUTPUT_SIZE, response_size)
212+
if needed > current_output_size:
213+
os.ftruncate(output_fd, needed)
214+
current_output_size = needed
215+
os.lseek(output_fd, 0, os.SEEK_SET)
216+
os.write(output_fd, output_data)
217+
218+
# Send [status=200, size]
219+
conn.sendall(struct.pack('<QQ', STATUS_OK, response_size))
220+
221+
except Exception as e:
222+
error_msg = (
223+
f"error in function '{function_name}': {e}"
224+
)
225+
logger.error(error_msg)
226+
for line in traceback.format_exc().splitlines():
227+
logger.error(line)
228+
err_bytes = error_msg.encode('utf8')
229+
conn.sendall(
230+
struct.pack(
231+
f'<QQ{len(err_bytes)}s',
232+
STATUS_ERROR, len(err_bytes), err_bytes,
233+
),
234+
)
235+
break
236+
237+
finally:
238+
os.close(input_fd)
239+
os.close(output_fd)
240+
241+
242+
def _recv_exact(sock: socket.socket, n: int) -> bytes | None:
243+
"""Receive exactly n bytes, or return None on EOF."""
244+
buf = bytearray()
245+
while len(buf) < n:
246+
chunk = sock.recv(n - len(buf))
247+
if not chunk:
248+
return None
249+
buf.extend(chunk)
250+
return bytes(buf)

0 commit comments

Comments
 (0)