diff --git a/vmupdate/qube_connection.py b/vmupdate/qube_connection.py index 56aadd8..22658ed 100644 --- a/vmupdate/qube_connection.py +++ b/vmupdate/qube_connection.py @@ -18,6 +18,8 @@ # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, # USA. + +import asyncio import os import shutil import signal @@ -30,11 +32,11 @@ import qubesadmin import qubesadmin.exc +from qubesadmin.utils import shutdown from vmupdate.agent.source.args import AgentArgs from vmupdate.agent.source.log_config import LOGPATH, LOG_FILE from vmupdate.agent.source.status import StatusInfo, FinalStatus, FormatedLine from vmupdate.agent.source.common.process_result import ProcessResult -from vmupdate.utils import shutdown_domains class QubeConnection: @@ -93,15 +95,20 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) if self.qube.is_running() and not self._initially_running: + wait = False if self._has_assigned_pci_devices(self.qube): self.logger.info( "Waiting for full shutdown %s (PCI devices assigned)", self.qube.name, ) - shutdown_domains([self.qube], self.logger) - else: - self.logger.info("Shutdown %s", self.qube.name) - self.qube.shutdown() + wait = True + failed = asyncio.run( + shutdown(domains=[self.qube], force=False, wait=wait) + ) + if failed: + exc = list(failed.values())[0] + self.logger.error(str(exc)) + raise exc self.__connected = False diff --git a/vmupdate/tests/test_qube_connection.py b/vmupdate/tests/test_qube_connection.py index b707a85..7b7c905 100644 --- a/vmupdate/tests/test_qube_connection.py +++ b/vmupdate/tests/test_qube_connection.py @@ -23,8 +23,7 @@ from vmupdate.qube_connection import QubeConnection -@patch("vmupdate.qube_connection.shutdown_domains") -def test_wait_for_shutdown_when_vm_started_by_update(shutdown_domains): +def test_wait_for_shutdown_when_vm_started_by_update(): vm = Mock() vm.name = "hvm1" vm.is_running.side_effect = [False, True] @@ -43,12 +42,10 @@ def test_wait_for_shutdown_when_vm_started_by_update(shutdown_domains): ): pass - shutdown_domains.assert_called_once_with([vm], logger) - vm.shutdown.assert_not_called() + vm.shutdown.assert_called_once_with(force=False, wait=True) -@patch("vmupdate.qube_connection.shutdown_domains") -def test_do_not_wait_for_shutdown_without_assigned_pci(shutdown_domains): +def test_do_not_wait_for_shutdown_without_assigned_pci(): vm = Mock() vm.name = "hvm2" vm.is_running.side_effect = [False, True] @@ -67,12 +64,10 @@ def test_do_not_wait_for_shutdown_without_assigned_pci(shutdown_domains): ): pass - vm.shutdown.assert_called_once_with() - shutdown_domains.assert_not_called() + vm.shutdown.assert_called_once_with(force=False, wait=False) -@patch("vmupdate.qube_connection.shutdown_domains") -def test_do_not_shutdown_if_vm_was_already_running(shutdown_domains): +def test_do_not_shutdown_if_vm_was_already_running(): vm = Mock() vm.name = "hvm3" vm.is_running.return_value = True @@ -92,4 +87,3 @@ def test_do_not_shutdown_if_vm_was_already_running(shutdown_domains): pass vm.shutdown.assert_not_called() - shutdown_domains.assert_not_called() diff --git a/vmupdate/tests/test_vmupdate.py b/vmupdate/tests/test_vmupdate.py index 7b0b4d5..b6fcaae 100644 --- a/vmupdate/tests/test_vmupdate.py +++ b/vmupdate/tests/test_vmupdate.py @@ -383,11 +383,9 @@ def test_selection( @patch("vmupdate.update_manager.UpdateAgentManager") @patch("multiprocessing.Pool") @patch("multiprocessing.Manager") -@patch("asyncio.run") @patch("subprocess.Popen") def test_restarting( dummy_subprocess, - arun, mp_manager, mp_pool, agent_mng, @@ -506,7 +504,6 @@ def test_restarting( fails = {args: failed[args] for args in failed if failed[args]} assert not fails - arun.asseert_called() stat = FinalStatus @@ -719,7 +716,6 @@ def test_error( @patch("os.chown") @patch("logging.FileHandler") @patch("logging.getLogger") -@patch("asyncio.run") @pytest.mark.parametrize( "action, code", ( @@ -729,7 +725,6 @@ def test_error( ), ) def test_error_apply( - _arun, _logger, _log_file, _chmod, diff --git a/vmupdate/utils.py b/vmupdate/utils.py index b822c1d..f9afffb 100644 --- a/vmupdate/utils.py +++ b/vmupdate/utils.py @@ -22,27 +22,40 @@ from datetime import datetime import qubesadmin.exc -from qubesadmin.events.utils import wait_for_domain_shutdown +from qubesadmin.utils import shutdown, start from vmupdate.agent.source.common.exit_codes import EXIT -def shutdown_domains(to_shutdown, log): +async def shutdown_domains( + to_shutdown, + log, + wait: bool = False, + force: bool = False, +): """ Try to shut down vms and wait to finish. """ ret_code = EXIT.OK - wait_for = [] - for vm in to_shutdown: - try: - vm.shutdown(force=True) - wait_for.append(vm) - except qubesadmin.exc.QubesVMError as exc: - log.error(str(exc)) - ret_code = EXIT.ERR_SHUTDOWN_APP + all_failed = [] + failed = await shutdown(domains=to_shutdown, wait=wait, force=force) + for qube, exc in failed.items(): + log.error(str(exc)) + all_failed.append(qube) + ret_code = EXIT.ERR_SHUTDOWN_APP + done = [qube for qube in to_shutdown if qube not in all_failed] + return ret_code, done - asyncio.run(wait_for_domain_shutdown(wait_for)) - return ret_code, wait_for +async def restart_vms(to_restart, log): + """ + Try to restart vms. + """ + ret_code, shutdowns = await shutdown_domains(to_restart, log) + failed = await start(domains=shutdowns) + for exc in failed.values(): + log.error(str(exc)) + ret_code = EXIT.ERR_START_APP + return ret_code def get_feature(vm, feature_name, default_value=None): diff --git a/vmupdate/vmupdate.py b/vmupdate/vmupdate.py index ddb08fe..f7b5300 100644 --- a/vmupdate/vmupdate.py +++ b/vmupdate/vmupdate.py @@ -4,6 +4,7 @@ """ import argparse +import asyncio import logging import sys import os @@ -11,11 +12,13 @@ from typing import Set, Iterable, Dict, Tuple import qubesadmin +import qubesadmin.utils import qubesadmin.exc from vmupdate.agent.source.status import FinalStatus from vmupdate.agent.source.common.exit_codes import EXIT from vmupdate.utils import ( shutdown_domains, + restart_vms, get_feature, get_boolean_feature, is_stale, @@ -119,8 +122,10 @@ def main(args=None, app=qubesadmin.Qubes()): if ret_code_appvm == EXIT.SIGINT: return EXIT.SIGINT - ret_code_restart = apply_updates_to_appvm( - args, independent, templ_statuses, app_statuses, log + ret_code_restart = asyncio.run( + apply_updates_to_appvm( + args, independent, templ_statuses, app_statuses, log + ) ) ret_code = max( @@ -396,7 +401,7 @@ def run_update( return ret_code, statuses -def apply_updates_to_appvm( +async def apply_updates_to_appvm( args, vm_updated: Iterable, template_statuses: Dict[str, FinalStatus], @@ -445,7 +450,7 @@ def apply_updates_to_appvm( # first shutdown templates to apply changes to the root volume # they are no need to start templates automatically - ret_code, _ = shutdown_domains(templates_to_shutdown, log) + ret_code, _ = await shutdown_domains(templates_to_shutdown, log) if ret_code != EXIT.OK: log.error("Shutdown of some templates fails with code %d", ret_code) @@ -464,11 +469,11 @@ def apply_updates_to_appvm( ) # both flags `restart` and `apply-to-all` include service vms - ret_code_ = restart_vms(to_restart, log) + ret_code_ = await restart_vms(to_restart, log) ret_code = max(ret_code, ret_code_) if args.apply_to_all: # there is no need to start plain AppVMs automatically - ret_code_, _ = shutdown_domains(to_shutdown, log) + ret_code_, _ = await shutdown_domains(to_shutdown, log) ret_code = max(ret_code, ret_code_) return ret_code @@ -496,22 +501,5 @@ def get_derived_vm_to_apply(templates, derived_statuses): return to_restart, to_shutdown -def restart_vms(to_restart, log): - """ - Try to restart vms. - """ - ret_code, shutdowns = shutdown_domains(to_restart, log) - - # restart shutdown qubes - for vm in shutdowns: - try: - vm.start() - except qubesadmin.exc.QubesVMError as exc: - log.error(str(exc)) - ret_code = EXIT.ERR_START_APP - - return ret_code - - if __name__ == "__main__": sys.exit(main())