Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/azure-cli-core/azure/cli/core/aaz/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import pickle
import base64
from azure.core import PipelineClient
from azure.core.configuration import Configuration
from azure.core.polling.base_polling import LocationPolling, StatusCheckPolling
Expand All @@ -17,6 +19,7 @@

def register_client(name):
def decorator(cls):
assert issubclass(cls, AAZPipelineClient)
if name in registered_clients:
assert registered_clients[name] == cls
else:
Expand All @@ -26,8 +29,17 @@ def decorator(cls):
return decorator


class AAZPipelineClient(PipelineClient):

def from_continuation_token(self, continuation_token):
session = pickle.loads(base64.b64decode(continuation_token)) # nosec
# Restore the transport in the context
session.context.transport = self._pipeline._transport # pylint: disable=protected-access
return session


@register_client("MgmtClient")
class AAZMgmtClient(PipelineClient):
class AAZMgmtClient(AAZPipelineClient):
"""Management Client for Management Plane APIs"""

class _Configuration(Configuration):
Expand Down
14 changes: 12 additions & 2 deletions src/azure-cli-core/azure/cli/core/aaz/_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class AAZCommand(CLICommand):
AZ_NAME = None
AZ_HELP = None
AZ_SUPPORT_NO_WAIT = False
AZ_SUPPORT_LRO_CONTINUE = False
AZ_SUPPORT_GENERIC_UPDATE = False

AZ_CONFIRMATION = None
Expand All @@ -83,8 +84,14 @@ def _build_arguments_schema(cls, *args, **kwargs):
if cls.AZ_SUPPORT_NO_WAIT:
schema.no_wait = AAZBoolArg(
options=['--no-wait'],
help='Do not wait for the long-running operation to finish.'
help='Do not wait for the long-running operation to finish.' + (
" The continuation token will be cached locally." if cls.AZ_SUPPORT_LRO_CONTINUE else "")
)
if cls.AZ_SUPPORT_LRO_CONTINUE:
schema.lro_continue = AAZBoolArg(
options=['--lro-continue'],
help='Continue the long-running operation from cached continuation token.'
)
if cls.AZ_SUPPORT_GENERIC_UPDATE:
schema.generic_update_add = AAZGenericUpdateAddArg()
schema.generic_update_set = AAZGenericUpdateSetArg()
Expand Down Expand Up @@ -232,7 +239,10 @@ def build_lro_poller(self, executor, extract_result):
polling_generator = executor()
if self.ctx.lro_no_wait:
# run until yield the first polling
_ = next(polling_generator)
polling = next(polling_generator)
if polling and self.AZ_SUPPORT_LRO_CONTINUE:
self.ctx.cache_continuation_token(polling)
logger.warning("The continuation token is cached locally. You can use `--lro-continue` for polling.")
return None
return AAZLROPoller(polling_generator=polling_generator, result_callback=extract_result)

Expand Down
63 changes: 63 additions & 0 deletions src/azure-cli-core/azure/cli/core/aaz/_command_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

from azure.cli.core._profile import Profile
from azure.cli.core.azclierror import InvalidArgumentValueError
import os
import time
from urllib.parse import urlparse, urlunparse
from azure.cli.core._environment import get_config_dir
from knack.config import _ConfigFile
from knack.util import ensure_dir
import configparser

from ._arg_action import AAZArgActionOperations, AAZGenericUpdateAction
from ._base import AAZUndefined
Expand Down Expand Up @@ -47,6 +54,8 @@ def __init__(self, cli_ctx, schema, command_args, no_wait_arg=None):

self._aux_subscriptions = set()
self._aux_tenants = set()
# command config file
self._command_config = None

def format_args(self):
try:
Expand Down Expand Up @@ -117,6 +126,60 @@ def aux_subscriptions(self):
def aux_tenants(self):
return list(self._aux_tenants) or None

def _get_command_cache_directory(self):
return os.path.join(
get_config_dir(),
'command_cache',
self._cli_ctx.cloud.name,
self.subscription_id,
)

def _load_command_cache_config(self):
config_dir = self._get_command_cache_directory()
ensure_dir(config_dir)
config_path = os.path.join(config_dir, "command_cache.json")
config = _ConfigFile(config_dir=config_dir, config_path=config_path)
# clean up expired section
now = time.time()
clean_sections = []
for section in config.sections():
if config.has_option(section, "expires_at") and config.getfloat(section, "expires_at") < now:
clean_sections.append(section)
for section in clean_sections:
config.remove_section(section)
return config

@property
def command_config(self):
if not self._command_config:
self._command_config = self._load_command_cache_config()
return self._command_config

def get_continuation_token(self, http_operation):
section = self.get_command_cache_section(
self._cli_ctx.data['command'], http_operation.method, http_operation.url
)
try:
continuation_token = self.command_config.get(section, "continuation_token")
except (configparser.NoSectionError, configparser.NoOptionError):
raise InvalidArgumentValueError(
"Cannot find cached continuation token for the long-running operation: --lro-continue")
return continuation_token

def cache_continuation_token(self, polling):
request = polling._initial_response.http_request
section = self.get_command_cache_section(self._cli_ctx.data['command'], request.method, request.url)
continuation_token = polling.get_continuation_token()
expires_at = int(time.time()) + 24*60*60
self.command_config.set_value(section, "continuation_token", continuation_token)
self.command_config.set_value(section, "expires_at", str(expires_at))

def get_command_cache_section(self, command_name, method, url):
method = method.upper()
parsed = urlparse(url)
url = urlunparse([parsed.scheme, parsed.netloc, parsed.path, None, None, None])
return f"{command_name};{method};{url}"


def get_subscription_locations(ctx: AAZCommandCtx):
from azure.cli.core.commands.parameters import get_subscription_locations as _get_subscription_locations
Expand Down
Loading