Skip to content

Commit 69461fa

Browse files
{Core} Use multi-threading to build command index to improve performance (#32730)
1 parent 8b8daa4 commit 69461fa

File tree

5 files changed

+302
-110
lines changed

5 files changed

+302
-110
lines changed

src/azure-cli-core/azure/cli/core/__init__.py

Lines changed: 202 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import os
1010
import sys
1111
import timeit
12+
import concurrent.futures
13+
from concurrent.futures import ThreadPoolExecutor
1214

1315
from knack.cli import CLI
1416
from knack.commands import CLICommandsLoader
@@ -34,6 +36,10 @@
3436
ALWAYS_LOADED_MODULES = []
3537
# Extensions that will always be loaded if installed. They don't expose commands but hook into CLI core.
3638
ALWAYS_LOADED_EXTENSIONS = ['azext_ai_examples', 'azext_next']
39+
# Timeout (in seconds) for loading a single module. Acts as a safety valve to prevent indefinite hangs
40+
MODULE_LOAD_TIMEOUT_SECONDS = 30
41+
# Maximum number of worker threads for parallel module loading.
42+
MAX_WORKER_THREAD_COUNT = 4
3743

3844

3945
def _configure_knack():
@@ -197,6 +203,17 @@ def _configure_style(self):
197203
format_styled_text.theme = theme
198204

199205

206+
class ModuleLoadResult: # pylint: disable=too-few-public-methods
207+
def __init__(self, module_name, command_table, group_table, elapsed_time, error=None, traceback_str=None, command_loader=None):
208+
self.module_name = module_name
209+
self.command_table = command_table
210+
self.group_table = group_table
211+
self.elapsed_time = elapsed_time
212+
self.error = error
213+
self.traceback_str = traceback_str
214+
self.command_loader = command_loader
215+
216+
200217
class MainCommandsLoader(CLICommandsLoader):
201218

202219
# Format string for pretty-print the command module table
@@ -241,11 +258,11 @@ def load_command_table(self, args):
241258
import pkgutil
242259
import traceback
243260
from azure.cli.core.commands import (
244-
_load_module_command_loader, _load_extension_command_loader, BLOCKED_MODS, ExtensionCommandSource)
261+
_load_extension_command_loader, ExtensionCommandSource)
245262
from azure.cli.core.extension import (
246263
get_extensions, get_extension_path, get_extension_modname)
247264
from azure.cli.core.breaking_change import (
248-
import_core_breaking_changes, import_module_breaking_changes, import_extension_breaking_changes)
265+
import_core_breaking_changes, import_extension_breaking_changes)
249266

250267
def _update_command_table_from_modules(args, command_modules=None):
251268
"""Loads command tables from modules and merge into the main command table.
@@ -273,41 +290,17 @@ def _update_command_table_from_modules(args, command_modules=None):
273290
except ImportError as e:
274291
logger.warning(e)
275292

276-
count = 0
277-
cumulative_elapsed_time = 0
278-
cumulative_group_count = 0
279-
cumulative_command_count = 0
280-
logger.debug("Loading command modules:")
281-
logger.debug(self.header_mod)
293+
start_time = timeit.default_timer()
294+
logger.debug("Loading command modules...")
295+
results = self._load_modules(args, command_modules)
282296

283-
for mod in [m for m in command_modules if m not in BLOCKED_MODS]:
284-
try:
285-
start_time = timeit.default_timer()
286-
module_command_table, module_group_table = _load_module_command_loader(self, args, mod)
287-
import_module_breaking_changes(mod)
288-
for cmd in module_command_table.values():
289-
cmd.command_source = mod
290-
self.command_table.update(module_command_table)
291-
self.command_group_table.update(module_group_table)
297+
count, cumulative_group_count, cumulative_command_count = \
298+
self._process_results_with_timing(results)
292299

293-
elapsed_time = timeit.default_timer() - start_time
294-
logger.debug(self.item_format_string, mod, elapsed_time,
295-
len(module_group_table), len(module_command_table))
296-
count += 1
297-
cumulative_elapsed_time += elapsed_time
298-
cumulative_group_count += len(module_group_table)
299-
cumulative_command_count += len(module_command_table)
300-
except Exception as ex: # pylint: disable=broad-except
301-
# Changing this error message requires updating CI script that checks for failed
302-
# module loading.
303-
from azure.cli.core import telemetry
304-
logger.error("Error loading command module '%s': %s", mod, ex)
305-
telemetry.set_exception(exception=ex, fault_type='module-load-error-' + mod,
306-
summary='Error loading module: {}'.format(mod))
307-
logger.debug(traceback.format_exc())
300+
total_elapsed_time = timeit.default_timer() - start_time
308301
# Summary line
309302
logger.debug(self.item_format_string,
310-
"Total ({})".format(count), cumulative_elapsed_time,
303+
"Total ({})".format(count), total_elapsed_time,
311304
cumulative_group_count, cumulative_command_count)
312305

313306
def _update_command_table_from_extensions(ext_suppressions, extension_modname=None):
@@ -345,70 +338,80 @@ def _filter_modname(extensions):
345338
return filtered_extensions
346339

347340
extensions = get_extensions()
348-
if extensions:
349-
if extension_modname is not None:
350-
extension_modname.extend(ALWAYS_LOADED_EXTENSIONS)
351-
extensions = _filter_modname(extensions)
352-
allowed_extensions = _handle_extension_suppressions(extensions)
353-
module_commands = set(self.command_table.keys())
354-
355-
count = 0
356-
cumulative_elapsed_time = 0
357-
cumulative_group_count = 0
358-
cumulative_command_count = 0
359-
logger.debug("Loading extensions:")
360-
logger.debug(self.header_ext)
361-
362-
for ext in allowed_extensions:
363-
try:
364-
# Import in the `for` loop because `allowed_extensions` can be []. In such case we
365-
# don't need to import `check_version_compatibility` at all.
366-
from azure.cli.core.extension.operations import check_version_compatibility
367-
check_version_compatibility(ext.get_metadata())
368-
except CLIError as ex:
369-
# issue warning and skip loading extensions that aren't compatible with the CLI core
370-
logger.warning(ex)
371-
continue
372-
ext_name = ext.name
373-
ext_dir = ext.path or get_extension_path(ext_name)
374-
sys.path.append(ext_dir)
375-
try:
376-
ext_mod = get_extension_modname(ext_name, ext_dir=ext_dir)
377-
# Add to the map. This needs to happen before we load commands as registering a command
378-
# from an extension requires this map to be up-to-date.
379-
# self._mod_to_ext_map[ext_mod] = ext_name
380-
start_time = timeit.default_timer()
381-
extension_command_table, extension_group_table = \
382-
_load_extension_command_loader(self, args, ext_mod)
383-
import_extension_breaking_changes(ext_mod)
384-
385-
for cmd_name, cmd in extension_command_table.items():
386-
cmd.command_source = ExtensionCommandSource(
387-
extension_name=ext_name,
388-
overrides_command=cmd_name in module_commands,
389-
preview=ext.preview,
390-
experimental=ext.experimental)
391-
392-
self.command_table.update(extension_command_table)
393-
self.command_group_table.update(extension_group_table)
394-
395-
elapsed_time = timeit.default_timer() - start_time
396-
logger.debug(self.item_ext_format_string, ext_name, elapsed_time,
397-
len(extension_group_table), len(extension_command_table),
398-
ext_dir)
399-
count += 1
400-
cumulative_elapsed_time += elapsed_time
401-
cumulative_group_count += len(extension_group_table)
402-
cumulative_command_count += len(extension_command_table)
403-
except Exception as ex: # pylint: disable=broad-except
404-
self.cli_ctx.raise_event(EVENT_FAILED_EXTENSION_LOAD, extension_name=ext_name)
405-
logger.warning("Unable to load extension '%s: %s'. Use --debug for more information.",
406-
ext_name, ex)
407-
logger.debug(traceback.format_exc())
408-
# Summary line
409-
logger.debug(self.item_ext_format_string,
410-
"Total ({})".format(count), cumulative_elapsed_time,
411-
cumulative_group_count, cumulative_command_count, "")
341+
if not extensions:
342+
return
343+
344+
if extension_modname is not None:
345+
extension_modname.extend(ALWAYS_LOADED_EXTENSIONS)
346+
extensions = _filter_modname(extensions)
347+
allowed_extensions = _handle_extension_suppressions(extensions)
348+
module_commands = set(self.command_table.keys())
349+
350+
count = 0
351+
cumulative_elapsed_time = 0
352+
cumulative_group_count = 0
353+
cumulative_command_count = 0
354+
logger.debug("Loading extensions:")
355+
logger.debug(self.header_ext)
356+
357+
for ext in allowed_extensions:
358+
try:
359+
# Import in the `for` loop because `allowed_extensions` can be []. In such case we
360+
# don't need to import `check_version_compatibility` at all.
361+
from azure.cli.core.extension.operations import check_version_compatibility
362+
check_version_compatibility(ext.get_metadata())
363+
except CLIError as ex:
364+
# issue warning and skip loading extensions that aren't compatible with the CLI core
365+
logger.warning(ex)
366+
continue
367+
ext_name = ext.name
368+
ext_dir = ext.path or get_extension_path(ext_name)
369+
sys.path.append(ext_dir)
370+
try:
371+
ext_mod = get_extension_modname(ext_name, ext_dir=ext_dir)
372+
# Add to the map. This needs to happen before we load commands as registering a command
373+
# from an extension requires this map to be up-to-date.
374+
# self._mod_to_ext_map[ext_mod] = ext_name
375+
start_time = timeit.default_timer()
376+
extension_command_table, extension_group_table, extension_command_loader = \
377+
_load_extension_command_loader(self, args, ext_mod)
378+
import_extension_breaking_changes(ext_mod)
379+
380+
for cmd_name, cmd in extension_command_table.items():
381+
cmd.command_source = ExtensionCommandSource(
382+
extension_name=ext_name,
383+
overrides_command=cmd_name in module_commands,
384+
preview=ext.preview,
385+
experimental=ext.experimental)
386+
387+
# Populate cmd_to_loader_map for extension commands
388+
if extension_command_loader:
389+
self.loaders.append(extension_command_loader)
390+
for cmd_name in extension_command_table:
391+
if cmd_name not in self.cmd_to_loader_map:
392+
self.cmd_to_loader_map[cmd_name] = []
393+
self.cmd_to_loader_map[cmd_name].append(extension_command_loader)
394+
395+
self.command_table.update(extension_command_table)
396+
self.command_group_table.update(extension_group_table)
397+
398+
elapsed_time = timeit.default_timer() - start_time
399+
logger.debug(self.item_ext_format_string, ext_name, elapsed_time,
400+
len(extension_group_table), len(extension_command_table),
401+
ext_dir)
402+
count += 1
403+
cumulative_elapsed_time += elapsed_time
404+
cumulative_group_count += len(extension_group_table)
405+
cumulative_command_count += len(extension_command_table)
406+
except Exception as ex: # pylint: disable=broad-except
407+
self.cli_ctx.raise_event(EVENT_FAILED_EXTENSION_LOAD, extension_name=ext_name)
408+
logger.warning("Unable to load extension '%s: %s'. Use --debug for more information.",
409+
ext_name, ex)
410+
logger.debug(traceback.format_exc())
411+
# Summary line
412+
logger.debug(self.item_ext_format_string,
413+
"Total ({})".format(count), cumulative_elapsed_time,
414+
cumulative_group_count, cumulative_command_count, "")
412415

413416
def _wrap_suppress_extension_func(func, ext):
414417
""" Wrapper method to handle centralization of log messages for extension filters """
@@ -587,6 +590,108 @@ def load_arguments(self, command=None):
587590
self.extra_argument_registry.update(loader.extra_argument_registry)
588591
loader._update_command_definitions() # pylint: disable=protected-access
589592

593+
def _load_modules(self, args, command_modules):
594+
"""Load command modules using ThreadPoolExecutor with timeout protection."""
595+
from azure.cli.core.commands import BLOCKED_MODS
596+
597+
results = []
598+
with ThreadPoolExecutor(max_workers=MAX_WORKER_THREAD_COUNT) as executor:
599+
future_to_module = {executor.submit(self._load_single_module, mod, args): mod
600+
for mod in command_modules if mod not in BLOCKED_MODS}
601+
602+
try:
603+
for future in concurrent.futures.as_completed(future_to_module, timeout=MODULE_LOAD_TIMEOUT_SECONDS):
604+
try:
605+
result = future.result()
606+
results.append(result)
607+
except (ImportError, AttributeError, TypeError, ValueError) as ex:
608+
mod = future_to_module[future]
609+
logger.warning("Module '%s' load failed: %s", mod, ex)
610+
results.append(ModuleLoadResult(mod, {}, {}, 0, ex))
611+
except Exception as ex: # pylint: disable=broad-exception-caught
612+
mod = future_to_module[future]
613+
logger.warning("Module '%s' load failed with unexpected exception: %s", mod, ex)
614+
results.append(ModuleLoadResult(mod, {}, {}, 0, ex))
615+
except concurrent.futures.TimeoutError:
616+
for future, mod in future_to_module.items():
617+
if future.done():
618+
try:
619+
result = future.result()
620+
results.append(result)
621+
except Exception as ex: # pylint: disable=broad-exception-caught
622+
logger.warning("Module '%s' load failed: %s", mod, ex)
623+
results.append(ModuleLoadResult(mod, {}, {}, 0, ex))
624+
else:
625+
logger.warning("Module '%s' load timeout after %s seconds", mod, MODULE_LOAD_TIMEOUT_SECONDS)
626+
results.append(ModuleLoadResult(mod, {}, {}, 0,
627+
Exception(f"Module '{mod}' load timeout")))
628+
629+
return results
630+
631+
def _load_single_module(self, mod, args):
632+
from azure.cli.core.breaking_change import import_module_breaking_changes
633+
from azure.cli.core.commands import _load_module_command_loader
634+
import traceback
635+
try:
636+
start_time = timeit.default_timer()
637+
module_command_table, module_group_table, command_loader = _load_module_command_loader(self, args, mod)
638+
import_module_breaking_changes(mod)
639+
elapsed_time = timeit.default_timer() - start_time
640+
return ModuleLoadResult(mod, module_command_table, module_group_table, elapsed_time, command_loader=command_loader)
641+
except Exception as ex: # pylint: disable=broad-except
642+
tb_str = traceback.format_exc()
643+
return ModuleLoadResult(mod, {}, {}, 0, ex, tb_str)
644+
645+
def _handle_module_load_error(self, result):
646+
"""Handle errors that occurred during module loading."""
647+
from azure.cli.core import telemetry
648+
649+
logger.error("Error loading command module '%s': %s", result.module_name, result.error)
650+
telemetry.set_exception(exception=result.error,
651+
fault_type='module-load-error-' + result.module_name,
652+
summary='Error loading module: {}'.format(result.module_name))
653+
if result.traceback_str:
654+
logger.debug(result.traceback_str)
655+
656+
def _process_successful_load(self, result):
657+
"""Process successfully loaded module results."""
658+
if result.command_loader:
659+
self.loaders.append(result.command_loader)
660+
661+
for cmd in result.command_table:
662+
if cmd not in self.cmd_to_loader_map:
663+
self.cmd_to_loader_map[cmd] = []
664+
self.cmd_to_loader_map[cmd].append(result.command_loader)
665+
666+
for cmd in result.command_table.values():
667+
cmd.command_source = result.module_name
668+
669+
self.command_table.update(result.command_table)
670+
self.command_group_table.update(result.group_table)
671+
672+
logger.debug(self.item_format_string, result.module_name, result.elapsed_time,
673+
len(result.group_table), len(result.command_table))
674+
675+
def _process_results_with_timing(self, results):
676+
"""Process pre-loaded module results with timing and progress reporting."""
677+
logger.debug("Loaded command modules in parallel:")
678+
logger.debug(self.header_mod)
679+
680+
count = 0
681+
cumulative_group_count = 0
682+
cumulative_command_count = 0
683+
684+
for result in results:
685+
if result.error:
686+
self._handle_module_load_error(result)
687+
else:
688+
self._process_successful_load(result)
689+
count += 1
690+
cumulative_group_count += len(result.group_table)
691+
cumulative_command_count += len(result.command_table)
692+
693+
return count, cumulative_group_count, cumulative_command_count
694+
590695

591696
class CommandIndex:
592697

src/azure-cli-core/azure/cli/core/commands/__init__.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,22 +1134,17 @@ def _load_command_loader(loader, args, name, prefix):
11341134
logger.debug("Module '%s' is missing `get_command_loader` entry.", name)
11351135

11361136
command_table = {}
1137+
command_loader = None
11371138

11381139
if loader_cls:
11391140
command_loader = loader_cls(cli_ctx=loader.cli_ctx)
1140-
loader.loaders.append(command_loader) # This will be used by interactive
11411141
if command_loader.supported_resource_type():
11421142
command_table = command_loader.load_command_table(args)
1143-
if command_table:
1144-
for cmd in list(command_table.keys()):
1145-
# TODO: If desired to for extension to patch module, this can be uncommented
1146-
# if loader.cmd_to_loader_map.get(cmd):
1147-
# loader.cmd_to_loader_map[cmd].append(command_loader)
1148-
# else:
1149-
loader.cmd_to_loader_map[cmd] = [command_loader]
11501143
else:
11511144
logger.debug("Module '%s' is missing `COMMAND_LOADER_CLS` entry.", name)
1152-
return command_table, command_loader.command_group_table
1145+
1146+
group_table = command_loader.command_group_table if command_loader else {}
1147+
return command_table, group_table, command_loader
11531148

11541149

11551150
def _load_extension_command_loader(loader, args, ext):

0 commit comments

Comments
 (0)