Skip to content

Commit 34e2452

Browse files
kesmit13claude
andcommitted
Address PR #121 review comments: memory safety, correctness, hardening
accel.c: - Replace empty TODO type stubs with NotImplementedError raises - Add CHECK_REMAINING macro for bounds checking on buffer reads - Replace unaligned pointer-cast reads with memcpy for WASM/ARM safety - Fix double-decref in output error paths (set to NULL before goto) - Fix Py_None reference leak by removing pre-switch INCREF - Fix MYSQL_TYPE_NULL consuming an extra byte from next column - Add PyErr_Format in default switch cases - Add PyErr_Occurred() checks after PyLong/PyFloat conversions Python: - Align list/tuple multi-return handling in registry.py with C path - Add _write_all_fd helper for partial os.write() handling - Harden handshake recvmsg: name length bound, ancdata validation, MSG_CTRUNC check, FD cleanup on error - Wrap get_context('fork') with platform safety error - Narrow events.py exception catch to (ImportError, OSError) - Fix _iquery DataFrame check ordering (check before list()) - Expand setblocking(False) warning comment - Update WIT and wasm.py docstrings for code parameter Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b66653a commit 34e2452

File tree

8 files changed

+269
-86
lines changed

8 files changed

+269
-86
lines changed

accel.c

Lines changed: 176 additions & 63 deletions
Large diffs are not rendered by default.

singlestoredb/connection.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,12 +1166,13 @@ def _iquery(
11661166
cur.execute(oper, params)
11671167
if not re.match(r'^\s*(select|show|call|echo)\s+', oper, flags=re.I):
11681168
return []
1169-
out = list(cur.fetchall())
1169+
raw = cur.fetchall()
1170+
if hasattr(raw, 'to_dict') and callable(raw.to_dict):
1171+
return raw.to_dict(orient='records')
1172+
out = list(raw)
11701173
if not out:
11711174
return []
1172-
if hasattr(out, 'to_dict') and callable(getattr(out, 'to_dict')):
1173-
out = out.to_dict(orient='records')
1174-
elif isinstance(out[0], (tuple, list)):
1175+
if isinstance(out[0], (tuple, list)):
11751176
if cur.description:
11761177
names = [x[0] for x in cur.description]
11771178
if fix_names:

singlestoredb/functions/ext/collocated/connection.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
# Pre-pack the status OK header prefix to avoid per-request struct.pack
4242
_STATUS_OK_PREFIX = struct.pack('<Q', STATUS_OK)
4343

44+
# Maximum function name length to prevent resource exhaustion
45+
_MAX_FUNCTION_NAME_LEN = 4096
46+
4447
# Enable per-request timing via environment variable
4548
_PROFILE = os.environ.get('SINGLESTOREDB_UDF_PROFILE', '') == '1'
4649

@@ -83,18 +86,56 @@ def _handle_connection_inner(
8386
logger.warning(f'Unsupported protocol version: {version}')
8487
return
8588

89+
if namelen > _MAX_FUNCTION_NAME_LEN:
90+
logger.warning(f'Function name too long: {namelen}')
91+
return
92+
8693
# Receive function name + 2 FDs via SCM_RIGHTS
8794
fd_model = array.array('i', [0, 0])
8895
msg, ancdata, flags, addr = conn.recvmsg(
8996
namelen,
9097
socket.CMSG_LEN(2 * fd_model.itemsize),
9198
)
92-
if len(ancdata) != 1:
93-
logger.warning(f'Expected 1 ancdata, got {len(ancdata)}')
94-
return
9599

96-
function_name = msg.decode('utf8')
97-
input_fd, output_fd = struct.unpack('<ii', ancdata[0][2])
100+
# Validate ancdata and extract FDs
101+
received_fds: list[int] = []
102+
try:
103+
if len(ancdata) != 1:
104+
logger.warning(f'Expected 1 ancdata, got {len(ancdata)}')
105+
return
106+
107+
level, type_, fd_data = ancdata[0]
108+
if level != socket.SOL_SOCKET or type_ != socket.SCM_RIGHTS:
109+
logger.warning(
110+
f'Unexpected ancdata level={level} type={type_}',
111+
)
112+
return
113+
114+
if flags & getattr(socket, 'MSG_CTRUNC', 0):
115+
logger.warning('Ancillary data was truncated (MSG_CTRUNC)')
116+
return
117+
118+
fd_array = array.array('i')
119+
fd_array.frombytes(fd_data)
120+
received_fds = list(fd_array)
121+
122+
if len(received_fds) != 2:
123+
logger.warning(
124+
f'Expected 2 FDs, got {len(received_fds)}',
125+
)
126+
return
127+
128+
function_name = msg.decode('utf8')
129+
input_fd, output_fd = received_fds[0], received_fds[1]
130+
# Clear so finally doesn't close FDs we're handing off
131+
received_fds = []
132+
finally:
133+
# Close any received FDs if we're returning early
134+
for fd in received_fds:
135+
try:
136+
os.close(fd)
137+
except OSError:
138+
pass
98139

99140
# --- Control signal path ---
100141
if function_name.startswith('@@'):
@@ -160,7 +201,7 @@ def _handle_control_signal(
160201
else:
161202
os.ftruncate(output_fd, max(_MIN_OUTPUT_SIZE, response_size))
162203
os.lseek(output_fd, 0, os.SEEK_SET)
163-
os.write(output_fd, response_bytes)
204+
_write_all_fd(output_fd, response_bytes)
164205

165206
# Send [status=200, size]
166207
conn.sendall(struct.pack('<QQ', STATUS_OK, response_size))
@@ -298,7 +339,7 @@ def _handle_udf_loop(
298339
os.ftruncate(output_fd, needed)
299340
current_output_size = needed
300341
os.lseek(output_fd, 0, os.SEEK_SET)
301-
os.write(output_fd, output_data)
342+
_write_all_fd(output_fd, output_data)
302343
if profile:
303344
t_mmap_write += time.monotonic() - t0
304345

@@ -359,3 +400,17 @@ def _recv_exact_py(sock: socket.socket, n: int) -> bytes | None:
359400
return None
360401
pos += nbytes
361402
return bytes(buf)
403+
404+
405+
def _write_all_fd(fd: int, data: bytes) -> None:
406+
"""Write all bytes to a file descriptor, handling partial writes."""
407+
view = memoryview(data)
408+
written = 0
409+
while written < len(data):
410+
try:
411+
n = os.write(fd, view[written:])
412+
except InterruptedError:
413+
continue
414+
if n == 0:
415+
raise RuntimeError('short write to output fd')
416+
written += n

singlestoredb/functions/ext/collocated/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def call_function(
458458
results = []
459459
for row in rows:
460460
result = func(*row)
461-
if not isinstance(result, tuple):
461+
if not isinstance(result, (tuple, list)):
462462
result = [result]
463463
results.append(list(result))
464464
return bytes(_dump_rowdat_1(return_types, row_ids, results))

singlestoredb/functions/ext/collocated/server.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Optional
2525
from typing import Tuple
2626

27+
from .connection import _write_all_fd
2728
from .connection import handle_connection
2829
from .registry import FunctionRegistry
2930

@@ -61,7 +62,7 @@ def _write_pipe_message(fd: int, payload: bytes) -> None:
6162
Wire format: [u32 LE length][payload].
6263
"""
6364
header = struct.pack('<I', len(payload))
64-
os.write(fd, header + payload)
65+
_write_all_fd(fd, header + payload)
6566

6667

6768
class SharedRegistry:
@@ -252,7 +253,14 @@ def _run_process_mode(
252253
kills and re-forks all workers so every worker has the updated
253254
registry state.
254255
"""
255-
ctx = multiprocessing.get_context('fork')
256+
try:
257+
ctx = multiprocessing.get_context('fork')
258+
except ValueError:
259+
raise RuntimeError(
260+
"Process mode requires 'fork' multiprocessing context, "
261+
'which is not available on this platform. '
262+
"Use process_mode='thread' instead.",
263+
)
256264
# workers[wid] = (process, pipe_read_fd)
257265
workers: Dict[
258266
int,
@@ -407,11 +415,13 @@ def _worker_signal_handler(
407415
signal.signal(signal.SIGTERM, _worker_signal_handler)
408416
signal.signal(signal.SIGINT, signal.SIG_IGN)
409417

410-
# Set non-blocking so accept() raises BlockingIOError
411-
# instead of blocking when another worker wins the race.
412-
# O_NONBLOCK is on the open file description (shared across
413-
# forked processes), but that's fine — all workers want
414-
# non-blocking accept and the parent doesn't call accept.
418+
# WARNING: setblocking(False) sets O_NONBLOCK on the open
419+
# file description, which is shared across all forked
420+
# processes. This is intentional here — all workers need
421+
# non-blocking accept() to handle the thundering-herd race,
422+
# and the parent process never calls accept() on this
423+
# socket. Do NOT add blocking operations on this socket
424+
# in the parent process after workers are forked.
415425
server_sock.setblocking(False)
416426

417427
registry = self.shared_registry.get_thread_local_registry()

singlestoredb/functions/ext/collocated/wasm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ def create_function(
6060
code: str,
6161
replace: bool,
6262
) -> None:
63-
"""Register a function from its signature and Python source code."""
63+
"""Register a function from its signature and function body.
64+
65+
The ``code`` parameter should contain the function body, not a
66+
full ``def`` statement or ``@udf``-decorated source.
67+
"""
6468
try:
6569
_registry.create_function(signature, code, replace)
6670
except Exception as e:

singlestoredb/utils/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
try:
88
from IPython import get_ipython
99
has_ipython = True
10-
except Exception:
10+
except (ImportError, OSError):
1111
has_ipython = False
1212

1313

wit/udf.wit

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ interface function-handler {
1414
/// args_data_format, returns_data_format, function_type, doc
1515
describe-functions: func() -> result<string, string>;
1616

17-
/// Register a function from its signature and Python source code.
17+
/// Register a function from its signature and source code.
1818
/// `signature` is a JSON object matching the describe-functions element schema.
19-
/// `code` is the Python source containing the @udf-decorated function.
19+
/// `code` is the function body (not a full `def` statement).
2020
/// `replace` controls whether an existing function of the same name is overwritten.
2121
create-function: func(signature: string, code: string, replace: bool) -> result<_, string>;
2222
}

0 commit comments

Comments
 (0)