Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/aks-agent/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ To release a new version, please select a new version number (usually plus 1 to
Pending
+++++++

1.0.0b9
+++++++
* agent-init: replace model name with deployment name for Azure OpenAI service.
* agent-init: remove importing holmesgpt to resolve the latency issue.

1.0.0b8
+++++++
* Error handling: dont raise traceback for init prompt and holmesgpt interaction.
Expand Down
4 changes: 4 additions & 0 deletions src/aks-agent/azext_aks_agent/_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@
CONST_MCP_MIN_VERSION = "0.0.10"
CONST_MCP_GITHUB_REPO = "Azure/aks-mcp"
CONST_MCP_BINARY_DIR = "bin"

# Color constants for terminal output
HELP_COLOR = "cyan" # same as AI_COLOR for now
ERROR_COLOR = "red"
66 changes: 52 additions & 14 deletions src/aks-agent/azext_aks_agent/agent/llm_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@

import yaml
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME
from azext_aks_agent.agent.llm_providers import PROVIDER_REGISTRY
from azure.cli.core.api import get_config_dir
from azure.cli.core.azclierror import AzCLIError
from knack.log import get_logger

logger = get_logger(__name__)


class LLMConfigManager:
Expand Down Expand Up @@ -67,18 +71,30 @@ def save(self, provider_name: str, params: dict):
configs = {}

models = configs.get("llms", [])
model_name = params.get("MODEL_NAME")
if not model_name:
raise ValueError("MODEL_NAME is required to save configuration.")

# Check if model already exists, update it and move it to the last;
# otherwise, append new
models = [
cfg for cfg in models if not (
cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name)]
models.append({"provider": provider_name, **params})

configs["llms"] = models
# modify existing azure openai config from model name to deloyment name
for model in models:
if provider_name.lower() == "azure" and "MODEL_NAME" in model:
model["DEPLOYMENT_NAME"] = model.pop("MODEL_NAME")

def _update_llm_config(provider_name, required_key, params, existing_models):
required_value = params.get(required_key)
if not required_value:
raise ValueError(f"{required_key} is required to save configuration.")

# Check if model already exists, update it and move it to the last;
# otherwise, append the new one.
models = [
cfg for cfg in existing_models if not (
cfg.get("provider") == provider_name and cfg.get(required_key) == required_value)]
models.append({"provider": provider_name, **params})
return models

# To be consistent, we expose DEPLOYMENT_NAME for Azure provider in both configuration file and init prompts.
if provider_name.lower() == "azure":
configs["llms"] = _update_llm_config(provider_name, "DEPLOYMENT_NAME", params, models)
else:
configs["llms"] = _update_llm_config(provider_name, "MODEL_NAME", params, models)

with open(self.config_path, "w") as f:
yaml.safe_dump(configs, f, sort_keys=False)
Expand Down Expand Up @@ -112,14 +128,16 @@ def get_specific(
"""
model_configs = self.get_list()
for cfg in model_configs:
if cfg.get("provider") == provider_name and cfg.get(
"MODEL_NAME") == model_name:
if cfg.get("provider") == provider_name and provider_name.lower() == "azure":
if cfg.get("DEPLOYMENT_NAME") == model_name or cfg.get("MODEL_NAME") == model_name:
return cfg
if cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name:
return cfg
return None

def get_model_config(self, model) -> Optional[Dict]:
prompt_for_init = "Run 'az aks agent-init' to set up your LLM endpoint (recommended path).\n" \
"To configure your LLM manually, create a config file using the templates provided here: "\
"To configure your LLM manually, create a config file using the templates provided here: " \
"https://aka.ms/aks/agentic-cli/init"

if not model:
Expand Down Expand Up @@ -147,3 +165,23 @@ def is_config_complete(self, config, provider_schema):
config.get(key)):
return False
return True

def export_model_config(self, llm_config) -> str:
# Check if the configuration is complete
provider_name = llm_config.get("provider")
provider_instance = PROVIDER_REGISTRY.get(provider_name)()
# NOTE(mainred) for backward compatibility with Azure OpenAI, replace the MODEL_NAME with DEPLOYMENT_NAME
if provider_name.lower() == "azure" and "MODEL_NAME" in llm_config:
llm_config["DEPLOYMENT_NAME"] = llm_config.pop("MODEL_NAME")

model_name_key = "MODEL_NAME" if provider_name.lower() != "azure" else "DEPLOYMENT_NAME"
model = provider_instance.model_name(llm_config.get(model_name_key))

# Set environment variables for the model provider
for k, v in llm_config.items():
if k not in ["provider", "MODEL_NAME", "DEPLOYMENT_NAME"]:
os.environ[k] = v
logger.info(
"Using provider: %s, model: %s, Env vars setup successfully.", provider_name, llm_config.get("MODEL_NAME"))

return model
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import List, Tuple

from azext_aks_agent._consts import ERROR_COLOR, HELP_COLOR
from rich.console import Console

from .anthropic_provider import AnthropicProvider
Expand Down Expand Up @@ -47,7 +48,6 @@ def _get_provider_by_index(idx: int) -> LLMProvider:
Return provider instance by numeric index (1-based).
Raises ValueError if index is out of range.
"""
from holmes.utils.colors import HELP_COLOR
if 1 <= idx <= len(_PROVIDER_CLASSES):
console.print("You selected provider:", _PROVIDER_CLASSES[idx - 1]().readable_name, style=f"bold {HELP_COLOR}")
return _PROVIDER_CLASSES[idx - 1]()
Expand All @@ -59,7 +59,6 @@ def prompt_provider_choice() -> LLMProvider:
Show a numbered menu and return the chosen provider instance.
Keeps prompting until a valid selection is made.
"""
from holmes.utils.colors import ERROR_COLOR, HELP_COLOR
choices = _provider_choices_numbered()
if not choices:
raise ValueError("No providers are registered.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def model_route(self) -> str:
@property
def parameter_schema(self):
return {
"MODEL_NAME": {
"DEPLOYMENT_NAME": {
"secret": False,
"default": None,
"hint": "should be consistent with your deployed name, e.g., gpt-5",
"hint": "ensure your deployment name is the same as the model name, e.g., gpt-5",
"validator": non_empty
},
"AZURE_API_KEY": {
Expand All @@ -62,19 +62,19 @@ def validate_connection(self, params: dict) -> Tuple[bool, str, str]:
api_key = params.get("AZURE_API_KEY")
api_base = params.get("AZURE_API_BASE")
api_version = params.get("AZURE_API_VERSION")
model_name = params.get("MODEL_NAME")
deployment_name = params.get("DEPLOYMENT_NAME")

if not all([api_key, api_base, api_version, model_name]):
if not all([api_key, api_base, api_version, deployment_name]):
return False, "Missing required Azure parameters.", "retry_input"

# REST API reference: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/api-version-lifecycle?tabs=rest
url = urljoin(api_base, f"openai/deployments/{model_name}/chat/completions")
url = urljoin(api_base, f"openai/deployments/{deployment_name}/chat/completions")

query = {"api-version": api_version}
full_url = f"{url}?{urlencode(query)}"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"model": model_name,
"model": deployment_name,
"messages": [{"role": "user", "content": "ping"}],
"max_tokens": 16
}
Expand Down
6 changes: 2 additions & 4 deletions src/aks-agent/azext_aks_agent/agent/llm_providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, Tuple
from urllib.parse import urlparse

from azext_aks_agent._consts import ERROR_COLOR, HELP_COLOR
from rich.console import Console

console = Console()
Expand Down Expand Up @@ -85,9 +86,6 @@ def parameter_schema(self) -> Dict[str, Dict[str, Any]]:

def prompt_params(self):
"""Prompt user for parameters using parameter_schema when available."""
from holmes.interactive import SlashCommands
from holmes.utils.colors import ERROR_COLOR, HELP_COLOR

schema = self.parameter_schema
params = {}
for param, meta in schema.items():
Expand Down Expand Up @@ -134,7 +132,7 @@ def prompt_params(self):
params[param] = value
break
console.print(
f"Invalid value for {param}. Please try again, or type '{SlashCommands.EXIT.command}' to exit.",
f"Invalid value for {param}. Please try again, or type '/exit' to exit.",
style=f"{ERROR_COLOR}")

return params
Expand Down
59 changes: 22 additions & 37 deletions src/aks-agent/azext_aks_agent/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@

import os

from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME, HELP_COLOR

# pylint: disable=too-many-lines, disable=broad-except
from azext_aks_agent.agent.agent import aks_agent as aks_agent_internal
from azext_aks_agent.agent.llm_config_manager import LLMConfigManager
from azext_aks_agent.agent.llm_providers import (
PROVIDER_REGISTRY,
prompt_provider_choice,
)
from azext_aks_agent.agent.llm_providers import prompt_provider_choice
from azext_aks_agent.agent.logging import rich_logging
from azure.cli.core.api import get_config_dir
from azure.cli.core.azclierror import AzCLIError
Expand All @@ -25,29 +22,28 @@
# pylint: disable=unused-argument
def aks_agent_init(cmd):
"""Initialize AKS agent llm configuration."""
from rich.console import Console
console = Console()
console.print(
"Welcome to AKS Agent LLM configuration setup. Type '/exit' to exit.",
style=f"bold {HELP_COLOR}")

with rich_logging() as console:
from holmes.utils.colors import HELP_COLOR
console.print(
"Welcome to AKS Agent LLM configuration setup. Type '/exit' to exit.",
style=f"bold {HELP_COLOR}")

provider = prompt_provider_choice()
params = provider.prompt_params()
provider = prompt_provider_choice()
params = provider.prompt_params()

llm_config_manager = LLMConfigManager()
# If the connection to the model endpoint is valid, save the configuration
is_valid, msg, action = provider.validate_connection(params)
llm_config_manager = LLMConfigManager()
# If the connection to the model endpoint is valid, save the configuration
is_valid, msg, action = provider.validate_connection(params)

if is_valid and action == "save":
llm_config_manager.save(provider.model_route if provider.model_route else "openai", params)
console.print(
f"LLM configuration setup successfully and is saved to {llm_config_manager.config_path}.",
style=f"bold {HELP_COLOR}")
elif not is_valid and action == "retry_input":
raise AzCLIError(f"Please re-run `az aks agent-init` to correct the input parameters. {str(msg)}")
else:
raise AzCLIError(f"Please check your deployed model and network connectivity. {str(msg)}")
if is_valid and action == "save":
llm_config_manager.save(provider.model_route if provider.model_route else "openai", params)
console.print(
f"LLM configuration setup successfully and is saved to {llm_config_manager.config_path}.",
style=f"bold {HELP_COLOR}")
elif not is_valid and action == "retry_input":
raise AzCLIError(f"Please re-run `az aks agent-init` to correct the input parameters. {str(msg)}")
else:
raise AzCLIError(f"Please check your deployed model and network connectivity. {str(msg)}")


# pylint: disable=unused-argument
Expand Down Expand Up @@ -81,18 +77,7 @@ def aks_agent(
llm_config_manager = LLMConfigManager(config_file)
llm_config_manager.validate_config()
llm_config = llm_config_manager.get_model_config(model)

# Check if the configuration is complete
provider_name = llm_config.get("provider")
provider_instance = PROVIDER_REGISTRY.get(provider_name)()
model = provider_instance.model_name(llm_config.get("MODEL_NAME"))

# Set environment variables for the model provider
for k, v in llm_config.items():
if k not in ["provider", "MODEL_NAME"]:
os.environ[k] = v
logger.info(
"Using provider: %s, model: %s, Env vars setup successfully.", provider_name, llm_config.get("MODEL_NAME"))
llm_config_manager.export_model_config(llm_config)

with rich_logging():
aks_agent_internal(
Expand Down
Loading
Loading