diff --git a/src/aks-agent/HISTORY.rst b/src/aks-agent/HISTORY.rst index 57f3ff8a2da..2367a9c214a 100644 --- a/src/aks-agent/HISTORY.rst +++ b/src/aks-agent/HISTORY.rst @@ -12,6 +12,11 @@ To release a new version, please select a new version number (usually plus 1 to Pending +++++++ +1.0.0b6 ++++++++ +* Introduce the new `az aks agent-init` command for better cli interaction. +* Separate llm configuration from main agent command for improved clarity and extensibility. + 1.0.0b5 +++++++ * Bump holmesgpt to 0.15.0 - Enhanced AI debugging experience and bug fixes diff --git a/src/aks-agent/README.rst b/src/aks-agent/README.rst index 94df60cabb3..2d33953250a 100644 --- a/src/aks-agent/README.rst +++ b/src/aks-agent/README.rst @@ -4,32 +4,41 @@ Azure CLI AKS Agent Extension Introduction ============ -The AKS Agent extension provides the "az aks agent" command, an AI-powered assistant that -helps analyze and troubleshoot Azure Kubernetes Service (AKS) clusters using Large Language -Models (LLMs). The agent combines cluster context, configurable toolsets, and LLMs to answer -natural-language questions about your cluster (for example, "Why are my pods not starting?") -and can investigate issues in both interactive and non-interactive (batch) modes. + +The AKS Agent extension provides the "az aks agent" command, an AI-powered assistant that helps analyze and troubleshoot Azure Kubernetes Service (AKS) clusters using Large Language Models (LLMs). The agent combines cluster context, configurable toolsets, and LLMs to answer natural-language questions about your cluster (for example, "Why are my pods not starting?") and can investigate issues in both interactive and non-interactive (batch) modes. + +New in this version: **az aks agent-init** command for easy LLM model configuration! + +You can now use `az aks agent-init` to interactively add and configure LLM models before asking questions. This command guides you through the setup process, allowing you to add multiple models as needed. When asking questions with `az aks agent`, you can: + +- Use `--config-file` to specify your own model configuration file +- Use `--model` to select a previously configured model +- If neither is provided, the last configured LLM will be used by default + +This makes it much easier to manage and switch between multiple models for your AKS troubleshooting workflows. Key capabilities ---------------- + - Interactive and non-interactive modes (use --no-interactive for batch runs). -- Support for multiple LLM providers (Azure OpenAI, OpenAI, etc.) via environment variables. -- Configurable via a JSON/YAML config file provided with --config-file. +- Support for multiple LLM providers (Azure OpenAI, OpenAI, etc.) via interactive configuration. +- **Easy model setup with `az aks agent-init`**: interactively add and configure LLM models, run multiple times to add more models. +- Configurable via a JSON/YAML config file provided with --config-file, or select a model with --model. +- If no config or model is specified, the last configured LLM is used automatically. - Control echo and tool output visibility with --no-echo-request and --show-tool-output. - Refresh the available toolsets with --refresh-toolsets. - Stay in traditional toolset mode by default, or opt in to aks-mcp integration with ``--aks-mcp`` when you need the enhanced capabilities. Prerequisites ------------- - -Before using the agent, make sure provider-specific environment variables are set. For -example, Azure OpenAI typically requires AZURE_API_BASE, AZURE_API_VERSION, and AZURE_API_KEY, -while OpenAI requires OPENAI_API_KEY. For more details about supported providers and required +No need to manually set environment variables! All model and credential information can be configured interactively using `az aks agent-init`. +For more details about supported model providers and required variables, see: https://docs.litellm.ai/docs/providers + Quick start and examples -======================== +========================= Install the extension --------------------- @@ -38,25 +47,58 @@ Install the extension az extension add --name aks-agent -Run the agent (Azure OpenAI example) +Configure LLM models interactively +---------------------------------- + +.. code-block:: bash + + az aks agent-init + +This command will guide you through adding a new LLM model. You can run it multiple times to add more models or update existing models. All configured models are saved locally and can be selected when asking questions. + +Run the agent (Azure OpenAI example) : ----------------------------------- +**1. Use the last configured model (no extra parameters needed):** + .. code-block:: bash - export AZURE_API_BASE="https://my-azureopenai-service.openai.azure.com/" - export AZURE_API_VERSION="2025-01-01-preview" - export AZURE_API_KEY="sk-xxx" + az aks agent "Why are my pods not starting?" --name MyManagedCluster --resource-group MyResourceGroup + +**2. Specify a particular model you have configured:** + +.. code-block:: bash az aks agent "Why are my pods not starting?" --name MyManagedCluster --resource-group MyResourceGroup --model azure/my-gpt4.1-deployment +**3. Use a custom config file:** + +.. code-block:: bash + + az aks agent "Why are my pods not starting?" --config-file /path/to/your/model_config.yaml + + Run the agent (OpenAI example) ------------------------------ +**1. Use the last configured model (no extra parameters needed):** + .. code-block:: bash - export OPENAI_API_KEY="sk-xxx" + az aks agent "Why are my pods not starting?" --name MyManagedCluster --resource-group MyResourceGroup + +**2. Specify a particular model you have configured:** + +.. code-block:: bash + az aks agent "Why are my pods not starting?" --name MyManagedCluster --resource-group MyResourceGroup --model gpt-4o +**3. Use a custom config file:** + +.. code-block:: bash + + az aks agent "Why are my pods not starting?" --config-file /path/to/your/model_config.yaml + Run in non-interactive batch mode --------------------------------- diff --git a/src/aks-agent/azext_aks_agent/_help.py b/src/aks-agent/azext_aks_agent/_help.py index 429a351a09c..becc69c98f4 100644 --- a/src/aks-agent/azext_aks_agent/_help.py +++ b/src/aks-agent/azext_aks_agent/_help.py @@ -16,7 +16,7 @@ short-summary: Run AI assistant to analyze and troubleshoot Kubernetes clusters. long-summary: |- This command allows you to ask questions about your Azure Kubernetes cluster and get answers using AI models. - Environment variables must be set to use the AI model, please refer to https://docs.litellm.ai/docs/providers to learn more about supported AI providers and models and required environment variables. + No need to manually set environment variables! All model and credential information can be configured interactively using `az aks agent-init` or via a config file. parameters: - name: --name -n type: string @@ -36,7 +36,7 @@ Note: For Azure OpenAI, it is recommended to set the deployment name as the model name until https://github.com/BerriAI/litellm/issues/13950 is resolved. - name: --api-key type: string - short-summary: API key to use for the LLM (if not given, uses environment variables AZURE_API_KEY, OPENAI_API_KEY). + short-summary: API key to use for the LLM (if not given, uses environment variables AZURE_API_KEY, OPENAI_API_KEY). (Deprecated) - name: --config-file type: string short-summary: Path to configuration file. @@ -63,23 +63,25 @@ short-summary: Enable AKS MCP integration for enhanced capabilities. Traditional mode is the default. examples: + - name: Ask about pod issues in the cluster with last configured model + text: |- + az aks agent "Why are my pods not starting?" --name MyManagedCluster --resource-group MyResourceGroup - name: Ask about pod issues in the cluster with Azure OpenAI text: |- - export AZURE_API_BASE="https://my-azureopenai-service.openai.azure.com/" - export AZURE_API_VERSION="2025-01-01-preview" - export AZURE_API_KEY="sk-xxx" az aks agent "Why are my pods not starting?" --name MyManagedCluster --resource-group MyResourceGroup --model azure/gpt-4.1 - name: Ask about pod issues in the cluster with OpenAI text: |- - export OPENAI_API_KEY="sk-xxx" az aks agent "Why are my pods not starting?" --name MyManagedCluster --resource-group MyResourceGroup --model gpt-4o - name: Run agent with config file text: | az aks agent "Check kubernetes pod resource usage" --config-file /path/to/custom.yaml --name MyManagedCluster --resource-group MyResourceGroup Here is an example of config file: ```json - model: "azure/gpt-4.1" - api_key: "..." + llms: + - provider: "azure" + MODEL_NAME: "gpt-4.1" + AZURE_API_BASE: "https://" + AZURE_API_KEY: "" # define a list of mcp servers, mcp server can be defined mcp_servers: aks_mcp: @@ -131,3 +133,16 @@ - name: Refresh toolsets to get the latest available tools text: az aks agent "What is the status of my cluster?" --refresh-toolsets --model azure/my-gpt4.1-deployment """ + +helps[ + "aks agent-init" +] = """ + type: command + short-summary: Initialize and validate LLM provider/model configuration for AKS agent. + long-summary: |- + This command interactively guides you to select an LLM provider and model, validates the connection, and saves the configuration for later use. + You can run this command multiple times to add or update different model configurations. + examples: + - name: Initialize configuration for Azure OpenAI, OpenAI or other llms + text: az aks agent-init +""" diff --git a/src/aks-agent/azext_aks_agent/agent/agent.py b/src/aks-agent/azext_aks_agent/agent/agent.py index 50e079de4c7..95afc37fc63 100644 --- a/src/aks-agent/azext_aks_agent/agent/agent.py +++ b/src/aks-agent/azext_aks_agent/agent/agent.py @@ -371,6 +371,7 @@ async def _setup_mcp_mode(mcp_manager, config_file: str, model: str, api_key: st # Generate enhanced MCP config mcp_config_dict = ConfigurationGenerator.generate_mcp_config(base_config_dict, server_url) + mcp_config_dict.pop("llms", None) # Remove existing llms to avoid conflicts # Create temporary config file with MCP settings with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as temp_file: @@ -739,6 +740,7 @@ def _setup_traditional_mode_sync(config_file: str, model: str, api_key: str, # Generate traditional config traditional_config_dict = ConfigurationGenerator.generate_traditional_config(base_config_dict) + traditional_config_dict.pop("llms", None) # Remove existing llms to avoid conflicts # Create temporary config and load with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as temp_file: diff --git a/src/aks-agent/azext_aks_agent/agent/llm_config_manager.py b/src/aks-agent/azext_aks_agent/agent/llm_config_manager.py new file mode 100644 index 00000000000..73ae82c4d94 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/llm_config_manager.py @@ -0,0 +1,91 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +import os +from typing import List, Dict, Optional +import yaml + +from azure.cli.core.api import get_config_dir +from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME + + +class LLMConfigManager: + """Manages loading and saving LLM configuration from/to a YAML file.""" + + def __init__(self, config_path=None): + if config_path is None: + config_path = os.path.join( + get_config_dir(), CONST_AGENT_CONFIG_FILE_NAME) + self.config_path = os.path.expanduser(config_path) + + def save(self, provider_name: str, params: dict): + configs = self.load() + if not isinstance(configs, Dict): + configs = {} + + models = configs.get("llms", []) + model_name = params.get("MODEL_NAME") + if not model_name: + raise ValueError("MODEL_NAME is required to save configuration.") + + # Check if model already exists, update it and move it to the last; + # otherwise, append new + models = [ + cfg for cfg in models if not ( + cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name)] + models.append({"provider": provider_name, **params}) + + configs["llms"] = models + + with open(self.config_path, "w") as f: + yaml.safe_dump(configs, f, sort_keys=False) + + def load(self): + """Load configurations from the YAML file.""" + if not os.path.exists(self.config_path): + return {} + with open(self.config_path, "r") as f: + configs = yaml.safe_load(f) + return configs if isinstance(configs, Dict) else {} + + def get_list(self) -> List[Dict]: + """Get the list of all model configurations""" + return self.load()["llms"] if self.load( + ) and "llms" in self.load() else [] + + def get_latest(self) -> Optional[Dict]: + """Get the last model configuration""" + model_configs = self.get_list() + if model_configs: + return model_configs[-1] + raise ValueError( + "No configurations found. Please run `az aks agent-init`") + + def get_specific( + self, + provider_name: str, + model_name: str) -> Optional[Dict]: + """ + Get specific model configuration by provider and model name during Q&A with --model provider/model + """ + model_configs = self.get_list() + for cfg in model_configs: + if cfg.get("provider") == provider_name and cfg.get( + "MODEL_NAME") == model_name: + return cfg + raise ValueError( + f"No configuration found for provider '{provider_name}' with model '{model_name}'. " + f"Please run `az aks agent-init`") + + def is_config_complete(self, config, provider_schema): + """ + Check if the given config has all required keys and valid values as per the provider schema. + """ + for key, meta in provider_schema.items(): + if meta.get("validator") and not meta["validator"]( + config.get(key)): + return False + return True diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py new file mode 100644 index 00000000000..0efd01fbb48 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py @@ -0,0 +1,85 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from typing import List, Tuple +from rich.console import Console +from .base import LLMProvider +from .azure_provider import AzureProvider +from .openai_provider import OpenAIProvider +from .anthropic_provider import AnthropicProvider +from .gemini_provider import GeminiProvider +from .openai_compatible_provider import OpenAICompatibleProvider + + +console = Console() + +_PROVIDER_CLASSES: List[LLMProvider] = [ + AzureProvider, + OpenAIProvider, + AnthropicProvider, + GeminiProvider, + OpenAICompatibleProvider, + # Add new providers here +] + +PROVIDER_REGISTRY = {} +for cls in _PROVIDER_CLASSES: + key = cls.name.lower() + if key not in PROVIDER_REGISTRY: + PROVIDER_REGISTRY[key] = cls + + +def _available_providers() -> List[str]: + """Return a list of registered provider names (lowercase): ["azure", "openai", ...]""" + return list(PROVIDER_REGISTRY.keys()) + + +def _provider_choices_numbered() -> List[Tuple[int, str]]: + """Return numbered choices: [(1, "azure"), (2, "openai"), ...].""" + return [(i + 1, name) for i, name in enumerate(_available_providers())] + + +def _get_provider_by_index(idx: int) -> LLMProvider: + """ + Return provider instance by numeric index (1-based). + Raises ValueError if index is out of range. + """ + from holmes.utils.colors import HELP_COLOR + if 1 <= idx <= len(_PROVIDER_CLASSES): + console.print("You selected provider:", _PROVIDER_CLASSES[idx - 1].name, style=f"bold {HELP_COLOR}") + return _PROVIDER_CLASSES[idx - 1]() + raise ValueError(f"Invalid provider index: {idx}") + + +def prompt_provider_choice() -> LLMProvider: + """ + Show a numbered menu and return the chosen provider instance. + Keeps prompting until a valid selection is made. + """ + from holmes.utils.colors import HELP_COLOR, ERROR_COLOR + from holmes.interactive import SlashCommands + choices = _provider_choices_numbered() + if not choices: + raise ValueError("No providers are registered.") + while True: + for idx, name in choices: + console.print(f" {idx}. {name}", style=f"bold {HELP_COLOR}") + sel_idx = console.input( + f"[bold {HELP_COLOR}]Enter the number of your LLM provider: [/bold {HELP_COLOR}]").strip().lower() + + if sel_idx == "/exit": + raise SystemExit(0) + try: + return _get_provider_by_index(int(sel_idx)) + except ValueError as e: + console.print( + f"{e}. Please enter a valid number, or type '{SlashCommands.EXIT.command}' to exit.", + style=f"{ERROR_COLOR}") + + +__all__ = [ + "PROVIDER_REGISTRY", + "prompt_provider_choice", +] diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/anthropic_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/anthropic_provider.py new file mode 100644 index 00000000000..6a911c2e4db --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/anthropic_provider.py @@ -0,0 +1,60 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +import requests +from .base import LLMProvider, non_empty + + +class AnthropicProvider(LLMProvider): + name = "anthropic" + + @property + def parameter_schema(self): + return { + "ANTHROPIC_API_KEY": { + "secret": True, + "default": None, + "hint": None, + "validator": non_empty + }, + "MODEL_NAME": { + "secret": False, + "default": "claude-3", + "hint": None, + "validator": non_empty + }, + } + + def validate_connection(self, params: dict): + api_key = params.get("ANTHROPIC_API_KEY") + model_name = params.get("MODEL_NAME") + + if not all([api_key, model_name]): + return False, "Missing required Anthropic parameters.", "retry_input" + + url = "https://api.anthropic.com/v1/messages" + headers = { + "x-api-key": api_key, + "anthropic-version": "2023-06-01", + "Content-Type": "application/json" + } + payload = { + "model": model_name, + "max_tokens": 16, + "messages": [{"role": "user", "content": "ping"}] + } + + try: + resp = requests.post(url, headers=headers, + json=payload, timeout=10) + resp.raise_for_status() + return True, "Connection successful.", "save" + except requests.exceptions.HTTPError as e: + if 400 <= resp.status_code < 500: + return False, f"Client error: {e} - {resp.text}", "retry_input" + return False, f"Server error: {e} - {resp.text}", "connection_error" + except requests.exceptions.RequestException as e: + return False, f"Request error: {e}", "connection_error" diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py new file mode 100644 index 00000000000..ce7e0510635 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py @@ -0,0 +1,72 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +import requests +from typing import Tuple +from urllib.parse import urljoin, urlencode +from .base import LLMProvider, is_valid_url, non_empty + + +class AzureProvider(LLMProvider): + name = "azure" + + @property + def parameter_schema(self): + return { + "MODEL_NAME": { + "secret": False, + "default": None, + "hint": "should be consistent with your deployed name, e.g., gpt-4.1", + "validator": non_empty + }, + "AZURE_API_KEY": { + "secret": True, + "default": None, + "hint": None, + "validator": non_empty + }, + "AZURE_API_BASE": { + "secret": False, + "default": None, + "hint": "https://{your-custom-endpoint}.openai.azure.com/", + "validator": is_valid_url + }, + "AZURE_API_VERSION": { + "secret": False, + "default": "2025-04-01-preview", + "hint": None, + "validator": non_empty + } + } + + def validate_connection(self, params: dict) -> Tuple[bool, str, str]: + api_key = params.get("AZURE_API_KEY") + api_base = params.get("AZURE_API_BASE") + api_version = params.get("AZURE_API_VERSION") + model_name = params.get("MODEL_NAME") + + if not all([api_key, api_base, api_version, model_name]): + return False, "Missing required Azure parameters.", "retry_input" + + # REST API reference: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/api-version-lifecycle?tabs=rest + url = urljoin(api_base, "openai/responses") + query = {"api-version": api_version} + full_url = f"{url}?{urlencode(query)}" + headers = {"api-key": api_key, "Content-Type": "application/json"} + payload = {"model": model_name, + "input": "ping", "max_output_tokens": 16} + + try: + resp = requests.post(full_url, headers=headers, + json=payload, timeout=10) + resp.raise_for_status() + return True, "Connection successful.", "save" + except requests.exceptions.HTTPError as e: + if 400 <= resp.status_code < 500: + return False, f"Client error: {e} - {resp.text}", "retry_input" + return False, f"Server error: {e} - {resp.text}", "connection_error" + except requests.exceptions.RequestException as e: + return False, f"Request error: {e}", "connection_error" diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/base.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/base.py new file mode 100644 index 00000000000..1a59fd1824d --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/base.py @@ -0,0 +1,116 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +from abc import ABC, abstractmethod +from typing import Dict, Callable, Tuple, Any +from rich.console import Console +from rich.prompt import Prompt +from urllib.parse import urlparse + +console = Console() +HINT_COLOR = "bright_black" +DEFAULT_COLOR = "bright_black" + + +def non_empty(v: str) -> bool: + return bool(v and v.strip()) + + +def is_valid_url(v: str) -> bool: + try: + parsed = urlparse(v) + if not parsed.scheme or not parsed.netloc: + return False + return True + except ValueError: + return False + + +class LLMProvider(ABC): + name = "base" + + @property + @abstractmethod + def parameter_schema(self) -> Dict[str, Dict[str, Any]]: + """ + provider may return a schema mapping param -> metadata: + { + "PARAM_NAME": { + "prompt": "Prompt to show user", + "secret": True/False, + "default": "default value or None", + "hint": "Additional hint to show user", + "validator": Callable[[str], bool] # function to validate input + } + } + """ + raise NotImplementedError() + + def prompt_params(self): + """Prompt user for parameters using parameter_schema when available.""" + from holmes.utils.colors import HELP_COLOR, ERROR_COLOR + from holmes.interactive import SlashCommands + + schema = self.parameter_schema + params = {} + for param, meta in schema.items(): + prompt = meta.get("prompt", f"[bold {HELP_COLOR}]Enter value for {param}: [/]") + default = meta.get("default") + hint = meta.get("hint") + secret = meta.get("secret", False) + validator: Callable[[str], bool] = meta.get( + "validator", lambda x: True) + + if default: + prompt += f" [italic {DEFAULT_COLOR}](Default: {default})[/] " + if hint: + prompt += f" [italic {HINT_COLOR}](Hint: {hint})[/] " + + while True: + if secret: + value = Prompt.ask( + f"[bold {HELP_COLOR}]Enter your API key[/]", + password=True + ) + else: + value = console.input(prompt) + + if not value and default is not None: + value = default + + value = value.strip() + if value == "/exit": + raise SystemExit(0) + if validator(value): + params[param] = value + break + console.print( + f"Invalid value for {param}. Please try again, or type '{SlashCommands.EXIT.command}' to exit.", + style=f"{ERROR_COLOR}") + + return params + + def validate_params(self, params: dict): + """Validate parameters from provided config file against schema.""" + schema = self.parameter_schema + for param, meta in schema.items(): + if param not in params: + raise ValueError(f"Missing required parameter: {param}") + validator: Callable[[str], bool] = meta.get( + "validator", lambda x: True) + if not validator(params[param]): + raise ValueError(f"Invalid value for parameter: {param}") + return True + + # pylint: disable=unused-argument + @abstractmethod + def validate_connection(self, params: dict) -> Tuple[bool, str, str]: + """ + Validate connection to the model endpoint using provided parameters. + Returns a tuple of (is_valid: bool, message: str, action: str) + where action can be "retry_input", "connection_error", or "save". + """ + raise NotImplementedError() diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/gemini_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/gemini_provider.py new file mode 100644 index 00000000000..14b17eb2c19 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/gemini_provider.py @@ -0,0 +1,55 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +import requests +from .base import LLMProvider, non_empty + + +class GeminiProvider(LLMProvider): + name = "gemini" + + @property + def parameter_schema(self): + return { + "GEMINI_API_KEY": { + "secret": True, + "default": None, + "hint": None, + "validator": non_empty + }, + "MODEL_NAME": { + "secret": False, + "default": None, + "hint": "gemini-2.5", + "validator": non_empty + }, + } + + def validate_connection(self, params: dict): + api_key = params.get("GEMINI_API_KEY") + model_name = params.get("MODEL_NAME") + + if not all([api_key, model_name]): + return False, "Missing required Gemini parameters.", "retry_input" + + url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" + headers = {"Content-Type": "application/json", + "x-goog-api-key": api_key} + payload = { + "contents": [{"role": "user", "parts": [{"text": "ping"}]}] + } + + try: + resp = requests.post(url, headers=headers, + json=payload, timeout=10) + resp.raise_for_status() + return True, "Connection successful.", "save" + except requests.exceptions.HTTPError as e: + if 400 <= resp.status_code < 500: + return False, f"Client error: {e} - {resp.text}", "retry_input" + return False, f"Server error: {e} - {resp.text}", "connection_error" + except requests.exceptions.RequestException as e: + return False, f"Request error: {e}", "connection_error" diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_compatible_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_compatible_provider.py new file mode 100644 index 00000000000..952431ab2e8 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_compatible_provider.py @@ -0,0 +1,64 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import requests +from urllib.parse import urljoin +from .base import LLMProvider, non_empty, is_valid_url + + +class OpenAICompatibleProvider(LLMProvider): + name = "openai_compatible" + + @property + def parameter_schema(self): + return { + "MODEL_NAME": { + "secret": False, + "default": None, + "hint": None, + "validator": non_empty + }, + "API_KEY": { + "secret": True, + "default": "ollama", + "hint": None, + "validator": non_empty + }, + "API_BASE": { + "secret": False, + "default": "https://api.openai.com/v1", + "hint": None, + "validator": is_valid_url + }, + } + + def validate_connection(self, params: dict): + api_key = params.get("API_KEY") + api_base = params.get("API_BASE") + model_name = params.get("MODEL_NAME") + + if not all([api_key, api_base, model_name]): + return False, "Missing required parameters.", "retry_input" + + url = urljoin(api_base, "chat/completions") + headers = {"Authorization": f"Bearer {api_key}", + "Content-Type": "application/json"} + payload = { + "model": model_name, + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 16 + } + + try: + resp = requests.post(url, headers=headers, + json=payload, timeout=10) + resp.raise_for_status() + return True, "Connection successful.", "save" + except requests.exceptions.HTTPError as e: + if 400 <= resp.status_code < 500: + return False, f"Client error: {e} - {resp.text}", "retry_input" + return False, f"Server error: {e} - {resp.text}", "connection_error" + except requests.exceptions.RequestException as e: + return False, f"Request error: {e}", "connection_error" diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_provider.py new file mode 100644 index 00000000000..c495b0197b6 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_provider.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +import requests +from .base import LLMProvider, non_empty + + +class OpenAIProvider(LLMProvider): + name = "openai" + + @property + def parameter_schema(self): + return { + "MODEL_NAME": { + "secret": False, + "default": None, + "hint": "gpt-4.1", + "validator": non_empty + }, + "OPENAI_API_KEY": { + "secret": True, + "default": None, + "hint": None, + "validator": non_empty + }, + } + + def validate_connection(self, params: dict): + api_key = params.get("OPENAI_API_KEY") + model_name = params.get("MODEL_NAME") + + if not all([api_key, model_name]): + return False, "Missing required OpenAI parameters.", "retry_input" + + url = "https://api.openai.com/v1/chat/completions" + headers = {"Authorization": f"Bearer {api_key}", + "Content-Type": "application/json"} + payload = { + "model": model_name, + "messages": [{"role": "user", "content": "ping"}], + "max_completion_tokens": 16 + } + + try: + resp = requests.post(url, headers=headers, + json=payload, timeout=10) + resp.raise_for_status() + return True, "Connection successful", "save" + except requests.exceptions.HTTPError as e: + if 400 <= resp.status_code < 500: + return False, f"Client error: {e} - {resp.text}", "retry_input" + return False, f"Server error: {e} - {resp.text}", "connection_error" + except requests.exceptions.RequestException as e: + return False, f"Request error: {e}", "connection_error" diff --git a/src/aks-agent/azext_aks_agent/commands.py b/src/aks-agent/azext_aks_agent/commands.py index 726dbb56590..713771ddfb1 100644 --- a/src/aks-agent/azext_aks_agent/commands.py +++ b/src/aks-agent/azext_aks_agent/commands.py @@ -14,3 +14,4 @@ def load_command_table(self, _): "aks", ) as g: g.custom_command("agent", "aks_agent") + g.custom_command("agent-init", "aks_agent_init") diff --git a/src/aks-agent/azext_aks_agent/custom.py b/src/aks-agent/azext_aks_agent/custom.py index 32da1118334..01b28f65b61 100644 --- a/src/aks-agent/azext_aks_agent/custom.py +++ b/src/aks-agent/azext_aks_agent/custom.py @@ -5,7 +5,15 @@ # pylint: disable=too-many-lines, disable=broad-except import os +import sys +from typing import Dict, Optional +from azure.cli.core.api import get_config_dir +from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME from azext_aks_agent.agent.agent import aks_agent as aks_agent_internal +from azext_aks_agent.agent.llm_providers import prompt_provider_choice, PROVIDER_REGISTRY +from azext_aks_agent.agent.llm_config_manager import LLMConfigManager + +from azext_aks_agent.agent.agent import init_log from knack.log import get_logger @@ -14,6 +22,47 @@ # pylint: disable=unused-argument +def aks_agent_init(cmd): + """Initialize AKS agent llm configuration.""" + + init_log() + + from rich.console import Console + from holmes.utils.colors import HELP_COLOR, ERROR_COLOR + from holmes.interactive import SlashCommands + + console = Console() + console.print( + f"Welcome to AKS Agent LLM configuration setup. Type '{SlashCommands.EXIT.command}' to exit.", + style=f"bold {HELP_COLOR}") + + provider = prompt_provider_choice() + params = provider.prompt_params() + + llm_config_manager = LLMConfigManager() + # If the connection to the model endpoint is valid, save the configuration + is_valid, message, action = provider.validate_connection(params) + + if is_valid and action == "save": + logger.info("%s", message) + llm_config_manager.save(provider.name, params) + console.print("LLM configuration setup successfully.", style=f"bold {HELP_COLOR}") + + elif not is_valid and action == "retry_input": + logger.warning("%s", message) + console.print( + "Please re-run [bold]`az aks agent-init`[/bold] to correct the input parameters.", style=f"{ERROR_COLOR}") + sys.exit(1) + + else: + logger.error("%s", message) + console.print( + "Please check your deployed model and network connectivity.", style=f"bold {ERROR_COLOR}") + sys.exit(1) + + +# pylint: disable=unused-argument +# pylint: disable=too-many-locals def aks_agent( cmd, prompt, @@ -34,6 +83,75 @@ def aks_agent( if status: return aks_agent_status(cmd) + llm_config_manager = LLMConfigManager() + llm_config = None + default_llm_config_path = os.path.join( + get_config_dir(), CONST_AGENT_CONFIG_FILE_NAME) + + if config_file == default_llm_config_path: + if not model: + logger.info("Using default configuration file: %s", config_file) + llm_config: Optional[Dict] = llm_config_manager.get_latest() + if not llm_config: + raise ValueError( + "No llm configurations found. " + "Please run `az aks agent init` " + "or provide a config file using --config-file.") + + else: + logger.info("Using specified model: %s", model) + # parsing model into provider/model + if "/" in model: + provider_name, model_name = model.split("/", 1) + else: + provider_name = "openai" + model_name = model + llm_config = llm_config_manager.get_specific( + provider_name, model_name) + + else: + if config_file: + logger.info("Using user configuration file: %s", config_file) + import yaml + try: + with open(config_file, "r") as f: + llm_config = yaml.safe_load(f)["llms"][0] + if not isinstance(llm_config, Dict): + raise ValueError( + "Configuration file format is invalid. It should be a YAML mapping.") + except Exception as e: + raise ValueError(f"Failed to load configuration file: {e}") + + else: + raise ValueError( + "No configuration found. " + "Please run `az aks agent-init` or provide a config file using --config-file, " + "or specify a model using --model.") + + # Check if the configuration is complete + provider_name = llm_config.get("provider") + provider_instance = PROVIDER_REGISTRY.get(provider_name)() + parameter_schema = provider_instance.parameter_schema + if _check_provider( + provider_name, + parameter_schema, + llm_config, + llm_config_manager): + # get model for holmesgpt/litellm: provider_name/model_name + model_name = llm_config.get("MODEL_NAME") + if provider_name == "openai": + model = model or model_name + elif provider_name == "openai_compatiable": + model = model or f"openai/{model_name}" + else: + model = model or f"{provider_name}/{model_name}" + # Set environment variables for the model provider + for k, v in llm_config.items(): + if k not in ["provider", "MODEL_NAME"]: + os.environ[k] = v + logger.info( + "Using provider: %s, model: %s, Env vars setup successfully.", provider_name, model_name) + aks_agent_internal( cmd, resource_group_name, @@ -51,6 +169,28 @@ def aks_agent( ) +def _check_provider( + provider_name: str, + parameter_schema: Dict, + llm_config: Dict, + llm_config_manager: LLMConfigManager +) -> bool: + # Check if provider name is not empty + if not provider_name: + raise ValueError("No provider name.") + # Check if provider is supported + if provider_name not in PROVIDER_REGISTRY: + supported = list(PROVIDER_REGISTRY.keys()) + raise ValueError( + f"Unsupported provider {provider_name} for LLM initialization." + f"Supported llm providers are {supported}. Please refer to doc.") + # check if provider config is complete + if not llm_config_manager.is_config_complete(llm_config, parameter_schema): + raise ValueError( + "Incomplete configuration in user config, please run `az aks agent-init` to initialize.") + return True + + def aks_agent_status(cmd): """ Show AKS agent configuration and status. @@ -61,7 +201,6 @@ def aks_agent_status(cmd): try: from azext_aks_agent.agent.binary_manager import AksMcpBinaryManager from azext_aks_agent.agent.mcp_manager import MCPManager - from azure.cli.core.api import get_config_dir from azext_aks_agent._consts import CONST_MCP_BINARY_DIR # Initialize status information @@ -163,7 +302,8 @@ def _display_agent_status(status_info): # MCP Binary status binary_info = status_info.get("mcp_binary", {}) - binary_status = "✅ Available" if binary_info.get("available") else "❌ Not available" + binary_status = "✅ Available" if binary_info.get( + "available") else "❌ Not available" binary_details = [] if binary_info.get("version"): @@ -177,7 +317,8 @@ def _display_agent_status(status_info): table.add_row("MCP Binary", binary_status, " | ".join(binary_details)) # Server status (only if binary is available) - if binary_info.get("available") and status_info.get("mode") in ["mcp_ready", "mcp"]: + if binary_info.get("available") and status_info.get( + "mode") in ["mcp_ready", "mcp"]: server_info = status_info.get("server", {}) server_status = "" server_details = [] @@ -224,21 +365,25 @@ def _get_recommendations(status_info): mode = status_info.get("mode", "unknown") if not binary_info.get("available"): - recommendations.append("Run 'az aks agent' to automatically download the MCP binary for enhanced capabilities") + recommendations.append( + "Run 'az aks agent' to automatically download the MCP binary for enhanced capabilities") elif not binary_info.get("version_valid", True): - recommendations.append("Update the MCP binary by running 'az aks agent --refresh-toolsets'") + recommendations.append( + "Update the MCP binary by running 'az aks agent --refresh-toolsets'") elif mode == "mcp_ready" and not server_info.get("running"): - recommendations.append("MCP binary is ready - run 'az aks agent' to start using enhanced capabilities") + recommendations.append( + "MCP binary is ready - run 'az aks agent' to start using enhanced capabilities") elif mode == "mcp_ready" and server_info.get("running") and not server_info.get("healthy"): - recommendations.append("MCP server is running but unhealthy - it will be automatically restarted on next use") + recommendations.append( + "MCP server is running but unhealthy - it will be automatically restarted on next use") elif mode in ["mcp_ready", "mcp"] and server_info.get("running") and server_info.get("healthy"): - recommendations.append("✅ AKS agent is ready with enhanced MCP capabilities") + recommendations.append( + "✅ AKS agent is ready with enhanced MCP capabilities") elif mode == "traditional": if binary_info.get("available"): recommendations.append( "Consider using MCP mode for enhanced capabilities by running 'az aks agent' " - "(run again with --aks-mcp to switch modes)" - ) + "(run again with --aks-mcp to switch modes)") else: recommendations.append("✅ AKS agent is ready in traditional mode") else: @@ -268,7 +413,8 @@ def _get_health_emoji(status_info): if mode == "traditional": return "✅" # Traditional mode is always healthy if working if mode in ["mcp_ready", "mcp"]: - if binary_info.get("available") and binary_info.get("version_valid", True): + if binary_info.get("available") and binary_info.get( + "version_valid", True): if server_info.get("running") and server_info.get("healthy"): return "✅" # Fully healthy if server_info.get("running"): diff --git a/src/aks-agent/azext_aks_agent/tests/latest/const.py b/src/aks-agent/azext_aks_agent/tests/latest/const.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_binary_manager.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_binary_manager.py index 038b5808f7a..a587e2f56d4 100644 --- a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_binary_manager.py +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_binary_manager.py @@ -3,88 +3,89 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import unittest -from unittest import IsolatedAsyncioTestCase import os -import tempfile import platform import stat import subprocess -from unittest.mock import Mock, patch, AsyncMock +import tempfile +import unittest +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, Mock, patch +import pytest from azext_aks_agent.agent.binary_manager import AksMcpBinaryManager # Use IsolatedAsyncioTestCase for proper async test method support class TestAksMcpBinaryManager(IsolatedAsyncioTestCase): - + def setUp(self): """Set up test fixtures.""" self.test_install_dir = tempfile.mkdtemp() self.binary_manager = AksMcpBinaryManager(self.test_install_dir) - + def tearDown(self): """Clean up test fixtures.""" import shutil shutil.rmtree(self.test_install_dir, ignore_errors=True) - + def test_get_binary_path_linux(self): """Test binary path resolution on Linux.""" with patch('platform.system', return_value='Linux'): manager = AksMcpBinaryManager('/test/dir') expected_path = os.path.join('/test/dir', 'aks-mcp') self.assertEqual(manager.get_binary_path(), expected_path) - + def test_get_binary_path_windows(self): """Test binary path resolution on Windows.""" with patch('platform.system', return_value='Windows'): manager = AksMcpBinaryManager('/test/dir') expected_path = os.path.join('/test/dir', 'aks-mcp.exe') self.assertEqual(manager.get_binary_path(), expected_path) - + def test_get_binary_path_darwin(self): """Test binary path resolution on macOS.""" with patch('platform.system', return_value='Darwin'): manager = AksMcpBinaryManager('/test/dir') expected_path = os.path.join('/test/dir', 'aks-mcp') self.assertEqual(manager.get_binary_path(), expected_path) - + def test_is_binary_available_not_exists(self): """Test binary availability when file doesn't exist.""" self.assertFalse(self.binary_manager.is_binary_available()) - + def test_is_binary_available_exists_but_not_executable(self): """Test binary availability when file exists but is not executable.""" # Create a non-executable file with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') os.chmod(self.binary_manager.binary_path, 0o644) # Read/write but not execute - + self.assertFalse(self.binary_manager.is_binary_available()) - + def test_is_binary_available_exists_and_executable(self): """Test binary availability when file exists and is executable.""" # Create an executable file with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') os.chmod(self.binary_manager.binary_path, 0o755) # Read/write/execute - + self.assertTrue(self.binary_manager.is_binary_available()) - + @patch('os.access') def test_is_binary_available_os_error(self, mock_access): """Test binary availability when os.access raises OSError.""" # Create file first with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') - + mock_access.side_effect = OSError("Permission denied") self.assertFalse(self.binary_manager.is_binary_available()) - + def test_get_binary_version_not_available(self): """Test version retrieval when binary is not available.""" self.assertIsNone(self.binary_manager.get_binary_version()) - + @patch('subprocess.run') def test_get_binary_version_success(self, mock_run): """Test successful version retrieval.""" @@ -92,13 +93,13 @@ def test_get_binary_version_success(self, mock_run): with open(self.binary_manager.binary_path, 'w') as f: f.write('#!/bin/bash\necho "aks-mcp version 0.1.0"') os.chmod(self.binary_manager.binary_path, 0o755) - + # Mock subprocess.run mock_result = Mock() mock_result.returncode = 0 mock_result.stdout = "aks-mcp version 0.1.0\n" mock_run.return_value = mock_result - + version = self.binary_manager.get_binary_version() self.assertEqual(version, "0.1.0") mock_run.assert_called_once_with( @@ -108,7 +109,7 @@ def test_get_binary_version_success(self, mock_run): timeout=10, check=False ) - + @patch('subprocess.run') def test_get_binary_version_different_format(self, mock_run): """Test version retrieval with different output format.""" @@ -116,16 +117,16 @@ def test_get_binary_version_different_format(self, mock_run): with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') os.chmod(self.binary_manager.binary_path, 0o755) - + # Different format mock_result = Mock() mock_result.returncode = 0 mock_result.stdout = "version: 1.2.3\n" mock_run.return_value = mock_result - + version = self.binary_manager.get_binary_version() self.assertEqual(version, "1.2.3") - + @patch('subprocess.run') def test_get_binary_version_actual_format(self, mock_run): """Test version retrieval with actual aks-mcp version format.""" @@ -133,7 +134,7 @@ def test_get_binary_version_actual_format(self, mock_run): with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') os.chmod(self.binary_manager.binary_path, 0o755) - + # Mock subprocess.run with actual aks-mcp output format mock_result = Mock() mock_result.returncode = 0 @@ -144,10 +145,10 @@ def test_get_binary_version_actual_format(self, mock_run): Platform: darwin/arm64 """ mock_run.return_value = mock_result - + version = self.binary_manager.get_binary_version() self.assertEqual(version, "0.0.6") - + @patch('subprocess.run') def test_get_binary_version_git_format_variations(self, mock_run): """Test version retrieval with different git-style version formats.""" @@ -155,24 +156,24 @@ def test_get_binary_version_git_format_variations(self, mock_run): with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') os.chmod(self.binary_manager.binary_path, 0o755) - + test_cases = [ ("aks-mcp version v0.1.0", "0.1.0"), ("aks-mcp version v0.1.0-5-g123abc", "0.1.0"), ("aks-mcp version v1.2.3-10-gabc123+2025-01-01T12:00:00Z", "1.2.3"), ("version v0.0.7-dirty", "0.0.7"), ] - + for output, expected_version in test_cases: with self.subTest(output=output): mock_result = Mock() mock_result.returncode = 0 mock_result.stdout = output + "\n" mock_run.return_value = mock_result - + version = self.binary_manager.get_binary_version() self.assertEqual(version, expected_version) - + @patch('subprocess.run') def test_get_binary_version_subprocess_error(self, mock_run): """Test version retrieval when subprocess fails.""" @@ -180,11 +181,11 @@ def test_get_binary_version_subprocess_error(self, mock_run): with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') os.chmod(self.binary_manager.binary_path, 0o755) - + mock_run.side_effect = subprocess.SubprocessError("Command failed") version = self.binary_manager.get_binary_version() self.assertIsNone(version) - + @patch('subprocess.run') def test_get_binary_version_timeout(self, mock_run): """Test version retrieval when subprocess times out.""" @@ -192,11 +193,11 @@ def test_get_binary_version_timeout(self, mock_run): with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') os.chmod(self.binary_manager.binary_path, 0o755) - + mock_run.side_effect = subprocess.TimeoutExpired("cmd", 10) version = self.binary_manager.get_binary_version() self.assertIsNone(version) - + @patch('subprocess.run') def test_get_binary_version_non_zero_exit(self, mock_run): """Test version retrieval when command returns non-zero exit code.""" @@ -204,15 +205,15 @@ def test_get_binary_version_non_zero_exit(self, mock_run): with open(self.binary_manager.binary_path, 'w') as f: f.write('dummy content') os.chmod(self.binary_manager.binary_path, 0o755) - + mock_result = Mock() mock_result.returncode = 1 mock_result.stdout = "error" mock_run.return_value = mock_result - + version = self.binary_manager.get_binary_version() self.assertIsNone(version) - + @patch.object(AksMcpBinaryManager, 'get_binary_version') def test_validate_version_success(self, mock_get_version): """Test successful version validation.""" @@ -220,19 +221,19 @@ def test_validate_version_success(self, mock_get_version): self.assertTrue(self.binary_manager.validate_version("0.0.6")) self.assertTrue(self.binary_manager.validate_version("0.1.0")) self.assertFalse(self.binary_manager.validate_version("0.2.0")) - + @patch.object(AksMcpBinaryManager, 'get_binary_version') def test_validate_version_no_version(self, mock_get_version): """Test version validation when no version available.""" mock_get_version.return_value = None self.assertFalse(self.binary_manager.validate_version("0.0.6")) - + @patch.object(AksMcpBinaryManager, 'get_binary_version') def test_validate_version_invalid_format(self, mock_get_version): """Test version validation with invalid version format.""" mock_get_version.return_value = "invalid-version" self.assertFalse(self.binary_manager.validate_version("0.0.6")) - + def test_validate_version_complex_versions(self): """Test version validation with complex version numbers.""" with patch.object(self.binary_manager, 'get_binary_version') as mock_get_version: @@ -240,105 +241,106 @@ def test_validate_version_complex_versions(self): mock_get_version.return_value = "0.1.0.1" self.assertTrue(self.binary_manager.validate_version("0.1.0.0")) self.assertFalse(self.binary_manager.validate_version("0.2.0.0")) - + # Test equal versions mock_get_version.return_value = "1.2.3" self.assertTrue(self.binary_manager.validate_version("1.2.3")) - + def test_get_platform_info_linux_amd64(self): """Test platform info detection for Linux amd64.""" with patch('platform.system', return_value='Linux'), \ - patch('platform.machine', return_value='x86_64'): + patch('platform.machine', return_value='x86_64'): platform_name, arch_name = self.binary_manager._get_platform_info() self.assertEqual(platform_name, 'linux') self.assertEqual(arch_name, 'amd64') - + def test_get_platform_info_darwin_arm64(self): """Test platform info detection for macOS ARM64.""" with patch('platform.system', return_value='Darwin'), \ - patch('platform.machine', return_value='arm64'): + patch('platform.machine', return_value='arm64'): platform_name, arch_name = self.binary_manager._get_platform_info() self.assertEqual(platform_name, 'darwin') self.assertEqual(arch_name, 'arm64') - + def test_get_platform_info_windows_amd64(self): """Test platform info detection for Windows amd64.""" with patch('platform.system', return_value='Windows'), \ - patch('platform.machine', return_value='AMD64'): + patch('platform.machine', return_value='AMD64'): platform_name, arch_name = self.binary_manager._get_platform_info() self.assertEqual(platform_name, 'windows') self.assertEqual(arch_name, 'amd64') - + def test_get_platform_info_linux_aarch64(self): """Test platform info detection for Linux aarch64.""" with patch('platform.system', return_value='Linux'), \ - patch('platform.machine', return_value='aarch64'): + patch('platform.machine', return_value='aarch64'): platform_name, arch_name = self.binary_manager._get_platform_info() self.assertEqual(platform_name, 'linux') self.assertEqual(arch_name, 'arm64') - + def test_make_binary_executable_unix(self): """Test making binary executable on Unix-like systems.""" if platform.system() == 'Windows': self.skipTest("Skipping Unix test on Windows") - + # Create a test file test_file = os.path.join(self.test_install_dir, 'test-binary') with open(test_file, 'w') as f: f.write('dummy content') os.chmod(test_file, 0o644) # Read/write only - + # Make it executable success = self.binary_manager._make_binary_executable(test_file) self.assertTrue(success) - + # Check that it's now executable file_mode = os.stat(test_file).st_mode self.assertTrue(file_mode & stat.S_IEXEC) # Owner executable self.assertTrue(file_mode & stat.S_IXGRP) # Group executable self.assertTrue(file_mode & stat.S_IXOTH) # Others executable - + @patch('platform.system', return_value='Windows') def test_make_binary_executable_windows(self, mock_system): """Test making binary executable on Windows (should always succeed).""" test_file = os.path.join(self.test_install_dir, 'test-binary.exe') with open(test_file, 'w') as f: f.write('dummy content') - + success = self.binary_manager._make_binary_executable(test_file) self.assertTrue(success) - + def test_make_binary_executable_os_error(self): """Test making binary executable when OS operations fail.""" if platform.system() == 'Windows': self.skipTest("Skipping Unix test on Windows") - + # Try to make a non-existent file executable non_existent_file = os.path.join(self.test_install_dir, 'non-existent') success = self.binary_manager._make_binary_executable(non_existent_file) self.assertFalse(success) # New tests for GitHub Release API Integration - + def test_get_platform_binary_name_linux_amd64(self): """Test platform binary name generation for Linux AMD64.""" with patch.object(self.binary_manager, '_get_platform_info', return_value=('linux', 'amd64')): binary_name = self.binary_manager._get_platform_binary_name() self.assertEqual(binary_name, 'aks-mcp-linux-amd64') - + def test_get_platform_binary_name_windows_amd64(self): """Test platform binary name generation for Windows AMD64.""" with patch.object(self.binary_manager, '_get_platform_info', return_value=('windows', 'amd64')), \ - patch('platform.system', return_value='Windows'): + patch('platform.system', return_value='Windows'): binary_name = self.binary_manager._get_platform_binary_name() self.assertEqual(binary_name, 'aks-mcp-windows-amd64.exe') - + def test_get_platform_binary_name_darwin_arm64(self): """Test platform binary name generation for macOS ARM64.""" with patch.object(self.binary_manager, '_get_platform_info', return_value=('darwin', 'arm64')): binary_name = self.binary_manager._get_platform_binary_name() self.assertEqual(binary_name, 'aks-mcp-darwin-arm64') + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") @patch('aiohttp.ClientSession') async def test_get_latest_release_info_success(self, mock_session): """Test successful GitHub API release info retrieval.""" @@ -351,53 +353,55 @@ async def test_get_latest_release_info_success(self, mock_session): {"name": "aks-mcp-linux-amd64", "browser_download_url": "https://example.com/binary"} ] }) - + # Create mock session with proper async context manager support mock_session_ctx = AsyncMock() mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_ctx) mock_session_ctx.__aexit__ = AsyncMock(return_value=None) - - # Create mock get response with proper async context manager support + + # Create mock get response with proper async context manager support mock_get_ctx = AsyncMock() mock_get_ctx.__aenter__ = AsyncMock(return_value=mock_response) mock_get_ctx.__aexit__ = AsyncMock(return_value=None) - + # Wire up the mocks mock_session.return_value = mock_session_ctx mock_session_ctx.get = Mock(return_value=mock_get_ctx) - + result = await self.binary_manager.get_latest_release_info() - + self.assertIsInstance(result, dict) self.assertEqual(result["tag_name"], "v0.1.0") self.assertIn("assets", result) - + + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") @patch('aiohttp.ClientSession') async def test_get_latest_release_info_http_error(self, mock_session): """Test GitHub API release info with HTTP error.""" # Mock HTTP error response mock_response = AsyncMock() mock_response.status = 404 - + # Create mock session with proper async context manager support mock_session_ctx = AsyncMock() mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_ctx) mock_session_ctx.__aexit__ = AsyncMock(return_value=None) - - # Create mock get response with proper async context manager support + + # Create mock get response with proper async context manager support mock_get_ctx = AsyncMock() mock_get_ctx.__aenter__ = AsyncMock(return_value=mock_response) mock_get_ctx.__aexit__ = AsyncMock(return_value=None) - + # Wire up the mocks mock_session.return_value = mock_session_ctx mock_session_ctx.get = Mock(return_value=mock_get_ctx) - + with self.assertRaises(Exception) as context: await self.binary_manager.get_latest_release_info() - + self.assertIn("GitHub API request failed with status 404", str(context.exception)) - + + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") @patch('aiohttp.ClientSession') async def test_get_latest_release_info_network_error(self, mock_session): """Test GitHub API release info with network error.""" @@ -405,19 +409,20 @@ async def test_get_latest_release_info_network_error(self, mock_session): mock_session_ctx = AsyncMock() mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_ctx) mock_session_ctx.__aexit__ = AsyncMock(return_value=None) - + # Mock get to raise aiohttp.ClientError import aiohttp mock_session_ctx.get = Mock(side_effect=aiohttp.ClientError("Network unreachable")) - + # Wire up the mocks mock_session.return_value = mock_session_ctx - + with self.assertRaises(Exception) as context: await self.binary_manager.get_latest_release_info() - + self.assertIn("Network error accessing GitHub API", str(context.exception)) - + + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") @patch('aiohttp.ClientSession') async def test_get_latest_release_info_json_error(self, mock_session): """Test GitHub API release info with JSON decode error.""" @@ -425,46 +430,46 @@ async def test_get_latest_release_info_json_error(self, mock_session): mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(side_effect=ValueError("Invalid JSON")) - + # Create mock session with proper async context manager support mock_session_ctx = AsyncMock() mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_ctx) mock_session_ctx.__aexit__ = AsyncMock(return_value=None) - - # Create mock get response with proper async context manager support + + # Create mock get response with proper async context manager support mock_get_ctx = AsyncMock() mock_get_ctx.__aenter__ = AsyncMock(return_value=mock_response) mock_get_ctx.__aexit__ = AsyncMock(return_value=None) - + # Wire up the mocks mock_session.return_value = mock_session_ctx mock_session_ctx.get = Mock(return_value=mock_get_ctx) - + with self.assertRaises(Exception): await self.binary_manager.get_latest_release_info() - + @patch('urllib.request.urlopen') def test_verify_binary_integrity_subject_hash(self, mock_urlopen): """Test binary integrity verification with subject hash in attestation.""" - import json import hashlib - + import json + # Create a test binary file (any size) test_file = os.path.join(self.test_install_dir, 'aks-mcp-darwin-arm64') test_content = b'test content' with open(test_file, 'wb') as f: f.write(test_content) - + # Calculate the actual hash of our test content actual_hash = hashlib.sha256(test_content).hexdigest() - + # Mock release info with attestation file release_info = { "assets": [ {"name": "aks-mcp-darwin-arm64.intoto.jsonl", "browser_download_url": "https://example.com/attestation"} ] } - + # Mock attestation content with subject hash attestation_dict = { "subject": [ @@ -477,71 +482,74 @@ def test_verify_binary_integrity_subject_hash(self, mock_urlopen): ] } attestation_content = json.dumps(attestation_dict) - + # Mock HTTP response mock_response = Mock() mock_response.status = 200 mock_response.read.return_value = attestation_content.encode('utf-8') mock_urlopen.return_value.__enter__.return_value = mock_response - + result = self.binary_manager._verify_binary_integrity(test_file, release_info) self.assertTrue(result) - + def test_verify_binary_integrity_fallback_to_basic(self): """Test binary integrity verification fallback when no attestation found.""" # Create a test binary file (any size) test_file = os.path.join(self.test_install_dir, 'aks-mcp-linux-amd64') with open(test_file, 'wb') as f: f.write(b'test content') - + release_info = {"assets": []} result = self.binary_manager._verify_binary_integrity(test_file, release_info) self.assertTrue(result) + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") async def test_ensure_binary_already_available_and_valid(self): """Test ensure_binary when binary is already available and valid.""" with patch.object(self.binary_manager, 'is_binary_available', return_value=True), \ - patch.object(self.binary_manager, 'get_binary_version', return_value="1.0.0"), \ - patch.object(self.binary_manager, 'validate_version', return_value=True): - + patch.object(self.binary_manager, 'get_binary_version', return_value="1.0.0"), \ + patch.object(self.binary_manager, 'validate_version', return_value=True): + status = await self.binary_manager.ensure_binary() - + self.assertTrue(status.available) self.assertEqual(status.version, "1.0.0") self.assertTrue(status.version_valid) self.assertTrue(status.ready) self.assertIsNone(status.error_message) + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") async def test_ensure_binary_available_but_invalid_version(self): """Test ensure_binary when binary is available but has invalid version.""" mock_progress = Mock() - + with patch.object(self.binary_manager, 'is_binary_available', side_effect=[True, True]), \ - patch.object(self.binary_manager, 'get_binary_version', side_effect=["0.0.1", "1.0.0"]), \ - patch.object(self.binary_manager, 'validate_version', side_effect=[False, True]), \ - patch.object(self.binary_manager, '_create_installation_directory', return_value=True), \ - patch.object(self.binary_manager, 'download_binary', return_value=True): - + patch.object(self.binary_manager, 'get_binary_version', side_effect=["0.0.1", "1.0.0"]), \ + patch.object(self.binary_manager, 'validate_version', side_effect=[False, True]), \ + patch.object(self.binary_manager, '_create_installation_directory', return_value=True), \ + patch.object(self.binary_manager, 'download_binary', return_value=True): + status = await self.binary_manager.ensure_binary(progress_callback=mock_progress) - + self.assertTrue(status.available) self.assertEqual(status.version, "1.0.0") self.assertTrue(status.version_valid) self.assertTrue(status.ready) self.assertIsNone(status.error_message) + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") async def test_ensure_binary_not_available_download_success(self): """Test ensure_binary when binary is not available but download succeeds.""" mock_progress = Mock() - + with patch.object(self.binary_manager, 'is_binary_available', side_effect=[False, True]), \ - patch.object(self.binary_manager, 'get_binary_version', return_value="1.0.0"), \ - patch.object(self.binary_manager, 'validate_version', return_value=True), \ - patch.object(self.binary_manager, '_create_installation_directory', return_value=True), \ - patch.object(self.binary_manager, 'download_binary', return_value=True) as mock_download: - + patch.object(self.binary_manager, 'get_binary_version', return_value="1.0.0"), \ + patch.object(self.binary_manager, 'validate_version', return_value=True), \ + patch.object(self.binary_manager, '_create_installation_directory', return_value=True), \ + patch.object(self.binary_manager, 'download_binary', return_value=True) as mock_download: + status = await self.binary_manager.ensure_binary(progress_callback=mock_progress) - + self.assertTrue(status.available) self.assertEqual(status.version, "1.0.0") self.assertTrue(status.version_valid) @@ -549,49 +557,53 @@ async def test_ensure_binary_not_available_download_success(self): self.assertIsNone(status.error_message) mock_download.assert_called_once_with(progress_callback=mock_progress) + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") async def test_ensure_binary_directory_creation_fails(self): """Test ensure_binary when directory creation fails.""" with patch.object(self.binary_manager, 'is_binary_available', return_value=False), \ - patch.object(self.binary_manager, '_create_installation_directory', return_value=False): - + patch.object(self.binary_manager, '_create_installation_directory', return_value=False): + status = await self.binary_manager.ensure_binary() - + self.assertFalse(status.ready) self.assertIn("Failed to create installation directory", status.error_message) + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") async def test_ensure_binary_download_fails(self): """Test ensure_binary when download fails.""" with patch.object(self.binary_manager, 'is_binary_available', return_value=False), \ - patch.object(self.binary_manager, '_create_installation_directory', return_value=True), \ - patch.object(self.binary_manager, 'download_binary', return_value=False): - + patch.object(self.binary_manager, '_create_installation_directory', return_value=True), \ + patch.object(self.binary_manager, 'download_binary', return_value=False): + status = await self.binary_manager.ensure_binary() - + self.assertFalse(status.ready) self.assertEqual(status.error_message, "Binary download failed") + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") async def test_ensure_binary_download_success_but_validation_fails(self): """Test ensure_binary when download succeeds but validation fails.""" with patch.object(self.binary_manager, 'is_binary_available', side_effect=[False, True]), \ - patch.object(self.binary_manager, 'get_binary_version', return_value="0.0.1"), \ - patch.object(self.binary_manager, 'validate_version', return_value=False), \ - patch.object(self.binary_manager, '_create_installation_directory', return_value=True), \ - patch.object(self.binary_manager, 'download_binary', return_value=True): - + patch.object(self.binary_manager, 'get_binary_version', return_value="0.0.1"), \ + patch.object(self.binary_manager, 'validate_version', return_value=False), \ + patch.object(self.binary_manager, '_create_installation_directory', return_value=True), \ + patch.object(self.binary_manager, 'download_binary', return_value=True): + status = await self.binary_manager.ensure_binary() - + self.assertTrue(status.available) self.assertEqual(status.version, "0.0.1") self.assertFalse(status.version_valid) self.assertFalse(status.ready) self.assertEqual(status.error_message, "Downloaded binary failed validation") + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") async def test_ensure_binary_unexpected_exception(self): """Test ensure_binary handles unexpected exceptions gracefully.""" with patch.object(self.binary_manager, 'is_binary_available', side_effect=Exception("Unexpected error")): - + status = await self.binary_manager.ensure_binary() - + self.assertFalse(status.ready) self.assertIn("Unexpected error during binary management", status.error_message) self.assertIn("Unexpected error", status.error_message) @@ -599,4 +611,3 @@ async def test_ensure_binary_unexpected_exception(self): if __name__ == '__main__': unittest.main() - diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_init.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_init.py new file mode 100644 index 00000000000..360e8361fd9 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_init.py @@ -0,0 +1,167 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import types +import unittest +from unittest.mock import patch, MagicMock +from azext_aks_agent.custom import aks_agent_init + + +mock_logging = MagicMock(name="init_logging") +mock_console_mod = types.SimpleNamespace(logging=types.SimpleNamespace(init_logging=mock_logging)) + +mock_holmes = types.SimpleNamespace( + interactive=types.SimpleNamespace( + SlashCommands=MagicMock() + ), + utils=types.SimpleNamespace( + colors=types.SimpleNamespace( + ERROR_COLOR=MagicMock(), + HELP_COLOR=MagicMock(), + ), + console=mock_console_mod, + ) +) + +mock_rich = types.SimpleNamespace( + console=types.SimpleNamespace( + Console=MagicMock() + ) +) + + +class TestAksAgentInit(unittest.TestCase): + + def setUp(self): + patcher = patch.dict( + 'sys.modules', + { + 'holmes': mock_holmes, + 'holmes.interactive': mock_holmes.interactive, + 'holmes.utils': mock_holmes.utils, + 'holmes.utils.colors': mock_holmes.utils.colors, + 'holmes.utils.console': mock_holmes.utils.console, + 'holmes.utils.console.logging': mock_holmes.utils.console.logging, + 'rich': mock_rich, + 'rich.console': mock_rich.console, + } + ) + self.addCleanup(patcher.stop) + patcher.start() + + @patch('holmes.interactive.SlashCommands') + @patch('holmes.utils.colors.ERROR_COLOR') + @patch('holmes.utils.colors.HELP_COLOR') + @patch('rich.console.Console') + @patch('azext_aks_agent.custom.prompt_provider_choice') + @patch('azext_aks_agent.custom.LLMConfigManager') + def test_init_successful_save( + self, + mock_config_manager_cls, + mock_prompt_provider_choice, + mock_console_cls, + mock_help_color, + mock_error_color, + mock_slash_commands + ): + mock_console = MagicMock() + mock_console_cls.return_value = mock_console + + mock_provider = MagicMock() + mock_provider.prompt_params.return_value = {'MODEL_NAME': 'test-model', 'param': 'value'} + mock_provider.validate_connection.return_value = (True, 'Valid', 'save') + mock_provider.name = 'openai' + mock_prompt_provider_choice.return_value = mock_provider + + mock_config_manager = MagicMock() + mock_config_manager_cls.return_value = mock_config_manager + + mock_help_color.__str__.return_value = "green" + mock_error_color.__str__.return_value = "red" + mock_slash_commands.EXIT.command = "exit" + + aks_agent_init(cmd=None) + mock_config_manager.save.assert_called_once_with('openai', {'MODEL_NAME': 'test-model', 'param': 'value'}) + mock_console.print.assert_any_call("LLM configuration setup successfully.", style=unittest.mock.ANY) + + @patch('holmes.interactive.SlashCommands') + @patch('holmes.utils.colors.ERROR_COLOR') + @patch('holmes.utils.colors.HELP_COLOR') + @patch('rich.console.Console') + @patch('azext_aks_agent.custom.prompt_provider_choice') + @patch('azext_aks_agent.custom.LLMConfigManager') + def test_init_retry_input( + self, + mock_config_manager_cls, + mock_prompt_provider_choice, + mock_console_cls, + mock_help_color, + mock_error_color, + mock_slash_commands + ): + mock_console = MagicMock() + mock_console_cls.return_value = mock_console + + mock_provider = MagicMock() + mock_provider.prompt_params.return_value = {'MODEL_NAME': 'test-model'} + mock_provider.validate_connection.return_value = (False, 'Invalid input', 'retry_input') + mock_provider.name = 'openai' + mock_prompt_provider_choice.return_value = mock_provider + + mock_config_manager_cls.return_value = MagicMock() + + mock_help_color.__str__.return_value = "green" + mock_error_color.__str__.return_value = "red" + mock_slash_commands.EXIT.command = "exit" + + with self.assertRaises(SystemExit) as cm: + aks_agent_init(cmd=None) + self.assertEqual(cm.exception.code, 1) + mock_console.print.assert_any_call( + "Please re-run [bold]`az aks agent-init`[/bold] to correct the input parameters.", + style=unittest.mock.ANY, + ) + + @patch('holmes.interactive.SlashCommands') + @patch('holmes.utils.colors.ERROR_COLOR') + @patch('holmes.utils.colors.HELP_COLOR') + @patch('rich.console.Console') + @patch('azext_aks_agent.custom.prompt_provider_choice') + @patch('azext_aks_agent.custom.LLMConfigManager') + def test_init_connection_error( + self, + mock_config_manager_cls, + mock_prompt_provider_choice, + mock_console_cls, + mock_help_color, + mock_error_color, + mock_slash_commands + ): + mock_console = MagicMock() + mock_console_cls.return_value = mock_console + + mock_provider = MagicMock() + mock_provider.prompt_params.return_value = {'MODEL_NAME': 'test-model'} + mock_provider.validate_connection.return_value = (False, 'Connection failed', 'other') + mock_provider.name = 'openai' + mock_prompt_provider_choice.return_value = mock_provider + + mock_config_manager_cls.return_value = MagicMock() + + mock_help_color.__str__.return_value = "green" + mock_error_color.__str__.return_value = "red" + mock_slash_commands.EXIT.command = "exit" + + with self.assertRaises(SystemExit) as cm: + aks_agent_init(cmd=None) + self.assertEqual(cm.exception.code, 1) + mock_console.print.assert_any_call( + "Please check your deployed model and network connectivity.", + style=unittest.mock.ANY, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_config_manager.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_config_manager.py new file mode 100644 index 00000000000..2e68fcd31d1 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_config_manager.py @@ -0,0 +1,69 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import os +import tempfile +import unittest +from azext_aks_agent.agent.llm_config_manager import LLMConfigManager + + +class TestLLMConfigManager(unittest.TestCase): + def setUp(self): + # Create a temporary config file for testing + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.config_path = self.temp_file.name + self.manager = LLMConfigManager(config_path=self.config_path) + + def tearDown(self): + # Remove the temporary file after each test + if os.path.exists(self.config_path): + os.unlink(self.config_path) + + def test_save_and_load(self): + params = {"MODEL_NAME": "test-model", "param1": "value1"} + self.manager.save("openai", params) + loaded = self.manager.load() + self.assertIn("llms", loaded) + self.assertEqual(loaded["llms"][0]["MODEL_NAME"], "test-model") + self.assertEqual(loaded["llms"][0]["provider"], "openai") + + def test_get_list_and_latest(self): + params1 = {"MODEL_NAME": "model1", "param": "v1"} + params2 = {"MODEL_NAME": "model2", "param": "v2"} + self.manager.save("openai", params1) + self.manager.save("openai", params2) + model_list = self.manager.get_list() + self.assertEqual(len(model_list), 2) + latest = self.manager.get_latest() + self.assertEqual(latest["MODEL_NAME"], "model2") + + def test_get_specific(self): + params1 = {"MODEL_NAME": "modelA", "param": "foo"} + params2 = {"MODEL_NAME": "modelB", "param": "bar"} + self.manager.save("openai", params1) + self.manager.save("openai", params2) + specific = self.manager.get_specific("openai", "modelA") + self.assertEqual(specific["param"], "foo") + with self.assertRaises(ValueError): + self.manager.get_specific("openai", "not_exist") + + def test_is_config_complete(self): + config = {"key1": "val1", "key2": "val2"} + schema = { + "key1": {"validator": lambda v: v == "val1"}, + "key2": {"validator": lambda v: v == "val2"} + } + self.assertTrue(self.manager.is_config_complete(config, schema)) + config["key2"] = "wrong" + self.assertFalse(self.manager.is_config_complete(config, schema)) + + def test_load_returns_empty_when_file_missing(self): + # Remove file and test load fallback + os.unlink(self.config_path) + self.assertEqual(self.manager.load(), {}) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_providers.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_providers.py new file mode 100644 index 00000000000..0636af94f47 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_providers.py @@ -0,0 +1,29 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +from azext_aks_agent.agent.llm_providers import PROVIDER_REGISTRY, AnthropicProvider, GeminiProvider, AzureProvider, OpenAIProvider, OpenAICompatibleProvider + + +class TestLLMProviders(unittest.TestCase): + def test_provider_registry(self): + """Test that provider registry maps names to correct classes.""" + self.assertIs(PROVIDER_REGISTRY['azure'], AzureProvider) + self.assertIs(PROVIDER_REGISTRY['openai'], OpenAIProvider) + self.assertIs(PROVIDER_REGISTRY['anthropic'], AnthropicProvider) + self.assertIs(PROVIDER_REGISTRY['gemini'], GeminiProvider) + self.assertIs(PROVIDER_REGISTRY['openai_compatible'], OpenAICompatibleProvider) + + def test_provider_choices_numbered(self): + """Test numbered provider choices are correct and ordered.""" + from azext_aks_agent.agent.llm_providers import _provider_choices_numbered, _available_providers + choices = _provider_choices_numbered() + providers = _available_providers() + for idx, name in choices: + self.assertEqual(name, providers[idx-1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_mcp_manager.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_mcp_manager.py index 1743db4270b..f12c1bdc026 100644 --- a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_mcp_manager.py +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_mcp_manager.py @@ -3,46 +3,48 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import unittest -from unittest import IsolatedAsyncioTestCase +import asyncio import os import tempfile -import asyncio -from unittest.mock import Mock, patch, AsyncMock +import unittest +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, Mock, patch +import pytest from azext_aks_agent.agent.mcp_manager import MCPManager class TestMCPManager(unittest.TestCase): - + def setUp(self): """Set up test fixtures.""" self.test_config_dir = tempfile.mkdtemp() # Create the bin subdirectory that would be expected self.test_bin_dir = os.path.join(self.test_config_dir, 'bin') os.makedirs(self.test_bin_dir, exist_ok=True) - + # Set up event loop for async tests self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - + def tearDown(self): """Clean up test fixtures.""" import shutil + # Clean up async tasks try: self.loop.close() except Exception: pass shutil.rmtree(self.test_config_dir, ignore_errors=True) - + @patch('azext_aks_agent.agent.mcp_manager.get_config_dir') def test_mcp_manager_init_with_default_config_dir(self, mock_get_config_dir): """Test MCP manager initialization with default config directory.""" mock_get_config_dir.return_value = '/mock/config/dir' - + manager = MCPManager() - + self.assertEqual(manager.config_dir, '/mock/config/dir') self.assertFalse(manager.verbose) self.assertIsNotNone(manager.binary_manager) @@ -51,11 +53,11 @@ def test_mcp_manager_init_with_default_config_dir(self, mock_get_config_dir): self.assertIsNone(manager.server_url) self.assertIsNone(manager.server_port) mock_get_config_dir.assert_called_once() - + def test_mcp_manager_init_with_custom_config_dir(self): """Test MCP manager initialization with custom config directory.""" manager = MCPManager(config_dir=self.test_config_dir, verbose=True) - + self.assertEqual(manager.config_dir, self.test_config_dir) self.assertTrue(manager.verbose) self.assertIsNotNone(manager.binary_manager) @@ -66,80 +68,80 @@ def test_mcp_manager_init_with_custom_config_dir(self): # Check that binary manager was initialized with correct path expected_bin_path = os.path.join(self.test_config_dir, 'bin') self.assertEqual(manager.binary_manager.install_dir, expected_bin_path) - + def test_is_binary_available_true(self): """Test binary availability check when binary is available.""" manager = MCPManager(config_dir=self.test_config_dir) - + with patch.object(manager.binary_manager, 'is_binary_available', return_value=True): self.assertTrue(manager.is_binary_available()) - + def test_is_binary_available_false(self): """Test binary availability check when binary is not available.""" manager = MCPManager(config_dir=self.test_config_dir) - + with patch.object(manager.binary_manager, 'is_binary_available', return_value=False): self.assertFalse(manager.is_binary_available()) - + def test_get_binary_version_with_version(self): """Test getting binary version when version is available.""" manager = MCPManager(config_dir=self.test_config_dir) expected_version = "0.1.0" - + with patch.object(manager.binary_manager, 'get_binary_version', return_value=expected_version): version = manager.get_binary_version() self.assertEqual(version, expected_version) - + def test_get_binary_version_none(self): """Test getting binary version when no version is available.""" manager = MCPManager(config_dir=self.test_config_dir) - + with patch.object(manager.binary_manager, 'get_binary_version', return_value=None): version = manager.get_binary_version() self.assertIsNone(version) - + def test_get_binary_path(self): """Test getting binary path.""" manager = MCPManager(config_dir=self.test_config_dir) expected_path = os.path.join(self.test_bin_dir, 'aks-mcp') - + with patch.object(manager.binary_manager, 'get_binary_path', return_value=expected_path): path = manager.get_binary_path() self.assertEqual(path, expected_path) - + def test_validate_binary_version_valid(self): """Test binary version validation when version is valid.""" manager = MCPManager(config_dir=self.test_config_dir) - + with patch.object(manager.binary_manager, 'validate_version', return_value=True): self.assertTrue(manager.validate_binary_version()) - + def test_validate_binary_version_invalid(self): """Test binary version validation when version is invalid.""" manager = MCPManager(config_dir=self.test_config_dir) - + with patch.object(manager.binary_manager, 'validate_version', return_value=False): self.assertFalse(manager.validate_binary_version()) class TestMCPManagerServerLifecycle(IsolatedAsyncioTestCase): """Test server lifecycle management functionality.""" - + def setUp(self): """Set up test fixtures for server tests.""" self.test_config_dir = tempfile.mkdtemp() self.test_bin_dir = os.path.join(self.test_config_dir, 'bin') os.makedirs(self.test_bin_dir, exist_ok=True) - + def tearDown(self): """Clean up test fixtures.""" import shutil shutil.rmtree(self.test_config_dir, ignore_errors=True) - + def test_initial_server_state(self): """Test initial server state after initialization.""" manager = MCPManager(config_dir=self.test_config_dir) - + self.assertIsNone(manager.server_process) self.assertIsNone(manager.server_url) self.assertIsNone(manager.server_port) @@ -147,168 +149,171 @@ def test_initial_server_state(self): self.assertFalse(manager.is_server_healthy()) self.assertIsNone(manager.get_server_url()) self.assertIsNone(manager.get_server_port()) - + def test_find_available_port_default(self): """Test finding available port starting from default.""" manager = MCPManager(config_dir=self.test_config_dir) - + port = manager._find_available_port(8003) self.assertGreaterEqual(port, 8003) self.assertLess(port, 8103) # Should be within 100 port range - + def test_find_available_port_custom_start(self): """Test finding available port with custom start port.""" manager = MCPManager(config_dir=self.test_config_dir) - + port = manager._find_available_port(9000) self.assertGreaterEqual(port, 9000) self.assertLess(port, 9100) # Should be within 100 port range - + @patch('socket.socket') def test_find_available_port_no_ports_available(self, mock_socket): """Test exception when no ports are available.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Mock all sockets to fail binding mock_socket.return_value.__enter__.return_value.bind.side_effect = OSError("Port in use") - + with self.assertRaises(Exception) as cm: manager._find_available_port(8003) - + self.assertIn("No available ports found", str(cm.exception)) - + def test_is_server_running_no_process(self): """Test is_server_running when no process exists.""" manager = MCPManager(config_dir=self.test_config_dir) self.assertFalse(manager.is_server_running()) - + def test_is_server_running_with_process(self): """Test is_server_running with active process.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Mock an active process mock_process = Mock() mock_process.returncode = None # Process is still running manager.server_process = mock_process - + self.assertTrue(manager.is_server_running()) - + def test_is_server_running_with_terminated_process(self): """Test is_server_running with terminated process.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Mock a terminated process mock_process = Mock() mock_process.returncode = 0 # Process has exited manager.server_process = mock_process - + self.assertFalse(manager.is_server_running()) - + @patch('urllib.request.urlopen') def test_is_server_healthy_success(self, mock_urlopen): """Test server health check success.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Setup server state manager.server_process = Mock() manager.server_process.returncode = None manager.server_url = "http://localhost:8003/sse" - + # Mock successful HTTP response mock_response = Mock() mock_response.status = 200 mock_urlopen.return_value.__enter__.return_value = mock_response - + self.assertTrue(manager.is_server_healthy()) mock_urlopen.assert_called_once_with("http://localhost:8003/sse", timeout=3) - + @patch('urllib.request.urlopen') def test_is_server_healthy_http_error(self, mock_urlopen): """Test server health check HTTP error.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Setup server state manager.server_process = Mock() manager.server_process.returncode = None manager.server_url = "http://localhost:8003/sse" - + # Mock HTTP error import urllib.error mock_urlopen.side_effect = urllib.error.HTTPError( "http://localhost:8003/sse", 500, "Server Error", {}, None ) - + self.assertFalse(manager.is_server_healthy()) - + def test_is_server_healthy_no_url(self): """Test server health check when no URL is set.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Setup server process but no URL manager.server_process = Mock() manager.server_process.returncode = None # manager.server_url remains None - + self.assertFalse(manager.is_server_healthy()) - + def test_is_server_healthy_not_running(self): """Test server health check when server is not running.""" manager = MCPManager(config_dir=self.test_config_dir) - + # No server process self.assertFalse(manager.is_server_healthy()) - + + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") @patch('azext_aks_agent.agent.mcp_manager.asyncio.create_subprocess_exec') @patch('azext_aks_agent.agent.mcp_manager.asyncio.sleep') async def test_start_server_success(self, mock_sleep, mock_create_subprocess): """Test successful server start.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Mock binary availability with patch.object(manager, 'is_binary_available', return_value=True): with patch.object(manager, 'get_binary_path', return_value='/fake/aks-mcp'): with patch.object(manager, '_find_available_port', return_value=8003): with patch.object(manager, 'is_server_healthy', return_value=True): - + # Mock subprocess creation mock_process = AsyncMock() mock_create_subprocess.return_value = mock_process - + result = await manager.start_server() - + self.assertTrue(result) self.assertEqual(manager.server_process, mock_process) self.assertEqual(manager.server_url, "http://localhost:8003/sse") self.assertEqual(manager.server_port, 8003) - + # Verify subprocess was called correctly mock_create_subprocess.assert_called_once() args = mock_create_subprocess.call_args[0] self.assertEqual(args, ('/fake/aks-mcp', '--transport', 'sse', '--port', '8003')) - + + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") @patch('azext_aks_agent.agent.mcp_manager.asyncio.create_subprocess_exec') @patch('azext_aks_agent.agent.mcp_manager.asyncio.sleep') async def test_start_server_already_running_and_healthy(self, mock_sleep, mock_create_subprocess): """Test start_server when server is already running and healthy.""" manager = MCPManager(config_dir=self.test_config_dir, verbose=True) - + with patch.object(manager, 'is_binary_available', return_value=True): with patch.object(manager, 'is_server_running', return_value=True): with patch.object(manager, 'is_server_healthy', return_value=True): with patch('azext_aks_agent.agent.user_feedback.ProgressReporter.show_status_message') as mock_progress: - + result = await manager.start_server() - + self.assertTrue(result) # Should not create new subprocess mock_create_subprocess.assert_not_called() # Should show status message in verbose mode mock_progress.assert_called_with("MCP server is already running and healthy", "info") - + + @pytest.mark.skip(reason="The async test is currently not supported in test pipeline.") async def test_start_server_unhealthy_restart(self): """Test start_server restarts unhealthy running server.""" manager = MCPManager(config_dir=self.test_config_dir) - + with patch.object(manager, 'is_binary_available', return_value=True): with patch.object(manager, 'is_server_running', return_value=True): with patch.object(manager, 'is_server_healthy', return_value=False): @@ -317,65 +322,65 @@ async def test_start_server_unhealthy_restart(self): with patch.object(manager, 'get_binary_path', return_value='/fake/aks-mcp'): with patch('azext_aks_agent.agent.mcp_manager.asyncio.create_subprocess_exec') as mock_create: with patch('azext_aks_agent.agent.mcp_manager.asyncio.sleep'): - + # Mock the health check to fail first time, succeed second time health_calls = [False, True] - + def side_effect(*args, **kwargs): return health_calls.pop(0) if health_calls else True - + with patch.object(manager, 'is_server_healthy', side_effect=side_effect): mock_process = AsyncMock() mock_create.return_value = mock_process - + result = await manager.start_server() - + self.assertTrue(result) mock_stop.assert_called_once() - + def test_stop_server_no_process(self): """Test stop_server when no process exists.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Should not raise exception manager.stop_server() - + self.assertIsNone(manager.server_process) self.assertIsNone(manager.server_url) self.assertIsNone(manager.server_port) - + def test_get_server_url_running(self): """Test get_server_url when server is running.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Mock running server manager.server_process = Mock() manager.server_process.returncode = None manager.server_url = "http://localhost:8003/sse" - + self.assertEqual(manager.get_server_url(), "http://localhost:8003/sse") - + def test_get_server_url_not_running(self): """Test get_server_url when server is not running.""" manager = MCPManager(config_dir=self.test_config_dir) - + self.assertIsNone(manager.get_server_url()) - + def test_get_server_port_running(self): """Test get_server_port when server is running.""" manager = MCPManager(config_dir=self.test_config_dir) - + # Mock running server manager.server_process = Mock() manager.server_process.returncode = None manager.server_port = 8003 - + self.assertEqual(manager.get_server_port(), 8003) - + def test_get_server_port_not_running(self): """Test get_server_port when server is not running.""" manager = MCPManager(config_dir=self.test_config_dir) - + self.assertIsNone(manager.get_server_port()) @@ -383,10 +388,10 @@ def test_get_server_port_not_running(self): # Run tests including async tests loader = unittest.TestLoader() suite = unittest.TestSuite() - + # Add regular test cases suite.addTests(loader.loadTestsFromTestCase(TestMCPManager)) suite.addTests(loader.loadTestsFromTestCase(TestMCPManagerServerLifecycle)) - + runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) diff --git a/src/aks-agent/azext_aks_agent/tests/latest/utils.py b/src/aks-agent/azext_aks_agent/tests/latest/utils.py deleted file mode 100644 index 69cda720ee8..00000000000 --- a/src/aks-agent/azext_aks_agent/tests/latest/utils.py +++ /dev/null @@ -1,11 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -import os - - -def get_test_data_file_path(filename): - curr_dir = os.path.dirname(os.path.realpath(__file__)) - return os.path.join(curr_dir, "data", filename) diff --git a/src/aks-agent/setup.py b/src/aks-agent/setup.py index ee2dbf93b6c..ec23c5df238 100644 --- a/src/aks-agent/setup.py +++ b/src/aks-agent/setup.py @@ -9,7 +9,7 @@ from setuptools import find_packages, setup -VERSION = "1.0.0b5" +VERSION = "1.0.0b6" CLASSIFIERS = [ "Development Status :: 4 - Beta", @@ -24,6 +24,7 @@ ] DEPENDENCIES = [ + "rich==13.9.4", "holmesgpt==0.15.0; python_version >= '3.10'", ]