|
41 | 41 | # Pre-pack the status OK header prefix to avoid per-request struct.pack |
42 | 42 | _STATUS_OK_PREFIX = struct.pack('<Q', STATUS_OK) |
43 | 43 |
|
| 44 | +# Maximum function name length to prevent resource exhaustion |
| 45 | +_MAX_FUNCTION_NAME_LEN = 4096 |
| 46 | + |
44 | 47 | # Enable per-request timing via environment variable |
45 | 48 | _PROFILE = os.environ.get('SINGLESTOREDB_UDF_PROFILE', '') == '1' |
46 | 49 |
|
@@ -83,18 +86,56 @@ def _handle_connection_inner( |
83 | 86 | logger.warning(f'Unsupported protocol version: {version}') |
84 | 87 | return |
85 | 88 |
|
| 89 | + if namelen > _MAX_FUNCTION_NAME_LEN: |
| 90 | + logger.warning(f'Function name too long: {namelen}') |
| 91 | + return |
| 92 | + |
86 | 93 | # Receive function name + 2 FDs via SCM_RIGHTS |
87 | 94 | fd_model = array.array('i', [0, 0]) |
88 | 95 | msg, ancdata, flags, addr = conn.recvmsg( |
89 | 96 | namelen, |
90 | 97 | socket.CMSG_LEN(2 * fd_model.itemsize), |
91 | 98 | ) |
92 | | - if len(ancdata) != 1: |
93 | | - logger.warning(f'Expected 1 ancdata, got {len(ancdata)}') |
94 | | - return |
95 | 99 |
|
96 | | - function_name = msg.decode('utf8') |
97 | | - input_fd, output_fd = struct.unpack('<ii', ancdata[0][2]) |
| 100 | + # Validate ancdata and extract FDs |
| 101 | + received_fds: list[int] = [] |
| 102 | + try: |
| 103 | + if len(ancdata) != 1: |
| 104 | + logger.warning(f'Expected 1 ancdata, got {len(ancdata)}') |
| 105 | + return |
| 106 | + |
| 107 | + level, type_, fd_data = ancdata[0] |
| 108 | + if level != socket.SOL_SOCKET or type_ != socket.SCM_RIGHTS: |
| 109 | + logger.warning( |
| 110 | + f'Unexpected ancdata level={level} type={type_}', |
| 111 | + ) |
| 112 | + return |
| 113 | + |
| 114 | + if flags & getattr(socket, 'MSG_CTRUNC', 0): |
| 115 | + logger.warning('Ancillary data was truncated (MSG_CTRUNC)') |
| 116 | + return |
| 117 | + |
| 118 | + fd_array = array.array('i') |
| 119 | + fd_array.frombytes(fd_data) |
| 120 | + received_fds = list(fd_array) |
| 121 | + |
| 122 | + if len(received_fds) != 2: |
| 123 | + logger.warning( |
| 124 | + f'Expected 2 FDs, got {len(received_fds)}', |
| 125 | + ) |
| 126 | + return |
| 127 | + |
| 128 | + function_name = msg.decode('utf8') |
| 129 | + input_fd, output_fd = received_fds[0], received_fds[1] |
| 130 | + # Clear so finally doesn't close FDs we're handing off |
| 131 | + received_fds = [] |
| 132 | + finally: |
| 133 | + # Close any received FDs if we're returning early |
| 134 | + for fd in received_fds: |
| 135 | + try: |
| 136 | + os.close(fd) |
| 137 | + except OSError: |
| 138 | + pass |
98 | 139 |
|
99 | 140 | # --- Control signal path --- |
100 | 141 | if function_name.startswith('@@'): |
@@ -160,7 +201,7 @@ def _handle_control_signal( |
160 | 201 | else: |
161 | 202 | os.ftruncate(output_fd, max(_MIN_OUTPUT_SIZE, response_size)) |
162 | 203 | os.lseek(output_fd, 0, os.SEEK_SET) |
163 | | - os.write(output_fd, response_bytes) |
| 204 | + _write_all_fd(output_fd, response_bytes) |
164 | 205 |
|
165 | 206 | # Send [status=200, size] |
166 | 207 | conn.sendall(struct.pack('<QQ', STATUS_OK, response_size)) |
@@ -298,7 +339,7 @@ def _handle_udf_loop( |
298 | 339 | os.ftruncate(output_fd, needed) |
299 | 340 | current_output_size = needed |
300 | 341 | os.lseek(output_fd, 0, os.SEEK_SET) |
301 | | - os.write(output_fd, output_data) |
| 342 | + _write_all_fd(output_fd, output_data) |
302 | 343 | if profile: |
303 | 344 | t_mmap_write += time.monotonic() - t0 |
304 | 345 |
|
@@ -359,3 +400,17 @@ def _recv_exact_py(sock: socket.socket, n: int) -> bytes | None: |
359 | 400 | return None |
360 | 401 | pos += nbytes |
361 | 402 | return bytes(buf) |
| 403 | + |
| 404 | + |
| 405 | +def _write_all_fd(fd: int, data: bytes) -> None: |
| 406 | + """Write all bytes to a file descriptor, handling partial writes.""" |
| 407 | + view = memoryview(data) |
| 408 | + written = 0 |
| 409 | + while written < len(data): |
| 410 | + try: |
| 411 | + n = os.write(fd, view[written:]) |
| 412 | + except InterruptedError: |
| 413 | + continue |
| 414 | + if n == 0: |
| 415 | + raise RuntimeError('short write to output fd') |
| 416 | + written += n |
0 commit comments