diff --git a/src/azure-cli-core/azure/cli/core/aaz/_client.py b/src/azure-cli-core/azure/cli/core/aaz/_client.py index 21cd0b63965..6f52948934d 100644 --- a/src/azure-cli-core/azure/cli/core/aaz/_client.py +++ b/src/azure-cli-core/azure/cli/core/aaz/_client.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +import pickle +import base64 from azure.core import PipelineClient from azure.core.configuration import Configuration from azure.core.polling.base_polling import LocationPolling, StatusCheckPolling @@ -17,6 +19,7 @@ def register_client(name): def decorator(cls): + assert issubclass(cls, AAZPipelineClient) if name in registered_clients: assert registered_clients[name] == cls else: @@ -26,8 +29,17 @@ def decorator(cls): return decorator +class AAZPipelineClient(PipelineClient): + + def from_continuation_token(self, continuation_token): + session = pickle.loads(base64.b64decode(continuation_token)) # nosec + # Restore the transport in the context + session.context.transport = self._pipeline._transport # pylint: disable=protected-access + return session + + @register_client("MgmtClient") -class AAZMgmtClient(PipelineClient): +class AAZMgmtClient(AAZPipelineClient): """Management Client for Management Plane APIs""" class _Configuration(Configuration): diff --git a/src/azure-cli-core/azure/cli/core/aaz/_command.py b/src/azure-cli-core/azure/cli/core/aaz/_command.py index c47780b7dbf..3667c9aade7 100644 --- a/src/azure-cli-core/azure/cli/core/aaz/_command.py +++ b/src/azure-cli-core/azure/cli/core/aaz/_command.py @@ -60,6 +60,7 @@ class AAZCommand(CLICommand): AZ_NAME = None AZ_HELP = None AZ_SUPPORT_NO_WAIT = False + AZ_SUPPORT_LRO_CONTINUE = False AZ_SUPPORT_GENERIC_UPDATE = False AZ_CONFIRMATION = None @@ -83,8 +84,14 @@ def _build_arguments_schema(cls, *args, **kwargs): if cls.AZ_SUPPORT_NO_WAIT: schema.no_wait = AAZBoolArg( options=['--no-wait'], - help='Do not wait for the long-running operation to finish.' + help='Do not wait for the long-running operation to finish.' + ( + " The continuation token will be cached locally." if cls.AZ_SUPPORT_LRO_CONTINUE else "") ) + if cls.AZ_SUPPORT_LRO_CONTINUE: + schema.lro_continue = AAZBoolArg( + options=['--lro-continue'], + help='Continue the long-running operation from cached continuation token.' + ) if cls.AZ_SUPPORT_GENERIC_UPDATE: schema.generic_update_add = AAZGenericUpdateAddArg() schema.generic_update_set = AAZGenericUpdateSetArg() @@ -232,7 +239,10 @@ def build_lro_poller(self, executor, extract_result): polling_generator = executor() if self.ctx.lro_no_wait: # run until yield the first polling - _ = next(polling_generator) + polling = next(polling_generator) + if polling and self.AZ_SUPPORT_LRO_CONTINUE: + self.ctx.cache_continuation_token(polling) + logger.warning("The continuation token is cached locally. You can use `--lro-continue` for polling.") return None return AAZLROPoller(polling_generator=polling_generator, result_callback=extract_result) diff --git a/src/azure-cli-core/azure/cli/core/aaz/_command_ctx.py b/src/azure-cli-core/azure/cli/core/aaz/_command_ctx.py index 40409230a10..8b95eff0ad8 100644 --- a/src/azure-cli-core/azure/cli/core/aaz/_command_ctx.py +++ b/src/azure-cli-core/azure/cli/core/aaz/_command_ctx.py @@ -7,6 +7,13 @@ from azure.cli.core._profile import Profile from azure.cli.core.azclierror import InvalidArgumentValueError +import os +import time +from urllib.parse import urlparse, urlunparse +from azure.cli.core._environment import get_config_dir +from knack.config import _ConfigFile +from knack.util import ensure_dir +import configparser from ._arg_action import AAZArgActionOperations, AAZGenericUpdateAction from ._base import AAZUndefined @@ -47,6 +54,8 @@ def __init__(self, cli_ctx, schema, command_args, no_wait_arg=None): self._aux_subscriptions = set() self._aux_tenants = set() + # command config file + self._command_config = None def format_args(self): try: @@ -117,6 +126,60 @@ def aux_subscriptions(self): def aux_tenants(self): return list(self._aux_tenants) or None + def _get_command_cache_directory(self): + return os.path.join( + get_config_dir(), + 'command_cache', + self._cli_ctx.cloud.name, + self.subscription_id, + ) + + def _load_command_cache_config(self): + config_dir = self._get_command_cache_directory() + ensure_dir(config_dir) + config_path = os.path.join(config_dir, "command_cache.json") + config = _ConfigFile(config_dir=config_dir, config_path=config_path) + # clean up expired section + now = time.time() + clean_sections = [] + for section in config.sections(): + if config.has_option(section, "expires_at") and config.getfloat(section, "expires_at") < now: + clean_sections.append(section) + for section in clean_sections: + config.remove_section(section) + return config + + @property + def command_config(self): + if not self._command_config: + self._command_config = self._load_command_cache_config() + return self._command_config + + def get_continuation_token(self, http_operation): + section = self.get_command_cache_section( + self._cli_ctx.data['command'], http_operation.method, http_operation.url + ) + try: + continuation_token = self.command_config.get(section, "continuation_token") + except (configparser.NoSectionError, configparser.NoOptionError): + raise InvalidArgumentValueError( + "Cannot find cached continuation token for the long-running operation: --lro-continue") + return continuation_token + + def cache_continuation_token(self, polling): + request = polling._initial_response.http_request + section = self.get_command_cache_section(self._cli_ctx.data['command'], request.method, request.url) + continuation_token = polling.get_continuation_token() + expires_at = int(time.time()) + 24*60*60 + self.command_config.set_value(section, "continuation_token", continuation_token) + self.command_config.set_value(section, "expires_at", str(expires_at)) + + def get_command_cache_section(self, command_name, method, url): + method = method.upper() + parsed = urlparse(url) + url = urlunparse([parsed.scheme, parsed.netloc, parsed.path, None, None, None]) + return f"{command_name};{method};{url}" + def get_subscription_locations(ctx: AAZCommandCtx): from azure.cli.core.commands.parameters import get_subscription_locations as _get_subscription_locations