diff --git a/qubesagent/firewall.py b/qubesagent/firewall.py index 2af38336..b7b8c013 100755 --- a/qubesagent/firewall.py +++ b/qubesagent/firewall.py @@ -344,6 +344,9 @@ def dns_addresses(family=None): def main(self): self.terminate_requested = False self.reload_requested = False + # Block SIGHUP and SIGTERM during all qdb operations to prevent interrupting request-response pairs which corrupts protocol state. + # Signals are only unblocked during read_watch() which is safe to interrupt as it's just waiting, not mid-operation. + signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGHUP, signal.SIGTERM}) self.init() self.run_firewall_dir() if not self.is_custom_persist_enabled(): @@ -366,11 +369,22 @@ def main(self): self.handle_addr(source_addr) self.reload_requested = False self.sd_notify('READY=1') + # Unblock signals only during read_watch() + signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGHUP, signal.SIGTERM}) + + # Re-check flags after unblocking, in case signal arrived + if self.terminate_requested or self.reload_requested: + signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGHUP, signal.SIGTERM}) + continue try: watch_path = self.qdb.read_watch() except OSError: # EINTR - # signal received, re-check loop condition + # signal received, block signals again and re-check loop condition + signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGHUP, signal.SIGTERM}) continue + + #Block signals again before doing any qdb work + signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGHUP, signal.SIGTERM}) if watch_path is None: break diff --git a/qubesagent/test_firewall.py b/qubesagent/test_firewall.py index f562bcb6..2afafa55 100644 --- a/qubesagent/test_firewall.py +++ b/qubesagent/test_firewall.py @@ -6,6 +6,7 @@ from unittest.mock import patch, Mock import qubesagent.firewall +import signal class DummyIptablesRestore(object): @@ -572,3 +573,46 @@ def test_is_blocked(self): for server in dns_servers_ipv6: self.assertTrue(self.obj.is_blocked(rules, ("udp", server, "53"), dns)) + + def test_main_blocks_signals_during_qdb_operations(self): + #Test that signals are blocked during qdb operations and only unblocked during read_watch(). + + self.obj.qdb.entries['/qubes-firewall/10.137.0.1/policy'] = b'accept' + self.obj.qdb.entries['/connected-ips'] = b'' + self.obj.qdb.entries['/connected-ips6'] = b'' + + # Track sigmask calls + sigmask_calls = [] + original_sigmask = signal.pthread_sigmask + + def mock_sigmask(how, mask): + sigmask_calls.append((how, mask)) + return original_sigmask(how, set()) # Don't actually block + + # Make read_watch() terminate the loop after first call + call_count = [0] + def mock_read_watch(): + call_count[0] += 1 + if call_count[0] == 1: + return '/qubes-firewall/10.137.0.1' + self.obj.terminate_requested = True + raise OSError("Interrupted") + + self.obj.qdb.read_watch = mock_read_watch + + with patch.object(signal, 'pthread_sigmask', mock_sigmask): + self.obj.main() + + # Verify signal blocking pattern: + # 1. SIG_BLOCK at start + # 2. SIG_UNBLOCK before read_watch + # 3. SIG_BLOCK after read_watch (or in except) + + block_calls = [c for c in sigmask_calls if c[0] == signal.SIG_BLOCK] + unblock_calls = [c for c in sigmask_calls if c[0] == signal.SIG_UNBLOCK] + + self.assertGreater(len(block_calls), 0, "Should have SIG_BLOCK calls") + self.assertGreater(len(unblock_calls), 0, "Should have SIG_UNBLOCK calls") + # First call should be SIG_BLOCK + self.assertEqual(sigmask_calls[0][0], signal.SIG_BLOCK) +