Skip to content

Commit e4d8627

Browse files
kesmit13claude
andcommitted
Fix @@register propagation in collocated server process mode
Each forked worker previously created its own independent SharedRegistry and FunctionRegistry. When @@register arrived at a worker, only that worker's local registry was updated — the main process and sibling workers never learned about the new function. Add Unix pipe-based IPC (matching the R UDF server fix): each worker gets a pipe back to the main process. When a worker handles @@register, it writes the registration payload to its pipe. The main process reads it via select.poll(), applies the registration to its own SharedRegistry, then kills and re-forks all workers so they inherit the updated state. Thread mode is unaffected — pipe_write_fd is None and the pipe write is a no-op. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 248f897 commit e4d8627

File tree

3 files changed

+189
-40
lines changed

3 files changed

+189
-40
lines changed

singlestoredb/functions/ext/collocated/connection.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ def handle_connection(
3939
conn: socket.socket,
4040
shared_registry: SharedRegistry,
4141
shutdown_event: threading.Event,
42+
pipe_write_fd: int | None = None,
4243
) -> None:
4344
"""Handle a single client connection (runs in a thread pool worker)."""
4445
try:
45-
_handle_connection_inner(conn, shared_registry, shutdown_event)
46+
_handle_connection_inner(
47+
conn, shared_registry, shutdown_event, pipe_write_fd,
48+
)
4649
except Exception:
4750
logger.error(f'Connection error:\n{traceback.format_exc()}')
4851
finally:
@@ -56,6 +59,7 @@ def _handle_connection_inner(
5659
conn: socket.socket,
5760
shared_registry: SharedRegistry,
5861
shutdown_event: threading.Event,
62+
pipe_write_fd: int | None = None,
5963
) -> None:
6064
"""Inner connection handler (may raise)."""
6165
# --- Handshake ---
@@ -87,6 +91,7 @@ def _handle_connection_inner(
8791
logger.info(f"Received control signal '{function_name}'")
8892
_handle_control_signal(
8993
conn, function_name, input_fd, output_fd, shared_registry,
94+
pipe_write_fd,
9095
)
9196
return
9297

@@ -104,6 +109,7 @@ def _handle_control_signal(
104109
input_fd: int,
105110
output_fd: int,
106111
shared_registry: SharedRegistry,
112+
pipe_write_fd: int | None = None,
107113
) -> None:
108114
"""Handle a @@-prefixed control signal (one-shot request-response)."""
109115
try:
@@ -126,7 +132,7 @@ def _handle_control_signal(
126132

127133
# Dispatch
128134
result = dispatch_control_signal(
129-
signal_name, request_data, shared_registry,
135+
signal_name, request_data, shared_registry, pipe_write_fd,
130136
)
131137

132138
if result.ok:

singlestoredb/functions/ext/collocated/control.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def dispatch_control_signal(
2929
signal_name: str,
3030
request_data: bytes,
3131
shared_registry: SharedRegistry,
32+
pipe_write_fd: int | None = None,
3233
) -> ControlResult:
3334
"""Dispatch a control signal to the appropriate handler."""
3435
try:
@@ -37,7 +38,9 @@ def dispatch_control_signal(
3738
elif signal_name == '@@functions':
3839
return _handle_functions(shared_registry)
3940
elif signal_name == '@@register':
40-
return _handle_register(request_data, shared_registry)
41+
return _handle_register(
42+
request_data, shared_registry, pipe_write_fd,
43+
)
4144
else:
4245
return ControlResult(
4346
ok=False,
@@ -62,8 +65,14 @@ def _handle_functions(shared_registry: SharedRegistry) -> ControlResult:
6265
def _handle_register(
6366
request_data: bytes,
6467
shared_registry: SharedRegistry,
68+
pipe_write_fd: int | None = None,
6569
) -> ControlResult:
66-
"""Handle @@register: register a new function dynamically."""
70+
"""Handle @@register: register a new function dynamically.
71+
72+
If ``pipe_write_fd`` is not None (process mode), the registration
73+
payload is written to the pipe so the main process can update its
74+
own registry and re-fork all workers.
75+
"""
6776
if not request_data:
6877
return ControlResult(ok=False, data='Missing registration payload')
6978

@@ -111,5 +120,15 @@ def _handle_register(
111120
except Exception as e:
112121
return ControlResult(ok=False, data=str(e))
113122

123+
# Notify main process so it can re-fork workers with updated state
124+
if pipe_write_fd is not None:
125+
from .server import _write_pipe_message
126+
payload = json.dumps({
127+
'signature_json': signature,
128+
'code': func_body,
129+
'replace': replace,
130+
}).encode()
131+
_write_pipe_message(pipe_write_fd, payload)
132+
114133
logger.info(f"@@register: added function '{function_name}'")
115134
return ControlResult(ok=True, data='{"status":"ok"}')

singlestoredb/functions/ext/collocated/server.py

Lines changed: 160 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
counter caching for thread-safe live reload.
77
"""
88
import importlib
9+
import json
910
import logging
1011
import multiprocessing
1112
import os
1213
import select
1314
import signal
1415
import socket
16+
import struct
1517
import sys
1618
import threading
1719
import traceback
@@ -28,6 +30,40 @@
2830
logger = logging.getLogger('collocated.server')
2931

3032

33+
def _read_pipe_message(fd: int) -> Optional[bytes]:
34+
"""Read a length-prefixed message from a pipe fd.
35+
36+
Wire format: [u32 LE length][payload].
37+
Returns None on EOF or short read.
38+
"""
39+
try:
40+
len_buf = b''
41+
while len(len_buf) < 4:
42+
chunk = os.read(fd, 4 - len(len_buf))
43+
if not chunk:
44+
return None
45+
len_buf += chunk
46+
length = struct.unpack('<I', len_buf)[0]
47+
payload = b''
48+
while len(payload) < length:
49+
chunk = os.read(fd, length - len(payload))
50+
if not chunk:
51+
return None
52+
payload += chunk
53+
return payload
54+
except OSError:
55+
return None
56+
57+
58+
def _write_pipe_message(fd: int, payload: bytes) -> None:
59+
"""Write a length-prefixed message to a pipe fd.
60+
61+
Wire format: [u32 LE length][payload].
62+
"""
63+
header = struct.pack('<I', len(payload))
64+
os.write(fd, header + payload)
65+
66+
3167
class SharedRegistry:
3268
"""Thread-safe wrapper around FunctionRegistry with generation caching.
3369
@@ -208,23 +244,66 @@ def _run_process_mode(
208244
server_sock: socket.socket,
209245
n_workers: int,
210246
) -> None:
211-
"""Pre-fork worker pool for true CPU parallelism."""
212-
ctx = multiprocessing.get_context('fork')
213-
workers: Dict[int, multiprocessing.process.BaseProcess] = {}
247+
"""Pre-fork worker pool for true CPU parallelism.
214248
215-
def _spawn_worker(
216-
worker_id: int,
217-
) -> multiprocessing.process.BaseProcess:
249+
Each worker gets a pipe back to the main process. When a worker
250+
receives @@register, it writes the registration payload to its
251+
pipe. The main process reads it, updates its own registry, then
252+
kills and re-forks all workers so every worker has the updated
253+
registry state.
254+
"""
255+
ctx = multiprocessing.get_context('fork')
256+
# workers[wid] = (process, pipe_read_fd)
257+
workers: Dict[
258+
int,
259+
Tuple[multiprocessing.process.BaseProcess, int],
260+
] = {}
261+
262+
def _spawn_worker(worker_id: int) -> Tuple[
263+
multiprocessing.process.BaseProcess, int,
264+
]:
265+
pipe_r, pipe_w = os.pipe()
218266
p = ctx.Process(
219267
target=self._worker_process_main,
220-
args=(server_sock, worker_id),
268+
args=(server_sock, worker_id, pipe_w),
221269
daemon=True,
222270
)
223271
p.start()
272+
# Close the write end in the parent — only the child writes
273+
os.close(pipe_w)
224274
logger.info(
225275
f'Started worker {worker_id} (pid={p.pid})',
226276
)
227-
return p
277+
return p, pipe_r
278+
279+
def _kill_all_workers() -> None:
280+
"""SIGTERM all workers, wait, then SIGKILL stragglers."""
281+
for wid, (proc, pipe_r) in workers.items():
282+
if proc.is_alive():
283+
assert proc.pid is not None
284+
os.kill(proc.pid, signal.SIGTERM)
285+
for wid, (proc, pipe_r) in workers.items():
286+
proc.join(timeout=5.0)
287+
if proc.is_alive():
288+
logger.warning(
289+
f'Worker {wid} (pid={proc.pid}) '
290+
f'did not exit, terminating...',
291+
)
292+
proc.terminate()
293+
proc.join(timeout=2.0)
294+
# Close all pipe read fds
295+
for wid, (proc, pipe_r) in workers.items():
296+
try:
297+
os.close(pipe_r)
298+
except OSError:
299+
pass
300+
301+
def _respawn_all_workers() -> None:
302+
"""Kill all workers and re-fork them with fresh state."""
303+
_kill_all_workers()
304+
workers.clear()
305+
for i in range(n_workers):
306+
workers[i] = _spawn_worker(i)
228307

229308
# Fork initial workers
230309
logger.info(
@@ -233,11 +312,59 @@ def _spawn_worker(
233312
for i in range(n_workers):
234313
workers[i] = _spawn_worker(i)
235314

236-
# Monitor loop: restart dead workers
315+
# Monitor loop using poll() over pipe read fds
237316
try:
238317
while not self.shutdown_event.is_set():
239-
self.shutdown_event.wait(timeout=0.5)
240-
for wid, proc in list(workers.items()):
318+
poller = select.poll()
319+
fd_to_wid: Dict[int, int] = {}
320+
for wid, (proc, pipe_r) in workers.items():
321+
poller.register(
322+
pipe_r, select.POLLIN | select.POLLHUP,
323+
)
324+
fd_to_wid[pipe_r] = wid
325+
326+
events = poller.poll(500) # 500ms timeout
327+
328+
registration_received = False
329+
for fd, event in events:
330+
if fd not in fd_to_wid:
331+
continue
332+
wid = fd_to_wid[fd]
333+
334+
if event & select.POLLIN:
335+
msg = _read_pipe_message(fd)
336+
if msg is not None:
337+
# Apply registration to main's registry
338+
try:
339+
body = json.loads(msg)
340+
self.shared_registry.create_function(
341+
body['signature_json'],
342+
body['code'],
343+
body['replace'],
344+
)
345+
logger.info(
346+
'Main process: applied '
347+
'@@register from worker '
348+
f'{wid}, will re-fork all '
349+
'workers',
350+
)
351+
registration_received = True
352+
except Exception:
353+
logger.error(
354+
'Main process: failed to '
355+
'apply @@register:\n'
356+
f'{traceback.format_exc()}',
357+
)
358+
elif event & select.POLLHUP:
359+
# Worker died — will be respawned below
360+
pass
361+
362+
if registration_received:
363+
_respawn_all_workers()
364+
continue
365+
366+
# Check for dead workers and respawn individually
367+
for wid, (proc, pipe_r) in list(workers.items()):
241368
if not proc.is_alive():
242369
exitcode = proc.exitcode
243370
if not self.shutdown_event.is_set():
@@ -246,39 +373,29 @@ def _spawn_worker(
246373
f'exited with code {exitcode}, '
247374
f'restarting...',
248375
)
376+
try:
377+
os.close(pipe_r)
378+
except OSError:
379+
pass
249380
workers[wid] = _spawn_worker(wid)
250381
finally:
251382
logger.info('Shutting down worker processes...')
252-
# Signal all workers to stop
253-
for wid, proc in workers.items():
254-
if proc.is_alive():
255-
assert proc.pid is not None
256-
os.kill(proc.pid, signal.SIGTERM)
257-
258-
# Wait for graceful exit
259-
for wid, proc in workers.items():
260-
proc.join(timeout=5.0)
261-
if proc.is_alive():
262-
logger.warning(
263-
f'Worker {wid} (pid={proc.pid}) '
264-
f'did not exit, terminating...',
265-
)
266-
proc.terminate()
267-
proc.join(timeout=2.0)
383+
_kill_all_workers()
268384

269385
def _worker_process_main(
270386
self,
271387
server_sock: socket.socket,
272388
worker_id: int,
389+
pipe_w: int,
273390
) -> None:
274-
"""Entry point for each forked worker process."""
275-
try:
276-
# Each worker gets its own registry and shutdown event
277-
local_shared = SharedRegistry()
278-
local_registry = FunctionRegistry()
279-
local_registry.initialize()
280-
local_shared.set_base_registry(local_registry)
391+
"""Entry point for each forked worker process.
281392
393+
Uses ``self.shared_registry`` inherited via fork (contains the
394+
main process's current state). ``pipe_w`` is used to notify the
395+
main process when @@register is handled so it can re-fork all
396+
workers.
397+
"""
398+
try:
282399
local_shutdown = threading.Event()
283400

284401
def _worker_signal_handler(
@@ -297,9 +414,10 @@ def _worker_signal_handler(
297414
# non-blocking accept and the parent doesn't call accept.
298415
server_sock.setblocking(False)
299416

417+
registry = self.shared_registry.get_thread_local_registry()
300418
logger.info(
301419
f'Worker {worker_id} (pid={os.getpid()}) ready, '
302-
f'{len(local_registry.functions)} function(s)',
420+
f'{len(registry.functions)} function(s)',
303421
)
304422

305423
# Accept loop
@@ -322,15 +440,21 @@ def _worker_signal_handler(
322440

323441
handle_connection(
324442
conn,
325-
local_shared,
443+
self.shared_registry,
326444
local_shutdown,
445+
pipe_write_fd=pipe_w,
327446
)
328447
except Exception:
329448
logger.error(
330449
f'Worker {worker_id} crashed:\n'
331450
f'{traceback.format_exc()}',
332451
)
333452
raise
453+
finally:
454+
try:
455+
os.close(pipe_w)
456+
except OSError:
457+
pass
334458

335459
def _initialize_registry(self) -> FunctionRegistry:
336460
"""Import the extension module and discover @udf functions."""

0 commit comments

Comments
 (0)