Skip to content

Commit 5671eb1

Browse files
fix(highs): stop solve on Ctrl-C (#536)
Run HiGHS in a worker thread so the main thread can catch KeyboardInterrupt and signal cancelSolve(), preventing orphaned/continuing solves.
1 parent 643d29a commit 5671eb1

2 files changed

Lines changed: 116 additions & 1 deletion

File tree

linopy/solvers.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import re
1414
import subprocess as sub
1515
import sys
16+
import threading
1617
import warnings
1718
from abc import ABC, abstractmethod
1819
from collections import namedtuple
@@ -56,6 +57,73 @@
5657

5758
which = "where" if os.name == "nt" else "which"
5859

60+
61+
def _run_highs_with_keyboard_interrupt(h: Any) -> None:
62+
"""
63+
Run `highspy.Highs.run()` while ensuring Ctrl-C cancels the solve.
64+
65+
HiGHS can run for a long time inside a C-extension call. Running it in a
66+
worker thread allows the main thread to reliably receive KeyboardInterrupt
67+
and signal HiGHS to stop via `cancelSolve()`.
68+
"""
69+
70+
handle_keyboard_interrupt = getattr(h, "HandleKeyboardInterrupt", None)
71+
handle_user_interrupt = getattr(h, "HandleUserInterrupt", None)
72+
73+
old_handle_keyboard_interrupt = (
74+
handle_keyboard_interrupt if not callable(handle_keyboard_interrupt) else None
75+
)
76+
old_handle_user_interrupt = (
77+
handle_user_interrupt if not callable(handle_user_interrupt) else None
78+
)
79+
80+
try:
81+
if callable(handle_keyboard_interrupt):
82+
handle_keyboard_interrupt(True)
83+
elif handle_keyboard_interrupt is not None:
84+
h.HandleKeyboardInterrupt = True
85+
86+
if callable(handle_user_interrupt):
87+
handle_user_interrupt(True)
88+
elif handle_user_interrupt is not None:
89+
h.HandleUserInterrupt = True
90+
91+
finished = threading.Event()
92+
run_error: BaseException | None = None
93+
94+
def _target() -> None:
95+
nonlocal run_error
96+
try:
97+
h.run()
98+
except BaseException as exc: # pragma: no cover
99+
run_error = exc
100+
finally:
101+
finished.set()
102+
103+
thread = threading.Thread(target=_target, name="linopy-highs-run", daemon=True)
104+
thread.start()
105+
106+
try:
107+
while not finished.wait(0.1):
108+
pass
109+
except KeyboardInterrupt:
110+
cancel_solve = getattr(h, "cancelSolve", None)
111+
if callable(cancel_solve):
112+
with contextlib.suppress(Exception):
113+
cancel_solve()
114+
while not finished.wait(0.1):
115+
pass
116+
raise
117+
118+
if run_error is not None:
119+
raise run_error
120+
finally:
121+
if old_handle_keyboard_interrupt is not None:
122+
h.HandleKeyboardInterrupt = old_handle_keyboard_interrupt
123+
if old_handle_user_interrupt is not None:
124+
h.HandleUserInterrupt = old_handle_user_interrupt
125+
126+
59127
# the first available solver will be the default solver
60128
with contextlib.suppress(ModuleNotFoundError):
61129
import gurobipy
@@ -912,7 +980,7 @@ def _solve(
912980
elif warmstart_fn:
913981
h.readBasis(path_to_string(warmstart_fn))
914982

915-
h.run()
983+
_run_highs_with_keyboard_interrupt(h)
916984

917985
condition = h.getModelStatus()
918986
termination_condition = CONDITION_MAP.get(
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
import _thread
4+
import threading
5+
import time
6+
7+
import pytest
8+
9+
from linopy.solvers import _run_highs_with_keyboard_interrupt
10+
11+
12+
class DummyHighs:
13+
def __init__(self) -> None:
14+
self.HandleKeyboardInterrupt = False
15+
self.HandleUserInterrupt = False
16+
self._cancel_event = threading.Event()
17+
self.started = threading.Event()
18+
self.finished = threading.Event()
19+
self.cancel_calls = 0
20+
21+
def run(self) -> None:
22+
self.started.set()
23+
self._cancel_event.wait(timeout=5)
24+
self.finished.set()
25+
26+
def cancelSolve(self) -> None:
27+
self.cancel_calls += 1
28+
self._cancel_event.set()
29+
30+
31+
def test_run_highs_cancels_on_keyboard_interrupt() -> None:
32+
dummy = DummyHighs()
33+
34+
def interrupter() -> None:
35+
assert dummy.started.wait(timeout=1)
36+
time.sleep(0.05)
37+
_thread.interrupt_main()
38+
39+
threading.Thread(target=interrupter, daemon=True).start()
40+
41+
with pytest.raises(KeyboardInterrupt):
42+
_run_highs_with_keyboard_interrupt(dummy)
43+
44+
assert dummy.cancel_calls >= 1
45+
assert dummy.finished.wait(timeout=1)
46+
assert dummy.HandleKeyboardInterrupt is False
47+
assert dummy.HandleUserInterrupt is False

0 commit comments

Comments
 (0)