66counter caching for thread-safe live reload.
77"""
88import importlib
9+ import json
910import logging
1011import multiprocessing
1112import os
1213import select
1314import signal
1415import socket
16+ import struct
1517import sys
1618import threading
1719import traceback
2830logger = logging .getLogger ('collocated.server' )
2931
3032
33+ def _read_pipe_message (fd : int ) -> Optional [bytes ]:
34+ """Read a length-prefixed message from a pipe fd.
35+
36+ Wire format: [u32 LE length][payload].
37+ Returns None on EOF or short read.
38+ """
39+ try :
40+ len_buf = b''
41+ while len (len_buf ) < 4 :
42+ chunk = os .read (fd , 4 - len (len_buf ))
43+ if not chunk :
44+ return None
45+ len_buf += chunk
46+ length = struct .unpack ('<I' , len_buf )[0 ]
47+ payload = b''
48+ while len (payload ) < length :
49+ chunk = os .read (fd , length - len (payload ))
50+ if not chunk :
51+ return None
52+ payload += chunk
53+ return payload
54+ except OSError :
55+ return None
56+
57+
58+ def _write_pipe_message (fd : int , payload : bytes ) -> None :
59+ """Write a length-prefixed message to a pipe fd.
60+
61+ Wire format: [u32 LE length][payload].
62+ """
63+ header = struct .pack ('<I' , len (payload ))
64+ os .write (fd , header + payload )
65+
66+
3167class SharedRegistry :
3268 """Thread-safe wrapper around FunctionRegistry with generation caching.
3369
@@ -208,23 +244,66 @@ def _run_process_mode(
208244 server_sock : socket .socket ,
209245 n_workers : int ,
210246 ) -> None :
211- """Pre-fork worker pool for true CPU parallelism."""
212- ctx = multiprocessing .get_context ('fork' )
213- workers : Dict [int , multiprocessing .process .BaseProcess ] = {}
247+ """Pre-fork worker pool for true CPU parallelism.
214248
215- def _spawn_worker (
216- worker_id : int ,
217- ) -> multiprocessing .process .BaseProcess :
249+ Each worker gets a pipe back to the main process. When a worker
250+ receives @@register, it writes the registration payload to its
251+ pipe. The main process reads it, updates its own registry, then
252+ kills and re-forks all workers so every worker has the updated
253+ registry state.
254+ """
255+ ctx = multiprocessing .get_context ('fork' )
256+ # workers[wid] = (process, pipe_read_fd)
257+ workers : Dict [
258+ int ,
259+ Tuple [multiprocessing .process .BaseProcess , int ],
260+ ] = {}
261+
262+ def _spawn_worker (worker_id : int ) -> Tuple [
263+ multiprocessing .process .BaseProcess , int ,
264+ ]:
265+ pipe_r , pipe_w = os .pipe ()
218266 p = ctx .Process (
219267 target = self ._worker_process_main ,
220- args = (server_sock , worker_id ),
268+ args = (server_sock , worker_id , pipe_w ),
221269 daemon = True ,
222270 )
223271 p .start ()
272+ # Close the write end in the parent — only the child writes
273+ os .close (pipe_w )
224274 logger .info (
225275 f'Started worker { worker_id } (pid={ p .pid } )' ,
226276 )
227- return p
277+ return p , pipe_r
278+
279+ def _kill_all_workers () -> None :
280+ """SIGTERM all workers, wait, then SIGKILL stragglers."""
281+ for wid , (proc , pipe_r ) in workers .items ():
282+ if proc .is_alive ():
283+ assert proc .pid is not None
284+ os .kill (proc .pid , signal .SIGTERM )
285+ for wid , (proc , pipe_r ) in workers .items ():
286+ proc .join (timeout = 5.0 )
287+ if proc .is_alive ():
288+ logger .warning (
289+ f'Worker { wid } (pid={ proc .pid } ) '
290+ f'did not exit, terminating...' ,
291+ )
292+ proc .terminate ()
293+ proc .join (timeout = 2.0 )
294+ # Close all pipe read fds
295+ for wid , (proc , pipe_r ) in workers .items ():
296+ try :
297+ os .close (pipe_r )
298+ except OSError :
299+ pass
300+
301+ def _respawn_all_workers () -> None :
302+ """Kill all workers and re-fork them with fresh state."""
303+ _kill_all_workers ()
304+ workers .clear ()
305+ for i in range (n_workers ):
306+ workers [i ] = _spawn_worker (i )
228307
229308 # Fork initial workers
230309 logger .info (
@@ -233,11 +312,59 @@ def _spawn_worker(
233312 for i in range (n_workers ):
234313 workers [i ] = _spawn_worker (i )
235314
236- # Monitor loop: restart dead workers
315+ # Monitor loop using poll() over pipe read fds
237316 try :
238317 while not self .shutdown_event .is_set ():
239- self .shutdown_event .wait (timeout = 0.5 )
240- for wid , proc in list (workers .items ()):
318+ poller = select .poll ()
319+ fd_to_wid : Dict [int , int ] = {}
320+ for wid , (proc , pipe_r ) in workers .items ():
321+ poller .register (
322+ pipe_r , select .POLLIN | select .POLLHUP ,
323+ )
324+ fd_to_wid [pipe_r ] = wid
325+
326+ events = poller .poll (500 ) # 500ms timeout
327+
328+ registration_received = False
329+ for fd , event in events :
330+ if fd not in fd_to_wid :
331+ continue
332+ wid = fd_to_wid [fd ]
333+
334+ if event & select .POLLIN :
335+ msg = _read_pipe_message (fd )
336+ if msg is not None :
337+ # Apply registration to main's registry
338+ try :
339+ body = json .loads (msg )
340+ self .shared_registry .create_function (
341+ body ['signature_json' ],
342+ body ['code' ],
343+ body ['replace' ],
344+ )
345+ logger .info (
346+ 'Main process: applied '
347+ '@@register from worker '
348+ f'{ wid } , will re-fork all '
349+ 'workers' ,
350+ )
351+ registration_received = True
352+ except Exception :
353+ logger .error (
354+ 'Main process: failed to '
355+ 'apply @@register:\n '
356+ f'{ traceback .format_exc ()} ' ,
357+ )
358+ elif event & select .POLLHUP :
359+ # Worker died — will be respawned below
360+ pass
361+
362+ if registration_received :
363+ _respawn_all_workers ()
364+ continue
365+
366+ # Check for dead workers and respawn individually
367+ for wid , (proc , pipe_r ) in list (workers .items ()):
241368 if not proc .is_alive ():
242369 exitcode = proc .exitcode
243370 if not self .shutdown_event .is_set ():
@@ -246,39 +373,29 @@ def _spawn_worker(
246373 f'exited with code { exitcode } , '
247374 f'restarting...' ,
248375 )
376+ try :
377+ os .close (pipe_r )
378+ except OSError :
379+ pass
249380 workers [wid ] = _spawn_worker (wid )
250381 finally :
251382 logger .info ('Shutting down worker processes...' )
252- # Signal all workers to stop
253- for wid , proc in workers .items ():
254- if proc .is_alive ():
255- assert proc .pid is not None
256- os .kill (proc .pid , signal .SIGTERM )
257-
258- # Wait for graceful exit
259- for wid , proc in workers .items ():
260- proc .join (timeout = 5.0 )
261- if proc .is_alive ():
262- logger .warning (
263- f'Worker { wid } (pid={ proc .pid } ) '
264- f'did not exit, terminating...' ,
265- )
266- proc .terminate ()
267- proc .join (timeout = 2.0 )
383+ _kill_all_workers ()
268384
269385 def _worker_process_main (
270386 self ,
271387 server_sock : socket .socket ,
272388 worker_id : int ,
389+ pipe_w : int ,
273390 ) -> None :
274- """Entry point for each forked worker process."""
275- try :
276- # Each worker gets its own registry and shutdown event
277- local_shared = SharedRegistry ()
278- local_registry = FunctionRegistry ()
279- local_registry .initialize ()
280- local_shared .set_base_registry (local_registry )
391+ """Entry point for each forked worker process.
281392
393+ Uses ``self.shared_registry`` inherited via fork (contains the
394+ main process's current state). ``pipe_w`` is used to notify the
395+ main process when @@register is handled so it can re-fork all
396+ workers.
397+ """
398+ try :
282399 local_shutdown = threading .Event ()
283400
284401 def _worker_signal_handler (
@@ -297,9 +414,10 @@ def _worker_signal_handler(
297414 # non-blocking accept and the parent doesn't call accept.
298415 server_sock .setblocking (False )
299416
417+ registry = self .shared_registry .get_thread_local_registry ()
300418 logger .info (
301419 f'Worker { worker_id } (pid={ os .getpid ()} ) ready, '
302- f'{ len (local_registry .functions )} function(s)' ,
420+ f'{ len (registry .functions )} function(s)' ,
303421 )
304422
305423 # Accept loop
@@ -322,15 +440,21 @@ def _worker_signal_handler(
322440
323441 handle_connection (
324442 conn ,
325- local_shared ,
443+ self . shared_registry ,
326444 local_shutdown ,
445+ pipe_write_fd = pipe_w ,
327446 )
328447 except Exception :
329448 logger .error (
330449 f'Worker { worker_id } crashed:\n '
331450 f'{ traceback .format_exc ()} ' ,
332451 )
333452 raise
453+ finally :
454+ try :
455+ os .close (pipe_w )
456+ except OSError :
457+ pass
334458
335459 def _initialize_registry (self ) -> FunctionRegistry :
336460 """Import the extension module and discover @udf functions."""
0 commit comments