diff --git a/vmupdate/qube_connection.py b/vmupdate/qube_connection.py index 543974b..9da6b15 100644 --- a/vmupdate/qube_connection.py +++ b/vmupdate/qube_connection.py @@ -29,10 +29,12 @@ from typing import List import qubesadmin +import qubesadmin.exc from vmupdate.agent.source.args import AgentArgs from vmupdate.agent.source.log_config import LOGPATH, LOG_FILE from vmupdate.agent.source.status import StatusInfo, FinalStatus, FormatedLine from vmupdate.agent.source.common.process_result import ProcessResult +from vmupdate.utils import shutdown_domains class QubeConnection: @@ -91,11 +93,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) if self.qube.is_running() and not self._initially_running: - self.logger.info("Shutdown %s", self.qube.name) - self.qube.shutdown() + if self._has_assigned_pci_devices(self.qube): + self.logger.info( + 'Waiting for full shutdown %s (PCI devices assigned)', + self.qube.name) + shutdown_domains([self.qube], self.logger) + else: + self.logger.info('Shutdown %s', self.qube.name) + self.qube.shutdown() self.__connected = False + @staticmethod + def _has_assigned_pci_devices(vm) -> bool: + """Return True when VM has assigned PCI devices.""" + try: + return any(vm.devices['pci'].get_assigned_devices()) + except qubesadmin.exc.QubesDaemonAccessError: + return False + def transfer_agent(self, src_dir: str) -> ProcessResult: """ Copy a directory content to the workdir in the qube. diff --git a/vmupdate/tests/test_qube_connection.py b/vmupdate/tests/test_qube_connection.py new file mode 100644 index 0000000..d8f7669 --- /dev/null +++ b/vmupdate/tests/test_qube_connection.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# +# The Qubes OS Project, https://www.qubes-os.org +# +# Copyright (C) 2025 Jayant Saxena +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, +# USA. +from unittest.mock import Mock, patch + +from vmupdate.qube_connection import QubeConnection + + +@patch("vmupdate.qube_connection.shutdown_domains") +def test_wait_for_shutdown_when_vm_started_by_update(shutdown_domains): + vm = Mock() + vm.name = "hvm1" + vm.is_running.side_effect = [False, True] + vm.devices = {'pci': Mock()} + vm.devices['pci'].get_assigned_devices.return_value = ["00_1f.2"] + status_notifier = Mock() + logger = Mock() + + with QubeConnection( + vm, "/tmp/qubes-update", cleanup=False, logger=logger, + show_progress=False, status_notifier=status_notifier): + pass + + shutdown_domains.assert_called_once_with([vm], logger) + vm.shutdown.assert_not_called() + + +@patch("vmupdate.qube_connection.shutdown_domains") +def test_do_not_wait_for_shutdown_without_assigned_pci(shutdown_domains): + vm = Mock() + vm.name = "hvm2" + vm.is_running.side_effect = [False, True] + vm.devices = {'pci': Mock()} + vm.devices['pci'].get_assigned_devices.return_value = [] + status_notifier = Mock() + logger = Mock() + + with QubeConnection( + vm, "/tmp/qubes-update", cleanup=False, logger=logger, + show_progress=False, status_notifier=status_notifier): + pass + + vm.shutdown.assert_called_once_with() + shutdown_domains.assert_not_called() + + +@patch("vmupdate.qube_connection.shutdown_domains") +def test_do_not_shutdown_if_vm_was_already_running(shutdown_domains): + vm = Mock() + vm.name = "hvm3" + vm.is_running.return_value = True + vm.devices = {'pci': Mock()} + vm.devices['pci'].get_assigned_devices.return_value = ["00_1f.2"] + status_notifier = Mock() + logger = Mock() + + with QubeConnection( + vm, "/tmp/qubes-update", cleanup=False, logger=logger, + show_progress=False, status_notifier=status_notifier): + pass + + vm.shutdown.assert_not_called() + shutdown_domains.assert_not_called() diff --git a/vmupdate/utils.py b/vmupdate/utils.py new file mode 100644 index 0000000..ae76775 --- /dev/null +++ b/vmupdate/utils.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# +# The Qubes OS Project, http://www.qubes-os.org +# +# Copyright (C) 2022 Piotr Bartman +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, +# USA. +import asyncio +from datetime import datetime + +import qubesadmin.exc +from qubesadmin.events.utils import wait_for_domain_shutdown +from vmupdate.agent.source.common.exit_codes import EXIT + + +def shutdown_domains(to_shutdown, log): + """ + Try to shut down vms and wait to finish. + """ + ret_code = EXIT.OK + wait_for = [] + for vm in to_shutdown: + try: + vm.shutdown(force=True) + wait_for.append(vm) + except qubesadmin.exc.QubesVMError as exc: + log.error(str(exc)) + ret_code = EXIT.ERR_SHUTDOWN_APP + + asyncio.run(wait_for_domain_shutdown(wait_for)) + + return ret_code, wait_for + + +def get_feature(vm, feature_name, default_value=None): + """Get feature, with a working default_value.""" + try: + return vm.features.get(feature_name, default_value) + except qubesadmin.exc.QubesDaemonAccessError: + return default_value + + +def get_boolean_feature(vm, feature_name, default=False): + """Helper function to get a feature converted to bool if it exists. + + Necessary because true/false in features are coded as 1/empty string. + """ + result = get_feature(vm, feature_name, None) + if result is not None: + result = bool(result) + else: + result = default + return result + + +def is_stale(vm, expiration_period): + """Return True if VM has not been checked for updates recently.""" + today = datetime.today() + try: + if not ('qrexec' in vm.features.keys() + and vm.features.get('os', '') == 'Linux'): + return False + + last_update_str = vm.features.check_with_template( + 'last-updates-check', + datetime.fromtimestamp(0).strftime('%Y-%m-%d %H:%M:%S') + ) + last_update = datetime.fromisoformat(last_update_str) + if (today - last_update).days > expiration_period: + return True + except qubesadmin.exc.QubesDaemonCommunicationError: + pass + return False diff --git a/vmupdate/vmupdate.py b/vmupdate/vmupdate.py index edc6cb8..d9fa476 100644 --- a/vmupdate/vmupdate.py +++ b/vmupdate/vmupdate.py @@ -4,19 +4,18 @@ """ import argparse -import asyncio import logging import sys import os import grp -from datetime import datetime from typing import Set, Iterable, Dict, Tuple import qubesadmin import qubesadmin.exc -from qubesadmin.events.utils import wait_for_domain_shutdown from vmupdate.agent.source.status import FinalStatus from vmupdate.agent.source.common.exit_codes import EXIT +from vmupdate.utils import shutdown_domains, get_feature, get_boolean_feature, \ + is_stale from . import update_manager from .agent.source.args import AgentArgs @@ -355,27 +354,6 @@ def select_targets(targets, args) -> Set[qubesadmin.vm.QubesVM]: return selected -def is_stale(vm, expiration_period): - today = datetime.today() - try: - if not ( - "qrexec" in vm.features.keys() - and vm.features.get("os", "") == "Linux" - ): - return False - - last_update_str = vm.features.check_with_template( - "last-updates-check", - datetime.fromtimestamp(0).strftime("%Y-%m-%d %H:%M:%S"), - ) - last_update = datetime.fromisoformat(last_update_str) - if (today - last_update).days > expiration_period: - return True - except qubesadmin.exc.QubesDaemonCommunicationError: - pass - return False - - def run_update( targets, args, log, qube_klass="qubes", dom0=False ) -> Tuple[int, Dict[str, FinalStatus]]: @@ -408,26 +386,6 @@ def run_update( return ret_code, statuses -def get_feature(vm, feature_name, default_value=None): - """Get feature, with a working default_value.""" - try: - return vm.features.get(feature_name, default_value) - except qubesadmin.exc.QubesDaemonAccessError: - return default_value - - -def get_boolean_feature(vm, feature_name, default=False): - """helper function to get a feature converted to a Bool if it does exist. - Necessary because of the true/false in features being coded as 1/empty - string.""" - result = get_feature(vm, feature_name, None) - if result is not None: - result = bool(result) - else: - result = default - return result - - def apply_updates_to_appvm( args, vm_updated: Iterable, @@ -528,24 +486,6 @@ def get_derived_vm_to_apply(templates, derived_statuses): return to_restart, to_shutdown -def shutdown_domains(to_shutdown, log): - """ - Try to shut down vms and wait to finish. - """ - ret_code = EXIT.OK - wait_for = [] - for vm in to_shutdown: - try: - vm.shutdown(force=True) - wait_for.append(vm) - except qubesadmin.exc.QubesVMError as exc: - log.error(str(exc)) - ret_code = EXIT.ERR_SHUTDOWN_APP - - asyncio.run(wait_for_domain_shutdown(wait_for)) - - return ret_code, wait_for - def restart_vms(to_restart, log): """