|
6 | 6 | from unittest.mock import patch, Mock |
7 | 7 |
|
8 | 8 | import qubesagent.firewall |
| 9 | +import signal |
9 | 10 |
|
10 | 11 |
|
11 | 12 | class DummyIptablesRestore(object): |
@@ -572,3 +573,46 @@ def test_is_blocked(self): |
572 | 573 |
|
573 | 574 | for server in dns_servers_ipv6: |
574 | 575 | self.assertTrue(self.obj.is_blocked(rules, ("udp", server, "53"), dns)) |
| 576 | + |
| 577 | + def test_main_blocks_signals_during_qdb_operations(self): |
| 578 | + #Test that signals are blocked during qdb operations and only unblocked during read_watch(). |
| 579 | + |
| 580 | + self.obj.qdb.entries['/qubes-firewall/10.137.0.1/policy'] = b'accept' |
| 581 | + self.obj.qdb.entries['/connected-ips'] = b'' |
| 582 | + self.obj.qdb.entries['/connected-ips6'] = b'' |
| 583 | + |
| 584 | + # Track sigmask calls |
| 585 | + sigmask_calls = [] |
| 586 | + original_sigmask = signal.pthread_sigmask |
| 587 | + |
| 588 | + def mock_sigmask(how, mask): |
| 589 | + sigmask_calls.append((how, mask)) |
| 590 | + return original_sigmask(how, set()) # Don't actually block |
| 591 | + |
| 592 | + # Make read_watch() terminate the loop after first call |
| 593 | + call_count = [0] |
| 594 | + def mock_read_watch(): |
| 595 | + call_count[0] += 1 |
| 596 | + if call_count[0] == 1: |
| 597 | + return '/qubes-firewall/10.137.0.1' |
| 598 | + self.obj.terminate_requested = True |
| 599 | + raise OSError("Interrupted") |
| 600 | + |
| 601 | + self.obj.qdb.read_watch = mock_read_watch |
| 602 | + |
| 603 | + with patch.object(signal, 'pthread_sigmask', mock_sigmask): |
| 604 | + self.obj.main() |
| 605 | + |
| 606 | + # Verify signal blocking pattern: |
| 607 | + # 1. SIG_BLOCK at start |
| 608 | + # 2. SIG_UNBLOCK before read_watch |
| 609 | + # 3. SIG_BLOCK after read_watch (or in except) |
| 610 | + |
| 611 | + block_calls = [c for c in sigmask_calls if c[0] == signal.SIG_BLOCK] |
| 612 | + unblock_calls = [c for c in sigmask_calls if c[0] == signal.SIG_UNBLOCK] |
| 613 | + |
| 614 | + self.assertGreater(len(block_calls), 0, "Should have SIG_BLOCK calls") |
| 615 | + self.assertGreater(len(unblock_calls), 0, "Should have SIG_UNBLOCK calls") |
| 616 | + # First call should be SIG_BLOCK |
| 617 | + self.assertEqual(sigmask_calls[0][0], signal.SIG_BLOCK) |
| 618 | + |
0 commit comments