|
9 | 9 |
|
10 | 10 | import yaml |
11 | 11 | from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME |
| 12 | +from azext_aks_agent.agent.llm_providers import PROVIDER_REGISTRY |
12 | 13 | from azure.cli.core.api import get_config_dir |
13 | 14 | from azure.cli.core.azclierror import AzCLIError |
| 15 | +from knack.log import get_logger |
| 16 | + |
| 17 | +logger = get_logger(__name__) |
14 | 18 |
|
15 | 19 |
|
16 | 20 | class LLMConfigManager: |
@@ -67,18 +71,30 @@ def save(self, provider_name: str, params: dict): |
67 | 71 | configs = {} |
68 | 72 |
|
69 | 73 | models = configs.get("llms", []) |
70 | | - model_name = params.get("MODEL_NAME") |
71 | | - if not model_name: |
72 | | - raise ValueError("MODEL_NAME is required to save configuration.") |
73 | | - |
74 | | - # Check if model already exists, update it and move it to the last; |
75 | | - # otherwise, append new |
76 | | - models = [ |
77 | | - cfg for cfg in models if not ( |
78 | | - cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name)] |
79 | | - models.append({"provider": provider_name, **params}) |
80 | 74 |
|
81 | | - configs["llms"] = models |
| 75 | + # modify existing azure openai config from model name to deloyment name |
| 76 | + for model in models: |
| 77 | + if provider_name.lower() == "azure" and "MODEL_NAME" in model: |
| 78 | + model["DEPLOYMENT_NAME"] = model.pop("MODEL_NAME") |
| 79 | + |
| 80 | + def _update_llm_config(provider_name, required_key, params, existing_models): |
| 81 | + required_value = params.get(required_key) |
| 82 | + if not required_value: |
| 83 | + raise ValueError(f"{required_key} is required to save configuration.") |
| 84 | + |
| 85 | + # Check if model already exists, update it and move it to the last; |
| 86 | + # otherwise, append the new one. |
| 87 | + models = [ |
| 88 | + cfg for cfg in existing_models if not ( |
| 89 | + cfg.get("provider") == provider_name and cfg.get(required_key) == required_value)] |
| 90 | + models.append({"provider": provider_name, **params}) |
| 91 | + return models |
| 92 | + |
| 93 | + # To be consistent, we expose DEPLOYMENT_NAME for Azure provider in both configuration file and init prompts. |
| 94 | + if provider_name.lower() == "azure": |
| 95 | + configs["llms"] = _update_llm_config(provider_name, "DEPLOYMENT_NAME", params, models) |
| 96 | + else: |
| 97 | + configs["llms"] = _update_llm_config(provider_name, "MODEL_NAME", params, models) |
82 | 98 |
|
83 | 99 | with open(self.config_path, "w") as f: |
84 | 100 | yaml.safe_dump(configs, f, sort_keys=False) |
@@ -112,14 +128,16 @@ def get_specific( |
112 | 128 | """ |
113 | 129 | model_configs = self.get_list() |
114 | 130 | for cfg in model_configs: |
115 | | - if cfg.get("provider") == provider_name and cfg.get( |
116 | | - "MODEL_NAME") == model_name: |
| 131 | + if cfg.get("provider") == provider_name and provider_name.lower() == "azure": |
| 132 | + if cfg.get("DEPLOYMENT_NAME") == model_name or cfg.get("MODEL_NAME") == model_name: |
| 133 | + return cfg |
| 134 | + if cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name: |
117 | 135 | return cfg |
118 | 136 | return None |
119 | 137 |
|
120 | 138 | def get_model_config(self, model) -> Optional[Dict]: |
121 | 139 | prompt_for_init = "Run 'az aks agent-init' to set up your LLM endpoint (recommended path).\n" \ |
122 | | - "To configure your LLM manually, create a config file using the templates provided here: "\ |
| 140 | + "To configure your LLM manually, create a config file using the templates provided here: " \ |
123 | 141 | "https://aka.ms/aks/agentic-cli/init" |
124 | 142 |
|
125 | 143 | if not model: |
@@ -147,3 +165,23 @@ def is_config_complete(self, config, provider_schema): |
147 | 165 | config.get(key)): |
148 | 166 | return False |
149 | 167 | return True |
| 168 | + |
| 169 | + def export_model_config(self, llm_config) -> str: |
| 170 | + # Check if the configuration is complete |
| 171 | + provider_name = llm_config.get("provider") |
| 172 | + provider_instance = PROVIDER_REGISTRY.get(provider_name)() |
| 173 | + # NOTE(mainred) for backward compatibility with Azure OpenAI, replace the MODEL_NAME with DEPLOYMENT_NAME |
| 174 | + if provider_name.lower() == "azure" and "MODEL_NAME" in llm_config: |
| 175 | + llm_config["DEPLOYMENT_NAME"] = llm_config.pop("MODEL_NAME") |
| 176 | + |
| 177 | + model_name_key = "MODEL_NAME" if provider_name.lower() != "azure" else "DEPLOYMENT_NAME" |
| 178 | + model = provider_instance.model_name(llm_config.get(model_name_key)) |
| 179 | + |
| 180 | + # Set environment variables for the model provider |
| 181 | + for k, v in llm_config.items(): |
| 182 | + if k not in ["provider", "MODEL_NAME", "DEPLOYMENT_NAME"]: |
| 183 | + os.environ[k] = v |
| 184 | + logger.info( |
| 185 | + "Using provider: %s, model: %s, Env vars setup successfully.", provider_name, llm_config.get("MODEL_NAME")) |
| 186 | + |
| 187 | + return model |
0 commit comments