-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy path__init__.py
More file actions
108 lines (84 loc) · 2.82 KB
/
Copy path__init__.py
File metadata and controls
108 lines (84 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Provider factory with automatic fallback.
Resolution order:
1. CUA_PROVIDER env var (required)
2. CUA_FALLBACK_PROVIDERS env var (optional, comma-separated)
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Protocol
from kernel import Kernel
@dataclass
class TaskOptions:
query: str
kernel: Kernel
session_id: str
model: str | None = None
viewport_width: int = 1280
viewport_height: int = 800
@dataclass
class TaskResult:
result: str
provider: str
class CuaProvider(Protocol):
@property
def name(self) -> str: ...
def is_configured(self) -> bool: ...
async def run_task(self, options: TaskOptions) -> TaskResult: ...
def _build_provider(name: str) -> CuaProvider | None:
if name == "anthropic":
from .anthropic import AnthropicProvider
return AnthropicProvider()
if name == "openai":
from .openai import OpenAIProvider
return OpenAIProvider()
if name == "gemini":
from .gemini import GeminiProvider
return GeminiProvider()
return None
def resolve_providers() -> list[CuaProvider]:
"""Build the ordered list of providers to try."""
primary = os.environ.get("CUA_PROVIDER", "").strip().lower()
fallbacks = [
s.strip().lower()
for s in os.environ.get("CUA_FALLBACK_PROVIDERS", "").split(",")
if s.strip()
]
order = ([primary] if primary else []) + fallbacks
seen: set[str] = set()
providers: list[CuaProvider] = []
for name in order:
if name in seen:
continue
seen.add(name)
provider = _build_provider(name)
if provider is None:
print(f'Warning: Unknown provider "{name}", skipping.')
continue
if not provider.is_configured():
print(f'Warning: Provider "{name}" missing API key, skipping.')
continue
providers.append(provider)
if not providers:
raise RuntimeError(
"No CUA provider is configured. "
"Set CUA_PROVIDER to one of: anthropic, openai, gemini, "
"and provide the matching API key."
)
return providers
async def run_with_fallback(
providers: list[CuaProvider],
options: TaskOptions,
) -> TaskResult:
"""Run a CUA task, trying each provider in order until one succeeds."""
errors: list[tuple[str, Exception]] = []
for provider in providers:
try:
print(f"Attempting provider: {provider.name}")
return await provider.run_task(options)
except Exception as exc:
print(f'Provider "{provider.name}" failed: {exc}')
errors.append((provider.name, exc))
summary = "\n".join(f" {name}: {exc}" for name, exc in errors)
raise RuntimeError(f"All providers failed:\n{summary}")