Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion qubesagent/firewall.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ def dns_addresses(family=None):
def main(self):
self.terminate_requested = False
self.reload_requested = False
# Block SIGHUP and SIGTERM during all qdb operations to prevent interrupting request-response pairs which corrupts protocol state.
# Signals are only unblocked during read_watch() which is safe to interrupt as it's just waiting, not mid-operation.
signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGHUP, signal.SIGTERM})
self.init()
self.run_firewall_dir()
if not self.is_custom_persist_enabled():
Expand All @@ -366,11 +369,22 @@ def main(self):
self.handle_addr(source_addr)
self.reload_requested = False
self.sd_notify('READY=1')
# Unblock signals only during read_watch()
signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGHUP, signal.SIGTERM})

# Re-check flags after unblocking, in case signal arrived
if self.terminate_requested or self.reload_requested:
signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGHUP, signal.SIGTERM})
continue
try:
watch_path = self.qdb.read_watch()
except OSError: # EINTR
# signal received, re-check loop condition
# signal received, block signals again and re-check loop condition
signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGHUP, signal.SIGTERM})
continue

#Block signals again before doing any qdb work
signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGHUP, signal.SIGTERM})

if watch_path is None:
break
Expand Down
44 changes: 44 additions & 0 deletions qubesagent/test_firewall.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest.mock import patch, Mock

import qubesagent.firewall
import signal


class DummyIptablesRestore(object):
Expand Down Expand Up @@ -572,3 +573,46 @@ def test_is_blocked(self):

for server in dns_servers_ipv6:
self.assertTrue(self.obj.is_blocked(rules, ("udp", server, "53"), dns))

def test_main_blocks_signals_during_qdb_operations(self):
#Test that signals are blocked during qdb operations and only unblocked during read_watch().

self.obj.qdb.entries['/qubes-firewall/10.137.0.1/policy'] = b'accept'
self.obj.qdb.entries['/connected-ips'] = b''
self.obj.qdb.entries['/connected-ips6'] = b''

# Track sigmask calls
sigmask_calls = []
original_sigmask = signal.pthread_sigmask

def mock_sigmask(how, mask):
sigmask_calls.append((how, mask))
return original_sigmask(how, set()) # Don't actually block

# Make read_watch() terminate the loop after first call
call_count = [0]
def mock_read_watch():
call_count[0] += 1
if call_count[0] == 1:
return '/qubes-firewall/10.137.0.1'
self.obj.terminate_requested = True
raise OSError("Interrupted")

self.obj.qdb.read_watch = mock_read_watch

with patch.object(signal, 'pthread_sigmask', mock_sigmask):
self.obj.main()

# Verify signal blocking pattern:
# 1. SIG_BLOCK at start
# 2. SIG_UNBLOCK before read_watch
# 3. SIG_BLOCK after read_watch (or in except)

block_calls = [c for c in sigmask_calls if c[0] == signal.SIG_BLOCK]
unblock_calls = [c for c in sigmask_calls if c[0] == signal.SIG_UNBLOCK]

self.assertGreater(len(block_calls), 0, "Should have SIG_BLOCK calls")
self.assertGreater(len(unblock_calls), 0, "Should have SIG_UNBLOCK calls")
# First call should be SIG_BLOCK
self.assertEqual(sigmask_calls[0][0], signal.SIG_BLOCK)