Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

# Inspired by https://stackoverflow.com/a/894284

# Since the root logger object is global, accessing it from multiple threads can be problematic
# To be thread safe, we use a lock.
ROOT_LOGGER_LOCK = threading.Lock()


class _MultiprocessingLoggingHandler(logging.Handler):
"""This class wraps a logging handler and instantiates a multiprocessing queue.
Expand Down Expand Up @@ -87,14 +91,18 @@ def decrement_usage(self) -> None:
self._usage_counter -= 1
if self._usage_counter == 0:
# unwrap inner handler:
root_logger = getLogger()
root_logger.removeHandler(self)
root_logger.addHandler(self.wrapped_handler)
with ROOT_LOGGER_LOCK:
root_logger = getLogger()
root_logger.removeHandler(self)
root_logger.addHandler(self.wrapped_handler)

self._is_closed = True
self._queue_thread.join(30)
self._manager.shutdown()
self.wrapped_handler.close()
# Thread-owned handlers (e.g. task-specific file handlers) are closed by their
# owner (attach_logging_handler), not by the pool.
if not getattr(self.wrapped_handler, "_owner_thread_id", None):
self.wrapped_handler.close()
super().close()
Comment on lines 91 to 106
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The decrement_usage method has several thread-safety and lifecycle issues:

  1. Non-atomic counter: The decrement and check of self._usage_counter (lines 91-92) are not synchronized. If shutdown() is called concurrently on the same executor from different threads, the counter could end up in an inconsistent state.
  2. Race condition on close: The wrapped_handler.close() call (line 105) happens outside the ROOT_LOGGER_LOCK. A new pool could be initialized in another thread, see the wrapped_handler (which was added back to the root logger at line 97), and wrap it before it is closed here. This would leave the new pool with a closed and broken handler.
  3. Global handler lifecycle: Closing handlers that don't have an _owner_thread_id (lines 104-105) will destroy global handlers (like a standard StreamHandler on the root logger) once the first executor finishes. This breaks logging for the rest of the application.

It is recommended to move the counter logic and the restoration/closure of the handler inside the ROOT_LOGGER_LOCK. Also, consider if the pool should be closing these handlers at all, as they were not created by the pool.

Suggested change
self._usage_counter -= 1
if self._usage_counter == 0:
# unwrap inner handler:
root_logger = getLogger()
root_logger.removeHandler(self)
root_logger.addHandler(self.wrapped_handler)
with ROOT_LOGGER_LOCK:
root_logger = getLogger()
root_logger.removeHandler(self)
root_logger.addHandler(self.wrapped_handler)
self._is_closed = True
self._queue_thread.join(30)
self._manager.shutdown()
self.wrapped_handler.close()
# Thread-owned handlers (e.g. task-specific file handlers) are closed by their
# owner (attach_logging_handler), not by the pool.
if not getattr(self.wrapped_handler, "_owner_thread_id", None):
self.wrapped_handler.close()
super().close()
with ROOT_LOGGER_LOCK:
self._usage_counter -= 1
if self._usage_counter == 0:
# unwrap inner handler:
root_logger = getLogger()
root_logger.removeHandler(self)
root_logger.addHandler(self.wrapped_handler)
# Thread-owned handlers (e.g. task-specific file handlers) are closed by their
# owner (attach_logging_handler), not by the pool.
if not getattr(self.wrapped_handler, "_owner_thread_id", None):
self.wrapped_handler.close()
if self._usage_counter == 0:
self._is_closed = True
self._queue_thread.join(30)
self._manager.shutdown()
super().close()


def close(self) -> None:
Expand All @@ -115,35 +123,51 @@ def _setup_logging_multiprocessing(
"""
warnings.filters = filters

root_logger = getLogger()
for handler in root_logger.handlers:
root_logger.removeHandler(handler)
with ROOT_LOGGER_LOCK:
root_logger = getLogger()
for handler in root_logger.handlers:
root_logger.removeHandler(handler)

root_logger.setLevel(min(levels) if len(levels) else logging.DEBUG)
for queue, level in zip(queues, levels):
handler = QueueHandler(queue)
handler.setLevel(level)
root_logger.addHandler(handler)
root_logger.setLevel(min(levels) if len(levels) else logging.DEBUG)
for queue, level in zip(queues, levels):
handler = QueueHandler(queue)
handler.setLevel(level)
root_logger.addHandler(handler)


class _MultiprocessingLoggingHandlerPool:
def __init__(self) -> None:
root_logger = getLogger()

self.handlers = []
for i, handler in enumerate(list(root_logger.handlers)):
# Wrap logging handlers in _MultiprocessingLoggingHandlers to make them work in a multiprocessing setup
# when using start_methods other than fork, for example, spawn or forkserver
if not isinstance(handler, _MultiprocessingLoggingHandler):
mp_handler = _MultiprocessingLoggingHandler(
f"multi-processing-handler-{i}", handler
with ROOT_LOGGER_LOCK:
root_logger = getLogger()
current_thread_id = threading.get_ident()

self.handlers = []
for i, handler in enumerate(list(root_logger.handlers)):
# Resolve the underlying handler to check for thread ownership.
# An already-wrapped handler exposes its original via .wrapped_handler.
underlying = (
handler.wrapped_handler
if isinstance(handler, _MultiprocessingLoggingHandler)
else handler
)
root_logger.removeHandler(handler)
root_logger.addHandler(mp_handler)
self.handlers.append(mp_handler)
else:
handler.increment_usage()
self.handlers.append(handler)
owner_thread = getattr(underlying, "_owner_thread_id", None)
# Skip handlers owned by a different thread: wrapping them would transfer
# lifecycle ownership, causing premature closure of another task's file handler.
if owner_thread is not None and owner_thread != current_thread_id:
continue

# Wrap logging handlers in _MultiprocessingLoggingHandlers to make them work in a multiprocessing setup
# when using start_methods other than fork, for example, spawn or forkserver
if not isinstance(handler, _MultiprocessingLoggingHandler):
mp_handler = _MultiprocessingLoggingHandler(
f"multi-processing-handler-{i}", handler
)
root_logger.removeHandler(handler)
root_logger.addHandler(mp_handler)
self.handlers.append(mp_handler)
else:
handler.increment_usage()
self.handlers.append(handler)

def get_multiprocessing_logging_setup_fn(self) -> Callable[[], None]:
# Return a logging setup function that when called will setup QueueHandler loggers
Expand Down
47 changes: 41 additions & 6 deletions cluster_tools/cluster_tools/schedulers/cluster_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import logging
import os
import signal
Expand Down Expand Up @@ -69,6 +70,7 @@ class ClusterExecutor(futures.Executor):

_shutdown_hooks: list[Callable[[], None]] = []
_installed_signal_handler: bool = False
_installed_atexit_handler: bool = False

def __init__(
self,
Expand Down Expand Up @@ -142,9 +144,24 @@ def executor_key(cls) -> str:

@classmethod
def _ensure_signal_handlers_are_installed(cls) -> None:
# Only overwrite the signal handler once
if not cls._installed_atexit_handler:
atexit.register(cls._run_shutdown_hooks)
cls._installed_atexit_handler = True

# signal.signal() only works from the main thread. If we're on a worker
# thread, skip but don't mark as installed — a later main-thread
# instantiation will still be able to install the handlers.
if cls._installed_signal_handler:
return
if threading.current_thread() is not threading.main_thread():
logging.warning(
f"[{cls.__name__}] Cannot install signal handlers because the executor "
"was instantiated from a non-main thread. Cleanup on SIGTERM will not "
"work; SIGINT and normal exits are covered via atexit. "
f"Call {cls.__name__}.install_signal_handlers() from the main thread "
"at startup to enable full signal handling."
)
return

# Clean up if a SIGINT or SIGTERM signal is received. However, do not
# interfere with the existing signal handler of the process and execute
Expand All @@ -162,6 +179,20 @@ def _ensure_signal_handlers_are_installed(cls) -> None:

cls._installed_signal_handler = True

@classmethod
def install_signal_handlers(cls) -> None:
"""Install SIGINT/SIGTERM handlers for cluster job cleanup.

Must be called from the main thread. Call this once at program startup
before using ClusterExecutor from worker threads to ensure that
SIGINT and SIGTERM trigger proper job cancellation and cleanup.
"""
if threading.current_thread() is not threading.main_thread():
raise RuntimeError(
f"{cls.__name__}.install_signal_handlers() must be called from the main thread."
)
cls._ensure_signal_handlers_are_installed()

@classmethod
def _register_shutdown_hook(cls, hook: Callable[[], None]) -> None:
cls._shutdown_hooks.append(hook)
Expand All @@ -176,6 +207,14 @@ def _deregister_shutdown_hook(cls, hook: Callable[[], None]) -> None:
"Cannot deregister executors shutdown hook since it's not registered."
)

@classmethod
def _run_shutdown_hooks(cls) -> None:
for hook in cls._shutdown_hooks:
try:
hook()
except Exception as e:
print(f"Error during shutdown: {e}")

@classmethod
def _handle_shutdown(
cls,
Expand All @@ -186,11 +225,7 @@ def _handle_shutdown(
logging.critical(
f"[{cls.__name__}] Caught signal {signal.Signals(signum).name}, running shutdown hooks"
)
try:
for hook in cls._shutdown_hooks:
hook()
except Exception as e:
print(f"Error during shutdown: {e}")
cls._run_shutdown_hooks()

if (
callable(existing_signal_handler)
Expand Down
Loading