Skip to content

Commit 1da0b1b

Browse files
committed
add validation connection for customer provided model
1 parent a242946 commit 1da0b1b

File tree

9 files changed

+107
-118
lines changed

9 files changed

+107
-118
lines changed

src/aks-sreclaw/azext_aks_sreclaw/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# Licensed under the MIT License. See License.txt in the project root for license information.
44
# --------------------------------------------------------------------------------------------
55

6-
from azext_aks_sreclaw._client_factory import CUSTOM_MGMT_AKS
7-
86
# pylint: disable=unused-import
7+
import azext_aks_sreclaw._help
8+
from azext_aks_sreclaw._client_factory import CUSTOM_MGMT_AKS
99
from azure.cli.core import AzCommandsLoader
1010
from azure.cli.core.profiles import register_resource_type
1111

src/aks-sreclaw/azext_aks_sreclaw/custom.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
# pylint: disable=too-many-lines, disable=broad-except, disable=line-too-long
77

8-
import subprocess
9-
108
from azext_aks_sreclaw.sreclaw.aks import get_aks_credentials
119
from azext_aks_sreclaw.sreclaw.console import (
1210
ERROR_COLOR,
@@ -25,7 +23,6 @@
2523
from azure.cli.core.azclierror import AzCLIError
2624
from azure.cli.core.commands.client_factory import get_subscription_id
2725
from knack.log import get_logger
28-
from knack.util import CLIError
2926

3027
logger = get_logger(__name__)
3128

@@ -259,7 +256,6 @@ def aks_sreclaw_status(
259256
namespace,
260257
):
261258
"""Display the status of the SREClaw deployment."""
262-
console = get_console()
263259

264260
kubeconfig_path = get_aks_credentials(
265261
client,
@@ -453,7 +449,7 @@ def aks_sreclaw_connect(
453449
console.print(
454450
f"🚀 Port-forwarding: localhost:{local_port} -> {aks_sreclaw_manager.chart_name}:{target_port}", style=INFO_COLOR)
455451
console.print(f"🌐 Open your browser and navigate to: http://localhost:{local_port}", style=INFO_COLOR)
456-
console.print(f"Press Ctrl+C to stop\n", style="dim")
452+
console.print("Press Ctrl+C to stop\n", style="dim")
457453

458454
# Start blocking port-forward
459455
aks_sreclaw_manager.start_port_forward(pod_name, target_port, local_port)

src/aks-sreclaw/azext_aks_sreclaw/sreclaw/k8s/aks_sreclaw_manager.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
import os
99
import tempfile
1010
from abc import ABC, abstractmethod
11-
from pathlib import Path
1211
from typing import Dict, List, Optional, Tuple, Union
1312

1413
from azext_aks_sreclaw._consts import (
1514
AGENT_NAMESPACE,
1615
AKS_SRECLAW_LABEL_SELECTOR,
17-
AKS_SRECLAW_VERSION,
1816
)
1917
from azext_aks_sreclaw.sreclaw.k8s.helm_manager import HelmManager
2018
from azext_aks_sreclaw.sreclaw.llm_config_manager import LLMConfigManager
@@ -480,7 +478,6 @@ def _wait_for_pods_ready(self, timeout: int = 300, interval: int = 5) -> bool:
480478

481479
while time.time() - start_time < timeout:
482480
try:
483-
# Check for pods with label selector
484481
pod_list = self.core_v1.list_namespaced_pod(
485482
namespace=self.namespace,
486483
label_selector=AKS_SRECLAW_LABEL_SELECTOR
@@ -491,30 +488,7 @@ def _wait_for_pods_ready(self, timeout: int = 300, interval: int = 5) -> bool:
491488
time.sleep(interval)
492489
continue
493490

494-
# Check if all pods are ready
495-
all_ready = True
496-
for pod in pod_list.items:
497-
pod_name = pod.metadata.name
498-
pod_phase = pod.status.phase
499-
500-
if pod_phase != "Running":
501-
logger.debug("Pod %s is in phase %s, waiting...", pod_name, pod_phase)
502-
all_ready = False
503-
break
504-
505-
# Check pod readiness condition
506-
pod_ready = False
507-
if pod.status.conditions:
508-
for condition in pod.status.conditions:
509-
if condition.type == "Ready" and condition.status == "True":
510-
pod_ready = True
511-
break
512-
513-
if not pod_ready:
514-
logger.debug("Pod %s is not ready yet, waiting...", pod_name)
515-
all_ready = False
516-
break
517-
491+
all_ready = self._check_all_pods_ready(pod_list.items)
518492
if all_ready:
519493
logger.info("All SREClaw pods are ready")
520494
return True
@@ -531,6 +505,30 @@ def _wait_for_pods_ready(self, timeout: int = 300, interval: int = 5) -> bool:
531505
logger.warning("Timeout waiting for SREClaw pods to be ready")
532506
return False
533507

508+
def _check_all_pods_ready(self, pods) -> bool:
509+
"""Check if all pods are ready."""
510+
for pod in pods:
511+
pod_name = pod.metadata.name
512+
pod_phase = pod.status.phase
513+
514+
if pod_phase != "Running":
515+
logger.debug("Pod %s is in phase %s, waiting...", pod_name, pod_phase)
516+
return False
517+
518+
if not self._is_pod_ready(pod):
519+
logger.debug("Pod %s is not ready yet, waiting...", pod_name)
520+
return False
521+
522+
return True
523+
524+
def _is_pod_ready(self, pod) -> bool:
525+
"""Check if a pod is ready."""
526+
if pod.status.conditions:
527+
for condition in pod.status.conditions:
528+
if condition.type == "Ready" and condition.status == "True":
529+
return True
530+
return False
531+
534532
def deploy_sreclaw(self, chart_version: Optional[str] = None, no_wait: bool = False) -> Tuple[bool, str]:
535533
"""
536534
Deploy SREClaw using helm chart.
@@ -1042,7 +1040,7 @@ def get_gateway_token(self) -> str:
10421040
)
10431041
raise AzCLIError(f"Failed to retrieve gateway token: {e}")
10441042

1045-
def port_forward_to_service(self, local_port: int = 18789) -> str:
1043+
def port_forward_to_service(self, local_port: int = 18789) -> str: # pylint: disable=unused-argument
10461044
"""Port-forward to aks-sreclaw service.
10471045
10481046
Args:
@@ -1054,12 +1052,6 @@ def port_forward_to_service(self, local_port: int = 18789) -> str:
10541052
Raises:
10551053
AzCLIError: If service or pod is not found, or port-forwarding fails
10561054
"""
1057-
import select
1058-
import socket
1059-
import threading
1060-
1061-
from kubernetes.stream import portforward
1062-
10631055
# Get gateway token first before starting port-forward
10641056
gateway_token = self.get_gateway_token()
10651057

@@ -1092,7 +1084,7 @@ def port_forward_to_service(self, local_port: int = 18789) -> str:
10921084
pod_name = pod.metadata.name
10931085
target_port = 18789
10941086

1095-
logger.info(f"Found running pod: {pod_name}")
1087+
logger.info("Found running pod: %s", pod_name)
10961088

10971089
# Return token to caller before starting blocking port-forward
10981090
return gateway_token, pod_name, target_port
@@ -1114,7 +1106,7 @@ def start_port_forward(self, pod_name: str, target_port: int, local_port: int =
11141106

11151107
from kubernetes.stream import portforward
11161108

1117-
logger.info(f"Port-forwarding localhost:{local_port} -> {pod_name}:{target_port}")
1109+
logger.info("Port-forwarding localhost:%d -> %s:%d", local_port, pod_name, target_port)
11181110

11191111
# Start a local TCP server and forward each connection through the k8s portforward API
11201112
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -1146,7 +1138,7 @@ def _forward(local_conn, pf_socket):
11461138
if not data:
11471139
break
11481140
local_conn.sendall(data)
1149-
except Exception:
1141+
except Exception: # pylint: disable=broad-exception-caught
11501142
pass
11511143
finally:
11521144
local_conn.close()

src/aks-sreclaw/azext_aks_sreclaw/sreclaw/llm_config_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from typing import Dict
77

8-
import yaml
98
from azext_aks_sreclaw.sreclaw.llm_providers import LLMProvider
109
from knack.log import get_logger
1110

src/aks-sreclaw/azext_aks_sreclaw/sreclaw/llm_providers/anthropic_provider.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Tuple
88

9-
import requests
9+
from openai import OpenAI
1010

1111
from .base import LLMProvider, non_empty
1212

@@ -16,10 +16,6 @@ class AnthropicProvider(LLMProvider):
1616
def readable_name(self) -> str:
1717
return "Anthropic"
1818

19-
@property
20-
def provider(self) -> str:
21-
return "anthropic"
22-
2319
@property
2420
def name(self) -> str:
2521
return "anthropic"
@@ -48,28 +44,24 @@ def validate_connection(self, params: dict) -> Tuple[str, str]:
4844
return "Missing required Anthropic parameters.", "retry_input"
4945

5046
models = [m.strip() for m in models_str.split(",")]
51-
model_name = models[0]
47+
client = OpenAI(
48+
api_key=api_key,
49+
base_url="https://api.anthropic.com/v1"
50+
)
5251

53-
url = "https://api.anthropic.com/v1/messages"
54-
headers = {
55-
"x-api-key": api_key,
56-
"anthropic-version": "2023-06-01",
57-
"Content-Type": "application/json"
58-
}
59-
payload = {
60-
"model": model_name,
61-
"max_tokens": 16,
62-
"messages": [{"role": "user", "content": "ping"}]
63-
}
52+
for model_name in models:
53+
try:
54+
client.chat.completions.create(
55+
model=model_name,
56+
messages=[{"role": "user", "content": "ping"}],
57+
max_tokens=16,
58+
timeout=10
59+
)
60+
except Exception as e: # pylint: disable=broad-exception-caught
61+
error_str = str(e).lower()
62+
if any(x in error_str for x in ["api key", "authentication", "unauthorized",
63+
"invalid", "bad request"]):
64+
return f"Model '{model_name}' validation failed: {e}", "retry_input"
65+
return f"Model '{model_name}' connection error: {e}", "connection_error"
6466

65-
try:
66-
resp = requests.post(url, headers=headers,
67-
json=payload, timeout=10)
68-
resp.raise_for_status()
69-
return None, "save"
70-
except requests.exceptions.HTTPError as e:
71-
if 400 <= resp.status_code < 500:
72-
return f"Client error: {e} - {resp.text}", "retry_input"
73-
return f"Server error: {e} - {resp.text}", "connection_error"
74-
except requests.exceptions.RequestException as e:
75-
return f"Request error: {e}", "connection_error"
67+
return None, "save"

src/aks-sreclaw/azext_aks_sreclaw/sreclaw/llm_providers/azure_provider.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,13 @@
55

66

77
from typing import Tuple
8-
from urllib.parse import urlencode, urljoin
98

10-
import requests
9+
from openai import AzureOpenAI
1110

1211
from .base import LLMProvider, is_valid_url, non_empty
1312

1413

1514
def is_valid_api_base(v: str) -> bool:
16-
# A valid api_base should be a URL and starts with https://, and ends with either .openai.azure.com/ or
17-
# .cognitiveservices.azure.com/. Until there's a convergence on the endpoint format for Azure OpenAI service,
18-
# we will accept both formats without validation.
1915
if not v.startswith("https://"):
2016
return False
2117
return is_valid_url(v)
@@ -36,7 +32,7 @@ def parameter_schema(self):
3632
"models": {
3733
"secret": False,
3834
"default": None,
39-
"hint": "comma-separated model names, e.g., gpt-5.4,gpt-5.1",
35+
"hint": "comma-separated deployment names, e.g., gpt-5.4,gpt-5.1",
4036
"validator": non_empty,
4137
"alias": "models"
4238
},
@@ -55,4 +51,32 @@ def parameter_schema(self):
5551
}
5652

5753
def validate_connection(self, params: dict) -> Tuple[str, str]:
58-
return None, "save" # None error means success
54+
api_key = params.get("api_key")
55+
api_base = params.get("api_base")
56+
models_str = params.get("models")
57+
58+
if not all([api_key, api_base, models_str]):
59+
return "Missing required Azure OpenAI parameters.", "retry_input"
60+
61+
models = [m.strip() for m in models_str.split(",")]
62+
client = AzureOpenAI(
63+
api_key=api_key,
64+
azure_endpoint=api_base
65+
)
66+
67+
for model_name in models:
68+
try:
69+
client.responses.create(
70+
model=model_name,
71+
instructions="You are a helpful assistant.",
72+
input="ping",
73+
timeout=10
74+
)
75+
except Exception as e: # pylint: disable=broad-exception-caught
76+
error_str = str(e).lower()
77+
if any(x in error_str for x in ["api key", "authentication", "unauthorized",
78+
"invalid", "bad request", "deployment"]):
79+
return f"Model '{model_name}' validation failed: {e}", "retry_input"
80+
return f"Model '{model_name}' connection error: {e}", "connection_error"
81+
82+
return None, "save"

src/aks-sreclaw/azext_aks_sreclaw/sreclaw/llm_providers/base.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# --------------------------------------------------------------------------------------------
55

66

7-
import base64
7+
import base64 # pylint: disable=unused-import
88
from abc import ABC, abstractmethod
99
from typing import Any, Callable, Dict, Tuple
1010
from urllib.parse import urlparse
@@ -45,18 +45,11 @@ def readable_name(self) -> str:
4545
@property
4646
def name(self) -> str:
4747
"""Return the provider name for this provider.
48-
provider name is the key to identity a llmprovider.
49-
https://docs.litellm.ai/docs/providers
48+
This name is used as the OpenClaw LLM provider identifier and must match
49+
the provider name expected by the OpenClaw configuration.
50+
Examples: "azure-openai", "openai", "anthropic"
5051
"""
51-
return self.provider
52-
53-
def model_name(self, model_name) -> str:
54-
"""Return the model name for this provider.
55-
The models name combines the model route and model name, e.g., "azure/gpt-5"
56-
https://docs.litellm.ai/docs/providers
57-
"""
58-
59-
return model_name
52+
return ""
6053

6154
@property
6255
@abstractmethod
@@ -145,7 +138,6 @@ def validate_params(self, params: dict):
145138
raise ValueError(f"Invalid value for parameter: {param}")
146139
return True
147140

148-
# pylint: disable=unused-argument
149141
@abstractmethod
150142
def validate_connection(self, params: dict) -> Tuple[str, str]:
151143
"""
@@ -154,6 +146,4 @@ def validate_connection(self, params: dict) -> Tuple[str, str]:
154146
where error is None if validation is successful, otherwise contains the error message.
155147
Action can be "retry_input", "connection_error", or "save".
156148
"""
157-
# TODO(mainred): leverage 3rd party libraries like litellm instead of
158-
# calling http request in each provider to complete the connection check.
159149
raise NotImplementedError()

0 commit comments

Comments
 (0)