-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathmain.py
More file actions
105 lines (79 loc) · 3.01 KB
/
Copy pathmain.py
File metadata and controls
105 lines (79 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Unified CUA (Computer Use Agent) template with multi-provider support.
Supports Anthropic, OpenAI, and Gemini as interchangeable providers.
Configure via environment variables:
CUA_PROVIDER — primary provider ("anthropic", "openai", or "gemini")
CUA_FALLBACK_PROVIDERS — comma-separated fallback order (optional)
Each provider requires its own API key:
ANTHROPIC_API_KEY, OPENAI_API_KEY, GOOGLE_API_KEY
"""
from __future__ import annotations
import asyncio
from typing import Literal, TypedDict
import kernel
from kernel import Kernel
from providers import resolve_providers, run_with_fallback, TaskOptions
from session import KernelBrowserSession, SessionOptions
kernel_client = Kernel()
app = kernel.App("python-cua")
class CuaInput(TypedDict, total=False):
query: str
provider: Literal["anthropic", "openai", "gemini"]
model: str
record_replay: bool
class CuaOutput(TypedDict, total=False):
result: str
provider: str
replay_url: str
# Provider resolution is deferred to the action handler because env vars
# are not available during Hypeman's build/discovery phase.
_providers: list | None = None
def _get_providers():
global _providers
if _providers is None:
_providers = resolve_providers()
print(f"Configured providers: {' -> '.join(p.name for p in _providers)}")
return _providers
@app.action("cua-task")
async def cua_task(ctx: kernel.KernelContext, payload: CuaInput | None = None) -> CuaOutput:
if not payload or not payload.get("query"):
raise ValueError('Query is required. Payload must include: {"query": "your task description"}')
providers = _get_providers()
# Per-request provider override: move requested provider to front
if payload.get("provider"):
requested = next((p for p in providers if p.name == payload["provider"]), None)
if requested:
providers = [requested] + [p for p in providers if p is not requested]
session = KernelBrowserSession(
kernel_client,
SessionOptions(
invocation_id=ctx.invocation_id,
stealth=True,
record_replay=payload.get("record_replay", False),
),
)
await session.start()
print(f"Live view: {session.live_view_url}")
try:
task_result = await run_with_fallback(
providers,
TaskOptions(
query=payload["query"],
kernel=kernel_client,
session_id=session.session_id,
model=payload.get("model"),
viewport_width=session.opts.viewport_width,
viewport_height=session.opts.viewport_height,
),
)
session_info = await session.stop()
output: CuaOutput = {
"result": task_result.result,
"provider": task_result.provider,
}
if session_info.replay_view_url:
output["replay_url"] = session_info.replay_view_url
return output
except Exception:
await session.stop()
raise