diff --git a/ui/opensnitch/customwidgets/generictableview.py b/ui/opensnitch/customwidgets/generictableview.py index 30dd9ccd17..0b37871b33 100644 --- a/ui/opensnitch/customwidgets/generictableview.py +++ b/ui/opensnitch/customwidgets/generictableview.py @@ -12,6 +12,7 @@ QObject, pyqtSignal, QEvent, + QTimer, Qt) class GenericTableModel(QStandardItemModel): @@ -35,6 +36,10 @@ class GenericTableModel(QStandardItemModel): prevQueryStr = '' # modified query object realQuery = QSqlQuery() + # set when the query changes, to force the next viewport refresh. + # Otherwise the view would keep displaying the rows of the previous + # query when the scrollbar is not at the top/bottom of the view. + forceNextRefresh = False items = [] lastItems = [] @@ -169,6 +174,7 @@ def setQuery(self, q, db, binds=None, limit=None, offset=None): if self.prevQueryStr != self.origQueryStr: self.realQuery = tmpQuery + self.forceNextRefresh = True self.update_row_count() self.update_col_count() @@ -201,6 +207,9 @@ def refreshViewport(self, scrollValue, maxRowsInViewport, force=False): force var will force a refresh if the scrollbar is at the top or bottom of the viewport, otherwise skip it to allow rows analyzing without refreshing. """ + if self.forceNextRefresh: + force = True + self.forceNextRefresh = False if not force: return @@ -325,6 +334,10 @@ def __init__(self, parent): # current selected rows self._rows_selection = set() + # tracking-column text of the current (focused) row, used to + # restore currentIndex after a viewport refresh + self._current_row_text = None + # selection range to highlight the rows of the viewport, that is, # the rows of the current sql query (offset + limit). self._first_row_selected = None @@ -388,14 +401,11 @@ def selectDbRows(self, first, last): if selrows is None: return self._rows_selection.clear() - for rid, row in enumerate(selrows): - key = row[self.trackingCol] - self._rows_selection.add(key) - idx = self.model().index(rid, self.trackingCol) - self.selectionModel().setCurrentIndex( - idx, - QItemSelectionModel.SelectionFlag.Rows | QItemSelectionModel.SelectionFlag.SelectCurrent - ) + for row in selrows: + self._rows_selection.add(row[self.trackingCol]) + # the visual selection of the visible rows is applied by + # selectIndices(); selecting db-range positions here would + # highlight wrong viewport rows when the range is scrolled. self.selectIndices() def getMinViewportRow(self): @@ -509,8 +519,8 @@ def mouseMoveEvent(self, event): def mousePressEvent(self, event): # we need to call upper class to paint selections properly super().mousePressEvent(event) - self.mousePressed = True rightBtnPressed = event.button() != Qt.MouseButton.LeftButton + self.mousePressed = not rightBtnPressed self.keySelectAll = False if not self.shiftPressed: @@ -522,11 +532,15 @@ def mousePressEvent(self, event): pos = event.pos() item = self.indexAt(pos) row = self.rowAt(pos.y()) - if item is None: - return clickedItem = self.model().index(row, self.trackingCol) - if clickedItem.data() is None: + if not item.isValid() or clickedItem.data() is None: + # Qt clears the visual selection when pressing on an empty + # area; keep the tracked selection in sync, otherwise menu + # actions keep operating on rows no longer highlighted. + if not self.ctrlPressed: + self._rows_selection.clear() + self._current_row_text = None return flags = QItemSelectionModel.SelectionFlag.Rows | QItemSelectionModel.SelectionFlag.SelectCurrent @@ -607,6 +621,8 @@ def mousePressEvent(self, event): clickedItem, flags ) + isDeselect = bool(flags & QItemSelectionModel.SelectionFlag.Deselect) + self._current_row_text = None if isDeselect else clickedItem.data() def handleShiftPressed(self): # in the viewport, the rows start at 1, but in the db at 0 @@ -683,6 +699,7 @@ def clearSelection(self): self.selectionModel().reset() self.selectionModel().clearCurrentIndex() self._rows_selection.clear() + self._current_row_text = None self._first_row_selected = None self._last_row_selected = None self._db_selection_range = { @@ -743,6 +760,44 @@ def selectIndices(self): sel.append(QItemSelectionRange(i.index())) self.selectionModel().clear() self.selectionModel().select(sel, QItemSelectionModel.SelectionFlag.Select | QItemSelectionModel.SelectionFlag.Rows) + self._restoreCurrentIndex() + + def _restoreCurrentIndex(self): + """Re-apply the current (focused) row after the viewport has been + refreshed. selectionModel().clear() drops currentIndex, so keyboard + navigation would lose its position on every refresh otherwise. + """ + if self._current_row_text is None: + return + items = self.model().findItems(self._current_row_text, column=self.trackingCol) + if len(items) == 0: + return + self.selectionModel().setCurrentIndex( + items[0].index(), + QItemSelectionModel.SelectionFlag.NoUpdate + ) + + def _syncSelectionFromCurrentRow(self): + """Sync the tracked rows with the row that became current after Qt + processed a navigation key press. The view handles the key AFTER + our eventFilter runs, so reading currentIndex there returns the row + the user navigated AWAY from, leaving the tracked selection (and + thus context-menu actions) one row behind the visible selection. + """ + if self.ctrlPressed: + # ctrl+navigation moves the current row without changing the + # selection + return + curIdx = self.selectionModel().currentIndex() + if not curIdx.isValid(): + return + rowText = self.model().index(curIdx.row(), self.trackingCol).data() + if rowText is None: + return + if not self.shiftPressed: + self._rows_selection.clear() + self._rows_selection.add(rowText) + self._current_row_text = rowText def _selectLastRow(self): internalId = self.getCurrentIndex() @@ -781,17 +836,12 @@ def onScrollbarValueChanged(self, vSBNewValue): def onKeyUp(self): curIdx = self.selectionModel().currentIndex() - if not self.shiftPressed: - self._rows_selection.clear() - self._rows_selection.add(curIdx.data()) - viewport_row = self.getViewportRowPos(curIdx.row()) self._last_row_selected = viewport_row if self._first_row_selected is None: self._first_row_selected = viewport_row offset = self.model().queryOffset - limit = self.model().queryLimit if curIdx.row() == 0: self.vScrollBar.setValue(max(0, self.vScrollBar.value() - 1)) if curIdx.row() == 0 and viewport_row+offset-1 == offset: @@ -800,21 +850,21 @@ def onKeyUp(self): def onKeyDown(self): curIdx = self.selectionModel().currentIndex() curRow = curIdx.row() - if not self.shiftPressed: - self._rows_selection.clear() - self._rows_selection.add(curIdx.data()) + viewport_row = self.getViewportRowPos(curRow) viewport_row = self.getViewportRowPos(curRow) newValue = self.vScrollBar.value() - offset = self.model().queryOffset limit = self.model().queryLimit if curRow >= self.maxRowsInViewport-2: # this change will fire onScrollbarValueChanged, which will refresh the # view (the rows and the rows numbers) self.vScrollBar.setValue(newValue+1) self._selectLastRow() - if (offset == 0 and viewport_row == limit) or viewport_row+offset == limit+offset: + # wrap the selection to the first row after paginating to the next + # records window. The query offset cancels out on both sides of the + # comparison, so checking against the limit alone is enough. + if viewport_row == limit: self._selectRow(0) def onKeyHome(self): @@ -884,12 +934,16 @@ def eventFilter(self, obj, event): # some pyqt versions. if event.key() == Qt.Key.Key_Up: self.onKeyUp() + QTimer.singleShot(0, self._syncSelectionFromCurrentRow) elif event.key() == Qt.Key.Key_Down: self.onKeyDown() + QTimer.singleShot(0, self._syncSelectionFromCurrentRow) elif event.key() == Qt.Key.Key_Home: self.onKeyHome() + QTimer.singleShot(0, self._syncSelectionFromCurrentRow) elif event.key() == Qt.Key.Key_End: self.onKeyEnd() + QTimer.singleShot(0, self._syncSelectionFromCurrentRow) elif event.key() == Qt.Key.Key_PageUp: self.onKeyPageUp() elif event.key() == Qt.Key.Key_PageDown: diff --git a/ui/opensnitch/dialogs/events/menu_actions.py b/ui/opensnitch/dialogs/events/menu_actions.py index 6f50156959..5d7e2cc67f 100644 --- a/ui/opensnitch/dialogs/events/menu_actions.py +++ b/ui/opensnitch/dialogs/events/menu_actions.py @@ -372,6 +372,11 @@ def table_menu_edit(self, cur_idx, model, selection): QtWidgets.QMessageBox.Icon.Warning) return print(node, name) + # the editor runs its own event loop, and table refreshes + # are discarded while the context menu is flagged as + # active, so saving from the editor wouldn't update the + # views otherwise. + self.set_context_menu_active(False) r = RulesEditorDialog(modal=False) r.edit_rule(records, node) @@ -386,6 +391,9 @@ def table_menu_edit(self, cur_idx, model, selection): QC.translate("stats", "Rule not found by that name and node"), QtWidgets.QMessageBox.Icon.Warning) return + # see the TAB_MAIN branch: allow table refreshes while the + # editor is open. + self.set_context_menu_active(False) r = RulesEditorDialog(modal=False) r.edit_rule(records, node) break diff --git a/ui/opensnitch/dialogs/events/menus.py b/ui/opensnitch/dialogs/events/menus.py index 8aa725ba16..40a9ff4492 100644 --- a/ui/opensnitch/dialogs/events/menus.py +++ b/ui/opensnitch/dialogs/events/menus.py @@ -1,3 +1,5 @@ +import traceback + from PyQt6 import QtCore, QtWidgets, QtGui from PyQt6.QtCore import QCoreApplication as QC @@ -251,6 +253,8 @@ def configure_rules_contextual_menu(self, pos): model = table.model() selection = table.selectedRows() + if not selection: + return False menu = QtWidgets.QMenu() durMenu = QtWidgets.QMenu(self.COL_STR_DURATION) @@ -362,10 +366,11 @@ def configure_rules_contextual_menu(self, pos): elif action == _toDisk: self.table_menu_export_disk(cur_idx, model, selection) + return True except Exception as e: print("rules contextual menu exception:", e) - finally: - return True + traceback.print_exc() + return False def configure_alerts_contextual_menu(self, pos): try: diff --git a/ui/opensnitch/dialogs/ruleseditor/dialog.py b/ui/opensnitch/dialogs/ruleseditor/dialog.py index 160d366a27..fc220a11ff 100644 --- a/ui/opensnitch/dialogs/ruleseditor/dialog.py +++ b/ui/opensnitch/dialogs/ruleseditor/dialog.py @@ -271,15 +271,17 @@ def cb_save_clicked(self): if self._old_rule_name is not None and self._old_rule_name != self.rule.name: self.delete_rule() - self._old_rule_name = rule_name + # use the saved rule name, not the one typed in the field: save_rule() + # may rename the rule (e.g. when the action of an auto-named rule + # changes). Otherwise _old_rule_name would lag behind and the next + # save would wrongly report a name conflict with the just saved rule. + self._old_rule_name = self.rule.name # after adding a new rule, we enter into EDIT mode, to allow further # changes without closing the dialog. if constants.WORK_MODE == constants.ADD_RULE: constants.WORK_MODE = constants.EDIT_RULE - self._rules.updated.emit(0) - @QtCore.pyqtSlot(str, ui_pb2.NotificationReply) def cb_notification_callback(self, addr, reply): #print(self.LOG_TAG, "Rule notification received: ", reply.id, reply.code) @@ -390,8 +392,6 @@ def delete_rule(self): # if the rule name has changed, we need to remove the old one if self._old_rule_name != self.rule.name: node = nodes.get_node_addr(self) - old_rule = self.rule - old_rule.name = self._old_rule_name if self.nodeApplyAllCheck.isChecked(): nid, noti = self._nodes.delete_rule(rule_name=self._old_rule_name, addr=None, callback=self._notification_callback) self.notifications_sent[nid] = noti diff --git a/ui/opensnitch/rules.py b/ui/opensnitch/rules.py index 4185afd37c..ff04870bdf 100644 --- a/ui/opensnitch/rules.py +++ b/ui/opensnitch/rules.py @@ -94,7 +94,7 @@ def __init__(self): QObject.__init__(self) self._db = Database.instance() - def add(self, time, node, name, description, enabled, precedence, nolog, action, duration, op_type, op_sensitive, op_operand, op_data, created): + def add(self, time, node, name, description, enabled, precedence, nolog, action, duration, op_type, op_sensitive, op_operand, op_data, created, notify=True): # don't add rule if the user has selected to exclude temporary # rules if duration in Config.RULES_DURATION_FILTER: @@ -104,6 +104,8 @@ def add(self, time, node, name, description, enabled, precedence, nolog, action, "(time, node, name, description, enabled, precedence, nolog, action, duration, operator_type, operator_sensitive, operator_operand, operator_data, created)", (time, node, name, description, enabled, precedence, nolog, action, duration, op_type, op_sensitive, op_operand, op_data, created), action_on_conflict="REPLACE") + if notify: + self.updated.emit(0) def add_rules(self, addr, rules): try: @@ -123,8 +125,12 @@ def add_rules(self, addr, rules): r.operator.type, str(r.operator.sensitive), r.operator.operand, r.operator.data, - str(datetime.fromtimestamp(r.created).strftime(DBDateFieldFormat))) + str(datetime.fromtimestamp(r.created).strftime(DBDateFieldFormat)), + notify=False) + # notify once per batch, to avoid a refresh storm when a node + # connects and sends all its rules. + self.updated.emit(0) return True except Exception as e: log.warning("exception adding node rules to db: %s", repr(e)) @@ -142,6 +148,7 @@ def delete(self, name, addr, callback): if not self._db.delete_rule(rule.name, addr): return None + self.updated.emit(0) return rule def delete_by_field(self, field, values): @@ -184,6 +191,7 @@ def disable(self, addr, name): "name=? AND node=?", action_on_conflict="OR REPLACE" ) + self.updated.emit(0) def update_time(self, time, name, addr): """Updates the time of a rule, whenever a new connection matched a diff --git a/ui/opensnitch/service.py b/ui/opensnitch/service.py index 08c66443c6..bdd0bb7a41 100644 --- a/ui/opensnitch/service.py +++ b/ui/opensnitch/service.py @@ -23,6 +23,7 @@ from opensnitch.notifications import DesktopNotifications from opensnitch.firewall import Rules as FwRules from opensnitch.nodes import Nodes +from opensnitch.rules import Rules from opensnitch.config import Config from opensnitch.version import version from opensnitch.database import Database @@ -129,6 +130,7 @@ def __init__(self, app, on_exit, start_in_bg=False): self._nodes = Nodes.instance() self._nodes.reset_status() + self._rules = Rules.instance() self._last_stats = {} self._last_items = { @@ -898,7 +900,9 @@ def _disable_temp_rule(args): ost.start() elif kwargs['action'] == self.DELETE_RULE: - self._db.delete_rule(kwargs['name'], kwargs['addr']) + # route it through Rules, so the views are notified of the + # change. + self._rules.delete(kwargs['name'], kwargs['addr'], None) elif kwargs['action'] == self.NODE_DELETE: self._delete_node(kwargs['peer']) diff --git a/ui/tests/customwidgets/__init__.py b/ui/tests/customwidgets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ui/tests/customwidgets/test_generictableview.py b/ui/tests/customwidgets/test_generictableview.py new file mode 100644 index 0000000000..eb1dc5b26f --- /dev/null +++ b/ui/tests/customwidgets/test_generictableview.py @@ -0,0 +1,199 @@ +# +# pytest -v tests/customwidgets/test_generictableview.py +# +# Regression tests for the selection tracking of GenericTableView: +# the view keeps a parallel set of selected rows (texts of the tracking +# column) which menu actions operate on, so it must always match the +# visually selected rows. +# + +import pytest +from PyQt6 import QtCore, QtWidgets +from PyQt6.QtCore import Qt + +# opensnitch.utils must be imported before opensnitch.database to resolve +# their circular import +import opensnitch.utils # noqa: F401 +from opensnitch.database import Database +from opensnitch.customwidgets.generictableview import GenericTableModel, GenericTableView + +RULES_HEADERS = ["Time", "Node", "Name", "Enabled"] +RULES_QUERY = "SELECT time, node, name, enabled FROM rules ORDER BY name ASC" +COL_NAME = 2 +NUM_RULES = 6 +TEST_NODE = "unix:/tmp/osui.sock" +# time enough for the deferred (QTimer.singleShot) selection sync to run +DEFERRED_SYNC_WAIT_MS = 50 + + +def rule_name(num): + return "rule-{0:03d}".format(num) + + +def insert_test_rules(db, names=None): + if names is None: + names = [rule_name(i) for i in range(NUM_RULES)] + db.clean("rules") + for i, name in enumerate(names): + rule_time = "2026-06-06 10:{0:02d}:{1:02d}".format(i // 60, i % 60) + db.insert( + "rules", + "(time, node, name, enabled, precedence, action, duration, " \ + "operator_type, operator_sensitive, operator_operand, operator_data, " \ + "description, nolog, created)", + ( + rule_time, TEST_NODE, name, "True", + "False", "allow", "always", "simple", "False", "process.path", + "/bin/app-{0}".format(i), "", "False", rule_time + ) + ) + + +def build_rules_view(qtbot): + container = QtWidgets.QWidget() + layout = QtWidgets.QHBoxLayout(container) + view = GenericTableView(container) + scrollbar = QtWidgets.QScrollBar(container) + layout.addWidget(view) + layout.addWidget(scrollbar) + + model = GenericTableModel("rules", RULES_HEADERS) + view.setVerticalScrollBar(scrollbar) + view.setTrackingColumn(COL_NAME) + view.setModel(model) + model.setQuery(RULES_QUERY, Database.instance().get_db()) + + qtbot.addWidget(container) + container.resize(600, 400) + container.show() + qtbot.waitExposed(container) + view.refresh() + return container, view + + +@pytest.fixture +def rules_view(qtbot): + insert_test_rules(Database.instance()) + container, view = build_rules_view(qtbot) + # the fixture frame keeps the container referenced for the duration + # of the test; qtbot only holds a weak reference + yield view + + +@pytest.fixture +def mixed_rules_view(qtbot): + """Two name groups with enough rules to scroll the view.""" + names = ["app-{0:03d}".format(i) for i in range(30)] + names += ["term-{0:03d}".format(i) for i in range(30)] + insert_test_rules(Database.instance(), names) + container, view = build_rules_view(qtbot) + yield view + + +def click_row(qtbot, view, row): + cell_rect = view.visualRect(view.model().index(row, COL_NAME)) + qtbot.mouseClick(view.viewport(), Qt.MouseButton.LeftButton, pos=cell_rect.center()) + + +def get_current_row_name(view): + cur_idx = view.selectionModel().currentIndex() + if not cur_idx.isValid(): + return None + return view.model().index(cur_idx.row(), COL_NAME).data() + + +def test_click_tracks_clicked_row(rules_view, qtbot): + click_row(qtbot, rules_view, 0) + assert rules_view._rows_selection == {rule_name(0)} + assert get_current_row_name(rules_view) == rule_name(0) + + +def test_key_down_tracks_new_current_row(rules_view, qtbot): + """Regression: the tracked selection lagged one row behind the visible + one on keyboard navigation, so actions hit the wrong rule. Also guards + against the NameError raised by onKeyDown.""" + click_row(qtbot, rules_view, 0) + + qtbot.keyClick(rules_view, Qt.Key.Key_Down) + qtbot.wait(DEFERRED_SYNC_WAIT_MS) + + assert get_current_row_name(rules_view) == rule_name(1) + assert rules_view._rows_selection == {rule_name(1)} + + +def test_key_up_tracks_new_current_row(rules_view, qtbot): + click_row(qtbot, rules_view, 2) + + qtbot.keyClick(rules_view, Qt.Key.Key_Up) + qtbot.wait(DEFERRED_SYNC_WAIT_MS) + + assert get_current_row_name(rules_view) == rule_name(1) + assert rules_view._rows_selection == {rule_name(1)} + + +def test_click_empty_area_clears_tracked_selection(rules_view, qtbot): + """Regression: clicking on the empty area below the rows cleared the + visual selection but kept the tracked rows, so menu actions kept + operating on rules no longer highlighted.""" + click_row(qtbot, rules_view, 0) + assert rules_view._rows_selection == {rule_name(0)} + + empty_area_pos = QtCore.QPoint(10, rules_view.viewport().height() - 5) + qtbot.mouseClick(rules_view.viewport(), Qt.MouseButton.LeftButton, pos=empty_area_pos) + + assert rules_view._rows_selection == set() + assert rules_view.selectedRows() is None + + +def test_viewport_refresh_preserves_selection_and_current_row(rules_view, qtbot): + """Regression: the periodic viewport refresh cleared currentIndex, so + the focused row was lost every time the daemon pushed an event.""" + click_row(qtbot, rules_view, 2) + + rules_view.refresh() + + selected = rules_view.selectionModel().selectedRows(COL_NAME) + assert [sel.data() for sel in selected] == [rule_name(2)] + assert get_current_row_name(rules_view) == rule_name(2) + + +def test_selected_rows_returns_clicked_rule(rules_view, qtbot): + """selectedRows() feeds the context-menu actions: it must return the + db row matching the visually selected rule.""" + click_row(qtbot, rules_view, 1) + + selected_db_rows = rules_view.selectedRows() + assert selected_db_rows is not None + assert len(selected_db_rows) == 1 + assert selected_db_rows[0][COL_NAME] == rule_name(1) + + +def test_query_change_refreshes_viewport_while_scrolled(mixed_rules_view, qtbot): + """Regression: changing the query (e.g. typing a filter) while the + scrollbar was not at the top or bottom of the view kept displaying the + rows of the previous query, so the visible rows didn't match the data + that selections and menu actions operated on.""" + view = mixed_rules_view + model = view.model() + + view.vScrollBar.setValue(10) + displayed = [row[COL_NAME] for row in model.items] + assert len(displayed) > 0 + assert all(name.startswith("app-") for name in displayed) + + filtered_query = "SELECT time, node, name, enabled FROM rules " \ + "WHERE name LIKE 'term-%' ORDER BY name ASC" + model.setQuery(filtered_query, Database.instance().get_db()) + + displayed = [row[COL_NAME] for row in model.items] + assert len(displayed) > 0 + assert all(name.startswith("term-") for name in displayed) + + +def test_right_press_does_not_arm_drag_selection(rules_view, qtbot): + """Regression: a right-button press armed the drag-selection logic + (mousePressed), interfering with refresh skipping and row tracking.""" + cell_rect = rules_view.visualRect(rules_view.model().index(0, COL_NAME)) + qtbot.mousePress(rules_view.viewport(), Qt.MouseButton.RightButton, pos=cell_rect.center()) + assert rules_view.mousePressed is False + qtbot.mouseRelease(rules_view.viewport(), Qt.MouseButton.RightButton, pos=cell_rect.center()) diff --git a/ui/tests/dialogs/test_ruleseditor.py b/ui/tests/dialogs/test_ruleseditor.py index 68dac1198c..570c36c782 100644 --- a/ui/tests/dialogs/test_ruleseditor.py +++ b/ui/tests/dialogs/test_ruleseditor.py @@ -286,6 +286,47 @@ def handle_dialog(): records = self.rd._db.get_rule("www.test-renamed.com", node_addr) assert records.next() == True + def test_change_action_of_autonamed_rule(self, qtbot): + """ Regression: changing the action of an auto-named rule + (reject-xxx -> allow-xxx) renames it. The name tracking lagged + behind, so saving the rule a second time wrongly reported a name + conflict with the just renamed rule. + """ + qtbot.addWidget(self.rd) + node = re_nodes.get_node_addr(self.rd) + # start from a clean state for the names used here + self.rd._db.delete_rule("reject-rename-host", node) + self.rd._db.delete_rule("allow-rename-host", node) + + re_constants.WORK_MODE = re_constants.ADD_RULE + re_utils.reset_state(self.rd) + self.rd.statusLabel.setText("") + self.rd.ruleNameEdit.setText("reject-rename-host") + self.rd.dstHostCheck.setChecked(True) + self.rd.dstHostLine.setText("rename.example.com") + self.rd.actionRejectRadio.setChecked(True) + + qtbot.mouseClick(self.rd.buttonBox.button(QtWidgets.QDialogButtonBox.StandardButton.Save), QtCore.Qt.MouseButton.LeftButton) + assert self.rd.statusLabel.text() == "" + assert self.rd._old_rule_name == "reject-rename-host" + + # change the action: the rule gets renamed to allow-rename-host + self.rd.actionAllowRadio.setChecked(True) + qtbot.mouseClick(self.rd.buttonBox.button(QtWidgets.QDialogButtonBox.StandardButton.Save), QtCore.Qt.MouseButton.LeftButton) + + assert self.rd.statusLabel.text() == "" + assert self.rd.ruleNameEdit.text() == "allow-rename-host" + # name tracking must follow the rename, not lag on the old name + assert self.rd._old_rule_name == "allow-rename-host" + assert self.rd.rule.name == "allow-rename-host" + assert self.rd._db.get_rule("reject-rename-host", node).next() == False + assert self.rd._db.get_rule("allow-rename-host", node).next() == True + + # saving again must not report a name conflict with itself + qtbot.mouseClick(self.rd.buttonBox.button(QtWidgets.QDialogButtonBox.StandardButton.Save), QtCore.Qt.MouseButton.LeftButton) + assert self.rd.statusLabel.text() == "" + assert self.rd._db.get_rule("allow-rename-host", node).next() == True + def test_durations(self, qtbot): """ Test adding new rule with action "deny". """ diff --git a/ui/tests/test_rules_signals.py b/ui/tests/test_rules_signals.py new file mode 100644 index 0000000000..9422704ddd --- /dev/null +++ b/ui/tests/test_rules_signals.py @@ -0,0 +1,92 @@ +# +# pytest -v tests/test_rules_signals.py +# +# Regression tests for the Rules.updated signal: rule changes written to +# the db from any source (pop-up answers, temporary rules expiration, +# rules received on node connection) must notify the views, so they +# refresh what they display. +# + +from datetime import datetime + +import opensnitch.proto as proto +ui_pb2, ui_pb2_grpc = proto.import_() + +# opensnitch.utils must be imported before opensnitch.database to resolve +# their circular import +import opensnitch.utils # noqa: F401 +from opensnitch.rules import Rules + +TEST_NODE = "unix:/tmp/osui.sock" +RULE_TIME = "2026-06-10 10:00:00" + + +def new_proto_rule(name): + rule = ui_pb2.Rule(name=name) + rule.enabled = True + rule.action = "allow" + rule.duration = "always" + rule.created = int(datetime.now().timestamp()) + rule.operator.type = "simple" + rule.operator.operand = "process.path" + rule.operator.data = "/bin/test-app" + return rule + + +def add_test_rule(rules, name): + rules.add( + RULE_TIME, TEST_NODE, name, "", "True", "False", "False", + "allow", "always", "simple", "False", "process.path", + "/bin/test-app", RULE_TIME + ) + + +def count_emits(rules, operation): + emitted = [] + + def _on_updated(what): + emitted.append(what) + + rules.updated.connect(_on_updated) + try: + operation() + finally: + rules.updated.disconnect(_on_updated) + return len(emitted) + + +def test_add_emits_updated(qtbot): + rules = Rules.instance() + assert count_emits(rules, lambda: add_test_rule(rules, "sig-add")) == 1 + + +def test_add_rules_emits_once_per_batch(qtbot): + """A node sends all its rules on connection: one refresh per batch, + not one per rule.""" + rules = Rules.instance() + batch = [new_proto_rule("sig-batch-0"), new_proto_rule("sig-batch-1")] + results = [] + emits = count_emits(rules, lambda: results.append(rules.add_rules(TEST_NODE, batch))) + assert results == [True] + assert emits == 1 + + +def test_delete_emits_updated(qtbot): + rules = Rules.instance() + add_test_rule(rules, "sig-delete") + assert count_emits(rules, lambda: rules.delete("sig-delete", TEST_NODE, None)) == 1 + + +def test_disable_emits_updated(qtbot): + """Temporary rules are marked as disabled in the db when they expire.""" + rules = Rules.instance() + add_test_rule(rules, "sig-disable") + assert count_emits(rules, lambda: rules.disable(TEST_NODE, "sig-disable")) == 1 + + +def test_update_time_does_not_emit(qtbot): + """update_time() runs for every connection matching a rule; emitting + here would refresh the views non-stop.""" + rules = Rules.instance() + add_test_rule(rules, "sig-time") + assert count_emits(rules, lambda: rules.update_time(RULE_TIME, "sig-time", TEST_NODE)) == 0