|
9 | 9 | import os |
10 | 10 | import sys |
11 | 11 | import timeit |
| 12 | +import concurrent.futures |
| 13 | +from concurrent.futures import ThreadPoolExecutor |
12 | 14 |
|
13 | 15 | from knack.cli import CLI |
14 | 16 | from knack.commands import CLICommandsLoader |
|
34 | 36 | ALWAYS_LOADED_MODULES = [] |
35 | 37 | # Extensions that will always be loaded if installed. They don't expose commands but hook into CLI core. |
36 | 38 | ALWAYS_LOADED_EXTENSIONS = ['azext_ai_examples', 'azext_next'] |
| 39 | +# Timeout (in seconds) for loading a single module. Acts as a safety valve to prevent indefinite hangs |
| 40 | +MODULE_LOAD_TIMEOUT_SECONDS = 30 |
| 41 | +# Maximum number of worker threads for parallel module loading. |
| 42 | +MAX_WORKER_THREAD_COUNT = 4 |
37 | 43 |
|
38 | 44 |
|
39 | 45 | def _configure_knack(): |
@@ -197,6 +203,17 @@ def _configure_style(self): |
197 | 203 | format_styled_text.theme = theme |
198 | 204 |
|
199 | 205 |
|
| 206 | +class ModuleLoadResult: # pylint: disable=too-few-public-methods |
| 207 | + def __init__(self, module_name, command_table, group_table, elapsed_time, error=None, traceback_str=None, command_loader=None): |
| 208 | + self.module_name = module_name |
| 209 | + self.command_table = command_table |
| 210 | + self.group_table = group_table |
| 211 | + self.elapsed_time = elapsed_time |
| 212 | + self.error = error |
| 213 | + self.traceback_str = traceback_str |
| 214 | + self.command_loader = command_loader |
| 215 | + |
| 216 | + |
200 | 217 | class MainCommandsLoader(CLICommandsLoader): |
201 | 218 |
|
202 | 219 | # Format string for pretty-print the command module table |
@@ -241,11 +258,11 @@ def load_command_table(self, args): |
241 | 258 | import pkgutil |
242 | 259 | import traceback |
243 | 260 | from azure.cli.core.commands import ( |
244 | | - _load_module_command_loader, _load_extension_command_loader, BLOCKED_MODS, ExtensionCommandSource) |
| 261 | + _load_extension_command_loader, ExtensionCommandSource) |
245 | 262 | from azure.cli.core.extension import ( |
246 | 263 | get_extensions, get_extension_path, get_extension_modname) |
247 | 264 | from azure.cli.core.breaking_change import ( |
248 | | - import_core_breaking_changes, import_module_breaking_changes, import_extension_breaking_changes) |
| 265 | + import_core_breaking_changes, import_extension_breaking_changes) |
249 | 266 |
|
250 | 267 | def _update_command_table_from_modules(args, command_modules=None): |
251 | 268 | """Loads command tables from modules and merge into the main command table. |
@@ -273,41 +290,17 @@ def _update_command_table_from_modules(args, command_modules=None): |
273 | 290 | except ImportError as e: |
274 | 291 | logger.warning(e) |
275 | 292 |
|
276 | | - count = 0 |
277 | | - cumulative_elapsed_time = 0 |
278 | | - cumulative_group_count = 0 |
279 | | - cumulative_command_count = 0 |
280 | | - logger.debug("Loading command modules:") |
281 | | - logger.debug(self.header_mod) |
| 293 | + start_time = timeit.default_timer() |
| 294 | + logger.debug("Loading command modules...") |
| 295 | + results = self._load_modules(args, command_modules) |
282 | 296 |
|
283 | | - for mod in [m for m in command_modules if m not in BLOCKED_MODS]: |
284 | | - try: |
285 | | - start_time = timeit.default_timer() |
286 | | - module_command_table, module_group_table = _load_module_command_loader(self, args, mod) |
287 | | - import_module_breaking_changes(mod) |
288 | | - for cmd in module_command_table.values(): |
289 | | - cmd.command_source = mod |
290 | | - self.command_table.update(module_command_table) |
291 | | - self.command_group_table.update(module_group_table) |
| 297 | + count, cumulative_group_count, cumulative_command_count = \ |
| 298 | + self._process_results_with_timing(results) |
292 | 299 |
|
293 | | - elapsed_time = timeit.default_timer() - start_time |
294 | | - logger.debug(self.item_format_string, mod, elapsed_time, |
295 | | - len(module_group_table), len(module_command_table)) |
296 | | - count += 1 |
297 | | - cumulative_elapsed_time += elapsed_time |
298 | | - cumulative_group_count += len(module_group_table) |
299 | | - cumulative_command_count += len(module_command_table) |
300 | | - except Exception as ex: # pylint: disable=broad-except |
301 | | - # Changing this error message requires updating CI script that checks for failed |
302 | | - # module loading. |
303 | | - from azure.cli.core import telemetry |
304 | | - logger.error("Error loading command module '%s': %s", mod, ex) |
305 | | - telemetry.set_exception(exception=ex, fault_type='module-load-error-' + mod, |
306 | | - summary='Error loading module: {}'.format(mod)) |
307 | | - logger.debug(traceback.format_exc()) |
| 300 | + total_elapsed_time = timeit.default_timer() - start_time |
308 | 301 | # Summary line |
309 | 302 | logger.debug(self.item_format_string, |
310 | | - "Total ({})".format(count), cumulative_elapsed_time, |
| 303 | + "Total ({})".format(count), total_elapsed_time, |
311 | 304 | cumulative_group_count, cumulative_command_count) |
312 | 305 |
|
313 | 306 | def _update_command_table_from_extensions(ext_suppressions, extension_modname=None): |
@@ -345,70 +338,80 @@ def _filter_modname(extensions): |
345 | 338 | return filtered_extensions |
346 | 339 |
|
347 | 340 | extensions = get_extensions() |
348 | | - if extensions: |
349 | | - if extension_modname is not None: |
350 | | - extension_modname.extend(ALWAYS_LOADED_EXTENSIONS) |
351 | | - extensions = _filter_modname(extensions) |
352 | | - allowed_extensions = _handle_extension_suppressions(extensions) |
353 | | - module_commands = set(self.command_table.keys()) |
354 | | - |
355 | | - count = 0 |
356 | | - cumulative_elapsed_time = 0 |
357 | | - cumulative_group_count = 0 |
358 | | - cumulative_command_count = 0 |
359 | | - logger.debug("Loading extensions:") |
360 | | - logger.debug(self.header_ext) |
361 | | - |
362 | | - for ext in allowed_extensions: |
363 | | - try: |
364 | | - # Import in the `for` loop because `allowed_extensions` can be []. In such case we |
365 | | - # don't need to import `check_version_compatibility` at all. |
366 | | - from azure.cli.core.extension.operations import check_version_compatibility |
367 | | - check_version_compatibility(ext.get_metadata()) |
368 | | - except CLIError as ex: |
369 | | - # issue warning and skip loading extensions that aren't compatible with the CLI core |
370 | | - logger.warning(ex) |
371 | | - continue |
372 | | - ext_name = ext.name |
373 | | - ext_dir = ext.path or get_extension_path(ext_name) |
374 | | - sys.path.append(ext_dir) |
375 | | - try: |
376 | | - ext_mod = get_extension_modname(ext_name, ext_dir=ext_dir) |
377 | | - # Add to the map. This needs to happen before we load commands as registering a command |
378 | | - # from an extension requires this map to be up-to-date. |
379 | | - # self._mod_to_ext_map[ext_mod] = ext_name |
380 | | - start_time = timeit.default_timer() |
381 | | - extension_command_table, extension_group_table = \ |
382 | | - _load_extension_command_loader(self, args, ext_mod) |
383 | | - import_extension_breaking_changes(ext_mod) |
384 | | - |
385 | | - for cmd_name, cmd in extension_command_table.items(): |
386 | | - cmd.command_source = ExtensionCommandSource( |
387 | | - extension_name=ext_name, |
388 | | - overrides_command=cmd_name in module_commands, |
389 | | - preview=ext.preview, |
390 | | - experimental=ext.experimental) |
391 | | - |
392 | | - self.command_table.update(extension_command_table) |
393 | | - self.command_group_table.update(extension_group_table) |
394 | | - |
395 | | - elapsed_time = timeit.default_timer() - start_time |
396 | | - logger.debug(self.item_ext_format_string, ext_name, elapsed_time, |
397 | | - len(extension_group_table), len(extension_command_table), |
398 | | - ext_dir) |
399 | | - count += 1 |
400 | | - cumulative_elapsed_time += elapsed_time |
401 | | - cumulative_group_count += len(extension_group_table) |
402 | | - cumulative_command_count += len(extension_command_table) |
403 | | - except Exception as ex: # pylint: disable=broad-except |
404 | | - self.cli_ctx.raise_event(EVENT_FAILED_EXTENSION_LOAD, extension_name=ext_name) |
405 | | - logger.warning("Unable to load extension '%s: %s'. Use --debug for more information.", |
406 | | - ext_name, ex) |
407 | | - logger.debug(traceback.format_exc()) |
408 | | - # Summary line |
409 | | - logger.debug(self.item_ext_format_string, |
410 | | - "Total ({})".format(count), cumulative_elapsed_time, |
411 | | - cumulative_group_count, cumulative_command_count, "") |
| 341 | + if not extensions: |
| 342 | + return |
| 343 | + |
| 344 | + if extension_modname is not None: |
| 345 | + extension_modname.extend(ALWAYS_LOADED_EXTENSIONS) |
| 346 | + extensions = _filter_modname(extensions) |
| 347 | + allowed_extensions = _handle_extension_suppressions(extensions) |
| 348 | + module_commands = set(self.command_table.keys()) |
| 349 | + |
| 350 | + count = 0 |
| 351 | + cumulative_elapsed_time = 0 |
| 352 | + cumulative_group_count = 0 |
| 353 | + cumulative_command_count = 0 |
| 354 | + logger.debug("Loading extensions:") |
| 355 | + logger.debug(self.header_ext) |
| 356 | + |
| 357 | + for ext in allowed_extensions: |
| 358 | + try: |
| 359 | + # Import in the `for` loop because `allowed_extensions` can be []. In such case we |
| 360 | + # don't need to import `check_version_compatibility` at all. |
| 361 | + from azure.cli.core.extension.operations import check_version_compatibility |
| 362 | + check_version_compatibility(ext.get_metadata()) |
| 363 | + except CLIError as ex: |
| 364 | + # issue warning and skip loading extensions that aren't compatible with the CLI core |
| 365 | + logger.warning(ex) |
| 366 | + continue |
| 367 | + ext_name = ext.name |
| 368 | + ext_dir = ext.path or get_extension_path(ext_name) |
| 369 | + sys.path.append(ext_dir) |
| 370 | + try: |
| 371 | + ext_mod = get_extension_modname(ext_name, ext_dir=ext_dir) |
| 372 | + # Add to the map. This needs to happen before we load commands as registering a command |
| 373 | + # from an extension requires this map to be up-to-date. |
| 374 | + # self._mod_to_ext_map[ext_mod] = ext_name |
| 375 | + start_time = timeit.default_timer() |
| 376 | + extension_command_table, extension_group_table, extension_command_loader = \ |
| 377 | + _load_extension_command_loader(self, args, ext_mod) |
| 378 | + import_extension_breaking_changes(ext_mod) |
| 379 | + |
| 380 | + for cmd_name, cmd in extension_command_table.items(): |
| 381 | + cmd.command_source = ExtensionCommandSource( |
| 382 | + extension_name=ext_name, |
| 383 | + overrides_command=cmd_name in module_commands, |
| 384 | + preview=ext.preview, |
| 385 | + experimental=ext.experimental) |
| 386 | + |
| 387 | + # Populate cmd_to_loader_map for extension commands |
| 388 | + if extension_command_loader: |
| 389 | + self.loaders.append(extension_command_loader) |
| 390 | + for cmd_name in extension_command_table: |
| 391 | + if cmd_name not in self.cmd_to_loader_map: |
| 392 | + self.cmd_to_loader_map[cmd_name] = [] |
| 393 | + self.cmd_to_loader_map[cmd_name].append(extension_command_loader) |
| 394 | + |
| 395 | + self.command_table.update(extension_command_table) |
| 396 | + self.command_group_table.update(extension_group_table) |
| 397 | + |
| 398 | + elapsed_time = timeit.default_timer() - start_time |
| 399 | + logger.debug(self.item_ext_format_string, ext_name, elapsed_time, |
| 400 | + len(extension_group_table), len(extension_command_table), |
| 401 | + ext_dir) |
| 402 | + count += 1 |
| 403 | + cumulative_elapsed_time += elapsed_time |
| 404 | + cumulative_group_count += len(extension_group_table) |
| 405 | + cumulative_command_count += len(extension_command_table) |
| 406 | + except Exception as ex: # pylint: disable=broad-except |
| 407 | + self.cli_ctx.raise_event(EVENT_FAILED_EXTENSION_LOAD, extension_name=ext_name) |
| 408 | + logger.warning("Unable to load extension '%s: %s'. Use --debug for more information.", |
| 409 | + ext_name, ex) |
| 410 | + logger.debug(traceback.format_exc()) |
| 411 | + # Summary line |
| 412 | + logger.debug(self.item_ext_format_string, |
| 413 | + "Total ({})".format(count), cumulative_elapsed_time, |
| 414 | + cumulative_group_count, cumulative_command_count, "") |
412 | 415 |
|
413 | 416 | def _wrap_suppress_extension_func(func, ext): |
414 | 417 | """ Wrapper method to handle centralization of log messages for extension filters """ |
@@ -587,6 +590,108 @@ def load_arguments(self, command=None): |
587 | 590 | self.extra_argument_registry.update(loader.extra_argument_registry) |
588 | 591 | loader._update_command_definitions() # pylint: disable=protected-access |
589 | 592 |
|
| 593 | + def _load_modules(self, args, command_modules): |
| 594 | + """Load command modules using ThreadPoolExecutor with timeout protection.""" |
| 595 | + from azure.cli.core.commands import BLOCKED_MODS |
| 596 | + |
| 597 | + results = [] |
| 598 | + with ThreadPoolExecutor(max_workers=MAX_WORKER_THREAD_COUNT) as executor: |
| 599 | + future_to_module = {executor.submit(self._load_single_module, mod, args): mod |
| 600 | + for mod in command_modules if mod not in BLOCKED_MODS} |
| 601 | + |
| 602 | + try: |
| 603 | + for future in concurrent.futures.as_completed(future_to_module, timeout=MODULE_LOAD_TIMEOUT_SECONDS): |
| 604 | + try: |
| 605 | + result = future.result() |
| 606 | + results.append(result) |
| 607 | + except (ImportError, AttributeError, TypeError, ValueError) as ex: |
| 608 | + mod = future_to_module[future] |
| 609 | + logger.warning("Module '%s' load failed: %s", mod, ex) |
| 610 | + results.append(ModuleLoadResult(mod, {}, {}, 0, ex)) |
| 611 | + except Exception as ex: # pylint: disable=broad-exception-caught |
| 612 | + mod = future_to_module[future] |
| 613 | + logger.warning("Module '%s' load failed with unexpected exception: %s", mod, ex) |
| 614 | + results.append(ModuleLoadResult(mod, {}, {}, 0, ex)) |
| 615 | + except concurrent.futures.TimeoutError: |
| 616 | + for future, mod in future_to_module.items(): |
| 617 | + if future.done(): |
| 618 | + try: |
| 619 | + result = future.result() |
| 620 | + results.append(result) |
| 621 | + except Exception as ex: # pylint: disable=broad-exception-caught |
| 622 | + logger.warning("Module '%s' load failed: %s", mod, ex) |
| 623 | + results.append(ModuleLoadResult(mod, {}, {}, 0, ex)) |
| 624 | + else: |
| 625 | + logger.warning("Module '%s' load timeout after %s seconds", mod, MODULE_LOAD_TIMEOUT_SECONDS) |
| 626 | + results.append(ModuleLoadResult(mod, {}, {}, 0, |
| 627 | + Exception(f"Module '{mod}' load timeout"))) |
| 628 | + |
| 629 | + return results |
| 630 | + |
| 631 | + def _load_single_module(self, mod, args): |
| 632 | + from azure.cli.core.breaking_change import import_module_breaking_changes |
| 633 | + from azure.cli.core.commands import _load_module_command_loader |
| 634 | + import traceback |
| 635 | + try: |
| 636 | + start_time = timeit.default_timer() |
| 637 | + module_command_table, module_group_table, command_loader = _load_module_command_loader(self, args, mod) |
| 638 | + import_module_breaking_changes(mod) |
| 639 | + elapsed_time = timeit.default_timer() - start_time |
| 640 | + return ModuleLoadResult(mod, module_command_table, module_group_table, elapsed_time, command_loader=command_loader) |
| 641 | + except Exception as ex: # pylint: disable=broad-except |
| 642 | + tb_str = traceback.format_exc() |
| 643 | + return ModuleLoadResult(mod, {}, {}, 0, ex, tb_str) |
| 644 | + |
| 645 | + def _handle_module_load_error(self, result): |
| 646 | + """Handle errors that occurred during module loading.""" |
| 647 | + from azure.cli.core import telemetry |
| 648 | + |
| 649 | + logger.error("Error loading command module '%s': %s", result.module_name, result.error) |
| 650 | + telemetry.set_exception(exception=result.error, |
| 651 | + fault_type='module-load-error-' + result.module_name, |
| 652 | + summary='Error loading module: {}'.format(result.module_name)) |
| 653 | + if result.traceback_str: |
| 654 | + logger.debug(result.traceback_str) |
| 655 | + |
| 656 | + def _process_successful_load(self, result): |
| 657 | + """Process successfully loaded module results.""" |
| 658 | + if result.command_loader: |
| 659 | + self.loaders.append(result.command_loader) |
| 660 | + |
| 661 | + for cmd in result.command_table: |
| 662 | + if cmd not in self.cmd_to_loader_map: |
| 663 | + self.cmd_to_loader_map[cmd] = [] |
| 664 | + self.cmd_to_loader_map[cmd].append(result.command_loader) |
| 665 | + |
| 666 | + for cmd in result.command_table.values(): |
| 667 | + cmd.command_source = result.module_name |
| 668 | + |
| 669 | + self.command_table.update(result.command_table) |
| 670 | + self.command_group_table.update(result.group_table) |
| 671 | + |
| 672 | + logger.debug(self.item_format_string, result.module_name, result.elapsed_time, |
| 673 | + len(result.group_table), len(result.command_table)) |
| 674 | + |
| 675 | + def _process_results_with_timing(self, results): |
| 676 | + """Process pre-loaded module results with timing and progress reporting.""" |
| 677 | + logger.debug("Loaded command modules in parallel:") |
| 678 | + logger.debug(self.header_mod) |
| 679 | + |
| 680 | + count = 0 |
| 681 | + cumulative_group_count = 0 |
| 682 | + cumulative_command_count = 0 |
| 683 | + |
| 684 | + for result in results: |
| 685 | + if result.error: |
| 686 | + self._handle_module_load_error(result) |
| 687 | + else: |
| 688 | + self._process_successful_load(result) |
| 689 | + count += 1 |
| 690 | + cumulative_group_count += len(result.group_table) |
| 691 | + cumulative_command_count += len(result.command_table) |
| 692 | + |
| 693 | + return count, cumulative_group_count, cumulative_command_count |
| 694 | + |
590 | 695 |
|
591 | 696 | class CommandIndex: |
592 | 697 |
|
|
0 commit comments