Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
New Functionality
^^^^^^^^^^^^^^^^^

- Implement hot-restart functionality for Multi-user endpoint. See
:ref:`hot-restart` for full documentation, but the synopsis is send the
``SIGHUP`` signal to the MEP (parent) process. Currently, there is no
equivalent built-in sub-command to ``globus-compute-endpoint``.
5 changes: 5 additions & 0 deletions compute_endpoint/globus_compute_endpoint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def _do_start_endpoint(
reg_info = {}
config_str: str | None = None
audit_fd: int | None = None
restart_fd: int | None = None
fn_allow_list: list[str] | None | int = _no_fn_list_canary
if sys.stdin and not (sys.stdin.closed or sys.stdin.isatty()):
try:
Expand All @@ -593,6 +594,7 @@ def _do_start_endpoint(
reg_info = stdin_data.get("amqp_creds", {})
config_str = stdin_data.get("config")
audit_fd = stdin_data.get("audit_fd")
restart_fd = stdin_data.get("restart_fd")
fn_allow_list = stdin_data.get("allowed_functions", _no_fn_list_canary)

del stdin_data # clarity for intended scope
Expand Down Expand Up @@ -639,7 +641,10 @@ def _do_start_endpoint(
raise ClickException(
"multi-user endpoints are not supported on this system"
)

epm = EndpointManager(ep_dir, endpoint_uuid, ep_config, reg_info)
if restart_fd:
epm._finish_hot_restart(restart_fd)
epm.start()
else:
assert isinstance(ep_config, UserEndpointConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import pathlib
import pickle
import platform
import pwd
import queue
Expand Down Expand Up @@ -157,8 +158,8 @@ def __init__(
else:
_import_pyprctl()

self._reload_requested = False
self._time_to_stop = False
self._restart = False

self._heartbeat_period: float = max(MINIMUM_HEARTBEAT, config.heartbeat_period)

Expand All @@ -174,7 +175,7 @@ def __init__(
self._cached_cmd_start_args: TTLCache[int, T_CMD_START_ARGS] = TTLCache(
maxsize=32768, ttl=config.mu_child_ep_grace_period_s
)
self._audit_pipes: dict[int, t.Any] = {}
self._audit_pipes: dict[int, dict[str, int | str]] = {}
self._audit_log_handler_stop = not (
self._config.high_assurance and bool(self._config.audit_log_path)
)
Expand Down Expand Up @@ -372,6 +373,9 @@ def get_metadata(self, config: ManagerEndpointConfig) -> dict:
"user_config_schema": user_config_schema,
}

def request_restart(self, sig_num, curr_stack_frame):
self._restart = True

def request_shutdown(self, sig_num, curr_stack_frame):
self._time_to_stop = True

Expand Down Expand Up @@ -488,12 +492,13 @@ def _audit_log_write(self, fd: int, fpath: io.BytesIO):
uid = uep_audit_info.get("uid")
eid = uep_audit_info.get("endpoint_id")
try:
msg = (
os.read(fd, self._audit_buf_size)
.replace(b"\n", b" ")
.replace(b"\r", b"")
.replace(b"\0", b"")
)
with self._audit_log_lock:
msg = (
os.read(fd, self._audit_buf_size)
.replace(b"\n", b" ")
.replace(b"\r", b"")
.replace(b"\0", b"")
)
if not msg:
self._audit_log_close_reader(fd)
return
Expand All @@ -511,6 +516,7 @@ def _audit_log_write(self, fd: int, fpath: io.BytesIO):
log.error(f"Failed to write audit log message: [{uid=}, {eid=}] - {e_str}")

def _install_signal_handlers(self):
signal.signal(signal.SIGHUP, self.request_restart)
signal.signal(signal.SIGTERM, self.request_shutdown)
signal.signal(signal.SIGINT, self.request_shutdown)
signal.signal(signal.SIGQUIT, self.request_shutdown)
Expand Down Expand Up @@ -629,6 +635,74 @@ def start(self):
# re-enable cursor visibility
print("\033[?25h", end="", file=msg_out)

def hot_restart(self):
log.info("Manager hot hot_restart requested")
r_fd = os.memfd_create("hot_restart", flags=0) # 0 == *not* CLOEXEC

stdin_data = {
"amqp_creds": {
"endpoint_id": self._endpoint_uuid_str,
"command_queue_info": self._command.queue_info,
"heartbeat_queue_info": self._heartbeat_publisher.queue_info,
},
"restart_fd": r_fd,
}
self._command_stop_event.set()
self._heartbeat_publisher.stop()
self._command.join()

r, w = os.pipe()
os.dup2(r, 0)
os.write(w, json.dumps(stdin_data).encode())
os.close(w)
os.close(r)

with self._audit_log_lock:
if not self._audit_log_handler_stop:
nowtz = datetime.now().astimezone().isoformat()
uid = os.getuid()
pid = os.getpid()
eid = self._endpoint_uuid_str
msg = (
f"{nowtz} uid={uid} pid={pid} eid={eid} End MEP session"
f" [hot restart] .....\n"
)
with open(self._config.audit_log_path, "ab", buffering=0) as audit_f:
audit_f.write(msg.encode())

# only thread of consequence that we block; will be restarted in new exec();
# AMQP will resend any interim received tasks because we won't ACK them.
state = {
"_audit_pipes": self._audit_pipes,
"_children": self._children,
"_cached_cmd_start_args": self._cached_cmd_start_args,
}
os.write(r_fd, pickle.dumps(state))
os.fsync(r_fd)
os.lseek(r_fd, 0, os.SEEK_SET)
args = [sys.executable, *sys.argv]

num_children = len(self._children)
log.info(
f"\n.......... Manager hot restarting {self._endpoint_uuid_str}"
f" (task processors: {num_children})\n"
)
os.execvpe(args[0], args=args, env=os.environ)

def _finish_hot_restart(self, fd: int):
with os.fdopen(fd, "rb") as f:
restart_data: dict = pickle.loads(f.read())

self._audit_pipes.update(restart_data.get("_audit_pipes", {}))
self._children.update(restart_data.get("_children", {}))
self._cached_cmd_start_args.update(
restart_data.get("_cached_cmd_start_args", {})
)
for audit_r in self._audit_pipes:
self._audit_selector.register(
audit_r, selectors.EVENT_READ, self._audit_log_write
)

def _event_loop(self):
parent_identities: set[str] = set()
if not is_privileged():
Expand Down Expand Up @@ -668,6 +742,11 @@ def _event_loop(self):
if self._wait_for_child:
self.wait_for_children()

if self._restart:
# not protected; if exec() fails, then this raises and we shutdown
# ... "Failure is not an option!"
self.hot_restart()

if time.monotonic() - last_heartbeat >= self._heartbeat_period:
self.send_heartbeat()
last_heartbeat = time.monotonic()
Expand Down
148 changes: 146 additions & 2 deletions compute_endpoint/tests/unit/test_endpointmanager_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import logging
import os
import pathlib
import pickle
import pwd
import queue
import random
import re
import resource
import selectors
import signal
import sys
import time
Expand Down Expand Up @@ -49,6 +51,7 @@
EndpointManager,
InvalidUserError,
MappedPosixIdentity,
UserEndpointRecord,
)


Expand Down Expand Up @@ -267,6 +270,7 @@ def epmanager_as_root(
mock_os.pipe.return_value = 40, 41
mock_os.dup2.side_effect = (0, 1, 2, AssertionError("dup2: unexpected?"))
mock_os.open.side_effect = (4, 5, AssertionError("open: unexpected?"))
mock_os.memfd_create.return_value = random.randint(50, 10000)

mock_pwd = mocker.patch(f"{_MOCK_BASE}pwd")
mock_pwd.getpwnam.side_effect = (
Expand Down Expand Up @@ -295,8 +299,8 @@ def epmanager_as_root(
mock_auth_client.userinfo.return_value = {"identity_set": [{"sub": ident}]}

em = EndpointManager(conf_dir, ep_uuid, mock_conf_root)
em._command = mock.Mock(spec=CommandQueueSubscriber)
em._heartbeat_publisher = mock.Mock(spec=ResultPublisher)
em._command = mock.Mock(spec=CommandQueueSubscriber, queue_info={})
em._heartbeat_publisher = mock.Mock(spec=ResultPublisher, queue_info={})

yield conf_dir, mock_conf_root, mock_client, mock_os, mock_pwd, em
if em.identity_mapper:
Expand Down Expand Up @@ -2543,3 +2547,143 @@ def _called(fn_name):

assert pyexc.value.code == _GOOD_EC, "Q&D: verify we exec'ed, based on '+= 1'"
assert pamh.pam_close_session.called


def test_restart_signal(successful_exec_from_mocked_root, reset_signals):
mock_os, *_, em = successful_exec_from_mocked_root

em.hot_restart = mock.Mock(side_effect=MemoryError)
em._install_signal_handlers()
assert not em._restart, "Verify test setup"
os.kill(os.getpid(), signal.SIGHUP)

with pytest.raises(MemoryError):
em._event_loop()

assert em._restart, "Ensure class state, but main thing is .hot_restart() invoked"


def test_restart_restarts(successful_exec_from_mocked_root, randomstring):
mock_os, *_, em = successful_exec_from_mocked_root

canary = randomstring()
mock_os.environ = {"canary": canary}

em.hot_restart()

assert mock_os.execvpe.called, "Basic correctness"
a, k = mock_os.execvpe.call_args
exp_args = [sys.executable, *sys.argv]
assert (exp_args[0],) == a, "Expect repeat of initial args"
assert k["args"] == exp_args, "Expect repeat of initial args"
assert k["env"]["canary"] == canary, "Expect to relay environment variables"


def test_restart_conveys_state(successful_exec_from_mocked_root, randomstring):
mock_os, *_, em = successful_exec_from_mocked_root

em._audit_pipes[123] = {"pid": random.randint(1, 1000000)}
em._children[123] = UserEndpointRecord(ep_name="abc", arguments="some_args")
em._cached_cmd_start_args[123] = randomstring()
em._command.queue_info = {"canary": randomstring()}
em._heartbeat_publisher.queue_info = {"canary": randomstring()}
em.hot_restart()

assert mock_os.execvpe.called, "Basic correctness"
assert mock_os.write.call_count == 2, "Verify test setup, expected writes"

pipe_r, pipe_w = mock_os.pipe.return_value
(stdin_fd, stdin_bytes), _ = mock_os.write.call_args_list[0]
(mem_fd, conveyed), _ = mock_os.write.call_args_list[1]

assert stdin_fd == pipe_w, "Expect write to new proc stdin"
mock_os.dup2.assert_called_with(pipe_r, 0), "Expect write to new proc stdin"
stdin = json.loads(stdin_bytes)
creds = stdin.get("amqp_creds")
assert creds, "Expect reconnection credentials; no need to relogin"
assert creds["endpoint_id"] == em._endpoint_uuid_str
assert creds["command_queue_info"] == em._command.queue_info
assert creds["heartbeat_queue_info"] == em._heartbeat_publisher.queue_info
assert stdin.get("restart_fd") == mem_fd, "Hot restarted requires a state file"

assert mem_fd == mock_os.memfd_create.return_value, "Should write *anonymous* file"

state = pickle.loads(conveyed)
assert state["_audit_pipes"] == em._audit_pipes
assert state["_children"] == em._children
assert state["_cached_cmd_start_args"] == em._cached_cmd_start_args


def test_restart_repopulates_state(successful_exec_from_mocked_root, randomstring):
mock_os, *_, em = successful_exec_from_mocked_root

canary = randomstring()
audit_pipes = {123: {"pid": random.randint(1, 1000000)}}
children = {123: UserEndpointRecord(ep_name="abc", arguments="some_args")}
cached_args = {123: randomstring()}
em._audit_selector = mock.Mock(spec=selectors.DefaultSelector)
em._audit_pipes = audit_pipes
em._children = children
em._cached_cmd_start_args = cached_args

em.hot_restart()
em._audit_pipes = {10000: canary}
em._children = {10000: canary}
em._cached_cmd_start_args = {10000: canary}

(mem_fd, conveyed), _ = mock_os.write.call_args_list[1]
mem_f = io.BytesIO(conveyed)
mem_f.seek(0)
mock_os.fdopen.return_value = mem_f

em._finish_hot_restart(mem_fd)
mock_os.fdopen.assert_called_with(mem_fd, "rb"), "Expect passed fd opened"
assert em._audit_pipes[10000] == canary, "Expect updated, not overwritten"
assert em._children[10000] == canary, "Expect updated, not overwritten"
assert em._cached_cmd_start_args[10000] == canary, "Expect updated, not overwritten"
del em._audit_pipes[10000], em._children[10000], em._cached_cmd_start_args[10000]

assert em._audit_pipes == audit_pipes
assert em._children == children
assert em._cached_cmd_start_args == cached_args

all_args = {
fd: (evt, cb) for (fd, evt, cb), _ in em._audit_selector.register.call_args_list
}

exp_args = (selectors.EVENT_READ, em._audit_log_write)
for audit_fd in em._audit_pipes:
assert all_args[audit_fd] == exp_args, "Expect reregistration of audit pipes"


def test_restart_audit_pipes_protected(successful_exec_from_mocked_root):
mock_os, *_, em = successful_exec_from_mocked_root

em._audit_pipes[123] = {"pid": 1235}
em._audit_log_lock = mock.MagicMock()

def lock_test(*a, **k):
assert em._audit_log_lock.__enter__.called
assert not em._audit_log_lock.__exit__.called, "Expect locked at during call"
return b"some audit bytes"

mock_os.execvpe.side_effect = lock_test
em.hot_restart()
assert em._audit_log_lock.__enter__.called, "Verify test setup"

mock_os.read.side_effect = lock_test
em._audit_log_lock.reset_mock()
em._audit_log_write(123, mock.Mock())
assert em._audit_log_lock.__enter__.called, "Verify test setup"


def test_restart_logs(successful_exec_from_mocked_root, mock_log):
mock_os, *_, em = successful_exec_from_mocked_root

em.hot_restart()

i_logs = "\n".join(f"{a}" for (a,), k in mock_log.info.call_args_list)

assert "hot hot_restart requested" in i_logs, "Expect initial signal acknowledged"
assert ".......... Manager hot restarting" in i_logs, "Expect last message"
assert " (task processors: 0)" in i_logs, "Expect friendly count for admin"
Loading