Skip to content

Commit 6963cd9

Browse files
authored
Fix SIGINT shutdown during active inference (#8993)
1 parent ab6f186 commit 6963cd9

2 files changed

Lines changed: 185 additions & 3 deletions

File tree

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,6 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
133133

134134
self._on_after_run_node(invocation, queue_item, output)
135135

136-
except KeyboardInterrupt:
137-
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
138-
pass
139136
except CanceledException:
140137
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
141138
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from contextlib import contextmanager
2+
from threading import Event
3+
4+
import pytest
5+
6+
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
7+
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionRunner
8+
from tests.dangerously_run_function_in_subprocess import dangerously_run_function_in_subprocess
9+
10+
11+
@invocation_output("test_interrupt_output")
12+
class InterruptTestOutput(BaseInvocationOutput):
13+
pass
14+
15+
16+
@invocation("test_keyboard_interrupt", version="1.0.0")
17+
class KeyboardInterruptInvocation(BaseInvocation):
18+
def invoke(self, context) -> InterruptTestOutput:
19+
raise KeyboardInterrupt
20+
21+
22+
class _DummyStats:
23+
@contextmanager
24+
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
25+
yield
26+
27+
28+
class _DummyEvents:
29+
def emit_invocation_started(self, queue_item, invocation) -> None:
30+
pass
31+
32+
def emit_invocation_complete(self, invocation, queue_item, output) -> None:
33+
pass
34+
35+
def emit_invocation_error(self, queue_item, invocation, error_type, error_message, error_traceback) -> None:
36+
pass
37+
38+
39+
class _DummyLogger:
40+
def debug(self, msg) -> None:
41+
pass
42+
43+
def error(self, msg) -> None:
44+
pass
45+
46+
47+
class _DummyConfig:
48+
node_cache_size = 0
49+
50+
51+
def _build_runner(monkeypatch: pytest.MonkeyPatch) -> DefaultSessionRunner:
52+
monkeypatch.setattr(
53+
"invokeai.app.services.session_processor.session_processor_default.build_invocation_context",
54+
lambda data, services, is_canceled: None,
55+
)
56+
57+
runner = DefaultSessionRunner()
58+
runner.start(
59+
services=type(
60+
"Services",
61+
(),
62+
{
63+
"performance_statistics": _DummyStats(),
64+
"events": _DummyEvents(),
65+
"logger": _DummyLogger(),
66+
"configuration": _DummyConfig(),
67+
},
68+
)(),
69+
cancel_event=Event(),
70+
)
71+
return runner
72+
73+
74+
def _build_queue_item(invocation: BaseInvocation):
75+
return type(
76+
"QueueItem",
77+
(),
78+
{
79+
"item_id": 1,
80+
"session_id": "test-session",
81+
"session": type("Session", (), {"prepared_source_mapping": {invocation.id: invocation.id}})(),
82+
},
83+
)()
84+
85+
86+
def test_run_node_propagates_keyboard_interrupt(monkeypatch: pytest.MonkeyPatch) -> None:
87+
runner = _build_runner(monkeypatch)
88+
invocation = KeyboardInterruptInvocation(id="node")
89+
queue_item = _build_queue_item(invocation)
90+
91+
with pytest.raises(KeyboardInterrupt):
92+
runner.run_node(invocation=invocation, queue_item=queue_item)
93+
94+
95+
def test_run_node_does_not_swallow_sigint_in_subprocess() -> None:
96+
def test_func():
97+
import os
98+
import signal
99+
import threading
100+
import time
101+
from contextlib import contextmanager
102+
from threading import Event
103+
104+
import invokeai.app.services.session_processor.session_processor_default as session_processor_default
105+
from invokeai.app.invocations.baseinvocation import (
106+
BaseInvocation,
107+
BaseInvocationOutput,
108+
invocation,
109+
invocation_output,
110+
)
111+
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionRunner
112+
113+
@invocation_output("test_interrupt_output_subprocess")
114+
class InterruptTestOutput(BaseInvocationOutput):
115+
pass
116+
117+
@invocation("test_sigint_during_node", version="1.0.0")
118+
class SigIntDuringNodeInvocation(BaseInvocation):
119+
def invoke(self, context) -> InterruptTestOutput:
120+
timer = threading.Thread(target=lambda: (time.sleep(0.1), os.kill(os.getpid(), signal.SIGINT)))
121+
timer.daemon = True
122+
timer.start()
123+
time.sleep(5)
124+
return InterruptTestOutput()
125+
126+
class DummyStats:
127+
@contextmanager
128+
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
129+
yield
130+
131+
class DummyEvents:
132+
def emit_invocation_started(self, queue_item, invocation) -> None:
133+
pass
134+
135+
def emit_invocation_complete(self, invocation, queue_item, output) -> None:
136+
pass
137+
138+
def emit_invocation_error(self, queue_item, invocation, error_type, error_message, error_traceback) -> None:
139+
pass
140+
141+
class DummyLogger:
142+
def debug(self, msg) -> None:
143+
pass
144+
145+
def error(self, msg) -> None:
146+
pass
147+
148+
class DummyConfig:
149+
node_cache_size = 0
150+
151+
session_processor_default.build_invocation_context = lambda data, services, is_canceled: None
152+
153+
runner = DefaultSessionRunner()
154+
runner.start(
155+
services=type(
156+
"Services",
157+
(),
158+
{
159+
"performance_statistics": DummyStats(),
160+
"events": DummyEvents(),
161+
"logger": DummyLogger(),
162+
"configuration": DummyConfig(),
163+
},
164+
)(),
165+
cancel_event=Event(),
166+
)
167+
168+
invocation = SigIntDuringNodeInvocation(id="node")
169+
queue_item = type(
170+
"QueueItem",
171+
(),
172+
{
173+
"item_id": 1,
174+
"session_id": "test-session",
175+
"session": type("Session", (), {"prepared_source_mapping": {invocation.id: invocation.id}})(),
176+
},
177+
)()
178+
179+
runner.run_node(invocation=invocation, queue_item=queue_item)
180+
print("swallowed")
181+
182+
stdout, stderr, returncode = dangerously_run_function_in_subprocess(test_func)
183+
184+
assert stdout.strip() == ""
185+
assert returncode != 0, stderr

0 commit comments

Comments
 (0)