diff --git a/cluster_tools/cluster_tools/_utils/multiprocessing_logging_handler.py b/cluster_tools/cluster_tools/_utils/multiprocessing_logging_handler.py index a40ee29ec..b5978d07e 100644 --- a/cluster_tools/cluster_tools/_utils/multiprocessing_logging_handler.py +++ b/cluster_tools/cluster_tools/_utils/multiprocessing_logging_handler.py @@ -15,6 +15,10 @@ # Inspired by https://stackoverflow.com/a/894284 +# Since the root logger object is global, accessing it from multiple threads can be problematic +# To be thread safe, we use a lock. +ROOT_LOGGER_LOCK = threading.Lock() + class _MultiprocessingLoggingHandler(logging.Handler): """This class wraps a logging handler and instantiates a multiprocessing queue. @@ -87,14 +91,18 @@ def decrement_usage(self) -> None: self._usage_counter -= 1 if self._usage_counter == 0: # unwrap inner handler: - root_logger = getLogger() - root_logger.removeHandler(self) - root_logger.addHandler(self.wrapped_handler) + with ROOT_LOGGER_LOCK: + root_logger = getLogger() + root_logger.removeHandler(self) + root_logger.addHandler(self.wrapped_handler) self._is_closed = True self._queue_thread.join(30) self._manager.shutdown() - self.wrapped_handler.close() + # Thread-owned handlers (e.g. task-specific file handlers) are closed by their + # owner (attach_logging_handler), not by the pool. + if not getattr(self.wrapped_handler, "_owner_thread_id", None): + self.wrapped_handler.close() super().close() def close(self) -> None: @@ -115,35 +123,51 @@ def _setup_logging_multiprocessing( """ warnings.filters = filters - root_logger = getLogger() - for handler in root_logger.handlers: - root_logger.removeHandler(handler) + with ROOT_LOGGER_LOCK: + root_logger = getLogger() + for handler in root_logger.handlers: + root_logger.removeHandler(handler) - root_logger.setLevel(min(levels) if len(levels) else logging.DEBUG) - for queue, level in zip(queues, levels): - handler = QueueHandler(queue) - handler.setLevel(level) - root_logger.addHandler(handler) + root_logger.setLevel(min(levels) if len(levels) else logging.DEBUG) + for queue, level in zip(queues, levels): + handler = QueueHandler(queue) + handler.setLevel(level) + root_logger.addHandler(handler) class _MultiprocessingLoggingHandlerPool: def __init__(self) -> None: - root_logger = getLogger() - - self.handlers = [] - for i, handler in enumerate(list(root_logger.handlers)): - # Wrap logging handlers in _MultiprocessingLoggingHandlers to make them work in a multiprocessing setup - # when using start_methods other than fork, for example, spawn or forkserver - if not isinstance(handler, _MultiprocessingLoggingHandler): - mp_handler = _MultiprocessingLoggingHandler( - f"multi-processing-handler-{i}", handler + with ROOT_LOGGER_LOCK: + root_logger = getLogger() + current_thread_id = threading.get_ident() + + self.handlers = [] + for i, handler in enumerate(list(root_logger.handlers)): + # Resolve the underlying handler to check for thread ownership. + # An already-wrapped handler exposes its original via .wrapped_handler. + underlying = ( + handler.wrapped_handler + if isinstance(handler, _MultiprocessingLoggingHandler) + else handler ) - root_logger.removeHandler(handler) - root_logger.addHandler(mp_handler) - self.handlers.append(mp_handler) - else: - handler.increment_usage() - self.handlers.append(handler) + owner_thread = getattr(underlying, "_owner_thread_id", None) + # Skip handlers owned by a different thread: wrapping them would transfer + # lifecycle ownership, causing premature closure of another task's file handler. + if owner_thread is not None and owner_thread != current_thread_id: + continue + + # Wrap logging handlers in _MultiprocessingLoggingHandlers to make them work in a multiprocessing setup + # when using start_methods other than fork, for example, spawn or forkserver + if not isinstance(handler, _MultiprocessingLoggingHandler): + mp_handler = _MultiprocessingLoggingHandler( + f"multi-processing-handler-{i}", handler + ) + root_logger.removeHandler(handler) + root_logger.addHandler(mp_handler) + self.handlers.append(mp_handler) + else: + handler.increment_usage() + self.handlers.append(handler) def get_multiprocessing_logging_setup_fn(self) -> Callable[[], None]: # Return a logging setup function that when called will setup QueueHandler loggers diff --git a/cluster_tools/cluster_tools/schedulers/cluster_executor.py b/cluster_tools/cluster_tools/schedulers/cluster_executor.py index 75c498656..fe57ba55e 100644 --- a/cluster_tools/cluster_tools/schedulers/cluster_executor.py +++ b/cluster_tools/cluster_tools/schedulers/cluster_executor.py @@ -1,3 +1,4 @@ +import atexit import logging import os import signal @@ -69,6 +70,7 @@ class ClusterExecutor(futures.Executor): _shutdown_hooks: list[Callable[[], None]] = [] _installed_signal_handler: bool = False + _installed_atexit_handler: bool = False def __init__( self, @@ -142,9 +144,24 @@ def executor_key(cls) -> str: @classmethod def _ensure_signal_handlers_are_installed(cls) -> None: - # Only overwrite the signal handler once + if not cls._installed_atexit_handler: + atexit.register(cls._run_shutdown_hooks) + cls._installed_atexit_handler = True + + # signal.signal() only works from the main thread. If we're on a worker + # thread, skip but don't mark as installed — a later main-thread + # instantiation will still be able to install the handlers. if cls._installed_signal_handler: return + if threading.current_thread() is not threading.main_thread(): + logging.warning( + f"[{cls.__name__}] Cannot install signal handlers because the executor " + "was instantiated from a non-main thread. Cleanup on SIGTERM will not " + "work; SIGINT and normal exits are covered via atexit. " + f"Call {cls.__name__}.install_signal_handlers() from the main thread " + "at startup to enable full signal handling." + ) + return # Clean up if a SIGINT or SIGTERM signal is received. However, do not # interfere with the existing signal handler of the process and execute @@ -162,6 +179,20 @@ def _ensure_signal_handlers_are_installed(cls) -> None: cls._installed_signal_handler = True + @classmethod + def install_signal_handlers(cls) -> None: + """Install SIGINT/SIGTERM handlers for cluster job cleanup. + + Must be called from the main thread. Call this once at program startup + before using ClusterExecutor from worker threads to ensure that + SIGINT and SIGTERM trigger proper job cancellation and cleanup. + """ + if threading.current_thread() is not threading.main_thread(): + raise RuntimeError( + f"{cls.__name__}.install_signal_handlers() must be called from the main thread." + ) + cls._ensure_signal_handlers_are_installed() + @classmethod def _register_shutdown_hook(cls, hook: Callable[[], None]) -> None: cls._shutdown_hooks.append(hook) @@ -176,6 +207,14 @@ def _deregister_shutdown_hook(cls, hook: Callable[[], None]) -> None: "Cannot deregister executors shutdown hook since it's not registered." ) + @classmethod + def _run_shutdown_hooks(cls) -> None: + for hook in cls._shutdown_hooks: + try: + hook() + except Exception as e: + print(f"Error during shutdown: {e}") + @classmethod def _handle_shutdown( cls, @@ -186,11 +225,7 @@ def _handle_shutdown( logging.critical( f"[{cls.__name__}] Caught signal {signal.Signals(signum).name}, running shutdown hooks" ) - try: - for hook in cls._shutdown_hooks: - hook() - except Exception as e: - print(f"Error during shutdown: {e}") + cls._run_shutdown_hooks() if ( callable(existing_signal_handler)