|
| 1 | +"""SuperAgent 实例 / Super Agent Instance |
| 2 | +
|
| 3 | +``SuperAgent`` 是暴露给应用开发者的强类型实例对象, 承载 ``invoke`` / 会话管理 |
| 4 | +两类方法 (仅异步; 见决策 14)。CRUDL 由 ``SuperAgentClient`` 管理。 |
| 5 | +
|
| 6 | +本文件为模板 (``__agent_async_template.py``), codegen 会把 ``async def ...`` |
| 7 | +转换成同步骨架; 实际第一版异步主路径 + 同步 NotImplementedError 占位。 |
| 8 | +""" |
| 9 | + |
| 10 | +from dataclasses import dataclass, field |
| 11 | +from typing import Any, Dict, List, Optional |
| 12 | + |
| 13 | +from agentrun.super_agent.api.data import SuperAgentDataAPI |
| 14 | +from agentrun.super_agent.model import ConversationInfo, Message |
| 15 | +from agentrun.super_agent.stream import InvokeStream |
| 16 | +from agentrun.utils.config import Config |
| 17 | + |
| 18 | +_SYNC_UNSUPPORTED_MSG = ( |
| 19 | + "sync version not supported, use *_async (see decision 14 in" |
| 20 | + " openspec/changes/add-super-agent-sdk/design.md)" |
| 21 | +) |
| 22 | + |
| 23 | + |
| 24 | +@dataclass |
| 25 | +class SuperAgent: |
| 26 | + """超级 Agent 实例. |
| 27 | +
|
| 28 | + 业务字段 (``prompt / agents / tools / ...``) 从 ``protocolSettings.config`` |
| 29 | + 反解。系统字段 (``agent_runtime_id / arn / status / ...``) 来自 AgentRuntime。 |
| 30 | + """ |
| 31 | + |
| 32 | + name: str |
| 33 | + description: Optional[str] = None |
| 34 | + prompt: Optional[str] = None |
| 35 | + agents: List[str] = field(default_factory=list) |
| 36 | + tools: List[str] = field(default_factory=list) |
| 37 | + skills: List[str] = field(default_factory=list) |
| 38 | + sandboxes: List[str] = field(default_factory=list) |
| 39 | + workspaces: List[str] = field(default_factory=list) |
| 40 | + model_service_name: Optional[str] = None |
| 41 | + model_name: Optional[str] = None |
| 42 | + |
| 43 | + agent_runtime_id: str = "" |
| 44 | + arn: str = "" |
| 45 | + status: str = "" |
| 46 | + created_at: str = "" |
| 47 | + last_updated_at: str = "" |
| 48 | + external_endpoint: str = "" |
| 49 | + |
| 50 | + _client: Any = field(default=None, repr=False, compare=False) |
| 51 | + |
| 52 | + def _resolve_config(self, config: Optional[Config]) -> Config: |
| 53 | + client_cfg = ( |
| 54 | + getattr(self._client, "config", None) if self._client else None |
| 55 | + ) |
| 56 | + return Config.with_configs(client_cfg, config) |
| 57 | + |
| 58 | + def _forwarded_business_fields(self) -> Dict[str, Any]: |
| 59 | + """把 SuperAgent 实例字段打包成 ``forwardedProps`` 顶层业务字段 dict. |
| 60 | +
|
| 61 | + 与 ``protocolSettings[0].config`` 写入时的结构保持对称: list 型用 ``[]`` |
| 62 | + 代替 None, scalar 型保留 None (由 JSON 序列化为 ``null``)。服务端读取同 |
| 63 | + 一份语义, 避免客户端/服务端对"未设置"产生歧义。 |
| 64 | + """ |
| 65 | + return { |
| 66 | + "prompt": self.prompt, |
| 67 | + "agents": list(self.agents), |
| 68 | + "tools": list(self.tools), |
| 69 | + "skills": list(self.skills), |
| 70 | + "sandboxes": list(self.sandboxes), |
| 71 | + "workspaces": list(self.workspaces), |
| 72 | + "modelServiceName": self.model_service_name, |
| 73 | + "modelName": self.model_name, |
| 74 | + } |
| 75 | + |
| 76 | + async def invoke_async( |
| 77 | + self, |
| 78 | + messages: List[Dict[str, Any]], |
| 79 | + *, |
| 80 | + conversation_id: Optional[str] = None, |
| 81 | + config: Optional[Config] = None, |
| 82 | + ) -> InvokeStream: |
| 83 | + """Phase 1: POST /invoke; 返回包含 ``conversation_id`` 的 :class:`InvokeStream`. |
| 84 | +
|
| 85 | + 首次 ``async for ev in stream`` 才触发 Phase 2 拉流 (lazy)。 |
| 86 | + """ |
| 87 | + cfg = self._resolve_config(config) |
| 88 | + api = SuperAgentDataAPI(self.name, config=cfg) |
| 89 | + resp = await api.invoke_async( |
| 90 | + messages, |
| 91 | + conversation_id=conversation_id, |
| 92 | + config=cfg, |
| 93 | + forwarded_extras=self._forwarded_business_fields(), |
| 94 | + ) |
| 95 | + stream_url = resp.stream_url |
| 96 | + stream_headers = dict(resp.stream_headers) |
| 97 | + session_id = stream_headers.get("X-Super-Agent-Session-Id", "") |
| 98 | + |
| 99 | + async def _factory(): |
| 100 | + return api.stream_async( |
| 101 | + stream_url, stream_headers=stream_headers, config=cfg |
| 102 | + ) |
| 103 | + |
| 104 | + return InvokeStream( |
| 105 | + conversation_id=resp.conversation_id, |
| 106 | + session_id=session_id, |
| 107 | + stream_url=stream_url, |
| 108 | + stream_headers=stream_headers, |
| 109 | + _stream_factory=_factory, |
| 110 | + ) |
| 111 | + |
| 112 | + def invoke( |
| 113 | + self, |
| 114 | + messages: List[Dict[str, Any]], |
| 115 | + *, |
| 116 | + conversation_id: Optional[str] = None, |
| 117 | + config: Optional[Config] = None, |
| 118 | + ) -> InvokeStream: |
| 119 | + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) |
| 120 | + |
| 121 | + async def get_conversation_async( |
| 122 | + self, |
| 123 | + conversation_id: str, |
| 124 | + *, |
| 125 | + config: Optional[Config] = None, |
| 126 | + ) -> ConversationInfo: |
| 127 | + """GET /conversations/{id} → :class:`ConversationInfo` (缺字段用默认值).""" |
| 128 | + cfg = self._resolve_config(config) |
| 129 | + api = SuperAgentDataAPI(self.name, config=cfg) |
| 130 | + data = await api.get_conversation_async(conversation_id, config=cfg) |
| 131 | + return _conversation_info_from_dict( |
| 132 | + data, fallback_conversation_id=conversation_id |
| 133 | + ) |
| 134 | + |
| 135 | + def get_conversation( |
| 136 | + self, |
| 137 | + conversation_id: str, |
| 138 | + *, |
| 139 | + config: Optional[Config] = None, |
| 140 | + ) -> ConversationInfo: |
| 141 | + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) |
| 142 | + |
| 143 | + async def delete_conversation_async( |
| 144 | + self, |
| 145 | + conversation_id: str, |
| 146 | + *, |
| 147 | + config: Optional[Config] = None, |
| 148 | + ) -> None: |
| 149 | + """DELETE /conversations/{id}.""" |
| 150 | + cfg = self._resolve_config(config) |
| 151 | + api = SuperAgentDataAPI(self.name, config=cfg) |
| 152 | + await api.delete_conversation_async(conversation_id, config=cfg) |
| 153 | + |
| 154 | + def delete_conversation( |
| 155 | + self, |
| 156 | + conversation_id: str, |
| 157 | + *, |
| 158 | + config: Optional[Config] = None, |
| 159 | + ) -> None: |
| 160 | + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) |
| 161 | + |
| 162 | + |
| 163 | +def _to_message(raw: Dict[str, Any]) -> Message: |
| 164 | + return Message( |
| 165 | + role=str(raw.get("role") or ""), |
| 166 | + content=str(raw.get("content") or ""), |
| 167 | + message_id=raw.get("messageId") or raw.get("message_id"), |
| 168 | + created_at=raw.get("createdAt") or raw.get("created_at"), |
| 169 | + ) |
| 170 | + |
| 171 | + |
| 172 | +def _conversation_info_from_dict( |
| 173 | + data: Dict[str, Any], *, fallback_conversation_id: str |
| 174 | +) -> ConversationInfo: |
| 175 | + data = data or {} |
| 176 | + messages_raw = data.get("messages") or [] |
| 177 | + messages = [_to_message(m) for m in messages_raw if isinstance(m, dict)] |
| 178 | + return ConversationInfo( |
| 179 | + conversation_id=str( |
| 180 | + data.get("conversationId") or fallback_conversation_id |
| 181 | + ), |
| 182 | + agent_id=str(data.get("agentId") or data.get("agent_id") or ""), |
| 183 | + title=data.get("title"), |
| 184 | + main_user_id=data.get("mainUserId") or data.get("main_user_id"), |
| 185 | + sub_user_id=data.get("subUserId") or data.get("sub_user_id"), |
| 186 | + created_at=int(data.get("createdAt") or data.get("created_at") or 0), |
| 187 | + updated_at=int(data.get("updatedAt") or data.get("updated_at") or 0), |
| 188 | + error_message=data.get("errorMessage") or data.get("error_message"), |
| 189 | + invoke_info=data.get("invokeInfo") or data.get("invoke_info"), |
| 190 | + messages=messages, |
| 191 | + params=data.get("params"), |
| 192 | + ) |
| 193 | + |
| 194 | + |
| 195 | +__all__ = ["SuperAgent"] |
0 commit comments