diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3f227d2..655fd35 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,9 @@ jobs: - name: Install project dependencies run: uv sync --frozen + - name: Run pre-commit + uses: pre-commit/action@v3.0.1 + - name: Setup Node uses: actions/setup-node@v4 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3499abf --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-json + - id: check-added-large-files + args: ["--maxkb=512"] + - id: check-merge-conflict + - id: check-toml + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.10 + hooks: + - id: ruff-format + + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.10.0.1 + hooks: + - id: shellcheck + files: \.sh$ diff --git a/apps/api/api_runtime_console.py b/apps/api/api_runtime_console.py index b96c9bc..49d0383 100644 --- a/apps/api/api_runtime_console.py +++ b/apps/api/api_runtime_console.py @@ -140,6 +140,7 @@ def _tools(app: Any, *, tool_overrides: Mapping[str, Any]) -> list[dict[str, Any ) return rows + def _cron_jobs(app: Any) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for job in app.cron_runtime.list_jobs(): @@ -183,7 +184,13 @@ def _proactive_ask_system_job(app: Any) -> dict[str, Any] | None: "status": "scheduled" if enabled else "paused", "profileId": None, "eggId": None, - "payload": {"type": "proactive_ask", "enabled": enabled, "idle_threshold_minutes": idle, "daily_max": daily_max, "quiet_hours": [qs, qe]}, + "payload": { + "type": "proactive_ask", + "enabled": enabled, + "idle_threshold_minutes": idle, + "daily_max": daily_max, + "quiet_hours": [qs, qe], + }, "skills": [], "createdAt": None, "updatedAt": None, diff --git a/apps/api/api_runtime_console_ops.py b/apps/api/api_runtime_console_ops.py index 01b1227..6ac254f 100644 --- a/apps/api/api_runtime_console_ops.py +++ b/apps/api/api_runtime_console_ops.py @@ -21,8 +21,6 @@ global_config_schema, load_global_config, load_extensions_from_config, - save_extensions_to_config, - save_provider_to_config, parse_global_config_text, read_global_config_text, write_global_config, @@ -67,7 +65,13 @@ def _write_manifest_to_config(state_dir: Path, manifest: Mapping[str, Any]) -> P models["default_provider_source"] = "config" config["models"] = models # Extension keys - extension_keys = ("tool_manifests", "skill_manifests", "skill_overrides", "tool_overrides", "skill_packages") + extension_keys = ( + "tool_manifests", + "skill_manifests", + "skill_overrides", + "tool_overrides", + "skill_packages", + ) extensions = config.get("extensions", {}) for key in extension_keys: if key in manifest: @@ -88,6 +92,7 @@ def _read_json_file(path: Path) -> Any: def _load_manifest_from_config(state_dir: Path) -> dict[str, Any]: """Load manifest data (gateway, extensions) from config.yaml for the given state_dir.""" from packages.runtime_config import global_config_path_for_state_dir + config_path = global_config_path_for_state_dir(state_dir) try: config = load_global_config(config_path, state_dir=state_dir) @@ -141,8 +146,16 @@ def _read_text_file(path: Path, *, max_chars: int = 20_000) -> str | None: "summary": "Feishu bot long-connection bridge for p2p and group chat messages.", "eventPath": "/feishu/events", "secretFields": ( - {"key": "app_id", "label": "App ID", "defaultEnvVar": "ELEPHANT_FEISHU_APP_ID"}, - {"key": "app_secret", "label": "App Secret", "defaultEnvVar": "ELEPHANT_FEISHU_APP_SECRET"}, + { + "key": "app_id", + "label": "App ID", + "defaultEnvVar": "ELEPHANT_FEISHU_APP_ID", + }, + { + "key": "app_secret", + "label": "App Secret", + "defaultEnvVar": "ELEPHANT_FEISHU_APP_SECRET", + }, ), "supportsDirectConfig": True, }, @@ -155,7 +168,11 @@ def _read_text_file(path: Path, *, max_chars: int = 20_000) -> str | None: "transports": ("gateway",), "summary": "Discord bot gateway bridge for DMs, channels, and threads.", "secretFields": ( - {"key": "bot_token", "label": "Bot token", "defaultEnvVar": "ELEPHANT_DISCORD_BOT_TOKEN"}, + { + "key": "bot_token", + "label": "Bot token", + "defaultEnvVar": "ELEPHANT_DISCORD_BOT_TOKEN", + }, ), "supportsDirectConfig": True, }, @@ -168,9 +185,21 @@ def _read_text_file(path: Path, *, max_chars: int = 20_000) -> str | None: "transports": ("stream",), "summary": "DingDing stream bridge for chatbot messages.", "secretFields": ( - {"key": "client_id", "label": "Client ID", "defaultEnvVar": "ELEPHANT_DINGDING_CLIENT_ID"}, - {"key": "client_secret", "label": "Client Secret", "defaultEnvVar": "ELEPHANT_DINGDING_CLIENT_SECRET"}, - {"key": "robot_code", "label": "Robot Code", "defaultEnvVar": "ELEPHANT_DINGDING_ROBOT_CODE"}, + { + "key": "client_id", + "label": "Client ID", + "defaultEnvVar": "ELEPHANT_DINGDING_CLIENT_ID", + }, + { + "key": "client_secret", + "label": "Client Secret", + "defaultEnvVar": "ELEPHANT_DINGDING_CLIENT_SECRET", + }, + { + "key": "robot_code", + "label": "Robot Code", + "defaultEnvVar": "ELEPHANT_DINGDING_ROBOT_CODE", + }, ), "supportsDirectConfig": True, }, @@ -183,8 +212,16 @@ def _read_text_file(path: Path, *, max_chars: int = 20_000) -> str | None: "transports": ("websocket",), "summary": "WeCom AI Bot WebSocket bridge for chats and groups.", "secretFields": ( - {"key": "bot_id", "label": "Bot ID", "defaultEnvVar": "ELEPHANT_WECOM_BOT_ID"}, - {"key": "secret", "label": "Secret", "defaultEnvVar": "ELEPHANT_WECOM_SECRET"}, + { + "key": "bot_id", + "label": "Bot ID", + "defaultEnvVar": "ELEPHANT_WECOM_BOT_ID", + }, + { + "key": "secret", + "label": "Secret", + "defaultEnvVar": "ELEPHANT_WECOM_SECRET", + }, ), "supportsDirectConfig": True, }, @@ -385,7 +422,11 @@ def _gateway_services( service = str(spec["service"]) adapter = adapters_payload.get(service) adapter_payload = adapter if isinstance(adapter, Mapping) else {} - account_rows = [dict(item) for item in adapter_payload.get("accounts", ()) if isinstance(item, Mapping)] if isinstance(adapter_payload.get("accounts"), (list, tuple)) else [] + account_rows = ( + [dict(item) for item in adapter_payload.get("accounts", ()) if isinstance(item, Mapping)] + if isinstance(adapter_payload.get("accounts"), (list, tuple)) + else [] + ) primary_account = account_rows[0] if account_rows else {} account_id = str(primary_account.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID) secret_fields = [] @@ -401,14 +442,18 @@ def _gateway_services( secret_key=secret_key, default_env_var=default_env_var, ) - secret_fields.append({ - "key": secret_key, - "label": str(field.get("label") or secret_key), - "hasValue": bool(local_secrets.get(env_var)), - }) + secret_fields.append( + { + "key": secret_key, + "label": str(field.get("label") or secret_key), + "hasValue": bool(local_secrets.get(env_var)), + } + ) service_runtime_files = [row for row in runtime_files if _gateway_runtime_service_key(row) == service] control = adapter_payload.get("control") if isinstance(adapter_payload.get("control"), Mapping) else {} - configured_transport = str(primary_account.get("surface") or adapter_payload.get("surface") or spec.get("defaultTransport") or "") + configured_transport = str( + primary_account.get("surface") or adapter_payload.get("surface") or spec.get("defaultTransport") or "" + ) enabled = adapter_payload.get("enabled") is True runtime_states = [_gateway_runtime_status(row) for row in service_runtime_files] is_running = any(state == "running" for state in runtime_states) @@ -421,22 +466,29 @@ def _gateway_services( if err: last_error = err break - rows.append({ - **{key: value for key, value in spec.items() if key != "secretFields"}, - "enabled": enabled, - "configured": bool(account_rows), - "configuredTransport": configured_transport, - "accountCount": len(account_rows), - "accounts": tuple(account_rows), - "primaryAccountId": account_id, - "eventPath": str(primary_account.get("event_path") or adapter_payload.get("event_path") or spec.get("eventPath") or ""), - "allowGroupChats": bool(control.get("allow_group_chats") is True), - "secretFields": tuple(secret_fields), - "runtimeFiles": tuple(service_runtime_files), - "running": is_running, - "starting": is_starting, - "lastError": last_error, - }) + rows.append( + { + **{key: value for key, value in spec.items() if key != "secretFields"}, + "enabled": enabled, + "configured": bool(account_rows), + "configuredTransport": configured_transport, + "accountCount": len(account_rows), + "accounts": tuple(account_rows), + "primaryAccountId": account_id, + "eventPath": str( + primary_account.get("event_path") + or adapter_payload.get("event_path") + or spec.get("eventPath") + or "" + ), + "allowGroupChats": bool(control.get("allow_group_chats") is True), + "secretFields": tuple(secret_fields), + "runtimeFiles": tuple(service_runtime_files), + "running": is_running, + "starting": is_starting, + "lastError": last_error, + } + ) return rows @@ -474,7 +526,9 @@ def _gateway_manifest(state_dir: Path) -> dict[str, Any]: return dict(manifest) if isinstance(manifest, Mapping) else {} -def _gateway_adapter_payload(manifest: Mapping[str, Any], service: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: +def _gateway_adapter_payload( + manifest: Mapping[str, Any], service: str +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: gateway_payload = manifest.get("gateway") if isinstance(manifest.get("gateway"), Mapping) else {} adapters_payload = gateway_payload.get("adapters") if isinstance(gateway_payload.get("adapters"), Mapping) else {} adapter_payload = adapters_payload.get(service) if isinstance(adapters_payload.get(service), Mapping) else {} @@ -538,7 +592,9 @@ def _gateway_weixin_config_from_payload(payload: Mapping[str, Any]) -> dict[str, return dict(config) -def _gateway_weixin_qr_payload(session_id: str, session_state: Mapping[str, Any], *, status: str = "wait") -> dict[str, Any]: +def _gateway_weixin_qr_payload( + session_id: str, session_state: Mapping[str, Any], *, status: str = "wait" +) -> dict[str, Any]: scan_data = str(session_state.get("qrScanData") or "") return { "status": status, @@ -557,7 +613,9 @@ async def _fetch_weixin_qr(*, bot_type: str) -> dict[str, Any]: from apps.gateway import weixin_support as wx if not wx.check_weixin_requirements(): - raise RuntimeError("WeChat QR login requires aiohttp and cryptography. Install gateway WeChat dependencies first.") + raise RuntimeError( + "WeChat QR login requires aiohttp and cryptography. Install gateway WeChat dependencies first." + ) async with wx.aiohttp.ClientSession(trust_env=True, connector=wx._make_ssl_connector()) as session: return await wx._api_get( session, @@ -571,7 +629,9 @@ async def _poll_weixin_qr(*, qrcode: str, base_url: str) -> dict[str, Any]: from apps.gateway import weixin_support as wx if not wx.check_weixin_requirements(): - raise RuntimeError("WeChat QR login requires aiohttp and cryptography. Install gateway WeChat dependencies first.") + raise RuntimeError( + "WeChat QR login requires aiohttp and cryptography. Install gateway WeChat dependencies first." + ) async with wx.aiohttp.ClientSession(trust_env=True, connector=wx._make_ssl_connector()) as session: return await wx._api_get( session, @@ -602,7 +662,9 @@ def _gateway_weixin_qr_start(self, payload: Mapping[str, Any]) -> dict[str, Any] return _gateway_weixin_qr_payload(session_id, session_state, status="wait") -def _gateway_persist_weixin_credentials(self, credentials: Mapping[str, Any], config: Mapping[str, Any]) -> dict[str, Any]: +def _gateway_persist_weixin_credentials( + self, credentials: Mapping[str, Any], config: Mapping[str, Any] +) -> dict[str, Any]: from apps.gateway import weixin_support as wx database_path = self.repository.database_path @@ -621,8 +683,14 @@ def _gateway_persist_weixin_credentials(self, credentials: Mapping[str, Any], co base_url=str(credentials.get("base_url") or credentials.get("baseurl") or wx.ILINK_BASE_URL), user_id=str(credentials.get("user_id") or credentials.get("ilink_user_id") or ""), ) - control_payload = dict(adapter_payload.get("control")) if isinstance(adapter_payload.get("control"), Mapping) else {} - allow_group_chats = bool(config.get("allowGroupChats")) if isinstance(config.get("allowGroupChats"), bool) else bool(control_payload.get("allow_group_chats") is True) + control_payload = ( + dict(adapter_payload.get("control")) if isinstance(adapter_payload.get("control"), Mapping) else {} + ) + allow_group_chats = ( + bool(config.get("allowGroupChats")) + if isinstance(config.get("allowGroupChats"), bool) + else bool(control_payload.get("allow_group_chats") is True) + ) account_payload: dict[str, Any] = { "account_id": account_id, "token": token, @@ -631,7 +699,9 @@ def _gateway_persist_weixin_credentials(self, credentials: Mapping[str, Any], co "surface": "ilink", "enabled": bool(config.get("accountEnabled")) if isinstance(config.get("accountEnabled"), bool) else True, } - event_path = str(config.get("eventPath") or config.get("event_path") or adapter_payload.get("event_path") or "/weixin/events").strip() + event_path = str( + config.get("eventPath") or config.get("event_path") or adapter_payload.get("event_path") or "/weixin/events" + ).strip() if event_path: account_payload["event_path"] = event_path adapter_payload["accounts"] = _gateway_upsert_account(accounts, account_payload) @@ -667,14 +737,25 @@ def _gateway_weixin_qr_poll(self, payload: Mapping[str, Any]) -> dict[str, Any]: raise ValueError("WeChat QR session is missing or expired; start QR setup again") if time.time() > datetime.fromisoformat(str(session_state["expiresAt"])).timestamp(): store.pop(session_id, None) - return {**_gateway_weixin_qr_payload(session_id, session_state, status="expired"), "message": "QR session expired; start again."} - status_resp = asyncio.run(_poll_weixin_qr(qrcode=str(session_state["qrcode"]), base_url=str(session_state.get("baseUrl") or "https://ilinkai.weixin.qq.com"))) + return { + **_gateway_weixin_qr_payload(session_id, session_state, status="expired"), + "message": "QR session expired; start again.", + } + status_resp = asyncio.run( + _poll_weixin_qr( + qrcode=str(session_state["qrcode"]), + base_url=str(session_state.get("baseUrl") or "https://ilinkai.weixin.qq.com"), + ) + ) status = str(status_resp.get("status") or "wait") if status == "scaned_but_redirect": redirect_host = str(status_resp.get("redirect_host") or "").strip() if redirect_host: session_state["baseUrl"] = f"https://{redirect_host}" - return {**_gateway_weixin_qr_payload(session_id, session_state, status=status), "message": "Redirected QR polling host."} + return { + **_gateway_weixin_qr_payload(session_id, session_state, status=status), + "message": "Redirected QR polling host.", + } if status == "confirmed": credentials = { "account_id": str(status_resp.get("ilink_bot_id") or ""), @@ -687,12 +768,22 @@ def _gateway_weixin_qr_poll(self, payload: Mapping[str, Any]) -> dict[str, Any]: return { **_gateway_weixin_qr_payload(session_id, session_state, status="confirmed"), "message": f"WeChat connected as {credentials['account_id']}", - "credentials": {"account_id": credentials["account_id"], "base_url": credentials["base_url"], "user_id": credentials["user_id"]}, + "credentials": { + "account_id": credentials["account_id"], + "base_url": credentials["base_url"], + "user_id": credentials["user_id"], + }, **persisted, } if status == "need_verifycode": - return {**_gateway_weixin_qr_payload(session_id, session_state, status=status), "message": "Scanned. Please confirm the verification code on your phone to continue."} - return {**_gateway_weixin_qr_payload(session_id, session_state, status=status), "message": "Scan the QR with WeChat and confirm login."} + return { + **_gateway_weixin_qr_payload(session_id, session_state, status=status), + "message": "Scanned. Please confirm the verification code on your phone to continue.", + } + return { + **_gateway_weixin_qr_payload(session_id, session_state, status=status), + "message": "Scan the QR with WeChat and confirm login.", + } def _gateway_configure_service(self, payload: Mapping[str, Any], *, service: str) -> dict[str, Any]: @@ -704,15 +795,51 @@ def _gateway_configure_service(self, payload: Mapping[str, Any], *, service: str manifest = _gateway_manifest(state_dir) gateway_payload, adapters_payload, adapter_payload = _gateway_adapter_payload(manifest, service) accounts = _gateway_accounts(adapter_payload) - account_id = str(config.get("accountId") or config.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID).strip() or DEFAULT_GATEWAY_ACCOUNT_ID - existing_account = next((account for account in accounts if str(account.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID) == account_id), {}) - transport = str(config.get("transport") or existing_account.get("surface") or adapter_payload.get("surface") or spec.get("defaultTransport") or "").strip() + account_id = ( + str(config.get("accountId") or config.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID).strip() + or DEFAULT_GATEWAY_ACCOUNT_ID + ) + existing_account = next( + (account for account in accounts if str(account.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID) == account_id), + {}, + ) + transport = str( + config.get("transport") + or existing_account.get("surface") + or adapter_payload.get("surface") + or spec.get("defaultTransport") + or "" + ).strip() if transport not in tuple(spec.get("transports", ())): - raise ValueError(f"gateway {service} transport must be one of {', '.join(spec.get('transports', ())) }") - enabled = bool(config.get("enabled")) if isinstance(config.get("enabled"), bool) else bool(adapter_payload.get("enabled") is not False) - account_enabled = bool(config.get("accountEnabled")) if isinstance(config.get("accountEnabled"), bool) else bool(existing_account.get("enabled") is not False) - event_path = str(config.get("eventPath") or config.get("event_path") or existing_account.get("event_path") or adapter_payload.get("event_path") or spec.get("eventPath") or "").strip() - allow_group_chats = bool(config.get("allowGroupChats")) if isinstance(config.get("allowGroupChats"), bool) else bool((adapter_payload.get("control") if isinstance(adapter_payload.get("control"), Mapping) else {}).get("allow_group_chats") is True) + raise ValueError(f"gateway {service} transport must be one of {', '.join(spec.get('transports', ()))}") + enabled = ( + bool(config.get("enabled")) + if isinstance(config.get("enabled"), bool) + else bool(adapter_payload.get("enabled") is not False) + ) + account_enabled = ( + bool(config.get("accountEnabled")) + if isinstance(config.get("accountEnabled"), bool) + else bool(existing_account.get("enabled") is not False) + ) + event_path = str( + config.get("eventPath") + or config.get("event_path") + or existing_account.get("event_path") + or adapter_payload.get("event_path") + or spec.get("eventPath") + or "" + ).strip() + allow_group_chats = ( + bool(config.get("allowGroupChats")) + if isinstance(config.get("allowGroupChats"), bool) + else bool( + (adapter_payload.get("control") if isinstance(adapter_payload.get("control"), Mapping) else {}).get( + "allow_group_chats" + ) + is True + ) + ) secrets = config.get("secrets") if isinstance(config.get("secrets"), Mapping) else {} secret_fields = tuple(field for field in spec.get("secretFields", ()) if isinstance(field, Mapping)) env_payload: dict[str, str] = {} @@ -740,7 +867,12 @@ def _gateway_configure_service(self, payload: Mapping[str, Any], *, service: str account_payload["event_path"] = event_path if service == "feishu": account_payload["secret_references"] = tuple( - _gateway_secret_reference(service=service, account_id=account_id, secret_key=secret_key, env_var=env_var) + _gateway_secret_reference( + service=service, + account_id=account_id, + secret_key=secret_key, + env_var=env_var, + ) for secret_key, env_var in env_payload.items() ) elif env_payload: @@ -759,7 +891,9 @@ def _gateway_configure_service(self, payload: Mapping[str, Any], *, service: str adapter_payload["enabled"] = enabled if event_path: adapter_payload["event_path"] = event_path - control_payload = dict(adapter_payload.get("control")) if isinstance(adapter_payload.get("control"), Mapping) else {} + control_payload = ( + dict(adapter_payload.get("control")) if isinstance(adapter_payload.get("control"), Mapping) else {} + ) control_payload.pop("default_elephant_id", None) control_payload.pop("default_session_id", None) control_payload.pop("auto_create_elephant", None) @@ -892,7 +1026,9 @@ def _row_id(row: Mapping[str, Any]) -> str: elif len(accounts) == 1: resolved_id = existing_ids[0] if requested_id and requested_id != resolved_id: - reason = f"requested accountId {requested_id!r} not found; removed the only configured account {resolved_id!r}" + reason = ( + f"requested accountId {requested_id!r} not found; removed the only configured account {resolved_id!r}" + ) else: # Ambiguous: either multiple accounts and id didn't match, or zero accounts. if not accounts: @@ -995,7 +1131,9 @@ def gateway_action(self, payload: Mapping[str, Any]) -> dict[str, Any]: if action == "remove": return _gateway_remove_service_account(self, payload, service=service) if action not in {"status", "doctor", "start", "stop", "restart"}: - raise ValueError("gateway action must be status, doctor, start, stop, restart, configure, remove, qr-start, or qr-poll") + raise ValueError( + "gateway action must be status, doctor, start, stop, restart, configure, remove, qr-start, or qr-poll" + ) database_path = self.repository.database_path state_dir = database_path.parent command = [sys.executable, "-m", "apps.gateway", service, action] @@ -1005,12 +1143,14 @@ def gateway_action(self, payload: Mapping[str, Any]) -> dict[str, Any]: transport = str(payload.get("transport") or payload.get("runtimeTarget") or "").strip() if transport: command.extend(["--transport", transport]) - command.extend([ - "--state-dir", - str(state_dir), - "--cli-state-dir", - str(state_dir), - ]) + command.extend( + [ + "--state-dir", + str(state_dir), + "--cli-state-dir", + str(state_dir), + ] + ) if action == "start": command.append("--detach") if action in {"stop", "restart"} and bool(payload.get("force")): @@ -1067,11 +1207,7 @@ def _override_enabled(overrides: Mapping[str, Any], item_id: str, default: bool) def _mapping_rows(value: object) -> dict[str, dict[str, Any]]: if not isinstance(value, Mapping): return {} - return { - str(key): dict(item) - for key, item in value.items() - if str(key).strip() and isinstance(item, Mapping) - } + return {str(key): dict(item) for key, item in value.items() if str(key).strip() and isinstance(item, Mapping)} def _text_list(value: object) -> list[str]: @@ -1111,11 +1247,7 @@ def _object_payload(value: object, *, field: str) -> dict[str, Any]: def _string_object_payload(value: object, *, field: str) -> dict[str, str]: - return { - str(key): str(item) - for key, item in _object_payload(value, field=field).items() - if str(key).strip() - } + return {str(key): str(item) for key, item in _object_payload(value, field=field).items() if str(key).strip()} def _optional_text(value: object) -> str | None: @@ -1153,7 +1285,11 @@ def _mcp_catalog(*, config_path: Path, config: Mapping[str, Any]) -> dict[str, A url = str(server.get("url") or "").strip() transport = str(server.get("transport") or ("http" if url else "stdio")).strip() or "stdio" env = _mapping_rows({"env": server.get("env")}).get("env", {}) if isinstance(server.get("env"), Mapping) else {} - headers = _mapping_rows({"headers": server.get("headers")}).get("headers", {}) if isinstance(server.get("headers"), Mapping) else {} + headers = ( + _mapping_rows({"headers": server.get("headers")}).get("headers", {}) + if isinstance(server.get("headers"), Mapping) + else {} + ) env_keys = sorted(str(key) for key in env if str(key).strip()) header_keys = sorted(str(key) for key in headers if str(key).strip()) server_rows.append( @@ -1211,7 +1347,9 @@ def _mcp_catalog(*, config_path: Path, config: Mapping[str, Any]) -> dict[str, A "writesState": bool(tool.get("writes_state", False)), "touchesNetwork": bool(tool.get("touches_network", False)), "touchesSecrets": bool(tool.get("touches_secrets", False)), - "requiredFields": tuple(str(item) for item in schema.get("required", []) if str(item).strip()) if isinstance(schema.get("required"), list) else (), + "requiredFields": tuple(str(item) for item in schema.get("required", []) if str(item).strip()) + if isinstance(schema.get("required"), list) + else (), "schema": schema, "provenance": f"{config_path}#mcp_servers.{server_id}.tools.{tool_name}", "backend": "mcp", @@ -1225,7 +1363,9 @@ def _mcp_catalog(*, config_path: Path, config: Mapping[str, Any]) -> dict[str, A } -def _load_operator_global_config(database_path: Path) -> tuple[Path, Path, dict[str, Any]]: +def _load_operator_global_config( + database_path: Path, +) -> tuple[Path, Path, dict[str, Any]]: state_dir = database_path.parent config_path = global_config_path_for_state_dir(database_path.parent) config = load_global_config(config_path, state_dir=state_dir) @@ -1464,8 +1604,17 @@ def sync_operator_mcp_server(self, payload: Mapping[str, Any]) -> dict[str, Any] server_exists = server_id in servers existing_server = dict(servers.get(server_id, {})) next_server = _apply_mcp_server_payload(existing_server, payload) - transport = str(next_server.get("transport") or ("http" if str(next_server.get("url") or "").strip() else "stdio")).strip().lower() or "stdio" - headers = _mapping_rows({"headers": next_server.get("headers")}).get("headers", {}) if isinstance(next_server.get("headers"), Mapping) else {} + transport = ( + str(next_server.get("transport") or ("http" if str(next_server.get("url") or "").strip() else "stdio")) + .strip() + .lower() + or "stdio" + ) + headers = ( + _mapping_rows({"headers": next_server.get("headers")}).get("headers", {}) + if isinstance(next_server.get("headers"), Mapping) + else {} + ) existing_tools = _mapping_rows(existing_server.get("tools")) merged_tools = _merge_discovered_mcp_tools(existing_tools, discovered_tools, transport=transport, headers=headers) next_server["tools"] = merged_tools @@ -1576,7 +1725,9 @@ def _mcp_discover_payload(payload: Mapping[str, Any]) -> dict[str, Any]: } -def _mcporter_command_for_discovery(payload: Mapping[str, Any]) -> tuple[list[str], Any | None]: +def _mcporter_command_for_discovery( + payload: Mapping[str, Any], +) -> tuple[list[str], Any | None]: repo_root = Path(__file__).resolve().parents[2] transport = str(payload.get("transport") or "stdio") server_id = str(payload.get("serverId") or "mcp-probe") @@ -1642,8 +1793,12 @@ def _mcp_discovered_tool_rows(payload: Mapping[str, Any]) -> list[dict[str, Any] "name": str(item.get("name") or "").strip(), "description": str(item.get("description") or "").strip(), "inputSchema": schema, - "requiredFields": tuple(str(field) for field in schema.get("required", []) if str(field).strip()) if isinstance(schema.get("required"), list) else (), - "options": [option for option in item.get("options", []) if isinstance(option, Mapping)] if isinstance(item.get("options"), list) else [], + "requiredFields": tuple(str(field) for field in schema.get("required", []) if str(field).strip()) + if isinstance(schema.get("required"), list) + else (), + "options": [option for option in item.get("options", []) if isinstance(option, Mapping)] + if isinstance(item.get("options"), list) + else [], } ) return rows @@ -1733,7 +1888,10 @@ def discover_operator_mcp_server(self, payload: Mapping[str, Any]) -> dict[str, error_text = str(parsed.get("error") or "").strip() if parsed else "" if not error_text and result.returncode != 0: error_text = (result.stderr or result.stdout or "mcporter discovery failed").strip() - status = str(parsed.get("status") or ("ok" if result.returncode == 0 and not error_text else "failed")).strip() or "failed" + status = ( + str(parsed.get("status") or ("ok" if result.returncode == 0 and not error_text else "failed")).strip() + or "failed" + ) return { "status": status, "serverId": probe["serverId"], diff --git a/apps/api/api_runtime_cron_ops.py b/apps/api/api_runtime_cron_ops.py index 97fe00f..ce13815 100644 --- a/apps/api/api_runtime_cron_ops.py +++ b/apps/api/api_runtime_cron_ops.py @@ -12,7 +12,10 @@ def run_proactive_ask_now(self) -> dict[str, Any]: """Run the built-in proactive ask scheduler once on demand.""" from apps.gateway.cron_service import CONFIGURED_IM_ADAPTERS - from apps.gateway.proactive_ask_job import ProactiveAskTickResult, run_proactive_ask_tick + from apps.gateway.proactive_ask_job import ( + ProactiveAskTickResult, + run_proactive_ask_tick, + ) from apps.gateway.runtime import build_gateway_app from packages.runtime_config import ( global_config_path_for_state_dir, diff --git a/apps/api/api_runtime_http_dispatch_helpers.py b/apps/api/api_runtime_http_dispatch_helpers.py index 01c07e3..b4e6470 100644 --- a/apps/api/api_runtime_http_dispatch_helpers.py +++ b/apps/api/api_runtime_http_dispatch_helpers.py @@ -10,17 +10,14 @@ def _elephant_id_from_name(name: str) -> str: """Convert elephant display name to elephant ID format.""" import re + return re.sub(r"[^a-zA-Z0-9_-]", "", name.lower().replace(" ", "-")) def _cron_payload(payload: Mapping[str, Any]) -> dict[str, Any]: """Extract validated cron job payload.""" job_payload = { - key: value - for key, value in ( - ("prompt", _optional_str(payload.get("prompt"))), - ) - if value is not None + key: value for key, value in (("prompt", _optional_str(payload.get("prompt"))),) if value is not None } skills = _cron_skill_ids(payload.get("skills")) if skills: diff --git a/apps/api/api_runtime_http_methods.py b/apps/api/api_runtime_http_methods.py index 691e222..2bdb55f 100644 --- a/apps/api/api_runtime_http_methods.py +++ b/apps/api/api_runtime_http_methods.py @@ -3,7 +3,6 @@ from __future__ import annotations from dataclasses import replace -import re import shutil from typing import Any, Mapping from urllib.parse import unquote @@ -17,7 +16,10 @@ from packages.kernel import KernelSourceRequest, ReconciliationPipeline, StateReconciler from packages.operator.runtime import RecallEvidenceOperatorDetail from packages.runtime_layout import elephant_file_path -from packages.state import render_default_elephant_identity, write_elephant_identity_file +from packages.state import ( + render_default_elephant_identity, + write_elephant_identity_file, +) from .api_runtime_support import ( APILoopRecord, @@ -33,11 +35,11 @@ _cron_job_system_kind, _elephant_id_from_name, _cron_payload, - _cron_skill_ids, _cron_job_record, _read_wsgi_body, ) + def run_loop( self, episode_id: str, @@ -139,6 +141,8 @@ def run_loop( latest_loop=record, inspection=inspection, ) + + def dispatch(self, method: str, path: str, body: bytes | None = None) -> APIResponse: if method.upper() == "GET" and path == "/healthz": return APIResponse(200, {"status": "ok", "service": "elephant-api"}) @@ -168,6 +172,8 @@ def dispatch(self, method: str, path: str, body: bytes | None = None) -> APIResp return APIResponse(422, {"error": "configuration_required", "detail": str(error)}) except Exception as error: return APIResponse(500, {"error": "internal_error", "detail": str(error)}) + + def _unique_elephant_id(self, base_elephant_id: str) -> str: root = _elephant_id_from_name(base_elephant_id) elephant_id = root @@ -176,6 +182,8 @@ def _unique_elephant_id(self, base_elephant_id: str) -> str: elephant_id = f"{root}-{suffix}" suffix += 1 return elephant_id + + def _elephant_state_for_id(self, elephant_id: str): target = elephant_id.strip() if not target: @@ -183,7 +191,12 @@ def _elephant_state_for_id(self, elephant_id: str): direct = self.repository.load_state(f"state:{target}") if direct is not None: return direct - return next((state for state in self.repository.list_states() if state.elephant_id == target), None) + return next( + (state for state in self.repository.list_states() if state.elephant_id == target), + None, + ) + + def _default_elephant_identity_text(*, elephant_id: str, display_name: str, mode: str) -> str: """Seed identity text when none is supplied via the API. @@ -206,22 +219,34 @@ def _default_elephant_identity_text(*, elephant_id: str, display_name: str, mode charter, ) ) -def _elephant_identity_text_from_payload(payload: Mapping[str, Any], *, elephant_id: str, display_name: str, mode: str) -> str: - return ( - _optional_str(payload.get("elephant_identity_text") or payload.get("eggIdentityText") or payload.get("text") or payload.get("content")) - or _default_elephant_identity_text(elephant_id=elephant_id, display_name=display_name, mode=mode) - ) + + +def _elephant_identity_text_from_payload( + payload: Mapping[str, Any], *, elephant_id: str, display_name: str, mode: str +) -> str: + return _optional_str( + payload.get("elephant_identity_text") + or payload.get("eggIdentityText") + or payload.get("text") + or payload.get("content") + ) or _default_elephant_identity_text(elephant_id=elephant_id, display_name=display_name, mode=mode) + + def _write_elephant_identity_file(self, *, elephant_id: str, text: str) -> str: path = write_elephant_identity_file( elephant_file_path(elephant_id, install_root=self.config.install_root), text, ) return str(path) + + def _dispatch_elephants(self, method: str, parts: tuple[str, ...], body: bytes | None) -> APIResponse: normalized_method = method.upper() if normalized_method == "POST" and not parts: payload = _read_json_bytes(body) - display_name = str(payload.get("elephant_name") or payload.get("display_name") or payload.get("name") or "").strip() + display_name = str( + payload.get("elephant_name") or payload.get("display_name") or payload.get("name") or "" + ).strip() if not display_name: raise ValueError("display_name is required") raw_elephant_id = str(payload.get("elephant_id") or payload.get("eggId") or "").strip() @@ -234,7 +259,9 @@ def _dispatch_elephants(self, method: str, parts: tuple[str, ...], body: bytes | or payload.get("profile_id") or self.repository.ensure_default_personal_model().personal_model_id ).strip() - identity_text = _elephant_identity_text_from_payload(payload, elephant_id=elephant_id, display_name=display_name, mode=mode) + identity_text = _elephant_identity_text_from_payload( + payload, elephant_id=elephant_id, display_name=display_name, mode=mode + ) state = self.repository.create_state( personal_model_id=personal_model_id, state_id=f"state:{elephant_id}", @@ -250,7 +277,10 @@ def _dispatch_elephants(self, method: str, parts: tuple[str, ...], body: bytes | metadata={"profile_id": personal_model_id}, ) elephant_identity_path = _write_elephant_identity_file(self, elephant_id=elephant_id, text=identity_text) - return APIResponse(201, _jsonable({"elephant": state, "eggIdentityPath": elephant_identity_path})) + return APIResponse( + 201, + _jsonable({"elephant": state, "eggIdentityPath": elephant_identity_path}), + ) if len(parts) != 1: return APIResponse(404, {"error": "not_found"}) elephant_id = unquote(parts[0]).strip() @@ -261,12 +291,19 @@ def _dispatch_elephants(self, method: str, parts: tuple[str, ...], body: bytes | payload = _read_json_bytes(body) display_name = _optional_str(payload.get("elephant_name") or payload.get("display_name") or payload.get("name")) mode = _optional_str(payload.get("mode")) - identity_text = _optional_str(payload.get("elephant_identity_text") or payload.get("eggIdentityText") or payload.get("text") or payload.get("content")) + identity_text = _optional_str( + payload.get("elephant_identity_text") + or payload.get("eggIdentityText") + or payload.get("text") + or payload.get("content") + ) updated = replace( state, elephant_name=display_name or state.elephant_name, identity_mode=mode or state.identity_mode or "companion", - initiative=_optional_str(payload.get("initiative")) if payload.get("initiative") is not None else state.initiative, + initiative=_optional_str(payload.get("initiative")) + if payload.get("initiative") is not None + else state.initiative, working_style=( _optional_str(payload.get("personality_preset") or payload.get("working_style")) if payload.get("personality_preset") is not None or payload.get("working_style") is not None @@ -279,15 +316,34 @@ def _dispatch_elephants(self, method: str, parts: tuple[str, ...], body: bytes | self.repository.upsert_state(updated) elephant_identity_path = "" if identity_text is not None: - elephant_identity_path = _write_elephant_identity_file(self, elephant_id=updated.elephant_id, text=identity_text) - return APIResponse(200, _jsonable({"elephant": updated, "eggIdentityPath": elephant_identity_path})) + elephant_identity_path = _write_elephant_identity_file( + self, elephant_id=updated.elephant_id, text=identity_text + ) + return APIResponse( + 200, + _jsonable({"elephant": updated, "eggIdentityPath": elephant_identity_path}), + ) if normalized_method == "DELETE": episode_ids = tuple(episode.episode_id for episode in self.repository.list_episodes(state_id=state.state_id)) deleted_sessions = self.repository.delete_episodes(episode_ids, delete_orphaned_profiles=False) self.repository.delete_state(state.state_id) - shutil.rmtree(elephant_file_path(state.elephant_id, install_root=self.config.install_root), ignore_errors=True) - return APIResponse(200, _jsonable({"elephant_id": state.elephant_id, "deleted": True, "deleted_sessions": deleted_sessions})) + shutil.rmtree( + elephant_file_path(state.elephant_id, install_root=self.config.install_root), + ignore_errors=True, + ) + return APIResponse( + 200, + _jsonable( + { + "elephant_id": state.elephant_id, + "deleted": True, + "deleted_sessions": deleted_sessions, + } + ), + ) return APIResponse(404, {"error": "not_found"}) + + def _dispatch_episodes(self, method: str, parts: tuple[str, ...], body: bytes | None) -> APIResponse: if method.upper() == "POST" and len(parts) == 0: payload = _read_json_bytes(body) @@ -330,17 +386,38 @@ def _dispatch_episodes(self, method: str, parts: tuple[str, ...], body: bytes | if method.upper() == "GET" and len(parts) == 2 and parts[1] == "profile": inspection = self.inspect_episode(episode_id) return APIResponse(200, _jsonable({"personal_model": inspection.personal_model})) - if len(parts) == 2 and parts[1] in {"identity", "user", "relationship", "continuity"}: + if len(parts) == 2 and parts[1] in { + "identity", + "user", + "relationship", + "continuity", + }: episode = self.repository.load_episode(episode_id) if episode is None: raise KeyError(episode_id) return self._dispatch_states(method, (episode.state_id, parts[1]), body) if len(parts) == 2 and parts[1] == "recall": if method.upper() == "GET": - return APIResponse(200, _jsonable({"episode_id": episode_id, "recall": self.inspect_recall_evidence_surface(episode_id)})) + return APIResponse( + 200, + _jsonable( + { + "episode_id": episode_id, + "recall": self.inspect_recall_evidence_surface(episode_id), + } + ), + ) if len(parts) == 3 and parts[1] == "recall" and parts[2] == "evidence": if method.upper() == "GET": - return APIResponse(200, _jsonable({"episode_id": episode_id, "evidence": self.list_recall_evidence(episode_id)})) + return APIResponse( + 200, + _jsonable( + { + "episode_id": episode_id, + "evidence": self.list_recall_evidence(episode_id), + } + ), + ) if len(parts) == 3 and parts[1] == "recall" and parts[2] == "search": payload = _read_json_bytes(body) query = _optional_str(payload.get("query")) @@ -349,10 +426,12 @@ def _dispatch_episodes(self, method: str, parts: tuple[str, ...], body: bytes | limit = int(payload.get("limit", 5)) return APIResponse( 200, - _jsonable({ - "episode_id": episode_id, - "recall": self.search_recall_evidence_surface(episode_id, query=query, limit=limit), - }), + _jsonable( + { + "episode_id": episode_id, + "recall": self.search_recall_evidence_surface(episode_id, query=query, limit=limit), + } + ), ) if len(parts) == 3 and parts[1] == "recall": evidence_ref = parts[2] @@ -372,13 +451,23 @@ def _dispatch_episodes(self, method: str, parts: tuple[str, ...], body: bytes | ), ) return APIResponse(404, {"error": "not_found"}) + + def _dispatch_states(self, method: str, parts: tuple[str, ...], body: bytes | None) -> APIResponse: if len(parts) != 2: return APIResponse(404, {"error": "not_found"}) state_id, surface = parts if surface == "identity": if method.upper() == "GET": - return APIResponse(200, _jsonable({"state_id": state_id, "identity": self.inspect_identity(state_id=state_id)})) + return APIResponse( + 200, + _jsonable( + { + "state_id": state_id, + "identity": self.inspect_identity(state_id=state_id), + } + ), + ) if method.upper() in {"PATCH", "POST"}: payload = _read_json_bytes(body) result = self.update_identity_state( @@ -386,13 +475,21 @@ def _dispatch_states(self, method: str, parts: tuple[str, ...], body: bytes | No display_name=_optional_str(payload.get("display_name") or payload.get("name")), personality_preset=_optional_str(payload.get("personality_preset")), initiative=_optional_str(payload.get("initiative")), - elephant_identity_text=_optional_str(payload.get("elephant_identity_text") or payload.get("eggIdentityText") or payload.get("text") or payload.get("content")), + elephant_identity_text=_optional_str( + payload.get("elephant_identity_text") + or payload.get("eggIdentityText") + or payload.get("text") + or payload.get("content") + ), clear_elephant_identity=bool(payload.get("clear_elephant_identity", False)), ) return APIResponse(200, _jsonable({"state_id": state_id, "identity": result})) if surface == "user": if method.upper() == "GET": - return APIResponse(200, _jsonable({"state_id": state_id, "user": self.inspect_user(state_id=state_id)})) + return APIResponse( + 200, + _jsonable({"state_id": state_id, "user": self.inspect_user(state_id=state_id)}), + ) if method.upper() in {"PATCH", "POST"}: payload = _read_json_bytes(body) result = self.update_user_state( @@ -405,7 +502,15 @@ def _dispatch_states(self, method: str, parts: tuple[str, ...], body: bytes | No return APIResponse(200, _jsonable({"state_id": state_id, "user": result})) if surface == "relationship": if method.upper() == "GET": - return APIResponse(200, _jsonable({"state_id": state_id, "relationship": self.inspect_relationship(state_id=state_id)})) + return APIResponse( + 200, + _jsonable( + { + "state_id": state_id, + "relationship": self.inspect_relationship(state_id=state_id), + } + ), + ) if method.upper() in {"PATCH", "POST"}: payload = _read_json_bytes(body) result = self.update_relationship_state( @@ -418,6 +523,8 @@ def _dispatch_states(self, method: str, parts: tuple[str, ...], body: bytes | No if surface == "continuity" and method.upper() == "GET": return APIResponse(200, _jsonable(self.inspect_continuity(state_id).to_record())) return APIResponse(404, {"error": "not_found"}) + + def _dispatch_providers(self, method: str, parts: tuple[str, ...], body: bytes | None) -> APIResponse: if method.upper() == "GET" and len(parts) == 0: return APIResponse(200, _jsonable(self.list_providers())) @@ -456,6 +563,7 @@ def _dispatch_providers(self, method: str, parts: tuple[str, ...], body: bytes | return APIResponse(200, _jsonable(self.delete_provider_key(parts[1]))) return APIResponse(404, {"error": "not_found"}) + def _dispatch_internal(self, method: str, parts: tuple[str, ...], body: bytes | None) -> APIResponse: if method.upper() == "GET" and len(parts) == 2 and parts[0] == "dashboard": return APIResponse(200, {"dashboard": _jsonable(self.inspect_internal_dashboard(parts[1]))}) @@ -480,10 +588,14 @@ def _dispatch_internal(self, method: str, parts: tuple[str, ...], body: bytes | return APIResponse(200, _jsonable(result)) return APIResponse(404, {"error": "not_found"}) + def _dispatch_operator(self, method: str, parts: tuple[str, ...], body: bytes | None) -> APIResponse: if parts and parts[0] == "cron": if method.upper() == "GET" and len(parts) == 1: - return APIResponse(200, {"cron": {"jobs": [_cron_job_record(job) for job in self.cron_runtime.list_jobs()]}}) + return APIResponse( + 200, + {"cron": {"jobs": [_cron_job_record(job) for job in self.cron_runtime.list_jobs()]}}, + ) if method.upper() == "POST" and len(parts) == 1: payload = _read_json_bytes(body) job_payload = _cron_payload(payload) @@ -506,7 +618,10 @@ def _dispatch_operator(self, method: str, parts: tuple[str, ...], body: bytes | if job is None: raise ValueError(f"system cron job unavailable: {job_id}") return APIResponse(200, {"cron": {"job": job}}) - return APIResponse(200, {"cron": {"job": _cron_job_record(self.cron_runtime.inspect_job(job_id))}}) + return APIResponse( + 200, + {"cron": {"job": _cron_job_record(self.cron_runtime.inspect_job(job_id))}}, + ) if method.upper() == "PATCH": payload = _read_json_bytes(body) action = str(payload.get("action") or "").strip().lower() @@ -530,7 +645,10 @@ def _dispatch_operator(self, method: str, parts: tuple[str, ...], body: bytes | raise ValueError("cron PATCH requires action=pause or action=resume") if job is None: raise ValueError(f"system cron job unavailable: {job_id}") - return APIResponse(200, {"cron": {"job": job if isinstance(job, Mapping) else _cron_job_record(job)}}) + return APIResponse( + 200, + {"cron": {"job": job if isinstance(job, Mapping) else _cron_job_record(job)}}, + ) if method.upper() == "DELETE": if job_id == "system:proactive-ask": return APIResponse(403, {"error": "system_cron_jobs_cannot_be_deleted"}) @@ -586,9 +704,9 @@ def _dispatch_operator(self, method: str, parts: tuple[str, ...], body: bytes | ) return APIResponse(200, _jsonable(result)) return APIResponse(404, {"error": "not_found"}) -def _dispatch_personal_model( - self, method: str, parts: tuple[str, ...], body: bytes | None -) -> APIResponse: + + +def _dispatch_personal_model(self, method: str, parts: tuple[str, ...], body: bytes | None) -> APIResponse: """Operator-surface writes against Personal Model claims and questions. Routes: @@ -627,9 +745,21 @@ def _dispatch_personal_model( intensity = str(payload.get("learning_intensity") or "").strip().lower() if intensity in {"low", "medium", "high"} and not proactive_updates: _INTENSITY_MAP = { - "low": {"idle_threshold_minutes": 720, "daily_max": 2, "quiet_hours": [23, 7]}, - "medium": {"idle_threshold_minutes": 180, "daily_max": 8, "quiet_hours": [23, 7]}, - "high": {"idle_threshold_minutes": 60, "daily_max": 24, "quiet_hours": [1, 7]}, + "low": { + "idle_threshold_minutes": 720, + "daily_max": 2, + "quiet_hours": [23, 7], + }, + "medium": { + "idle_threshold_minutes": 180, + "daily_max": 8, + "quiet_hours": [23, 7], + }, + "high": { + "idle_threshold_minutes": 60, + "daily_max": 24, + "quiet_hours": [1, 7], + }, } proactive_updates = _INTENSITY_MAP[intensity] if not proactive_updates: @@ -655,9 +785,20 @@ def _dispatch_personal_model( return APIResponse(404, {"error": "question_not_found"}) bumped = replace(target, priority=min(1.0, max(target.priority, 0.85))) upsert(bumped) - return APIResponse(200, {"personal_model": {"question_id": question_id, "priority": bumped.priority}}) + return APIResponse( + 200, + { + "personal_model": { + "question_id": question_id, + "priority": bumped.priority, + } + }, + ) if action == "dismiss": - surface = PersonalModelUnderstandingSurface(repository=self.repository, semantic_summary_indexer=getattr(self, "semantic_summary_indexer", None)) + surface = PersonalModelUnderstandingSurface( + repository=self.repository, + semantic_summary_indexer=getattr(self, "semantic_summary_indexer", None), + ) result = surface.manage_personal_model_questions( str(payload.get("episode_id") or "dashboard"), action="dismiss", @@ -670,7 +811,10 @@ def _dispatch_personal_model( content = str(payload.get("content") or "").strip() if not content: raise ValueError("answer requires 'content'") - surface = PersonalModelUnderstandingSurface(repository=self.repository, semantic_summary_indexer=getattr(self, "semantic_summary_indexer", None)) + surface = PersonalModelUnderstandingSurface( + repository=self.repository, + semantic_summary_indexer=getattr(self, "semantic_summary_indexer", None), + ) result = surface.manage_personal_model_questions( str(payload.get("episode_id") or "dashboard"), action="answer", @@ -684,11 +828,26 @@ def _dispatch_personal_model( if normalized == "POST" and len(parts) >= 3 and parts[0] == "claims": claim_id = unquote(parts[1]).strip() action = parts[2].strip().lower() - if action not in {"correct", "forget", "dispute", "restore", "delete", "protect", "unprotect"}: + if action not in { + "correct", + "forget", + "dispute", + "restore", + "delete", + "protect", + "unprotect", + }: return APIResponse(404, {"error": "not_found"}) payload = _read_json_bytes(body) if body else {} - personal_model_id = str(payload.get("personal_model_id") or DEFAULT_PERSONAL_MODEL_ID).strip() or DEFAULT_PERSONAL_MODEL_ID - facts = tuple(self.repository.list_personal_model_facts(personal_model_id=personal_model_id, status=("active", "retired", "disputed") if action in {"restore", "delete"} else "active")) + personal_model_id = ( + str(payload.get("personal_model_id") or DEFAULT_PERSONAL_MODEL_ID).strip() or DEFAULT_PERSONAL_MODEL_ID + ) + facts = tuple( + self.repository.list_personal_model_facts( + personal_model_id=personal_model_id, + status=("active", "retired", "disputed") if action in {"restore", "delete"} else "active", + ) + ) target = next((fact for fact in facts if fact.fact_id == claim_id), None) if target is None: return APIResponse(404, {"error": "claim_not_found"}) @@ -713,11 +872,31 @@ def _dispatch_personal_model( } updated = replace(target, metadata=next_metadata) self.repository.upsert_personal_model_fact(updated) - return APIResponse(200, {"personal_model": {"action": action, "status": "active", "ref": claim_id, "claim": _serialize(updated)}}) + return APIResponse( + 200, + { + "personal_model": { + "action": action, + "status": "active", + "ref": claim_id, + "claim": _serialize(updated), + } + }, + ) if action == "delete": - from packages.understanding.personal_model_governance import is_protected_topic + from packages.understanding.personal_model_governance import ( + is_protected_topic, + ) + if is_protected_topic(str(metadata.get("topic") or ""), metadata): - return APIResponse(409, {"error": "protected_topic", "detail": "protected Personal Model topics must be unprotected before delete", "ref": claim_id}) + return APIResponse( + 409, + { + "error": "protected_topic", + "detail": "protected Personal Model topics must be unprotected before delete", + "ref": claim_id, + }, + ) now = _now() deleted = replace( target, @@ -749,11 +928,23 @@ def _dispatch_personal_model( }, ) ) - return APIResponse(200, {"personal_model": {"action": "delete", "status": "deleted", "ref": claim_id}}) + return APIResponse( + 200, + { + "personal_model": { + "action": "delete", + "status": "deleted", + "ref": claim_id, + } + }, + ) topic = str(payload.get("topic") or metadata.get("topic") or "").strip() if not topic: return APIResponse(409, {"error": "claim_missing_topic"}) - surface = PersonalModelUnderstandingSurface(repository=self.repository, semantic_summary_indexer=getattr(self, "semantic_summary_indexer", None)) + surface = PersonalModelUnderstandingSurface( + repository=self.repository, + semantic_summary_indexer=getattr(self, "semantic_summary_indexer", None), + ) result = surface.update_personal_model( str(payload.get("episode_id") or "dashboard"), action=action, @@ -769,16 +960,22 @@ def _dispatch_personal_model( return APIResponse(404, {"error": "not_found"}) + def _persist_proactive_ask_config(state_dir, updates: dict) -> None: try: from packages.runtime_config import ( - personal_model_question_config_from_global, global_config_path_for_state_dir, - load_global_config, write_global_config, + personal_model_question_config_from_global, + global_config_path_for_state_dir, + load_global_config, + write_global_config, ) + config_path = global_config_path_for_state_dir(state_dir) config = load_global_config(config_path, state_dir=state_dir) question_policy = personal_model_question_config_from_global(config) - proactive = question_policy.get("proactive_ask") if isinstance(question_policy.get("proactive_ask"), dict) else {} + proactive = ( + question_policy.get("proactive_ask") if isinstance(question_policy.get("proactive_ask"), dict) else {} + ) proactive.update(updates) question_policy["proactive_ask"] = proactive question_policy.pop("learning_intensity", None) @@ -787,6 +984,7 @@ def _persist_proactive_ask_config(state_dir, updates: dict) -> None: except Exception: # pragma: no cover return + def run_cron_job_now(self, job_id: str) -> dict[str, Any]: """Fire one cron job on demand and return its execution result. @@ -799,7 +997,10 @@ def run_cron_job_now(self, job_id: str) -> dict[str, Any]: from pathlib import Path as _Path from apps.cli.runtime import CliRuntime - from apps.gateway.cron_service import build_gateway_cron_delivery_callback, cron_execution_should_deliver + from apps.gateway.cron_service import ( + build_gateway_cron_delivery_callback, + cron_execution_should_deliver, + ) state_dir = _Path(str(self.repository.database_path.parent)) # Gateway and CLI share the same state dir and DB (`/herd`) — the @@ -840,6 +1041,7 @@ def run_cron_job_now(self, job_id: str) -> dict[str, Any]: } } + def __call__(self, environ: Mapping[str, Any], start_response: Any) -> list[bytes]: from .api_runtime_support import _json_bytes as encode_json diff --git a/apps/api/api_runtime_impl.py b/apps/api/api_runtime_impl.py index 55b52bb..534a4d6 100644 --- a/apps/api/api_runtime_impl.py +++ b/apps/api/api_runtime_impl.py @@ -1,36 +1,21 @@ """Programmatic API runtime implementation assembled from smaller method modules.""" - from __future__ import annotations -from dataclasses import asdict, dataclass, is_dataclass, replace -from datetime import UTC, datetime from pathlib import Path -import json -from typing import Any, Mapping -from uuid import uuid4 +from typing import Mapping from apps.provider_runtime import load_provider_profile from packages.runtime_config import global_config_path_for_state_dir from packages.models import SurfaceModelProviderCapability -from packages.auth import AuthProfile, PersistentAuthProfileStore +from packages.auth import PersistentAuthProfileStore from packages.context import ContextRuntime from packages.cron import CronRuntime -from packages.contracts import ( - ContextBundle, - EventEnvelope, - ExecutionResult, -) -from packages.contracts.runtime import PersonalModelRuntimeState, RecallEvidence -from packages.kernel import KernelDependencies, KernelOutcome, KernelService, KernelSourceRequest, ReconciliationPipeline, StateReconciler -from packages.evidence import RecallRuntime, SemanticSummaryIndexer, build_semantic_index_bundle -from packages.operator.runtime import ( - RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, - ProcedureOperatorDetail, - build_recall_evidence_operator_surface, - build_procedure_operator_surface, - build_profile_operator_surface, +from packages.kernel import KernelDependencies, KernelService +from packages.evidence import ( + RecallRuntime, + SemanticSummaryIndexer, + build_semantic_index_bundle, ) from packages.runtime_config import configured_external_skill_dirs, load_global_config from packages.runtime_layout import default_cron_dir, infer_install_root_from_state_dir @@ -54,7 +39,10 @@ build_tool_runtime, sync_custom_mcp_tools, ) -from packages.tools.adapters import DeliveryMessageSurfaceAdapter, StructuredClarifySurface +from packages.tools.adapters import ( + DeliveryMessageSurfaceAdapter, + StructuredClarifySurface, +) from packages.understanding import PersonalModelUnderstandingSurface from packages.tools.browser_backend import create_playwright_browser_backend from packages.tools.local_roots import default_local_allowed_roots @@ -67,7 +55,7 @@ APITelemetrySink, APIToolExecution, ) -from .state_runtime import APIContinuityInspection, APIStateService +from .state_runtime import APIStateService from .api_runtime_support import ( APIAppConfig, @@ -90,7 +78,12 @@ def _enabled_overrides(state_dir: Path, section: str) -> dict[str, bool]: """Load skill/extension override settings from config.yaml.""" - from packages.runtime_config import load_global_config, load_extensions_from_config, global_config_path_for_state_dir + from packages.runtime_config import ( + load_global_config, + load_extensions_from_config, + global_config_path_for_state_dir, + ) + manifest = {} try: config_path = global_config_path_for_state_dir(state_dir) @@ -116,13 +109,20 @@ def __init__(self, config: APIAppConfig) -> None: self.repository = RuntimeStorageRepository(config.database_path) self.repository.bootstrap() runtime_state_dir = self.repository.database_path.parent - _obs_cfg = load_global_config(global_config_path_for_state_dir(runtime_state_dir), state_dir=runtime_state_dir) + _obs_cfg = load_global_config( + global_config_path_for_state_dir(runtime_state_dir), + state_dir=runtime_state_dir, + ) from packages.observability import setup_from_config + setup_from_config(_obs_cfg, state_dir=str(runtime_state_dir)) install_root = config.install_root or infer_install_root_from_state_dir(runtime_state_dir) sync_builtin_skill_shelf(destination_root=install_root / "skills" / "builtin") self.profile_loader = ProfileLoader(install_root) - active_provider_profile = load_provider_profile(runtime_state_dir, config_path=global_config_path_for_state_dir(runtime_state_dir)) + active_provider_profile = load_provider_profile( + runtime_state_dir, + config_path=global_config_path_for_state_dir(runtime_state_dir), + ) active_provider_profile_id = None active_provider_id = None if active_provider_profile is not None: @@ -147,7 +147,9 @@ def __init__(self, config: APIAppConfig) -> None: loaded_profile = self.profile_loader.load() prompt_contract = build_prompt_contract(loaded_profile, prompt_mode="full") context_instruction_refs = prompt_contract.instruction_refs or config.instruction_refs - self.context_runtime = ContextRuntime(instruction_refs=context_instruction_refs, total_tokens=config.total_tokens) + self.context_runtime = ContextRuntime( + instruction_refs=context_instruction_refs, total_tokens=config.total_tokens + ) self.personal_state = APIStateService( repository=self.repository, recall_runtime=self.recall_runtime, @@ -203,6 +205,7 @@ def __init__(self, config: APIAppConfig) -> None: profile_loader=self.profile_loader, install_root=install_root, ) + def _resolve_elephant_state(elephant_id: str): resolved_elephant_id = elephant_id.strip() if resolved_elephant_id: @@ -232,6 +235,7 @@ def _tool_context_for_session(session_id: str, requester: ToolRequester | None) elephant_id=elephant_id, episode_id=episode.episode_id, ) + self.tool_runtime = build_tool_runtime( enabled_overrides=_enabled_overrides(runtime_state_dir, "tool_overrides"), dependencies=BuiltinToolDependencies( @@ -299,6 +303,7 @@ def _tool_context_for_session(session_id: str, requester: ToolRequester | None) ) self._loops: dict[str, list[APILoopRecord]] = {} + ElephantAPIApp.list_providers = _provider_methods.list_providers ElephantAPIApp.setup_provider = _provider_methods.setup_provider ElephantAPIApp.discover_provider_models = _provider_methods.discover_provider_models @@ -360,6 +365,7 @@ def _tool_context_for_session(session_id: str, requester: ToolRequester | None) ElephantAPIApp.run_proactive_ask_now = _cron_methods.run_proactive_ask_now ElephantAPIApp.__call__ = _http_methods.__call__ + def create_app( *, database_path: str | Path, @@ -376,6 +382,7 @@ def create_app( ) ) + __all__ = [ "APIAppConfig", "APIResponse", diff --git a/apps/api/api_runtime_internal_methods.py b/apps/api/api_runtime_internal_methods.py index b928f7b..d82a295 100644 --- a/apps/api/api_runtime_internal_methods.py +++ b/apps/api/api_runtime_internal_methods.py @@ -20,7 +20,7 @@ _tools, ) from .api_runtime_console_usage import normalize_token_usage_row -from .api_runtime_support import _jsonable, _now +from .api_runtime_support import _jsonable _INTERNAL_DASHBOARD_QUERY_CONTRACT = ( @@ -43,7 +43,10 @@ def _sort_items(items: tuple[Any, ...], *, id_field: str, time_field: str | None return tuple( sorted( items, - key=lambda item: (str(getattr(item, time_field) or ""), str(getattr(item, id_field))), + key=lambda item: ( + str(getattr(item, time_field) or ""), + str(getattr(item, id_field)), + ), reverse=True, ) ) @@ -163,11 +166,7 @@ def _usage_by_elephant(events: list[dict[str, Any]]) -> tuple[dict[str, Any], .. grouped: dict[str, dict[str, Any]] = {} for row in events: elephant_id = str( - row.get("eggId") - or row.get("elephant_id") - or row.get("profile_id") - or row.get("session_id") - or "unknown" + row.get("eggId") or row.get("elephant_id") or row.get("profile_id") or row.get("session_id") or "unknown" ) elephant_name = str(row.get("eggName") or row.get("elephant_name") or elephant_id) bucket = grouped.setdefault( @@ -193,7 +192,10 @@ def _usage_by_elephant(events: list[dict[str, Any]]) -> tuple[dict[str, Any], .. return tuple( sorted( grouped.values(), - key=lambda row: (_usage_int(row.get("totalTokens")), str(row.get("lastUsedAt") or "")), + key=lambda row: ( + _usage_int(row.get("totalTokens")), + str(row.get("lastUsedAt") or ""), + ), reverse=True, )[:50] ) @@ -323,7 +325,10 @@ def _learning_snapshot( list_jobs = getattr(self.repository, "list_learning_jobs", None) jobs = tuple(list_jobs(limit=500)) if callable(list_jobs) else () try: - from apps.learning_worker_runtime import load_learning_worker_record, learning_worker_is_running + from apps.learning_worker_runtime import ( + load_learning_worker_record, + learning_worker_is_running, + ) worker = dict(load_learning_worker_record(state_dir) or {}) worker.setdefault("running", learning_worker_is_running(state_dir)) @@ -369,7 +374,9 @@ def _learning_snapshot( } -def _operation_snapshot(self, *, active_provider: Mapping[str, Any], provider_doctor: Mapping[str, Any]) -> dict[str, Any]: +def _operation_snapshot( + self, *, active_provider: Mapping[str, Any], provider_doctor: Mapping[str, Any] +) -> dict[str, Any]: database_path = self.repository.database_path state_dir = database_path.parent settings = _settings(state_dir, database_path) @@ -400,10 +407,26 @@ def _operation_snapshot(self, *, active_provider: Mapping[str, Any], provider_do _PERSONAL_MODEL_LENSES = ( - ("identity", "Identity", "Who the person is — durable attributes: character, values, style, and body."), - ("world", "World", "What is around the person — environment: people, projects, tools, and places."), - ("pulse", "Pulse", "How the person is right now — current state: chapter, focus, mood, and blockers."), - ("journey", "Journey", "What the person has been through — accumulated experience: lessons, patterns, and decisions."), + ( + "identity", + "Identity", + "Who the person is — durable attributes: character, values, style, and body.", + ), + ( + "world", + "World", + "What is around the person — environment: people, projects, tools, and places.", + ), + ( + "pulse", + "Pulse", + "How the person is right now — current state: chapter, focus, mood, and blockers.", + ), + ( + "journey", + "Journey", + "What the person has been through — accumulated experience: lessons, patterns, and decisions.", + ), ) @@ -446,16 +469,18 @@ def _personal_model_lens_summaries( for key, label, description in _PERSONAL_MODEL_LENSES: lens_facts = tuple(facts_by_lens[key]) latest = _latest_time(lens_facts) - rows.append({ - "component_key": key, - "lens": key, - "label": label, - "description": description, - "claim_count": len(lens_facts), - "active_claim_count": len(lens_facts), - "latest_observation_at": latest, - "status": "active" if lens_facts else "empty", - }) + rows.append( + { + "component_key": key, + "lens": key, + "label": label, + "description": description, + "claim_count": len(lens_facts), + "active_claim_count": len(lens_facts), + "latest_observation_at": latest, + "status": "active" if lens_facts else "empty", + } + ) return tuple(rows) @@ -520,9 +545,17 @@ def _step_event_content(step: Any, source_payloads: Mapping[str, Mapping[str, An metadata = _step_metadata(step) event_type = _step_event_type(step) if event_type == "user_query": - return _metadata_text(metadata, "effective_user_query") or _metadata_text(metadata, "user_query") or _payload_ref_prompt(step, source_payloads) + return ( + _metadata_text(metadata, "effective_user_query") + or _metadata_text(metadata, "user_query") + or _payload_ref_prompt(step, source_payloads) + ) if event_type == "source_input": - return _metadata_text(metadata, "user_query") or _metadata_text(metadata, "raw_user_query") or _payload_ref_prompt(step, source_payloads) + return ( + _metadata_text(metadata, "user_query") + or _metadata_text(metadata, "raw_user_query") + or _payload_ref_prompt(step, source_payloads) + ) if event_type == "system_prompt": return _metadata_text(metadata, "system_prompt") or _metadata_text(metadata, "model_prompt") if event_type == "tool_call": @@ -579,7 +612,10 @@ def _runtime_traces( episode_loops = tuple( sorted( loops_by_episode.get(episode.episode_id, ()), - key=lambda item: (str(getattr(item, "started_at", "") or ""), str(getattr(item, "loop_id", ""))), + key=lambda item: ( + str(getattr(item, "started_at", "") or ""), + str(getattr(item, "loop_id", "")), + ), ) ) loop_rows = [] @@ -588,7 +624,10 @@ def _runtime_traces( loop_steps = tuple( sorted( steps_by_loop.get(loop.loop_id, ()), - key=lambda item: (int(getattr(item, "sequence", 0) or 0), str(getattr(item, "created_at", "") or "")), + key=lambda item: ( + int(getattr(item, "sequence", 0) or 0), + str(getattr(item, "created_at", "") or ""), + ), ) ) step_rows = tuple(_dashboard_step_row(step, source_payloads) for step in loop_steps) @@ -607,7 +646,16 @@ def _runtime_traces( from .api_runtime_internal_sections import inspect_internal_dashboard -from .api_runtime_internal_triggers import delete_diary_entry, trigger_diary_write, trigger_reflect_job +from .api_runtime_internal_triggers import ( + delete_diary_entry, + trigger_diary_write, + trigger_reflect_job, +) -__all__ = ["delete_diary_entry", "inspect_internal_dashboard", "trigger_diary_write", "trigger_reflect_job"] +__all__ = [ + "delete_diary_entry", + "inspect_internal_dashboard", + "trigger_diary_write", + "trigger_reflect_job", +] diff --git a/apps/api/api_runtime_internal_sections.py b/apps/api/api_runtime_internal_sections.py index e888067..9109c02 100644 --- a/apps/api/api_runtime_internal_sections.py +++ b/apps/api/api_runtime_internal_sections.py @@ -10,7 +10,10 @@ from packages.runtime_layout import elephant_file_path from packages.state import ELEPHANT_IDENTITY_FILENAME from packages.storage.repository_support import DEFAULT_PERSONAL_MODEL_ID -from packages.understanding.personal_model_governance import is_skill_affinity_topic, skill_affinity_index_id +from packages.understanding.personal_model_governance import ( + is_skill_affinity_topic, + skill_affinity_index_id, +) from .api_runtime_console import ( _cron_jobs, @@ -27,7 +30,6 @@ _canonical_usage, _connection, _dashboard_active_provider, - _dashboard_step_row, _learning_snapshot, _personal_model_lens_summaries, _now, @@ -50,6 +52,7 @@ "semantic_index_entries", } + def _count_rows(database_path: Path, table: str) -> int: if table not in _COUNT_TABLES: raise ValueError(f"Unsupported dashboard count table: {table}") @@ -59,17 +62,24 @@ def _count_rows(database_path: Path, table: str) -> int: row = connection.execute("SELECT COUNT(*) AS count FROM " + table).fetchone() return int(row["count"] if row is not None else 0) + def _read_optional_text(path: Path, *, max_chars: int = 20_000) -> str: try: return path.read_text(encoding="utf-8", errors="replace")[:max_chars].strip() except OSError: return "" + def _elephant_identity_file(elephant_id: str, *, install_root: Path | None, fallback_text: str = "") -> dict[str, Any]: try: elephant_root = elephant_file_path(elephant_id, install_root=install_root) except ValueError: - return {"eggId": elephant_id, "path": "", "exists": False, "text": fallback_text.strip()} + return { + "eggId": elephant_id, + "path": "", + "exists": False, + "text": fallback_text.strip(), + } path = elephant_root / ELEPHANT_IDENTITY_FILENAME return { "eggId": elephant_id, @@ -138,7 +148,13 @@ def _empty_dashboard(self, *, section: str, generated_at: str) -> dict[str, Any] "herd": (), "personal_models": (), "states": (), - "runtime": {"episodes": (), "loops": (), "steps": (), "episode_traces": (), "learning_jobs": ()}, + "runtime": { + "episodes": (), + "loops": (), + "steps": (), + "episode_traces": (), + "learning_jobs": (), + }, "learning": _empty_learning(), "evidence": { "semantic_index_entries": (), @@ -206,47 +222,55 @@ def _state_projection_rows( state_episodes = episode_map.get(state.state_id, ()) state_loops = tuple(loop for episode in state_episodes for loop in loop_map.get(episode.episode_id, ())) state_steps = tuple(step for loop in state_loops for step in step_map.get(loop.loop_id, ())) - state_index_entries = tuple(entry for entry in semantic_index_entries if getattr(entry, "state_id", None) == state.state_id) + state_index_entries = tuple( + entry for entry in semantic_index_entries if getattr(entry, "state_id", None) == state.state_id + ) is_current = bool(current_state is not None and current_state.state_id == state.state_id) - growth_state = repository.load_personal_model_growth(state.personal_model_id) if repository is not None else None + growth_state = ( + repository.load_personal_model_growth(state.personal_model_id) if repository is not None else None + ) growth = build_growth_snapshot(growth_state or default_growth_state(state.personal_model_id)) - elephant_rows.append({ - "elephant_id": state.elephant_id, - "elephant_name": state.elephant_name, - "state_id": state.state_id, - "personal_model_id": state.personal_model_id, - "profile_id": state.personal_model_id, - "status": state.status, - "current": is_current, - "level": growth.level, - "checkpoint_label": f"checkpoint {growth.level}", - "stage": growth.stage.display_name, - "stage_id": growth.stage.stage_id, - "progress_percent": growth.progress_percent, - "score_to_next_level": growth.score_to_next_level, - "identity_mode": state.identity_mode, - "initiative": state.initiative, - "working_style": state.working_style, - "summary": state.summary, - "current_context_note": state.current_context_note, - "elephant_identity_text": state.elephant_identity_text, - "elephant_identity_file": _elephant_identity_file( - state.elephant_id, - install_root=install_root, - fallback_text=state.elephant_identity_text, - ), - "updated_at": _serialize(state).get("updated_at"), - }) + elephant_rows.append( + { + "elephant_id": state.elephant_id, + "elephant_name": state.elephant_name, + "state_id": state.state_id, + "personal_model_id": state.personal_model_id, + "profile_id": state.personal_model_id, + "status": state.status, + "current": is_current, + "level": growth.level, + "checkpoint_label": f"checkpoint {growth.level}", + "stage": growth.stage.display_name, + "stage_id": growth.stage.stage_id, + "progress_percent": growth.progress_percent, + "score_to_next_level": growth.score_to_next_level, + "identity_mode": state.identity_mode, + "initiative": state.initiative, + "working_style": state.working_style, + "summary": state.summary, + "current_context_note": state.current_context_note, + "elephant_identity_text": state.elephant_identity_text, + "elephant_identity_file": _elephant_identity_file( + state.elephant_id, + install_root=install_root, + fallback_text=state.elephant_identity_text, + ), + "updated_at": _serialize(state).get("updated_at"), + } + ) state_payload = dict(_serialize(state)) state_payload["current_context_note"] = state.current_context_note - state_rows.append({ - **state_payload, - "current": is_current, - "episode_count": len(state_episodes), - "loop_count": len(state_loops), - "step_count": len(state_steps), - "semantic_index_entry_count": len(state_index_entries), - }) + state_rows.append( + { + **state_payload, + "current": is_current, + "episode_count": len(state_episodes), + "loop_count": len(state_loops), + "step_count": len(state_steps), + "semantic_index_entry_count": len(state_index_entries), + } + ) return elephant_rows, state_rows @@ -255,6 +279,7 @@ def _personal_model_dashboard_row(model: Any, repository: Any) -> dict[str, Any] personal_model_id = str(model.personal_model_id) # Derive user_profile directly from active PM facts from packages.state.profile_from_claims import derive_profile_from_claims + facts = _active_personal_model_facts(repository, personal_model_id) user_profile = derive_profile_from_claims(facts) if user_profile: @@ -292,30 +317,39 @@ def _personal_model_rows( } rows: list[dict[str, Any]] = [] for model in personal_models: - model_index_entries = tuple(entry for entry in semantic_index_entries if entry.personal_model_id == model.personal_model_id) + model_index_entries = tuple( + entry for entry in semantic_index_entries if entry.personal_model_id == model.personal_model_id + ) model_facts = _active_personal_model_facts(repository, str(model.personal_model_id)) - model_all_facts = _personal_model_facts(repository, str(model.personal_model_id), ("active", "retired", "disputed")) - rows.append({ - **_personal_model_dashboard_row(model, repository), - "state_count": len(states_by_personal_model.get(model.personal_model_id, ())), - "personal_model_fact_count": len(model_facts), - "semantic_index_entry_count": len(model_index_entries), - "states": tuple({ - "state_id": state.state_id, - "elephant_id": state.elephant_id, - "elephant_name": state.elephant_name, - "status": state.status, - "summary": state.summary, - "current_context_note": state.current_context_note, - "updated_at": _serialize(state).get("updated_at"), - } for state in states_by_personal_model.get(model.personal_model_id, ())), - "understanding_components": _personal_model_lens_summaries( - model_facts=model_facts, - ), - "personal_model_facts": tuple(_serialize(fact) for fact in model_facts), - "personal_model_all_facts": tuple(_serialize(fact) for fact in model_all_facts), - "semantic_index_entries": tuple(_serialize(entry) for entry in model_index_entries), - }) + model_all_facts = _personal_model_facts( + repository, str(model.personal_model_id), ("active", "retired", "disputed") + ) + rows.append( + { + **_personal_model_dashboard_row(model, repository), + "state_count": len(states_by_personal_model.get(model.personal_model_id, ())), + "personal_model_fact_count": len(model_facts), + "semantic_index_entry_count": len(model_index_entries), + "states": tuple( + { + "state_id": state.state_id, + "elephant_id": state.elephant_id, + "elephant_name": state.elephant_name, + "status": state.status, + "summary": state.summary, + "current_context_note": state.current_context_note, + "updated_at": _serialize(state).get("updated_at"), + } + for state in states_by_personal_model.get(model.personal_model_id, ()) + ), + "understanding_components": _personal_model_lens_summaries( + model_facts=model_facts, + ), + "personal_model_facts": tuple(_serialize(fact) for fact in model_facts), + "personal_model_all_facts": tuple(_serialize(fact) for fact in model_all_facts), + "semantic_index_entries": tuple(_serialize(entry) for entry in model_index_entries), + } + ) return rows @@ -323,7 +357,9 @@ def _basic_personal_model_rows(personal_models: tuple[Any, ...], *, repository: return tuple(_personal_model_dashboard_row(model, repository) for model in personal_models) -def _runtime_collections(self) -> tuple[tuple[Any, ...], tuple[Any, ...], tuple[Any, ...]]: +def _runtime_collections( + self, +) -> tuple[tuple[Any, ...], tuple[Any, ...], tuple[Any, ...]]: episodes = _sort_items(self.repository.list_episodes(), id_field="episode_id", time_field="started_at") loops = _sort_items(self.repository.list_loops(), id_field="loop_id", time_field="started_at") steps = _sort_items(self.repository.list_steps(), id_field="step_id", time_field="created_at") @@ -337,18 +373,33 @@ def _runtime_maps( loops: tuple[Any, ...], steps: tuple[Any, ...], ) -> tuple[dict[str, tuple[Any, ...]], dict[str, tuple[Any, ...]], dict[str, tuple[Any, ...]]]: - episodes_by_state = {state.state_id: tuple(episode for episode in episodes if episode.state_id == state.state_id) for state in states} - loops_by_episode = {episode.episode_id: tuple(loop for loop in loops if loop.episode_id == episode.episode_id) for episode in episodes} + episodes_by_state = { + state.state_id: tuple(episode for episode in episodes if episode.state_id == state.state_id) for state in states + } + loops_by_episode = { + episode.episode_id: tuple(loop for loop in loops if loop.episode_id == episode.episode_id) + for episode in episodes + } steps_by_loop = {loop.loop_id: tuple(step for step in steps if step.loop_id == loop.loop_id) for loop in loops} return episodes_by_state, loops_by_episode, steps_by_loop -def _episode_rows(episodes: tuple[Any, ...], loops_by_episode: Mapping[str, tuple[Any, ...]], steps_by_loop: Mapping[str, tuple[Any, ...]]) -> list[dict[str, Any]]: +def _episode_rows( + episodes: tuple[Any, ...], + loops_by_episode: Mapping[str, tuple[Any, ...]], + steps_by_loop: Mapping[str, tuple[Any, ...]], +) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for episode in episodes: episode_loops = loops_by_episode.get(episode.episode_id, ()) episode_steps = tuple(step for loop in episode_loops for step in steps_by_loop.get(loop.loop_id, ())) - rows.append({**_serialize(episode), "loop_count": len(episode_loops), "step_count": len(episode_steps)}) + rows.append( + { + **_serialize(episode), + "loop_count": len(episode_loops), + "step_count": len(episode_steps), + } + ) return rows @@ -356,8 +407,12 @@ def _loop_rows(loops: tuple[Any, ...], steps_by_loop: Mapping[str, tuple[Any, .. return [{**_serialize(loop), "step_count": len(steps_by_loop.get(loop.loop_id, ()))} for loop in loops] -def _semantic_index_health(semantic_index_entries: tuple[Any, ...], active_provider: Mapping[str, Any]) -> dict[str, Any]: - semantic_index_status = str(active_provider.get("embedding_bootstrap_status") or ("indexed" if semantic_index_entries else "empty")) +def _semantic_index_health( + semantic_index_entries: tuple[Any, ...], active_provider: Mapping[str, Any] +) -> dict[str, Any]: + semantic_index_status = str( + active_provider.get("embedding_bootstrap_status") or ("indexed" if semantic_index_entries else "empty") + ) return { "status": semantic_index_status, "entry_count": len(semantic_index_entries), @@ -399,7 +454,9 @@ def _provider_catalog_rows(self, active_provider: Mapping[str, Any]) -> tuple[di return tuple(rows) -def _operation_model_snapshot(self, *, active_provider: Mapping[str, Any], embedding_provider: Mapping[str, Any]) -> dict[str, Any]: +def _operation_model_snapshot( + self, *, active_provider: Mapping[str, Any], embedding_provider: Mapping[str, Any] +) -> dict[str, Any]: provider_keys = self.list_provider_keys() return { "activeProvider": _dashboard_active_provider(dict(active_provider)), @@ -445,7 +502,12 @@ def _learning_overview(self) -> dict[str, Any]: ) return { "worker": {}, - "summary": {**counts, "total": len(jobs), "active_job_id": None, "latest_completed_at": None}, + "summary": { + **counts, + "total": len(jobs), + "active_job_id": None, + "latest_completed_at": None, + }, "jobs": job_rows, } @@ -463,9 +525,15 @@ def _latest_episode_row(self, *, limit: int = 20) -> tuple[dict[str, Any], ...]: episodes = _sort_items(self.repository.list_episodes(), id_field="episode_id", time_field="started_at") if not episodes: return () - recent = episodes[:max(1, int(limit))] + recent = episodes[: max(1, int(limit))] return tuple( - {**_serialize(episode), "loop_count": 0, "step_count": 0, "loops": (), "timeline": ()} + { + **_serialize(episode), + "loop_count": 0, + "step_count": 0, + "loops": (), + "timeline": (), + } for episode in recent ) @@ -473,25 +541,38 @@ def _latest_episode_row(self, *, limit: int = 20) -> tuple[dict[str, Any], ...]: def _fill_overview(dashboard: dict[str, Any], self) -> None: database_path = self.repository.database_path states, current_state = _fill_states(dashboard, self) - personal_models = _sort_items(self.repository.list_personal_models(), id_field="personal_model_id", time_field="updated_at") + personal_models = _sort_items( + self.repository.list_personal_models(), + id_field="personal_model_id", + time_field="updated_at", + ) canonical_models = tuple( model for model in personal_models if model.personal_model_id == DEFAULT_PERSONAL_MODEL_ID )[:1] overview_target_models = canonical_models or personal_models[:1] current_personal_model_id = DEFAULT_PERSONAL_MODEL_ID - semantic_index_entries = self.repository.list_semantic_index_entries() if hasattr(self.repository, "list_semantic_index_entries") else () - overview_models = tuple(_personal_model_rows( - personal_models=overview_target_models, - states=states, - semantic_index_entries=tuple(semantic_index_entries), - repository=self.repository, - )) + semantic_index_entries = ( + self.repository.list_semantic_index_entries() if hasattr(self.repository, "list_semantic_index_entries") else () + ) + overview_models = tuple( + _personal_model_rows( + personal_models=overview_target_models, + states=states, + semantic_index_entries=tuple(semantic_index_entries), + repository=self.repository, + ) + ) active_provider = dict(self.model_provider.describe()) learning = _learning_overview(self) semantic_index_count = _count_rows(database_path, "semantic_index_entries") - provider_auth_states = self.repository.list_provider_auth_states() if hasattr(self.repository, "list_provider_auth_states") else () + provider_auth_states = ( + self.repository.list_provider_auth_states() if hasattr(self.repository, "list_provider_auth_states") else () + ) dashboard["personal_models"] = overview_models - dashboard["runtime"] = {**dashboard["runtime"], "episode_traces": _latest_episode_row(self)} + dashboard["runtime"] = { + **dashboard["runtime"], + "episode_traces": _latest_episode_row(self), + } dashboard["learning"] = learning dashboard["overview"] = { "counts": { @@ -512,21 +593,33 @@ def _fill_overview(dashboard: dict[str, Any], self) -> None: "current_state_id": current_state.state_id if current_state is not None else None, "current_personal_model_id": current_personal_model_id, "provider_status": str(active_provider.get("status") or "unknown"), - "semantic_index_status": str(active_provider.get("embedding_bootstrap_status") or ("indexed" if semantic_index_count else "empty")), + "semantic_index_status": str( + active_provider.get("embedding_bootstrap_status") or ("indexed" if semantic_index_count else "empty") + ), "note": "Overview fetches counts, current elephant, current PersonalModel identity, and latest Episode summary only.", } def _fill_personal_models(dashboard: dict[str, Any], self) -> None: states, _ = _state_collections(self) - personal_models = _sort_items(self.repository.list_personal_models(), id_field="personal_model_id", time_field="updated_at") + personal_models = _sort_items( + self.repository.list_personal_models(), + id_field="personal_model_id", + time_field="updated_at", + ) canonical_models = tuple(model for model in personal_models if model.personal_model_id == DEFAULT_PERSONAL_MODEL_ID) - dashboard["personal_models"] = tuple(_personal_model_rows( - personal_models=canonical_models or personal_models[:1], - states=states, - semantic_index_entries=_sort_items(self.repository.list_semantic_index_entries(), id_field="semantic_index_entry_id", time_field="updated_at"), - repository=self.repository, - )) + dashboard["personal_models"] = tuple( + _personal_model_rows( + personal_models=canonical_models or personal_models[:1], + states=states, + semantic_index_entries=_sort_items( + self.repository.list_semantic_index_entries(), + id_field="semantic_index_entry_id", + time_field="updated_at", + ), + repository=self.repository, + ) + ) def _fill_runtime(dashboard: dict[str, Any], self) -> None: @@ -543,8 +636,14 @@ def _fill_runtime(dashboard: dict[str, Any], self) -> None: for loop in recent_loops_tuple: recent_steps.extend(self.repository.list_steps(loop_id=loop.loop_id)) recent_steps_tuple = tuple(recent_steps) - loops_by_episode = {ep.episode_id: tuple(loop for loop in recent_loops_tuple if loop.episode_id == ep.episode_id) for ep in recent_episodes} - steps_by_loop = {loop.loop_id: tuple(step for step in recent_steps_tuple if step.loop_id == loop.loop_id) for loop in recent_loops_tuple} + loops_by_episode = { + ep.episode_id: tuple(loop for loop in recent_loops_tuple if loop.episode_id == ep.episode_id) + for ep in recent_episodes + } + steps_by_loop = { + loop.loop_id: tuple(step for step in recent_steps_tuple if step.loop_id == loop.loop_id) + for loop in recent_loops_tuple + } elephant_rows, state_rows = _state_projection_rows( states, current_state=current_state, @@ -555,7 +654,12 @@ def _fill_runtime(dashboard: dict[str, Any], self) -> None: dashboard["states"] = tuple(state_rows) dashboard["runtime"] = { "episodes": tuple(_episode_rows(all_episodes, loops_by_episode, steps_by_loop)), - "episode_traces": _runtime_traces(episodes=recent_episodes, loops_by_episode=loops_by_episode, steps_by_loop=steps_by_loop, source_payloads={}), + "episode_traces": _runtime_traces( + episodes=recent_episodes, + loops_by_episode=loops_by_episode, + steps_by_loop=steps_by_loop, + source_payloads={}, + ), "learning_jobs": (), } @@ -586,20 +690,42 @@ def _fill_chat(dashboard: dict[str, Any], self) -> None: ) personal_models = tuple( model - for model in _sort_items(self.repository.list_personal_models(), id_field="personal_model_id", time_field="updated_at") + for model in _sort_items( + self.repository.list_personal_models(), + id_field="personal_model_id", + time_field="updated_at", + ) if model.personal_model_id == DEFAULT_PERSONAL_MODEL_ID ) dashboard["herd"] = tuple(elephant_rows) dashboard["states"] = tuple(state_rows) dashboard["personal_models"] = _basic_personal_model_rows(personal_models, repository=self.repository) - dashboard["runtime"] = {**dashboard["runtime"], "episode_traces": _runtime_traces(episodes=episodes, loops_by_episode=loops_by_episode, steps_by_loop=steps_by_loop, source_payloads={})} - dashboard["overview"] = {**dashboard["overview"], "current_state_id": current_state.state_id if current_state is not None else None, "current_personal_model_id": current_state.personal_model_id if current_state is not None else DEFAULT_PERSONAL_MODEL_ID} + dashboard["runtime"] = { + **dashboard["runtime"], + "episode_traces": _runtime_traces( + episodes=episodes, + loops_by_episode=loops_by_episode, + steps_by_loop=steps_by_loop, + source_payloads={}, + ), + } + dashboard["overview"] = { + **dashboard["overview"], + "current_state_id": current_state.state_id if current_state is not None else None, + "current_personal_model_id": current_state.personal_model_id + if current_state is not None + else DEFAULT_PERSONAL_MODEL_ID, + } def _fill_evidence(dashboard: dict[str, Any], self) -> None: # Steps, Episodes, and Facts own evidence. This section only exposes the # shared semantic index that makes those rows searchable. - semantic_index_entries = _sort_items(self.repository.list_semantic_index_entries(), id_field="semantic_index_entry_id", time_field="updated_at") + semantic_index_entries = _sort_items( + self.repository.list_semantic_index_entries(), + id_field="semantic_index_entry_id", + time_field="updated_at", + ) active_provider = dict(self.model_provider.describe()) dashboard["evidence"] = { "semantic_index_entries": tuple(_serialize(entry) for entry in semantic_index_entries), @@ -665,7 +791,10 @@ def _fill_questions(dashboard: dict[str, Any], self) -> None: # Coverage grid — one row per (lens, facet) with durable Fact counts. by_key: dict[tuple[str, str], dict[str, Any]] = {} for fact in facts: - key = (str(getattr(fact, "lens", "") or ""), str(getattr(fact, "facet", "") or "")) + key = ( + str(getattr(fact, "lens", "") or ""), + str(getattr(fact, "facet", "") or ""), + ) row = by_key.setdefault(key, {"lens": key[0], "facet": key[1], "facts": 0}) row["facts"] += 1 for row in by_key.values(): @@ -713,6 +842,7 @@ def _dashboard_question_config(repository) -> dict[str, Any]: global_config_path_for_state_dir, load_global_config, ) + state_dir = repository.database_path.parent config = load_global_config( global_config_path_for_state_dir(state_dir), @@ -761,7 +891,9 @@ def _fill_providers(dashboard: dict[str, Any], self) -> None: } dashboard["operations"] = { **dashboard["operations"], - "models": _operation_model_snapshot(self, active_provider=active_provider, embedding_provider=embedding_provider), + "models": _operation_model_snapshot( + self, active_provider=active_provider, embedding_provider=embedding_provider + ), } @@ -779,7 +911,7 @@ def _skill_affinity_rows(self) -> tuple[dict[str, Any], ...]: return () personal_models = _sort_items(list_models(), id_field="personal_model_id", time_field="updated_at") canonical_models = tuple(model for model in personal_models if model.personal_model_id == DEFAULT_PERSONAL_MODEL_ID) - target_model = (canonical_models or personal_models[:1]) + target_model = canonical_models or personal_models[:1] if not target_model: return () personal_model_id = str(target_model[0].personal_model_id) @@ -817,7 +949,10 @@ def _skill_affinity_rows(self) -> tuple[dict[str, Any], ...]: return tuple( sorted( slots.values(), - key=lambda row: (-int(row["activeCount"]), str(row["skillId"] or row["indexId"] or row["topic"])), + key=lambda row: ( + -int(row["activeCount"]), + str(row["skillId"] or row["indexId"] or row["topic"]), + ), ) ) @@ -858,7 +993,10 @@ def _fill_gateway(dashboard: dict[str, Any], self) -> None: def _fill_cron(dashboard: dict[str, Any], self) -> None: _fill_states(dashboard, self) - dashboard["operations"] = {**dashboard["operations"], "cron": {"jobs": tuple(_cron_jobs(self))}} + dashboard["operations"] = { + **dashboard["operations"], + "cron": {"jobs": tuple(_cron_jobs(self))}, + } def _fill_settings(dashboard: dict[str, Any], self) -> None: @@ -867,11 +1005,17 @@ def _fill_settings(dashboard: dict[str, Any], self) -> None: def _fill_usage(dashboard: dict[str, Any], self) -> None: - dashboard["operations"] = {**dashboard["operations"], "usage": _canonical_usage(self.repository.database_path)} + dashboard["operations"] = { + **dashboard["operations"], + "usage": _canonical_usage(self.repository.database_path), + } def _fill_logs(dashboard: dict[str, Any], self) -> None: - dashboard["operations"] = {**dashboard["operations"], "logs": tuple(_logs(self.repository.database_path.parent))} + dashboard["operations"] = { + **dashboard["operations"], + "logs": tuple(_logs(self.repository.database_path.parent)), + } def _fill_diary(dashboard: dict[str, Any], self) -> None: @@ -938,4 +1082,9 @@ def inspect_internal_dashboard(self, section: str) -> dict[str, Any]: return dashboard -__all__ = ["DASHBOARD_SECTIONS", "inspect_internal_dashboard", "trigger_diary_write", "trigger_reflect_job"] +__all__ = [ + "DASHBOARD_SECTIONS", + "inspect_internal_dashboard", + "trigger_diary_write", + "trigger_reflect_job", +] diff --git a/apps/api/api_runtime_internal_triggers.py b/apps/api/api_runtime_internal_triggers.py index 2c9f520..5328734 100644 --- a/apps/api/api_runtime_internal_triggers.py +++ b/apps/api/api_runtime_internal_triggers.py @@ -22,6 +22,7 @@ def trigger_diary_write(self, *, target_date: str) -> dict[str, Any]: try: # Attempt to enqueue journal job from datetime import datetime + target = datetime.strptime(target_date.strip()[:10], "%Y-%m-%d").date() metadata["target_date"] = target.isoformat() except (ValueError, AttributeError): @@ -57,7 +58,11 @@ def delete_diary_entry(self, *, entry_date: str) -> dict[str, Any]: personal_model_id=pm.personal_model_id, entry_date=target, ) - return {"status": "deleted" if deleted else "not_found", "entry_date": target, "deleted": deleted} + return { + "status": "deleted" if deleted else "not_found", + "entry_date": target, + "deleted": deleted, + } def trigger_reflect_job(self, *, trigger: str, features: str | None = None) -> dict[str, Any]: @@ -102,7 +107,12 @@ def trigger_reflect_job(self, *, trigger: str, features: str | None = None) -> d ensure_learning_worker_running(state_dir=self.repository.database_path.parent) except Exception: pass - return {"status": "queued", "job_id": job.job_id, "trigger": trigger or "manual", "features": features} + return { + "status": "queued", + "job_id": job.job_id, + "trigger": trigger or "manual", + "features": features, + } __all__ = ["delete_diary_entry", "trigger_diary_write", "trigger_reflect_job"] diff --git a/apps/api/api_runtime_provider_methods.py b/apps/api/api_runtime_provider_methods.py index f024d4a..6b9d577 100644 --- a/apps/api/api_runtime_provider_methods.py +++ b/apps/api/api_runtime_provider_methods.py @@ -1,11 +1,8 @@ """Provider methods for the API runtime app.""" - from __future__ import annotations -from dataclasses import asdict, dataclass, is_dataclass, replace -from pathlib import Path -import json +from dataclasses import asdict, replace from typing import Any, Mapping from uuid import uuid4 @@ -21,59 +18,21 @@ OPENAI_COMPATIBLE_EMBED_SECRET_REFERENCE_ID, default_local_embedding_provider_config, ) -from packages.models import SurfaceModelProviderCapability -from packages.auth import AuthProfile, PersistentAuthProfileStore, SecretReference -from packages.context import ContextRuntime +from packages.auth import AuthProfile, SecretReference from packages.contracts import ( ContextBundle, Episode, - EventEnvelope, ExecutionResult, ) -from packages.contracts.runtime import PersonalModelRuntimeState, RecallEvidence -from packages.kernel import KernelDependencies, KernelOutcome, KernelService, KernelSourceRequest, ReconciliationPipeline, StateReconciler -from packages.evidence.recall_runtime import RecallRuntime -from packages.operator.runtime import ( - RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, - ProcedureOperatorDetail, - build_recall_evidence_operator_surface, - build_procedure_operator_surface, - build_profile_operator_surface, -) -from packages.storage import RuntimeStorageRepository -from packages.runtime_config import global_config_path_for_state_dir, save_provider_to_config -from packages.tools import BuiltinToolDependencies, build_tool_runtime -from packages.tools.adapters import DeliveryMessageSurfaceAdapter, StructuredClarifySurface -from packages.tools.browser_backend import create_playwright_browser_backend - -from .capabilities import ( - APIContextCapability, - APIDeliveryCapability, - APIRecallCapability, - APIModelProvider, - APITelemetrySink, - APIToolExecution, +from packages.contracts.runtime import PersonalModelRuntimeState +from packages.runtime_config import ( + global_config_path_for_state_dir, + save_provider_to_config, ) -from .state_runtime import APIContinuityInspection, APIStateService + from .api_runtime_support import ( - APIAppConfig, - APIResponse, - APIEpisodeCreationResult, - APIEpisodeInspection, - APIEpisodeTransitionResult, - APILoopRecord, - APILoopResult, - _coerce_str_tuple, - _json_bytes, - _jsonable, _now, - _optional_bool, - _optional_datetime, - _optional_str, - _read_json_bytes, - _split_path, ) _EMBEDDING_API_KEY_ENV_VAR = OPENAI_COMPATIBLE_EMBED_DEFAULT_SECRET_ENV_VAR @@ -108,6 +67,7 @@ def list_providers(self) -> dict[str, Any]: "providers": providers, } + def setup_provider(self, provider_id: str) -> dict[str, Any]: guide = self.model_provider.runtime_resolver.build_setup_guide(provider_id) return { @@ -115,6 +75,7 @@ def setup_provider(self, provider_id: str) -> dict[str, Any]: "guide": guide.as_mapping(), } + def discover_provider_models(self, payload: Mapping[str, Any]) -> dict[str, Any]: provider_id = str(payload.get("providerId") or payload.get("provider_id") or "").strip() if not provider_id: @@ -133,6 +94,7 @@ def discover_provider_models(self, payload: Mapping[str, Any]) -> dict[str, Any] "models": [asdict(model) for model in models], } + def _metadata_context_window_tokens(metadata: Mapping[str, str]) -> int | None: raw_value = metadata.get("context_window_tokens") if raw_value is None: @@ -189,6 +151,7 @@ def set_default_provider(self, provider_profile: Mapping[str, Any]) -> dict[str, "active_provider": self.model_provider.describe(), } + def _provider_probe( self, *, @@ -227,6 +190,7 @@ def _provider_probe( prompt=prompt, ) + def test_provider(self, *, prompt: str = "Summarize the current provider configuration.") -> dict[str, Any]: active_provider = self.model_provider.describe() try: @@ -243,6 +207,7 @@ def test_provider(self, *, prompt: str = "Summarize the current provider configu "result": result, } + def doctor_provider(self) -> dict[str, Any]: active_provider = self.model_provider.describe() bootstrap_check = { @@ -372,7 +337,10 @@ def _stored_api_key_for_active_provider(self, provider_id: str) -> str | None: active_profile = self.model_provider.active_profile() if active_profile is None or active_profile.provider_id != provider_id: return None - reference = next((item for item in active_profile.secret_references if item.secret_key == "api_key"), None) + reference = next( + (item for item in active_profile.secret_references if item.secret_key == "api_key"), + None, + ) if reference is None or not self.repository.has_auth_secret_value(reference.reference_id): return None credentials = self.model_provider.resolve_credentials(active_profile) @@ -384,7 +352,10 @@ def embedding_provider_summary(self) -> dict[str, Any]: active_provider = dict(self.model_provider.describe()) profile = self._active_embedding_provider_profile() if profile is not None: - reference = next((item for item in profile.secret_references if item.secret_key == "api_key"), None) + reference = next( + (item for item in profile.secret_references if item.secret_key == "api_key"), + None, + ) reference_id = reference.reference_id if reference is not None else "" has_secret = bool(reference_id) and self.repository.has_auth_secret_value(reference_id) return { @@ -543,7 +514,9 @@ def create_provider_key(self, payload: Mapping[str, Any]) -> dict[str, Any]: provider_id = str(payload.get("providerId") or payload.get("provider_id") or "").strip() secret_key = str(payload.get("secretKey") or payload.get("secret_key") or "api_key").strip() secret_name = str(payload.get("secretName") or payload.get("secret_name") or "api_token").strip() - reference_id = str(payload.get("referenceId") or payload.get("reference_id") or f"secret:{profile_id}:{secret_key}").strip() + reference_id = str( + payload.get("referenceId") or payload.get("reference_id") or f"secret:{profile_id}:{secret_key}" + ).strip() if not profile_id or not provider_id or not reference_id: raise ValueError("profileId, providerId, and referenceId are required") profile = self.repository.load_auth_profile(profile_id) diff --git a/apps/api/api_runtime_surface_methods.py b/apps/api/api_runtime_surface_methods.py index d95a545..545185e 100644 --- a/apps/api/api_runtime_surface_methods.py +++ b/apps/api/api_runtime_surface_methods.py @@ -9,16 +9,16 @@ from apps.provider_runtime import provider_profile_from_payload from packages.contracts import Episode, State from packages.contracts.runtime import RecallEvidence, PersonalModelRuntimeState -from packages.evidence.recall_runtime import RecallRuntime from packages.growth import ProgressionProjectionBuilder -from packages.kernel.episode_state_machine import open_next_episode as _open_next_episode +from packages.kernel.episode_state_machine import ( + open_next_episode as _open_next_episode, +) from packages.state.persistence import resolve_runtime_state from packages.storage.repository_support import canonical_personal_model_id from packages.operator.runtime import ( RecallEvidenceOperatorDetail, RecallEvidenceSearchHit, ProcedureOperatorDetail, - build_canonical_procedure_detail, build_recall_evidence_operator_surface, ) @@ -27,10 +27,7 @@ APIEpisodeInspection, APIEpisodeLifecycleResult, APIEpisodeTransitionResult, - APILoopRecord, - APILoopResult, _now, - _optional_str, ) from .state_runtime import APIContinuityInspection @@ -75,7 +72,10 @@ def _ensure_episode_state( identity_mode=personal_model.mode, surface_bindings=("api",), summary=f"{personal_model.display_name} is ready for API-bound continuity.", - metadata={"personal_model_id": personal_model.profile_id, "episode_id": episode.episode_id}, + metadata={ + "personal_model_id": personal_model.profile_id, + "episode_id": episode.episode_id, + }, ) else: state = replace( @@ -84,7 +84,11 @@ def _ensure_episode_state( identity_mode=personal_model.mode, state_anchor=state_anchor, surface_bindings=tuple(sorted({*existing.surface_bindings, "api"})), - metadata={**dict(existing.metadata), "personal_model_id": personal_model.profile_id, "episode_id": episode.episode_id}, + metadata={ + **dict(existing.metadata), + "personal_model_id": personal_model.profile_id, + "episode_id": episode.episode_id, + }, ) self.repository.upsert_state(state) self.repository.switch_state(state.state_id) @@ -127,7 +131,9 @@ def create_episode( ) resolved_elephant_id = elephant_id or personal_model_id timestamp = _now() - resolved_state_id = f"state:{resolved_elephant_id}" if resolved_elephant_id else f"state:{personal_model.profile_id}:api" + resolved_state_id = ( + f"state:{resolved_elephant_id}" if resolved_elephant_id else f"state:{personal_model.profile_id}:api" + ) episode = Episode( episode_id=episode_id or uuid4().hex, state_id=resolved_state_id, @@ -354,7 +360,11 @@ def search_recall_evidence_surface(self, episode_id: str, *, query: str, limit: for evidence in self.list_recall_evidence(episode_id) ) hits = tuple( - RecallEvidenceSearchHit(evidence=candidate.evidence, score=candidate.score, reasons=candidate.reasons) + RecallEvidenceSearchHit( + evidence=candidate.evidence, + score=candidate.score, + reasons=candidate.reasons, + ) for candidate in retrieval.candidates ) return build_recall_evidence_operator_surface( diff --git a/apps/api/state_runtime.py b/apps/api/state_runtime.py index 919aa28..5050474 100644 --- a/apps/api/state_runtime.py +++ b/apps/api/state_runtime.py @@ -7,7 +7,10 @@ from packages.continuity import ContinuityProjection, ContinuityProjectionService from packages.contracts import ElephantIdentityRecord, Episode -from packages.state.rendered_views import RenderedRelationshipView, RenderedUserProfileView +from packages.state.rendered_views import ( + RenderedRelationshipView, + RenderedUserProfileView, +) from packages.contracts.runtime import PersonalModelRuntimeState from packages.evidence.recall_runtime import RecallRuntime from packages.state import ( @@ -22,7 +25,10 @@ resolve_runtime_state, sync_canonical_profile_state, ) -from packages.state.projection import build_loaded_profile_from_state, render_user_profile_projection_text +from packages.state.projection import ( + build_loaded_profile_from_state, + render_user_profile_projection_text, +) from packages.state.user_updates import apply_user_profile_update from packages.storage import RuntimeStorageRepository from packages.storage.repository_support import canonical_personal_model_id @@ -89,7 +95,9 @@ def ensure_personal_model_state( elephant_id=elephant_id, required=False, ) - resolved_elephant_id = elephant_id or (resolved_state.elephant_id if resolved_state is not None and resolved_state.elephant_id else None) + resolved_elephant_id = elephant_id or ( + resolved_state.elephant_id if resolved_state is not None and resolved_state.elephant_id else None + ) persisted = load_persisted_canonical_state(self.repository, canonical_personal_model.profile_id) if ( persisted.elephant_identity is not None @@ -231,12 +239,17 @@ def update_identity_state( ) bundle = build_canonical_profile_state( loaded, - elephant_id=resolved_state.elephant_id if resolved_state is not None and resolved_state.elephant_id else None, + elephant_id=resolved_state.elephant_id + if resolved_state is not None and resolved_state.elephant_id + else None, ) synced = sync_canonical_profile_state( self.repository, bundle, - previous=load_persisted_canonical_state(self.repository, canonical_personal_model_id(updated_personal_model.profile_id)), + previous=load_persisted_canonical_state( + self.repository, + canonical_personal_model_id(updated_personal_model.profile_id), + ), sync_source="api.identity.update", recall_runtime=self.recall_runtime, surface="api", @@ -284,12 +297,16 @@ def update_user_state( ) bundle = build_canonical_profile_state( replace(loaded, user_profile_text=render_user_profile_projection_text(next_user)), - elephant_id=resolved_state.elephant_id if resolved_state is not None and resolved_state.elephant_id else None, + elephant_id=resolved_state.elephant_id + if resolved_state is not None and resolved_state.elephant_id + else None, ) synced = sync_canonical_profile_state( self.repository, bundle, - previous=load_persisted_canonical_state(self.repository, canonical_personal_model_id(personal_model.profile_id)), + previous=load_persisted_canonical_state( + self.repository, canonical_personal_model_id(personal_model.profile_id) + ), sync_source="api.user.update", recall_runtime=self.recall_runtime, surface="api", @@ -344,12 +361,16 @@ def update_relationship_state( next_notes = current_notes bundle = build_canonical_profile_state( replace(loaded, companion=replace(companion, notes=next_notes)), - elephant_id=resolved_state.elephant_id if resolved_state is not None and resolved_state.elephant_id else None, + elephant_id=resolved_state.elephant_id + if resolved_state is not None and resolved_state.elephant_id + else None, ) synced = sync_canonical_profile_state( self.repository, bundle, - previous=load_persisted_canonical_state(self.repository, canonical_personal_model_id(personal_model.profile_id)), + previous=load_persisted_canonical_state( + self.repository, canonical_personal_model_id(personal_model.profile_id) + ), sync_source="api.relationship.update", recall_runtime=self.recall_runtime, surface="api", diff --git a/apps/cli/cli_main_elephant_support.py b/apps/cli/cli_main_elephant_support.py index 185ba3d..a4ee527 100644 --- a/apps/cli/cli_main_elephant_support.py +++ b/apps/cli/cli_main_elephant_support.py @@ -2,48 +2,17 @@ from __future__ import annotations -import argparse -from dataclasses import dataclass -import os -import random -import re -import sys -from collections.abc import Iterable, Mapping -from pathlib import Path +from collections.abc import Mapping -from packages.state import DEFAULT_ELEPHANT_IDENTITY_TEXT, render_default_elephant_identity from .runtime import CliRuntime from .turn_metrics import cache_hit_metric_line -from .provider_flow import ( - ProviderSelectionState, - provider_choices as _shared_provider_choices, - provider_setup_defaults, - run_provider_selection_wizard, -) -from .shell import ( - Align, - BRAND_ACCENT, - BRAND_LIGHT, - BRAND_MUTED, - Console, - Group, - Panel, - ProductizedShell, - RICH_AVAILABLE, - Table, - Text, - _resolve_elephant_version, - render_elephant_mark, -) from .wizard import ( WIZARD_BACK, WizardChoice, - _WizardBackSignal, _interactive_shell_supported, _wizard_choice_prompt, _wizard_dialogs_supported, - _wizard_text_prompt, ) DEFAULT_PROVIDER_ID = "openai-compatible" @@ -84,7 +53,6 @@ CLI_THEME_SUBTITLE = "shaped from you · alive between sessions." - from .cli_main_support import * # noqa: F401,F403 @@ -142,6 +110,7 @@ def _select_elephant(runtime: CliRuntime, elephant_id: str): runtime.repository.switch_state(elephant_state.state_id) return session + def _print_doctor(runtime: CliRuntime, *, deep: bool = False) -> None: provider = runtime.provider_doctor(deep=deep) security = runtime.security_doctor() @@ -158,16 +127,14 @@ def _print_doctor(runtime: CliRuntime, *, deep: bool = False) -> None: f"active_provider_embedding_ready · {_embedding_bootstrap_ready_label(active.get('embedding_bootstrap_status'))}", ) provider_checks = tuple( - f"{check['check']} · {check['status']}{f' · {check['summary']}' if check.get('summary') else ''}" + f"{check['check']} · {check['status']}{f' · {check["summary"]}' if check.get('summary') else ''}" for check in provider["checks"] ) security_checks = tuple( - f"{check['check']} · {check['status']}{f' · {check['summary']}' if check.get('summary') else ''}" + f"{check['check']} · {check['status']}{f' · {check["summary"]}' if check.get('summary') else ''}" for check in security["checks"] ) - extra_lines = ( - (f"probe_summary · {provider['probe_summary']}",) if provider["probe_summary"] else () - ) + extra_lines = (f"probe_summary · {provider['probe_summary']}",) if provider["probe_summary"] else () sections = [CliCardSection("Readiness", status_lines)] embedding_status_lines = _embedding_bootstrap_status_lines(embedding) if embedding_status_lines: @@ -192,6 +159,7 @@ def _print_doctor(runtime: CliRuntime, *, deep: bool = False) -> None: else ("elephant init",), ) + def _print_elephant_created(runtime: CliRuntime, session_id: str) -> None: session = runtime.inspect_session(session_id) elephant_id = runtime.elephant_id_for_session(session) @@ -211,9 +179,14 @@ def _print_elephant_created(runtime: CliRuntime, session_id: str) -> None: tuple(ready_lines), ), ), - next_commands=("elephant wake", f"elephant wake --elephant-id {elephant_id}", "elephant herd"), + next_commands=( + "elephant wake", + f"elephant wake --elephant-id {elephant_id}", + "elephant herd", + ), ) + def _print_elephant_paused() -> None: _print_cli_card( "Elephant Agent elephant paused", @@ -221,6 +194,7 @@ def _print_elephant_paused() -> None: next_commands=("elephant herd new ", "elephant wake", "elephant herd"), ) + def _print_herd(runtime: CliRuntime) -> None: herd = runtime.list_herd(limit=24) current_session = _current_elephant_session(runtime) @@ -251,13 +225,18 @@ def _print_herd(runtime: CliRuntime) -> None: ), ) + def _print_current_elephant(runtime: CliRuntime) -> None: session = _current_elephant_session(runtime) if session is None: _print_cli_card( "Current elephant", "No current elephant has been selected yet.", - next_commands=("elephant herd", "elephant herd use ", "elephant wake"), + next_commands=( + "elephant herd", + "elephant herd use ", + "elephant wake", + ), ) return elephant_id = runtime.elephant_id_for_session(session) @@ -276,7 +255,11 @@ def _print_current_elephant(runtime: CliRuntime) -> None: ), ), ), - next_commands=(f"elephant wake --elephant-id {elephant_id}", "elephant wake", "elephant herd"), + next_commands=( + f"elephant wake --elephant-id {elephant_id}", + "elephant wake", + "elephant herd", + ), ) @@ -302,6 +285,7 @@ def _print_elephant_selected(runtime: CliRuntime, elephant_id: str) -> None: next_commands=("elephant wake", "elephant herd current", "elephant herd"), ) + def _print_elephant_retired(elephant_id: str, deleted_sessions: int) -> None: _print_cli_card( "Elephant retired", @@ -319,6 +303,7 @@ def _print_elephant_retired(elephant_id: str, deleted_sessions: int) -> None: next_commands=("elephant herd", "elephant herd new ", "elephant wake"), ) + def _print_elephant_retire_paused() -> None: _print_cli_card( "Elephant retire paused", @@ -326,6 +311,7 @@ def _print_elephant_retire_paused() -> None: next_commands=("elephant herd", "elephant wake", "elephant herd new "), ) + def _print_all_herd_retired(deleted_elephants: int, deleted_sessions: int) -> None: _print_cli_card( "All herd retired", @@ -343,6 +329,7 @@ def _print_all_herd_retired(deleted_elephants: int, deleted_sessions: int) -> No next_commands=("elephant herd new ", "elephant init", "elephant status"), ) + def _prompt_elephant_choice( runtime: CliRuntime, herd, @@ -407,6 +394,7 @@ def _prompt_elephant_choice( return elephant print(" enter an elephant number or elephant id from the list above") + def _open_growth_episode( runtime: CliRuntime, *, @@ -453,14 +441,18 @@ def open_next(selected): return opened.episode_id, f"Opened elephant {selected.elephant_id}" if current is not None: opened = open_next(current) - return opened.episode_id, f"Opened elephant {runtime.elephant_id_for_session(current)}" + return ( + opened.episode_id, + f"Opened elephant {runtime.elephant_id_for_session(current)}", + ) raise ValueError("multiple herd are available; pass --elephant-id or enter wake from an interactive TTY") + def _print_elephant_blocked(runtime: CliRuntime) -> None: report = runtime.provider_doctor() provider = report["provider"] checks = tuple( - f"{check['check']} · {check['status']}{f' · {check['summary']}' if check.get('summary') else ''}" + f"{check['check']} · {check['status']}{f' · {check["summary"]}' if check.get('summary') else ''}" for check in report["checks"] ) sections = [ @@ -482,11 +474,12 @@ def _print_elephant_blocked(runtime: CliRuntime) -> None: next_commands=("elephant init", "elephant status"), ) + def _print_grow_blocked(runtime: CliRuntime) -> None: report = runtime.provider_doctor() provider = report["provider"] checks = tuple( - f"{check['check']} · {check['status']}{f' · {check['summary']}' if check.get('summary') else ''}" + f"{check['check']} · {check['status']}{f' · {check["summary"]}' if check.get('summary') else ''}" for check in report["checks"] ) sections = [ @@ -508,19 +501,14 @@ def _print_grow_blocked(runtime: CliRuntime) -> None: next_commands=("elephant init", "elephant status"), ) + def _provider_session_ready(report: dict[str, object]) -> bool: raw_checks = tuple(report.get("checks", ())) if not raw_checks: return str(report.get("status", "")).strip().lower() == "ready" - checks = { - str(check.get("check")): str(check.get("status")) - for check in raw_checks - if isinstance(check, dict) - } - return ( - checks.get("provider_profile") == "configured" - and checks.get("credentials") in {"available", "not-required"} - ) + checks = {str(check.get("check")): str(check.get("status")) for check in raw_checks if isinstance(check, dict)} + return checks.get("provider_profile") == "configured" and checks.get("credentials") in {"available", "not-required"} + def _print_no_elephants() -> None: _print_cli_card( @@ -529,6 +517,7 @@ def _print_no_elephants() -> None: next_commands=("elephant herd new ", "elephant status"), ) + def _print_assistant_turn(runtime: CliRuntime, outcome, *, title: str = "Elephant Agent turn") -> None: provider = dict(runtime.provider_summary()) lines = [ @@ -577,9 +566,14 @@ def _print_provider_turn_failed( title, "The provider failed before the Loop completed.", sections=(CliCardSection("Recovery state", tuple(lines)),), - next_commands=("elephant provider status", "elephant status", "elephant wake --message \"...\""), + next_commands=( + "elephant provider status", + "elephant status", + 'elephant wake --message "..."', + ), ) + __all__ = [ "DEFAULT_PROVIDER_ID", "DEFAULT_ELEPHANT_NAME_SUGGESTIONS", diff --git a/apps/cli/cli_main_impl.py b/apps/cli/cli_main_impl.py index 77e08a1..9bad511 100644 --- a/apps/cli/cli_main_impl.py +++ b/apps/cli/cli_main_impl.py @@ -3,9 +3,7 @@ from __future__ import annotations import argparse -from dataclasses import dataclass import os -import random import re import subprocess import sys @@ -15,12 +13,15 @@ import typer -from packages.state import DEFAULT_ELEPHANT_IDENTITY_TEXT, render_default_elephant_identity, render_user_profile_text +from packages.state import ( + DEFAULT_ELEPHANT_IDENTITY_TEXT, + render_default_elephant_identity, + render_user_profile_text, +) from .runtime import CliRuntime from .provider_flow import ( ProviderSelectionState, - provider_choices as _shared_provider_choices, provider_setup_defaults, run_provider_selection_wizard, ) @@ -34,10 +35,8 @@ Panel, ProductizedShell, RICH_AVAILABLE, - Table, Text, _resolve_elephant_version, - render_stage_zero_elephant_mark, ) from .wizard import ( WIZARD_BACK, @@ -46,7 +45,6 @@ _WizardBackSignal, _interactive_shell_supported, _wizard_choice_prompt, - _wizard_dialogs_supported, _wizard_multi_choice_prompt, _wizard_text_prompt, ) @@ -89,7 +87,6 @@ CLI_THEME_SUBTITLE = "Personal Model first, curious at your pace." - from .cli_main_elephant_support import * # noqa: F401,F403 from .cli_main_elephant_support import _current_elephant_session from .cli_main_setup import * # noqa: F401,F403 @@ -103,8 +100,16 @@ def _prompt_first_elephant_name( language: str = "en", ) -> str | _WizardBackSignal: return _wizard_text_prompt( - _init_text(language, "Name Your First Elephant Agent", "给你的第一个 Elephant Agent 起名"), - _init_text(language, "This first Elephant Agent is yours. What name feels right?", "这是你的第一个 Elephant Agent。哪个名字最合适?"), + _init_text( + language, + "Name Your First Elephant Agent", + "给你的第一个 Elephant Agent 起名", + ), + _init_text( + language, + "This first Elephant Agent is yours. What name feels right?", + "这是你的第一个 Elephant Agent。哪个名字最合适?", + ), default=default_name, allow_back=allow_back, ) @@ -119,24 +124,40 @@ def _prompt_learning_intensity( """Let the user choose how often Elephant Agent may ask Personal Model questions.""" return _wizard_choice_prompt( _init_text(language, "Elephant Agent's Questions", "Elephant Agent 的问题频率"), - _init_text(language, "How often should Elephant Agent ask open questions to learn more about you?", "Elephant Agent 可以多频繁地问开放问题来更了解你?"), + _init_text( + language, + "How often should Elephant Agent ask open questions to learn more about you?", + "Elephant Agent 可以多频繁地问开放问题来更了解你?", + ), ( WizardChoice( value="low", label=_init_text(language, "Quiet questions", "安静提问"), - detail=_init_text(language, "Low touch. Up to two open questions per day, usually morning or before bed.", "低频打扰。每天最多两次,通常偏早晨或睡前。"), + detail=_init_text( + language, + "Low touch. Up to two open questions per day, usually morning or before bed.", + "低频打扰。每天最多两次,通常偏早晨或睡前。", + ), emoji="🌙", ), WizardChoice( value="medium", label=_init_text(language, "Gentle questions", "温和提问"), - detail=_init_text(language, "Default. If an IM route is running, asks after roughly 3 idle hours.", "默认。如果 IM 通道在线,空闲约 3 小时后会问一个问题。"), + detail=_init_text( + language, + "Default. If an IM route is running, asks after roughly 3 idle hours.", + "默认。如果 IM 通道在线,空闲约 3 小时后会问一个问题。", + ), emoji="🌿", ), WizardChoice( value="high", label=_init_text(language, "Active questions", "积极提问"), - detail=_init_text(language, "Most active. Outside quiet hours, an IM route may ask once an elephant has been idle for 1 hour.", "最主动。静默时间外,如果 IM 通道在线,elephant 空闲 1 小时后就可以主动问。"), + detail=_init_text( + language, + "Most active. Outside quiet hours, an IM route may ask once an elephant has been idle for 1 hour.", + "最主动。静默时间外,如果 IM 通道在线,elephant 空闲 1 小时后就可以主动问。", + ), emoji="⚡", ), ), @@ -164,8 +185,16 @@ def _prompt_first_language(default: str = "en", *, allow_back: bool = False) -> "First language / 第一语言", "Choose the language Elephant Agent should use for the rest of init.", ( - WizardChoice(value="en", label="English", detail="Use English for init and store English as your first language."), - WizardChoice(value="zh", label="中文", detail="后续初始化过程使用中文,并把中文记录为你的第一语言。"), + WizardChoice( + value="en", + label="English", + detail="Use English for init and store English as your first language.", + ), + WizardChoice( + value="zh", + label="中文", + detail="后续初始化过程使用中文,并把中文记录为你的第一语言。", + ), ), default=_normalize_first_language(default), allow_back=allow_back, @@ -200,7 +229,11 @@ def _prompt_required_text( default: str = "", allow_back: bool = True, ) -> str | _WizardBackSignal: - required = _init_text(language, "Please add a little something here before continuing.", "这里需要写一点内容,才能继续。") + required = _init_text( + language, + "Please add a little something here before continuing.", + "这里需要写一点内容,才能继续。", + ) while True: answer = _wizard_text_prompt( _init_text(language, title_en, title_zh), @@ -312,7 +345,11 @@ def _prompt_hobbies(language: str, default: str = "", *, allow_back: bool = True existing = tuple(part.strip() for part in re.split(r"[,,、/]+", default or "") if part.strip()) answer = _wizard_multi_choice_prompt( _init_text(language, "Personal hobbies", "个人爱好"), - _init_text(language, "Optional. Use Space to select any hobbies Elephant Agent should know.", "可选。用空格多选你希望 Elephant Agent 知道的个人爱好。"), + _init_text( + language, + "Optional. Use Space to select any hobbies Elephant Agent should know.", + "可选。用空格多选你希望 Elephant Agent 知道的个人爱好。", + ), tuple(_init_wizard_choice(choice) for choice in choices), default_values=existing, allow_back=allow_back, @@ -326,33 +363,131 @@ def _prompt_hobbies(language: str, default: str = "", *, allow_back: bool = True _ATTENTION_CHOICES_EN = ( - ("a project wants to move", "A project wants to move", "Work, product, writing, craft, or something you want to bring into shape.", "🚀", "Primary attention is on moving a concrete project or piece of work forward; prioritize momentum, blockers, completion pressure, and output rhythm."), - ("standing at a fork", "Standing at a fork", "Changing direction, deciding, leaving, or beginning a new road.", "🧭", "Currently in transition and choice, possibly changing direction, deciding, leaving an old path, or beginning a new one; prioritize trade-offs, risks, what is hard to leave, and reversible next steps."), - ("chewing on a new question", "Chewing on a new question", "Reading, studying, testing ideas, or trying to understand something important.", "🔎", "Drawn to a new question and forming judgment through study, research, or testing; prioritize structure, key assumptions, evidence, and the next round of exploration."), - ("relationships are tugging", "Relationships are tugging", "Family, friends, intimacy, distance, care, or where you belong among people.", "🤝", "Attention is being pulled by relationships, belonging, or social position; include distance, care, promises, boundaries, and emotional safety in the frame."), - ("body needs attention first", "Body needs attention first", "Sleep, health, rhythm, pressure, stamina, or recovery may need to be seen first.", "🌿", "Body, energy, and recovery rhythm need attention first; consider sleep, pressure, stamina, safety, and restoration before pushing intensity."), - ("steady the life floor", "Steady the life floor", "Home, money, routines, logistics, or making ordinary life hold you again.", "🏠", "Basic life stability needs to come first, including home, money, routines, logistics, or real-world order; prioritize structure, certainty, and low-friction arrangements that hold daily life."), + ( + "a project wants to move", + "A project wants to move", + "Work, product, writing, craft, or something you want to bring into shape.", + "🚀", + "Primary attention is on moving a concrete project or piece of work forward; prioritize momentum, blockers, completion pressure, and output rhythm.", + ), + ( + "standing at a fork", + "Standing at a fork", + "Changing direction, deciding, leaving, or beginning a new road.", + "🧭", + "Currently in transition and choice, possibly changing direction, deciding, leaving an old path, or beginning a new one; prioritize trade-offs, risks, what is hard to leave, and reversible next steps.", + ), + ( + "chewing on a new question", + "Chewing on a new question", + "Reading, studying, testing ideas, or trying to understand something important.", + "🔎", + "Drawn to a new question and forming judgment through study, research, or testing; prioritize structure, key assumptions, evidence, and the next round of exploration.", + ), + ( + "relationships are tugging", + "Relationships are tugging", + "Family, friends, intimacy, distance, care, or where you belong among people.", + "🤝", + "Attention is being pulled by relationships, belonging, or social position; include distance, care, promises, boundaries, and emotional safety in the frame.", + ), + ( + "body needs attention first", + "Body needs attention first", + "Sleep, health, rhythm, pressure, stamina, or recovery may need to be seen first.", + "🌿", + "Body, energy, and recovery rhythm need attention first; consider sleep, pressure, stamina, safety, and restoration before pushing intensity.", + ), + ( + "steady the life floor", + "Steady the life floor", + "Home, money, routines, logistics, or making ordinary life hold you again.", + "🏠", + "Basic life stability needs to come first, including home, money, routines, logistics, or real-world order; prioritize structure, certainty, and low-friction arrangements that hold daily life.", + ), ("type", "None fit; I’ll write one", "Write one short phrase instead", "✍️"), ) _ATTENTION_CHOICES_ZH = ( - ("一件作品正在往前推", "一件作品正在往前推", "像是有件东西正在手里发热,想被认真推到前面去。可能是项目、产品、写作、作品,或任何你希望它慢慢成形的事。", "🚀", "最近的主要注意力在推进一个具体作品或项目;优先关注推进节奏、阻力、完成欲和产出压力。"), - ("正站在一个岔路口", "正站在一个岔路口", "像站在一条路将要分开的地方,心里已经知道不能一直停在原处。可能关于换方向、做决定、离开,或开始一段新路。", "🧭", "最近处在过渡和选择中,可能正在考虑换方向、做决定、离开原来的路径或开始新路;优先澄清取舍、风险、舍不得的东西和可逆的下一步。"), - ("在啃一个新问题", "在啃一个新问题", "有个问题一直在脑海里发亮,想被读懂、拆开、验证。可能是学习、研究、准备,或理解某件重要的事。", "🔎", "最近被一个新问题吸引,正在通过学习、研究或验证来形成判断;优先整理问题结构、关键假设、证据和下一轮探索。"), - ("关系和归属感在拉扯", "关系和归属感在拉扯", "有些牵挂来自人和人之间的位置:靠近、距离、照顾、承诺,或不知道自己该站在哪里。", "🤝", "最近的注意力被关系、归属感或人际位置牵动;距离、照顾、承诺、边界和情感安全都需要一起纳入判断。"), - ("身体和精力先要照顾", "身体和精力先要照顾", "身体像先举了一下手,提醒你慢一点。睡眠、健康、节奏、压力、体力或恢复,可能比别的事更需要被看见。", "🌿", "最近首先需要照顾身体、精力和恢复节奏;先考虑睡眠、压力、体力、安全感和节奏修复,再谈更高强度的推进。"), - ("先把生活地基稳住", "先把生活地基稳住", "像先把房间的灯打开、地面扫平,让生活重新能托住你。可能关于住处、金钱、日程、杂事,或现实里的秩序。", "🏠", "最近需要先稳定生活基础,包括住处、金钱、日程、杂事或现实秩序;优先关注能承托日常的结构、确定性和低摩擦安排。"), + ( + "一件作品正在往前推", + "一件作品正在往前推", + "像是有件东西正在手里发热,想被认真推到前面去。可能是项目、产品、写作、作品,或任何你希望它慢慢成形的事。", + "🚀", + "最近的主要注意力在推进一个具体作品或项目;优先关注推进节奏、阻力、完成欲和产出压力。", + ), + ( + "正站在一个岔路口", + "正站在一个岔路口", + "像站在一条路将要分开的地方,心里已经知道不能一直停在原处。可能关于换方向、做决定、离开,或开始一段新路。", + "🧭", + "最近处在过渡和选择中,可能正在考虑换方向、做决定、离开原来的路径或开始新路;优先澄清取舍、风险、舍不得的东西和可逆的下一步。", + ), + ( + "在啃一个新问题", + "在啃一个新问题", + "有个问题一直在脑海里发亮,想被读懂、拆开、验证。可能是学习、研究、准备,或理解某件重要的事。", + "🔎", + "最近被一个新问题吸引,正在通过学习、研究或验证来形成判断;优先整理问题结构、关键假设、证据和下一轮探索。", + ), + ( + "关系和归属感在拉扯", + "关系和归属感在拉扯", + "有些牵挂来自人和人之间的位置:靠近、距离、照顾、承诺,或不知道自己该站在哪里。", + "🤝", + "最近的注意力被关系、归属感或人际位置牵动;距离、照顾、承诺、边界和情感安全都需要一起纳入判断。", + ), + ( + "身体和精力先要照顾", + "身体和精力先要照顾", + "身体像先举了一下手,提醒你慢一点。睡眠、健康、节奏、压力、体力或恢复,可能比别的事更需要被看见。", + "🌿", + "最近首先需要照顾身体、精力和恢复节奏;先考虑睡眠、压力、体力、安全感和节奏修复,再谈更高强度的推进。", + ), + ( + "先把生活地基稳住", + "先把生活地基稳住", + "像先把房间的灯打开、地面扫平,让生活重新能托住你。可能关于住处、金钱、日程、杂事,或现实里的秩序。", + "🏠", + "最近需要先稳定生活基础,包括住处、金钱、日程、杂事或现实秩序;优先关注能承托日常的结构、确定性和低摩擦安排。", + ), ("type", "都不像,我写一句", "如果上面都不贴切,可以写一个短句", "✍️"), ) _MBTI_EMOJI = { - "INTJ": "♟️", "INTP": "🧩", "ENTJ": "🧭", "ENTP": "⚡", - "INFJ": "🌙", "INFP": "🌿", "ENFJ": "🌻", "ENFP": "✨", - "ISTJ": "📚", "ISFJ": "🕯️", "ESTJ": "🏗️", "ESFJ": "🤝", - "ISTP": "🛠️", "ISFP": "🎨", "ESTP": "🏃", "ESFP": "🎉", + "INTJ": "♟️", + "INTP": "🧩", + "ENTJ": "🧭", + "ENTP": "⚡", + "INFJ": "🌙", + "INFP": "🌿", + "ENFJ": "🌻", + "ENFP": "✨", + "ISTJ": "📚", + "ISFJ": "🕯️", + "ESTJ": "🏗️", + "ESFJ": "🤝", + "ISTP": "🛠️", + "ISFP": "🎨", + "ESTP": "🏃", + "ESFP": "🎉", } _MBTI_CODES = ( - "INTJ", "INTP", "ENTJ", "ENTP", "INFJ", "INFP", "ENFJ", "ENFP", - "ISTJ", "ISFJ", "ESTJ", "ESFJ", "ISTP", "ISFP", "ESTP", "ESFP", + "INTJ", + "INTP", + "ENTJ", + "ENTP", + "INFJ", + "INFP", + "ENFJ", + "ENFP", + "ISTJ", + "ISFJ", + "ESTJ", + "ESFJ", + "ISTP", + "ISFP", + "ESTP", + "ESFP", ) _MBTI_TRAITS_EN = { "INTJ": "Architect: imaginative, strategic, private, and long-range; prefers clear plans, competence, and room to think independently", @@ -424,12 +559,37 @@ def _mbti_choices(language: str = "en") -> tuple[tuple[str, ...], ...]: ("music", "Music", "Listening, playing, collecting, or live shows", "🎧"), ("films and shows", "Films / shows", "Movies, series, anime, documentaries", "🎬"), ("games", "Games", "Video games, board games, puzzles, or playful systems", "🎮"), - ("sports and movement", "Sports / movement", "Gym, running, climbing, dancing, walking", "🏃"), - ("food and cooking", "Food / cooking", "Eating, cooking, baking, coffee, restaurants", "🍳"), - ("travel and city walks", "Travel / city walks", "Exploring places, routes, neighborhoods, trips", "🧳"), - ("art and design", "Art / design", "Drawing, photography, visual taste, making things beautiful", "🎨"), + ( + "sports and movement", + "Sports / movement", + "Gym, running, climbing, dancing, walking", + "🏃", + ), + ( + "food and cooking", + "Food / cooking", + "Eating, cooking, baking, coffee, restaurants", + "🍳", + ), + ( + "travel and city walks", + "Travel / city walks", + "Exploring places, routes, neighborhoods, trips", + "🧳", + ), + ( + "art and design", + "Art / design", + "Drawing, photography, visual taste, making things beautiful", + "🎨", + ), ("writing", "Writing", "Journaling, essays, fiction, notes, scripts", "✍️"), - ("technology and making", "Technology / making", "Coding, gadgets, tools, building small systems", "🛠️"), + ( + "technology and making", + "Technology / making", + "Coding, gadgets, tools, building small systems", + "🛠️", + ), ("skip", "Skip", "Leave this blank for now", "➖"), ) _HOBBY_CHOICES_ZH = ( @@ -457,11 +617,20 @@ def _mbti_choices(language: str = "en") -> tuple[tuple[str, ...], ...]: "hobbies": {"lens": "identity", "topic": "identity.style.hobbies.personal"}, "city": {"lens": "world", "topic": "world.places.city.current"}, "food_allergies": {"lens": "identity", "topic": "identity.body.allergy.food"}, - "medication_allergies": {"lens": "identity", "topic": "identity.body.allergy.medication"}, - "chronic_conditions": {"lens": "identity", "topic": "identity.body.condition.chronic"}, + "medication_allergies": { + "lens": "identity", + "topic": "identity.body.allergy.medication", + }, + "chronic_conditions": { + "lens": "identity", + "topic": "identity.body.condition.chronic", + }, "trauma_history": {"lens": "identity", "topic": "identity.body.history.trauma"}, "safety_boundaries": {"lens": "identity", "topic": "identity.body.safety.boundary"}, - "inferred_companion_posture": {"lens": "identity", "topic": "identity.style.companion.posture"}, + "inferred_companion_posture": { + "lens": "identity", + "topic": "identity.style.companion.posture", + }, } @@ -473,18 +642,66 @@ def _mbti_choices(language: str = "en") -> tuple[tuple[str, ...], ...]: "en": "If your recent inner weather were an image, which one is closest?", "zh": "如果把你现在的内心状态想象成一种风景,会是什么样的?", "choices_en": ( - ("standing in fog", "Standing in fog", "Not lost, but the horizon has not opened yet; reflect context first, then clarify the next visible step", "🌫️", "Not completely lost, but visibility and direction are not open yet; first confirm the ground underfoot, then gently clarify the next visible step."), - ("tabs open everywhere", "Tabs open everywhere", "Many thoughts are running in the background; help gather, order, and reduce cognitive load", "🗂️", "Many thoughts or unfinished tasks are open at once; help gather, order, and reduce cognitive load."), - ("boat resting in harbor", "Boat resting in harbor", "Pausing at shore before setting out again; allow recovery before asking for motion", "⚓", "In a pause, repair, or harboring phase before setting out again; do not push too quickly, allow replenishment and rhythm to return."), - ("small light ahead", "Small light ahead", "Direction is faint but present; protect the signal and test forward gradually", "🕯️", "A faint but meaningful direction is already visible; protect that signal and use small experiments to make the path clearer."), + ( + "standing in fog", + "Standing in fog", + "Not lost, but the horizon has not opened yet; reflect context first, then clarify the next visible step", + "🌫️", + "Not completely lost, but visibility and direction are not open yet; first confirm the ground underfoot, then gently clarify the next visible step.", + ), + ( + "tabs open everywhere", + "Tabs open everywhere", + "Many thoughts are running in the background; help gather, order, and reduce cognitive load", + "🗂️", + "Many thoughts or unfinished tasks are open at once; help gather, order, and reduce cognitive load.", + ), + ( + "boat resting in harbor", + "Boat resting in harbor", + "Pausing at shore before setting out again; allow recovery before asking for motion", + "⚓", + "In a pause, repair, or harboring phase before setting out again; do not push too quickly, allow replenishment and rhythm to return.", + ), + ( + "small light ahead", + "Small light ahead", + "Direction is faint but present; protect the signal and test forward gradually", + "🕯️", + "A faint but meaningful direction is already visible; protect that signal and use small experiments to make the path clearer.", + ), ("type", "None fit; I’ll describe it", "A short image or phrase", "✍️"), ("skip", "Leave this blank for now", "", "➖"), ), "choices_zh": ( - ("像站在起雾的路口", "像站在起雾的路口", "雾还没有散,不是不知道往哪走,只是远处暂时看不清。也许可以先陪你确认脚下,再慢慢等下一步显出来。", "🌫️", "并非完全迷失,而是处在视野未打开、方向暂不清晰的阶段;适合先确认脚下处境,再温和澄清下一步。"), - ("像房间里开满标签页", "像房间里开满标签页", "脑海里像同时亮着很多窗口,每个都还在发出一点声音。也许先把它们轻轻放到桌面上,会舒服一些。", "🗂️", "近期可能同时承载很多念头和未关闭的任务;适合帮助收束、排序、减轻认知负荷。"), - ("像一艘船暂时靠岸", "像一艘船暂时靠岸", "不是不再出发,只是船需要靠岸、补给、修整一下。等风向更清楚时,再离岸也不迟。", "⚓", "可能处在修整、恢复或重新出发前的停靠期;不要急着推动,应允许补给和节奏恢复。"), - ("像远处有一盏小灯", "像远处有一盏小灯", "答案还没有完全出现,但远处已经有一点光。那点光也许很小,却值得先被守住。", "🕯️", "已有微弱但重要的方向感;适合保护这点信号,并用小步试探让方向更清晰。"), + ( + "像站在起雾的路口", + "像站在起雾的路口", + "雾还没有散,不是不知道往哪走,只是远处暂时看不清。也许可以先陪你确认脚下,再慢慢等下一步显出来。", + "🌫️", + "并非完全迷失,而是处在视野未打开、方向暂不清晰的阶段;适合先确认脚下处境,再温和澄清下一步。", + ), + ( + "像房间里开满标签页", + "像房间里开满标签页", + "脑海里像同时亮着很多窗口,每个都还在发出一点声音。也许先把它们轻轻放到桌面上,会舒服一些。", + "🗂️", + "近期可能同时承载很多念头和未关闭的任务;适合帮助收束、排序、减轻认知负荷。", + ), + ( + "像一艘船暂时靠岸", + "像一艘船暂时靠岸", + "不是不再出发,只是船需要靠岸、补给、修整一下。等风向更清楚时,再离岸也不迟。", + "⚓", + "可能处在修整、恢复或重新出发前的停靠期;不要急着推动,应允许补给和节奏恢复。", + ), + ( + "像远处有一盏小灯", + "像远处有一盏小灯", + "答案还没有完全出现,但远处已经有一点光。那点光也许很小,却值得先被守住。", + "🕯️", + "已有微弱但重要的方向感;适合保护这点信号,并用小步试探让方向更清晰。", + ), ("type", "都不像,我自己描述", "写一个短句或画面就好", "✍️"), ("skip", "暂时留空", "", "➖"), ), @@ -496,20 +713,80 @@ def _mbti_choices(language: str = "en") -> tuple[tuple[str, ...], ...]: "en": "When you make trade-offs lately, what feels most important not to lose?", "zh": "最近做取舍时,你最不想弄丢的是什么?", "choices_en": ( - ("keep my authorship", "Keep my authorship", "Autonomy and authorship matter in trade-offs; preserve choice space and avoid over-directing", "🧭", "Authorship and autonomy matter in trade-offs; do not over-decide on their behalf, preserve choice space and help them hold the wheel."), - ("keep the ground steady", "Keep the ground steady", "Safety and certainty matter in trade-offs; reduce collapse risk before optimizing", "🪨", "Safety and certainty are bottom-layer needs in the trade-off; reduce collapse risk and real-world instability before optimizing or taking bigger risks."), - ("stay true inside", "Stay true inside", "Authenticity and inner consistency matter in trade-offs; slower is better than self-betrayal", "💎", "Authenticity and inner consistency matter; respect the value signal rather than evaluating only by efficiency, gain, or speed."), - ("protect important people", "Protect important people", "Relationships, promises, and care matter in trade-offs; include responsibility and attachment in the frame", "🤲", "Relationships, promises, and care strongly shape the decision; include emotional responsibility and relational boundaries in the analysis."), - ("open the future", "Open the future", "Possibility matters in trade-offs; evaluate long-term space, growth, and optionality", "🌱", "Possibility, growth space, and long-term optionality matter; help evaluate which path makes the future wider."), + ( + "keep my authorship", + "Keep my authorship", + "Autonomy and authorship matter in trade-offs; preserve choice space and avoid over-directing", + "🧭", + "Authorship and autonomy matter in trade-offs; do not over-decide on their behalf, preserve choice space and help them hold the wheel.", + ), + ( + "keep the ground steady", + "Keep the ground steady", + "Safety and certainty matter in trade-offs; reduce collapse risk before optimizing", + "🪨", + "Safety and certainty are bottom-layer needs in the trade-off; reduce collapse risk and real-world instability before optimizing or taking bigger risks.", + ), + ( + "stay true inside", + "Stay true inside", + "Authenticity and inner consistency matter in trade-offs; slower is better than self-betrayal", + "💎", + "Authenticity and inner consistency matter; respect the value signal rather than evaluating only by efficiency, gain, or speed.", + ), + ( + "protect important people", + "Protect important people", + "Relationships, promises, and care matter in trade-offs; include responsibility and attachment in the frame", + "🤲", + "Relationships, promises, and care strongly shape the decision; include emotional responsibility and relational boundaries in the analysis.", + ), + ( + "open the future", + "Open the future", + "Possibility matters in trade-offs; evaluate long-term space, growth, and optionality", + "🌱", + "Possibility, growth space, and long-term optionality matter; help evaluate which path makes the future wider.", + ), ("type", "None fit; I’ll name it", "A short value or phrase", "✍️"), ("skip", "Leave this blank for now", "", "➖"), ), "choices_zh": ( - ("我想保住选择权", "我想保住选择权", "最怕的不是慢一点,而是把方向感交出去。这个选择最好仍然像是你自己做出的。", "🧭", "取舍中很在意自主感和作者性;不要替其下结论,应保留选择空间,帮助重新握住方向盘。"), - ("我想先踩稳地面", "我想先踩稳地面", "在往前之前,你可能需要先确认地面不会塌。安全感和确定性,是这次取舍里很重要的底色。", "🪨", "安全感和确定性是当前取舍中的底层需求;应先降低坍塌感和现实风险,再谈优化或冒险。"), - ("我不想背离真心", "我不想背离真心", "有些决定不只是对错,也关乎是否还像自己。宁可慢一点,也不想把真实感弄丢。", "💎", "真实感和内在一致性很重要;需要尊重其价值感,不要只用效率或收益衡量。"), - ("我想顾住重要的人", "我想顾住重要的人", "这件事不只属于你一个人。关系、承诺、照顾和亏欠感,都可能一起坐在桌边。", "🤲", "关系、承诺和照顾会显著影响判断;应把情感责任和关系边界纳入分析。"), - ("我想把未来打开", "我想把未来打开", "你在意这个选择会把生活带到哪里。它最好不是关上一扇门,而是让未来多一点空气。", "🌱", "重视可能性、成长空间和长期可选项;应帮助评估哪条路让未来更宽。"), + ( + "我想保住选择权", + "我想保住选择权", + "最怕的不是慢一点,而是把方向感交出去。这个选择最好仍然像是你自己做出的。", + "🧭", + "取舍中很在意自主感和作者性;不要替其下结论,应保留选择空间,帮助重新握住方向盘。", + ), + ( + "我想先踩稳地面", + "我想先踩稳地面", + "在往前之前,你可能需要先确认地面不会塌。安全感和确定性,是这次取舍里很重要的底色。", + "🪨", + "安全感和确定性是当前取舍中的底层需求;应先降低坍塌感和现实风险,再谈优化或冒险。", + ), + ( + "我不想背离真心", + "我不想背离真心", + "有些决定不只是对错,也关乎是否还像自己。宁可慢一点,也不想把真实感弄丢。", + "💎", + "真实感和内在一致性很重要;需要尊重其价值感,不要只用效率或收益衡量。", + ), + ( + "我想顾住重要的人", + "我想顾住重要的人", + "这件事不只属于你一个人。关系、承诺、照顾和亏欠感,都可能一起坐在桌边。", + "🤲", + "关系、承诺和照顾会显著影响判断;应把情感责任和关系边界纳入分析。", + ), + ( + "我想把未来打开", + "我想把未来打开", + "你在意这个选择会把生活带到哪里。它最好不是关上一扇门,而是让未来多一点空气。", + "🌱", + "重视可能性、成长空间和长期可选项;应帮助评估哪条路让未来更宽。", + ), ("type", "都不像,我自己命名", "写一个词或短句就好", "✍️"), ("skip", "暂时留空", "", "➖"), ), @@ -521,20 +798,80 @@ def _mbti_choices(language: str = "en") -> tuple[tuple[str, ...], ...]: "en": "When pressure rises, what do you usually do first?", "zh": "压力升起来时,你通常会先怎么保护自己?", "choices_en": ( - ("retreat into quiet", "Retreat into quiet", "Under pressure, tends to pull inward and process quietly before speaking", "🫧", "Under pressure, low-input and low-interruption inner processing space is needed; offer quiet and buffer before inviting expression."), - ("comb the knots into lines", "Comb the knots into lines", "Under pressure, tends to use lists, structure, and plans to separate the knots", "🧵", "Under pressure, stability returns through structure, lists, and decomposition; organize the mess into layers and steps."), - ("get the wheels moving", "Get the wheels moving", "Under pressure, tends to move first and regain stability by adjusting in motion", "🏃", "Under pressure, action restores feel and stability; offer a concrete small step rather than staying in abstract analysis."), - ("ask where it hurts", "Ask where it hurts", "Under pressure, tends to ask what pain point, value, or meaning is being touched", "🔦", "Under pressure, the deeper pain point, value, or emotion needs to be understood; ask first about meaning and where it hurts."), - ("borrow another mind", "Borrow another mind", "Under pressure, tends to think with another person rather than metabolize it alone", "👂", "Under pressure, co-thinking and being held matter more than processing alone; provide companionate sorting and shared simulation."), + ( + "retreat into quiet", + "Retreat into quiet", + "Under pressure, tends to pull inward and process quietly before speaking", + "🫧", + "Under pressure, low-input and low-interruption inner processing space is needed; offer quiet and buffer before inviting expression.", + ), + ( + "comb the knots into lines", + "Comb the knots into lines", + "Under pressure, tends to use lists, structure, and plans to separate the knots", + "🧵", + "Under pressure, stability returns through structure, lists, and decomposition; organize the mess into layers and steps.", + ), + ( + "get the wheels moving", + "Get the wheels moving", + "Under pressure, tends to move first and regain stability by adjusting in motion", + "🏃", + "Under pressure, action restores feel and stability; offer a concrete small step rather than staying in abstract analysis.", + ), + ( + "ask where it hurts", + "Ask where it hurts", + "Under pressure, tends to ask what pain point, value, or meaning is being touched", + "🔦", + "Under pressure, the deeper pain point, value, or emotion needs to be understood; ask first about meaning and where it hurts.", + ), + ( + "borrow another mind", + "Borrow another mind", + "Under pressure, tends to think with another person rather than metabolize it alone", + "👂", + "Under pressure, co-thinking and being held matter more than processing alone; provide companionate sorting and shared simulation.", + ), ("type", "None fit; I’ll describe it", "A short pattern is enough", "✍️"), ("skip", "Leave this blank for now", "", "➖"), ), "choices_zh": ( - ("先缩回安静里", "先缩回安静里", "压力一来,你可能会先往安静处退一小步。不是逃开,是给自己一点重新听见自己的空间。", "🫧", "压力下需要低输入、低打扰的内在处理空间;应先给安静和缓冲,再邀请表达。"), - ("先把乱麻理成线", "先把乱麻理成线", "混乱靠近时,你会想把它拆成线、列成项、排出顺序。把看不清的东西变清楚,会让人稳一点。", "🧵", "压力下靠结构、清单和拆解恢复稳定;适合把混乱整理成层次和步骤。"), - ("先动手让车跑起来", "先动手让车跑起来", "你可能不是等想明白才动,而是在动起来之后找回手感。车先跑起来,方向可以边走边调。", "🏃", "压力下通过行动找回手感和稳定;适合给出可执行的小步,而不是停留在抽象分析。"), - ("先问这事伤到哪儿", "先问这事伤到哪儿", "你会想知道它到底碰到了哪里:是害怕、委屈、价值感,还是某个一直没被说清的东西。", "🔦", "压力下需要理解被触动的深层痛点、价值或情绪;应先追问意义和伤处。"), - ("先找个人一起想", "先找个人一起想", "压力太满时,一个人在房间里可能不够。你需要另一个脑子,也需要一个能接住话的人。", "👂", "压力下需要共思和被接住,而不是独自消化;应提供陪伴式梳理和共同推演。"), + ( + "先缩回安静里", + "先缩回安静里", + "压力一来,你可能会先往安静处退一小步。不是逃开,是给自己一点重新听见自己的空间。", + "🫧", + "压力下需要低输入、低打扰的内在处理空间;应先给安静和缓冲,再邀请表达。", + ), + ( + "先把乱麻理成线", + "先把乱麻理成线", + "混乱靠近时,你会想把它拆成线、列成项、排出顺序。把看不清的东西变清楚,会让人稳一点。", + "🧵", + "压力下靠结构、清单和拆解恢复稳定;适合把混乱整理成层次和步骤。", + ), + ( + "先动手让车跑起来", + "先动手让车跑起来", + "你可能不是等想明白才动,而是在动起来之后找回手感。车先跑起来,方向可以边走边调。", + "🏃", + "压力下通过行动找回手感和稳定;适合给出可执行的小步,而不是停留在抽象分析。", + ), + ( + "先问这事伤到哪儿", + "先问这事伤到哪儿", + "你会想知道它到底碰到了哪里:是害怕、委屈、价值感,还是某个一直没被说清的东西。", + "🔦", + "压力下需要理解被触动的深层痛点、价值或情绪;应先追问意义和伤处。", + ), + ( + "先找个人一起想", + "先找个人一起想", + "压力太满时,一个人在房间里可能不够。你需要另一个脑子,也需要一个能接住话的人。", + "👂", + "压力下需要共思和被接住,而不是独自消化;应提供陪伴式梳理和共同推演。", + ), ("type", "都不像,我自己描述", "写一个短句就好", "✍️"), ("skip", "暂时留空", "", "➖"), ), @@ -546,20 +883,80 @@ def _mbti_choices(language: str = "en") -> tuple[tuple[str, ...], ...]: "en": "When your energy is low, what usually helps you return to yourself?", "zh": "当你需要恢复精力、让自己舒服一点时,通常会怎么做?", "choices_en": ( - ("give me a quiet corner", "Give me a quiet corner", "Low energy recovery starts with quiet space, less input, and no rushing", "🌙", "Recovery needs less input, less rushing, and space that does not require explanation; lower interruption density."), - ("talk softly for a while", "Talk softly for a while", "Low energy recovery is helped by calm presence and gentle conversation", "🕯️", "Steady presence and low-pressure conversation help the mind land; accompany first, solve second."), - ("change the body rhythm", "Change the body rhythm", "Low energy recovery is helped by walking, sleep, music, food, or a body-rhythm reset", "🌿", "Body rhythm can lead psychological recovery; consider walking, rest, music, food, or rhythm reset first."), - ("finish one tiny action", "Finish one tiny action", "Low energy recovery is helped by completing one tiny action and restoring agency", "✅", "Tiny completion restores agency; break suggestions into one very small step that can be completed immediately."), - ("use beauty and ritual", "Use beauty and ritual", "Low energy recovery is helped by beauty, light, music, order, objects, or small rituals", "✨", "Beauty, order, light, music, objects, or small rituals help return to self; support through sensory and ritualized cues."), + ( + "give me a quiet corner", + "Give me a quiet corner", + "Low energy recovery starts with quiet space, less input, and no rushing", + "🌙", + "Recovery needs less input, less rushing, and space that does not require explanation; lower interruption density.", + ), + ( + "talk softly for a while", + "Talk softly for a while", + "Low energy recovery is helped by calm presence and gentle conversation", + "🕯️", + "Steady presence and low-pressure conversation help the mind land; accompany first, solve second.", + ), + ( + "change the body rhythm", + "Change the body rhythm", + "Low energy recovery is helped by walking, sleep, music, food, or a body-rhythm reset", + "🌿", + "Body rhythm can lead psychological recovery; consider walking, rest, music, food, or rhythm reset first.", + ), + ( + "finish one tiny action", + "Finish one tiny action", + "Low energy recovery is helped by completing one tiny action and restoring agency", + "✅", + "Tiny completion restores agency; break suggestions into one very small step that can be completed immediately.", + ), + ( + "use beauty and ritual", + "Use beauty and ritual", + "Low energy recovery is helped by beauty, light, music, order, objects, or small rituals", + "✨", + "Beauty, order, light, music, objects, or small rituals help return to self; support through sensory and ritualized cues.", + ), ("type", "None fit; I’ll name it", "A short recovery cue", "✍️"), ("skip", "Leave this blank for now", "", "➖"), ), "choices_zh": ( - ("给我一块安静角落", "给我一块安静角落", "恢复有时不是被鼓励,而是先少一点声音、少一点催促。你需要一块不必解释自己的安静角落。", "🌙", "恢复时需要少输入、少催促、不必解释自己的空间;应降低打扰密度。"), - ("陪我轻轻说一会儿", "陪我轻轻说一会儿", "有时候不是要立刻解决什么,只是有人在旁边轻轻说话,心就会慢慢落回身体里。", "🕯️", "通过温和陪伴和低压对话恢复落地感;应先陪伴,再解决。"), - ("先让身体换个节奏", "先让身体换个节奏", "身体换了节奏,心也会跟着松一点。走路、睡觉、音乐、吃点东西,都可能是一条回来的路。", "🌿", "身体节奏会带动心理恢复;可优先建议散步、休息、音乐、饮食或节奏重置。"), - ("完成一个很小动作", "完成一个很小动作", "把一件很小的事做完,会像在地上放下一颗钉子:不大,却能让人重新有一点掌控感。", "✅", "微小完成感能帮助恢复掌控;应把建议切成很小、能立刻完成的一步。"), - ("靠一点美感和仪式", "靠一点美感和仪式", "一点光线、音乐、整理、香气或小物件,能把散掉的自己慢慢召回来。", "✨", "审美、秩序、光线、音乐或小仪式能帮助回到自己;可用更有感官和仪式感的方式支持。"), + ( + "给我一块安静角落", + "给我一块安静角落", + "恢复有时不是被鼓励,而是先少一点声音、少一点催促。你需要一块不必解释自己的安静角落。", + "🌙", + "恢复时需要少输入、少催促、不必解释自己的空间;应降低打扰密度。", + ), + ( + "陪我轻轻说一会儿", + "陪我轻轻说一会儿", + "有时候不是要立刻解决什么,只是有人在旁边轻轻说话,心就会慢慢落回身体里。", + "🕯️", + "通过温和陪伴和低压对话恢复落地感;应先陪伴,再解决。", + ), + ( + "先让身体换个节奏", + "先让身体换个节奏", + "身体换了节奏,心也会跟着松一点。走路、睡觉、音乐、吃点东西,都可能是一条回来的路。", + "🌿", + "身体节奏会带动心理恢复;可优先建议散步、休息、音乐、饮食或节奏重置。", + ), + ( + "完成一个很小动作", + "完成一个很小动作", + "把一件很小的事做完,会像在地上放下一颗钉子:不大,却能让人重新有一点掌控感。", + "✅", + "微小完成感能帮助恢复掌控;应把建议切成很小、能立刻完成的一步。", + ), + ( + "靠一点美感和仪式", + "靠一点美感和仪式", + "一点光线、音乐、整理、香气或小物件,能把散掉的自己慢慢召回来。", + "✨", + "审美、秩序、光线、音乐或小仪式能帮助回到自己;可用更有感官和仪式感的方式支持。", + ), ("type", "都不像,我自己命名", "写一个短句就好", "✍️"), ("skip", "暂时留空", "", "➖"), ), @@ -571,20 +968,80 @@ def _mbti_choices(language: str = "en") -> tuple[tuple[str, ...], ...]: "en": "When a choice stays unresolved, what usually brings the answer closer?", "zh": "当一个选择还悬在那里,什么会让你离答案近一点?", "choices_en": ( - ("put trade-offs on paper", "Put trade-offs on paper", "Unresolved choices become clearer when trade-offs are written down and invisible factors become visible", "📝", "Externalizing and writing make hidden weights visible; help list trade-offs, costs, and what must be preserved."), - ("hear it spoken aloud", "Hear it spoken aloud", "Unresolved choices become clearer when spoken aloud, giving the problem a shape", "🗣️", "Speaking gives the problem shape; use conversational reflection, follow-up questions, and shared naming."), - ("lay out possible futures", "Lay out possible futures", "Unresolved choices become clearer by laying out possible futures and where each road leads", "🛤️", "Different paths need to be compared as lived future scenes; unfold possible futures rather than only listing pros and cons."), - ("try one small experiment", "Try one small experiment", "Unresolved choices become clearer through a small reversible experiment before deciding", "🧪", "Reversible experiments are a good way to gather feedback; design low-risk trials rather than forcing a one-shot decision."), - ("wait for the body signal", "Wait for the body signal", "Unresolved choices become clearer by noticing body signals like relief, resistance, energy, or fatigue", "🌡️", "Body signals help calibrate decisions; pay attention to relief, resistance, excitement, and fatigue."), + ( + "put trade-offs on paper", + "Put trade-offs on paper", + "Unresolved choices become clearer when trade-offs are written down and invisible factors become visible", + "📝", + "Externalizing and writing make hidden weights visible; help list trade-offs, costs, and what must be preserved.", + ), + ( + "hear it spoken aloud", + "Hear it spoken aloud", + "Unresolved choices become clearer when spoken aloud, giving the problem a shape", + "🗣️", + "Speaking gives the problem shape; use conversational reflection, follow-up questions, and shared naming.", + ), + ( + "lay out possible futures", + "Lay out possible futures", + "Unresolved choices become clearer by laying out possible futures and where each road leads", + "🛤️", + "Different paths need to be compared as lived future scenes; unfold possible futures rather than only listing pros and cons.", + ), + ( + "try one small experiment", + "Try one small experiment", + "Unresolved choices become clearer through a small reversible experiment before deciding", + "🧪", + "Reversible experiments are a good way to gather feedback; design low-risk trials rather than forcing a one-shot decision.", + ), + ( + "wait for the body signal", + "Wait for the body signal", + "Unresolved choices become clearer by noticing body signals like relief, resistance, energy, or fatigue", + "🌡️", + "Body signals help calibrate decisions; pay attention to relief, resistance, excitement, and fatigue.", + ), ("type", "None fit; I’ll name it", "A short decision cue", "✍️"), ("skip", "Leave this blank for now", "", "➖"), ), "choices_zh": ( - ("把取舍写到纸上", "把取舍写到纸上", "有些答案要先落到纸上才会显形。把取舍写出来,心里那些看不见的重量就有了位置。", "📝", "靠外化和书写看清选择里的隐形权重;应帮助列出取舍、代价和保留项。"), - ("说出来听听形状", "说出来听听形状", "话说出口之前,问题像一团雾;说出来以后,它会有边缘、有形状,也更容易被一起看见。", "🗣️", "通过表达来让问题成形;适合用对话复述、追问和共同命名。"), - ("把几种未来摆开", "把几种未来摆开", "你需要的不只是选项列表,而是看见每条路会把生活带向哪里,哪一种未来更像你。", "🛤️", "需要比较不同路径导向的生活图景;应帮助展开未来场景,而不是只列优缺点。"), - ("先做一个小实验", "先做一个小实验", "不用一下子把门关死。先试一个可逆的小动作,身体和现实都会给出一点回音。", "🧪", "适合通过可逆试探获得反馈;应设计低风险实验,而不是要求一次性定案。"), - ("等身体先给信号", "等身体先给信号", "有时候答案不是先从脑子里来,而是从身体里冒出来:放松、抗拒、兴奋,或者忽然很累。", "🌡️", "会用身体感受校准决定;应关注放松、抗拒、兴奋和疲惫等体感线索。"), + ( + "把取舍写到纸上", + "把取舍写到纸上", + "有些答案要先落到纸上才会显形。把取舍写出来,心里那些看不见的重量就有了位置。", + "📝", + "靠外化和书写看清选择里的隐形权重;应帮助列出取舍、代价和保留项。", + ), + ( + "说出来听听形状", + "说出来听听形状", + "话说出口之前,问题像一团雾;说出来以后,它会有边缘、有形状,也更容易被一起看见。", + "🗣️", + "通过表达来让问题成形;适合用对话复述、追问和共同命名。", + ), + ( + "把几种未来摆开", + "把几种未来摆开", + "你需要的不只是选项列表,而是看见每条路会把生活带向哪里,哪一种未来更像你。", + "🛤️", + "需要比较不同路径导向的生活图景;应帮助展开未来场景,而不是只列优缺点。", + ), + ( + "先做一个小实验", + "先做一个小实验", + "不用一下子把门关死。先试一个可逆的小动作,身体和现实都会给出一点回音。", + "🧪", + "适合通过可逆试探获得反馈;应设计低风险实验,而不是要求一次性定案。", + ), + ( + "等身体先给信号", + "等身体先给信号", + "有时候答案不是先从脑子里来,而是从身体里冒出来:放松、抗拒、兴奋,或者忽然很累。", + "🌡️", + "会用身体感受校准决定;应关注放松、抗拒、兴奋和疲惫等体感线索。", + ), ("type", "都不像,我自己命名", "写一个短句就好", "✍️"), ("skip", "暂时留空", "", "➖"), ), @@ -622,19 +1079,19 @@ def _mbti_choices(language: str = "en") -> tuple[tuple[str, ...], ...]: ), ) _SAFETY_FIELD_LABELS = { - field_id: (title_en, title_zh) - for field_id, title_en, title_zh, _prompt_en, _prompt_zh in _SAFETY_PROMPTS + field_id: (title_en, title_zh) for field_id, title_en, title_zh, _prompt_en, _prompt_zh in _SAFETY_PROMPTS } _SAFETY_LABEL_TO_FIELD = { - label.casefold(): field_id - for field_id, labels in _SAFETY_FIELD_LABELS.items() - for label in (field_id, *labels) + label.casefold(): field_id for field_id, labels in _SAFETY_FIELD_LABELS.items() for label in (field_id, *labels) } _SAFETY_FACT_TEMPLATES = { "food_allergies": ("食物过敏:{value}。", "Food allergies: {value}."), "medication_allergies": ("药物过敏:{value}。", "Medication allergies: {value}."), "chronic_conditions": ("健康注意事项:{value}。", "Health notes: {value}."), - "trauma_history": ("不愿给别人说、藏在心里的秘密:{value}。", "Secrets you keep inside: {value}."), + "trauma_history": ( + "不愿给别人说、藏在心里的秘密:{value}。", + "Secrets you keep inside: {value}.", + ), } @@ -668,17 +1125,36 @@ def _print_init_section(language: str, title_en: str, title_zh: str, body_en: st _print_heading(title, body) return console = Console(highlight=False, soft_wrap=True) - console.print(Panel(body, title=f"[bold {BRAND_ACCENT}]{title}[/bold {BRAND_ACCENT}]", border_style=BRAND_ACCENT, padding=(1, 2))) + console.print( + Panel( + body, + title=f"[bold {BRAND_ACCENT}]{title}[/bold {BRAND_ACCENT}]", + border_style=BRAND_ACCENT, + padding=(1, 2), + ) + ) def _starter_question_model_hints(question_id: str) -> dict[str, str]: topic_map = { "inner_landscape": {"lens": "pulse", "topic": "pulse.mood.inner_landscape"}, - "value_anchor": {"lens": "identity", "topic": "identity.values.trade_off_anchor"}, + "value_anchor": { + "lens": "identity", + "topic": "identity.values.trade_off_anchor", + }, "recent_resonance": {"lens": "pulse", "topic": "pulse.mood.recent_resonance"}, - "pressure_pattern": {"lens": "identity", "topic": "identity.character.rhythm.pressure"}, - "recovery_style": {"lens": "identity", "topic": "identity.character.rhythm.recovery"}, - "decision_compass": {"lens": "identity", "topic": "identity.character.decision.compass"}, + "pressure_pattern": { + "lens": "identity", + "topic": "identity.character.rhythm.pressure", + }, + "recovery_style": { + "lens": "identity", + "topic": "identity.character.rhythm.recovery", + }, + "decision_compass": { + "lens": "identity", + "topic": "identity.character.decision.compass", + }, } return topic_map.get(question_id, {}) @@ -752,11 +1228,19 @@ def _run_embedding_birth_wizard( ) -> tuple[str, str, str, str, int | None, str | None] | _WizardBackSignal: provider = _wizard_choice_prompt( _init_text(language, "Choose Embedding Recall", "选择记忆嵌入方式"), - _init_text(language, "How should Elephant Agent's evidence grow to know you?", "Elephant Agent 应该怎样建立可检索的记忆来了解你?"), + _init_text( + language, + "How should Elephant Agent's evidence grow to know you?", + "Elephant Agent 应该怎样建立可检索的记忆来了解你?", + ), ( WizardChoice( value="local", - label=_init_text(language, "Local embedding (recommended & free)", "本地嵌入(推荐 & 免费)"), + label=_init_text( + language, + "Local embedding (recommended & free)", + "本地嵌入(推荐 & 免费)", + ), detail=_init_text( language, "Powered by sentence-transformers. Runs entirely on your machine.", @@ -765,8 +1249,16 @@ def _run_embedding_birth_wizard( ), WizardChoice( value="openai-compatible", - label=_init_text(language, "Embedding provider (paid & accuracy first)", "嵌入模型服务(付费 & 精度优先)"), - detail=_init_text(language, "Use an OpenAI-compatible embedding endpoint.", "使用 OpenAI-compatible 的嵌入接口。"), + label=_init_text( + language, + "Embedding provider (paid & accuracy first)", + "嵌入模型服务(付费 & 精度优先)", + ), + detail=_init_text( + language, + "Use an OpenAI-compatible embedding endpoint.", + "使用 OpenAI-compatible 的嵌入接口。", + ), ), ), default=default_provider or "local", @@ -850,14 +1342,25 @@ def _run_embedding_birth_wizard( dimensions = default_dimensions or 1024 api_key = _wizard_text_prompt( _init_text(language, "Embedding Key", "嵌入接口密钥"), - _init_text(language, "Enter an embedding key if this endpoint needs one.", "如果这个接口需要密钥,请输入。"), + _init_text( + language, + "Enter an embedding key if this endpoint needs one.", + "如果这个接口需要密钥,请输入。", + ), default=None, allow_back=True, password=True, ) if api_key is WIZARD_BACK: return WIZARD_BACK - return (selected, "", str(base_url).strip(), str(model).strip(), dimensions, str(api_key).strip() or None) + return ( + selected, + "", + str(base_url).strip(), + str(model).strip(), + dimensions, + str(api_key).strip() or None, + ) def _mapping_or_empty(value: object) -> dict[str, object]: @@ -893,8 +1396,27 @@ def _infer_init_companion_posture(bootstrap_state: object, *, language: str) -> for token in ("quiet", "安静", "room", "房间", "walk", "走") ) or mbti in {"INFJ", "INFP", "INTJ", "INTP", "ISFJ", "ISFP"} action_signals = any( - token in " ".join((pressure, decision, recovery, str(getattr(bootstrap_state, "occupation", "")))).lower() - for token in ("experiment", "实验", "project", "项目", "next step", "下一步", "plan", "计划", "move fast", "先动") + token + in " ".join( + ( + pressure, + decision, + recovery, + str(getattr(bootstrap_state, "occupation", "")), + ) + ).lower() + for token in ( + "experiment", + "实验", + "project", + "项目", + "next step", + "下一步", + "plan", + "计划", + "move fast", + "先动", + ) ) or mbti in {"ENTJ", "ESTJ", "ESTP", "ISTP"} if language == "zh": if quiet_signals and not action_signals: @@ -914,15 +1436,40 @@ def _learned_init_entries(language: str, bootstrap_state: object) -> list[tuple[ is_zh = language == "zh" entries: list[tuple[str, dict[str, str]]] = [] if is_zh: - entries.append(("中文", {"field": "first_language", **_INIT_FIELD_MODEL_HINTS["first_language"]})) + entries.append( + ( + "中文", + { + "field": "first_language", + **_INIT_FIELD_MODEL_HINTS["first_language"], + }, + ) + ) else: - entries.append(("English", {"field": "first_language", **_INIT_FIELD_MODEL_HINTS["first_language"]})) + entries.append( + ( + "English", + { + "field": "first_language", + **_INIT_FIELD_MODEL_HINTS["first_language"], + }, + ) + ) def add(field: str, value: object, extra: dict[str, str] | None = None) -> None: cleaned = str(value or "").strip() if not cleaned: return - entries.append((cleaned, {"field": field, **_INIT_FIELD_MODEL_HINTS.get(field, {}), **(extra or {})})) + entries.append( + ( + cleaned, + { + "field": field, + **_INIT_FIELD_MODEL_HINTS.get(field, {}), + **(extra or {}), + }, + ) + ) add("preferred_name", getattr(bootstrap_state, "preferred_name", "")) add("occupation", getattr(bootstrap_state, "occupation", "")) @@ -936,7 +1483,16 @@ def add(field: str, value: object, extra: dict[str, str] | None = None) -> None: text = f"MBTI:{mbti};特征参考:{traits}" if traits else f"MBTI:{mbti}" else: text = f"MBTI: {mbti}; trait reference: {traits}" if traits else f"MBTI: {mbti}" - entries.append((text, {"field": "mbti", "mbti_traits": traits, **_INIT_FIELD_MODEL_HINTS["mbti"]})) + entries.append( + ( + text, + { + "field": "mbti", + "mbti_traits": traits, + **_INIT_FIELD_MODEL_HINTS["mbti"], + }, + ) + ) add("hobbies", getattr(bootstrap_state, "hobbies", "")) for field_id, value in _init_care_entries(bootstrap_state): entries.append((value, {"field": field_id, **_INIT_FIELD_MODEL_HINTS[field_id]})) @@ -948,7 +1504,15 @@ def add(field: str, value: object, extra: dict[str, str] | None = None) -> None: entries.append((answer, {"field": question_id, **hints})) posture = _infer_init_companion_posture(bootstrap_state, language=language) - entries.append((posture, {"field": "inferred_companion_posture", **_INIT_FIELD_MODEL_HINTS["inferred_companion_posture"]})) + entries.append( + ( + posture, + { + "field": "inferred_companion_posture", + **_INIT_FIELD_MODEL_HINTS["inferred_companion_posture"], + }, + ) + ) return entries @@ -1066,7 +1630,9 @@ def _go_back() -> bool: if not _go_back(): return None continue - state.occupation = str(occupation).strip() or _choice_saved_value(attention_choices, str(attention_choices[0][0])) + state.occupation = str(occupation).strip() or _choice_saved_value( + attention_choices, str(attention_choices[0][0]) + ) gender = _prompt_choice_with_type( state.first_language, @@ -1221,7 +1787,9 @@ def _go_back() -> bool: step_index += 1 continue if step == "learning_intensity": - answer = _prompt_learning_intensity(state.learning_intensity, allow_back=True, language=state.first_language) + answer = _prompt_learning_intensity( + state.learning_intensity, allow_back=True, language=state.first_language + ) if answer is WIZARD_CANCEL: return None if answer is WIZARD_BACK: @@ -1268,6 +1836,7 @@ def _persist_init_question_config(runtime: CliRuntime, *, first_language: str, l load_global_config, write_global_config, ) + config_path = global_config_path_for_state_dir(runtime.paths.state_dir) config = load_global_config(config_path, state_dir=runtime.paths.state_dir) question_config = personal_model_question_config_from_global(config) @@ -1284,7 +1853,9 @@ def _persist_init_question_config(runtime: CliRuntime, *, first_language: str, l return -def _proactive_ask_config_for_learning_intensity(learning_intensity: str) -> dict[str, object]: +def _proactive_ask_config_for_learning_intensity( + learning_intensity: str, +) -> dict[str, object]: intensity = str(learning_intensity or "").strip().lower() if intensity == "low": return {"idle_threshold_minutes": 720, "daily_max": 2, "quiet_hours": [23, 7]} @@ -1370,6 +1941,7 @@ def _bootstrap_personal_model_from_init(runtime: CliRuntime, session, bootstrap_ language = _normalize_first_language(getattr(bootstrap_state, "first_language", "en")) try: from dataclasses import replace as _dc_replace + profile = runtime.repository.load_personal_model_runtime_state(personal_model_id) if profile is not None: preferences = list(tuple(getattr(profile, "preferences", ()) or ())) @@ -1410,9 +1982,7 @@ def _bootstrap_personal_model_from_init(runtime: CliRuntime, session, bootstrap_ repository=runtime.repository, semantic_summary_indexer=semantic_summary_indexer, semantic_searcher=( - runtime.semantic_index_bundle.searcher - if runtime.semantic_index_bundle is not None - else None + runtime.semantic_index_bundle.searcher if runtime.semantic_index_bundle is not None else None ), embedding_service=embedding_service, ) @@ -1441,9 +2011,7 @@ def _bootstrap_personal_model_from_init(runtime: CliRuntime, session, bootstrap_ summary="initial profile and skill-affinity learning", metadata=_init_profile_learning_metadata( bootstrap_state, - learning_intensity=str( - getattr(bootstrap_state, "learning_intensity", "medium") or "medium" - ), + learning_intensity=str(getattr(bootstrap_state, "learning_intensity", "medium") or "medium"), language=language, ), ) @@ -1473,7 +2041,6 @@ def _bootstrap_personal_model_from_init(runtime: CliRuntime, session, bootstrap_ pass - def _run_setup(runtime: CliRuntime, args: argparse.Namespace) -> int: provider_id = args.provider_id loaded = runtime.current_profile() @@ -1486,11 +2053,14 @@ def _run_setup(runtime: CliRuntime, args: argparse.Namespace) -> int: else: display_name = _suggest_elephant_name(runtime) mode = "companion" - personality_preset = _default_personality_preset( - runtime, - mode=mode, - current=(loaded.companion.personality_preset if loaded.companion is not None else None), - ) or "companion" + personality_preset = ( + _default_personality_preset( + runtime, + mode=mode, + current=(loaded.companion.personality_preset if loaded.companion is not None else None), + ) + or "companion" + ) initiative = loaded.companion.initiative if loaded.companion is not None else "gentle" requested_elephant_identity_text = getattr(args, "elephant_identity_text", None) secret_env_var = getattr(args, "secret_env_var", None) @@ -1651,6 +2221,7 @@ def _run_setup(runtime: CliRuntime, args: argparse.Namespace) -> int: profile_state = runtime.repository.load_personal_model_runtime_state(configured.state.profile_id) if profile_state is not None: from dataclasses import replace as _dc_replace + runtime.repository.upsert_personal_model_runtime_state( _dc_replace(profile_state, learning_intensity=learning_intensity) ) @@ -1689,6 +2260,7 @@ def _run_setup(runtime: CliRuntime, args: argparse.Namespace) -> int: ) try: from dataclasses import replace as _dc_replace + profile_state = runtime.repository.load_personal_model_runtime_state(first_elephant.personal_model_id) if profile_state is not None: runtime.repository.upsert_personal_model_runtime_state( @@ -1698,7 +2270,10 @@ def _run_setup(runtime: CliRuntime, args: argparse.Namespace) -> int: pass _bootstrap_personal_model_from_init(runtime, first_elephant, bootstrap_state) if first_elephant_status == "created": - _play_creating_transition("Elephant Agent init", f"{display_name} is becoming a continuing personal AI thread.") + _play_creating_transition( + "Elephant Agent init", + f"{display_name} is becoming a continuing personal AI thread.", + ) readiness_lines = [ f"elephant · {runtime.elephant_id_for_session(first_elephant)}", f"status · {first_elephant_status}", @@ -1733,6 +2308,7 @@ def _run_setup(runtime: CliRuntime, args: argparse.Namespace) -> int: ) return 0 + def _run_brain(runtime: CliRuntime, args: argparse.Namespace) -> int: action = str(getattr(args, "provider_command", "configure") or "configure") if action == "status": @@ -1759,7 +2335,12 @@ def _run_brain(runtime: CliRuntime, args: argparse.Namespace) -> int: ) initial_state.api_key = args.api_key initial_state.reasoning_effort = ( - str(getattr(args, "reasoning_effort", None) or provider.get("reasoning_effort") or initial_state.reasoning_effort).strip() or None + str( + getattr(args, "reasoning_effort", None) + or provider.get("reasoning_effort") + or initial_state.reasoning_effort + ).strip() + or None ) if args.context_window_mode is not None: initial_state.context_window_mode = args.context_window_mode @@ -1796,7 +2377,9 @@ def _run_brain(runtime: CliRuntime, args: argparse.Namespace) -> int: and not configured.api_key and not _provider_secret_ready(runtime, provider_id=configured.provider_id) ): - raise SystemExit("provider requires a provider key for API-key providers; rerun interactively or pass --api-key") + raise SystemExit( + "provider requires a provider key for API-key providers; rerun interactively or pass --api-key" + ) context_window_tokens = configured.context_window_tokens if context_window_tokens is None and configured.model_id: @@ -1845,7 +2428,10 @@ def _run_embedding_setup_wizard(runtime: CliRuntime) -> int: # Detect user's first language from global config. language = "en" try: - from packages.runtime_config import global_config_path_for_state_dir, load_global_config + from packages.runtime_config import ( + global_config_path_for_state_dir, + load_global_config, + ) config_path = global_config_path_for_state_dir(runtime.paths.state_dir) config = load_global_config(config_path, state_dir=runtime.paths.state_dir) @@ -1885,7 +2471,10 @@ def _run_embedding_setup_wizard(runtime: CliRuntime) -> int: "Embedding provider updated", "Elephant Agent will use the local embedding model for semantic retrieval.", sections=tuple(sections), - next_commands=("elephant provider embeddings status", "elephant provider status"), + next_commands=( + "elephant provider embeddings status", + "elephant provider status", + ), ) else: if not base_url or not model or dimensions is None: @@ -1914,7 +2503,10 @@ def _run_embedding_setup_wizard(runtime: CliRuntime) -> int: ), ), ), - next_commands=("elephant provider embeddings status", "elephant provider status"), + next_commands=( + "elephant provider embeddings status", + "elephant provider status", + ), ) return 0 @@ -1950,7 +2542,10 @@ def _run_embedding_provider(runtime: CliRuntime, args: argparse.Namespace) -> in "Embedding provider updated", "Elephant Agent will fall back to the local embedding default for semantic retrieval.", sections=tuple(sections), - next_commands=("elephant provider embeddings status", "elephant provider status"), + next_commands=( + "elephant provider embeddings status", + "elephant provider status", + ), ) return 0 if action != "openai-compatible": @@ -1994,10 +2589,14 @@ def _run_embedding_provider(runtime: CliRuntime, args: argparse.Namespace) -> in ), ), ), - next_commands=("elephant provider embeddings status", "elephant provider status"), + next_commands=( + "elephant provider embeddings status", + "elephant provider status", + ), ) return 0 + def _run_elephant(runtime: CliRuntime, args: argparse.Namespace) -> int: report = runtime.provider_doctor() if not _provider_session_ready(report): @@ -2006,7 +2605,10 @@ def _run_elephant(runtime: CliRuntime, args: argparse.Namespace) -> int: raw_elephant_name = args.elephant_name interactive_shell = _interactive_shell_supported() if raw_elephant_name is None and not interactive_shell: - _print_heading("Name needed", "Run elephant herd new , or rerun in a TTY and Elephant Agent will ask you.") + _print_heading( + "Name needed", + "Run elephant herd new , or rerun in a TTY and Elephant Agent will ask you.", + ) _print_command_hints("elephant herd new ", "elephant wake", "elephant herd") return 1 if interactive_shell and raw_elephant_name is None: @@ -2039,10 +2641,16 @@ def _run_elephant(runtime: CliRuntime, args: argparse.Namespace) -> int: _print_assistant_turn(runtime, outcome) return 0 if _interactive_shell_supported(): - return ProductizedShell(runtime, session_id=session.episode_id, opened="Shaped new", debug=args.debug).run() + return ProductizedShell( + runtime, + session_id=session.episode_id, + opened="Shaped new", + debug=args.debug, + ).run() _print_elephant_created(runtime, session.episode_id) return 0 + def _run_herd(runtime: CliRuntime, args: argparse.Namespace) -> int: if args.herd_command is None: _print_herd(runtime) @@ -2064,7 +2672,11 @@ def _run_herd(runtime: CliRuntime, args: argparse.Namespace) -> int: _print_cli_card( "Elephant selection paused", "No current elephant was changed.", - next_commands=("elephant herd", "elephant wake", "elephant herd new "), + next_commands=( + "elephant herd", + "elephant wake", + "elephant herd new ", + ), ) return 0 elephant_id = selected.elephant_id @@ -2175,7 +2787,11 @@ def _print_fact_list(runtime: CliRuntime, *, elephant_id: str | None = None) -> status_breakdown = ", ".join(_fact_status_breakdown(entries)) or "" fact_line_list: list[str] = [] for entry in entries[:10]: - timestamp = entry.committed_at.isoformat(timespec="seconds") if getattr(entry, "committed_at", None) is not None else "" + timestamp = ( + entry.committed_at.isoformat(timespec="seconds") + if getattr(entry, "committed_at", None) is not None + else "" + ) metadata = dict(getattr(entry, "metadata", {}) or {}) facet = str(metadata.get("facet") or metadata.get("topic") or "claim").strip() fact_line_list.append(f"{entry.fact_id} · {entry.lens}.{facet} · status={entry.status} · {timestamp}") @@ -2206,16 +2822,23 @@ def _print_fact_list(runtime: CliRuntime, *, elephant_id: str | None = None) -> ) -def _delete_personal_model_fact(runtime: CliRuntime, *, elephant_id: str | None, fact_id: str, reason: str | None) -> None: +def _delete_personal_model_fact( + runtime: CliRuntime, *, elephant_id: str | None, fact_id: str, reason: str | None +) -> None: session, state, resolved_elephant_id = _resolve_fact_target(runtime, elephant_id=elephant_id) owner_id = _fact_owner_id(session, state) deletion_reason = reason or "fact retired from elephant evidence command" - facts = tuple(runtime.repository.list_personal_model_facts(personal_model_id=owner_id, status=("active", "retired", "disputed"))) + facts = tuple( + runtime.repository.list_personal_model_facts( + personal_model_id=owner_id, status=("active", "retired", "disputed") + ) + ) current = next((fact for fact in facts if getattr(fact, "fact_id", "") == fact_id), None) if current is None: raise ValueError(f"unknown Personal Model entry: {fact_id}") from dataclasses import replace as _dc_replace from datetime import datetime, timezone + updated = _dc_replace( current, status="deleted", @@ -2257,7 +2880,11 @@ def _run_facts(runtime: CliRuntime, args: argparse.Namespace) -> int: _print_cli_card( "Elephant Agent evidence", "No elephant is available yet.", - next_commands=("elephant init", "elephant herd new ", "elephant wake"), + next_commands=( + "elephant init", + "elephant herd new ", + "elephant wake", + ), ) return 1 command = args.facts_command or "list" @@ -2325,7 +2952,10 @@ def _learning_job_lines(jobs: Iterable[object], *, runtime: CliRuntime | None = def _learning_worker_lines(runtime: CliRuntime) -> tuple[str, ...]: - from apps.learning_worker_runtime import load_learning_worker_record, learning_worker_is_running + from apps.learning_worker_runtime import ( + load_learning_worker_record, + learning_worker_is_running, + ) record = load_learning_worker_record(runtime.paths.state_dir) or {} return ( @@ -2346,7 +2976,11 @@ def _print_learning_history(runtime: CliRuntime, *, limit: int) -> None: CliCardSection("Worker", _learning_worker_lines(runtime)), CliCardSection("Jobs", _learning_job_lines(jobs, runtime=runtime)), ), - next_commands=("elephant reflect status", "elephant reflect start", "elephant wake"), + next_commands=( + "elephant reflect status", + "elephant reflect start", + "elephant wake", + ), ) @@ -2392,7 +3026,11 @@ def _print_learning_status(runtime: CliRuntime, *, elephant_id: str | None, limi CliCardSection("Counts", tuple(lines)), CliCardSection("Recent jobs", tuple(job_lines) or ("",)), ), - next_commands=("elephant reflect queue", "elephant reflect run", "elephant reflect history"), + next_commands=( + "elephant reflect queue", + "elephant reflect run", + "elephant reflect history", + ), ) @@ -2475,7 +3113,9 @@ def _run_learn(runtime: CliRuntime, args: argparse.Namespace) -> int: ) worker_exit_code = int(completed.returncode or 0) if worker_exit_code: - from apps.learning_worker_runtime import mark_learning_job_terminal_failure + from apps.learning_worker_runtime import ( + mark_learning_job_terminal_failure, + ) mark_learning_job_terminal_failure( runtime, @@ -2499,7 +3139,11 @@ def _run_learn(runtime: CliRuntime, args: argparse.Namespace) -> int: ), ), ), - next_commands=("elephant reflect list", "elephant reflect kill", "elephant wake"), + next_commands=( + "elephant reflect list", + "elephant reflect kill", + "elephant wake", + ), ) return worker_exit_code raise ValueError(f"unknown learn command: {command}") @@ -2562,7 +3206,11 @@ def _run_grow(runtime: CliRuntime, args: argparse.Namespace) -> int: _print_cli_card( "Grow paused", "No elephant was selected.", - next_commands=("elephant wake", "elephant herd", "elephant herd new "), + next_commands=( + "elephant wake", + "elephant herd", + "elephant herd new ", + ), ) return 0 except LookupError: @@ -2584,6 +3232,7 @@ def _run_grow(runtime: CliRuntime, args: argparse.Namespace) -> int: runtime.prepare_session_surface(episode_id) return _run_stream_grow_loop(runtime, episode_id, sys.stdin) + def _run_stream_grow_loop(runtime: CliRuntime, session_id: str, stream: Iterable[str]) -> int: for line in stream: prompt = line.rstrip("\n").strip() @@ -2597,6 +3246,7 @@ def _run_stream_grow_loop(runtime: CliRuntime, session_id: str, stream: Iterable _print_assistant_turn(runtime, outcome) return 0 + def _run_default_entry(runtime: CliRuntime) -> int: _print_root_cli_help() return 0 @@ -2616,8 +3266,14 @@ def _show_cli_banner() -> None: console = Console(highlight=False, soft_wrap=True) header = Text() header.append("🐘 Elephant Agent CLI\n", style=f"bold {BRAND_LIGHT}") - header.append("A warm, steady way back to the elephant that remembers your path.\n", style=BRAND_MUTED) - header.append(f"🐾 v{_resolve_elephant_version()} · here with you, built to stay.", style=BRAND_ACCENT) + header.append( + "A warm, steady way back to the elephant that remembers your path.\n", + style=BRAND_MUTED, + ) + header.append( + f"🐾 v{_resolve_elephant_version()} · here with you, built to stay.", + style=BRAND_ACCENT, + ) console.print( Panel( Group( @@ -2625,7 +3281,10 @@ def _show_cli_banner() -> None: Text(" "), Align.center(_render_cli_banner_mark()), Text(" "), - Text("Model what matters · ask gently · follow the path", style=BRAND_LIGHT), + Text( + "Model what matters · ask gently · follow the path", + style=BRAND_LIGHT, + ), ), border_style=BRAND_ACCENT, title=f"[bold {BRAND_ACCENT}]Welcome[/bold {BRAND_ACCENT}]", @@ -2644,7 +3303,10 @@ def _print_root_cli_help() -> None: commands=CLI_HELP_COMMANDS, options=( ("--help", "Show this message and exit."), - ("--no-animation", "Prefer steady output over animated transitions when the terminal supports motion."), + ( + "--no-animation", + "Prefer steady output over animated transitions when the terminal supports motion.", + ), ("--color ", "Control colorized output."), ), next_commands=CLI_HELP_NEXT_COMMANDS, @@ -2727,36 +3389,128 @@ def main_callback( @app.command("init") def init_command( ctx: typer.Context, - provider_id: str = typer.Option(DEFAULT_PROVIDER_ID, "--provider-id", help="Provider id to configure for dialogue turns."), - display_name: str | None = typer.Option(None, "--display-name", help="Display name to persist for the active profile."), - elephant_text: str | None = typer.Option(None, "--elephant-text", help="Optional identity text for the first elephant."), - elephant_name: str | None = typer.Option(None, "--elephant-name", help="Name for the first elephant created during init."), + provider_id: str = typer.Option( + DEFAULT_PROVIDER_ID, + "--provider-id", + help="Provider id to configure for dialogue turns.", + ), + display_name: str | None = typer.Option( + None, + "--display-name", + help="Display name to persist for the active profile.", + ), + elephant_text: str | None = typer.Option( + None, + "--elephant-text", + help="Optional identity text for the first elephant.", + ), + elephant_name: str | None = typer.Option( + None, + "--elephant-name", + help="Name for the first elephant created during init.", + ), base_url: str | None = typer.Option(None, "--base-url", help="Provider base URL."), model_id: str | None = typer.Option(None, "--model-id", help="Dialogue model id to save as default."), api_key: str | None = typer.Option(None, "--api-key", help="Provider API key to persist or use immediately."), - secret_env_var: str | None = typer.Option(None, "--secret-env-var", help="Environment variable name to read the provider key from."), - embedding_provider: str = typer.Option("local", "--embedding-provider", help="Embedding provider kind: local or openai-compatible."), - embedding_base_url: str | None = typer.Option(None, "--embedding-base-url", help="Embedding provider base URL."), + secret_env_var: str | None = typer.Option( + None, + "--secret-env-var", + help="Environment variable name to read the provider key from.", + ), + embedding_provider: str = typer.Option( + "local", + "--embedding-provider", + help="Embedding provider kind: local or openai-compatible.", + ), + embedding_base_url: str | None = typer.Option( + None, "--embedding-base-url", help="Embedding provider base URL." + ), embedding_model: str | None = typer.Option(None, "--embedding-model", help="Embedding model id."), - embedding_dimensions: str | None = typer.Option(None, "--embedding-dimensions", help="Embedding vector dimensions."), + embedding_dimensions: str | None = typer.Option( + None, "--embedding-dimensions", help="Embedding vector dimensions." + ), embedding_api_key: str | None = typer.Option(None, "--embedding-api-key", help="Embedding API key."), - embedding_secret_env_var: str | None = typer.Option(None, "--embedding-secret-env-var", help="Environment variable name for the embedding provider key."), - context_window_mode: str | None = typer.Option(None, "--context-window-mode", help="Context window selection mode."), - context_window: str | None = typer.Option(None, "--context-window", help="Explicit context window token count."), - first_language: str = typer.Option("en", "--first-language", help="User first language for Personal Model bootstrap: en or zh."), - learning_intensity: str = typer.Option("medium", "--learning-intensity", help="Personal Model question cadence tier: low, medium, or high."), - preferred_name: str | None = typer.Option(None, "--preferred-name", help="Preferred name for Personal Model bootstrap."), - age: str | None = typer.Option(None, "--age", help="Optional age or age range for Personal Model bootstrap."), - birth_date: str | None = typer.Option(None, "--birth-date", help="Optional birth date for Personal Model bootstrap."), - gender: str | None = typer.Option(None, "--gender", help="Optional gender/self-description for Personal Model bootstrap."), - occupation: str | None = typer.Option(None, "--occupation", help="Optional role or occupation for Personal Model bootstrap."), - city: str | None = typer.Option(None, "--city", help="Optional city or timezone for Personal Model bootstrap."), - mbti: str | None = typer.Option(None, "--mbti", help="Optional MBTI/self-label for Personal Model bootstrap."), - hobbies: str | None = typer.Option(None, "--hobbies", help="Optional comma-separated personal hobbies for Personal Model bootstrap."), - astrology: str | None = typer.Option(None, "--astrology", help="Optional astrology/zodiac self-label for Personal Model bootstrap."), - safety_boundaries: str | None = typer.Option(None, "--safety-boundaries", help="Optional boundaries Elephant Agent should respect."), - communication_preference: str | None = typer.Option(None, "--communication-preference", help="Optional communication preference for Personal Model bootstrap."), - relationship_mode: str | None = typer.Option(None, "--relationship-mode", help="Optional starting relationship mode for Personal Model bootstrap."), + embedding_secret_env_var: str | None = typer.Option( + None, + "--embedding-secret-env-var", + help="Environment variable name for the embedding provider key.", + ), + context_window_mode: str | None = typer.Option( + None, "--context-window-mode", help="Context window selection mode." + ), + context_window: str | None = typer.Option( + None, "--context-window", help="Explicit context window token count." + ), + first_language: str = typer.Option( + "en", + "--first-language", + help="User first language for Personal Model bootstrap: en or zh.", + ), + learning_intensity: str = typer.Option( + "medium", + "--learning-intensity", + help="Personal Model question cadence tier: low, medium, or high.", + ), + preferred_name: str | None = typer.Option( + None, + "--preferred-name", + help="Preferred name for Personal Model bootstrap.", + ), + age: str | None = typer.Option( + None, + "--age", + help="Optional age or age range for Personal Model bootstrap.", + ), + birth_date: str | None = typer.Option( + None, + "--birth-date", + help="Optional birth date for Personal Model bootstrap.", + ), + gender: str | None = typer.Option( + None, + "--gender", + help="Optional gender/self-description for Personal Model bootstrap.", + ), + occupation: str | None = typer.Option( + None, + "--occupation", + help="Optional role or occupation for Personal Model bootstrap.", + ), + city: str | None = typer.Option( + None, + "--city", + help="Optional city or timezone for Personal Model bootstrap.", + ), + mbti: str | None = typer.Option( + None, + "--mbti", + help="Optional MBTI/self-label for Personal Model bootstrap.", + ), + hobbies: str | None = typer.Option( + None, + "--hobbies", + help="Optional comma-separated personal hobbies for Personal Model bootstrap.", + ), + astrology: str | None = typer.Option( + None, + "--astrology", + help="Optional astrology/zodiac self-label for Personal Model bootstrap.", + ), + safety_boundaries: str | None = typer.Option( + None, + "--safety-boundaries", + help="Optional boundaries Elephant Agent should respect.", + ), + communication_preference: str | None = typer.Option( + None, + "--communication-preference", + help="Optional communication preference for Personal Model bootstrap.", + ), + relationship_mode: str | None = typer.Option( + None, + "--relationship-mode", + help="Optional starting relationship mode for Personal Model bootstrap.", + ), non_interactive: bool = typer.Option(False, "--non-interactive", help="Skip wizards and rely on flags only."), ) -> None: params = ctx.parent.params if ctx.parent is not None else ctx.params @@ -2809,7 +3563,9 @@ def status_command( @app.command("wake") def wake_command( ctx: typer.Context, - elephant_id: str | None = typer.Option(None, "--elephant-id", help="Open the next Episode for a known elephant."), + elephant_id: str | None = typer.Option( + None, "--elephant-id", help="Open the next Episode for a known elephant." + ), debug: bool = typer.Option(False, "--debug", help="Show runtime diagnostics inside the wake surface."), message: str | None = typer.Option(None, "--message", help="Run one wake turn and exit."), ) -> None: @@ -2857,7 +3613,9 @@ def provider_catalog_command(ctx: typer.Context) -> None: @provider_app.command("models") def provider_models_command( ctx: typer.Context, - provider_id: str | None = typer.Option(None, "--provider-id", help="Inspect models for a specific provider id."), + provider_id: str | None = typer.Option( + None, "--provider-id", help="Inspect models for a specific provider id." + ), ) -> None: params = ctx.parent.parent.params if ctx.parent is not None and ctx.parent.parent is not None else ctx.params runtime = _cli_runtime(params["state_dir"]) @@ -2870,10 +3628,22 @@ def provider_configure_command( base_url: str | None = typer.Option(None, "--base-url", help="Provider base URL."), model_id: str | None = typer.Option(None, "--model-id", help="Dialogue model id."), api_key: str | None = typer.Option(None, "--api-key", help="Provider API key."), - secret_env_var: str | None = typer.Option(None, "--secret-env-var", help="Environment variable name to read the provider key from."), - reasoning_effort: str | None = typer.Option(None, "--reasoning-effort", help="Reasoning effort to save for the active model."), - context_window_mode: str | None = typer.Option(None, "--context-window-mode", help="Context window selection mode."), - context_window: str | None = typer.Option(None, "--context-window", help="Explicit context window token count."), + secret_env_var: str | None = typer.Option( + None, + "--secret-env-var", + help="Environment variable name to read the provider key from.", + ), + reasoning_effort: str | None = typer.Option( + None, + "--reasoning-effort", + help="Reasoning effort to save for the active model.", + ), + context_window_mode: str | None = typer.Option( + None, "--context-window-mode", help="Context window selection mode." + ), + context_window: str | None = typer.Option( + None, "--context-window", help="Explicit context window token count." + ), non_interactive: bool = typer.Option(False, "--non-interactive", help="Skip interactive provider selection."), ) -> None: params = ctx.parent.parent.params if ctx.parent is not None and ctx.parent.parent is not None else ctx.params @@ -2894,25 +3664,56 @@ def provider_configure_command( @provider_embeddings_app.command("status") def provider_embeddings_status_command(ctx: typer.Context) -> None: - params = ctx.parent.parent.parent.params if ctx.parent and ctx.parent.parent and ctx.parent.parent.parent else ctx.params + params = ( + ctx.parent.parent.parent.params + if ctx.parent and ctx.parent.parent and ctx.parent.parent.parent + else ctx.params + ) runtime = _cli_runtime(params["state_dir"]) - raise typer.Exit(_run_brain(runtime, _namespace(provider_command="embeddings", embedding_command="status"))) + raise typer.Exit( + _run_brain( + runtime, + _namespace(provider_command="embeddings", embedding_command="status"), + ) + ) @provider_embeddings_app.command("local") def provider_embeddings_local_command( ctx: typer.Context, source: str = typer.Option("huggingface", "--source", help="Model source: huggingface or modelscope."), ) -> None: - params = ctx.parent.parent.parent.params if ctx.parent and ctx.parent.parent and ctx.parent.parent.parent else ctx.params + params = ( + ctx.parent.parent.parent.params + if ctx.parent and ctx.parent.parent and ctx.parent.parent.parent + else ctx.params + ) runtime = _cli_runtime(params["state_dir"]) - raise typer.Exit(_run_brain(runtime, _namespace(provider_command="embeddings", embedding_command="local", embedding_source=source))) + raise typer.Exit( + _run_brain( + runtime, + _namespace( + provider_command="embeddings", + embedding_command="local", + embedding_source=source, + ), + ) + ) @provider_embeddings_app.command("setup") def provider_embeddings_setup_command(ctx: typer.Context) -> None: """Interactive embedding provider setup wizard.""" - params = ctx.parent.parent.parent.params if ctx.parent and ctx.parent.parent and ctx.parent.parent.parent else ctx.params + params = ( + ctx.parent.parent.parent.params + if ctx.parent and ctx.parent.parent and ctx.parent.parent.parent + else ctx.params + ) runtime = _cli_runtime(params["state_dir"]) - raise typer.Exit(_run_brain(runtime, _namespace(provider_command="embeddings", embedding_command="setup"))) + raise typer.Exit( + _run_brain( + runtime, + _namespace(provider_command="embeddings", embedding_command="setup"), + ) + ) @provider_embeddings_app.command("openai-compatible") def provider_embeddings_openai_command( @@ -2921,9 +3722,17 @@ def provider_embeddings_openai_command( model: str = typer.Option(..., "--model", help="Embedding model id."), dimensions: str = typer.Option(..., "--dimensions", help="Embedding vector dimensions."), api_key: str | None = typer.Option(None, "--api-key", help="Embedding API key."), - secret_env_var: str | None = typer.Option(None, "--secret-env-var", help="Environment variable name for the embedding provider key."), + secret_env_var: str | None = typer.Option( + None, + "--secret-env-var", + help="Environment variable name for the embedding provider key.", + ), ) -> None: - params = ctx.parent.parent.parent.params if ctx.parent and ctx.parent.parent and ctx.parent.parent.parent else ctx.params + params = ( + ctx.parent.parent.parent.params + if ctx.parent and ctx.parent.parent and ctx.parent.parent.parent + else ctx.params + ) runtime = _cli_runtime(params["state_dir"]) args = _namespace( provider_command="embeddings", @@ -2996,7 +3805,14 @@ def herd_delete_command( runtime = _cli_runtime(params["state_dir"]) try: raise typer.Exit( - _run_herd(runtime, _namespace(herd_command="delete", elephant_id=elephant_id, delete_all=delete_all)) + _run_herd( + runtime, + _namespace( + herd_command="delete", + elephant_id=elephant_id, + delete_all=delete_all, + ), + ) ) except ValueError as error: raise typer.BadParameter(str(error)) from error @@ -3011,7 +3827,11 @@ def facts_callback(ctx: typer.Context) -> None: @facts_app.command("list") def facts_list_command( ctx: typer.Context, - elephant_id: str | None = typer.Option(None, "--elephant-id", help="Resolve Personal Model facts through a named elephant."), + elephant_id: str | None = typer.Option( + None, + "--elephant-id", + help="Resolve Personal Model facts through a named elephant.", + ), ) -> None: params = ctx.parent.parent.params if ctx.parent and ctx.parent.parent else ctx.params runtime = _cli_runtime(params["state_dir"]) @@ -3021,8 +3841,16 @@ def facts_list_command( def facts_delete_command( ctx: typer.Context, fact_id: str = typer.Argument(..., help="Name the Personal Model entry to retire."), - elephant_id: str | None = typer.Option(None, "--elephant-id", help="Resolve Personal Model facts through a named elephant."), - reason: str | None = typer.Option(None, "--reason", help="Record why this Personal Model entry is being retired."), + elephant_id: str | None = typer.Option( + None, + "--elephant-id", + help="Resolve Personal Model facts through a named elephant.", + ), + reason: str | None = typer.Option( + None, + "--reason", + help="Record why this Personal Model entry is being retired.", + ), ) -> None: params = ctx.parent.parent.params if ctx.parent and ctx.parent.parent else ctx.params runtime = _cli_runtime(params["state_dir"]) @@ -3030,7 +3858,12 @@ def facts_delete_command( raise typer.Exit( _run_facts( runtime, - _namespace(facts_command="delete", elephant_id=elephant_id, fact_id=fact_id, reason=reason), + _namespace( + facts_command="delete", + elephant_id=elephant_id, + fact_id=fact_id, + reason=reason, + ), ) ) except ValueError as error: @@ -3046,7 +3879,12 @@ def reflect_callback( params = ctx.parent.params if ctx.parent is not None else ctx.params runtime = _cli_runtime(params["state_dir"]) try: - raise typer.Exit(_run_learn(runtime, _namespace(learn_command="list", elephant_id=elephant_id, limit=limit))) + raise typer.Exit( + _run_learn( + runtime, + _namespace(learn_command="list", elephant_id=elephant_id, limit=limit), + ) + ) except ValueError as error: raise typer.BadParameter(str(error)) from error @@ -3064,10 +3902,22 @@ def reflect_list_command( def reflect_run_command( ctx: typer.Context, elephant_id: str | None = typer.Option(None, "--elephant-id", help="Run reflect for a named elephant."), - features: str | None = typer.Option(None, "--features", help="Comma-separated feature set (pm,questions,dream,diary,skills,recall,compress)."), - date: str | None = typer.Option(None, "--date", help="Target date for dream/diary feature (YYYY-MM-DD). Defaults to today for dream and yesterday for diary."), + features: str | None = typer.Option( + None, + "--features", + help="Comma-separated feature set (pm,questions,dream,diary,skills,recall,compress).", + ), + date: str | None = typer.Option( + None, + "--date", + help="Target date for dream/diary feature (YYYY-MM-DD). Defaults to today for dream and yesterday for diary.", + ), wait: bool = typer.Option(False, "--wait", help="Wait for the reflect agent to finish."), - install_cron: bool = typer.Option(False, "--install-cron", help="Install the built-in nightly Dream learning cron job."), + install_cron: bool = typer.Option( + False, + "--install-cron", + help="Install the built-in nightly Dream learning cron job.", + ), ) -> None: """Run a reflect agent with the specified features.""" from datetime import date as date_type, timedelta @@ -3082,13 +3932,19 @@ def reflect_run_command( cron_label = "Nightly dream cron job installed." else: if "dream" not in requested_features: - raise typer.BadParameter("--install-cron only installs the dream feature; diary remains manual-only outside Dream") + raise typer.BadParameter( + "--install-cron only installs the dream feature; diary remains manual-only outside Dream" + ) _ensure_dream_cron(runtime) cron_label = "Nightly dream cron job installed." _print_cli_card( "Elephant Agent learning cron", cron_label, - next_commands=("elephant reflect run --features dream --date ", "elephant reflect run --features diary --date ", "elephant cron list"), + next_commands=( + "elephant reflect run --features dream --date ", + "elephant reflect run --features diary --date ", + "elephant cron list", + ), ) if not features: raise typer.Exit(0) @@ -3125,7 +3981,14 @@ def reflect_run_command( worker_exit_code = 0 if wait: completed = subprocess.run( - (sys.executable, "-m", "apps.learning_worker_command", "--state-dir", str(runtime.paths.state_dir), "--once"), + ( + sys.executable, + "-m", + "apps.learning_worker_command", + "--state-dir", + str(runtime.paths.state_dir), + "--once", + ), check=False, ) worker_exit_code = int(completed.returncode or 0) @@ -3134,12 +3997,15 @@ def reflect_run_command( "Elephant Agent reflect", f"Reflect agent {'completed' if wait else 'queued'}.", sections=( - CliCardSection("Job", ( - f"job_id · {job.job_id}", - f"trigger · {trigger}", - f"features · {features or '(trigger default)'}", - f"status · {worker_line}", - )), + CliCardSection( + "Job", + ( + f"job_id · {job.job_id}", + f"trigger · {trigger}", + f"features · {features or '(trigger default)'}", + f"status · {worker_line}", + ), + ), ), next_commands=("elephant reflect list",), ) diff --git a/apps/cli/cli_main_setup.py b/apps/cli/cli_main_setup.py index 972cbd9..c745b6f 100644 --- a/apps/cli/cli_main_setup.py +++ b/apps/cli/cli_main_setup.py @@ -3,23 +3,15 @@ from __future__ import annotations import argparse -from dataclasses import dataclass import os -import random -import re import select import sys import time -from collections.abc import Iterable -from pathlib import Path -from packages.state import DEFAULT_ELEPHANT_IDENTITY_TEXT, render_default_elephant_identity from .runtime import CliRuntime from .provider_flow import ( ProviderSelectionState, - provider_choices as _shared_provider_choices, - provider_setup_defaults, run_provider_selection_wizard, ) from .shell import ( @@ -32,7 +24,6 @@ Console, Group, Panel, - ProductizedShell, RICH_AVAILABLE, Table, Text, @@ -42,11 +33,8 @@ from .wizard import ( WIZARD_BACK, WIZARD_CANCEL, - WizardChoice, _WizardBackSignal, _interactive_shell_supported, - _wizard_choice_prompt, - _wizard_dialogs_supported, _wizard_text_prompt, ) from .shell_stack import Live @@ -101,9 +89,9 @@ ) - from .cli_main_support import * # noqa: F401,F403 + def _default_personality_preset(runtime: CliRuntime, *, mode: str, current: str | None = None) -> str | None: if mode != "companion": return None @@ -114,9 +102,13 @@ def _default_personality_preset(runtime: CliRuntime, *, mode: str, current: str return preset.preset_id return runtime.personality_presets()[0].preset_id + def _print_birth_wizard_intro() -> None: if not RICH_AVAILABLE or Table is None or Panel is None or Group is None: - _print_heading("Elephant Agent Init", "Start from you, then choose the first elephant and model path.") + _print_heading( + "Elephant Agent Init", + "Start from you, then choose the first elephant and model path.", + ) for line in INIT_REFLECTION_LINES: _print_bullet(line) return @@ -161,13 +153,15 @@ def _print_birth_wizard_intro() -> None: logo_block.add_row(_center_brand_block(render_stage_zero_elephant_mark())) layout.add_row(_center_brand_block(logo_block), Text(" "), questions, Text(" "), flow) console.print( - _center_intro_window(Panel( - layout, - title=f"[bold {BRAND_ACCENT}]Elephant Agent Init · Stage 0 → first wake · v{_resolve_elephant_version()}[/bold {BRAND_ACCENT}]", - border_style=BRAND_ACCENT, - expand=False, - padding=(1, 2), - )) + _center_intro_window( + Panel( + layout, + title=f"[bold {BRAND_ACCENT}]Elephant Agent Init · Stage 0 → first wake · v{_resolve_elephant_version()}[/bold {BRAND_ACCENT}]", + border_style=BRAND_ACCENT, + expand=False, + padding=(1, 2), + ) + ) ) @@ -245,7 +239,9 @@ def _print_birth_wizard_intro() -> None: ) -def _init_welcome_variant(variant_index: int) -> tuple[str, str, str, str, tuple[str, ...], str]: +def _init_welcome_variant( + variant_index: int, +) -> tuple[str, str, str, str, tuple[str, ...], str]: variant = _INIT_WELCOME_VARIANTS[variant_index % len(_INIT_WELCOME_VARIANTS)] return ( str(variant["title"]), @@ -268,12 +264,7 @@ def _init_welcome_elephant_mark(): return mark plain = getattr(mark, "plain", "") rows = plain.splitlines() - visible_cells = [ - index - for row in rows - for index, cell in enumerate(row) - if cell != " " - ] + visible_cells = [index for row in rows for index, cell in enumerate(row) if cell != " "] if not rows or not visible_cells: return mark visible_left = min(visible_cells) @@ -304,20 +295,25 @@ def _init_welcome_frame(variant_index: int): copy.append(prefix, style=style) copy.append("Elephant Agent", style=f"bold {BRAND_LIGHT}") copy.append(suffix + "\n", style=style) - indicator = " ".join("●" if index == variant_index % len(_INIT_WELCOME_VARIANTS) else "·" for index in range(len(_INIT_WELCOME_VARIANTS))) + indicator = " ".join( + "●" if index == variant_index % len(_INIT_WELCOME_VARIANTS) else "·" + for index in range(len(_INIT_WELCOME_VARIANTS)) + ) copy.append("\n" + indicator + "\n", style=BRAND_MUTED) copy.append(enter + "\n", style=f"bold {BRAND_LIGHT}") body.add_row(_center_brand_block(copy)) - return _center_intro_window(Panel( - body, - subtitle=f"[bold {BRAND_ACCENT}]Create yours[/bold {BRAND_ACCENT}]", - subtitle_align="center", - border_style=BRAND_DARK, - expand=True, - padding=(1, 3), - width=92, - height=28, - )) + return _center_intro_window( + Panel( + body, + subtitle=f"[bold {BRAND_ACCENT}]Create yours[/bold {BRAND_ACCENT}]", + subtitle_align="center", + border_style=BRAND_DARK, + expand=True, + padding=(1, 3), + width=92, + height=28, + ) + ) def _prompt_init_welcome_gate() -> bool: @@ -381,7 +377,12 @@ def _center_intro_window(renderable): _, height = _intro_console_size() try: if height > 0: - return Align(renderable, align="center", vertical="middle", height=max(22, height - 1)) + return Align( + renderable, + align="center", + vertical="middle", + height=max(22, height - 1), + ) return Align.center(renderable, vertical="middle") except TypeError: return Align.center(renderable) @@ -399,6 +400,7 @@ def _prompt_first_elephant_name(default_name: str, *, allow_back: bool = False) allow_back=allow_back, ) + def _run_interactive_elephant_wizard( runtime: CliRuntime, *, @@ -415,6 +417,7 @@ def _run_interactive_elephant_wizard( return None return str(answer).strip() or current_elephant_name + def _run_interactive_birth_wizard( runtime: CliRuntime, *, @@ -475,6 +478,7 @@ def _run_interactive_birth_wizard( continue return state + def _print_birth_paused() -> None: _print_cli_card( "Elephant Agent birth paused", @@ -482,6 +486,7 @@ def _print_birth_paused() -> None: next_commands=("elephant init", "elephant status"), ) + def _gateway_birth_lines(elephant_name: str) -> tuple[str, ...]: return ( "wire IM · elephant gateway setup", @@ -490,6 +495,7 @@ def _gateway_birth_lines(elephant_name: str) -> tuple[str, ...]: "launch operator dashboard · elephant daemon start && elephant dashboard", ) + def _prompt_im_onboarding(runtime: CliRuntime, *, elephant_name: str) -> None: from apps.gateway.__main__ import run_im_setup @@ -501,6 +507,7 @@ def _prompt_im_onboarding(runtime: CliRuntime, *, elephant_name: str) -> None: allow_skip=True, ) + def _print_overview(runtime: CliRuntime) -> None: provider = dict(runtime.provider_summary()) doctor = runtime.provider_doctor() @@ -515,23 +522,59 @@ def _print_overview(runtime: CliRuntime) -> None: capability = Text("You · Threads · Herd · Skills · Providers", style=BRAND_MUTED) action_lines = Text() action_lines.append("Start\n", style=f"bold {BRAND_ACCENT}") - action_lines.append(f"{_format_command_line('elephant wake', 'continue the active thread')}\n", style=BRAND_LIGHT) - action_lines.append(f"{_format_command_line('elephant init', 'set name, provider, model, and recall path')}\n", style=BRAND_LIGHT) - action_lines.append(f"{_format_command_line('elephant herd new ', 'create another named continuity thread')}\n", style=BRAND_LIGHT) - action_lines.append(f"{_format_command_line('elephant herd', 'inspect named continuity threads')}\n", style=BRAND_LIGHT) - action_lines.append(f"{_format_command_line('elephant dashboard', 'open the continuity console')}\n", style=BRAND_LIGHT) + action_lines.append( + f"{_format_command_line('elephant wake', 'continue the active thread')}\n", + style=BRAND_LIGHT, + ) + action_lines.append( + f"{_format_command_line('elephant init', 'set name, provider, model, and recall path')}\n", + style=BRAND_LIGHT, + ) + action_lines.append( + f"{_format_command_line('elephant herd new ', 'create another named continuity thread')}\n", + style=BRAND_LIGHT, + ) + action_lines.append( + f"{_format_command_line('elephant herd', 'inspect named continuity threads')}\n", + style=BRAND_LIGHT, + ) + action_lines.append( + f"{_format_command_line('elephant dashboard', 'open the continuity console')}\n", + style=BRAND_LIGHT, + ) action_lines.append("\nSystem controls\n", style=f"bold {BRAND_ACCENT}") - action_lines.append(f"{_format_command_line('elephant provider', 'manage models, keys, context, and embeddings')}\n", style=BRAND_LIGHT) - action_lines.append(f"{_format_command_line('elephant skills', 'inspect, install, search, and toggle skills')}\n", style=BRAND_LIGHT) - action_lines.append(f"{_format_command_line('elephant gateway', 'bind messenger surfaces')}\n", style=BRAND_LIGHT) - action_lines.append(f"{_format_command_line('elephant status', 'check provider and recall readiness')}\n", style=BRAND_LIGHT) + action_lines.append( + f"{_format_command_line('elephant provider', 'manage models, keys, context, and embeddings')}\n", + style=BRAND_LIGHT, + ) + action_lines.append( + f"{_format_command_line('elephant skills', 'inspect, install, search, and toggle skills')}\n", + style=BRAND_LIGHT, + ) + action_lines.append( + f"{_format_command_line('elephant gateway', 'bind messenger surfaces')}\n", + style=BRAND_LIGHT, + ) + action_lines.append( + f"{_format_command_line('elephant status', 'check provider and recall readiness')}\n", + style=BRAND_LIGHT, + ) action_lines.append("\nCurrent install\n", style=f"bold {BRAND_ACCENT}") - action_lines.append(f"readiness · {doctor['status']}\n", style=BRAND_MUTED if doctor["status"] != "ready" else BRAND_LIGHT) + action_lines.append( + f"readiness · {doctor['status']}\n", + style=BRAND_MUTED if doctor["status"] != "ready" else BRAND_LIGHT, + ) action_lines.append(f"provider · {provider['provider_id']}\n", style=BRAND_MUTED) if provider.get("model_id") or provider.get("default_model"): - action_lines.append(f"model · {provider.get('model_id') or provider.get('default_model')}\n", style=BRAND_MUTED) + action_lines.append( + f"model · {provider.get('model_id') or provider.get('default_model')}\n", + style=BRAND_MUTED, + ) if herd: - action_lines.append("states · " + ", ".join(elephant.elephant_id for elephant in herd), style=BRAND_MUTED) + action_lines.append( + "states · " + ", ".join(elephant.elephant_id for elephant in herd), + style=BRAND_MUTED, + ) else: action_lines.append("states · none yet", style=BRAND_MUTED) brand.add_row(_center_brand_block(headline)) @@ -581,13 +624,18 @@ def _print_overview(runtime: CliRuntime) -> None: _print_field("provider", provider["provider_id"]) if provider.get("model_id") or provider.get("default_model"): _print_field("model", provider.get("model_id") or provider.get("default_model")) - _print_field("states", ", ".join(elephant.elephant_id for elephant in herd) if herd else "none yet") + _print_field( + "states", + ", ".join(elephant.elephant_id for elephant in herd) if herd else "none yet", + ) + def _center_brand_block(renderable): if Align is None: return renderable return Align.center(renderable) + def _print_setup_intro(runtime: CliRuntime, *, provider_id: str) -> None: guide = runtime.provider_setup_guide(provider_id) loaded = runtime.current_profile() @@ -615,6 +663,7 @@ def _print_setup_intro(runtime: CliRuntime, *, provider_id: str) -> None: ), ) + def _default_born_args() -> argparse.Namespace: return argparse.Namespace( provider_id=DEFAULT_PROVIDER_ID, @@ -638,6 +687,7 @@ def _default_born_args() -> argparse.Namespace: non_interactive=False, ) + def _default_grow_args() -> argparse.Namespace: return argparse.Namespace( elephant_id=None, @@ -645,6 +695,7 @@ def _default_grow_args() -> argparse.Namespace: message=None, ) + def _ensure_elephant_ready( runtime: CliRuntime, *, @@ -663,6 +714,7 @@ def _ensure_elephant_ready( ) return session, "created" + __all__ = [ "DEFAULT_PROVIDER_ID", "DEFAULT_ELEPHANT_NAME_SUGGESTIONS", diff --git a/apps/cli/cli_main_support.py b/apps/cli/cli_main_support.py index 56f4669..952f5ae 100644 --- a/apps/cli/cli_main_support.py +++ b/apps/cli/cli_main_support.py @@ -4,45 +4,28 @@ from dataclasses import dataclass import random import re -import sys -from collections.abc import Iterable, Mapping +from collections.abc import Mapping from pathlib import Path -from packages.state import DEFAULT_ELEPHANT_IDENTITY_TEXT, render_default_elephant_identity from .runtime import CliRuntime from .provider_flow import ( - ProviderSelectionState, provider_choices as _shared_provider_choices, - provider_setup_defaults, - run_provider_selection_wizard, ) _provider_choices = _shared_provider_choices from .shell import ( Align, BRAND_ACCENT, - BRAND_DARK, BRAND_LIGHT, BRAND_MUTED, Console, Group, Panel, - ProductizedShell, RICH_AVAILABLE, - Table, Text, - _resolve_elephant_version, render_stage_zero_elephant_mark, ) -from .wizard import ( - WIZARD_BACK, - WizardChoice, - _WizardBackSignal, - _wizard_choice_prompt, - _wizard_dialogs_supported, - _wizard_text_prompt, -) try: from .shell_ui import BRAND_ACCENT_STRONG @@ -87,17 +70,38 @@ CLI_THEME_SUBTITLE = "Personal Model first, curious at your pace." CLI_HELP_TAGLINE = "🐘 Model what matters · 👂 Ask gently · 🐾 Follow the path" CLI_HELP_COMMANDS = ( - ("init", "Run first-time setup and persist identity, provider readiness, and the first elephant session."), - ("wake", "Enter an existing Elephant Agent elephant through the branded TUI or run one provider-backed turn."), - ("dashboard", "Launch the local operator dashboard when frontend assets are present."), + ( + "init", + "Run first-time setup and persist identity, provider readiness, and the first elephant session.", + ), + ( + "wake", + "Enter an existing Elephant Agent elephant through the branded TUI or run one provider-backed turn.", + ), + ( + "dashboard", + "Launch the local operator dashboard when frontend assets are present.", + ), ("herd", "Create, inspect, select, or retire existing Elephant Agent herd."), - ("provider", "Configure or inspect the active provider, model, reasoning effort, and context window."), + ( + "provider", + "Configure or inspect the active provider, model, reasoning effort, and context window.", + ), ("facts", "Inspect or retire Personal Model facts without entering wake."), - ("reflect", "Run, inspect, and manage background reflect agents (PM learning, dream, diary, audit)."), - ("skills", "Inspect, search, install, and toggle skill packages without entering wake."), + ( + "reflect", + "Run, inspect, and manage background reflect agents (PM learning, dream, diary, audit).", + ), + ( + "skills", + "Inspect, search, install, and toggle skill packages without entering wake.", + ), ("gateway", "Manage IM providers and accounts."), ("cron", "Manage the background cron scheduler."), - ("status", "Review provider, model, and security readiness before opening the wake surface."), + ( + "status", + "Review provider, model, and security readiness before opening the wake surface.", + ), ) CLI_COMMAND_HELP = {command: description for command, description in CLI_HELP_COMMANDS} CLI_HELP_NEXT_COMMANDS = ("elephant init", "elephant wake", "elephant dashboard") @@ -122,6 +126,7 @@ class CliCardSection: title: str lines: tuple[str, ...] = () + class _WizardCancelledError(Exception): __slots__ = ("surface",) @@ -129,6 +134,7 @@ def __init__(self, surface: str) -> None: super().__init__(surface) self.surface = surface + @dataclass(slots=True) class BirthWizardState: display_name: str @@ -166,6 +172,7 @@ class BirthWizardState: relationship_mode: str = "" starter_answers: tuple[tuple[str, str, str], ...] = () + def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="elephant", @@ -179,7 +186,9 @@ def build_parser() -> argparse.ArgumentParser: def _add_init_parser(name: str, *, hidden: bool = False) -> None: init = subparsers.add_parser( name, - help=argparse.SUPPRESS if hidden else "Run first-time setup and persist identity, provider readiness, and the first elephant session.", + help=argparse.SUPPRESS + if hidden + else "Run first-time setup and persist identity, provider readiness, and the first elephant session.", ) init.add_argument("--provider-id", default=DEFAULT_PROVIDER_ID) init.add_argument("--display-name", default=None) @@ -189,7 +198,11 @@ def _add_init_parser(name: str, *, hidden: bool = False) -> None: init.add_argument("--model-id", default=None) init.add_argument("--api-key", default=None) init.add_argument("--secret-env-var", default=None) - init.add_argument("--embedding-provider", choices=("local", "openai-compatible"), default="local") + init.add_argument( + "--embedding-provider", + choices=("local", "openai-compatible"), + default="local", + ) init.add_argument("--embedding-base-url", default=None) init.add_argument("--embedding-model", default=None) init.add_argument("--embedding-dimensions", default=None) @@ -213,7 +226,9 @@ def _add_init_parser(name: str, *, hidden: bool = False) -> None: def _add_status_parser(name: str, *, hidden: bool = False) -> None: status = subparsers.add_parser( name, - help=argparse.SUPPRESS if hidden else "Review provider, model, and security readiness before opening the wake surface.", + help=argparse.SUPPRESS + if hidden + else "Review provider, model, and security readiness before opening the wake surface.", ) status.add_argument( "--deep", @@ -224,7 +239,9 @@ def _add_status_parser(name: str, *, hidden: bool = False) -> None: def _add_provider_parser(name: str, *, hidden: bool = False) -> None: provider = subparsers.add_parser( name, - help=argparse.SUPPRESS if hidden else "Configure or inspect the active provider, model, reasoning effort, and context window.", + help=argparse.SUPPRESS + if hidden + else "Configure or inspect the active provider, model, reasoning effort, and context window.", ) provider.add_argument( "provider_command", @@ -249,10 +266,20 @@ def _add_provider_parser(name: str, *, hidden: bool = False) -> None: def _add_wake_parser(name: str, *, hidden: bool = False) -> None: wake = subparsers.add_parser( name, - help=argparse.SUPPRESS if hidden else "Open the next Elephant Agent Episode through the branded TUI or run one provider-backed turn.", + help=argparse.SUPPRESS + if hidden + else "Open the next Elephant Agent Episode through the branded TUI or run one provider-backed turn.", + ) + wake.add_argument( + "--elephant-id", + default=None, + help="Open the next Episode for a known elephant.", + ) + wake.add_argument( + "--debug", + action="store_true", + help="Show runtime diagnostics inside the wake surface.", ) - wake.add_argument("--elephant-id", default=None, help="Open the next Episode for a known elephant.") - wake.add_argument("--debug", action="store_true", help="Show runtime diagnostics inside the wake surface.") wake.add_argument("--message", default=None, help="Run one wake turn and exit.") _add_init_parser("init") @@ -271,7 +298,11 @@ def _add_wake_parser(name: str, *, hidden: bool = False) -> None: elephant_new.add_argument("elephant_name", nargs="?", help="Name the new Elephant Agent elephant.") elephant_new.add_argument("--profile-id", default=None) elephant_new.add_argument("--display-name", default=None) - elephant_new.add_argument("--debug", action="store_true", help="Show runtime diagnostics inside the wake surface.") + elephant_new.add_argument( + "--debug", + action="store_true", + help="Show runtime diagnostics inside the wake surface.", + ) elephant_new.add_argument("--message", default=None, help="Create the elephant, run one turn, and exit.") elephant_current = herd_subparsers.add_parser( "current", @@ -293,39 +324,59 @@ def _add_wake_parser(name: str, *, hidden: bool = False) -> None: "facts", help="Inspect or retire Personal Model facts without entering wake.", ) - facts.add_argument("--elephant-id", default=None, help="Resolve Personal Model facts through a named elephant.") + facts.add_argument( + "--elephant-id", + default=None, + help="Resolve Personal Model facts through a named elephant.", + ) facts_subparsers = facts.add_subparsers(dest="facts_command") facts_list = facts_subparsers.add_parser( "list", help="List Personal Model facts for the current or named elephant.", ) - facts_list.add_argument("--elephant-id", default=None, help="Resolve Personal Model facts through a named elephant.") + facts_list.add_argument( + "--elephant-id", + default=None, + help="Resolve Personal Model facts through a named elephant.", + ) facts_delete = facts_subparsers.add_parser( "delete", help="Retire one Personal Model entry by id.", ) facts_delete.add_argument("fact_id", help="Name the Personal Model entry to retire.") - facts_delete.add_argument("--elephant-id", default=None, help="Resolve Personal Model facts through a named elephant.") - facts_delete.add_argument("--reason", default=None, help="Record why this Personal Model entry is being retired.") + facts_delete.add_argument( + "--elephant-id", + default=None, + help="Resolve Personal Model facts through a named elephant.", + ) + facts_delete.add_argument( + "--reason", + default=None, + help="Record why this Personal Model entry is being retired.", + ) _add_wake_parser("wake") return parser + def _print_heading(title: str, detail: str | None = None) -> None: print(f"{CLI_THEME_TITLE_GLYPH} {title}") if detail: print(f" {detail}") + def _print_field(label: str, value: object) -> None: rendered = "" if value is not None: rendered = str(value) print(f" {label}: {rendered}") + def _print_bullet(text: str) -> None: print(f" {CLI_THEME_BULLET} {text}") + def _command_hint_glyph(command: str) -> str: normalized = " ".join(command.split()).strip() for prefix, glyph in CLI_COMMAND_GLYPHS: @@ -333,9 +384,11 @@ def _command_hint_glyph(command: str) -> str: return glyph return CLI_THEME_BULLET + def _format_command_hint(command: str) -> str: return f"{_command_hint_glyph(command)} {command}" + def _format_command_line(command: str, detail: str) -> str: return f"{_command_hint_glyph(command)} {command} · {detail}" @@ -348,7 +401,9 @@ def _append_command_highlight(target: Text, line: str) -> None: marker = " · " command_part, separator, detail_part = line.partition(marker) leading_token = command_part.split(maxsplit=1)[0] if command_part else "" - has_command_glyph = leading_token == CLI_THEME_BULLET or any(leading_token == glyph for _, glyph in CLI_COMMAND_GLYPHS) + has_command_glyph = leading_token == CLI_THEME_BULLET or any( + leading_token == glyph for _, glyph in CLI_COMMAND_GLYPHS + ) if not has_command_glyph: target.append(f"{CLI_THEME_BULLET} ", style=BRAND_MUTED) if separator: @@ -358,9 +413,11 @@ def _append_command_highlight(target: Text, line: str) -> None: else: target.append(line, style=BRAND_LIGHT) + def _print_command_line(command: str, detail: str) -> None: print(f" {_format_command_line(command, detail)}") + def _print_command_hints(*commands: str) -> None: if not commands: return @@ -368,6 +425,7 @@ def _print_command_hints(*commands: str) -> None: for command in commands: print(f" {_format_command_hint(command)}") + def _print_cli_card( title: str, detail: str | None = None, @@ -438,9 +496,7 @@ def _print_cli_help( next_commands: tuple[str, ...] = (), tagline: str | None = None, ) -> None: - intro = ( - "Elephant Agent is personal-model-first AI — it grows from you, understands first, gets curious at your pace, and grows into your shape over time." - ) + intro = "Elephant Agent is personal-model-first AI — it grows from you, understands first, gets curious at your pace, and grows into your shape over time." sections: list[CliCardSection] = [CliCardSection("Elephant Agent", (intro,))] if options: sections.append( @@ -468,12 +524,13 @@ def _print_cli_help( def _play_creating_transition(title: str, detail: str) -> None: return None + def _provider_secret_ready(runtime: CliRuntime, *, provider_id: str) -> bool: provider_summary = dict(runtime.provider_summary()) - if ( - provider_summary.get("provider_id") == provider_id - and provider_summary.get("secret_status") in {"stored", "not-required"} - ): + if provider_summary.get("provider_id") == provider_id and provider_summary.get("secret_status") in { + "stored", + "not-required", + }: return True try: discovered = runtime.discovered_provider(provider_id) @@ -497,7 +554,9 @@ def _embedding_bootstrap_ready_label(status: object) -> str: return normalized or "unknown" -def _embedding_bootstrap_status_lines(embedding: Mapping[str, object]) -> tuple[str, ...]: +def _embedding_bootstrap_status_lines( + embedding: Mapping[str, object], +) -> tuple[str, ...]: status = str(embedding.get("embedding_bootstrap_status") or "") summary = str(embedding.get("embedding_bootstrap_summary") or "").strip() lines = [ @@ -509,7 +568,9 @@ def _embedding_bootstrap_status_lines(embedding: Mapping[str, object]) -> tuple[ return tuple(lines) -def _embedding_bootstrap_notice_lines(embedding: Mapping[str, object]) -> tuple[str, ...]: +def _embedding_bootstrap_notice_lines( + embedding: Mapping[str, object], +) -> tuple[str, ...]: status = str(embedding.get("embedding_bootstrap_status") or "").strip().lower() source = str(embedding.get("source") or "").strip().lower() if source != "local-default": @@ -532,6 +593,7 @@ def _embedding_bootstrap_notice_lines(embedding: Mapping[str, object]) -> tuple[ ) return () + def _print_brain_status(runtime: CliRuntime) -> None: provider = dict(runtime.provider_summary()) embedding = dict(runtime.embedding_provider_summary()) @@ -594,6 +656,7 @@ def _print_brain_status(runtime: CliRuntime) -> None: ), ) + def _print_brain_provider_inventory(runtime: CliRuntime) -> None: lines = tuple( f"{state.provider_id} · {state.display_name} · {state.transport_display_name} · status={state.status} · source={state.source}" @@ -607,6 +670,7 @@ def _print_brain_provider_inventory(runtime: CliRuntime) -> None: next_commands=("elephant provider", "elephant provider status"), ) + def _print_brain_models(runtime: CliRuntime, *, provider_id: str) -> None: try: models = runtime.discover_provider_models(provider_id=provider_id) @@ -628,6 +692,7 @@ def _print_brain_models(runtime: CliRuntime, *, provider_id: str) -> None: next_commands=("elephant provider", "elephant provider status"), ) + def _print_embedding_provider_status(runtime: CliRuntime) -> None: embedding = dict(runtime.embedding_provider_summary()) sections = ( @@ -644,7 +709,11 @@ def _print_embedding_provider_status(runtime: CliRuntime) -> None: *_embedding_bootstrap_status_lines(embedding), ), ), - *((CliCardSection("Background bootstrap", _embedding_bootstrap_notice_lines(embedding)),) if _embedding_bootstrap_notice_lines(embedding) else ()), + *( + (CliCardSection("Background bootstrap", _embedding_bootstrap_notice_lines(embedding)),) + if _embedding_bootstrap_notice_lines(embedding) + else () + ), ) _print_cli_card( "Embedding provider status", @@ -657,22 +726,23 @@ def _print_embedding_provider_status(runtime: CliRuntime) -> None: ), ) + def _slugify_elephant_name(value: str) -> str: collapsed = re.sub(r"[^a-zA-Z0-9]+", "-", value.strip().lower()).strip("-") return collapsed or "elephant" + def _display_name_from_elephant_name(value: str) -> str: collapsed = re.sub(r"[^a-zA-Z0-9]+", " ", value.strip()).strip() return collapsed.title() or "Elephant Agent" + def _suggest_elephant_name(runtime: CliRuntime | None = None) -> str: candidates = DEFAULT_ELEPHANT_NAME_SUGGESTIONS if runtime is None: return random.choice(candidates) available = tuple( - name - for name in candidates - if runtime.latest_session_for_elephant(_slugify_elephant_name(name)) is None + name for name in candidates if runtime.latest_session_for_elephant(_slugify_elephant_name(name)) is None ) return random.choice(available or candidates) @@ -686,6 +756,7 @@ def _unique_elephant_name(runtime: CliRuntime, value: str) -> str: suffix += 1 return candidate + __all__ = [ "_provider_choices", "DEFAULT_PROVIDER_ID", diff --git a/apps/cli/provider_flow.py b/apps/cli/provider_flow.py index 0fb4f84..6708a5a 100644 --- a/apps/cli/provider_flow.py +++ b/apps/cli/provider_flow.py @@ -51,10 +51,7 @@ def provider_setup_defaults(runtime: CliRuntime, provider_id: str) -> ProviderSe summary = dict(runtime.provider_summary()) same_provider = str(summary.get("provider_id", "")).strip().lower() == normalized_provider_id base_url = str( - (summary.get("base_url") if same_provider else None) - or discovered.base_url - or guide.suggested_base_url - or "" + (summary.get("base_url") if same_provider else None) or discovered.base_url or guide.suggested_base_url or "" ).strip() model_id = str( (summary.get("model_id") if same_provider else None) @@ -68,10 +65,7 @@ def provider_setup_defaults(runtime: CliRuntime, provider_id: str) -> ProviderSe context_window_tokens = int(summary["context_window_tokens"]) except (TypeError, ValueError): context_window_tokens = None - context_window_mode = str( - (summary.get("context_window_mode") if same_provider else None) - or "auto" - ) + context_window_mode = str((summary.get("context_window_mode") if same_provider else None) or "auto") reasoning_effort = ( str(summary.get("reasoning_effort")).strip() if same_provider and summary.get("reasoning_effort") is not None @@ -101,7 +95,9 @@ def _manual_model_default(provider_id: str, model_id: str | None) -> str | None: def _should_retry_provider_key_on_model_discovery_failure(provider_id: str, auth_type: str) -> bool: normalized_provider = str(provider_id or "").strip().lower() normalized_auth_type = str(auth_type or "").strip().lower() - return normalized_auth_type == "api_key" and normalized_provider not in _MODEL_DISCOVERY_KEY_RETRY_EXCLUDED_PROVIDERS + return ( + normalized_auth_type == "api_key" and normalized_provider not in _MODEL_DISCOVERY_KEY_RETRY_EXCLUDED_PROVIDERS + ) def _choose_model( @@ -126,7 +122,11 @@ def _choose_model( if not models and _should_retry_provider_key_on_model_discovery_failure(state.provider_id, auth_type): refreshed_key = _wizard_text_prompt( _pf_text(language, "Refresh The Provider Key", "重新输入模型服务密钥"), - _pf_text(language, "Elephant Agent could not read the provider model catalog. Re-enter the provider key so it can retry live model discovery.", "Elephant Agent 读取不到模型列表。请重新输入密钥,让它重试实时模型发现。"), + _pf_text( + language, + "Elephant Agent could not read the provider model catalog. Re-enter the provider key so it can retry live model discovery.", + "Elephant Agent 读取不到模型列表。请重新输入密钥,让它重试实时模型发现。", + ), allow_back=allow_back, password=True, ) @@ -156,17 +156,23 @@ def _choose_model( WizardChoice( value=MANUAL_MODEL_SENTINEL, label=_pf_text(language, "Manual model id", "手动输入模型 ID"), - detail=_pf_text(language, "Type a model id that is not advertised by the provider catalog.", "输入模型列表里没有展示的模型 ID。"), + detail=_pf_text( + language, + "Type a model id that is not advertised by the provider catalog.", + "输入模型列表里没有展示的模型 ID。", + ), ), ) default_value = ( - state.model_id - if any(model.model_id == state.model_id for model in models) - else models[0].model_id + state.model_id if any(model.model_id == state.model_id for model in models) else models[0].model_id ) answer = _wizard_choice_prompt( _pf_text(language, "Choose The Model", "选择模型"), - _pf_text(language, "Pick the model Elephant Agent should use from this provider endpoint.", "从这个服务端点里选择 Elephant Agent 要使用的模型。"), + _pf_text( + language, + "Pick the model Elephant Agent should use from this provider endpoint.", + "从这个服务端点里选择 Elephant Agent 要使用的模型。", + ), model_choices, default=default_value, allow_back=allow_back, @@ -259,7 +265,11 @@ def _go_back() -> bool: if step == "provider_id": answer = _wizard_choice_prompt( _pf_text(language, "Choose A Provider", "选择模型服务"), - _pf_text(language, "Where should Elephant Agent think from next?", "Elephant Agent 接下来应该从哪里思考?"), + _pf_text( + language, + "Where should Elephant Agent think from next?", + "Elephant Agent 接下来应该从哪里思考?", + ), provider_choices(runtime), default=state.provider_id, allow_back=allow_back, @@ -285,13 +295,13 @@ def _go_back() -> bool: state_base_url = str(state.base_url or "").strip() supports_custom_base_url = "base_url" in guide.required_config_keys known_base_url = summary_base_url if same_provider and summary_base_url else discovered_base_url - same_endpoint = ( - not supports_custom_base_url - or (bool(state_base_url) and bool(known_base_url) and state_base_url == known_base_url) - ) - discovered_secret_reusable = discovered.status in {"authenticated", "configured"} and ( - not supports_custom_base_url or same_endpoint + same_endpoint = not supports_custom_base_url or ( + bool(state_base_url) and bool(known_base_url) and state_base_url == known_base_url ) + discovered_secret_reusable = discovered.status in { + "authenticated", + "configured", + } and (not supports_custom_base_url or same_endpoint) has_resolved_secret = ( same_provider and same_endpoint and summary.get("secret_status") in {"stored", "not-required"} ) or discovered_secret_reusable @@ -302,7 +312,11 @@ def _go_back() -> bool: continue answer = _wizard_text_prompt( _pf_text(language, "Set The Endpoint", "设置接口地址"), - _pf_text(language, "What endpoint should Elephant Agent call?", "Elephant Agent 应该调用哪个接口?"), + _pf_text( + language, + "What endpoint should Elephant Agent call?", + "Elephant Agent 应该调用哪个接口?", + ), default=state.base_url, allow_back=allow_back and step_index > 0, ) @@ -317,10 +331,11 @@ def _go_back() -> bool: continue if step == "api_key": - if ( - not guide.required_secret_keys - or guide.auth_type in {"oauth_external", "oauth_device_code", "external_process"} - ): + if not guide.required_secret_keys or guide.auth_type in { + "oauth_external", + "oauth_device_code", + "external_process", + }: state.api_key = None step_index += 1 continue @@ -410,7 +425,11 @@ def _go_back() -> bool: WizardChoice( value="", label=_pf_text(language, "Provider default", "服务默认"), - detail=_pf_text(language, "Let the provider choose its default reasoning budget.", "让模型服务使用默认推理预算。"), + detail=_pf_text( + language, + "Let the provider choose its default reasoning budget.", + "让模型服务使用默认推理预算。", + ), ), *tuple( WizardChoice( @@ -445,7 +464,11 @@ def _go_back() -> bool: auto_value = detected or state.context_window_tokens or DEFAULT_CONTEXT_WINDOW_TOKENS answer = _wizard_choice_prompt( _pf_text(language, "Choose The Context Window", "选择上下文窗口"), - _pf_text(language, "How should Elephant Agent size the context budget?", "Elephant Agent 应该怎样设置上下文预算?"), + _pf_text( + language, + "How should Elephant Agent size the context budget?", + "Elephant Agent 应该怎样设置上下文预算?", + ), ( WizardChoice( value="auto", @@ -486,7 +509,11 @@ def _go_back() -> bool: default_tokens = str(state.context_window_tokens or detected or DEFAULT_CONTEXT_WINDOW_TOKENS) answer = _wizard_text_prompt( _pf_text(language, "Enter The Context Window", "输入上下文窗口"), - _pf_text(language, "How many tokens of context should Elephant Agent budget for this model?", "Elephant Agent 应该为这个模型预留多少上下文 token?"), + _pf_text( + language, + "How many tokens of context should Elephant Agent budget for this model?", + "Elephant Agent 应该为这个模型预留多少上下文 token?", + ), default=default_tokens, allow_back=allow_back and step_index > 0, ) diff --git a/apps/cli/runtime.py b/apps/cli/runtime.py index c53d484..6f62a33 100644 --- a/apps/cli/runtime.py +++ b/apps/cli/runtime.py @@ -4,13 +4,4 @@ from .runtime_support import * # noqa: F401,F403 from .runtime_cognition import * # noqa: F401,F403 -from .runtime_cognition import ( - _CliContextCapability, - _DurableRecallCapability, - _PreviewDeliveryCapability, - _PreviewRecallCapability, - _PreviewModelProviderCapability, - _PreviewToolCapability, -) from .runtime_extensions import _PreviewTelemetrySink # noqa: F401 -from .runtime_impl import CliRuntime diff --git a/apps/cli/runtime_cognition.py b/apps/cli/runtime_cognition.py index 2564656..3264910 100644 --- a/apps/cli/runtime_cognition.py +++ b/apps/cli/runtime_cognition.py @@ -300,7 +300,9 @@ class _CliContextCapability: startup_cwd: Path | None = None summary_model_provider: Any | None = None embedding_service: Any | None = None - last_projection_compaction: ContextProjectionCompactionResult | None = field(default=None, init=False, repr=False, compare=False) + last_projection_compaction: ContextProjectionCompactionResult | None = field( + default=None, init=False, repr=False, compare=False + ) descriptor: CapabilityDescriptor = CapabilityDescriptor( capability_id="cli.context.runtime", kind="context_assembler", @@ -562,10 +564,14 @@ def _capability_artifacts( def _runtime_path_artifact(self, session: Episode) -> str: lines: list[str] = [] if self.startup_cwd is not None: - lines.append(f"startup_cwd={self.startup_cwd.expanduser().resolve()} (the directory where this session launched; use as working directory when the user asks to explore 'here' or 'current project')") + lines.append( + f"startup_cwd={self.startup_cwd.expanduser().resolve()} (the directory where this session launched; use as working directory when the user asks to explore 'here' or 'current project')" + ) if self.workspaces_dir is not None and session.elephant_id: - elephant_ws = self.workspaces_dir.expanduser().resolve() / quote(session.elephant_id.strip(), safe='') - lines.append(f"elephant_workspace={elephant_ws} (default scratch directory for file output when the user does not specify a path)") + elephant_ws = self.workspaces_dir.expanduser().resolve() / quote(session.elephant_id.strip(), safe="") + lines.append( + f"elephant_workspace={elephant_ws} (default scratch directory for file output when the user does not specify a path)" + ) if not lines: return "" return "runtime-paths: " + "; ".join(lines) @@ -583,9 +589,13 @@ def _resolve_pm_state_and_facts(self, session: Episode) -> tuple[Any, tuple[Any, state = self.repository.current_state() if state is None: active_states = self.repository.list_states(status="active") - profile_states = [c for c in active_states if str(c.metadata.get("profile_id") or "").strip() == session.personal_model_id] - if len(profile_states) == 1: state = profile_states[0] - elif len(active_states) == 1: state = active_states[0] + profile_states = [ + c for c in active_states if str(c.metadata.get("profile_id") or "").strip() == session.personal_model_id + ] + if len(profile_states) == 1: + state = profile_states[0] + elif len(active_states) == 1: + state = active_states[0] if state is None: return (None, ()) list_facts = getattr(self.repository, "list_personal_model_facts", None) @@ -601,6 +611,7 @@ def _recently_learned_from_facts(self, facts: tuple[Any, ...]) -> tuple[str, ... """Find PM facts committed in last 24h for UX visibility.""" try: from datetime import timedelta, datetime, timezone + cutoff = datetime.now(timezone.utc) - timedelta(hours=24) recent: list[str] = [] for fact in facts: @@ -610,10 +621,13 @@ def _recently_learned_from_facts(self, facts: tuple[Any, ...]) -> tuple[str, ... if not promoted_at: continue try: - if datetime.fromisoformat(promoted_at.replace("Z", "+00:00")) < cutoff: continue - except (ValueError, TypeError): continue + if datetime.fromisoformat(promoted_at.replace("Z", "+00:00")) < cutoff: + continue + except (ValueError, TypeError): + continue effect = str(metadata.get("behavioral_effect") or getattr(fact, "text", "") or "").strip() - if effect and effect not in recent: recent.append(_compact_runtime_text(effect, limit=120)) + if effect and effect not in recent: + recent.append(_compact_runtime_text(effect, limit=120)) return tuple(recent[:4]) except Exception: return () @@ -629,28 +643,45 @@ def _personal_model_behavior_contract_from_facts(self, facts: tuple[Any, ...], * if str(getattr(fact, "status", "") or "active") != "active": continue effect = str(metadata.get("behavioral_effect") or getattr(fact, "text", "") or "").strip() - if not effect: continue + if not effect: + continue family = str(metadata.get("facet") or getattr(fact, "lens", "") or "general").strip() grouped.setdefault(family, []) compact = _compact_runtime_text(effect, limit=160) - if compact not in grouped[family]: grouped[family].append(compact) + if compact not in grouped[family]: + grouped[family].append(compact) if not grouped: return "" lines: list[str] = [] total = 0 - family_labels = {"style": "Style", "core": "Identity", "relationship": "Relationship", "procedural": "Workflow", "personal_knowledge": "Knowledge"} - for family in ("style", "core", "relationship", "procedural", "personal_knowledge"): + family_labels = { + "style": "Style", + "core": "Identity", + "relationship": "Relationship", + "procedural": "Workflow", + "personal_knowledge": "Knowledge", + } + for family in ( + "style", + "core", + "relationship", + "procedural", + "personal_knowledge", + ): effects = grouped.pop(family, []) - if not effects: continue + if not effects: + continue label = family_labels.get(family, family.replace("_", " ").title()) for effect in effects: - if total >= limit: break + if total >= limit: + break lines.append(f"- {label}: {effect}") total += 1 for family, effects in grouped.items(): label = family.replace("_", " ").title() for effect in effects: - if total >= limit: break + if total >= limit: + break lines.append(f"- {label}: {effect}") total += 1 return "\n".join(lines) if lines else "" @@ -664,21 +695,15 @@ def _generation_artifacts( plan: PlanDraft | None, continuity: EpisodeContinuityState | None, ) -> tuple[str, ...]: - artifacts = [ - artifact - for artifact in ( - _continuity_artifact(continuity), - ) - if artifact.strip() - ] + artifacts = [artifact for artifact in (_continuity_artifact(continuity),) if artifact.strip()] if plan is not None and plan.steps: step = plan.steps[0] artifacts.append( - "runtime-plan-step: " - f"{step.title}; rationale={_compact_runtime_text(step.rationale, limit=160)}" - ) + f"runtime-plan-step: {step.title}; rationale={_compact_runtime_text(step.rationale, limit=160)}" + ) return tuple(artifacts) + def _continuity_artifact(continuity: EpisodeContinuityState | None) -> str: if continuity is None or not continuity.requires_recovery: return "" @@ -790,6 +815,7 @@ def generate( side_effects=(f"model_role={model_role}",), ) + class _PreviewToolCapability: descriptor: Any = None diff --git a/apps/cli/runtime_cron_sub_agents.py b/apps/cli/runtime_cron_sub_agents.py index 14119be..7eeecf0 100644 --- a/apps/cli/runtime_cron_sub_agents.py +++ b/apps/cli/runtime_cron_sub_agents.py @@ -54,7 +54,14 @@ def describe(self, tool_id: str): def list_tools(self, **kwargs: Any) -> tuple[Any, ...]: return tuple(tool for tool in self._runtime.list_tools(**kwargs) if tool.tool_id in self._allowed_tool_ids) - def invoke(self, tool_name: str, arguments: Mapping[str, Any], *, session_id: str, requester: str | None = None): + def invoke( + self, + tool_name: str, + arguments: Mapping[str, Any], + *, + session_id: str, + requester: str | None = None, + ): self._ensure_allowed(tool_name) return self._runtime.invoke(tool_name, arguments, session_id=session_id, requester=requester) @@ -65,7 +72,6 @@ def list_executions(self) -> tuple[Any, ...]: return self._runtime.list_executions() - def cron_skill_ids(value: object) -> tuple[str, ...]: if value is None: return () @@ -112,7 +118,16 @@ def run_sub_agent_task( result = run_sub_agent_tasks( runtime, session_id=session_id, - tasks=({"task": task, "name": name, "skills": skills, "allowed_tools": allowed_tools, "system_prompt": system_prompt, "learning_agent": learning_agent},), + tasks=( + { + "task": task, + "name": name, + "skills": skills, + "allowed_tools": allowed_tools, + "system_prompt": system_prompt, + "learning_agent": learning_agent, + }, + ), max_concurrency=1, ) results = tuple(result.get("results") or ()) @@ -229,7 +244,9 @@ def start_sub_agent_tasks( run_id=run_id, task_index=index, ) - future.add_done_callback(lambda completed, task_index=index: _record_async_sub_agent_result(run, task_index, completed)) + future.add_done_callback( + lambda completed, task_index=index: _record_async_sub_agent_result(run, task_index, completed) + ) futures.append(future) run.futures = tuple(futures) with _ASYNC_SUB_AGENT_RUNS_LOCK: @@ -265,8 +282,7 @@ def list_sub_agent_runs( return { "status": "completed", "summary": "\n".join( - f"{run.run_id}: {run.status} ({_completed_sub_agent_count(run)}/{len(run.results)} done)" - for run in runs + f"{run.run_id}: {run.status} ({_completed_sub_agent_count(run)}/{len(run.results)} done)" for run in runs ) or "no sub-agent runs", "runs": [_sub_agent_run_payload(run) for run in runs], @@ -305,12 +321,16 @@ def _prepare_sub_agent_child( ) -> Mapping[str, Any]: child = _open_sub_agent_child_episode(runtime, parent_session_id) child_session_id = child.episode_id - prompt = task if system_prompt.strip() else _compose_sub_agent_prompt( - runtime, - task=task, - name=name, - skills=skills, - session_id=child_session_id, + prompt = ( + task + if system_prompt.strip() + else _compose_sub_agent_prompt( + runtime, + task=task, + name=name, + skills=skills, + session_id=child_session_id, + ) ) return { "name": name or "sub-agent", @@ -470,7 +490,6 @@ def _run_prepared_sub_agent_child( return result - def _create_child_runtime(runtime: Any) -> Any: return runtime.__class__.create( state_dir=runtime.paths.state_dir, @@ -499,7 +518,14 @@ def _normalize_sub_agent_task(item: Mapping[str, Any]) -> Mapping[str, Any]: allowed_tools = cron_skill_ids(item.get("allowed_tools") or item.get("allowed_tool_ids")) system_prompt = str(item.get("system_prompt") or "").strip() learning_agent = bool(item.get("learning_agent")) - return {"task": task, "name": name_text, "skills": skills, "allowed_tools": allowed_tools, "system_prompt": system_prompt, "learning_agent": learning_agent} + return { + "task": task, + "name": name_text, + "skills": skills, + "allowed_tools": allowed_tools, + "system_prompt": system_prompt, + "learning_agent": learning_agent, + } def _aggregate_sub_agent_status(results: list[Mapping[str, Any] | None]) -> str: @@ -570,7 +596,9 @@ def _sub_agent_run_payload(run: _AsyncSubAgentRun) -> Mapping[str, Any]: f"progress: {completed}/{total}", ] if status == "running": - summary_lines.append(f"Use tool.sub_agents action=status run_id={run.run_id} to check progress, or action=join to wait.") + summary_lines.append( + f"Use tool.sub_agents action=status run_id={run.run_id} to check progress, or action=join to wait." + ) for index, item in enumerate(resolved): summary_lines.append( f"{index + 1}. {item.get('name') or 'sub-agent'}: {item.get('summary') or item.get('status') or 'finished'}" diff --git a/apps/cli/runtime_extensions.py b/apps/cli/runtime_extensions.py index 344fe56..b4d74c5 100644 --- a/apps/cli/runtime_extensions.py +++ b/apps/cli/runtime_extensions.py @@ -11,7 +11,11 @@ from .snapshot_io import load_snapshot_payload, write_snapshot_payload from packages.state import ProfileLoader from packages.security import SecurityPolicy -from packages.skills import SkillActivationContext, SkillRuntime, builtin_skill_definitions +from packages.skills import ( + SkillActivationContext, + SkillRuntime, + builtin_skill_definitions, +) from packages.storage import RuntimeStorageRepository from packages.tools import ( BuiltinToolDependencies, @@ -223,6 +227,9 @@ def _resolve_elephant_state(repository: RuntimeStorageRepository, elephant_id: s if state is not None: return state for candidate in repository.list_states(): - if candidate.elephant_id == resolved_elephant_id or candidate.state_anchor in {resolved_elephant_id, f"elephant:{resolved_elephant_id}"}: + if candidate.elephant_id == resolved_elephant_id or candidate.state_anchor in { + resolved_elephant_id, + f"elephant:{resolved_elephant_id}", + }: return candidate return repository.current_state() diff --git a/apps/cli/runtime_extensions_skill_sources.py b/apps/cli/runtime_extensions_skill_sources.py index 2913df1..881ece3 100644 --- a/apps/cli/runtime_extensions_skill_sources.py +++ b/apps/cli/runtime_extensions_skill_sources.py @@ -17,7 +17,9 @@ ) -def source_descriptor_for_hub_entry(entry: SkillHubEntry) -> PublicSkillSourceDescriptor | None: +def source_descriptor_for_hub_entry( + entry: SkillHubEntry, +) -> PublicSkillSourceDescriptor | None: existing = public_skill_source_descriptor_from_metadata(entry.metadata) source_reference = public_hub_source_reference(entry) install_reference = public_hub_install_reference(entry) @@ -97,7 +99,11 @@ def local_skill_trust_level(source_id: str, metadata: Mapping[str, Any]) -> str: return configured if source_id == "builtin": return "builtin" - if source_id in {"path", "elephant-installed", "elephant-authored"} or source_id.startswith("custom-"): + if source_id in { + "path", + "elephant-installed", + "elephant-authored", + } or source_id.startswith("custom-"): return "trusted" return "community" diff --git a/apps/cli/runtime_extensions_surface.py b/apps/cli/runtime_extensions_surface.py index d949816..3401a64 100644 --- a/apps/cli/runtime_extensions_surface.py +++ b/apps/cli/runtime_extensions_surface.py @@ -13,9 +13,26 @@ from packages.contracts.layers import Episode from packages.contracts.runtime import ExperienceRecord, ExecutionResult from packages.cron import CronJob, CronJobExecution -from packages.runtime_config import global_config_path_for_state_dir, load_global_config, save_extensions_to_config, load_extensions_from_config +from packages.runtime_config import ( + global_config_path_for_state_dir, + load_global_config, + save_extensions_to_config, +) from packages.growth import GrowthUpdate, ProgressionProjection, ProgressionTransition -from packages.skills import PublicSkillSourceDescriptor, SkillDefinition, SkillHubEntry, SkillManifestLoadRecord, SkillPackageLoader, SkillSearchEntry, build_installed_skill_provenance, build_public_skill_source_descriptor, install_bucket_for_source_descriptor, load_skill_package_definition, materialize_skill_package, public_skill_source_descriptor_from_metadata +from packages.skills import ( + PublicSkillSourceDescriptor, + SkillDefinition, + SkillHubEntry, + SkillManifestLoadRecord, + SkillPackageLoader, + SkillSearchEntry, + build_installed_skill_provenance, + build_public_skill_source_descriptor, + install_bucket_for_source_descriptor, + load_skill_package_definition, + materialize_skill_package, + public_skill_source_descriptor_from_metadata, +) from packages.skills.authoring import write_skill_package from packages.state import ( PromptContract, @@ -25,14 +42,40 @@ personality_presets, profile_with_authored_elephant_identity, ) -from packages.tools import BuiltinToolDependencies, ToolAudience, ToolDefinition, ToolManifestLoadRecord, sync_custom_mcp_tools +from packages.tools import ( + BuiltinToolDependencies, + ToolAudience, + ToolDefinition, + ToolManifestLoadRecord, + sync_custom_mcp_tools, +) from packages.tools.adapters import StructuredClarifySurface from packages.understanding import PersonalModelUnderstandingSurface -from .runtime_extensions import CliExtensionManifest, build_skill_runtime, build_tool_runtime, load_extension_manifest, sanitize_extension_manifest_payload, serialize_manifest_path -from .runtime_extensions_skill_sources import install_record_detail as _install_record_detail, installed_skill_record as _installed_skill_record, matching_install_record as _matching_install_record, normalized_install_requester as _normalized_install_requester, record_install_reference as _record_install_reference, remote_skill_definition as _remote_skill_definition, source_descriptor_for_hub_entry as _source_descriptor_for_hub_entry, source_descriptor_for_path as _source_descriptor_for_path +from .runtime_extensions import ( + CliExtensionManifest, + build_skill_runtime, + build_tool_runtime, + load_extension_manifest, + sanitize_extension_manifest_payload, + serialize_manifest_path, +) +from .runtime_extensions_skill_sources import ( + install_record_detail as _install_record_detail, + installed_skill_record as _installed_skill_record, + matching_install_record as _matching_install_record, + normalized_install_requester as _normalized_install_requester, + record_install_reference as _record_install_reference, + remote_skill_definition as _remote_skill_definition, + source_descriptor_for_hub_entry as _source_descriptor_for_hub_entry, + source_descriptor_for_path as _source_descriptor_for_path, +) from .runtime_cron_sub_agents import compose_cron_prompt -from .runtime_growth_surface import inspect_experiences as _inspect_experiences, inspect_growth as _inspect_growth, inspect_growth_transition as _inspect_growth_transition +from .runtime_growth_surface import ( + inspect_experiences as _inspect_experiences, + inspect_growth as _inspect_growth, + inspect_growth_transition as _inspect_growth_transition, +) from .runtime_sub_agents import CliRuntimeSubAgentsMixin from .runtime_support import _path_is_within, _utc_now @@ -227,7 +270,18 @@ def write_learning_result( if job.personal_model_id != resolved_pm_id: raise PermissionError(f"learning job does not belong to this personal model: {job_id}") normalized_status = status if status in {"completed", "partial", "no_op", "failed"} else "partial" - normalized_mode = mode if mode in {"init_bootstrap", "episode_close", "im_idle", "context_compression", "manual"} else "manual" + normalized_mode = ( + mode + if mode + in { + "init_bootstrap", + "episode_close", + "im_idle", + "context_compression", + "manual", + } + else "manual" + ) payload = { "job_id": job_id, "mode": normalized_mode, @@ -235,19 +289,44 @@ def write_learning_result( "summary": summary, "pm_facts": _normalize_result_section( pm_facts, - defaults={"created_refs": [], "updated_refs": [], "retired_refs": [], "notes": ""}, + defaults={ + "created_refs": [], + "updated_refs": [], + "retired_refs": [], + "notes": "", + }, aliases={ "created_refs": ("created", "created_ids", "created_facts"), "updated_refs": ("updated", "updated_ids", "updated_facts"), - "retired_refs": ("retired", "retired_ids", "forgotten", "forgotten_refs"), + "retired_refs": ( + "retired", + "retired_ids", + "forgotten", + "forgotten_refs", + ), }, ), "skill_affinities": _normalize_result_section( skill_affinities, - defaults={"included_refs": [], "excluded_refs": [], "candidate_refs": [], "notes": ""}, + defaults={ + "included_refs": [], + "excluded_refs": [], + "candidate_refs": [], + "notes": "", + }, aliases={ - "included_refs": ("included", "included_ids", "created", "created_refs"), - "excluded_refs": ("excluded", "excluded_ids", "retired", "retired_refs"), + "included_refs": ( + "included", + "included_ids", + "created", + "created_refs", + ), + "excluded_refs": ( + "excluded", + "excluded_ids", + "retired", + "retired_refs", + ), "candidate_refs": ("candidates", "candidate_ids"), }, ), @@ -265,7 +344,11 @@ def write_learning_result( "settled_ids": ("settled", "answered", "answered_ids"), "created_ids": ("created", "created_questions"), "updated_ids": ("updated", "updated_questions"), - "next_ask_candidate_ids": ("next_candidates", "next_ask_candidates", "candidates"), + "next_ask_candidate_ids": ( + "next_candidates", + "next_ask_candidates", + "candidates", + ), "dismissed_ids": ("dismissed", "dismissed_questions", "stale_ids"), }, ), @@ -285,7 +368,12 @@ def write_learning_result( worker_id=getattr(job, "worker_id", None) or "learning-result-tool", progress_detail=summary, ) - return {"job_id": job_id, "status": normalized_status, "summary": summary, "learning_result": payload} + return { + "job_id": job_id, + "status": normalized_status, + "summary": summary, + "learning_result": payload, + } # --- DiarySurface implementation --- @@ -300,6 +388,7 @@ def write_diary_entry( ) -> Mapping[str, Any]: from uuid import uuid4 from packages.contracts import DiaryEntry + entry_id = f"diary:{uuid4().hex[:12]}" entry = DiaryEntry( entry_id=entry_id, @@ -327,7 +416,11 @@ def list_diary_entries( ) return { "entries": [ - {"entry_id": e.entry_id, "entry_date": e.entry_date, "content": e.content} + { + "entry_id": e.entry_id, + "entry_date": e.entry_date, + "content": e.content, + } for e in entries ], "count": len(entries), @@ -349,7 +442,10 @@ def learning_runtime_status( queued = tuple(job for job in jobs if job.status == "queued") failed = tuple(job for job in jobs if job.status == "failed") completed = tuple(job for job in jobs if job.status == "completed") - from apps.learning_worker_runtime import load_learning_worker_record, learning_worker_is_running + from apps.learning_worker_runtime import ( + load_learning_worker_record, + learning_worker_is_running, + ) worker_record = load_learning_worker_record(self.paths.state_dir) or {} return { @@ -383,7 +479,9 @@ def learning_runtime_status( ), } - def tool_catalog(self, *, session_id: str | None = None, audience: ToolAudience | None = None) -> tuple[ToolDefinition, ...]: + def tool_catalog( + self, *, session_id: str | None = None, audience: ToolAudience | None = None + ) -> tuple[ToolDefinition, ...]: if session_id is not None: self.prepare_session_surface(session_id, steady_embeddings=False) return self.tool_runtime.list_tools(audience=audience) @@ -432,10 +530,7 @@ def install_tool_manifest( existing_paths = list(load_extension_manifest(manifest, profile_dir=profile_dir).tool_manifest_paths) if resolved_path not in existing_paths: existing_paths.append(resolved_path) - manifest["tool_manifests"] = [ - serialize_manifest_path(path, profile_dir=profile_dir) - for path in existing_paths - ] + manifest["tool_manifests"] = [serialize_manifest_path(path, profile_dir=profile_dir) for path in existing_paths] self._save_extensions_manifest(manifest) self._refresh_extensions(profile_id=resolved_profile_id) return self._tool_manifest_load_record(resolved_path) @@ -542,7 +637,10 @@ def run_due_cron_jobs_for_scheduler(self) -> tuple[CronJobExecution, ...]: def executor(job: CronJob) -> tuple[str, str]: session = self._cron_session_for_job(job) if session is None: - return ("failed", f"{job.name} skipped because no matching session is available.") + return ( + "failed", + f"{job.name} skipped because no matching session is available.", + ) return self._execute_cron_job(job, session_id=session.episode_id) return self.cron_runtime.run_due(executor) @@ -564,7 +662,10 @@ def run_cron_job_now(self, job_id: str) -> CronJobExecution: started = self.cron_runtime.begin_execution(job_id, now=now) session = self._cron_session_for_job(started.job) if session is None: - outcome, summary = ("failed", f"{started.job.name} skipped because no matching session is available.") + outcome, summary = ( + "failed", + f"{started.job.name} skipped because no matching session is available.", + ) else: outcome, summary = self._execute_cron_job(started.job, session_id=session.episode_id) return self.cron_runtime.record_execution_result( @@ -609,6 +710,7 @@ def _execute_cron_learning_job( ) -> tuple[str, str]: """Execute a cron job that triggers a learning agent.""" from datetime import date as date_type, timedelta + trigger = str(job.payload.get("trigger") or "").strip() if not trigger: raise ValueError("cron learning jobs require a 'trigger' in payload") @@ -647,12 +749,18 @@ def _cron_session_for_job(self, job: CronJob) -> Episode | None: def has_due_cron_jobs(self, *, session_id: str) -> bool: session = self._load_session(session_id) loaded = self._load_profile(session.personal_model_id) - return bool(self.cron_runtime.due_jobs(profile_id=loaded.state.profile_id, elephant_id=self.elephant_id_for_session(session))) + return bool( + self.cron_runtime.due_jobs( + profile_id=loaded.state.profile_id, + elephant_id=self.elephant_id_for_session(session), + ) + ) def skill_catalog(self, *, session_id: str | None = None) -> tuple[SkillDefinition, ...]: if session_id is not None: self.prepare_session_surface(session_id, steady_embeddings=False) return self.skill_runtime.catalog.list() + def list_skill_hub(self, *, limit: int | None = None) -> tuple[SkillHubEntry, ...]: entries = self.skill_hub.list(self._current_skill_enabled_overrides()) if limit is None or limit <= 0: @@ -660,11 +768,32 @@ def list_skill_hub(self, *, limit: int | None = None) -> tuple[SkillHubEntry, .. return entries[:limit] def search_skill_hub(self, query: str, *, limit: int = 12) -> tuple[SkillHubEntry, ...]: - return self.skill_hub.search(query, limit=limit, enabled_overrides=self._current_skill_enabled_overrides()) - def search_skill_sources(self, query: str, *, source: str | None = None, limit: int = 12) -> tuple[SkillSearchEntry, ...]: + return self.skill_hub.search( + query, + limit=limit, + enabled_overrides=self._current_skill_enabled_overrides(), + ) + + def search_skill_sources( + self, query: str, *, source: str | None = None, limit: int = 12 + ) -> tuple[SkillSearchEntry, ...]: return self.skill_search_hub.search(query, source=source, limit=limit) - def inspect_experiences(self, *, session_id: str | None = None, profile_id: str | None = None, statuses: tuple[str, ...] = (), limit: int | None = None) -> tuple[ExperienceRecord, ...]: - return _inspect_experiences(self, session_id=session_id, profile_id=profile_id, statuses=statuses, limit=limit) + + def inspect_experiences( + self, + *, + session_id: str | None = None, + profile_id: str | None = None, + statuses: tuple[str, ...] = (), + limit: int | None = None, + ) -> tuple[ExperienceRecord, ...]: + return _inspect_experiences( + self, + session_id=session_id, + profile_id=profile_id, + statuses=statuses, + limit=limit, + ) def inspect_growth( self, @@ -673,8 +802,10 @@ def inspect_growth( profile_id: str | None = None, ) -> ProgressionProjection: return _inspect_growth(self, session_id=session_id, profile_id=profile_id) + def consume_growth_update(self, *, session_id: str) -> GrowthUpdate | None: return self.growth_updates.pop(session_id, None) + def inspect_growth_transition(self, update: GrowthUpdate, *, session_id: str) -> ProgressionTransition: return _inspect_growth_transition(self, update, session_id=session_id) @@ -711,7 +842,8 @@ def inspect_skill(self, skill_id: str, *, session_id: str | None = None) -> Skil return replace(skill, metadata=metadata) def inspect_skill_source(self, skill_id: str, *, session_id: str | None = None) -> SkillDefinition: - if session_id is not None: self.prepare_session_surface(session_id) + if session_id is not None: + self.prepare_session_surface(session_id) try: return self.inspect_skill(skill_id) except KeyError: @@ -765,8 +897,7 @@ def install_skill_manifest( if resolved_path not in existing_paths: existing_paths.append(resolved_path) manifest["skill_manifests"] = [ - serialize_manifest_path(path, profile_dir=profile_dir) - for path in existing_paths + serialize_manifest_path(path, profile_dir=profile_dir) for path in existing_paths ] self._save_extensions_manifest(manifest) self._refresh_extensions(profile_id=resolved_profile_id) @@ -1090,10 +1221,7 @@ def _install_skill_package_path( retained_resolved.add(resolved_existing) if materialized_path not in retained_resolved: retained_paths.append(materialized_path) - manifest["skill_packages"] = [ - serialize_manifest_path(path, profile_dir=profile_dir) - for path in retained_paths - ] + manifest["skill_packages"] = [serialize_manifest_path(path, profile_dir=profile_dir) for path in retained_paths] self._save_extensions_manifest(manifest) for stale_path in stale_paths: if not _path_is_within(stale_path, installed_root): @@ -1128,9 +1256,7 @@ def _refresh_extensions(self, *, profile_id: str | None = None) -> None: manifest_payload, removed_manifest_keys = sanitize_extension_manifest_payload(dict(loaded.manifest)) if removed_manifest_keys: self._save_extensions_manifest(manifest_payload) - self._apply_extension_manifest( - load_extension_manifest(manifest_payload, profile_dir=Path(loaded.profile_dir)) - ) + self._apply_extension_manifest(load_extension_manifest(manifest_payload, profile_dir=Path(loaded.profile_dir))) def _sync_global_custom_mcp_tools(self) -> None: config_path = global_config_path_for_state_dir(self.paths.state_dir) @@ -1153,6 +1279,7 @@ def _elephant_file_root_for_session(session_id: str | None) -> Path: elephant_files.mkdir(parents=True, exist_ok=True) return elephant_files return Path.cwd() + embedding_service = self.recall_runtime.retriever.evidence_retriever.embedding_service semantic_summary_indexer = None if self.semantic_index_bundle is not None and embedding_service is not None: @@ -1189,9 +1316,7 @@ def _elephant_file_root_for_session(session_id: str | None) -> Path: repository=self.repository, semantic_summary_indexer=semantic_summary_indexer, semantic_searcher=( - self.semantic_index_bundle.searcher - if self.semantic_index_bundle is not None - else None + self.semantic_index_bundle.searcher if self.semantic_index_bundle is not None else None ), embedding_service=embedding_service, ), diff --git a/apps/cli/runtime_growth_metrics.py b/apps/cli/runtime_growth_metrics.py index 977523f..787238c 100644 --- a/apps/cli/runtime_growth_metrics.py +++ b/apps/cli/runtime_growth_metrics.py @@ -78,7 +78,10 @@ def personal_model_growth_metrics( if changed_this_turn: new_facts += 1 action = str(metadata.get("action") or "").strip().lower() - if getattr(fact, "supersedes_fact_id", None) or action in {"correct", "restore"}: + if getattr(fact, "supersedes_fact_id", None) or action in { + "correct", + "restore", + }: updated_facts += 1 fact_count = len(facts) return PersonalModelGrowthMetrics( diff --git a/apps/cli/runtime_growth_surface.py b/apps/cli/runtime_growth_surface.py index eea7a0f..8c14266 100644 --- a/apps/cli/runtime_growth_surface.py +++ b/apps/cli/runtime_growth_surface.py @@ -3,12 +3,24 @@ from __future__ import annotations from packages.contracts.runtime import ExperienceRecord -from packages.growth import GrowthUpdate, ProgressionProjection, ProgressionProjectionBuilder, ProgressionTransition +from packages.growth import ( + GrowthUpdate, + ProgressionProjection, + ProgressionProjectionBuilder, + ProgressionTransition, +) _PROGRESSION_BUILDER = ProgressionProjectionBuilder() -def inspect_experiences(runtime, *, session_id: str | None = None, profile_id: str | None = None, statuses: tuple[str, ...] = (), limit: int | None = None) -> tuple[ExperienceRecord, ...]: +def inspect_experiences( + runtime, + *, + session_id: str | None = None, + profile_id: str | None = None, + statuses: tuple[str, ...] = (), + limit: int | None = None, +) -> tuple[ExperienceRecord, ...]: """Return experience records. Procedural memory has been removed; returns empty.""" return () diff --git a/apps/cli/runtime_impl.py b/apps/cli/runtime_impl.py index 282b63f..b854cc7 100644 --- a/apps/cli/runtime_impl.py +++ b/apps/cli/runtime_impl.py @@ -25,12 +25,24 @@ PersonalModelRuntimeState, ) from packages.cron import CronRuntime -from packages.evidence import RecallRuntime, SemanticSummaryIndexer, build_semantic_index_bundle -from packages.gateway_core import FileGatewayIdentityStore, GatewayOutboundQueue, default_outbound_queue_path +from packages.evidence import ( + RecallRuntime, + SemanticSummaryIndexer, + build_semantic_index_bundle, +) +from packages.gateway_core import ( + FileGatewayIdentityStore, + GatewayOutboundQueue, + default_outbound_queue_path, +) from packages.gateway_core.outbound_delivery import GatewayMessageDeliverySurface from packages.growth import GrowthUpdate from packages.kernel import KernelDependencies, KernelOutcome -from packages.runtime_config import configured_external_skill_dirs, global_config_path_for_state_dir, load_global_config +from packages.runtime_config import ( + configured_external_skill_dirs, + global_config_path_for_state_dir, + load_global_config, +) from packages.runtime_layout import ( default_authored_skills_dir, default_builtin_skills_dir, @@ -42,10 +54,21 @@ infer_install_root_from_state_dir, ) from packages.security import SecurityPolicy -from packages.skills import SkillHub, SkillPromptContextBuilder, SkillSearchHub, SkillRuntime, default_skill_hub_sources, sync_builtin_skill_shelf +from packages.skills import ( + SkillHub, + SkillPromptContextBuilder, + SkillSearchHub, + SkillRuntime, + default_skill_hub_sources, + sync_builtin_skill_shelf, +) from packages.state import ProfileLoader from packages.storage import RuntimeStorageRepository -from packages.tools import BuiltinToolDependencies, InMemorySessionTodoStore, ToolRuntime +from packages.tools import ( + BuiltinToolDependencies, + InMemorySessionTodoStore, + ToolRuntime, +) from packages.tools.adapters import StructuredClarifySurface from packages.tools.browser_backend import create_playwright_browser_backend from packages.tools.surfaces import BrowserToolBackend, ClarifySurface @@ -62,8 +85,6 @@ build_skill_runtime, build_tool_runtime, load_extension_manifest, - load_json_file, - sanitize_extension_manifest_payload, ) from .runtime_extensions_surface import CliRuntimeExtensionsMixin from .runtime_profile import CliRuntimeProfileMixin @@ -77,7 +98,10 @@ write_snapshot as _write_runtime_snapshot, ) from .runtime_support import * # noqa: F401,F403 -from .runtime_support import _default_elephant_identity_file_text, _seed_elephant_identity_text +from .runtime_support import ( + _default_elephant_identity_file_text, + _seed_elephant_identity_text, +) from .runtime_turns import ( build_kernel_dependencies as _build_runtime_kernel_dependencies, create_elephant_session as _create_runtime_elephant_session, @@ -89,8 +113,14 @@ start_episode as _start_runtime_session, ) + @dataclass(frozen=True, slots=True) -class CliRuntime(CliRuntimeProfileMixin, CliRuntimeProviderMixin, CliRuntimeExtensionsMixin, CliRuntimeRecordsMixin): +class CliRuntime( + CliRuntimeProfileMixin, + CliRuntimeProviderMixin, + CliRuntimeExtensionsMixin, + CliRuntimeRecordsMixin, +): paths: CliPaths repository: RuntimeStorageRepository profile_loader: ProfileLoader @@ -141,9 +171,12 @@ def create( profile_loader = ProfileLoader(home_dir) global_config_path = global_config_path_for_state_dir(state_dir) # Ensure config.yaml is always written so the file is visible - from packages.runtime_config import read_global_config_text if not global_config_path.exists(): - from packages.runtime_config import write_global_config, default_global_config + from packages.runtime_config import ( + write_global_config, + default_global_config, + ) + write_global_config( global_config_path, default_global_config(state_dir=state_dir), @@ -153,6 +186,7 @@ def create( state_dir=state_dir, ) from packages.observability import setup_from_config + setup_from_config(global_config, state_dir=str(state_dir)) active_provider_profile = load_provider_profile(state_dir, config_path=global_config_path) active_provider_profile_id = None @@ -164,9 +198,14 @@ def create( capture_runtime_secret_env(paths.state_dir, active_provider_profile) # Load extension manifest from config.yaml from packages.runtime_config import load_extensions_from_config + config_extensions = load_extensions_from_config(global_config) extension_manifest = load_extension_manifest(config_extensions, profile_dir=home_dir) - cron_runtime = CronRuntime(paths.cron_jobs_path, output_dir=paths.cron_output_dir, lock_path=paths.cron_lock_path) + cron_runtime = CronRuntime( + paths.cron_jobs_path, + output_dir=paths.cron_output_dir, + lock_path=paths.cron_lock_path, + ) skill_hub = SkillHub( sources=default_skill_hub_sources( external_dirs=configured_external_skill_dirs(global_config), diff --git a/apps/cli/runtime_profile.py b/apps/cli/runtime_profile.py index 9fd68ab..45d1fff 100644 --- a/apps/cli/runtime_profile.py +++ b/apps/cli/runtime_profile.py @@ -7,13 +7,15 @@ from packages.context import ContextAssemblyResult from packages.contracts.runtime import ElephantIdentityRecord -from packages.state.rendered_views import RenderedRelationshipView, RenderedUserProfileView +from packages.state.rendered_views import ( + RenderedRelationshipView, + RenderedUserProfileView, +) from packages.continuity import ContinuityProjectionService from packages.operator.runtime import ( RecallEvidenceOperatorDetail, RecallEvidenceSearchHit, ProcedureOperatorDetail, - build_canonical_procedure_detail, build_recall_evidence_operator_surface, build_procedure_operator_surface, build_profile_operator_surface, @@ -94,7 +96,9 @@ def inspect_continuity(self, *, session_id: str | None = None) -> ContinuityStat recovery = self._planning_recall_evidence_recovery(session) wake_action = "continue" if active_state_focus else "idle" wake_summary = active_state_focus if active_state_focus else "No durable elephant focus is available yet." - wake_factors: tuple[str, ...] = tuple(("state-continuity", f"recall-scope={','.join(recovery.scope_episode_ids)}")) + wake_factors: tuple[str, ...] = tuple( + ("state-continuity", f"recall-scope={','.join(recovery.scope_episode_ids)}") + ) return ContinuityStatus( profile=profile, session=session, @@ -160,12 +164,26 @@ def inspect_profile_surface(self, session_id: str): def patch_profile_surface(self, session_id: str, payload: dict[str, object]): if any( key in payload - for key in {"display_name", "name", "personality_preset", "initiative", "elephant_identity_text", "text", "content", "clear_elephant_identity"} + for key in { + "display_name", + "name", + "personality_preset", + "initiative", + "elephant_identity_text", + "text", + "content", + "clear_elephant_identity", + } ): display_name = str(payload.get("display_name") or payload.get("name") or "").strip() or None personality_preset = str(payload.get("personality_preset") or "").strip() or None initiative = str(payload.get("initiative") or "").strip() or None - elephant_identity_text = str(payload.get("elephant_identity_text") or payload.get("text") or payload.get("content") or "").strip() or None + elephant_identity_text = ( + str( + payload.get("elephant_identity_text") or payload.get("text") or payload.get("content") or "" + ).strip() + or None + ) self.update_identity_state( session_id=session_id, display_name=display_name, @@ -174,7 +192,16 @@ def patch_profile_surface(self, session_id: str, payload: dict[str, object]): elephant_identity_text=elephant_identity_text, clear_elephant_identity=bool(payload.get("clear_elephant_identity", False)), ) - if any(key in payload for key in {"user_text", "user_content", "user_fields", "user_append", "user_clear"}): + if any( + key in payload + for key in { + "user_text", + "user_content", + "user_fields", + "user_append", + "user_clear", + } + ): self.update_user_state( session_id=session_id, text=str(payload.get("user_text") or payload.get("user_content") or "").strip() or None, @@ -182,7 +209,15 @@ def patch_profile_surface(self, session_id: str, payload: dict[str, object]): append=bool(payload.get("user_append", False)), clear=bool(payload.get("user_clear", False)), ) - if any(key in payload for key in {"relationship_text", "relationship_content", "relationship_append", "relationship_clear"}): + if any( + key in payload + for key in { + "relationship_text", + "relationship_content", + "relationship_append", + "relationship_clear", + } + ): self.update_relationship_state( session_id=session_id, text=str(payload.get("relationship_text") or payload.get("relationship_content") or "").strip() or None, @@ -262,6 +297,7 @@ def _coerce_str_tuple(self, value: object) -> tuple[str, ...]: def _session_continuity_state(self, session_id: str, *, session): from packages.continuity import build_episode_continuity_state + return build_episode_continuity_state( session, lineage=self.repository.episode_lineage(session_id), @@ -472,16 +508,22 @@ def update_identity_state( if clear_elephant_identity or elephant_identity_text is not None: self._authorize_write( operation="cli.identity.surface.update", - session_id=session_id or (self.latest_session().episode_id if self.latest_session() is not None else None), + session_id=session_id + or (self.latest_session().episode_id if self.latest_session() is not None else None), description="update elephant identity", - metadata={"profile_id": resolved_profile_id, "elephant_id": target_elephant_id}, + metadata={ + "profile_id": resolved_profile_id, + "elephant_id": target_elephant_id, + }, ) if target_session is not None and target_elephant_id: elephant_root = self.paths.elephant_file_path(target_elephant_id) next_state_text = ( render_default_elephant_identity( display_name=display_name or loaded.state.display_name, - personality_preset=(loaded.companion.personality_preset if loaded.companion is not None else None), + personality_preset=( + loaded.companion.personality_preset if loaded.companion is not None else None + ), initiative=(loaded.companion.initiative if loaded.companion is not None else "gentle"), mode=loaded.state.mode, ) @@ -504,7 +546,11 @@ def update_identity_state( identity_mode=elephant_state.identity_mode or loaded.state.mode, personality_preset=( elephant_state.working_style - or (loaded.companion.personality_preset if loaded.companion is not None else base_identity.personality_preset) + or ( + loaded.companion.personality_preset + if loaded.companion is not None + else base_identity.personality_preset + ) ), initiative=( elephant_state.initiative @@ -519,17 +565,25 @@ def update_identity_state( companion=loaded.companion, profile_dir=loaded.profile_dir, manifest_path=loaded.manifest_path, - elephant_identity_text=None if clear_elephant_identity else _normalized_profile_text(elephant_identity_text), + elephant_identity_text=None + if clear_elephant_identity + else _normalized_profile_text(elephant_identity_text), user_profile_text=loaded.user_profile_text, user_profile_path=loaded.user_profile_path, manifest=dict(loaded.manifest), ), sync_source="identity.state.update", ) - if target_session is not None and target_elephant_id and ( - display_name is not None or personality_preset is not None or initiative is not None + if ( + target_session is not None + and target_elephant_id + and (display_name is not None or personality_preset is not None or initiative is not None) ): - refreshed_state_text = read_elephant_identity_file(self.paths.elephant_file_path(target_elephant_id)) or loaded.elephant_identity_text or "" + refreshed_state_text = ( + read_elephant_identity_file(self.paths.elephant_file_path(target_elephant_id)) + or loaded.elephant_identity_text + or "" + ) elephant_state = self.ensure_elephant_state( target_session, elephant_identity_text=refreshed_state_text, @@ -544,7 +598,11 @@ def update_identity_state( identity_mode=elephant_state.identity_mode or loaded.state.mode, personality_preset=( elephant_state.working_style - or (loaded.companion.personality_preset if loaded.companion is not None else base_identity.personality_preset) + or ( + loaded.companion.personality_preset + if loaded.companion is not None + else base_identity.personality_preset + ) ), initiative=( elephant_state.initiative diff --git a/apps/cli/runtime_prompt_messages.py b/apps/cli/runtime_prompt_messages.py index 93691ba..9e8ec91 100644 --- a/apps/cli/runtime_prompt_messages.py +++ b/apps/cli/runtime_prompt_messages.py @@ -29,7 +29,13 @@ def session_history_messages( if summary: messages.append(PromptMessage(role="assistant", content=summary)) if delivery is not None and delivery.summary.strip(): - messages.append(PromptMessage(role="assistant", content=delivery.summary.strip(), metadata={"source": "delivery"})) + messages.append( + PromptMessage( + role="assistant", + content=delivery.summary.strip(), + metadata={"source": "delivery"}, + ) + ) return tuple(messages) @@ -82,4 +88,3 @@ def prompt_messages_tuple(value: Any) -> tuple[PromptMessage, ...]: ) ) return tuple(messages) - diff --git a/apps/cli/runtime_provider.py b/apps/cli/runtime_provider.py index e7cf573..c0b3cbd 100644 --- a/apps/cli/runtime_provider.py +++ b/apps/cli/runtime_provider.py @@ -7,7 +7,10 @@ from typing import Any from uuid import uuid4 -from apps.provider_runtime import capture_runtime_secret_env, provider_profile_from_payload +from apps.provider_runtime import ( + capture_runtime_secret_env, + provider_profile_from_payload, +) from packages.continuity import RelationshipPolicy, build_relationship_policy from packages.auth import AuthProfile, SecretReference from packages.embeddings import ( @@ -26,17 +29,22 @@ from packages.models.provider_catalog import provider_definition from packages.models.provider_runtime import ProviderCatalogRecord, ProviderSetupGuide from packages.security import SecurityPolicy, default_surface_policy_bundles -from packages.state import CompanionSettings, LoadedProfile, normalize_profile_mode -from packages.state.loader import companion_manifest_payload +from packages.state import CompanionSettings, LoadedProfile from .runtime_voice import VoiceInputRequest, build_provider_voice_service from .runtime_cognition import _CliContextCapability from .runtime_extensions import _PreviewTelemetrySink -from .runtime_support import CliVoiceTurnResult, _PLACEHOLDER_MODELS_BY_PROVIDER, _iso, _utc_now +from .runtime_support import ( + CliVoiceTurnResult, + _PLACEHOLDER_MODELS_BY_PROVIDER, + _iso, + _utc_now, +) _EMBEDDING_API_KEY_ENV_VAR = OPENAI_COMPATIBLE_EMBED_DEFAULT_SECRET_ENV_VAR _EMBEDDING_API_KEY_REFERENCE_ID = OPENAI_COMPATIBLE_EMBED_SECRET_REFERENCE_ID + class CliRuntimeProviderMixin: def provider_summary(self) -> Mapping[str, object]: return self.model_provider.describe() @@ -201,7 +209,10 @@ def embedding_provider_summary(self) -> Mapping[str, object]: provider = dict(self.provider_summary()) profile = self._active_embedding_provider_profile() if profile is not None: - reference = next((item for item in profile.secret_references if item.secret_key == "api_key"), None) + reference = next( + (item for item in profile.secret_references if item.secret_key == "api_key"), + None, + ) reference_id = reference.reference_id if reference is not None else "" has_secret = bool(reference_id) and self.repository.has_auth_secret_value(reference_id) return { @@ -420,7 +431,13 @@ def provider_doctor(self, *, deep: bool = True) -> dict[str, Any]: ) ) except Exception as error: # pragma: no cover - defensive surface guard - checks.append({"check": "model_catalog", "status": "not-ready", "summary": str(error)}) + checks.append( + { + "check": "model_catalog", + "status": "not-ready", + "summary": str(error), + } + ) else: live_models = tuple(model for model in discovered_models if model.source != "catalog-hint") if live_models: @@ -448,9 +465,7 @@ def provider_doctor(self, *, deep: bool = True) -> dict[str, Any]: "or enter the exact model id before running runtime checks" ) elif live_models and configured_model and configured_model not in {model.model_id for model in live_models}: - probe_error = ( - f"configured model '{configured_model}' was not returned by the provider model catalog" - ) + probe_error = f"configured model '{configured_model}' was not returned by the provider model catalog" else: try: probe = self.provider_test(prompt="Doctor check") @@ -493,9 +508,7 @@ def security_doctor(self) -> dict[str, Any]: if self.repository.has_auth_secret_value(reference.reference_id) ) missing_reference_ids = tuple( - reference.reference_id - for reference in secret_refs - if reference.reference_id not in stored_reference_ids + reference.reference_id for reference in secret_refs if reference.reference_id not in stored_reference_ids ) checks: list[dict[str, object]] = [ { @@ -505,21 +518,14 @@ def security_doctor(self) -> dict[str, Any]: }, { "check": "secret_boundary", - "status": ( - "ok" - if not missing_reference_ids - else "warning" - ), + "status": ("ok" if not missing_reference_ids else "warning"), "summary": ( "preview fallback carries no runtime provider secrets" if provider["source"] != "configured" and embedding_profile is None else ( "provider and embedding secrets are stored in the encrypted local vault" if not missing_reference_ids - else ( - "missing stored provider secrets for " - + ", ".join(missing_reference_ids) - ) + else ("missing stored provider secrets for " + ", ".join(missing_reference_ids)) ) ), }, @@ -530,16 +536,10 @@ def security_doctor(self) -> dict[str, Any]: }, ] return { - "status": ( - "ready" - if not missing_reference_ids - else "not-ready" - ), + "status": ("ready" if not missing_reference_ids else "not-ready"), "provider": provider, "checks": checks, - "surface_bundles": tuple( - bundle.to_record(policy) for bundle in default_surface_policy_bundles() - ), + "surface_bundles": tuple(bundle.to_record(policy) for bundle in default_surface_policy_bundles()), "support_bundle": self.security_support_bundle(), } @@ -764,7 +764,11 @@ def set_default_provider( extra_headers=extra_headers, ) # Write provider to config.yaml - from packages.runtime_config import save_provider_to_config, global_config_path_for_state_dir + from packages.runtime_config import ( + save_provider_to_config, + global_config_path_for_state_dir, + ) + config_path = global_config_path_for_state_dir(self.paths.state_dir) save_provider_to_config( config_path, diff --git a/apps/cli/runtime_records.py b/apps/cli/runtime_records.py index 9bff6e9..cbe4f23 100644 --- a/apps/cli/runtime_records.py +++ b/apps/cli/runtime_records.py @@ -5,14 +5,11 @@ from collections.abc import Mapping from dataclasses import replace from datetime import datetime, timezone -import json -from pathlib import Path import shutil from typing import Any -from uuid import uuid4 from packages.contracts.layers import Episode -from packages.contracts.runtime import EvidenceRetrievalRequest, EvidenceRetrievalResult, RecallEvidence +from packages.contracts.runtime import RecallEvidence from packages.evidence import ( UnifiedRecallRequest, render_recall_hit, @@ -25,30 +22,20 @@ ) from packages.state.canonical import build_canonical_profile_state from packages.state.governance import parse_elephant_identity_display_name -from packages.state.loader import profile_manifest_payload from packages.state.persistence import ( load_persisted_canonical_state, resolve_runtime_state, sync_canonical_profile_state, ) -from .runtime_cognition import ( - _list_scope_recall_evidence, - _recall_query_seed, - _recall_query_with_relationship, - _recall_scope_reason, - _recall_scope_session_ids, -) from .runtime_snapshot import load_snapshot_state_focus from .runtime_support import ( EggSummary, _PlanningRecallRecovery, _elephant_state_id, - _coerce_str_tuple, - _optional_datetime, - _utc_now, ) + def _hidden_elephant_id(elephant_id: str) -> bool: return str(elephant_id or "").strip().startswith("learn-live") @@ -107,7 +94,16 @@ def list_herd(self, *, limit: int = 12) -> tuple[EggSummary, ...]: ) ) herd = tuple(herd_items) - ordered = tuple(sorted(herd, key=lambda item: (item.updated_at or datetime.min.replace(tzinfo=timezone.utc), item.elephant_id), reverse=True)) + ordered = tuple( + sorted( + herd, + key=lambda item: ( + item.updated_at or datetime.min.replace(tzinfo=timezone.utc), + item.elephant_id, + ), + reverse=True, + ) + ) return ordered[:limit] def latest_session_for_elephant(self, elephant_id: str) -> Episode | None: @@ -205,9 +201,7 @@ def ensure_elephant_state( } keep_summary = existing.summary if existing.summary not in _seed_summary_markers else "" keep_context_note = ( - existing.current_context_note - if existing.current_context_note not in _seed_summary_markers - else "" + existing.current_context_note if existing.current_context_note not in _seed_summary_markers else "" ) updated = replace( existing, @@ -219,7 +213,10 @@ def ensure_elephant_state( elephant_identity_text=elephant_identity_text, summary=keep_summary, current_context_note=keep_context_note, - metadata={**dict(existing.metadata), "profile_id": session.personal_model_id}, + metadata={ + **dict(existing.metadata), + "profile_id": session.personal_model_id, + }, ) self.repository.upsert_state(updated) refreshed = self.repository.load_state(updated.state_id) @@ -265,7 +262,9 @@ def _profile_ids_for_sessions(self, session_ids: tuple[str, ...]) -> tuple[str, return tuple(profile_ids) def _delete_elephant_file_dirs(self, elephant_ids: tuple[str, ...]) -> None: - cleaned_elephant_ids = tuple(dict.fromkeys(elephant_id.strip() for elephant_id in elephant_ids if elephant_id.strip())) + cleaned_elephant_ids = tuple( + dict.fromkeys(elephant_id.strip() for elephant_id in elephant_ids if elephant_id.strip()) + ) for elephant_id in cleaned_elephant_ids: shutil.rmtree(self.paths.elephant_file_path(elephant_id), ignore_errors=True) @@ -274,8 +273,8 @@ def elephant_id_for_session(self, session: Episode) -> str: return session.elephant_id # Infer from state_id: state:milo -> milo (exclude non-elephant states like state:xxx:default) state_id = str(getattr(session, "state_id", "") or "").strip() - if state_id.startswith("state:") and ":" not in state_id[len("state:"):]: - inferred = state_id[len("state:"):] + if state_id.startswith("state:") and ":" not in state_id[len("state:") :]: + inferred = state_id[len("state:") :] if inferred: return inferred lineage = self.repository.episode_lineage(session.episode_id) @@ -354,12 +353,26 @@ def recall_evidence( no hybrid hit comes back. No record ids are returned. """ normalized_scope = scope.strip().lower() or "all" - if normalized_scope not in {"personal_model", "state", "episodes", "episode", "steps", "sources", "all"}: + if normalized_scope not in { + "personal_model", + "state", + "episodes", + "episode", + "steps", + "sources", + "all", + }: normalized_scope = "all" capped = max(1, min(int(limit or 5), 10)) if normalized_scope == "all": - scopes: tuple[str, ...] = ("personal_model", "state", "episodes", "steps", "sources") + scopes: tuple[str, ...] = ( + "personal_model", + "state", + "episodes", + "steps", + "sources", + ) else: scopes = (normalized_scope,) @@ -408,7 +421,12 @@ def _load_profile(self, profile_id: str) -> LoadedProfile: profile_loader=self.profile_loader, ) # Merge config.yaml extensions into the manifest so extension data is available - from packages.runtime_config import load_extensions_from_config, global_config_path_for_state_dir, load_global_config + from packages.runtime_config import ( + load_extensions_from_config, + global_config_path_for_state_dir, + load_global_config, + ) + config_path = global_config_path_for_state_dir(self.paths.state_dir) try: config = load_global_config( @@ -440,7 +458,12 @@ def _load_session(self, session_id: str) -> Episode: def _load_profile_manifest(self) -> dict[str, Any]: """Load extension manifest data from config.yaml.""" - from packages.runtime_config import load_extensions_from_config, global_config_path_for_state_dir, load_global_config + from packages.runtime_config import ( + load_extensions_from_config, + global_config_path_for_state_dir, + load_global_config, + ) + config_path = global_config_path_for_state_dir(self.paths.state_dir) try: config = load_global_config( @@ -456,7 +479,11 @@ def _load_profile_manifest(self) -> dict[str, Any]: def _write_profile_manifest(self, manifest: Mapping[str, Any]) -> None: """Write extension manifest data to config.yaml.""" - from packages.runtime_config import save_extensions_to_config, global_config_path_for_state_dir + from packages.runtime_config import ( + save_extensions_to_config, + global_config_path_for_state_dir, + ) + config_path = global_config_path_for_state_dir(self.paths.state_dir) save_extensions_to_config( config_path, @@ -475,14 +502,20 @@ def _persist_profile( resolved_state = resolve_runtime_state( self.repository, personal_model_id=loaded_profile.state.profile_id, - episode_id=(latest_session.episode_id if latest_session is not None and latest_session.personal_model_id == loaded_profile.state.profile_id else None), + episode_id=( + latest_session.episode_id + if latest_session is not None and latest_session.personal_model_id == loaded_profile.state.profile_id + else None + ), required=False, ) # Persist identity to SQLite only; no longer writing profile.json self.repository.upsert_personal_model_runtime_state(loaded_profile.state) canonical_bundle = build_canonical_profile_state( loaded_profile, - elephant_id=resolved_state.elephant_id if resolved_state is not None and resolved_state.elephant_id else None, + elephant_id=resolved_state.elephant_id + if resolved_state is not None and resolved_state.elephant_id + else None, ) sync_canonical_profile_state( self.repository, @@ -492,7 +525,11 @@ def _persist_profile( recall_runtime=self.recall_runtime, surface="cli", state_id=resolved_state.state_id if resolved_state is not None else None, - episode_id=(latest_session.episode_id if latest_session is not None and latest_session.personal_model_id == loaded_profile.state.profile_id else None), + episode_id=( + latest_session.episode_id + if latest_session is not None and latest_session.personal_model_id == loaded_profile.state.profile_id + else None + ), ) reloaded = self._load_profile(loaded_profile.state.profile_id) self.repository.upsert_personal_model_runtime_state(reloaded.state) diff --git a/apps/cli/runtime_snapshot.py b/apps/cli/runtime_snapshot.py index 11ba5fd..81e2e6f 100644 --- a/apps/cli/runtime_snapshot.py +++ b/apps/cli/runtime_snapshot.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Mapping -from dataclasses import dataclass, replace +from dataclasses import replace from datetime import datetime, timezone from pathlib import Path from types import SimpleNamespace @@ -50,7 +50,10 @@ if TYPE_CHECKING: from apps.cli.runtime import CliRuntime -from .runtime_growth_metrics import active_personal_model_facts_for_growth, personal_model_growth_metrics +from .runtime_growth_metrics import ( + active_personal_model_facts_for_growth, + personal_model_growth_metrics, +) from .runtime_support import _resolved_session_skills @@ -82,7 +85,9 @@ def restore_snapshot_state_focus( payload = snapshot.get("state_focus") if not isinstance(payload, Mapping): return None - reasons = tuple(_restore_state_focus_reason(reason) for reason in payload.get("reasons", ()) if isinstance(reason, Mapping)) + reasons = tuple( + _restore_state_focus_reason(reason) for reason in payload.get("reasons", ()) if isinstance(reason, Mapping) + ) candidate_scores = tuple( _restore_state_focus_candidate_score(score) for score in payload.get("candidate_scores", ()) @@ -229,11 +234,7 @@ def _growth_state_predates_profile_sessions( if growth_timestamp is None: return False episodes = runtime.repository.list_episodes() - started_at_values = [ - episode.started_at - for episode in episodes - if episode.personal_model_id == profile_id - ] + started_at_values = [episode.started_at for episode in episodes if episode.personal_model_id == profile_id] if not started_at_values: return False first_started_at = min(started_at_values) @@ -270,7 +271,9 @@ def _build_growth_turn_signals( work_item_status=None, work_item_priority=None, progression_action="", - resume_signal="continue" if any(step.action == "resume" and step.status == "completed" for step in outcome.steps) else "none", + resume_signal="continue" + if any(step.action == "resume" and step.status == "completed" for step in outcome.steps) + else "none", continuity_mode="background" if session.interruption_state else "foreground", execution_outcome=outcome.execution.outcome, experience_status=experience.status if experience is not None else None, @@ -308,9 +311,7 @@ def _promoted_procedure_delta( if not procedures: return 0, () promoted = tuple( - procedure.procedure_id - for procedure in procedures - if procedure.status in {"active", "promoted", "verified"} + procedure.procedure_id for procedure in procedures if procedure.status in {"active", "promoted", "verified"} ) already_recorded = current.promoted_experiences if current is not None else 0 delta = max(0, len(promoted) - already_recorded) @@ -407,9 +408,20 @@ def _next_session_context_epoch( ) -> SessionContextEpoch: disclosures = _skill_disclosure_records(runtime, context=context) frozen_skill_index = _frozen_session_skill_index(runtime, profile=profile, session=session) - can_refresh_episode_open = existing is not None and existing.frozen and event is None and execution is None and not existing.history_messages + can_refresh_episode_open = ( + existing is not None + and existing.frozen + and event is None + and execution is None + and not existing.history_messages + ) if context is None and (existing is None or not existing.frozen or can_refresh_episode_open): - context = _episode_open_frozen_context(runtime, profile=profile, session=session, frozen_skill_index=frozen_skill_index) + context = _episode_open_frozen_context( + runtime, + profile=profile, + session=session, + frozen_skill_index=frozen_skill_index, + ) is_user_turn = event is not None and _snapshot_event_is_user_turn(event.event_type, event.source) fallback_history = session_history_messages( event=event, @@ -460,7 +472,10 @@ def _episode_open_frozen_context( runtime_path_lines = _episode_open_runtime_path_lines(runtime, session=session) runtime_context = ContextRuntime( instruction_refs=stable_prefix_lines + skill_lines + runtime_path_lines, - total_tokens=max(1024, int(getattr(runtime, "active_provider_context_window", lambda: 0)() or 0)), + total_tokens=max( + 1024, + int(getattr(runtime, "active_provider_context_window", lambda: 0)() or 0), + ), ) assembled = runtime_context.assemble_detailed( session, @@ -499,7 +514,9 @@ def _episode_open_frozen_context( return None -def _frozen_skill_shelf_prompt_lines(frozen_skill_index: tuple[FrozenSkillIndexEntry, ...]) -> tuple[str, ...]: +def _frozen_skill_shelf_prompt_lines( + frozen_skill_index: tuple[FrozenSkillIndexEntry, ...], +) -> tuple[str, ...]: if not frozen_skill_index: return () lines = [ @@ -634,7 +651,9 @@ def _skill_index_id(skill_id: str) -> str: return "_".join(part for part in cleaned.split("_") if part) -def _skill_affinity_rows(runtime: CliRuntime, *, personal_model_id: str) -> tuple[tuple[float, str, dict[str, str], str], ...]: +def _skill_affinity_rows( + runtime: CliRuntime, *, personal_model_id: str +) -> tuple[tuple[float, str, dict[str, str], str], ...]: list_facts = getattr(runtime.repository, "list_personal_model_facts", None) if not callable(list_facts): return () @@ -670,10 +689,7 @@ def _frozen_session_skill_ids( profile: PersonalModelRuntimeState, session: Episode, ) -> tuple[str, ...]: - return tuple( - entry.skill_id - for entry in _frozen_session_skill_index(runtime, profile=profile, session=session) - ) + return tuple(entry.skill_id for entry in _frozen_session_skill_index(runtime, profile=profile, session=session)) def _frozen_session_tool_count(runtime: CliRuntime) -> int: @@ -711,9 +727,7 @@ def _skill_disclosure_records( for skill_id in dict.fromkeys(disclosed_skill_ids): definition = runtime.skill_runtime.describe(skill_id) display_name = ( - definition.display_name.strip() - if definition is not None and definition.display_name.strip() - else skill_id + definition.display_name.strip() if definition is not None and definition.display_name.strip() else skill_id ) records.append( SkillDisclosureRecord( @@ -726,9 +740,7 @@ def _skill_disclosure_records( def _skill_disclosure_reason(*, skill_id: str, display_name: str) -> str: - return ( - f"{display_name} ({skill_id}) was disclosed because the runtime recorded an explicit skill overlay." - ) + return f"{display_name} ({skill_id}) was disclosed because the runtime recorded an explicit skill overlay." def _profile_payload(profile: PersonalModelRuntimeState, *, elephant_identity_text: str | None) -> dict[str, Any]: @@ -817,7 +829,9 @@ def _execution_payload(execution: ExecutionResult | None) -> dict[str, Any] | No } -def _state_focus_payload(state_focus: StateFocusDecision | None) -> dict[str, Any] | None: +def _state_focus_payload( + state_focus: StateFocusDecision | None, +) -> dict[str, Any] | None: if state_focus is None: return None return { @@ -851,7 +865,9 @@ def _state_focus_reason_payload(reason: StateFocusReason) -> dict[str, Any]: } -def _state_focus_candidate_score_payload(score: StateFocusCandidateScore) -> dict[str, Any]: +def _state_focus_candidate_score_payload( + score: StateFocusCandidateScore, +) -> dict[str, Any]: return { "candidate_id": score.candidate_id, "kind": score.kind, @@ -880,7 +896,9 @@ def _restore_state_focus_reason(payload: Mapping[str, Any]) -> StateFocusReason: ) -def _restore_state_focus_candidate_score(payload: Mapping[str, Any]) -> StateFocusCandidateScore: +def _restore_state_focus_candidate_score( + payload: Mapping[str, Any], +) -> StateFocusCandidateScore: return StateFocusCandidateScore( candidate_id=str(payload.get("candidate_id") or "").strip(), kind=str(payload.get("kind") or "").strip(), @@ -889,14 +907,9 @@ def _restore_state_focus_candidate_score(payload: Mapping[str, Any]) -> StateFoc heuristics_score=float(payload.get("heuristics_score") or 0.0), embedding_score=float(payload.get("embedding_score") or 0.0), reasons=tuple( - _restore_state_focus_reason(reason) - for reason in payload.get("reasons", ()) - if isinstance(reason, Mapping) + _restore_state_focus_reason(reason) for reason in payload.get("reasons", ()) if isinstance(reason, Mapping) ), - metadata={ - str(key): str(value) - for key, value in dict(payload.get("metadata") or {}).items() - }, + metadata={str(key): str(value) for key, value in dict(payload.get("metadata") or {}).items()}, ) diff --git a/apps/cli/runtime_support.py b/apps/cli/runtime_support.py index e87762b..8621a30 100644 --- a/apps/cli/runtime_support.py +++ b/apps/cli/runtime_support.py @@ -192,7 +192,10 @@ def _resolved_state_for_elephant(repository: Any, elephant_id: str): return direct if hasattr(repository, "list_states"): for state in repository.list_states(): - if state.elephant_id == target or state.state_anchor in {target, f"elephant:{target}"}: + if state.elephant_id == target or state.state_anchor in { + target, + f"elephant:{target}", + }: return state return None diff --git a/apps/cli/runtime_turns.py b/apps/cli/runtime_turns.py index d1e7293..70c5976 100644 --- a/apps/cli/runtime_turns.py +++ b/apps/cli/runtime_turns.py @@ -8,7 +8,12 @@ from uuid import uuid4 from packages.contracts.layers import Episode -from packages.contracts.runtime import EventEnvelope, ExecutionResult, PersonalModelRuntimeState, PromptMessage +from packages.contracts.runtime import ( + EventEnvelope, + ExecutionResult, + PersonalModelRuntimeState, + PromptMessage, +) from packages.kernel import ( KernelDependencies, KernelOutcome, @@ -20,7 +25,9 @@ ) from packages.context.compress import split_for_compress, _deterministic_summary from packages.kernel.context_compaction import projection_compaction_detail -from packages.kernel.episode_state_machine import open_next_episode as _open_next_episode +from packages.kernel.episode_state_machine import ( + open_next_episode as _open_next_episode, +) from packages.storage.repository_support import DEFAULT_PERSONAL_MODEL_ID from packages.state import ( ensure_elephant_identity_file, @@ -336,7 +343,10 @@ def run_turn( refreshed_session = runtime._load_session(session.episode_id) persisted_profile = runtime._load_profile(refreshed_session.personal_model_id) decision_summary = _decision_summary_from_outcome(outcome) - observed_event = replace(event, payload=_payload_with_turn_reasoning(event.payload, outcome, decision_summary=decision_summary)) + observed_event = replace( + event, + payload=_payload_with_turn_reasoning(event.payload, outcome, decision_summary=decision_summary), + ) if performed_turn_reconciliation: turn_observation = ReconciliationPipeline().observe_turn( inbound_event=observed_event, @@ -509,10 +519,10 @@ def _reflect_compress_summary( object.__setattr__(runtime, "sub_agent_active", previous_sub_agent_active) - def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutcome) -> KernelOutcome: """Trigger synchronous reflect-based context compression when usage is high.""" import logging + log = logging.getLogger(__name__) usage_tokens = _execution_context_usage_tokens(outcome.execution) @@ -522,23 +532,36 @@ def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutco return outcome trigger_tokens = max(1, int(context_limit * _USAGE_AFTER_TURN_COMPACTION_RATIO)) if usage_tokens < trigger_tokens: - log.debug("compress skipped: usage %s < trigger %s (limit=%s ratio=%s)", - usage_tokens, trigger_tokens, context_limit, _USAGE_AFTER_TURN_COMPACTION_RATIO) + log.debug( + "compress skipped: usage %s < trigger %s (limit=%s ratio=%s)", + usage_tokens, + trigger_tokens, + context_limit, + _USAGE_AFTER_TURN_COMPACTION_RATIO, + ) return outcome - log.info("compress triggered: usage=%s trigger=%s limit=%s session=%s", - usage_tokens, trigger_tokens, context_limit, outcome.route_session_id) + log.info( + "compress triggered: usage=%s trigger=%s limit=%s session=%s", + usage_tokens, + trigger_tokens, + context_limit, + outcome.route_session_id, + ) # Load the frozen epoch to get history messages from apps.cli.runtime_snapshot import restore_snapshot_session_context_epoch from apps.cli.snapshot_io import load_snapshot_payload + snapshot_path = getattr(runtime, "snapshot_path", None) if snapshot_path is None: log.warning("compress skipped: snapshot_path is None") _emit_compress_skip_stage(runtime, outcome, "snapshot_path_missing", usage_tokens) return outcome snapshot = load_snapshot_payload(snapshot_path) if snapshot_path.exists() else None - frozen_epoch = restore_snapshot_session_context_epoch(snapshot, session_id=outcome.route_session_id) if snapshot else None + frozen_epoch = ( + restore_snapshot_session_context_epoch(snapshot, session_id=outcome.route_session_id) if snapshot else None + ) if frozen_epoch is None: log.warning("compress skipped: frozen_epoch is None (snapshot=%s)", snapshot is not None) _emit_compress_skip_stage(runtime, outcome, "epoch_missing", usage_tokens) @@ -573,8 +596,7 @@ def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutco protected_tail_turns=1, ) if not to_summarize: - log.warning("compress skipped: nothing to summarize (msgs=%d)", - history_count) + log.warning("compress skipped: nothing to summarize (msgs=%d)", history_count) _emit_compress_skip_stage(runtime, outcome, f"nothing_to_summarize_{history_count}", usage_tokens) return outcome @@ -610,7 +632,9 @@ def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutco log.warning( "context compress fallback: hard-truncating history without LLM summary " "(history=%d to_summarize=%d tail=%d)", - history_count, len(to_summarize), len(tail), + history_count, + len(to_summarize), + len(tail), ) summary = _deterministic_summary(to_summarize, history_count=history_count) _emit_post_snapshot_kernel_stage( @@ -631,6 +655,7 @@ def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutco # Update the epoch: summary replaces Episode resume in frozen_prefix, # optionally refresh PM facts, keep only tail messages. from packages.context.session_projection import compact_session_context_epoch + updated_epoch, compaction_result = compact_session_context_epoch( frozen_epoch, total_tokens=context_limit, @@ -645,6 +670,7 @@ def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutco _strip_prompt_sections, _append_prompt_section, ) + updated_prefix = _strip_prompt_sections(updated_epoch.frozen_prefix, "Episode resume") updated_prefix = _append_prompt_section( updated_prefix, @@ -652,6 +678,7 @@ def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutco (f"Reference summary: {summary}",), ) from dataclasses import replace as _dc_replace + updated_epoch = _dc_replace(updated_epoch, frozen_prefix=updated_prefix) # Write the compacted epoch back to snapshot. @@ -662,6 +689,7 @@ def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutco # both session_context_epoch AND the session key to the parent episode. from apps.cli.runtime_snapshot import _session_context_epoch_payload from apps.cli.snapshot_io import load_snapshot_payload, write_snapshot_payload + _snap = load_snapshot_payload(runtime.snapshot_path) or {} _snap["session_context_epoch"] = _session_context_epoch_payload(updated_epoch) # Restore session key — only episode_id matters for epoch session matching. @@ -690,7 +718,10 @@ def _compact_snapshot_after_high_usage(runtime: CliRuntime, outcome: KernelOutco _emit_post_snapshot_kernel_stage(runtime, outcome, record) log.info( "compress completed: %d->%d messages, summary_len=%d, session=%s", - history_count, len(tail), len(summary), outcome.route_session_id, + history_count, + len(tail), + len(summary), + outcome.route_session_id, ) return replace(outcome, stages=(*outcome.stages, record)) @@ -761,11 +792,7 @@ def _emit_compress_skip_stage( outcome, KernelStageRecord( stage="context-compact", - detail=( - f"reason=skip:{reason} " - f"tokens={usage_tokens}->{usage_tokens} " - f"messages=0->0" - ), + detail=(f"reason=skip:{reason} tokens={usage_tokens}->{usage_tokens} messages=0->0"), recorded_at=datetime.now(timezone.utc), ), ) @@ -796,7 +823,10 @@ def _projection_thread_focus(work_items: tuple[Any, ...]) -> str: None, ) if active_work_item is None: - active_work_item = next((work_item for work_item in work_items if str(getattr(work_item, "title", "") or "").strip()), None) + active_work_item = next( + (work_item for work_item in work_items if str(getattr(work_item, "title", "") or "").strip()), + None, + ) return str(getattr(active_work_item, "title", "") or "").strip() if active_work_item is not None else "" @@ -848,9 +878,7 @@ def _wake_rationale_event( ) -> EventEnvelope: scope_summary = ", ".join(recovery.scope_episode_ids) or episode_id content = ( - f"Wake recovery searched scope {scope_summary}. " - f"Reason: {recovery.scope_reason}. " - f"Next step: {wake_summary}" + f"Wake recovery searched scope {scope_summary}. Reason: {recovery.scope_reason}. Next step: {wake_summary}" ) return EventEnvelope( event_id=f"event:{uuid4().hex}", @@ -865,7 +893,9 @@ def _wake_rationale_event( "scope_episode_ids": ",".join(recovery.scope_episode_ids), "scope_reason": recovery.scope_reason, "query": recovery.query, - "resume_packet_summary": recovery.resume_packet.summary if getattr(recovery, "resume_packet", None) is not None else "", + "resume_packet_summary": recovery.resume_packet.summary + if getattr(recovery, "resume_packet", None) is not None + else "", "resume_packet_evidence_ids": ",".join( recovery.resume_packet.evidence_ids if getattr(recovery, "resume_packet", None) is not None else () ), diff --git a/apps/cli/runtime_voice.py b/apps/cli/runtime_voice.py index 9d84a1d..27f5b05 100644 --- a/apps/cli/runtime_voice.py +++ b/apps/cli/runtime_voice.py @@ -20,12 +20,26 @@ AuthProfile, ProfileCredentialResolver, SecretReference, - SecretStore, SecretValueResolution, ) -from packages.state import CompanionSettings, LoadedProfile, build_companion_identity_state -from packages.security import ApprovalClass, PolicyDecision, PolicyResult, SecurityPolicy, SecurityRequest -from packages.telemetry import TelemetrySink, emit_approval_event, emit_delivery_event, emit_failure_event +from packages.state import ( + CompanionSettings, + LoadedProfile, + build_companion_identity_state, +) +from packages.security import ( + ApprovalClass, + PolicyDecision, + PolicyResult, + SecurityPolicy, + SecurityRequest, +) +from packages.telemetry import ( + TelemetrySink, + emit_approval_event, + emit_delivery_event, + emit_failure_event, +) _VOICE_SOURCE_PREVIEW = "preview" _VOICE_SOURCE_PROVIDER_BACKED = "provider-backed" @@ -360,10 +374,7 @@ def synthesize( audio_format=audio_format, credentials=credentials, ) - payload = ( - f"VOICE[{plan.provider_id}/{plan.model_id}/{self.config.voice_name}] " - f"{transcript}" - ) + payload = f"VOICE[{plan.provider_id}/{plan.model_id}/{self.config.voice_name}] {transcript}" return payload.encode("utf-8"), plan def _build_headers( @@ -399,9 +410,7 @@ def open_session(self, profile: LoadedProfile, session_id: str) -> _VoiceRuntime text_first=companion.text_first, started_at=_utc_now(), updated_at=_utc_now(), - status=VoiceModeStatus.READY.value - if companion.text_first - else VoiceModeStatus.BLOCKED.value, + status=VoiceModeStatus.READY.value if companion.text_first else VoiceModeStatus.BLOCKED.value, notes=_dedupe( ( f"identity:{profile.state.display_name}", @@ -485,7 +494,9 @@ def doctor(self, profile: LoadedProfile) -> dict[str, object]: }, { "check": "credentials", - "status": "available" if credentials_ready else ("missing" if summary["supported"] else "not-applicable"), + "status": "available" + if credentials_ready + else ("missing" if summary["supported"] else "not-applicable"), "summary": ",".join(credential_keys) if credential_keys else credential_error, }, { @@ -706,7 +717,9 @@ def process_input( audio_format=audio_format, ) - def _refresh_session(self, session: _VoiceRuntimeSession, input_kind: str, output_kind: str) -> _VoiceRuntimeSession: + def _refresh_session( + self, session: _VoiceRuntimeSession, input_kind: str, output_kind: str + ) -> _VoiceRuntimeSession: return _VoiceRuntimeSession( voice_session_id=session.voice_session_id, profile_id=session.profile_id, @@ -854,10 +867,15 @@ def build_provider_voice_service( voice_name: str = "alloy", ) -> VoiceService: adapter: OneShotVoiceAdapter | None = None - if provider_profile is not None and provider_profile.base_url and provider_profile.transport_id in { - "openai-compatible", - "openai_chat_compatible", - }: + if ( + provider_profile is not None + and provider_profile.base_url + and provider_profile.transport_id + in { + "openai-compatible", + "openai_chat_compatible", + } + ): adapter = OpenAICompatibleVoiceAdapter( OpenAICompatibleVoiceConfig( provider_id=provider_profile.provider_id, diff --git a/apps/cli/shell_banner.py b/apps/cli/shell_banner.py index 494e4b7..c7fb4eb 100644 --- a/apps/cli/shell_banner.py +++ b/apps/cli/shell_banner.py @@ -64,20 +64,27 @@ def status_sections(shell, session, continuity, context_frame, growth): elif continuity.wake_action == "continue": ready_lines.append(("now", "Ready to pick the thread back up when you are.", True)) else: - ready_lines.append(("now", "Bring whatever you want to work on; I will adapt from here.", False)) + ready_lines.append( + ( + "now", + "Bring whatever you want to work on; I will adapt from here.", + False, + ) + ) ready_lines.append( ( "history", - ( - f"{growth.canonical_dialogues} dialogues · " - f"{growth.canonical_active_days} active day(s)" - ), + (f"{growth.canonical_dialogues} dialogues · {growth.canonical_active_days} active day(s)"), growth.canonical_dialogues > 0, ) ) model_lines = [ - ("learning", _learning_job_execution_summary(shell.runtime, personal_model_id), True), + ( + "learning", + _learning_job_execution_summary(shell.runtime, personal_model_id), + True, + ), ] if facts: model_lines.append(("saved", _lens_claim_summary(facts), True)) @@ -89,7 +96,13 @@ def status_sections(shell, session, continuity, context_frame, growth): sub_lens = str(getattr(next_question, "sub_lens", "") or "").strip() question_scope = " · ".join(part for part in (lens, sub_lens) if part) label = "question" if not question_scope else f"question ({question_scope})" - model_lines.append((label, compact_line(str(getattr(next_question, "text", "") or ""), limit=96), True)) + model_lines.append( + ( + label, + compact_line(str(getattr(next_question, "text", "") or ""), limit=96), + True, + ) + ) else: model_lines.append(("questions", "No pending question; I will ask only if it helps.", False)) @@ -119,9 +132,7 @@ def _looks_like_opening_prompt(text: str) -> bool: lowered = " ".join(text.casefold().split()) if any(" ".join(marker.casefold().split()) in lowered for marker in _OPENING_PROMPT_MARKERS): return True - return lowered.startswith("write ") or ( - "assistant_display_name:" in lowered and "current_work_summary:" in lowered - ) + return lowered.startswith("write ") or ("assistant_display_name:" in lowered and "current_work_summary:" in lowered) def _human_facing_state_text(*values: object) -> str: @@ -338,24 +349,46 @@ def _skill_lines( discoverable_count = len(skill_hub_entries) installed_ids = {str(getattr(skill, "skill_id", "") or "") for skill in skills} new_to_install_count = sum( - 1 - for entry in skill_hub_entries - if str(getattr(entry, "skill_id", "") or "") not in installed_ids + 1 for entry in skill_hub_entries if str(getattr(entry, "skill_id", "") or "") not in installed_ids ) lines: list[tuple[str, str, bool]] = [] affinity_summary = _skill_affinity_summary(affinity_facts) lines.append(("affinities", affinity_summary, bool(affinity_facts))) - lines.append(("active", f"{len(enabled_skills)} enabled · {len(builtin_skills)} built-in", bool(enabled_skills))) + lines.append( + ( + "active", + f"{len(enabled_skills)} enabled · {len(builtin_skills)} built-in", + bool(enabled_skills), + ) + ) if authored_skills: lines.append(("built by you", f"{len(authored_skills)} authored skill(s)", True)) if discoverable_count: if new_to_install_count: - lines.append(("discover", f"{discoverable_count} local packages · {new_to_install_count} not installed", True)) + lines.append( + ( + "discover", + f"{discoverable_count} local packages · {new_to_install_count} not installed", + True, + ) + ) else: - lines.append(("discover", f"{discoverable_count} local packages · /skills search ", True)) + lines.append( + ( + "discover", + f"{discoverable_count} local packages · /skills search ", + True, + ) + ) else: - lines.append(("discover", "/skills search when you want more capabilities", False)) + lines.append( + ( + "discover", + "/skills search when you want more capabilities", + False, + ) + ) return tuple(lines) @@ -394,7 +427,9 @@ def _skill_affinity_summary(facts: tuple[object, ...]) -> str: return " · ".join(parts) -def _skill_affinity_rows(facts: tuple[object, ...]) -> tuple[tuple[float, str, dict[str, str], str, str], ...]: +def _skill_affinity_rows( + facts: tuple[object, ...], +) -> tuple[tuple[float, str, dict[str, str], str, str], ...]: rows: list[tuple[float, str, dict[str, str], str, str]] = [] for fact in facts: metadata = {str(key): str(value) for key, value in dict(getattr(fact, "metadata", {}) or {}).items()} @@ -409,12 +444,14 @@ def _skill_affinity_rows(facts: tuple[object, ...]) -> tuple[tuple[float, str, d usage = min(10.0, float(metadata.get("usage_count") or 0.0)) except ValueError: usage = 0.0 - rows.append(( - confidence + (usage * 0.01), - topic, - metadata, - str(getattr(fact, "text", "") or ""), - str(getattr(fact, "status", "") or "active"), - )) + rows.append( + ( + confidence + (usage * 0.01), + topic, + metadata, + str(getattr(fact, "text", "") or ""), + str(getattr(fact, "status", "") or "active"), + ) + ) rows.sort(key=lambda item: (-item[0], item[1])) return tuple(rows) diff --git a/apps/cli/shell_clarify.py b/apps/cli/shell_clarify.py index 9c765d6..1d0b700 100644 --- a/apps/cli/shell_clarify.py +++ b/apps/cli/shell_clarify.py @@ -4,14 +4,19 @@ from dataclasses import dataclass from queue import Empty, Queue -import threading from typing import TYPE_CHECKING, Callable from uuid import uuid4 from packages.contracts.runtime import ExecutionResult from packages.tools.surfaces import ClarifySurface -from .shell_stack import Condition, ConditionalContainer, FormattedText, FormattedTextControl, Window +from .shell_stack import ( + Condition, + ConditionalContainer, + FormattedText, + FormattedTextControl, + Window, +) if TYPE_CHECKING: from .shell import ProductizedShell @@ -124,7 +129,12 @@ def render_clarify_fragments(shell: ProductizedShell) -> FormattedText: for index, choice in enumerate(state.choices, start=1): fragments.append(("class:clarify-choice", f"\n{index}. {choice}")) fragments.append(("", "\n")) - fragments.append(("class:clarify-hint", "Type a number or a custom answer, then press Enter.")) + fragments.append( + ( + "class:clarify-hint", + "Type a number or a custom answer, then press Enter.", + ) + ) fragments.append(("", "\n")) else: fragments.append(("", "\n")) diff --git a/apps/cli/shell_clipboard.py b/apps/cli/shell_clipboard.py index 3108263..6278411 100644 --- a/apps/cli/shell_clipboard.py +++ b/apps/cli/shell_clipboard.py @@ -72,8 +72,9 @@ def _detect_paste_intent(text: str) -> str: if "traceback (most recent call last)" in lowered: return "error" error_marker_count = sum( - 1 for line in stripped.splitlines() - if line.strip().startswith(("File \"", "File '", " at ", "Error:", "Exception", "Caused by:")) + 1 + for line in stripped.splitlines() + if line.strip().startswith(('File "', "File '", " at ", "Error:", "Exception", "Caused by:")) ) if error_marker_count >= 2: return "error" @@ -84,7 +85,18 @@ def _detect_paste_intent(text: str) -> str: stripped_line = line.strip() if not stripped_line: continue - if stripped_line.startswith(("def ", "class ", "import ", "from ", "function ", "const ", "let ", "var ")): + if stripped_line.startswith( + ( + "def ", + "class ", + "import ", + "from ", + "function ", + "const ", + "let ", + "var ", + ) + ): code_markers += 2 elif stripped_line.endswith((":", "{", "}", ");", ";")): code_markers += 1 @@ -150,7 +162,8 @@ def import_system_clipboard(*, storage_dir: Path) -> tuple[ClipboardAttachment, attachments = [ attachment for raw_path in probe.paths - if (attachment := build_path_attachment(raw_path, kind_hint="image" if probe.kind == "image" else None)) is not None + if (attachment := build_path_attachment(raw_path, kind_hint="image" if probe.kind == "image" else None)) + is not None ] return tuple(attachments) return () @@ -165,9 +178,7 @@ def compile_submission(raw_text: str, attachments: Iterable[ClipboardAttachment] display_command = _expanded_display_command(normalized, items) or compact_display_command prompt_parts: list[str] = [normalized] if normalized else [] prompt_parts.extend( - attachment.prompt_fragment.strip() - for attachment in items - if attachment.prompt_fragment.strip() + attachment.prompt_fragment.strip() for attachment in items if attachment.prompt_fragment.strip() ) command = "\n\n".join(part for part in prompt_parts if part) visible = display_command or normalized or command @@ -252,11 +263,7 @@ def _probe_macos_clipboard(*, storage_dir: Path) -> _ClipboardProbe: if kind == "text": return _ClipboardProbe(kind="text", text=str(payload.get("text") or "")) if kind == "files": - paths = tuple( - str(path).strip() - for path in payload.get("paths", ()) - if str(path or "").strip() - ) + paths = tuple(str(path).strip() for path in payload.get("paths", ()) if str(path or "").strip()) return _ClipboardProbe(kind="files", paths=paths) if kind == "image": path = str(payload.get("path") or "").strip() @@ -290,28 +297,28 @@ def _macos_clipboard_script(image_path: str) -> str: 'ObjC.import("AppKit");' 'ObjC.import("Foundation");' 'function emit(obj){var text=JSON.stringify(obj)+"\\n";' - 'var data=$(text).dataUsingEncoding($.NSUTF8StringEncoding);' - '$.NSFileHandle.fileHandleWithStandardOutput.writeData(data);}' - 'var pb=$.NSPasteboard.generalPasteboard;' - 'var items=pb.pasteboardItems;' - 'var paths=[];' - 'if(items){for(var i=0;i tuple[str, str] | N if name_lc == prefix: return None if name_lc.startswith(prefix): - tail = name[len(first_line):] + tail = name[len(first_line) :] return tail, description return None @@ -218,6 +223,7 @@ def _composer_ghost_fragments(shell: ProductizedShell, buffer): def build_input_meta_window(shell: ProductizedShell, buffer): try: from prompt_toolkit.layout import WindowAlign + align = WindowAlign.RIGHT except Exception: align = None @@ -860,10 +866,12 @@ def _(event) -> None: # Guarded to `turn_active & not_searching` so Esc at the idle # composer stays a no-op. turn_active = Condition( - lambda: shell is not None - and ( - getattr(shell, "_turn_started_at", None) is not None - or bool(getattr(shell, "_streaming_response_active", False)) + lambda: ( + shell is not None + and ( + getattr(shell, "_turn_started_at", None) is not None + or bool(getattr(shell, "_streaming_response_active", False)) + ) ) ) @@ -965,9 +973,11 @@ def _(event) -> None: # chance to edit. If the composer already has text, Up falls through # to prompt_toolkit's default (caret moves up in multi-line buffer). empty_and_idle = Condition( - lambda: shell is not None - and not _history_search_active(shell) - and not (shell is not None and getattr(shell, "_turn_started_at", None) is not None) + lambda: ( + shell is not None + and not _history_search_active(shell) + and not (shell is not None and getattr(shell, "_turn_started_at", None) is not None) + ) ) @bindings.add("up", filter=empty_and_idle) @@ -1018,6 +1028,7 @@ def _(event) -> None: if match is None: # Let the regular completion menu handle it. from prompt_toolkit.key_binding.bindings.named_commands import menu_complete + try: menu_complete(event) except Exception: diff --git a/apps/cli/shell_impl.py b/apps/cli/shell_impl.py index 08e0ab9..2206c9d 100644 --- a/apps/cli/shell_impl.py +++ b/apps/cli/shell_impl.py @@ -3,100 +3,22 @@ from __future__ import annotations from collections import deque -from dataclasses import dataclass -from difflib import unified_diff import os from pathlib import Path -import re -import shlex import sys import threading import time -from packages.contracts import ExperienceRecord -from packages.kernel.runtime import KernelOutcome -from packages.operator.runtime import ( - RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, - build_recall_evidence_operator_surface, - build_profile_operator_surface, - render_recall_evidence_lines, - render_profile_lines, -) -from packages.tools.handler_support import resolve_allowed_path from .provider_flow import provider_setup_defaults, run_provider_selection_wizard from .runtime import CliRuntime from .wizard import WIZARD_BACK, WIZARD_CANCEL -from .shell_composer import ( - build_command_palette as _build_shell_command_palette, - build_composer_body as _build_shell_composer_body, - build_divider_window as _build_shell_divider_window, - build_input_window as _build_shell_input_window, - build_key_bindings as _build_shell_key_bindings, - build_prompt_buffer as _build_shell_prompt_buffer, - build_queue_preview_window as _build_shell_queue_preview_window, - prompt_continuation as _shell_prompt_continuation, - prompt_label as _shell_prompt_label, - prompt_style as _shell_prompt_style, - prompt_style_map as _shell_prompt_style_map, - prompt_toolkit_composer_available as _shell_prompt_toolkit_composer_available, - read_command as _read_shell_command, - shell_history as _shell_history, -) from .shell_clarify import ShellInteractiveClarifySurface from .shell_boot import WAKE_DISPLAY_SECONDS, BootFrameContext, render_boot_frame -from .shell_opening import ( - ShellOpeningContext, - compose_shell_opening_instruction, - compose_shell_opener, -) -from .shell_progress import ( - animations_enabled as _shell_animations_enabled, - render_queued_followup_fragments as _render_shell_queued_followup_fragments, - render_tool_frame as _render_shell_tool_frame, - tool_trace_line as _shell_tool_trace_line, - render_turn_frame as _render_shell_turn_frame, - render_turn_progress_fragments as _render_shell_turn_progress_fragments, - run_tool_with_progress as _run_shell_tool_with_progress, - run_turn_with_progress as _run_shell_turn_with_progress, - run_turn_with_queued_input as _run_shell_turn_with_queued_input, - summarize_progress_prompt as _summarize_shell_progress_prompt, - tool_event_lines as _shell_tool_event_lines, - tool_event_summary as _shell_tool_event_summary, - tool_event_tracker as _shell_tool_event_tracker, - tool_frame_phases as _shell_tool_frame_phases, - turn_phase as _shell_turn_phase, - _tool_trace_emoji as _shell_tool_trace_emoji, -) -from .shell_render import ( - center_brand_block as _center_shell_brand_block, - displayable_experiences as _displayable_shell_experiences, - format_experience_status as _format_shell_experience_status, - growth_panel_lines as _shell_growth_panel_lines, - growth_progress_bar as _shell_growth_progress_bar, - growth_progress_counts as _shell_growth_progress_counts, - recent_activity_lines as _shell_recent_activity_lines, - recent_experience_lines as _shell_recent_experience_lines, - render_brand_column as _render_shell_brand_column, - render_chat_entry as _render_shell_chat_entry, - render_entry as _render_shell_entry, - render_elephant_brand_mark as _render_shell_elephant_mark, - render_growth_mark_for_stage as _render_shell_growth_mark, - render_pending_entries as _render_shell_pending_entries, - render_shell_frame as _render_shell_frame_view, - render_status_column as _render_shell_status_column, - should_display_experience as _should_display_shell_experience, - styled_growth_progress_bar as _styled_shell_growth_progress_bar, -) from .shell_stack import ( Align, - Completion, - Completer, Console, Document, - FormattedText, Group, - PROMPT_TOOLKIT_AVAILABLE, Panel, RICH_AVAILABLE, Table, @@ -110,7 +32,6 @@ BRAND_MUTED, COMMAND_PALETTE_VISIBLE_ROWS, ELEPHANT_STAGE_ROWS, - ELEPHANT_STAGE_ROWS, GROWTH_HIGHLIGHT_FG, GROWTH_PROGRESS_EMPTY, GROWTH_PROGRESS_FILLED, @@ -123,13 +44,10 @@ SHELL_WELCOME_HEADLINE, USER_HISTORY_BG, USER_HISTORY_FG, - WEB_URL_PATTERN, - compact_line as _compact_line, centered_elephant_rows as _centered_elephant_rows, display_path as _display_path, display_width as _display_width, render_elephant_mark, - render_stage_zero_elephant_mark, resolve_elephant_version as _resolve_elephant_version, ) @@ -173,7 +91,6 @@ ] - from .shell_support_runtime import * # noqa: F401,F403 from . import shell_methods_commands as _shell_commands from . import shell_methods_dispatch as _shell_dispatch @@ -184,6 +101,7 @@ from . import turn_metrics as _shell_turn_metrics from . import shell_methods_ui as _shell_ui_methods + def _latest_completed_learning_result_key(runtime: CliRuntime, *, session_id: str) -> str: try: status = runtime.learning_runtime_status(session_id=session_id, limit=8) @@ -205,18 +123,31 @@ class ProductizedShell: ShellCommandSpec("/help", "Open the command palette and interaction hints"), ShellCommandSpec("/status", "Refresh elephant, provider, and Personal Model posture"), ShellCommandSpec("/recall", "Inspect Step/Episode recall evidence for this elephant"), - ShellCommandSpec("/tools", "Inspect, install, toggle, and run built-in or manifest-backed tools"), - ShellCommandSpec("/skills", "Discover, inspect, install, and toggle built-in or external skill packages"), + ShellCommandSpec( + "/tools", + "Inspect, install, toggle, and run built-in or manifest-backed tools", + ), + ShellCommandSpec( + "/skills", + "Discover, inspect, install, and toggle built-in or external skill packages", + ), ShellCommandSpec("/learn", "Queue or run background learning for this episode"), ShellCommandSpec("/gateway", "Inspect gateway posture and open the CLI gateway setup path"), - ShellCommandSpec("/cron", "Inspect, create, pause, resume, and remove built-in scheduled jobs"), - ShellCommandSpec("/providers", "Set or switch the active provider, endpoint, key, and embedding path"), + ShellCommandSpec( + "/cron", + "Inspect, create, pause, resume, and remove built-in scheduled jobs", + ), + ShellCommandSpec( + "/providers", + "Set or switch the active provider, endpoint, key, and embedding path", + ), ShellCommandSpec("/models", "Set or switch the active model and context window"), ShellCommandSpec("/expand", "Reprint the last folded entry in full"), ShellCommandSpec("/clear", "Start a fresh Loop in this elephant and replay the opening reply"), ShellCommandSpec("/exit", "Leave the wake surface"), ) + def __init__(self, runtime: CliRuntime, *, session_id: str, opened: str, debug: bool = False) -> None: self.runtime = runtime self.session_id = session_id @@ -280,6 +211,7 @@ def __init__(self, runtime: CliRuntime, *, session_id: str, opened: str, debug: self._skill_slash_specs = self._load_skill_slash_specs() self._last_learning_notice_id = _latest_completed_learning_result_key(self.runtime, session_id=self.session_id) + def run(self) -> int: if self._use_alternate_screen: # Opt-in fullscreen mode: use the alternate screen buffer like vim/less. @@ -368,6 +300,7 @@ def run(self) -> int: learning_detail = "background learning queued" try: from packages.kernel.episode_state_machine import close_episode + closed = close_episode( self.runtime.repository, self.session_id, @@ -377,7 +310,7 @@ def run(self) -> int: ) # close_episode only enqueues; an explicit worker start is needed to consume the job self.runtime._ensure_learning_worker_if_needed() - learning_detail = f"episode closed · learning queued" + learning_detail = "episode closed · learning queued" except Exception: pass self._append_entry( @@ -400,6 +333,7 @@ def run(self) -> int: sys.stdout.flush() return 0 + def _append_due_cron_jobs(self) -> None: assistant_name = self._assistant_name() for execution in self.runtime.run_due_cron_jobs(session_id=self.session_id): @@ -410,12 +344,15 @@ def _append_due_cron_jobs(self) -> None: meta=f"cron · {execution.job.name}", ) + def _interactive_clarify_surface(self) -> ShellInteractiveClarifySurface: return ShellInteractiveClarifySurface(self) + def _render_startup_sequence(self) -> None: return None + def _render_boot_frame(self): continuity = self.runtime.inspect_continuity(session_id=self.session_id) growth = self.runtime.inspect_growth(session_id=self.session_id) @@ -442,6 +379,7 @@ def _render_boot_frame(self): brand_mark=self._render_growth_mark(stage_id, level=growth.level), ) + def _refresh_shell_frame(self) -> None: current = self._current_shell_frame_token() if not self._use_alternate_screen and self._last_shell_frame_token == current: @@ -456,6 +394,7 @@ def _refresh_shell_frame(self) -> None: self.console.print(self._render_shell_frame()) self._last_shell_frame_token = current + def _refresh_shell_frame_if_needed(self) -> bool: current = self._current_shell_frame_token() if current == self._last_shell_frame_token: @@ -463,6 +402,7 @@ def _refresh_shell_frame_if_needed(self) -> bool: self._refresh_shell_frame() return True + def _print_transition_footer(self) -> None: """No-op. Previously printed a static divider + elephant emoji between turns, but prompt_toolkit's erase_when_done never cleans it up — leaving orphan @@ -470,6 +410,7 @@ def _print_transition_footer(self) -> None: after _render_pending_entries(), so no gap-filling is needed. """ + def _current_shell_frame_token(self) -> tuple[object, ...]: return ( self.session_id, @@ -501,6 +442,7 @@ def _pending_context_compaction_frame_token(self) -> tuple[object, ...]: tuple(stage_tokens), ) + def _append_providers(self, args: list[str]) -> None: action = args[0] if args else "configure" if action == "embeddings": @@ -562,9 +504,7 @@ def _append_providers(self, args: list[str]) -> None: current_provider_id = str(provider.get("provider_id") or "openai-compatible") initial_state = provider_setup_defaults(self.runtime, current_provider_id) initial_state.base_url = str(provider.get("base_url") or initial_state.base_url) - initial_state.model_id = str( - provider.get("model_id") or provider.get("default_model") or initial_state.model_id - ) + initial_state.model_id = str(provider.get("model_id") or provider.get("default_model") or initial_state.model_id) initial_state.reasoning_effort = ( str(provider.get("reasoning_effort")).strip() if provider.get("reasoning_effort") is not None @@ -606,6 +546,7 @@ def _append_providers(self, args: list[str]) -> None: ), ) + def _append_provider_embeddings(self, args: list[str]) -> None: action = args[0] if args else "status" if action == "status": @@ -699,6 +640,7 @@ def _append_provider_embeddings(self, args: list[str]) -> None: ), ) + def _show_growth_celebration_if_needed(self): update = self.runtime.consume_growth_update(session_id=self.session_id) if update is None: @@ -713,6 +655,7 @@ def _show_growth_celebration_if_needed(self): # Application owns the terminal. return transition + def _render_level_up_frame(self, update, *, tick: int): marker = ("*", "+", "*", "+")[tick % 4] body = Text() @@ -729,6 +672,7 @@ def _render_level_up_frame(self, update, *, tick: int): padding=(0, 1), ) + def _render_stage_transition_frame(self, update, *, tick: int): if Table is None or Group is None: return Text( @@ -760,6 +704,7 @@ def _render_stage_transition_frame(self, update, *, tick: int): padding=(0, 1), ) + ProductizedShell.__init__ = __init__ ProductizedShell.run = run ProductizedShell._append_due_cron_jobs = _append_due_cron_jobs @@ -817,7 +762,9 @@ def _render_stage_transition_frame(self, update, *, tick: int): ProductizedShell._startup_state_focus_dispatch_ready = _shell_ui_methods._startup_state_focus_dispatch_ready ProductizedShell._startup_should_hold_user_command = _shell_ui_methods._startup_should_hold_user_command ProductizedShell._mark_startup_user_turn_submitted = _shell_ui_methods._mark_startup_user_turn_submitted -ProductizedShell._startup_should_surface_state_focus_notices = _shell_ui_methods._startup_should_surface_state_focus_notices +ProductizedShell._startup_should_surface_state_focus_notices = ( + _shell_ui_methods._startup_should_surface_state_focus_notices +) ProductizedShell._set_state_focus_runtime_notice = _shell_ui_methods._set_state_focus_runtime_notice ProductizedShell._clear_state_focus_runtime_notice = _shell_ui_methods._clear_state_focus_runtime_notice ProductizedShell._sync_state_focus_runtime_notices = _shell_ui_methods._sync_state_focus_runtime_notices diff --git a/apps/cli/shell_methods_commands.py b/apps/cli/shell_methods_commands.py index d4a10d0..42ae594 100644 --- a/apps/cli/shell_methods_commands.py +++ b/apps/cli/shell_methods_commands.py @@ -2,108 +2,21 @@ from __future__ import annotations -from collections import deque -from dataclasses import dataclass -from datetime import datetime, timezone -from difflib import unified_diff -import os -from pathlib import Path -import re -import shlex import subprocess import sys -import time -from packages.contracts import ExperienceRecord from packages.kernel.runtime import KernelOutcome from packages.operator.runtime import ( RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, build_recall_evidence_operator_surface, - build_profile_operator_surface, render_recall_evidence_lines, - render_profile_lines, ) from packages.state import parse_user_profile_text -from packages.tools.handler_support import resolve_allowed_path -from .provider_flow import provider_setup_defaults, run_provider_selection_wizard -from .runtime import CliRuntime from .shell_progress_support import outcome_state_focus_meta -from .wizard import WIZARD_BACK -from .shell_composer import ( - build_command_palette as _build_shell_command_palette, - build_composer_body as _build_shell_composer_body, - build_divider_window as _build_shell_divider_window, - build_input_window as _build_shell_input_window, - build_key_bindings as _build_shell_key_bindings, - build_prompt_buffer as _build_shell_prompt_buffer, - build_queue_preview_window as _build_shell_queue_preview_window, - prompt_continuation as _shell_prompt_continuation, - prompt_label as _shell_prompt_label, - prompt_style as _shell_prompt_style, - prompt_style_map as _shell_prompt_style_map, - prompt_toolkit_composer_available as _shell_prompt_toolkit_composer_available, - read_command as _read_shell_command, - shell_history as _shell_history, -) -from .shell_boot import WAKE_DISPLAY_SECONDS, BootFrameContext, render_boot_frame -from .shell_opening import ( - ShellOpeningContext, - compose_shell_opening_instruction, - compose_shell_opener, -) -from .shell_progress import ( - animations_enabled as _shell_animations_enabled, - render_queued_followup_fragments as _render_shell_queued_followup_fragments, - render_tool_frame as _render_shell_tool_frame, - tool_trace_line as _shell_tool_trace_line, - render_turn_frame as _render_shell_turn_frame, - render_turn_progress_fragments as _render_shell_turn_progress_fragments, - run_tool_with_progress as _run_shell_tool_with_progress, - run_turn_with_progress as _run_shell_turn_with_progress, - run_turn_with_queued_input as _run_shell_turn_with_queued_input, - summarize_progress_prompt as _summarize_shell_progress_prompt, - tool_event_lines as _shell_tool_event_lines, - tool_event_summary as _shell_tool_event_summary, - tool_event_tracker as _shell_tool_event_tracker, - tool_frame_phases as _shell_tool_frame_phases, - turn_phase as _shell_turn_phase, - _tool_trace_emoji as _shell_tool_trace_emoji, -) -from .shell_render import ( - center_brand_block as _center_shell_brand_block, - displayable_experiences as _displayable_shell_experiences, - format_experience_status as _format_shell_experience_status, - growth_panel_lines as _shell_growth_panel_lines, - growth_progress_bar as _shell_growth_progress_bar, - growth_progress_counts as _shell_growth_progress_counts, - recent_activity_lines as _shell_recent_activity_lines, - recent_experience_lines as _shell_recent_experience_lines, - render_brand_column as _render_shell_brand_column, - render_chat_entry as _render_shell_chat_entry, - render_entry as _render_shell_entry, - render_elephant_brand_mark as _render_shell_elephant_mark, - render_growth_mark_for_stage as _render_shell_growth_mark, - render_pending_entries as _render_shell_pending_entries, - render_shell_frame as _render_shell_frame_view, - render_status_column as _render_shell_status_column, - should_display_experience as _should_display_shell_experience, - styled_growth_progress_bar as _styled_shell_growth_progress_bar, -) from .shell_stack import ( - Align, - Completion, - Completer, Console, Document, - FormattedText, - Group, - Live, - PROMPT_TOOLKIT_AVAILABLE, - Panel, RICH_AVAILABLE, - Table, - Text, ) from .shell_ui import ( BRAND_ACCENT, @@ -119,7 +32,6 @@ GROWTH_PROGRESS_WIDTH, HATCHLING_HEAD_ROWS, HATCHLING_STAGE_ROWS, - HATCHLING_STAGE_ROWS, QUEUE_PREVIEW_INSET, SCOUT_STAGE_ROWS, SEED_STAGE_ROWS, @@ -127,12 +39,9 @@ USER_HISTORY_BG, USER_HISTORY_FG, WEB_URL_PATTERN, - compact_line as _compact_line, centered_elephant_rows as _centered_elephant_rows, - display_path as _display_path, display_width as _display_width, render_elephant_mark, - resolve_elephant_version as _resolve_elephant_version, ) __all__ = [ @@ -169,9 +78,9 @@ ] - from .shell_support_runtime import * # noqa: F401,F403 + def _append_help(self) -> None: lines = [ "Stay in the conversation. Slash commands exist only for orientation and control.", @@ -202,7 +111,10 @@ def _append_help(self) -> None: def _append_personal_model(self, args: list[str]) -> None: action = (args[0] if args else "summary").strip().lower() session = self.runtime.inspect_session(self.session_id) - state = self.runtime.state_for_elephant(self.runtime.elephant_id_for_session(session)) or self.runtime.current_elephant_state() + state = ( + self.runtime.state_for_elephant(self.runtime.elephant_id_for_session(session)) + or self.runtime.current_elephant_state() + ) if state is None: self._append_entry("recovery", "About you", "No active elephant is available yet.") return @@ -252,12 +164,24 @@ def _append_personal_model(self, args: list[str]) -> None: self._append_entry("notice", "About you to confirm", "\n".join(rows)) return if action in {"procedural", "skills"}: - self._append_entry("notice", "How I help", "Procedural patterns are tracked outside Personal Model support rows.") + self._append_entry( + "notice", + "How I help", + "Procedural patterns are tracked outside Personal Model support rows.", + ) return if action in {"learned", "diff", "recent"}: - self._append_entry("notice", "What I learned recently", "Use Personal Model facts and History for recent learning provenance.") + self._append_entry( + "notice", + "What I learned recently", + "Use Personal Model facts and History for recent learning provenance.", + ) return - self._append_entry("recovery", "About you", "Usage: [summary|evidence|uncertain|procedural|learned]") + self._append_entry( + "recovery", + "About you", + "Usage: [summary|evidence|uncertain|procedural|learned]", + ) def _append_gateway(self, args: list[str]) -> None: @@ -306,10 +230,10 @@ def _append_tools(self, args: list[str]) -> None: 'run search: /tools run tool.file.search query="elephant"', 'run web: /tools run tool.web.search query="agentic intelligence"', 'read page: /tools run tool.web.read url="https://example.com"', - 'run todos: /tools run tool.todo.manage action=list', + "run todos: /tools run tool.todo.manage action=list", 'search understanding: /tools run tool.personal_model.search query="review style"', 'update understanding: /tools run tool.personal_model.update action=remember lens=identity topic=identity.style.review.feedback text="prefers direct review" reason="user said this preference"', - 'manage cron: /tools run tool.cron.manage action=list', + "manage cron: /tools run tool.cron.manage action=list", ] ) self._append_entry("notice", "Tools", "\n".join(lines)) @@ -397,6 +321,7 @@ def _append_tools(self, args: list[str]) -> None: return self._append_entry("recovery", "Tools", "Usage: /tools [inspect|enable|disable|install|run]") + def _append_learn(self, args: list[str]) -> None: wait_for_worker = "--wait" in args filtered_args = [item for item in args if item != "--wait"] @@ -410,7 +335,9 @@ def _append_learn(self, args: list[str]) -> None: ] for job in jobs: if isinstance(job, dict): - lines.append(f"- {job.get('status', '')} {job.get('job_type', '')} {job.get('trigger', '')} {job.get('job_id', '')}") + lines.append( + f"- {job.get('status', '')} {job.get('job_type', '')} {job.get('trigger', '')} {job.get('job_id', '')}" + ) if len(lines) == 2: lines.append("") self._append_entry("notice", "Learning", "\n".join(lines)) @@ -419,11 +346,18 @@ def _append_learn(self, args: list[str]) -> None: try: from apps.learning_worker_runtime import stop_learning_worker - result = stop_learning_worker(state_dir=self.runtime.paths.state_dir, reason="operator requested /learn kill") + result = stop_learning_worker( + state_dir=self.runtime.paths.state_dir, + reason="operator requested /learn kill", + ) except Exception as error: self._append_entry("recovery", "Learning", str(error)) return - self._append_entry("notice", "Learning", f"worker stopped · pid={result.get('stopped_pid') or ''}") + self._append_entry( + "notice", + "Learning", + f"worker stopped · pid={result.get('stopped_pid') or ''}", + ) return if command not in {"queue", "run", "start"}: self._append_entry("recovery", "Learning", "Usage: /learn [list|run [--wait]|kill]") @@ -442,7 +376,14 @@ def _append_learn(self, args: list[str]) -> None: detail = f"queued · {job.job_id} · background worker requested" if wait_for_worker: completed = subprocess.run( - (sys.executable, "-m", "apps.learning_worker_command", "--state-dir", str(self.runtime.paths.state_dir), "--once"), + ( + sys.executable, + "-m", + "apps.learning_worker_command", + "--state-dir", + str(self.runtime.paths.state_dir), + "--once", + ), check=False, ) exit_code = int(completed.returncode or 0) @@ -571,7 +512,11 @@ def _append_skills(self, args: list[str]) -> None: return if command == "install": if len(args) < 2: - self._append_entry("recovery", "Skills", "Usage: /skills install ") + self._append_entry( + "recovery", + "Skills", + "Usage: /skills install ", + ) return try: result = self.runtime.install_skill_source(args[1], session_id=self.session_id) @@ -591,7 +536,11 @@ def _append_skills(self, args: list[str]) -> None: ) self._refresh_skill_slash_specs() return - self._append_entry("recovery", "Skills", "Usage: /skills [list|active|search|view|enable|disable|install]") + self._append_entry( + "recovery", + "Skills", + "Usage: /skills [list|active|search|view|enable|disable|install]", + ) def _display_skill_reference(entry) -> str: @@ -599,13 +548,13 @@ def _display_skill_reference(entry) -> str: return str(getattr(entry, "skill_id", "")).strip() or str(getattr(entry, "reference", "")) return str(getattr(entry, "reference", "")).strip() + def _append_cron(self, args: list[str]) -> None: command = args[0] if args else "list" if command in {"list", "ls"}: jobs = self.runtime.cron_jobs(session_id=self.session_id) lines = [ - f"{job.job_id} | {job.status} | {job.name} | {job.schedule_text} | {job.action_kind}" - for job in jobs + f"{job.job_id} | {job.status} | {job.name} | {job.schedule_text} | {job.action_kind}" for job in jobs ] or [""] lines.extend( [ @@ -712,6 +661,7 @@ def _append_cron(self, args: list[str]) -> None: return self._append_entry("recovery", "Cron jobs", "Usage: /cron [create|inspect|pause|resume|remove]") + def _parse_named_arguments(self, args: list[str]) -> dict[str, str]: payload: dict[str, str] = {} for item in args: @@ -724,6 +674,7 @@ def _parse_named_arguments(self, args: list[str]) -> dict[str, str]: payload[key] = self._strip_wrapping_quotes(value.strip()) return payload + def _requested_webpage_url(self, message: str) -> str | None: lowered = message.strip().lower() match = WEB_URL_PATTERN.search(message) @@ -747,11 +698,13 @@ def _requested_webpage_url(self, message: str) -> str | None: return None return match.group(1).rstrip(").,!?\"'") + def _strip_wrapping_quotes(self, value: str) -> str: if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}: return value[1:-1].strip() return value + def _append_status(self) -> None: session = self.runtime.inspect_session(self.session_id) provider = dict(self.runtime.provider_summary()) @@ -841,6 +794,7 @@ def _append_status(self) -> None: lines.append("next: keep talking") self._append_entry("status", "Elephant status", "\n".join(lines)) + def _append_recall(self, args: list[str]) -> None: action = args[0] if args else "inspect" if action in {"inspect", "show", "list", "ls"} and len(args) <= 1: @@ -871,7 +825,11 @@ def _append_recall(self, args: list[str]) -> None: evidence_items=(detail,), index_policy=self.runtime.recall_runtime.index_policy(), ) - self._append_entry("notice", "Understanding detail", "\n".join(render_recall_evidence_lines(surface))) + self._append_entry( + "notice", + "Understanding detail", + "\n".join(render_recall_evidence_lines(surface)), + ) return if action == "search": query = " ".join(args[1:]).strip() @@ -879,7 +837,11 @@ def _append_recall(self, args: list[str]) -> None: self._append_entry("recovery", "Understanding", "Usage: /recall search ") return surface = self.runtime.search_recall_evidence_surface(self.session_id, query=query) - self._append_entry("notice", "Understanding search", "\n".join(render_recall_evidence_lines(surface))) + self._append_entry( + "notice", + "Understanding search", + "\n".join(render_recall_evidence_lines(surface)), + ) return if action == "lineage": if len(args) < 2: @@ -900,6 +862,7 @@ def _append_recall(self, args: list[str]) -> None: return self._append_entry("recovery", "Understanding", "Usage: /recall [list|inspect|search|lineage]") + def _append_outcome(self, outcome: KernelOutcome) -> None: self._last_prompt_tokens = outcome.execution.prompt_tokens self._last_completion_tokens = outcome.execution.completion_tokens @@ -910,7 +873,9 @@ def _append_outcome(self, outcome: KernelOutcome) -> None: for stage in outcome.stages ] self._append_entry("status", "Runtime stages", "\n".join(stage_lines)) - assistant_name = self.runtime.inspect_profile(self.runtime.inspect_session(self.session_id).personal_model_id).state.display_name + assistant_name = self.runtime.inspect_profile( + self.runtime.inspect_session(self.session_id).personal_model_id + ).state.display_name assistant_lines = [outcome.execution.summary] if self.debug and outcome.plan is not None: assistant_lines.append(f"plan: {outcome.plan.rationale}") @@ -922,7 +887,13 @@ def _append_outcome(self, outcome: KernelOutcome) -> None: f"recall_hits: {len(outcome.recall_items)}", ] ) - self._append_entry("assistant", assistant_name, "\n".join(assistant_lines), meta=outcome_state_focus_meta(outcome)) + self._append_entry( + "assistant", + assistant_name, + "\n".join(assistant_lines), + meta=outcome_state_focus_meta(outcome), + ) + def _append_growth_update_message(self, update) -> None: if update is None: @@ -932,12 +903,13 @@ def _append_growth_update_message(self, update) -> None: after = update.after after_checkpoint = getattr(after, "level", 0) after_growth = getattr(after, "cycle_label", "Evidence I") - after_identity = getattr(after, "identity_line", getattr(getattr(after, "stage", None), "title", "Elephant Agent")) + after_identity = getattr( + after, + "identity_line", + getattr(getattr(after, "stage", None), "title", "Elephant Agent"), + ) if update.stage_changed: - body = ( - f"The path is clearer now: {after_identity}. " - "I'll carry this understanding forward." - ) + body = f"The path is clearer now: {after_identity}. I'll carry this understanding forward." meta = "understanding · clearer path" else: body = ( diff --git a/apps/cli/shell_methods_dispatch.py b/apps/cli/shell_methods_dispatch.py index 2d1fc95..4ddb921 100644 --- a/apps/cli/shell_methods_dispatch.py +++ b/apps/cli/shell_methods_dispatch.py @@ -2,104 +2,15 @@ from __future__ import annotations -from collections import deque -from dataclasses import dataclass -from difflib import unified_diff -import os -from pathlib import Path import re import shlex import threading -import time -from packages.contracts import ExperienceRecord from packages.kernel.runtime import KernelOutcome -from packages.operator.runtime import ( - RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, - build_recall_evidence_operator_surface, - build_profile_operator_surface, - render_recall_evidence_lines, - render_profile_lines, -) -from packages.tools.handler_support import resolve_allowed_path -from .provider_flow import provider_setup_defaults, run_provider_selection_wizard -from .runtime import CliRuntime -from .wizard import WIZARD_BACK -from .shell_composer import ( - build_command_palette as _build_shell_command_palette, - build_composer_body as _build_shell_composer_body, - build_divider_window as _build_shell_divider_window, - build_input_window as _build_shell_input_window, - build_key_bindings as _build_shell_key_bindings, - build_prompt_buffer as _build_shell_prompt_buffer, - build_queue_preview_window as _build_shell_queue_preview_window, - prompt_continuation as _shell_prompt_continuation, - prompt_label as _shell_prompt_label, - prompt_style as _shell_prompt_style, - prompt_style_map as _shell_prompt_style_map, - prompt_toolkit_composer_available as _shell_prompt_toolkit_composer_available, - read_command as _read_shell_command, - shell_history as _shell_history, -) -from .shell_boot import WAKE_DISPLAY_SECONDS, BootFrameContext, render_boot_frame -from .shell_opening import ( - ShellOpeningContext, - compose_shell_opening_instruction, - compose_shell_opener, -) -from .shell_progress import ( - animations_enabled as _shell_animations_enabled, - render_queued_followup_fragments as _render_shell_queued_followup_fragments, - render_tool_frame as _render_shell_tool_frame, - tool_trace_line as _shell_tool_trace_line, - render_turn_frame as _render_shell_turn_frame, - render_turn_progress_fragments as _render_shell_turn_progress_fragments, - run_tool_with_progress as _run_shell_tool_with_progress, - run_turn_with_progress as _run_shell_turn_with_progress, - run_turn_with_queued_input as _run_shell_turn_with_queued_input, - summarize_progress_prompt as _summarize_shell_progress_prompt, - tool_event_lines as _shell_tool_event_lines, - tool_event_summary as _shell_tool_event_summary, - tool_event_tracker as _shell_tool_event_tracker, - tool_frame_phases as _shell_tool_frame_phases, - turn_phase as _shell_turn_phase, - _tool_trace_emoji as _shell_tool_trace_emoji, -) -from .shell_render import ( - center_brand_block as _center_shell_brand_block, - displayable_experiences as _displayable_shell_experiences, - format_experience_status as _format_shell_experience_status, - growth_panel_lines as _shell_growth_panel_lines, - growth_progress_bar as _shell_growth_progress_bar, - growth_progress_counts as _shell_growth_progress_counts, - recent_activity_lines as _shell_recent_activity_lines, - recent_experience_lines as _shell_recent_experience_lines, - render_brand_column as _render_shell_brand_column, - render_chat_entry as _render_shell_chat_entry, - render_entry as _render_shell_entry, - render_elephant_brand_mark as _render_shell_elephant_mark, - render_growth_mark_for_stage as _render_shell_growth_mark, - render_pending_entries as _render_shell_pending_entries, - render_shell_frame as _render_shell_frame_view, - render_status_column as _render_shell_status_column, - should_display_experience as _should_display_shell_experience, - styled_growth_progress_bar as _styled_shell_growth_progress_bar, -) from .shell_stack import ( - Align, - Completion, - Completer, Console, Document, - FormattedText, - Group, - Live, - PROMPT_TOOLKIT_AVAILABLE, - Panel, RICH_AVAILABLE, - Table, - Text, ) from .shell_ui import ( BRAND_ACCENT, @@ -115,20 +26,15 @@ GROWTH_PROGRESS_WIDTH, HATCHLING_HEAD_ROWS, HATCHLING_STAGE_ROWS, - HATCHLING_STAGE_ROWS, QUEUE_PREVIEW_INSET, SCOUT_STAGE_ROWS, SEED_STAGE_ROWS, SHELL_WELCOME_HEADLINE, USER_HISTORY_BG, USER_HISTORY_FG, - WEB_URL_PATTERN, - compact_line as _compact_line, centered_elephant_rows as _centered_elephant_rows, - display_path as _display_path, display_width as _display_width, render_elephant_mark, - resolve_elephant_version as _resolve_elephant_version, ) __all__ = [ @@ -165,28 +71,28 @@ ] - from .shell_support_runtime import * # noqa: F401,F403 + def _safe_usage_token_count(value: object) -> int: try: return max(0, int(value or 0)) except (TypeError, ValueError): return 0 + def _execution_prompt_usage_tokens(execution: object) -> int: prompt_tokens = _safe_usage_token_count(getattr(execution, "prompt_tokens", 0)) total_tokens = _safe_usage_token_count(getattr(execution, "total_tokens", 0)) return prompt_tokens or total_tokens + def _outcome_has_context_compaction(outcome: KernelOutcome) -> bool: stages = getattr(outcome, "stages", ()) if not isinstance(stages, tuple | list): return False - return any( - str(getattr(stage, "stage", "") or "") == "context-compact" - for stage in stages - ) + return any(str(getattr(stage, "stage", "") or "") == "context-compact" for stage in stages) + def _outcome_context_compaction_after_tokens(outcome: KernelOutcome) -> int | None: stages = getattr(outcome, "stages", ()) @@ -201,6 +107,7 @@ def _outcome_context_compaction_after_tokens(outcome: KernelOutcome) -> int | No return int(match.group(1)) return None + def _dispatch(self, raw_command: str | PendingShellCommand) -> bool: pending = coerce_pending_shell_command(raw_command) command = pending.command.strip() @@ -275,6 +182,7 @@ def _dispatch(self, raw_command: str | PendingShellCommand) -> bool: self._schedule_post_turn_background() return False + def _schedule_post_turn_background(self) -> None: """Run growth celebration and learning-result checks off the main thread. @@ -301,6 +209,7 @@ def _post_turn_work() -> None: daemon=True, ).start() + def _handle_conversational_surface_request(self, message: str) -> bool: normalized = message.strip().lower().rstrip("?.!") if normalized in { @@ -316,10 +225,7 @@ def _handle_conversational_surface_request(self, message: str) -> bool: ) lines = [ "I can use these tools right now:", - *[ - f"- {tool.display_name} ({tool.tool_id}): {tool.description}" - for tool in tools - ], + *[f"- {tool.display_name} ({tool.tool_id}): {tool.description}" for tool in tools], "", "Ask me naturally if you want one used, or give me a manifest path if you want me to install an external tool.", ] @@ -409,6 +315,7 @@ def _handle_conversational_surface_request(self, message: str) -> bool: return True return False + def _handle_slash_command(self, raw_command: str) -> bool: try: parts = self._parse_slash_command(raw_command) @@ -510,6 +417,7 @@ def _handle_slash_command(self, raw_command: str) -> bool: self._append_entry("command", "Unknown command", f"{command}\nhelp: /help") return False + def _parse_slash_command(self, raw_command: str) -> list[str]: try: return shlex.split(raw_command) @@ -519,5 +427,6 @@ def _parse_slash_command(self, raw_command: str) -> list[str]: return fallback raise + def _text_surface_fallback_parts(self, raw_command: str) -> list[str] | None: return None diff --git a/apps/cli/shell_methods_models.py b/apps/cli/shell_methods_models.py index 8a645fe..02e606c 100644 --- a/apps/cli/shell_methods_models.py +++ b/apps/cli/shell_methods_models.py @@ -55,9 +55,7 @@ def _append_models(self, args: list[str]) -> None: profile = self.runtime.inspect_profile(session.personal_model_id) initial_state = provider_setup_defaults(self.runtime, provider_id) initial_state.base_url = str(provider.get("base_url") or initial_state.base_url) - initial_state.model_id = str( - provider.get("model_id") or provider.get("default_model") or initial_state.model_id - ) + initial_state.model_id = str(provider.get("model_id") or provider.get("default_model") or initial_state.model_id) initial_state.reasoning_effort = ( str(provider.get("reasoning_effort")).strip() if provider.get("reasoning_effort") is not None diff --git a/apps/cli/shell_methods_skills.py b/apps/cli/shell_methods_skills.py index 66eff51..b256d24 100644 --- a/apps/cli/shell_methods_skills.py +++ b/apps/cli/shell_methods_skills.py @@ -2,104 +2,12 @@ from __future__ import annotations -from collections import deque -from dataclasses import dataclass -from difflib import unified_diff -import os -from pathlib import Path -import re -import shlex -import time - -from packages.contracts import ExperienceRecord -from packages.kernel.runtime import KernelOutcome -from packages.operator.runtime import ( - RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, - build_recall_evidence_operator_surface, - build_profile_operator_surface, - render_recall_evidence_lines, - render_profile_lines, -) + from packages.skills import skill_provenance_fields -from packages.tools.handler_support import resolve_allowed_path -from .provider_flow import provider_setup_defaults, run_provider_selection_wizard -from .runtime import CliRuntime -from .wizard import WIZARD_BACK -from .shell_composer import ( - build_command_palette as _build_shell_command_palette, - build_composer_body as _build_shell_composer_body, - build_divider_window as _build_shell_divider_window, - build_input_window as _build_shell_input_window, - build_key_bindings as _build_shell_key_bindings, - build_prompt_buffer as _build_shell_prompt_buffer, - build_queue_preview_window as _build_shell_queue_preview_window, - prompt_continuation as _shell_prompt_continuation, - prompt_label as _shell_prompt_label, - prompt_style as _shell_prompt_style, - prompt_style_map as _shell_prompt_style_map, - prompt_toolkit_composer_available as _shell_prompt_toolkit_composer_available, - read_command as _read_shell_command, - shell_history as _shell_history, -) -from .shell_boot import WAKE_DISPLAY_SECONDS, BootFrameContext, render_boot_frame -from .shell_opening import ( - ShellOpeningContext, - compose_shell_opening_instruction, - compose_shell_opener, -) -from .shell_progress import ( - animations_enabled as _shell_animations_enabled, - render_queued_followup_fragments as _render_shell_queued_followup_fragments, - render_tool_frame as _render_shell_tool_frame, - tool_trace_line as _shell_tool_trace_line, - render_turn_frame as _render_shell_turn_frame, - render_turn_progress_fragments as _render_shell_turn_progress_fragments, - run_tool_with_progress as _run_shell_tool_with_progress, - run_turn_with_progress as _run_shell_turn_with_progress, - run_turn_with_queued_input as _run_shell_turn_with_queued_input, - summarize_progress_prompt as _summarize_shell_progress_prompt, - tool_event_lines as _shell_tool_event_lines, - tool_event_summary as _shell_tool_event_summary, - tool_event_tracker as _shell_tool_event_tracker, - tool_frame_phases as _shell_tool_frame_phases, - turn_phase as _shell_turn_phase, - _tool_trace_emoji as _shell_tool_trace_emoji, -) -from .shell_render import ( - center_brand_block as _center_shell_brand_block, - displayable_experiences as _displayable_shell_experiences, - format_experience_status as _format_shell_experience_status, - growth_panel_lines as _shell_growth_panel_lines, - growth_progress_bar as _shell_growth_progress_bar, - growth_progress_counts as _shell_growth_progress_counts, - recent_activity_lines as _shell_recent_activity_lines, - recent_experience_lines as _shell_recent_experience_lines, - render_brand_column as _render_shell_brand_column, - render_chat_entry as _render_shell_chat_entry, - render_entry as _render_shell_entry, - render_elephant_brand_mark as _render_shell_elephant_mark, - render_growth_mark_for_stage as _render_shell_growth_mark, - render_pending_entries as _render_shell_pending_entries, - render_shell_frame as _render_shell_frame_view, - render_status_column as _render_shell_status_column, - should_display_experience as _should_display_shell_experience, - styled_growth_progress_bar as _styled_shell_growth_progress_bar, -) from .shell_stack import ( - Align, - Completion, - Completer, Console, Document, - FormattedText, - Group, - Live, - PROMPT_TOOLKIT_AVAILABLE, - Panel, RICH_AVAILABLE, - Table, - Text, ) from .shell_ui import ( BRAND_ACCENT, @@ -115,20 +23,15 @@ GROWTH_PROGRESS_WIDTH, HATCHLING_HEAD_ROWS, HATCHLING_STAGE_ROWS, - HATCHLING_STAGE_ROWS, QUEUE_PREVIEW_INSET, SCOUT_STAGE_ROWS, SEED_STAGE_ROWS, SHELL_WELCOME_HEADLINE, USER_HISTORY_BG, USER_HISTORY_FG, - WEB_URL_PATTERN, - compact_line as _compact_line, centered_elephant_rows as _centered_elephant_rows, - display_path as _display_path, display_width as _display_width, render_elephant_mark, - resolve_elephant_version as _resolve_elephant_version, ) __all__ = [ @@ -165,21 +68,25 @@ ] - from .shell_support_runtime import * # noqa: F401,F403 + def recent_session_ids(self) -> tuple[str, ...]: return tuple(session.episode_id for session in self.runtime.recent_sessions(limit=8)) + def recent_elephant_ids(self) -> tuple[str, ...]: return tuple(elephant.elephant_id for elephant in self.runtime.list_herd(limit=8)) + def skill_slash_specs(self) -> tuple[SkillSlashSpec, ...]: return self._skill_slash_specs + def _refresh_skill_slash_specs(self) -> None: self._skill_slash_specs = self._load_skill_slash_specs() + def _load_skill_slash_specs(self) -> tuple[SkillSlashSpec, ...]: reserved = {spec.name.removeprefix("/").lower() for spec in self.command_specs} specs: list[SkillSlashSpec] = [] @@ -202,6 +109,7 @@ def _load_skill_slash_specs(self) -> tuple[SkillSlashSpec, ...]: seen.add(slug) return tuple(specs) + def _resolve_skill_slash_spec(self, command: str) -> SkillSlashSpec | None: normalized = command.strip().lower().replace("_", "-") for spec in self._skill_slash_specs: @@ -209,6 +117,7 @@ def _resolve_skill_slash_spec(self, command: str) -> SkillSlashSpec | None: return spec return None + def _resolve_explicit_skill_request(self, message: str) -> SkillSlashSpec | None: if not message.strip(): return None @@ -241,6 +150,7 @@ def _resolve_explicit_skill_request(self, message: str) -> SkillSlashSpec | None return None return top_spec + def _resolve_contextual_skill_request(self, message: str) -> SkillSlashSpec | None: if not message.strip(): return None @@ -264,11 +174,13 @@ def _resolve_contextual_skill_request(self, message: str) -> SkillSlashSpec | No return None return top_spec + def _resolved_skill_route(self, message: str) -> tuple[SkillSlashSpec, str] | None: del self del message return None + def _dispatch_skill_slash_command(self, raw_command: str, command: str, args: list[str]) -> bool: spec = self._resolve_skill_slash_spec(command) if spec is None: @@ -313,6 +225,7 @@ def _dispatch_skill_slash_command(self, raw_command: str, command: str, args: li self._append_latest_learning_result() return True + def _compose_skill_turn_prompt(self, skill, *, user_instruction: str) -> str: return "\n".join( [ @@ -440,7 +353,11 @@ def _append_skills(self, args: list[str]) -> None: return if command == "install": if len(args) < 2: - self._append_entry("recovery", "Skills", "Usage: /skills install ") + self._append_entry( + "recovery", + "Skills", + "Usage: /skills install ", + ) return try: result = self.runtime.install_skill_source(args[1], session_id=self.session_id) @@ -454,7 +371,11 @@ def _append_skills(self, args: list[str]) -> None: ) self._refresh_skill_slash_specs() return - self._append_entry("recovery", "Skills", "Usage: /skills [list|active|search|view|enable|disable|install]") + self._append_entry( + "recovery", + "Skills", + "Usage: /skills [list|active|search|view|enable|disable|install]", + ) def _skill_provenance_lines(metadata) -> list[str]: diff --git a/apps/cli/shell_methods_trace.py b/apps/cli/shell_methods_trace.py index 010e11f..38d44b8 100644 --- a/apps/cli/shell_methods_trace.py +++ b/apps/cli/shell_methods_trace.py @@ -2,51 +2,14 @@ from __future__ import annotations -from collections import deque -from dataclasses import dataclass from difflib import unified_diff -import os from pathlib import Path import re -import shlex import time from packages.contracts import ExperienceRecord from packages.kernel.runtime import KernelOutcome -from packages.operator.runtime import ( - RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, - build_recall_evidence_operator_surface, - build_profile_operator_surface, - render_recall_evidence_lines, - render_profile_lines, -) from packages.tools.handler_support import resolve_allowed_path -from .provider_flow import provider_setup_defaults, run_provider_selection_wizard -from .runtime import CliRuntime -from .wizard import WIZARD_BACK -from .shell_composer import ( - build_command_palette as _build_shell_command_palette, - build_composer_body as _build_shell_composer_body, - build_divider_window as _build_shell_divider_window, - build_input_window as _build_shell_input_window, - build_key_bindings as _build_shell_key_bindings, - build_prompt_buffer as _build_shell_prompt_buffer, - build_queue_preview_window as _build_shell_queue_preview_window, - prompt_continuation as _shell_prompt_continuation, - prompt_label as _shell_prompt_label, - prompt_style as _shell_prompt_style, - prompt_style_map as _shell_prompt_style_map, - prompt_toolkit_composer_available as _shell_prompt_toolkit_composer_available, - read_command as _read_shell_command, - shell_history as _shell_history, -) -from .shell_boot import WAKE_DISPLAY_SECONDS, BootFrameContext, render_boot_frame -from .shell_opening import ( - ShellOpeningContext, - compose_shell_opening_instruction, - compose_shell_opener, -) from .shell_progress import ( animations_enabled as _shell_animations_enabled, render_queued_followup_fragments as _render_shell_queued_followup_fragments, @@ -74,31 +37,19 @@ growth_progress_counts as _shell_growth_progress_counts, recent_activity_lines as _shell_recent_activity_lines, recent_experience_lines as _shell_recent_experience_lines, - render_brand_column as _render_shell_brand_column, render_chat_entry as _render_shell_chat_entry, render_entry as _render_shell_entry, render_elephant_brand_mark as _render_shell_elephant_mark, render_growth_mark_for_stage as _render_shell_growth_mark, render_pending_entries as _render_shell_pending_entries, - render_shell_frame as _render_shell_frame_view, - render_status_column as _render_shell_status_column, should_display_experience as _should_display_shell_experience, styled_growth_progress_bar as _styled_shell_growth_progress_bar, ) from .shell_stack import ( - Align, - Completion, - Completer, Console, Document, FormattedText, - Group, - Live, - PROMPT_TOOLKIT_AVAILABLE, - Panel, RICH_AVAILABLE, - Table, - Text, ) from .shell_ui import ( BRAND_ACCENT, @@ -114,20 +65,16 @@ GROWTH_PROGRESS_WIDTH, HATCHLING_HEAD_ROWS, HATCHLING_STAGE_ROWS, - HATCHLING_STAGE_ROWS, QUEUE_PREVIEW_INSET, SCOUT_STAGE_ROWS, SEED_STAGE_ROWS, SHELL_WELCOME_HEADLINE, USER_HISTORY_BG, USER_HISTORY_FG, - WEB_URL_PATTERN, compact_line as _compact_line, centered_elephant_rows as _centered_elephant_rows, - display_path as _display_path, display_width as _display_width, render_elephant_mark, - resolve_elephant_version as _resolve_elephant_version, ) __all__ = [ @@ -164,9 +111,9 @@ ] - from .shell_support_runtime import * # noqa: F401,F403 + def _identity_lines(self, profile_id: str) -> list[str]: profile = self.runtime.inspect_profile(profile_id) identity = self.runtime.inspect_identity(profile_id=profile_id) @@ -195,6 +142,7 @@ def _identity_lines(self, profile_id: str) -> list[str]: ) return lines + def _user_lines(self, profile_id: str) -> list[str]: user = self.runtime.inspect_user(profile_id=profile_id) return [ @@ -209,6 +157,7 @@ def _user_lines(self, profile_id: str) -> list[str]: f"shared_preferences: {', '.join(user.shared_preferences) or ''}", ] + def _relationship_lines(self, profile_id: str) -> list[str]: relationship = self.runtime.inspect_relationship(profile_id=profile_id) return [ @@ -221,6 +170,7 @@ def _relationship_lines(self, profile_id: str) -> list[str]: f"continuity_notes: {', '.join(relationship.continuity_notes) or ''}", ] + def _append_entry(self, kind: str, title: str, body: str, *, meta: str = "") -> None: self.transcript.append(TranscriptEntry(kind=kind, title=title, body=body, meta=meta)) if len(self.transcript) > 80: @@ -228,6 +178,7 @@ def _append_entry(self, kind: str, title: str, body: str, *, meta: str = "") -> self.transcript = self.transcript[overflow:] self._rendered_entries = max(0, self._rendered_entries - overflow) + def _append_tooltrace_line(self, line: str) -> None: if self.transcript and self.transcript[-1].kind == "tooltrace": previous = self.transcript[-1] @@ -242,6 +193,7 @@ def _append_tooltrace_line(self, line: str) -> None: return self._append_entry("tooltrace", "Tool trace", line) + def _capture_pending_file_review(self, tool_event: ToolLifecycleEvent) -> None: if tool_event.phase != "requested": return @@ -260,6 +212,7 @@ def _capture_pending_file_review(self, tool_event: ToolLifecycleEvent) -> None: before_text=before_text, ) + def _todo_trace_lines(self) -> tuple[str, ...]: todo_glyph = _shell_tool_trace_emoji("tool.todo.manage") items = self.runtime.todo_store.list_items(self.session_id) @@ -276,6 +229,7 @@ def _todo_trace_lines(self) -> tuple[str, ...]: lines.append(f"┊ {todo_glyph} more {remaining} additional item(s)") return tuple(lines) + def _display_tool_diff_path(self, path: Path) -> str: resolved = path.expanduser().resolve() try: @@ -283,6 +237,7 @@ def _display_tool_diff_path(self, path: Path) -> str: except ValueError: return str(resolved) + def _file_review_trace_lines(self, tool_event: ToolLifecycleEvent) -> tuple[str, ...]: snapshot = self._pending_file_reviews.pop(tool_event.invocation.invocation_id, None) if snapshot is None: @@ -321,6 +276,7 @@ def _file_review_trace_lines(self, tool_event: ToolLifecycleEvent) -> tuple[str, rendered.append(f"… omitted {overflow} diff line(s)") return tuple(rendered) + def _tool_result_trace_lines(self, tool_event: ToolLifecycleEvent) -> tuple[str, ...]: if tool_event.phase == "execution.failed": self._pending_file_reviews.pop(tool_event.invocation.invocation_id, None) @@ -333,6 +289,7 @@ def _tool_result_trace_lines(self, tool_event: ToolLifecycleEvent) -> tuple[str, return self._file_review_trace_lines(tool_event) return () + def _boot_growth_stage(self, active: int) -> tuple[str, int]: if active < 0: return ("seed", 0) @@ -344,6 +301,7 @@ def _boot_growth_stage(self, active: int) -> tuple[str, int]: ) return stages[min(active, len(stages) - 1)] + def _record_tool_event_trace(self, tool_event: ToolLifecycleEvent) -> None: self._capture_pending_file_review(tool_event) # Per-turn tally for the end-of-turn condense line. Only final phases @@ -360,6 +318,7 @@ def _record_tool_event_trace(self, tool_event: ToolLifecycleEvent) -> None: for extra_line in self._tool_result_trace_lines(tool_event): self._append_tooltrace_line(extra_line) + def _kernel_stage_payload(event: dict[str, object]) -> dict[str, object] | None: event_type = str(event.get("event_type") or "").strip().lower() if event_type != "kernel.stage": @@ -367,16 +326,19 @@ def _kernel_stage_payload(event: dict[str, object]) -> dict[str, object] | None: payload = event.get("payload") return payload if isinstance(payload, dict) else None + def _parse_context_compaction_tokens(detail: str) -> tuple[int, int] | None: match = re.search(r"(?:^|\s)tokens=(\d+)->(\d+)(?:\s|$)", str(detail or "")) if match is None: return None return int(match.group(1)), int(match.group(2)) + def _parse_kernel_stage_int(detail: str, key: str) -> int | None: match = re.search(rf"(?:^|\s){re.escape(key)}=(\d+)(?:\s|$)", str(detail or "")) return int(match.group(1)) if match is not None else None + def _kernel_trace_line(self, event: dict[str, object]) -> str | None: event_type = str(event.get("event_type") or "").strip().lower() stage_payload = _kernel_stage_payload(event) @@ -421,6 +383,7 @@ def _kernel_trace_line(self, event: dict[str, object]) -> str | None: body = _compact_line(" · ".join(part for part in body_parts if part), limit=96) return f"┊ 📚 disclosed {body}" + def _record_kernel_event_trace(self, event: dict[str, object]) -> None: stage_payload = _kernel_stage_payload(event) if stage_payload is not None: @@ -446,15 +409,19 @@ def _record_kernel_event_trace(self, event: dict[str, object]) -> None: if line is not None: self._append_tooltrace_line(line) + def _animations_enabled(self) -> bool: return _shell_animations_enabled() + def _turn_phase(self, tick: int) -> tuple[str, str, str]: return _shell_turn_phase(tick) + def _summarize_progress_prompt(self, prompt: str) -> str: return _summarize_shell_progress_prompt(prompt) + def _render_turn_progress_fragments( self, *, @@ -477,12 +444,15 @@ def _render_turn_progress_fragments( stream_text=stream_text, ) + def _render_queued_followup_fragments(self) -> FormattedText: return _render_shell_queued_followup_fragments(self) + def _run_turn_with_queued_input(self, prompt: str) -> KernelOutcome: return _run_shell_turn_with_queued_input(self, prompt) + def _run_turn_with_progress( self, prompt: str, @@ -491,12 +461,15 @@ def _run_turn_with_progress( ) -> KernelOutcome: return _run_shell_turn_with_progress(self, prompt, event_payload=event_payload) + def _run_tool_with_progress(self, tool_id: str, arguments: dict[str, str]): return _run_shell_tool_with_progress(self, tool_id, arguments) + def _tool_event_tracker(self): return _shell_tool_event_tracker() + def _render_turn_frame( self, *, @@ -517,63 +490,85 @@ def _render_turn_frame( stream_text=stream_text, ) + def _render_tool_frame(self, *, tool_id: str, tick: int, tool_event: ToolLifecycleEvent | None = None): return _render_shell_tool_frame(self, tool_id=tool_id, tick=tick, tool_event=tool_event) -def _tool_frame_phases(self, tool_id: str, *, tool_event: ToolLifecycleEvent | None = None) -> tuple[tuple[str, str], ...]: + +def _tool_frame_phases( + self, tool_id: str, *, tool_event: ToolLifecycleEvent | None = None +) -> tuple[tuple[str, str], ...]: return _shell_tool_frame_phases(self, tool_id, tool_event=tool_event) + def _tool_event_lines(self, tool_event: ToolLifecycleEvent | None) -> tuple[str | None, str | None]: return _shell_tool_event_lines(self, tool_event) + def _tool_event_summary(self, tool_event: ToolLifecycleEvent | None) -> str | None: return _shell_tool_event_summary(self, tool_event) + def _tool_trace_line(self, tool_event: ToolLifecycleEvent | None) -> str | None: return _shell_tool_trace_line(self, tool_event) + def _render_pending_entries(self) -> None: _render_shell_pending_entries(self) + def _render_entry(self, entry: TranscriptEntry): return _render_shell_entry(self, entry) + def _growth_panel_lines(self, session, continuity, provider, growth) -> tuple[str, ...]: return _shell_growth_panel_lines(self, session, continuity, provider, growth) + def _recent_activity_lines(self, session, continuity, provider) -> tuple[str, ...]: return _shell_recent_activity_lines(self, session, continuity, provider) + def _recent_experience_lines(self, experiences: tuple[ExperienceRecord, ...]) -> tuple[str, ...]: return _shell_recent_experience_lines(experiences) + def _displayable_experiences(self, experiences: tuple[ExperienceRecord, ...]) -> tuple[ExperienceRecord, ...]: return _displayable_shell_experiences(experiences) + def _should_display_experience(self, experience: ExperienceRecord) -> bool: return _should_display_shell_experience(experience) + def _format_experience_status(self, experience: ExperienceRecord) -> str: return _format_shell_experience_status(experience) + def _growth_progress_counts(self, growth, *, width: int = GROWTH_PROGRESS_WIDTH) -> tuple[int, int]: return _shell_growth_progress_counts(growth, width=width) + def _growth_progress_bar(self, growth, *, width: int = GROWTH_PROGRESS_WIDTH) -> str: return _shell_growth_progress_bar(growth, width=width) + def _styled_growth_progress_bar(self, growth, *, width: int = GROWTH_PROGRESS_WIDTH): return _styled_shell_growth_progress_bar(growth, width=width) + def _render_chat_entry(self, entry: TranscriptEntry, *, accent: str): return _render_shell_chat_entry(self, entry, accent=accent) + def _history_row_width(self) -> int: return max(24, len(self._composer_divider())) + def _queue_preview_row_width(self) -> int: return max(16, self._history_row_width() - (QUEUE_PREVIEW_INSET * 2)) + def _pad_history_line(self, content: str) -> str: display_width = _display_width(content) width = self._history_row_width() @@ -581,6 +576,7 @@ def _pad_history_line(self, content: str) -> str: return content return content + (" " * (width - display_width)) + def _pad_queue_preview_line(self, content: str) -> str: display_width = _display_width(content) width = self._queue_preview_row_width() @@ -588,11 +584,14 @@ def _pad_queue_preview_line(self, content: str) -> str: return content return content + (" " * (width - display_width)) + def _center_brand_block(self, renderable): return _center_shell_brand_block(renderable) + def _render_growth_mark(self, stage_id: str, *, level: int | None = None): return _render_shell_growth_mark(stage_id, level=level) + def _render_elephant_mark(self): return _render_shell_elephant_mark() diff --git a/apps/cli/shell_methods_ui.py b/apps/cli/shell_methods_ui.py index e35bf2d..17ac3cc 100644 --- a/apps/cli/shell_methods_ui.py +++ b/apps/cli/shell_methods_ui.py @@ -2,32 +2,11 @@ from __future__ import annotations -from collections import deque -from dataclasses import dataclass -from difflib import unified_diff -import os -from pathlib import Path -import re -import shlex import threading import time -from packages.contracts import ExperienceRecord from packages.growth import ProgressionProjectionBuilder -from packages.kernel.runtime import KernelOutcome from packages.state.governance import companion_display_name -from packages.operator.runtime import ( - RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, - build_recall_evidence_operator_surface, - build_profile_operator_surface, - render_recall_evidence_lines, - render_profile_lines, -) -from packages.tools.handler_support import resolve_allowed_path -from .provider_flow import provider_setup_defaults, run_provider_selection_wizard -from .runtime import CliRuntime -from .wizard import WIZARD_BACK from .shell_composer import ( build_command_palette as _build_shell_command_palette, build_composer_body as _build_shell_composer_body, @@ -44,64 +23,21 @@ read_command as _read_shell_command, shell_history as _shell_history, ) -from .shell_boot import WAKE_DISPLAY_SECONDS, BootFrameContext, render_boot_frame from .shell_opening import ( ShellOpeningContext, compose_shell_opening_instruction, compose_shell_opener, ) -from .shell_progress import ( - animations_enabled as _shell_animations_enabled, - render_queued_followup_fragments as _render_shell_queued_followup_fragments, - render_tool_frame as _render_shell_tool_frame, - tool_trace_line as _shell_tool_trace_line, - render_turn_frame as _render_shell_turn_frame, - render_turn_progress_fragments as _render_shell_turn_progress_fragments, - run_tool_with_progress as _run_shell_tool_with_progress, - run_turn_with_progress as _run_shell_turn_with_progress, - run_turn_with_queued_input as _run_shell_turn_with_queued_input, - summarize_progress_prompt as _summarize_shell_progress_prompt, - tool_event_lines as _shell_tool_event_lines, - tool_event_summary as _shell_tool_event_summary, - tool_event_tracker as _shell_tool_event_tracker, - tool_frame_phases as _shell_tool_frame_phases, - turn_phase as _shell_turn_phase, - _tool_trace_emoji as _shell_tool_trace_emoji, -) from .shell_render import ( - center_brand_block as _center_shell_brand_block, - displayable_experiences as _displayable_shell_experiences, - format_experience_status as _format_shell_experience_status, - growth_panel_lines as _shell_growth_panel_lines, - growth_progress_bar as _shell_growth_progress_bar, - growth_progress_counts as _shell_growth_progress_counts, - recent_activity_lines as _shell_recent_activity_lines, - recent_experience_lines as _shell_recent_experience_lines, render_brand_column as _render_shell_brand_column, - render_chat_entry as _render_shell_chat_entry, - render_entry as _render_shell_entry, - render_elephant_brand_mark as _render_shell_elephant_mark, - render_growth_mark_for_stage as _render_shell_growth_mark, - render_pending_entries as _render_shell_pending_entries, render_shell_frame as _render_shell_frame_view, render_status_column as _render_shell_status_column, - should_display_experience as _should_display_shell_experience, - styled_growth_progress_bar as _styled_shell_growth_progress_bar, ) from .shell_stack import ( - Align, - Completion, - Completer, Console, Document, - FormattedText, - Group, - Live, PROMPT_TOOLKIT_AVAILABLE, - Panel, RICH_AVAILABLE, - Table, - Text, ) from .shell_ui import ( BRAND_ACCENT, @@ -117,20 +53,16 @@ GROWTH_PROGRESS_WIDTH, HATCHLING_HEAD_ROWS, HATCHLING_STAGE_ROWS, - HATCHLING_STAGE_ROWS, QUEUE_PREVIEW_INSET, SCOUT_STAGE_ROWS, SEED_STAGE_ROWS, SHELL_WELCOME_HEADLINE, USER_HISTORY_BG, USER_HISTORY_FG, - WEB_URL_PATTERN, compact_line as _compact_line, centered_elephant_rows as _centered_elephant_rows, - display_path as _display_path, display_width as _display_width, render_elephant_mark, - resolve_elephant_version as _resolve_elephant_version, ) __all__ = [ @@ -167,13 +99,15 @@ ] - from .shell_support_runtime import * # noqa: F401,F403 def _first_language_from_runtime(runtime, profile) -> str: try: - from packages.runtime_config import global_config_path_for_state_dir, load_global_config + from packages.runtime_config import ( + global_config_path_for_state_dir, + load_global_config, + ) config_path = global_config_path_for_state_dir(runtime.paths.state_dir) config = load_global_config(config_path, state_dir=runtime.paths.state_dir) @@ -195,27 +129,35 @@ def _next_command(self) -> PendingShellCommand: return self._pending_commands.popleft() return coerce_pending_shell_command(self._read_command()) + def _prompt_toolkit_composer_available(self) -> bool: return _shell_prompt_toolkit_composer_available(self) + def _shell_history(self): return _shell_history(self) + def _build_prompt_buffer(self): return _build_shell_prompt_buffer(self) + def _build_input_window(self, buffer): return _build_shell_input_window(self, buffer) + def _build_command_palette(self): return _build_shell_command_palette(self) + def _build_queue_preview_window(self): return _build_shell_queue_preview_window(self) + def _build_divider_window(self): return _build_shell_divider_window(self) + def _build_composer_body( self, *, @@ -232,9 +174,11 @@ def _build_composer_body( buffer=buffer, ) + def _read_command(self) -> object: return _read_shell_command(self) + def personality_preset_choices(self) -> tuple[tuple[str, str], ...]: return tuple( (preset.preset_id, preset.summary) @@ -242,21 +186,27 @@ def personality_preset_choices(self) -> tuple[tuple[str, str], ...]: if preset.preset_id != "custom" ) + def _prompt_label(self) -> str: return _shell_prompt_label(self) + def _prompt_continuation(self): return _shell_prompt_continuation() + def _prompt_style(self): return _shell_prompt_style() + def _prompt_style_map(self) -> dict[str, str]: return _shell_prompt_style_map() + def _build_key_bindings(self, *, submit=None, allow_exit: bool = True) -> KeyBindings: return _build_shell_key_bindings(self, submit=submit, allow_exit=allow_exit) + def _composer_divider(self) -> str: try: width = int(self.console.size.width) @@ -269,6 +219,7 @@ def _composer_divider(self) -> str: self._composer_divider_cache = (width, divider) return divider + def _format_status_tokens(self, value: int | None) -> str: if value is None or value <= 0: return "--" @@ -280,6 +231,7 @@ def _format_status_tokens(self, value: int | None) -> str: return f"{whole}K" return str(value) + def _status_bar_context_style(self, percent_used: int | None) -> str: if percent_used is None: return "class:status-bar-muted" @@ -289,6 +241,7 @@ def _status_bar_context_style(self, percent_used: int | None) -> str: return "class:status-bar-warn" return "class:status-bar-good" + _STATUS_BAR_PROGRESS_FILLED = "█" _STATUS_BAR_PROGRESS_EMPTY = "░" _STATUS_CONTEXT_RING_STEPS = ("○", "◜", "◔", "◑", "◕", "●") @@ -327,6 +280,7 @@ def _build_growth_bar_fragments(self, growth, *, width: int = 12) -> list[tuple[ fragments.append(("class:status-bar-growth-bracket", "]")) return fragments + def _status_bar_elapsed_fragments(elapsed_seconds: int, *, streaming_active: bool = False) -> list[tuple[str, str]]: fragments: list[tuple[str, str]] = [("class:status-bar-muted", f"{elapsed_seconds}s")] if streaming_active: @@ -613,6 +567,7 @@ def _status_bar_snapshot(self) -> dict[str, object]: self._status_bar_snapshot_cache = (now, snapshot, active_turn) return snapshot + def _status_bar_fragments(self): snapshot = self._status_bar_snapshot() growth = _status_bar_growth(self) @@ -639,27 +594,36 @@ def _status_bar_fragments(self): phase_style, phase_label = phase_indicator fragments.append(("class:status-bar-sep", " · ")) fragments.append((phase_style, phase_label)) - fragments.extend([ - ("class:status-bar-sep", " │ "), - ("class:status-bar-model", str(snapshot["model_short"])), - ("class:status-bar-sep", " │ "), - ("class:status-bar-muted", f"{context_used}/{context_limit}"), - ("class:status-bar-sep", " "), - (percent_style, self._build_context_ring(percent if isinstance(percent, int) else None)), - ("class:status-bar-sep", " "), - (percent_style, percent_label), - ("class:status-bar-sep", " │ "), - *_status_bar_elapsed_fragments(elapsed_seconds, streaming_active=streaming_active), - ("class:status-bar-sep", " │ "), - ("class:status-bar-level", growth.cycle_label), - ("class:status-bar-sep", " "), - *self._build_growth_bar_fragments(growth), - ("class:status-bar-sep", " "), - ("class:status-bar-level", f"checkpoint {growth.level} · {growth.progress_percent}%"), - ]) + fragments.extend( + [ + ("class:status-bar-sep", " │ "), + ("class:status-bar-model", str(snapshot["model_short"])), + ("class:status-bar-sep", " │ "), + ("class:status-bar-muted", f"{context_used}/{context_limit}"), + ("class:status-bar-sep", " "), + ( + percent_style, + self._build_context_ring(percent if isinstance(percent, int) else None), + ), + ("class:status-bar-sep", " "), + (percent_style, percent_label), + ("class:status-bar-sep", " │ "), + *_status_bar_elapsed_fragments(elapsed_seconds, streaming_active=streaming_active), + ("class:status-bar-sep", " │ "), + ("class:status-bar-level", growth.cycle_label), + ("class:status-bar-sep", " "), + *self._build_growth_bar_fragments(growth), + ("class:status-bar-sep", " "), + ( + "class:status-bar-level", + f"checkpoint {growth.level} · {growth.progress_percent}%", + ), + ] + ) fragments.append(("class:status-bar-edge", " ")) return fragments + def _clear_composer(self, command: str) -> None: if PROMPT_TOOLKIT_AVAILABLE: return @@ -672,6 +636,7 @@ def _clear_composer(self, command: str) -> None: stream.write("\x1b[1A\r\x1b[2K") stream.flush() + def _enqueue_followup_command(self, raw_command: object) -> None: pending = coerce_pending_shell_command(raw_command) command = pending.command.strip() @@ -685,27 +650,33 @@ def _enqueue_followup_command(self, raw_command: object) -> None: ) ) + def _is_startup_conversational_command(self, raw_command: str) -> bool: command = raw_command.strip() return bool(command) and not command.startswith("/") + def _startup_state_focus_dispatch_ready(self) -> bool: return True + def _startup_should_hold_user_command(self, raw_command: str) -> bool: if not self._is_startup_conversational_command(raw_command): return False return not self._startup_transcript_primed + def _mark_startup_user_turn_submitted(self, raw_command: str) -> None: if self._is_startup_conversational_command(raw_command): self._startup_user_turn_submitted = True + def _startup_should_surface_state_focus_notices(self) -> bool: if not self._startup_surface_prepared or not self._state_focus_runtime_ready_seen: return True return not self._startup_transcript_primed + def _set_state_focus_runtime_notice(self, title: str, body: str) -> None: """Replace the live startup notice with a single (title, body) slot. @@ -718,10 +689,12 @@ def _set_state_focus_runtime_notice(self, title: str, body: str) -> None: return self._state_focus_runtime_notices = [notice] + def _clear_state_focus_runtime_notice(self) -> None: if self._state_focus_runtime_notices: self._state_focus_runtime_notices = [] + def _sync_state_focus_runtime_notices(self) -> None: status = self.runtime.state_focus_runtime_status() if not bool(status.get("embedding_ready")): @@ -768,6 +741,7 @@ def _sync_state_focus_runtime_notices(self) -> None: self._state_focus_runtime_ready_seen = True self._state_focus_runtime_ready_seen_at = time.monotonic() + def _prepare_startup_surface(self) -> None: self._sync_state_focus_runtime_notices() if self._startup_surface_prepared or self._startup_surface_prepare_started: @@ -788,6 +762,7 @@ def prepare_surface() -> None: daemon=True, ).start() + def _prime_startup_transcript_if_needed(self) -> None: if self._startup_transcript_primed: self._sync_state_focus_runtime_notices() @@ -796,6 +771,7 @@ def _prime_startup_transcript_if_needed(self) -> None: self._startup_transcript_primed = True self._sync_state_focus_runtime_notices() + def _prime_transcript(self, *, use_proactive_opening: bool = True) -> None: session = self.runtime.inspect_session(self.session_id) continuity = self.runtime.inspect_continuity(session_id=self.session_id) @@ -826,7 +802,11 @@ def _prime_transcript(self, *, use_proactive_opening: bool = True) -> None: ) except Exception as error: if self.debug: - self._append_entry("notice", "Startup prompt", f"fallback to local opener\nreason: {error}") + self._append_entry( + "notice", + "Startup prompt", + f"fallback to local opener\nreason: {error}", + ) if startup_outcome is not None and startup_outcome.execution.summary.strip(): self._append_entry("assistant", assistant_name, startup_outcome.execution.summary.strip()) else: @@ -858,6 +838,7 @@ def _prime_transcript(self, *, use_proactive_opening: bool = True) -> None: ) self._startup_transcript_primed = True + def _assistant_name(self) -> str: session = self.runtime.inspect_session(self.session_id) state = self.runtime.state_for_elephant(session.elephant_id or "") if session.elephant_id else None @@ -868,14 +849,18 @@ def _assistant_name(self) -> str: return identity.display_name.strip() return "Elephant Agent" + def _append_assistant_surface_reply(self, body: str, *, meta: str = "") -> None: self._append_entry("assistant", self._assistant_name(), body, meta=meta) + def _render_shell_frame(self): return _render_shell_frame_view(self) + def _render_brand_column(self, session, continuity, provider, growth): return _render_shell_brand_column(self, session, continuity, provider, growth) + def _render_status_column(self, session, continuity, context_frame, provider, growth): return _render_shell_status_column(self, session, continuity, context_frame, provider, growth) diff --git a/apps/cli/shell_opening.py b/apps/cli/shell_opening.py index 6d7a6ab..bc1e582 100644 --- a/apps/cli/shell_opening.py +++ b/apps/cli/shell_opening.py @@ -105,7 +105,15 @@ def compose_shell_opening_instruction(context: ShellOpeningContext) -> str: def _first_language_line(first_language: str) -> str: normalized = str(first_language or "en").strip().lower() - if normalized in {"zh", "zh-cn", "cn", "chinese", "中文", "汉语", "普通话"} or normalized.startswith("zh"): + if normalized in { + "zh", + "zh-cn", + "cn", + "chinese", + "中文", + "汉语", + "普通话", + } or normalized.startswith("zh"): return "User's first language selected during init: Chinese. Write this opener in Chinese unless explicitly requested otherwise." return "User's first language selected during init: English." @@ -126,8 +134,8 @@ def _actionable_wake_summary(*, wake_action: str, wake_summary: str) -> str: return "" lowered = " ".join(summary.casefold().split()) non_actionable_markers = ( - "no actionable current work", - "planner should defer", + "no actionable current work", + "planner should defer", "keeps the active slot clear", "defer and schedule", ) @@ -164,9 +172,7 @@ def _public_wake_summary(*, wake_action: str, wake_summary: str) -> str: return task_title sentences = re.split(r"(?<=[.!?])\s+", " ".join(summary.split())) public_sentences = tuple( - sentence.strip() - for sentence in sentences - if sentence.strip() and not _contains_internal_wake_marker(sentence) + sentence.strip() for sentence in sentences if sentence.strip() and not _contains_internal_wake_marker(sentence) ) if public_sentences: return " ".join(public_sentences[:2]) diff --git a/apps/cli/shell_progress_impl.py b/apps/cli/shell_progress_impl.py index 26806df..8c2ce6e 100644 --- a/apps/cli/shell_progress_impl.py +++ b/apps/cli/shell_progress_impl.py @@ -2,48 +2,12 @@ from __future__ import annotations -from dataclasses import dataclass -import os import re -import threading -import time from typing import TYPE_CHECKING -from packages.kernel.runtime import KernelOutcome -from packages.tools import ToolLifecycleEvent - -from .shell_stack import ( - Application, - Condition, - ConditionalContainer, - FormattedText, - FormattedTextControl, - Group, - Layout, - Live, - Panel, - RICH_AVAILABLE, - Text, - Window, -) -from .shell_ui import ( - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_DARK, - BRAND_LIGHT, - BRAND_MUTED, - LIVE_DIFF_ADD_FG, - LIVE_DIFF_CONTEXT_FG, - LIVE_DIFF_FILE_FG, - LIVE_DIFF_HUNK_FG, - LIVE_DIFF_REMOVE_FG, - QUEUE_PREVIEW_INSET, - compact_line, - strip_markdown_bold, -) if TYPE_CHECKING: - from .shell import ProductizedShell + pass _STREAM_TOOL_BLOCK_PATTERNS = ( @@ -70,7 +34,6 @@ ) - from .shell_progress_support import ( _ToolTraceDisplayParts, _VisibleToolEvent, diff --git a/apps/cli/shell_progress_runtime.py b/apps/cli/shell_progress_runtime.py index 11c91bf..a132bdc 100644 --- a/apps/cli/shell_progress_runtime.py +++ b/apps/cli/shell_progress_runtime.py @@ -2,8 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass -import os import re import threading import time @@ -18,30 +16,26 @@ ConditionalContainer, FormattedText, FormattedTextControl, - Group, Layout, Live, - Panel, - RICH_AVAILABLE, Text, Window, ) from .shell_ui import ( - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_DARK, BRAND_LIGHT, - BRAND_MUTED, LIVE_DIFF_ADD_FG, LIVE_DIFF_CONTEXT_FG, LIVE_DIFF_FILE_FG, LIVE_DIFF_HUNK_FG, LIVE_DIFF_REMOVE_FG, QUEUE_PREVIEW_INSET, - compact_line, strip_markdown_bold, ) -from .shell_clarify import build_clarify_window, route_clarify_answer, set_clarify_invalidator +from .shell_clarify import ( + build_clarify_window, + route_clarify_answer, + set_clarify_invalidator, +) from .shell_composer import _compose_submission, run_prompt_toolkit_application if TYPE_CHECKING: @@ -72,22 +66,15 @@ ) - from .shell_progress_support import ( - _ToolTraceDisplayParts, _VisibleToolEvent, animations_enabled, live_tool_feed_lines, - pending_tool_display_lines, - pending_tool_output_lines, - pending_tooltrace_lines, - summarize_progress_prompt, loop_context_progress_line, recall_progress_line, turn_state_focus_progress_line, turn_phase, turn_title, - turn_tool_progress_lines, ) from .shell_progress_trace import ( _stream_display_parts, @@ -97,6 +84,7 @@ render_turn_frame, ) + def _slow_op_hint_fragments(shell: ProductizedShell) -> list[tuple[str, str]]: """Telegraph long-running turns so the user knows we're still working. @@ -145,11 +133,11 @@ def render_turn_progress_fragments( ("class:progress-active-detail", f" {phase_detail}"), ] fragments.append(("", "\n")) - fragments.extend(render_live_tool_line_fragments(turn_state_focus_progress_line(kernel_stage_events=kernel_stage_events))) + fragments.extend( + render_live_tool_line_fragments(turn_state_focus_progress_line(kernel_stage_events=kernel_stage_events)) + ) context_line = loop_context_progress_line(kernel_stage_events=kernel_stage_events) - if context_line.startswith("┊ 🧩 context") and ( - "projection" in context_line or "compressing" in context_line - ): + if context_line.startswith("┊ 🧩 context") and ("projection" in context_line or "compressing" in context_line): fragments.extend( render_live_tool_line_fragments( context_line, @@ -197,10 +185,11 @@ def render_turn_progress_fragments( fragments.extend(stream_fragments) for live_line in live_tool_feed_lines(shell, tool_event=tool_event, tool_events=tool_events): fragments.extend(render_live_tool_line_fragments(live_line, leading_newline=True)) - + fragments.append(("", "\n")) return FormattedText(fragments) + def render_stream_response_fragments( shell: ProductizedShell, *, @@ -212,6 +201,7 @@ def render_stream_response_fragments( return FormattedText([]) return FormattedText(fragments) + def build_turn_progress_window( shell: ProductizedShell, *, @@ -243,6 +233,7 @@ def build_turn_progress_window( dont_extend_height=True, ) + def build_stream_response_window(shell: ProductizedShell, *, stream_holder, stream_lock): # Stream text is rendered inside render_turn_progress_fragments before tool lines. # This hidden container keeps the layout factory shape simple. @@ -255,9 +246,11 @@ def build_stream_response_window(shell: ProductizedShell, *, stream_holder, stre filter=Condition(lambda: False), # Always hidden ) + def set_streaming_response_active(shell: ProductizedShell, active: bool) -> None: shell._streaming_response_active = active + def render_tool_output_fragments(line: str, *, leading_newline: bool = False) -> list[tuple[str, str]]: fragments: list[tuple[str, str]] = [] if leading_newline: @@ -276,6 +269,7 @@ def render_tool_output_fragments(line: str, *, leading_newline: bool = False) -> fragments.append((style, line)) return fragments + def render_tool_output_text(line: str) -> Text: style = BRAND_LIGHT if line.startswith("a/") and " → b/" in line: @@ -290,6 +284,7 @@ def render_tool_output_text(line: str) -> Text: style = LIVE_DIFF_CONTEXT_FG return Text(line, style=style) + def render_live_tool_line_fragments(line: str, *, leading_newline: bool = False) -> list[tuple[str, str]]: if line.startswith("┊ "): from .shell_progress_trace import render_tool_trace_fragments @@ -297,6 +292,7 @@ def render_live_tool_line_fragments(line: str, *, leading_newline: bool = False) return render_tool_trace_fragments(line, leading_newline=leading_newline) return render_tool_output_fragments(line, leading_newline=leading_newline) + def render_live_tool_line_text(line: str) -> Text: if line.startswith("┊ "): from .shell_progress_trace import render_tool_trace_text @@ -304,6 +300,7 @@ def render_live_tool_line_text(line: str) -> Text: return render_tool_trace_text(line) return render_tool_output_text(line) + def render_queued_followup_fragments(shell: ProductizedShell) -> FormattedText: fragments: list[tuple[str, str]] = [] for command in shell._pending_commands: @@ -318,9 +315,11 @@ def render_queued_followup_fragments(shell: ProductizedShell) -> FormattedText: fragments.pop() return FormattedText(fragments) + def queued_turn_input_supported(shell: ProductizedShell) -> bool: return shell._prompt_toolkit_composer_available() + def resolve_turn_outcome(holder: dict[str, object]) -> KernelOutcome: error = holder.get("error") if isinstance(error, Exception): @@ -330,6 +329,7 @@ def resolve_turn_outcome(holder: dict[str, object]) -> KernelOutcome: raise RuntimeError("turn completed without a kernel outcome") return outcome + def run_turn_with_queued_input( shell: ProductizedShell, prompt: str, @@ -375,8 +375,7 @@ def reset_stream_for_tool_event(event: ToolLifecycleEvent) -> None: stream_lock=stream_lock, ) kernel_stage_holder, kernel_stage_lock, kernel_observer = kernel_event_tracker( - shell._record_kernel_event_trace, - lambda _event: invalidate_application() + shell._record_kernel_event_trace, lambda _event: invalidate_application() ) previous_clarify_surface = shell.runtime.clarify_surface shell.runtime.set_clarify_surface(shell._interactive_clarify_surface()) @@ -500,6 +499,7 @@ def exit_when_complete() -> None: shell.runtime.set_clarify_surface(previous_clarify_surface) return resolve_turn_outcome(holder) + def run_turn_with_progress( shell: ProductizedShell, prompt: str, @@ -543,9 +543,7 @@ def collapse_stream_reasoning_for_requested_tool(event: ToolLifecycleEvent) -> N stream_holder=stream_holder, stream_lock=stream_lock, ) - kernel_stage_holder, kernel_stage_lock, kernel_observer = kernel_event_tracker( - shell._record_kernel_event_trace - ) + kernel_stage_holder, kernel_stage_lock, kernel_observer = kernel_event_tracker(shell._record_kernel_event_trace) unsubscribe = shell.runtime.tool_runtime.subscribe(tool_observer) shell.runtime.set_model_stream_observer(stream_observer) shell.runtime.set_kernel_event_observer(kernel_observer) @@ -640,6 +638,7 @@ def worker() -> None: unsubscribe() return resolve_turn_outcome(holder) + def run_tool_with_progress(shell: ProductizedShell, tool_id: str, arguments: dict[str, str]): shell.runtime.prepare_session_surface(shell.session_id) tool_runtime = shell.runtime.tool_runtime @@ -726,6 +725,7 @@ def worker() -> None: raise RuntimeError("tool call completed without a result") return result + def tool_event_tracker(*extra_observers, stream_holder=None, stream_lock=None): holder: dict[str, object] = { "latest": None, @@ -740,9 +740,7 @@ def observer(event: ToolLifecycleEvent) -> None: with lock: holder["latest"] = event feed = [ - item - for item in holder.get("feed", []) - if isinstance(item, _VisibleToolEvent) and item.expires_at > now + item for item in holder.get("feed", []) if isinstance(item, _VisibleToolEvent) and item.expires_at > now ] snapshots = { key: value @@ -774,9 +772,7 @@ def observer(event: ToolLifecycleEvent) -> None: anchors.append(visible_event) holder["stream_anchors"] = anchors[-24:] active_invocation_ids = { - item.event.invocation.invocation_id - for item in feed - if isinstance(item, _VisibleToolEvent) + item.event.invocation.invocation_id for item in feed if isinstance(item, _VisibleToolEvent) } holder["stream_snapshots"] = { key: value for key, value in snapshots.items() if key in active_invocation_ids @@ -789,17 +785,17 @@ def observer(event: ToolLifecycleEvent) -> None: return holder, lock, observer + def latest_tool_event(holder, lock) -> ToolLifecycleEvent | None: with lock: return holder.get("latest") + def visible_tool_events(holder, lock) -> tuple[_VisibleToolEvent, ...]: now = time.monotonic() with lock: feed = [ - item - for item in holder.get("feed", []) - if isinstance(item, _VisibleToolEvent) and item.expires_at > now + item for item in holder.get("feed", []) if isinstance(item, _VisibleToolEvent) and item.expires_at > now ] holder["feed"] = feed return tuple(feed) @@ -823,6 +819,7 @@ def stream_anchor_events(holder, lock) -> tuple[_VisibleToolEvent, ...]: holder["stream_anchors"] = anchors return tuple(anchors) + def kernel_event_tracker(*extra_observers): holder: dict[str, object] = {"stages": []} lock = threading.Lock() @@ -854,6 +851,7 @@ def observer(event) -> None: return holder, lock, observer + def visible_kernel_stage_events(holder, lock) -> tuple[dict[str, object], ...]: with lock: stages = holder.get("stages", ()) @@ -862,7 +860,10 @@ def visible_kernel_stage_events(holder, lock) -> tuple[dict[str, object], ...]: visible = [stage for stage in stages if isinstance(stage, dict)] state_focus_prefix = [ stage - for stage in (holder.get("last_state_focus_previous"), holder.get("last_state_focus")) + for stage in ( + holder.get("last_state_focus_previous"), + holder.get("last_state_focus"), + ) if isinstance(stage, dict) and stage not in visible ] if state_focus_prefix: @@ -898,6 +899,7 @@ def remember_context_compaction_frame( } shell._pending_context_compaction_frame_rendered = False + def kernel_stages_include_compaction(stages: tuple[dict[str, object], ...]) -> bool: for stage_event in stages: payload = stage_event.get("payload") @@ -905,13 +907,20 @@ def kernel_stages_include_compaction(stages: tuple[dict[str, object], ...]) -> b return True return False + def _tool_event_hold_seconds(phase: str) -> float: if phase in {"requested", "execution.started"}: return 0.45 - if phase in {"execution.completed", "execution.failed", "approval.denied", "approval.deferred"}: + if phase in { + "execution.completed", + "execution.failed", + "approval.denied", + "approval.deferred", + }: return 1.1 return 0.35 + def stream_text_tracker(): # holder["raw"] — rolling 16KB buffer of raw model bytes # holder["_cache_key"] — id+length of the raw str at last parse diff --git a/apps/cli/shell_progress_support.py b/apps/cli/shell_progress_support.py index c07c4d3..d84c68a 100644 --- a/apps/cli/shell_progress_support.py +++ b/apps/cli/shell_progress_support.py @@ -7,40 +7,16 @@ from datetime import datetime import os import re -import threading -import time from typing import TYPE_CHECKING from packages.kernel.runtime import KernelOutcome from packages.tools import ToolLifecycleEvent from .shell_stack import ( - Application, - Condition, - ConditionalContainer, - FormattedText, - FormattedTextControl, - Group, - Layout, Live, - Panel, RICH_AVAILABLE, - Text, - Window, ) from .shell_ui import ( - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_DARK, - BRAND_LIGHT, - BRAND_MUTED, - LIVE_DIFF_ADD_FG, - LIVE_DIFF_CONTEXT_FG, - LIVE_DIFF_FILE_FG, - LIVE_DIFF_HUNK_FG, - LIVE_DIFF_REMOVE_FG, - QUEUE_PREVIEW_INSET, - compact_line, strip_markdown_bold, ) @@ -83,6 +59,7 @@ class _ToolTraceDisplayParts: duration_gap: str duration: str + @dataclass(frozen=True, slots=True) class _VisibleToolEvent: event: ToolLifecycleEvent @@ -96,6 +73,7 @@ class _KernelStageView: detail: str recorded_at: datetime | None + _TURN_TITLE_FRAMES: tuple[tuple[str, str], ...] = ( ("🐘", "Elephant Agent is orienting"), ("🐾", "Elephant Agent is following the path"), @@ -113,17 +91,20 @@ class _KernelStageView: def animations_enabled() -> bool: return RICH_AVAILABLE and Live is not None and os.environ.get("ELEPHANT_NO_ANIMATION") != "1" + def turn_title(tick: int) -> tuple[str, str]: # Slower title rotation — one frame every ~2.5s at 12.5 Hz caller. # Rapid title changes read as twitchy; a slow drift reads as calm # progress. return _TURN_TITLE_FRAMES[(tick // 32) % len(_TURN_TITLE_FRAMES)] + def turn_marker(tick: int) -> str: # Marker ticks at ~2 Hz (every 6th caller frame at 12.5 Hz) — fast # enough to feel alive, slow enough not to distract. return _TURN_MARKER_FRAMES[(tick // 6) % len(_TURN_MARKER_FRAMES)] + def turn_phase(tick: int) -> tuple[str, str, str]: phases = ( ("Opening", "Opening your thread and current path"), @@ -149,7 +130,9 @@ def _parse_kernel_stage_detail(detail: str) -> dict[str, str]: return parsed -def _kernel_stage_views(kernel_stage_events: tuple[Mapping[str, object], ...]) -> tuple[_KernelStageView, ...]: +def _kernel_stage_views( + kernel_stage_events: tuple[Mapping[str, object], ...], +) -> tuple[_KernelStageView, ...]: views: list[_KernelStageView] = [] for event in kernel_stage_events: payload = event.get("payload") @@ -369,12 +352,14 @@ def outcome_state_focus_meta(outcome: KernelOutcome) -> str: return "" return f"routing · {' · '.join(parts)}" + def summarize_progress_prompt(prompt: str) -> str: normalized = " ".join(prompt.split()) if len(normalized) <= 72: return normalized return f"{normalized[:69]}..." + def pending_tooltrace_lines(shell: ProductizedShell) -> tuple[str, ...]: pending = shell.transcript[shell._rendered_entries :] lines: list[str] = [] @@ -391,6 +376,7 @@ def pending_tooltrace_lines(shell: ProductizedShell) -> tuple[str, ...]: lines.append(normalized) return tuple(lines) + def pending_tool_output_lines(shell: ProductizedShell) -> tuple[str, ...]: pending = shell.transcript[shell._rendered_entries :] lines: list[str] = [] @@ -405,6 +391,7 @@ def pending_tool_output_lines(shell: ProductizedShell) -> tuple[str, ...]: lines.append(normalized) return tuple(lines[-80:]) + def pending_tool_display_lines(shell: ProductizedShell, *, limit: int | None = None) -> tuple[str, ...]: pending = shell.transcript[shell._rendered_entries :] lines: list[str] = [] @@ -424,6 +411,7 @@ def pending_tool_display_lines(shell: ProductizedShell, *, limit: int | None = N hidden = len(lines) - (limit - 1) return (f"… {hidden} earlier tool line(s) hidden", *lines[-(limit - 1) :]) + def live_tool_feed_lines( shell: ProductizedShell, *, @@ -446,6 +434,7 @@ def live_tool_feed_lines( hidden = len(lines) - (limit - 1) return (f"… {hidden} earlier tool line(s) hidden", *lines[-(limit - 1) :]) + def turn_tool_progress_lines( shell: ProductizedShell, *, diff --git a/apps/cli/shell_progress_trace.py b/apps/cli/shell_progress_trace.py index 854eac9..0d0ed02 100644 --- a/apps/cli/shell_progress_trace.py +++ b/apps/cli/shell_progress_trace.py @@ -2,29 +2,16 @@ from __future__ import annotations -from dataclasses import dataclass -import os import re -import threading from typing import TYPE_CHECKING -from packages.kernel.runtime import KernelOutcome from packages.models.reasoning_parser import split_reasoning_and_content from packages.tools import ToolLifecycleEvent from .shell_stack import ( - Application, - Condition, - ConditionalContainer, - FormattedText, - FormattedTextControl, - Group, - Layout, - Live, Panel, RICH_AVAILABLE, Text, - Window, ) from .shell_ui import ( BRAND_ACCENT, @@ -32,12 +19,6 @@ BRAND_DARK, BRAND_LIGHT, BRAND_MUTED, - LIVE_DIFF_ADD_FG, - LIVE_DIFF_CONTEXT_FG, - LIVE_DIFF_FILE_FG, - LIVE_DIFF_HUNK_FG, - LIVE_DIFF_REMOVE_FG, - QUEUE_PREVIEW_INSET, compact_line, strip_markdown_bold, ) @@ -72,17 +53,15 @@ _ToolTraceDisplayParts, _VisibleToolEvent, live_tool_feed_lines, - pending_tool_display_lines, - summarize_progress_prompt, recall_progress_line, loop_context_progress_line, turn_state_focus_progress_line, turn_marker, turn_phase, turn_title, - turn_tool_progress_lines, ) + def render_turn_frame( shell: ProductizedShell, *, @@ -122,7 +101,9 @@ def render_turn_frame( progress_body.append(marker, style=BRAND_MUTED) progress_body.append(f" {phase_detail}", style=BRAND_LIGHT) progress_body.append("\n") - progress_body.append_text(render_live_tool_line_text(turn_state_focus_progress_line(kernel_stage_events=kernel_stage_events))) + progress_body.append_text( + render_live_tool_line_text(turn_state_focus_progress_line(kernel_stage_events=kernel_stage_events)) + ) context_line = loop_context_progress_line(kernel_stage_events=kernel_stage_events) if _is_compaction_context_line(context_line): progress_body.append("\n") @@ -131,11 +112,12 @@ def render_turn_frame( if recall_line: progress_body.append("\n") progress_body.append_text(render_live_tool_line_text(recall_line)) - + # Render tool lines with stream text anchored to the matching tool event when # possible, while preserving the full merged tool rail from transcript + live events. if tool_event_holder is not None and tool_event_lock is not None: from .shell_progress_runtime import stream_anchor_events, visible_tool_events + visible_events = visible_tool_events(tool_event_holder, tool_event_lock) stable_stream_anchors = stream_anchor_events(tool_event_holder, tool_event_lock) previous_stream_was_reasoning_only = False @@ -161,7 +143,7 @@ def render_turn_frame( for live_index, live_line in enumerate(live_lines): progress_body.append("\n\n" if live_index == 0 and stream_has_reasoning_only else "\n") progress_body.append_text(render_live_tool_line_text(live_line)) - + progress_panel = Panel( progress_body, title=f"[bold {BRAND_ACCENT}]{title_glyph} {title_copy}[/bold {BRAND_ACCENT}]", @@ -170,6 +152,7 @@ def render_turn_frame( ) return progress_panel + def render_tool_frame( shell: ProductizedShell, *, @@ -203,6 +186,7 @@ def render_tool_frame( padding=(0, 1), ) + def tool_frame_phases( shell: ProductizedShell, tool_id: str, @@ -215,6 +199,7 @@ def tool_frame_phases( ("tool.report", "Adding the result to the Step trail"), ) + def tool_trace_line( shell: ProductizedShell, tool_event: ToolLifecycleEvent | None, @@ -249,6 +234,7 @@ def tool_trace_line( return f"┊ {marker}{label:<12} {preview}{duration_part}" return f"┊ {marker}{label}{duration_part}" + def tool_event_progress_line( shell: ProductizedShell, tool_event: ToolLifecycleEvent | None, @@ -265,6 +251,7 @@ def tool_event_progress_line( return f"┊ {marker}{label}" return tool_trace_line(shell, tool_event) + def tool_event_progress_lines( shell: ProductizedShell, *, @@ -284,6 +271,7 @@ def tool_event_progress_lines( lines.append(line) return tuple(lines) + def anchored_tool_progress_items( shell: ProductizedShell, *, @@ -338,13 +326,14 @@ def anchored_tool_progress_items( items: list[tuple[str, str]] = [] for index, line in enumerate(tool_lines): - for response_text in insertions.get(index, ()): + for response_text in insertions.get(index, ()): items.append(("stream", response_text)) items.append(("line", line)) - for response_text in insertions.get(len(tool_lines), ()): + for response_text in insertions.get(len(tool_lines), ()): items.append(("stream", response_text)) return tuple(items) + def _tool_event_progress_lines_for_event( shell: ProductizedShell, event: ToolLifecycleEvent, @@ -356,6 +345,7 @@ def _tool_event_progress_lines_for_event( line = tool_event_progress_line(shell, event) return () if line is None else (line,) + def tool_event_lines( shell: ProductizedShell, tool_event: ToolLifecycleEvent | None, @@ -381,6 +371,7 @@ def tool_event_lines( details.append(f"outcome: {tool_event.execution.outcome}") return (title, " · ".join(part for part in details if part)) + def tool_event_summary(shell: ProductizedShell, tool_event: ToolLifecycleEvent | None) -> str | None: if tool_event is None: return None @@ -390,13 +381,19 @@ def tool_event_summary(shell: ProductizedShell, tool_event: ToolLifecycleEvent | title, _ = tool_event_lines(shell, tool_event) return title + def render_tool_trace_fragments(line: str, *, leading_newline: bool = False) -> list[tuple[str, str]]: parts = _tool_trace_display_parts(line) if _is_state_focus_trace_line(line): fragments: list[tuple[str, str]] = [] if leading_newline: fragments.append(("", "\n")) - for text in (parts.rail, f"{parts.emoji} " if parts.emoji else "", parts.prefix, parts.label): + for text in ( + parts.rail, + f"{parts.emoji} " if parts.emoji else "", + parts.prefix, + parts.label, + ): if text: fragments.append(("class:progress-state-focus", text)) if parts.body: @@ -428,11 +425,17 @@ def render_tool_trace_fragments(line: str, *, leading_newline: bool = False) -> fragments.append(("class:progress-tool-duration", parts.duration)) return fragments + def render_tool_trace_text(line: str) -> Text: parts = _tool_trace_display_parts(line) if _is_state_focus_trace_line(line): block = Text() - for text in (parts.rail, f"{parts.emoji} " if parts.emoji else "", parts.prefix, parts.label): + for text in ( + parts.rail, + f"{parts.emoji} " if parts.emoji else "", + parts.prefix, + parts.label, + ): if text: block.append(text, style=BRAND_ACCENT_STRONG) if parts.body: @@ -462,6 +465,7 @@ def render_tool_trace_text(line: str) -> Text: block.append(parts.duration, style=BRAND_MUTED) return block + def _is_state_focus_trace_line(line: str) -> bool: normalized = strip_markdown_bold(line) return ( @@ -470,10 +474,10 @@ def _is_state_focus_trace_line(line: str) -> bool: or normalized.startswith("┊ 🧭 routing") ) + def _is_compaction_context_line(line: str) -> bool: - return line.startswith("┊ 🧩 context") and ( - "projection" in line or "compressing" in line - ) + return line.startswith("┊ 🧩 context") and ("projection" in line or "compressing" in line) + def _tool_trace_display_parts(line: str) -> _ToolTraceDisplayParts: body = strip_markdown_bold(line).rstrip("\n") @@ -483,7 +487,16 @@ def _tool_trace_display_parts(line: str) -> _ToolTraceDisplayParts: body = body[2:] emoji_match = re.match(r"(?P\S+)(?P\s+)(?P.*)$", body) if emoji_match is None: - return _ToolTraceDisplayParts(rail="", emoji="", prefix="", label=body, gap="", body="", duration_gap="", duration="") + return _ToolTraceDisplayParts( + rail="", + emoji="", + prefix="", + label=body, + gap="", + body="", + duration_gap="", + duration="", + ) emoji = emoji_match.group("emoji") remainder = emoji_match.group("remainder").lstrip() @@ -520,6 +533,7 @@ def _tool_trace_display_parts(line: str) -> _ToolTraceDisplayParts: duration=duration, ) + def _tool_trace_state(line: str) -> str: normalized = strip_markdown_bold(line).rstrip("\n") parts = _tool_trace_display_parts(normalized) @@ -533,6 +547,7 @@ def _tool_trace_state(line: str) -> str: return "active" return "done" + def _tool_trace_emoji(tool_id: str, arguments=None) -> str: if tool_id.startswith("mcp."): return "🧩" @@ -568,11 +583,13 @@ def _tool_trace_emoji(tool_id: str, arguments=None) -> str: return "🌐" return "🧩" + def _tool_trace_emoji_marker(emoji: str) -> str: if emoji in {"✍️", "🖥️", "🛠️"}: return f"{emoji} " return f"{emoji} " + def _tool_trace_label(tool_event: ToolLifecycleEvent) -> str: tool_id = tool_event.invocation.tool_id aliases = { @@ -603,6 +620,7 @@ def _tool_trace_label(tool_event: ToolLifecycleEvent) -> str: return tool_id.removeprefix("tool.browser.") return aliases.get(tool_id, tool_id.removeprefix("tool.")) + def _personal_model_trace_preview(arguments, *, tool_id: str | None = None) -> str: action = str(arguments.get("action") or "").strip().lower() topic = str(arguments.get("topic") or "").strip() @@ -620,14 +638,23 @@ def _personal_model_trace_preview(arguments, *, tool_id: str | None = None) -> s "tool.personal_model.questions": "ask", } fallback = fallback_by_tool.get(tool_id, "model") - return compact_line(" ".join(item for item in (action, target) if item).strip() or fallback, limit=64) + return compact_line( + " ".join(item for item in (action, target) if item).strip() or fallback, + limit=64, + ) + def _tool_trace_preview(arguments, *, tool_id: str | None = None) -> str: if tool_id == "tool.sub_agents": preview = _sub_agents_trace_preview(arguments) if preview: return preview - if tool_id in {"tool.personal_model.search", "tool.conversation.search", "tool.personal_model.update", "tool.personal_model.questions"}: + if tool_id in { + "tool.personal_model.search", + "tool.conversation.search", + "tool.personal_model.update", + "tool.personal_model.questions", + }: preview = _personal_model_trace_preview(arguments, tool_id=tool_id) if preview: return preview @@ -672,7 +699,11 @@ def _tool_trace_preview(arguments, *, tool_id: str | None = None) -> str: continue text = str(value).strip() if text: - if key in {"path", "file_path", "filePath"} and isinstance(tool_id, str) and tool_id.startswith("tool.file."): + if ( + key in {"path", "file_path", "filePath"} + and isinstance(tool_id, str) + and tool_id.startswith("tool.file.") + ): import os as _os try: @@ -688,6 +719,7 @@ def _tool_trace_preview(arguments, *, tool_id: str | None = None) -> str: return "all" return "" + def _sub_agents_trace_preview(arguments) -> str: action = _sub_agents_action_label(arguments) tasks = arguments.get("tasks") @@ -709,6 +741,7 @@ def _sub_agents_trace_preview(arguments) -> str: return compact_line(f"{action} · {name}", limit=56) return action + def _sub_agents_trace_progress_lines(arguments) -> tuple[str, ...]: tasks = arguments.get("tasks") if not isinstance(tasks, list) or not tasks: @@ -722,6 +755,7 @@ def _sub_agents_trace_progress_lines(arguments) -> tuple[str, ...]: lines.append(f"┊ … {len(tasks) - len(previews)} more") return tuple(lines) + def _sub_agents_action_label(arguments) -> str: action = str(arguments.get("action") or "run").strip().lower() aliases = { @@ -730,6 +764,7 @@ def _sub_agents_action_label(arguments) -> str: } return aliases.get(action, action or "run") + def _sub_agent_task_previews(tasks: list, *, limit: int) -> tuple[str, ...]: previews: list[str] = [] for item in tasks[:limit]: @@ -745,12 +780,15 @@ def _sub_agent_task_previews(tasks: list, *, limit: int) -> tuple[str, ...]: previews.append(name) return tuple(previews) + def _tool_trace_prepare_label(tool_event: ToolLifecycleEvent) -> str: return _tool_trace_label(tool_event) + def _tool_trace_started_label(tool_event: ToolLifecycleEvent) -> str: return _tool_trace_label(tool_event) + def _tool_trace_duration(tool_event: ToolLifecycleEvent) -> str: requested_at = tool_event.invocation.requested_at if requested_at is None: @@ -758,6 +796,7 @@ def _tool_trace_duration(tool_event: ToolLifecycleEvent) -> str: delta = max(0.0, (tool_event.occurred_at - requested_at).total_seconds()) return f"{delta:.1f}s" + def _stream_preview(stream_text: str, *, limit: int = 220) -> str: normalized = " ".join(_stream_response_text(stream_text).split()) if not normalized: @@ -766,8 +805,10 @@ def _stream_preview(stream_text: str, *, limit: int = 220) -> str: return normalized return f"{normalized[: limit - 3]}..." + STREAM_REASONING_HEADING = "🐾 Elephant Agent's Trail:" + def _stream_display_parts(stream_text: str, *, streaming: bool = True) -> tuple[str, str]: sanitized = _sanitize_stream_tool_markup(stream_text) parsed = split_reasoning_and_content(sanitized, streaming=streaming) @@ -775,6 +816,7 @@ def _stream_display_parts(stream_text: str, *, streaming: bool = True) -> tuple[ response = strip_markdown_bold(parsed.content.replace("\r\n", "\n").replace("\r", "\n")).lstrip("\n") return reasoning, response + def format_reasoning_display_text(reasoning: str, response: str = "") -> str: normalized_reasoning = str(reasoning or "").strip() normalized_response = str(response or "").strip() @@ -784,6 +826,7 @@ def format_reasoning_display_text(reasoning: str, response: str = "") -> str: return f"{STREAM_REASONING_HEADING}\n{normalized_reasoning}" return normalized_response + def _compose_stream_markup(reasoning: str, response: str) -> str: normalized_reasoning = str(reasoning or "") normalized_response = str(response or "") @@ -793,6 +836,7 @@ def _compose_stream_markup(reasoning: str, response: str) -> str: return f"{normalized_reasoning}" return normalized_response + def _stream_response_fragments(stream_text: str) -> list[tuple[str, str]]: reasoning, response = _stream_display_parts(stream_text, streaming=True) fragments: list[tuple[str, str]] = [] @@ -806,6 +850,7 @@ def _stream_response_fragments(stream_text: str) -> list[tuple[str, str]]: fragments.extend(_format_stream_response_markdown(response)) return fragments + def _format_stream_response_markdown(response: str) -> list[tuple[str, str]]: fragments: list[tuple[str, str]] = [] lines = response.split("\n") @@ -839,17 +884,37 @@ def _format_stream_response_markdown(response: str) -> list[tuple[str, str]]: continue list_match = _list_pat.match(line) if list_match: - fragments.append(("class:stream-response-accent", f"{list_match.group(1)}{list_match.group(2)} ")) - _append_inline_stream_fragments(fragments, list_match.group(3), _bold_italic_pat, _bold_pat, _italic_pat, _code_pat) + fragments.append( + ( + "class:stream-response-accent", + f"{list_match.group(1)}{list_match.group(2)} ", + ) + ) + _append_inline_stream_fragments( + fragments, + list_match.group(3), + _bold_italic_pat, + _bold_pat, + _italic_pat, + _code_pat, + ) continue if line.startswith(">"): fragments.append(("class:stream-response-muted", "│ ")) - _append_inline_stream_fragments(fragments, line.lstrip("> "), _bold_italic_pat, _bold_pat, _italic_pat, _code_pat) + _append_inline_stream_fragments( + fragments, + line.lstrip("> "), + _bold_italic_pat, + _bold_pat, + _italic_pat, + _code_pat, + ) continue _append_inline_stream_fragments(fragments, line, _bold_italic_pat, _bold_pat, _italic_pat, _code_pat) return fragments + def _append_inline_stream_fragments( fragments: list[tuple[str, str]], text: str, @@ -881,10 +946,12 @@ def _append_inline_stream_fragments( if pos < len(text): fragments.append(("class:stream-response-body", text[pos:])) + def _stream_has_reasoning_only(stream_text: str) -> bool: reasoning, response = _stream_display_parts(stream_text, streaming=True) return bool(reasoning and not response) + def _stream_response_rich_text(stream_text: str) -> Text: text = Text() style_map = { @@ -905,6 +972,7 @@ def _stream_response_rich_text(stream_text: str) -> Text: text.append(fragment, style=style_map.get(style, BRAND_LIGHT)) return text + def _stream_response_text(stream_text: str, *, limit: int = 3200) -> str: reasoning, response = _stream_display_parts(stream_text, streaming=True) normalized = format_reasoning_display_text(reasoning, response) @@ -920,6 +988,7 @@ def _stream_response_text(stream_text: str, *, limit: int = 3200) -> str: return f"...\n{trimmed}" return f"... {tail.lstrip()}" + def select_stream_response_text( stream_text: str, *, @@ -935,6 +1004,7 @@ def select_stream_response_text( return current return selected + def stream_response_delta(stream_text: str, *, previous_stream_text: str = "") -> str: current_reasoning, current_response = _stream_display_parts(stream_text, streaming=True) if not current_reasoning and not current_response: @@ -952,6 +1022,7 @@ def stream_response_delta(stream_text: str, *, previous_stream_text: str = "") - delta_response = current_response return _compose_stream_markup(delta_reasoning.lstrip("\n"), delta_response.lstrip("\n")) + def _sanitize_stream_tool_markup(raw: str) -> str: cleaned = raw for pattern in _STREAM_TOOL_BLOCK_PATTERNS: @@ -969,6 +1040,7 @@ def _sanitize_stream_tool_markup(raw: str) -> str: cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) return cleaned + def _partial_tool_tag_start(text: str) -> int | None: marker = text.rfind("<") if marker < 0: diff --git a/apps/cli/shell_render.py b/apps/cli/shell_render.py index 59bf91c..08be14b 100644 --- a/apps/cli/shell_render.py +++ b/apps/cli/shell_render.py @@ -149,10 +149,17 @@ def render_brand_column(shell: ProductizedShell, session, continuity, provider, heading.append("🐘 Personal Model first. Curious by design.", style=BRAND_MUTED) meta = Text(no_wrap=True) meta.append(f"{display_name}\n", style=f"bold {BRAND_LIGHT}") - meta.append("understands first · asks gently · picks up the right thread\n", style=BRAND_MUTED) + meta.append( + "understands first · asks gently · picks up the right thread\n", + style=BRAND_MUTED, + ) meta.append(growth.identity_line, style=BRAND_ACCENT_STRONG) if Table is None: - return Group(heading, shell._render_growth_mark(growth.brand_stage_id, level=growth.level), meta) + return Group( + heading, + shell._render_growth_mark(growth.brand_stage_id, level=growth.level), + meta, + ) brand = Table.grid(expand=True) brand.add_column(no_wrap=True) brand.add_row(shell._center_brand_block(heading)) @@ -276,7 +283,10 @@ def _format_entry_reasoning_display(reasoning: str, response: str) -> str: ("invalid key:", "Try F1 or `?` for the cheatsheet of valid bindings."), ("no module named", "Check your virtualenv — something isn't importable."), ("permission denied", "This path might be outside the current project root."), - ("connection refused", "The provider endpoint didn't answer — check /providers status."), + ( + "connection refused", + "The provider endpoint didn't answer — check /providers status.", + ), ("token limit", "The conversation got long. Try /clear and start a fresh thread."), ("unauthorized", "Your provider key looks off — run /providers to update."), ("rate limit", "Provider is rate-limiting us. Wait a moment and try again."), @@ -309,7 +319,7 @@ def _fold_long_body(shell: ProductizedShell, entry: TranscriptEntry) -> tuple[st # Keep the head (most-recent information first in most dumps is rare, # so showing the head is the right default). Tail an ellipsis so users # know it's been trimmed. - head_lines = body.split("\n")[: _MAX_TRANSCRIPT_BODY_LINES] + head_lines = body.split("\n")[:_MAX_TRANSCRIPT_BODY_LINES] hidden = line_count - len(head_lines) fold_marker = f"… {hidden} more line(s) hidden · type /expand last to see the whole thing" head = "\n".join(head_lines) @@ -388,16 +398,8 @@ def growth_panel_lines(shell: ProductizedShell, session, continuity, provider, g ] lines.extend( [ - ( - "history · " - f"{growth.canonical_dialogues} dialogues · " - f"{growth.canonical_active_days} active day(s)" - ), - ( - "saved work · " - f"{growth.canonical_experiences} experience(s) · " - f"{growth.state.total_tokens} tokens seen" - ), + (f"history · {growth.canonical_dialogues} dialogues · {growth.canonical_active_days} active day(s)"), + (f"saved work · {growth.canonical_experiences} experience(s) · {growth.state.total_tokens} tokens seen"), ] ) experiences = shell.runtime.inspect_experiences(session_id=session.episode_id, limit=2) @@ -415,11 +417,15 @@ def recent_activity_lines(shell: ProductizedShell, session, continuity, provider return growth_panel_lines(shell, session, continuity, provider, growth) -def recent_experience_lines(experiences: tuple[ExperienceRecord, ...]) -> tuple[str, ...]: +def recent_experience_lines( + experiences: tuple[ExperienceRecord, ...], +) -> tuple[str, ...]: return tuple(f"evidence · {format_experience_status(experience)}" for experience in experiences) -def displayable_experiences(experiences: tuple[ExperienceRecord, ...]) -> tuple[ExperienceRecord, ...]: +def displayable_experiences( + experiences: tuple[ExperienceRecord, ...], +) -> tuple[ExperienceRecord, ...]: filtered = tuple(experience for experience in experiences if should_display_experience(experience)) return filtered[:2] @@ -601,15 +607,39 @@ def _render_assistant_response(response: str) -> Text | list[object]: list_match = _re.match(r"^(\s*)([-*+]|\d+\.)\s+(.*)$", line) if list_match: current_block.append(f"{list_match.group(1)}{list_match.group(2)} ", style=BRAND_ACCENT) - _append_inline_formatted(current_block, list_match.group(3), bold_italic_pat, bold_pat, italic_pat, code_pat, link_pat) + _append_inline_formatted( + current_block, + list_match.group(3), + bold_italic_pat, + bold_pat, + italic_pat, + code_pat, + link_pat, + ) line_index += 1 continue if line.startswith(">"): current_block.append("│ ", style=BRAND_MUTED) - _append_inline_formatted(current_block, line.lstrip("> "), bold_italic_pat, bold_pat, italic_pat, code_pat, link_pat) + _append_inline_formatted( + current_block, + line.lstrip("> "), + bold_italic_pat, + bold_pat, + italic_pat, + code_pat, + link_pat, + ) line_index += 1 continue - _append_inline_formatted(current_block, line, bold_italic_pat, bold_pat, italic_pat, code_pat, link_pat) + _append_inline_formatted( + current_block, + line, + bold_italic_pat, + bold_pat, + italic_pat, + code_pat, + link_pat, + ) line_index += 1 # If no table detected, keep original return type diff --git a/apps/cli/shell_stack.py b/apps/cli/shell_stack.py index 5f72c5f..80bc875 100644 --- a/apps/cli/shell_stack.py +++ b/apps/cli/shell_stack.py @@ -14,7 +14,12 @@ from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.keys import Keys from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl - from prompt_toolkit.layout.containers import ConditionalContainer, HSplit, VSplit, Window + from prompt_toolkit.layout.containers import ( + ConditionalContainer, + HSplit, + VSplit, + Window, + ) from prompt_toolkit.layout.dimension import Dimension from prompt_toolkit.layout.scrollable_pane import ScrollablePane from prompt_toolkit.layout.layout import Layout @@ -106,6 +111,7 @@ def create_output(*_args, **_kwargs): VSplit = None Window = None + def prompt_toolkit_output_without_cpr(): if not PROMPT_TOOLKIT_AVAILABLE: return None diff --git a/apps/cli/shell_support_runtime.py b/apps/cli/shell_support_runtime.py index 4634797..5a5c475 100644 --- a/apps/cli/shell_support_runtime.py +++ b/apps/cli/shell_support_runtime.py @@ -2,135 +2,18 @@ from __future__ import annotations -from collections import deque from collections.abc import Mapping from dataclasses import dataclass, field -from difflib import unified_diff -import os from pathlib import Path import re -import shlex -import time -from packages.contracts import ExperienceRecord -from packages.kernel.runtime import KernelOutcome -from packages.operator.runtime import ( - RecallEvidenceOperatorDetail, - RecallEvidenceSearchHit, - build_recall_evidence_operator_surface, - build_profile_operator_surface, - render_recall_evidence_lines, - render_profile_lines, -) -from packages.tools.handler_support import resolve_allowed_path -from .provider_flow import provider_setup_defaults, run_provider_selection_wizard -from .runtime import CliRuntime -from .wizard import WIZARD_BACK -from .shell_composer import ( - build_command_palette as _build_shell_command_palette, - build_composer_body as _build_shell_composer_body, - build_divider_window as _build_shell_divider_window, - build_input_window as _build_shell_input_window, - build_key_bindings as _build_shell_key_bindings, - build_prompt_buffer as _build_shell_prompt_buffer, - build_queue_preview_window as _build_shell_queue_preview_window, - prompt_continuation as _shell_prompt_continuation, - prompt_label as _shell_prompt_label, - prompt_style as _shell_prompt_style, - prompt_style_map as _shell_prompt_style_map, - prompt_toolkit_composer_available as _shell_prompt_toolkit_composer_available, - read_command as _read_shell_command, - shell_history as _shell_history, -) -from .shell_boot import WAKE_DISPLAY_SECONDS, BootFrameContext, render_boot_frame -from .shell_opening import ( - ShellOpeningContext, - compose_shell_opening_instruction, - compose_shell_opener, -) -from .shell_progress import ( - animations_enabled as _shell_animations_enabled, - render_queued_followup_fragments as _render_shell_queued_followup_fragments, - render_tool_frame as _render_shell_tool_frame, - tool_trace_line as _shell_tool_trace_line, - render_turn_frame as _render_shell_turn_frame, - render_turn_progress_fragments as _render_shell_turn_progress_fragments, - run_tool_with_progress as _run_shell_tool_with_progress, - run_turn_with_progress as _run_shell_turn_with_progress, - run_turn_with_queued_input as _run_shell_turn_with_queued_input, - summarize_progress_prompt as _summarize_shell_progress_prompt, - tool_event_lines as _shell_tool_event_lines, - tool_event_summary as _shell_tool_event_summary, - tool_event_tracker as _shell_tool_event_tracker, - tool_frame_phases as _shell_tool_frame_phases, - turn_phase as _shell_turn_phase, - _tool_trace_emoji as _shell_tool_trace_emoji, -) -from .shell_render import ( - center_brand_block as _center_shell_brand_block, - displayable_experiences as _displayable_shell_experiences, - format_experience_status as _format_shell_experience_status, - growth_panel_lines as _shell_growth_panel_lines, - growth_progress_bar as _shell_growth_progress_bar, - growth_progress_counts as _shell_growth_progress_counts, - recent_activity_lines as _shell_recent_activity_lines, - recent_experience_lines as _shell_recent_experience_lines, - render_brand_column as _render_shell_brand_column, - render_chat_entry as _render_shell_chat_entry, - render_entry as _render_shell_entry, - render_elephant_brand_mark as _render_shell_elephant_mark, - render_growth_mark_for_stage as _render_shell_growth_mark, - render_pending_entries as _render_shell_pending_entries, - render_shell_frame as _render_shell_frame_view, - render_status_column as _render_shell_status_column, - should_display_experience as _should_display_shell_experience, - styled_growth_progress_bar as _styled_shell_growth_progress_bar, -) from .shell_stack import ( - Align, Completion, Completer, - Console, Document, - FormattedText, - Group, - Live, - PROMPT_TOOLKIT_AVAILABLE, - Panel, - RICH_AVAILABLE, - Table, - Text, -) -from .shell_ui import ( - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_DARK, - BRAND_LIGHT, - BRAND_MUTED, - COMMAND_PALETTE_VISIBLE_ROWS, - ELEPHANT_STAGE_ROWS, - GROWTH_HIGHLIGHT_FG, - GROWTH_PROGRESS_EMPTY, - GROWTH_PROGRESS_FILLED, - GROWTH_PROGRESS_WIDTH, - HATCHLING_HEAD_ROWS, - HATCHLING_STAGE_ROWS, - HATCHLING_STAGE_ROWS, - QUEUE_PREVIEW_INSET, - SCOUT_STAGE_ROWS, - SEED_STAGE_ROWS, - SHELL_WELCOME_HEADLINE, - USER_HISTORY_BG, - USER_HISTORY_FG, - WEB_URL_PATTERN, - compact_line as _compact_line, - centered_elephant_rows as _centered_elephant_rows, - display_path as _display_path, - display_width as _display_width, - render_elephant_mark, - resolve_elephant_version as _resolve_elephant_version, ) + @dataclass(frozen=True, slots=True) class TranscriptEntry: kind: str @@ -138,11 +21,13 @@ class TranscriptEntry: body: str meta: str = "" + @dataclass(frozen=True, slots=True) class _PendingFileReview: path: Path before_text: str | None + @dataclass(frozen=True, slots=True) class PendingShellCommand: command: str @@ -166,11 +51,13 @@ def coerce_pending_shell_command(value: object) -> PendingShellCommand: event_payload=event_payload, ) + @dataclass(frozen=True, slots=True) class ShellCommandSpec: name: str description: str + @dataclass(frozen=True, slots=True) class SkillSlashSpec: command: str @@ -181,6 +68,7 @@ class SkillSlashSpec: trigger_phrases: tuple[str, ...] = () keywords: tuple[str, ...] = () + def _skill_metadata_values(value: object) -> tuple[str, ...]: if value is None: return () @@ -204,11 +92,13 @@ def _skill_metadata_values(value: object) -> tuple[str, ...]: normalized.append(token) return tuple(normalized) + def _normalize_skill_match_text(value: str) -> str: normalized = value.strip().lower().replace("/", " ").replace("_", " ").replace("-", " ") normalized = re.sub(r"[^\w\s\u4e00-\u9fff]+", " ", normalized) return " ".join(normalized.split()) + def _skill_phrase_in_message(message: str, phrase: str) -> bool: normalized_message = _normalize_skill_match_text(message) normalized_phrase = _normalize_skill_match_text(phrase) @@ -218,12 +108,14 @@ def _skill_phrase_in_message(message: str, phrase: str) -> bool: return normalized_phrase in normalized_message return f" {normalized_phrase} " in f" {normalized_message} " + def _completion(text: str, *, start_position: int, display: str, meta: str = "") -> Completion: try: return Completion(text, start_position=start_position, display=display, display_meta=meta) except TypeError: # pragma: no cover - fallback signature return Completion(text, start_position=start_position, display=display) + class ShellCompleter(Completer): def __init__(self, shell: "ProductizedShell") -> None: self.shell = shell @@ -284,7 +176,10 @@ def get_completions(self, document: Document, complete_event): ) elif command == "/providers": candidates = ( - ("configure", "Choose a provider, endpoint, key, model, and context window"), + ( + "configure", + "Choose a provider, endpoint, key, model, and context window", + ), ("status", "Show the active provider configuration"), ("list", "List supported provider catalogs"), ) @@ -314,6 +209,7 @@ def get_completions(self, document: Document, complete_event): meta=description, ) + __all__ = [ "TranscriptEntry", "_PendingFileReview", diff --git a/apps/cli/shell_ui.py b/apps/cli/shell_ui.py index 185befe..95641f6 100644 --- a/apps/cli/shell_ui.py +++ b/apps/cli/shell_ui.py @@ -61,7 +61,13 @@ 24, *( len(row) - for rows in (ELEPHANT_STAGE_ROWS, SEED_STAGE_ROWS, HATCHLING_STAGE_ROWS, SCOUT_STAGE_ROWS, ELEPHANT_STAGE_ROWS) + for rows in ( + ELEPHANT_STAGE_ROWS, + SEED_STAGE_ROWS, + HATCHLING_STAGE_ROWS, + SCOUT_STAGE_ROWS, + ELEPHANT_STAGE_ROWS, + ) for row in rows ), ) @@ -162,7 +168,9 @@ def render_elephant_mark(): def render_growth_mark(stage_id: str, *, level: int | None = None): rows = _growth_rows(stage_id, level=level) - fallback = "[Elephant Agent elephant]" if stage_id == "seed" and (level or 0) <= 0 else f"[Elephant Agent {stage_id}]" + fallback = ( + "[Elephant Agent elephant]" if stage_id == "seed" and (level or 0) <= 0 else f"[Elephant Agent {stage_id}]" + ) centered = rows if _uses_literal_cells(rows) else visual_centered_rows(rows, width=GROWTH_MARK_CANVAS_WIDTH) return _render_pixel_mark(centered, fallback=fallback) @@ -187,12 +195,7 @@ def centered_rows(rows: tuple[str, ...], *, width: int | None = None) -> tuple[s def visual_centered_rows(rows: tuple[str, ...], *, width: int | None = None) -> tuple[str, ...]: """Center the visible pixels, not the transparent source-canvas whitespace.""" - visible_cells = [ - index - for row in rows - for index, cell in enumerate(row) - if cell != " " - ] + visible_cells = [index for row in rows for index, cell in enumerate(row) if cell != " "] if not visible_cells: return centered_rows(rows, width=width) visible_left = min(visible_cells) diff --git a/apps/cli/skills_command.py b/apps/cli/skills_command.py index 98f0567..e21d607 100644 --- a/apps/cli/skills_command.py +++ b/apps/cli/skills_command.py @@ -256,13 +256,19 @@ def main_callback( raise typer.Exit(0) @app.command("list") - def list_command(ctx: typer.Context, limit: int = typer.Option(24, "--limit", help="Maximum visible entries to show.")) -> None: + def list_command( + ctx: typer.Context, + limit: int = typer.Option(24, "--limit", help="Maximum visible entries to show."), + ) -> None: runtime = _runtime(state_dir=ctx.obj["state_dir"]) _print_skill_list(runtime, limit=limit) raise typer.Exit(0) @app.command("active") - def active_command(ctx: typer.Context, limit: int = typer.Option(24, "--limit", help="Maximum enabled skills to show.")) -> None: + def active_command( + ctx: typer.Context, + limit: int = typer.Option(24, "--limit", help="Maximum enabled skills to show."), + ) -> None: runtime = _runtime(state_dir=ctx.obj["state_dir"]) _print_active_skills(runtime, limit=limit) raise typer.Exit(0) @@ -279,25 +285,37 @@ def search_command( raise typer.Exit(0) @app.command("view") - def view_command(ctx: typer.Context, reference: str = typer.Argument(..., help="Skill id or source reference to inspect.")) -> None: + def view_command( + ctx: typer.Context, + reference: str = typer.Argument(..., help="Skill id or source reference to inspect."), + ) -> None: runtime = _runtime(state_dir=ctx.obj["state_dir"]) _print_skill_detail(runtime, reference) raise typer.Exit(0) @app.command("enable") - def enable_command(ctx: typer.Context, skill_id: str = typer.Argument(..., help="Installed skill id to enable.")) -> None: + def enable_command( + ctx: typer.Context, + skill_id: str = typer.Argument(..., help="Installed skill id to enable."), + ) -> None: runtime = _runtime(state_dir=ctx.obj["state_dir"]) _print_skill_toggle(runtime, skill_id=skill_id, enabled=True) raise typer.Exit(0) @app.command("disable") - def disable_command(ctx: typer.Context, skill_id: str = typer.Argument(..., help="Installed skill id to disable.")) -> None: + def disable_command( + ctx: typer.Context, + skill_id: str = typer.Argument(..., help="Installed skill id to disable."), + ) -> None: runtime = _runtime(state_dir=ctx.obj["state_dir"]) _print_skill_toggle(runtime, skill_id=skill_id, enabled=False) raise typer.Exit(0) @app.command("install") - def install_command(ctx: typer.Context, reference: str = typer.Argument(..., help="Hub id, public reference, local path, or manifest path.")) -> None: + def install_command( + ctx: typer.Context, + reference: str = typer.Argument(..., help="Hub id, public reference, local path, or manifest path."), + ) -> None: runtime = _runtime(state_dir=ctx.obj["state_dir"]) _print_skill_install(runtime, reference) raise typer.Exit(0) diff --git a/apps/cli/turn_metrics.py b/apps/cli/turn_metrics.py index c7e4696..7b0a8a1 100644 --- a/apps/cli/turn_metrics.py +++ b/apps/cli/turn_metrics.py @@ -130,10 +130,10 @@ def _format_compaction_notice(frame: dict) -> str: parts: list[str] = [] for segment in detail.split(): if segment.startswith("messages="): - value = segment[len("messages="):] + value = segment[len("messages=") :] parts.append(f"messages {value.replace('->', ' → ')}") elif segment.startswith("compressing="): - parts.append(f"compressed {segment[len('compressing='):]}") + parts.append(f"compressed {segment[len('compressing=') :]}") return " · ".join(parts) if parts else "compacted" return "" @@ -148,7 +148,9 @@ def _append_outcome(self, outcome: KernelOutcome) -> None: for stage in outcome.stages ] self._append_entry("status", "Runtime stages", "\n".join(stage_lines)) - assistant_name = self.runtime.inspect_profile(self.runtime.inspect_session(self.session_id).personal_model_id).state.display_name + assistant_name = self.runtime.inspect_profile( + self.runtime.inspect_session(self.session_id).personal_model_id + ).state.display_name assistant_body = _compose_reasoning_display( getattr(outcome.execution, "reasoning", ""), outcome.execution.summary, diff --git a/apps/cli/wizard.py b/apps/cli/wizard.py index ae2b37d..6617910 100644 --- a/apps/cli/wizard.py +++ b/apps/cli/wizard.py @@ -21,7 +21,14 @@ from prompt_toolkit.layout.dimension import Dimension as PromptDimension from prompt_toolkit.shortcuts import input_dialog from prompt_toolkit.styles import Style as PromptStyle - from prompt_toolkit.widgets import Button, CheckboxList, Dialog, Label, RadioList, TextArea + from prompt_toolkit.widgets import ( + Button, + CheckboxList, + Dialog, + Label, + RadioList, + TextArea, + ) PROMPT_TOOLKIT_DIALOGS_AVAILABLE = True except ModuleNotFoundError: # pragma: no cover - optional wizard polish @@ -182,7 +189,10 @@ def _wizard_choice_menu( and RadioList is not None ): return default - default_value = next((choice.value for choice in choices if choice.value == default), choices[0].value) + default_value = next( + (choice.value for choice in choices if choice.value == default), + choices[0].value, + ) values = tuple( ( choice.value, @@ -200,7 +210,11 @@ def _wizard_choice_menu( show_scrollbar=len(choices) > WIZARD_MAX_VISIBLE_CHOICES, ) _guard_radio_list_selection_bounds(radio_list) - hint = "Enter continues · Back goes back · Esc cancels · ↑/↓ or j/k moves" if allow_back else "Enter continues · Esc cancels · ↑/↓ or j/k moves" + hint = ( + "Enter continues · Back goes back · Esc cancels · ↑/↓ or j/k moves" + if allow_back + else "Enter continues · Esc cancels · ↑/↓ or j/k moves" + ) def _accept() -> None: get_app().exit(result=radio_list.current_value) @@ -291,7 +305,11 @@ def _wizard_dual_choice_menu( fallback = choices[0].value selected = { "first": default_first if default_first in values_by_id else fallback, - "second": default_second if default_second in values_by_id else default_first if default_first in values_by_id else fallback, + "second": default_second + if default_second in values_by_id + else default_first + if default_first in values_by_id + else fallback, } active_role = {"name": "first"} selection_history: list[str] = [role for role in ("first", "second") if selected[role]] @@ -325,9 +343,15 @@ def _role_fragments(): second_suffix = " · next" if active_role["name"] == "second" else "" return [ ("class:role-primary", f"* {first_title}"), - ("class:role-primary-detail", f" · {(first_choice.label if first_choice is not None else '')}{first_suffix}\n"), + ( + "class:role-primary-detail", + f" · {(first_choice.label if first_choice is not None else '')}{first_suffix}\n", + ), ("class:role-secondary", f"* {second_title}"), - ("class:role-secondary-detail", f" · {(second_choice.label if second_choice is not None else '')}{second_suffix}"), + ( + "class:role-secondary-detail", + f" · {(second_choice.label if second_choice is not None else '')}{second_suffix}", + ), ] def _validation_fragments(): @@ -476,10 +500,20 @@ def _cancel() -> None: dialog = Dialog( title=title, body=HSplit( - [Label(text=prompt, dont_extend_height=True), checkbox, Label(text="Space toggles · Enter continues · Back goes back · Esc cancels", dont_extend_height=True)], + [ + Label(text=prompt, dont_extend_height=True), + checkbox, + Label( + text="Space toggles · Enter continues · Back goes back · Esc cancels", + dont_extend_height=True, + ), + ], padding=1, ), - buttons=[Button(text="Continue", handler=_accept), Button(text="Back", handler=_back)], + buttons=[ + Button(text="Continue", handler=_accept), + Button(text="Back", handler=_back), + ], with_background=True, ) bindings = PromptKeyBindings() @@ -492,13 +526,15 @@ def _accept_binding(_event) -> None: def _cancel_binding(_event) -> None: _cancel() - answer = _wizard_run_dialog(Application( - layout=Layout(dialog, focused_element=checkbox), - key_bindings=bindings, - style=_wizard_style(), - full_screen=True, - mouse_support=True, - )) + answer = _wizard_run_dialog( + Application( + layout=Layout(dialog, focused_element=checkbox), + key_bindings=bindings, + style=_wizard_style(), + full_screen=True, + mouse_support=True, + ) + ) if answer is WIZARD_BACK or answer is WIZARD_CANCEL: return answer return tuple(str(value) for value in (answer or ())) @@ -547,7 +583,9 @@ def _wizard_multi_choice_prompt( return tuple(selected) -def _wizard_choice_window(total: int, selected: int, *, max_visible: int = WIZARD_MAX_VISIBLE_CHOICES) -> tuple[int, int]: +def _wizard_choice_window( + total: int, selected: int, *, max_visible: int = WIZARD_MAX_VISIBLE_CHOICES +) -> tuple[int, int]: if total <= max_visible: return 0, total if selected < 0: @@ -584,11 +622,7 @@ def _wizard_choice_fragments( active = index == selected marker = "›" if active else " " label_style = "class:selected" if active else "class:item" - detail_style = ( - f"class:{choice.selected_detail_style}" - if active - else f"class:{choice.detail_style}" - ) + detail_style = f"class:{choice.selected_detail_style}" if active else f"class:{choice.detail_style}" fragments.append((label_style, f"{marker} {_wizard_choice_label(choice)}\n")) fragments.append((detail_style, f" {choice.detail}\n")) if end < len(choices): @@ -699,7 +733,11 @@ def _wizard_required_text_dialog( prompt="", ) validation_state = {"message": ""} - hint = "Enter continues · Esc goes back · Tab moves focus" if allow_back else "Enter continues · Esc cancels · Tab moves focus" + hint = ( + "Enter continues · Esc goes back · Tab moves focus" + if allow_back + else "Enter continues · Esc cancels · Tab moves focus" + ) def _set_validation() -> None: validation_state["message"] = required_message @@ -723,7 +761,11 @@ def _cancel() -> None: [ Label(text=prompt, dont_extend_height=True), text_field, - Label(text=lambda: validation_state["message"], style="class:validation", dont_extend_height=True), + Label( + text=lambda: validation_state["message"], + style="class:validation", + dont_extend_height=True, + ), Label(text=hint, dont_extend_height=True), ], padding=1, @@ -791,7 +833,11 @@ def _wizard_password_dialog( wrap_lines=False, prompt="", ) - hint = "Enter continues · Esc goes back · Tab moves focus" if allow_back else "Enter continues · Esc cancels · Tab moves focus" + hint = ( + "Enter continues · Esc goes back · Tab moves focus" + if allow_back + else "Enter continues · Esc cancels · Tab moves focus" + ) def _accept() -> None: get_app().exit(result=password_field.text) diff --git a/apps/cron_scheduler_command.py b/apps/cron_scheduler_command.py index 48c735f..bb3089f 100644 --- a/apps/cron_scheduler_command.py +++ b/apps/cron_scheduler_command.py @@ -15,7 +15,6 @@ _gateway_runtime_environ, _run_logs, _run_status, - _run_stop, ) from apps.runtime_layout import ( default_cli_state_dir, @@ -47,7 +46,10 @@ def _build_parser(*, defaults: dict[str, Path]) -> ArgumentParser: common.add_argument("--cli-state-dir", type=Path, default=defaults["cli_state_dir"]) common.add_argument("--elephant-id", default="elephant:gateway") - parser = ArgumentParser(prog="elephant cron", description="Manage the Elephant Agent cron scheduler daemon.") + parser = ArgumentParser( + prog="elephant cron", + description="Manage the Elephant Agent cron scheduler daemon.", + ) subparsers = parser.add_subparsers(dest="command") start = subparsers.add_parser("start", parents=[common], help="Start the cron scheduler.") @@ -82,19 +84,43 @@ def _build_parser(*, defaults: dict[str, Path]) -> ArgumentParser: def _add_target_options(parser: ArgumentParser) -> None: parser.set_defaults(runtime_target="scheduler") - parser.add_argument("--target", dest="runtime_target", choices=("configured", "scheduler"), default="scheduler", help=SUPPRESS) + parser.add_argument( + "--target", + dest="runtime_target", + choices=("configured", "scheduler"), + default="scheduler", + help=SUPPRESS, + ) def _add_start_options(parser: ArgumentParser) -> None: _add_target_options(parser) - parser.add_argument("--detach", action="store_true", help="Start in a background process and return immediately.") - parser.add_argument("--interval-seconds", type=float, default=60.0, help="Seconds between scheduler ticks.") + parser.add_argument( + "--detach", + action="store_true", + help="Start in a background process and return immediately.", + ) + parser.add_argument( + "--interval-seconds", + type=float, + default=60.0, + help="Seconds between scheduler ticks.", + ) def _add_stop_options(parser: ArgumentParser) -> None: _add_target_options(parser) - parser.add_argument("--timeout", type=float, default=10.0, help="Seconds to wait before failing or forcing.") - parser.add_argument("--force", action="store_true", help="Send SIGKILL when the process does not exit.") + parser.add_argument( + "--timeout", + type=float, + default=10.0, + help="Seconds to wait before failing or forcing.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Send SIGKILL when the process does not exit.", + ) def _add_logs_options(parser: ArgumentParser) -> None: @@ -198,9 +224,17 @@ def main_callback(ctx: typer.Context) -> None: def start_command( state_dir: Path | None = typer.Option(None, "--state-dir", hidden=True), cli_state_dir: Path | None = typer.Option(None, "--cli-state-dir", hidden=True), - elephant_id: str = typer.Option("elephant:gateway", "--elephant-id", help="Scoped runtime elephant id for scheduler operations."), + elephant_id: str = typer.Option( + "elephant:gateway", + "--elephant-id", + help="Scoped runtime elephant id for scheduler operations.", + ), target: str = typer.Option("scheduler", "--target", help="Runtime target to inspect or launch."), - detach: bool = typer.Option(False, "--detach", help="Start in a background process and return immediately."), + detach: bool = typer.Option( + False, + "--detach", + help="Start in a background process and return immediately.", + ), interval_seconds: float = typer.Option(60.0, "--interval-seconds", help="Seconds between scheduler ticks."), ) -> None: args = _common_args(state_dir, cli_state_dir, elephant_id) @@ -217,7 +251,11 @@ def start_command( def run_command( state_dir: Path | None = typer.Option(None, "--state-dir", hidden=True), cli_state_dir: Path | None = typer.Option(None, "--cli-state-dir", hidden=True), - elephant_id: str = typer.Option("elephant:gateway", "--elephant-id", help="Scoped runtime elephant id for scheduler operations."), + elephant_id: str = typer.Option( + "elephant:gateway", + "--elephant-id", + help="Scoped runtime elephant id for scheduler operations.", + ), target: str = typer.Option("scheduler", "--target", help="Runtime target to inspect or launch."), interval_seconds: float = typer.Option(60.0, "--interval-seconds", help="Seconds between scheduler ticks."), once: bool = typer.Option(False, "--once", help="Run one scheduler tick and exit."), @@ -234,7 +272,11 @@ def run_command( def status_command( state_dir: Path | None = typer.Option(None, "--state-dir", hidden=True), cli_state_dir: Path | None = typer.Option(None, "--cli-state-dir", hidden=True), - elephant_id: str = typer.Option("elephant:gateway", "--elephant-id", help="Scoped runtime elephant id for scheduler operations."), + elephant_id: str = typer.Option( + "elephant:gateway", + "--elephant-id", + help="Scoped runtime elephant id for scheduler operations.", + ), target: str = typer.Option("scheduler", "--target", help="Runtime target to inspect or launch."), ) -> None: args = _common_args(state_dir, cli_state_dir, elephant_id) @@ -246,7 +288,11 @@ def status_command( def stop_command( state_dir: Path | None = typer.Option(None, "--state-dir", hidden=True), cli_state_dir: Path | None = typer.Option(None, "--cli-state-dir", hidden=True), - elephant_id: str = typer.Option("elephant:gateway", "--elephant-id", help="Scoped runtime elephant id for scheduler operations."), + elephant_id: str = typer.Option( + "elephant:gateway", + "--elephant-id", + help="Scoped runtime elephant id for scheduler operations.", + ), target: str = typer.Option("scheduler", "--target", help="Runtime target to inspect or launch."), timeout: float = typer.Option(10.0, "--timeout", help="Seconds to wait before failing or forcing."), force: bool = typer.Option(False, "--force", help="Send SIGKILL when the process does not exit."), @@ -257,6 +303,7 @@ def stop_command( args.force = force if daemon_is_running(args.state_dir): from apps.daemon_command import stop_daemon + raise typer.Exit(stop_daemon(args.state_dir, timeout=timeout, force=force)) print("Elephant daemon is not running. Nothing to stop.") raise typer.Exit(0) @@ -265,9 +312,17 @@ def stop_command( def restart_command( state_dir: Path | None = typer.Option(None, "--state-dir", hidden=True), cli_state_dir: Path | None = typer.Option(None, "--cli-state-dir", hidden=True), - elephant_id: str = typer.Option("elephant:gateway", "--elephant-id", help="Scoped runtime elephant id for scheduler operations."), + elephant_id: str = typer.Option( + "elephant:gateway", + "--elephant-id", + help="Scoped runtime elephant id for scheduler operations.", + ), target: str = typer.Option("scheduler", "--target", help="Runtime target to inspect or launch."), - detach: bool = typer.Option(True, "--detach/--foreground", help="Restart in the background by default, or keep it in the foreground."), + detach: bool = typer.Option( + True, + "--detach/--foreground", + help="Restart in the background by default, or keep it in the foreground.", + ), interval_seconds: float = typer.Option(60.0, "--interval-seconds", help="Seconds between scheduler ticks."), timeout: float = typer.Option(10.0, "--timeout", hidden=True), force: bool = typer.Option(False, "--force", hidden=True), @@ -280,12 +335,15 @@ def restart_command( args.force = force if daemon_is_running(args.state_dir): from apps.daemon_command import restart_daemon - raise typer.Exit(restart_daemon( - args.state_dir, - args.cli_state_dir, - timeout=timeout, - force=force, - )) + + raise typer.Exit( + restart_daemon( + args.state_dir, + args.cli_state_dir, + timeout=timeout, + force=force, + ) + ) # No daemon running — start a fresh daemon raise typer.Exit(_cron_start_via_daemon(args)) @@ -293,7 +351,11 @@ def restart_command( def logs_command( state_dir: Path | None = typer.Option(None, "--state-dir", hidden=True), cli_state_dir: Path | None = typer.Option(None, "--cli-state-dir", hidden=True), - elephant_id: str = typer.Option("elephant:gateway", "--elephant-id", help="Scoped runtime elephant id for scheduler operations."), + elephant_id: str = typer.Option( + "elephant:gateway", + "--elephant-id", + help="Scoped runtime elephant id for scheduler operations.", + ), target: str = typer.Option("scheduler", "--target", help="Runtime target to inspect or launch."), tail: int = typer.Option(80, "--tail", help="Show the last N log lines."), follow: bool = typer.Option(False, "--follow", help="Keep streaming appended log output."), diff --git a/apps/daemon.py b/apps/daemon.py index 2c4f24d..4658b9e 100644 --- a/apps/daemon.py +++ b/apps/daemon.py @@ -11,7 +11,6 @@ from typing import Any from apps.logging_setup import setup_logging -from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir logger = logging.getLogger("elephant.daemon") @@ -94,9 +93,7 @@ async def start(self) -> None: self._mark_runtime_ready() # Start periodic health heartbeat - heartbeat_task = asyncio.create_task( - self._health_heartbeat(), name="health-heartbeat" - ) + heartbeat_task = asyncio.create_task(self._health_heartbeat(), name="health-heartbeat") self._tasks.append(heartbeat_task) # Block until shutdown requested @@ -134,7 +131,10 @@ async def _stop_all(self) -> None: task_name = task.get_name() logger.error("task %s failed during shutdown: %s", task_name, result) except asyncio.TimeoutError: - logger.warning("shutdown timed out after %ss, some tasks may not have stopped cleanly", SHUTDOWN_TIMEOUT) + logger.warning( + "shutdown timed out after %ss, some tasks may not have stopped cleanly", + SHUTDOWN_TIMEOUT, + ) # Update all statuses for name, status in self._service_statuses.items(): @@ -146,7 +146,11 @@ async def _stop_all(self) -> None: def _mark_runtime_ready(self) -> None: """Mark the runtime record ready after the daemon services are live.""" try: - from apps.daemon_command import _daemon_record_path, _load_record, _write_record + from apps.daemon_command import ( + _daemon_record_path, + _load_record, + _write_record, + ) record_path = _daemon_record_path(self.state_dir) record = _load_record(record_path) or {} @@ -172,7 +176,11 @@ async def _stop_service_safe(self, key: str, service: Any) -> None: await asyncio.wait_for(service.stop_daemon_task(), timeout=GRACEFUL_STOP_TIMEOUT) logger.info("service %s stopped gracefully", key) except asyncio.TimeoutError: - logger.warning("service %s did not stop within %ss, force-cancelling", key, GRACEFUL_STOP_TIMEOUT) + logger.warning( + "service %s did not stop within %ss, force-cancelling", + key, + GRACEFUL_STOP_TIMEOUT, + ) except Exception as exc: logger.error("failed to stop service %s: %s", key, exc) @@ -190,12 +198,12 @@ async def _health_heartbeat(self, interval_seconds: float = 300.0) -> None: uptime = 0.0 if self._started_at: uptime = (datetime.now(UTC) - datetime.fromisoformat(self._started_at)).total_seconds() - service_summary = ", ".join( - f"{name}={s.status}" for name, s in self._service_statuses.items() - ) or "none" + service_summary = ", ".join(f"{name}={s.status}" for name, s in self._service_statuses.items()) or "none" logger.info( "heartbeat: uptime=%.0fs tasks=%d services=[%s]", - uptime, len(self._tasks), service_summary, + uptime, + len(self._tasks), + service_summary, ) # ── Gateway App ──────────────────────────────────────────── @@ -204,6 +212,7 @@ async def _start_gateway_app(self) -> None: """Build the shared GatewayApp and plugin registry.""" try: from apps.gateway.runtime import build_gateway_app + app, chat_adapter, webhook_adapter = build_gateway_app( state_dir=str(self.state_dir), start_learning_worker=False, @@ -230,7 +239,9 @@ async def _start_dashboard_api(self) -> None: database_path = self.cli_state_dir / "elephant.sqlite3" self._dashboard_api_app = create_app(database_path=database_path) self._service_statuses["dashboard_api"] = DaemonServiceStatus( - name="dashboard_api", status="running", started_at=datetime.now(UTC).isoformat() + name="dashboard_api", + status="running", + started_at=datetime.now(UTC).isoformat(), ) logger.info("Dashboard API initialized (database=%s)", database_path) except Exception as exc: @@ -265,9 +276,7 @@ async def _start_im_adapters(self) -> None: service = registry.create_service(key, app=app) except Exception as exc: logger.error("failed to create service %s: %s", key, exc) - self._service_statuses[key] = DaemonServiceStatus( - name=key, status="failed", last_error=str(exc) - ) + self._service_statuses[key] = DaemonServiceStatus(name=key, status="failed", last_error=str(exc)) continue # ── Preflight: check credentials ── @@ -299,13 +308,13 @@ async def _start_im_adapters(self) -> None: else: logger.info("%s: webhook-only, no long-running task", key) self._service_statuses[key] = DaemonServiceStatus( - name=key, status="running", started_at=datetime.now(UTC).isoformat() + name=key, + status="running", + started_at=datetime.now(UTC).isoformat(), ) except Exception as exc: logger.error("failed to start service %s: %s", key, exc) - self._service_statuses[key] = DaemonServiceStatus( - name=key, status="failed", last_error=str(exc) - ) + self._service_statuses[key] = DaemonServiceStatus(name=key, status="failed", last_error=str(exc)) # ── Service starters ─────────────────────────────────────── @@ -317,6 +326,7 @@ async def _start_http_server(self) -> None: self._http_app = app try: from aiohttp import web + runner = web.AppRunner(app, access_log=access_log) await runner.setup() site = web.TCPSite(runner, self.host, self.port) @@ -327,7 +337,9 @@ async def _start_http_server(self) -> None: logger.info("HTTP server listening on %s:%d", self.host, self.port) except ImportError: logger.warning("aiohttp not available, HTTP server skipped") - self._service_statuses["http"] = DaemonServiceStatus(name="http", status="failed", last_error="aiohttp not installed") + self._service_statuses["http"] = DaemonServiceStatus( + name="http", status="failed", last_error="aiohttp not installed" + ) async def _start_cron_scheduler(self) -> None: """Start the cron scheduler as an async task.""" @@ -361,7 +373,9 @@ async def _start_supervisor(self) -> None: self._tasks.append(task) self._task_names[task] = "supervisor" self._service_statuses["supervisor"] = DaemonServiceStatus( - name="supervisor", status="running", started_at=datetime.now(UTC).isoformat() + name="supervisor", + status="running", + started_at=datetime.now(UTC).isoformat(), ) async def _start_learning_worker(self) -> None: @@ -378,7 +392,9 @@ async def _start_learning_worker(self) -> None: self._tasks.append(task) self._task_names[task] = "learning_worker" self._service_statuses["learning_worker"] = DaemonServiceStatus( - name="learning_worker", status="running", started_at=datetime.now(UTC).isoformat() + name="learning_worker", + status="running", + started_at=datetime.now(UTC).isoformat(), ) # ── Dynamic adapter lifecycle ───────────────────────────────── @@ -439,7 +455,10 @@ async def start_adapter(self, key: str) -> dict[str, str]: self._register_http_routes_for_service(service, key) self._registered_http_service_keys.append(key) else: - logger.info("adapter %s HTTP routes already registered, skipping re-register", key) + logger.info( + "adapter %s HTTP routes already registered, skipping re-register", + key, + ) # Start daemon task if applicable if is_daemon_service: @@ -454,14 +473,14 @@ async def start_adapter(self, key: str) -> dict[str, str]: self._tasks.append(guarded) self._task_names[guarded] = key self._service_statuses[key] = DaemonServiceStatus( - name=key, status="running", started_at=datetime.now(UTC).isoformat() + name=key, + status="running", + started_at=datetime.now(UTC).isoformat(), ) logger.info("adapter %s started dynamically", key) return {"status": "running"} except Exception as exc: - self._service_statuses[key] = DaemonServiceStatus( - name=key, status="failed", last_error=str(exc) - ) + self._service_statuses[key] = DaemonServiceStatus(name=key, status="failed", last_error=str(exc)) return {"status": "error", "reason": str(exc)} # Webhook-only services (e.g. telegram with no daemon task) @@ -498,9 +517,7 @@ async def stop_adapter(self, key: str) -> dict[str, str]: del self._daemon_services[key] # Cancel the guarded task for this adapter - tasks_to_remove = [ - t for t, n in self._task_names.items() if n == key - ] + tasks_to_remove = [t for t, n in self._task_names.items() if n == key] for task in tasks_to_remove: if not task.done(): task.cancel() @@ -513,9 +530,7 @@ async def stop_adapter(self, key: str) -> dict[str, str]: self._task_names.pop(task, None) # Update status - self._service_statuses[key] = DaemonServiceStatus( - name=key, status="stopped" - ) + self._service_statuses[key] = DaemonServiceStatus(name=key, status="stopped") # Note: HTTP routes remain registered but the handler will return 503 # for stopped services (checked via service status in the handler) @@ -552,7 +567,11 @@ def get_status(self) -> dict[str, Any]: ), "started_at": self._started_at, "services": { - name: {"status": s.status, "started_at": s.started_at, "last_error": s.last_error} + name: { + "status": s.status, + "started_at": s.started_at, + "last_error": s.last_error, + } for name, s in self._service_statuses.items() }, } @@ -616,6 +635,7 @@ async def _daemon_task_guard( # ── CLI entry point ──────────────────────────────────────────── + def _pid() -> int: return __import__("os").getpid() diff --git a/apps/daemon_command.py b/apps/daemon_command.py index 383c7da..181c5d4 100644 --- a/apps/daemon_command.py +++ b/apps/daemon_command.py @@ -13,7 +13,6 @@ import sys import time import warnings -from dataclasses import asdict from pathlib import Path from typing import IO, Sequence @@ -117,6 +116,7 @@ def _daemon_healthz_payload(state_dir: Path) -> dict | None: addr = host if host != "0.0.0.0" else "127.0.0.1" try: import urllib.request + url = f"http://{addr}:{port}/healthz" req = urllib.request.Request(url, method="GET") with urllib.request.urlopen(req, timeout=2) as resp: @@ -132,6 +132,7 @@ def _daemon_healthz_payload(state_dir: Path) -> dict | None: def _utc_now_iso() -> str: from datetime import UTC, datetime + return datetime.now(UTC).isoformat() @@ -251,7 +252,11 @@ def start_command( host: str = typer.Option("0.0.0.0", "--host", help="HTTP listen host."), port: int = typer.Option(8900, "--port", help="HTTP listen port."), log_level: str = typer.Option("INFO", "--log-level", help="Log level: DEBUG, INFO, WARNING, ERROR."), - detach: bool = typer.Option(False, "--detach", help="Start in a background process. This is the recommended way to run all Elephant services (IM gateways, cron, supervisor, learning worker) together."), + detach: bool = typer.Option( + False, + "--detach", + help="Start in a background process. This is the recommended way to run all Elephant services (IM gateways, cron, supervisor, learning worker) together.", + ), ) -> None: if detach: raise typer.Exit(_start_detached(state_dir, cli_state_dir, host=host, port=port, log_level=log_level)) @@ -304,8 +309,6 @@ def logs_command( for line in lines[-tail:]: print(line) if follow: - import select - with log_path.open("r", encoding="utf-8") as f: f.seek(0, 2) try: @@ -321,7 +324,14 @@ def logs_command( return app -def _run_foreground(state_dir: Path, cli_state_dir: Path, *, host: str, port: int, log_level: str = "INFO") -> int: +def _run_foreground( + state_dir: Path, + cli_state_dir: Path, + *, + host: str, + port: int, + log_level: str = "INFO", +) -> int: """Run the daemon in the foreground (blocking).""" from apps.daemon import run_daemon_foreground @@ -340,11 +350,7 @@ def _run_foreground(state_dir: Path, cli_state_dir: Path, *, host: str, port: in try: existing_pid = _read_pid(pid_path) - if ( - existing_pid is not None - and existing_pid != os.getpid() - and _pid_is_running(existing_pid) - ): + if existing_pid is not None and existing_pid != os.getpid() and _pid_is_running(existing_pid): print(f"Elephant daemon is already running with pid {existing_pid}.") return 1 @@ -353,21 +359,24 @@ def _run_foreground(state_dir: Path, cli_state_dir: Path, *, host: str, port: in pid_path.write_text(f"{pid}\n", encoding="utf-8") # Write runtime record - _write_record(record_path, { - "runtime_id": f"{DAEMON_SERVICE_KEY}:{DAEMON_TARGET}", - "service_key": DAEMON_SERVICE_KEY, - "target": DAEMON_TARGET, - "status": "running", - "pid": pid, - "pid_path": str(pid_path), - "log_path": str(_daemon_log_path(state_dir)), - "record_path": str(record_path), - "state_dir": str(state_dir), - "cli_state_dir": str(cli_state_dir), - "host": host, - "port": port, - "started_at": _utc_now_iso(), - }) + _write_record( + record_path, + { + "runtime_id": f"{DAEMON_SERVICE_KEY}:{DAEMON_TARGET}", + "service_key": DAEMON_SERVICE_KEY, + "target": DAEMON_TARGET, + "status": "running", + "pid": pid, + "pid_path": str(pid_path), + "log_path": str(_daemon_log_path(state_dir)), + "record_path": str(record_path), + "state_dir": str(state_dir), + "cli_state_dir": str(cli_state_dir), + "host": host, + "port": port, + "started_at": _utc_now_iso(), + }, + ) finally: # Release startup lock — PID file now provides singleton protection. _release_daemon_lock(lock_fd) @@ -385,7 +394,14 @@ def _run_foreground(state_dir: Path, cli_state_dir: Path, *, host: str, port: in _mark_daemon_stopped(record_path) -def _start_detached(state_dir: Path, cli_state_dir: Path, *, host: str, port: int, log_level: str = "INFO") -> int: +def _start_detached( + state_dir: Path, + cli_state_dir: Path, + *, + host: str, + port: int, + log_level: str = "INFO", +) -> int: """Start the daemon as a background process.""" state_dir.mkdir(parents=True, exist_ok=True) pid_path = _daemon_pid_path(state_dir) @@ -410,11 +426,16 @@ def _start_detached(state_dir: Path, cli_state_dir: Path, *, host: str, port: in "apps.launcher", "daemon", "start", - "--state-dir", str(state_dir), - "--cli-state-dir", str(cli_state_dir), - "--host", host, - "--port", str(port), - "--log-level", log_level, + "--state-dir", + str(state_dir), + "--cli-state-dir", + str(cli_state_dir), + "--host", + host, + "--port", + str(port), + "--log-level", + log_level, ] started_at = _utc_now_iso() @@ -431,22 +452,25 @@ def _start_detached(state_dir: Path, cli_state_dir: Path, *, host: str, port: in # Write PID file + record (still under lock — eliminates TOCTOU race) pid_path.write_text(f"{process.pid}\n", encoding="utf-8") - _write_record(record_path, { - "runtime_id": f"{DAEMON_SERVICE_KEY}:{DAEMON_TARGET}", - "service_key": DAEMON_SERVICE_KEY, - "target": DAEMON_TARGET, - "status": "starting", - "pid": process.pid, - "pid_path": str(pid_path), - "log_path": str(log_path), - "record_path": str(record_path), - "command": command, - "state_dir": str(state_dir), - "cli_state_dir": str(cli_state_dir), - "host": host, - "port": port, - "started_at": started_at, - }) + _write_record( + record_path, + { + "runtime_id": f"{DAEMON_SERVICE_KEY}:{DAEMON_TARGET}", + "service_key": DAEMON_SERVICE_KEY, + "target": DAEMON_TARGET, + "status": "starting", + "pid": process.pid, + "pid_path": str(pid_path), + "log_path": str(log_path), + "record_path": str(record_path), + "command": command, + "state_dir": str(state_dir), + "cli_state_dir": str(cli_state_dir), + "host": host, + "port": port, + "started_at": started_at, + }, + ) finally: # Release startup lock. The PID file is now written and the child # process will re-acquire this lock inside its own _run_foreground(). @@ -470,13 +494,15 @@ def _start_detached(state_dir: Path, cli_state_dir: Path, *, host: str, port: in if return_code is not None: _remove_file_if_exists(pid_path) record = _load_record(record_path) or {} - record.update({ - "status": "failed", - "pid": None, - "stopped_at": _utc_now_iso(), - "last_exit_code": return_code, - "last_error": f"process exited with code {return_code}", - }) + record.update( + { + "status": "failed", + "pid": None, + "stopped_at": _utc_now_iso(), + "last_exit_code": return_code, + "last_error": f"process exited with code {return_code}", + } + ) _write_record(record_path, record) print(f"Elephant daemon failed to start (exit {return_code}). Check {log_path}.") return 1 @@ -490,12 +516,11 @@ def _start_detached(state_dir: Path, cli_state_dir: Path, *, host: str, port: in else: record["status"] = "starting" record["last_error"] = ( - f"healthz not ready after {_DAEMON_STARTUP_WAIT_SECONDS:g}s; " - "daemon process is still running" + f"healthz not ready after {_DAEMON_STARTUP_WAIT_SECONDS:g}s; daemon process is still running" ) _write_record(record_path, record) - print(f"Elephant daemon is now running in the background.") + print("Elephant daemon is now running in the background.") print(f" PID: {process.pid}") print(f" PID file: {pid_path}") print(f" Log file: {log_path}") @@ -605,6 +630,7 @@ def _show_status(state_dir: Path) -> int: if running: try: import urllib.request + url = f"http://127.0.0.1:{port}/healthz" req = urllib.request.Request(url) with urllib.request.urlopen(req, timeout=3) as resp: @@ -623,6 +649,7 @@ def _fmt_iso(iso_str: str) -> str: """Format an ISO timestamp to a more readable local time string.""" try: from datetime import datetime + dt = datetime.fromisoformat(iso_str) return dt.strftime("%Y-%m-%d %H:%M:%S") except Exception: @@ -633,11 +660,11 @@ def _print_services_table(services: dict) -> None: """Print a compact, aligned services status table.""" # Icons for each status icons = { - "running": "\U0001f7e2", # green circle - "skipped": "\U000026AA", # white circle - "failed": "\U0001f534", # red circle - "stopped": "\U0001f7e1", # yellow circle - "idle": "\U000026AA", # white circle + "running": "\U0001f7e2", # green circle + "skipped": "\U000026aa", # white circle + "failed": "\U0001f534", # red circle + "stopped": "\U0001f7e1", # yellow circle + "idle": "\U000026aa", # white circle } # Categorize services @@ -650,7 +677,7 @@ def _print_services_table(services: dict) -> None: if key in services: info = services[key] s = info.get("status", "unknown") - icon = icons.get(s, "\U000026AA") + icon = icons.get(s, "\U000026aa") line = f" {icon} {key:<12} {s}" if s == "skipped" and info.get("last_error"): line += f" ({info['last_error']})" @@ -673,7 +700,7 @@ def _print_services_table(services: dict) -> None: if key in services: info = services[key] s = info.get("status", "unknown") - icon = icons.get(s, "\U000026AA") + icon = icons.get(s, "\U000026aa") line = f" {icon} {key:<12} {s}" if s == "failed" and info.get("last_error"): line += f" ({info['last_error']})" diff --git a/apps/daemon_http.py b/apps/daemon_http.py index 1e31a1c..480e449 100644 --- a/apps/daemon_http.py +++ b/apps/daemon_http.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json import logging from pathlib import Path from typing import Any @@ -133,6 +132,7 @@ def _resolve_dashboard_static_dir(daemon: Any) -> Path | None: # Installed package: /dashboard/dist/ try: from packages.runtime_layout import infer_install_root_from_state_dir + install_root = infer_install_root_from_state_dir(daemon.cli_state_dir) installed_dist = install_root / "dashboard" / "dist" if (installed_dist / "index.html").is_file(): diff --git a/apps/daemon_tasks.py b/apps/daemon_tasks.py index 41d6c15..63c8191 100644 --- a/apps/daemon_tasks.py +++ b/apps/daemon_tasks.py @@ -16,6 +16,7 @@ # ── Cron Scheduler ───────────────────────────────────────────── + async def cron_scheduler_loop( *, cli_state_dir: Path, @@ -27,7 +28,6 @@ async def cron_scheduler_loop( from apps.gateway.cron_service import ( build_gateway_cron_delivery_callback, cron_execution_should_deliver, - run_cron_scheduler_loop, ) from apps.cli.runtime import CliRuntime @@ -62,7 +62,11 @@ async def cron_scheduler_loop( try: delivery_callback(execution.job, execution) except Exception as exc: - logger.error("cron delivery failed for %s: %s", execution.job.job_id, exc) + logger.error( + "cron delivery failed for %s: %s", + execution.job.job_id, + exc, + ) elif tick_count % 10 == 0: logger.debug("cron tick #%d: no due jobs", tick_count) except Exception as exc: @@ -73,6 +77,7 @@ async def cron_scheduler_loop( if now_ts - last_maintenance_at > 86400: try: from packages.understanding.auto_retire import retire_stale_facts + retired = retire_stale_facts(runtime.repository) if retired: logger.info("cron auto-retire: %d stale fact(s) retired", retired) @@ -121,13 +126,18 @@ def _run_proactive_ask_tick(cli_state_dir: Path, delivery_callback) -> None: config=proactive_config, ) if result.enqueued: - logger.info("cron proactive-ask %s: delivered %d question(s)", adapter_id, result.enqueued) + logger.info( + "cron proactive-ask %s: delivered %d question(s)", + adapter_id, + result.enqueued, + ) except Exception as exc: logger.error("cron proactive-ask %s failed: %s", adapter_id, exc) # ── Supervisor ────────────────────────────────────────────────── + async def supervisor_loop( *, state_dir: Path, @@ -166,6 +176,7 @@ async def supervisor_loop( # ── Learning Worker ───────────────────────────────────────────── + async def learning_worker_loop( *, state_dir: Path, @@ -200,16 +211,28 @@ async def learning_worker_loop( job = repository.claim_learning_job(worker_id=worker_id) if job is None: if idle_seconds is not None and time.monotonic() - last_activity >= max(1.0, idle_seconds): - logger.info("learning worker idle timeout (%gs, %d job(s) completed), exiting", idle_seconds, jobs_completed) + logger.info( + "learning worker idle timeout (%gs, %d job(s) completed), exiting", + idle_seconds, + jobs_completed, + ) break _write_learning_worker_record( - state_dir, pid=os.getpid(), status="idle", started_at=started_at, + state_dir, + pid=os.getpid(), + status="idle", + started_at=started_at, ) await asyncio.sleep(0.5) continue last_activity = time.monotonic() - logger.info("learning job claimed: %s (stage=%s, attempt=%d)", job.job_id, job.progress_stage, job.attempt_count) + logger.info( + "learning job claimed: %s (stage=%s, attempt=%d)", + job.job_id, + job.progress_stage, + job.attempt_count, + ) _write_learning_worker_record( state_dir, pid=os.getpid(), @@ -233,7 +256,10 @@ async def learning_worker_loop( ) finally: _write_learning_worker_record( - state_dir, pid=os.getpid(), status="running", started_at=started_at, + state_dir, + pid=os.getpid(), + status="running", + started_at=started_at, ) finally: _write_learning_worker_record( @@ -273,7 +299,9 @@ def _log_supervisor_tick(tick: object, tick_count: int) -> None: if decisions: logger.info( "supervisor tick #%d: scanned=%d decisions=%d", - tick_count, scanned_count, len(decisions), + tick_count, + scanned_count, + len(decisions), ) for decision in decisions: action = getattr(decision, "action", "?") @@ -284,10 +312,15 @@ def _log_supervisor_tick(tick: object, tick_count: int) -> None: pending = len(getattr(snapshot, "replay_plans", []) or []) if snapshot else 0 logger.info( "supervisor decision: %s %s wc=%s retry=%d pending=%d", - action, loop_id, wait_kind or "", retry, pending, + action, + loop_id, + wait_kind or "", + retry, + pending, ) elif tick_count % 10 == 0: logger.debug( "supervisor tick #%d: scanned=%d, no decisions", - tick_count, scanned_count, + tick_count, + scanned_count, ) diff --git a/apps/dashboard_command.py b/apps/dashboard_command.py index 52a0103..c180975 100644 --- a/apps/dashboard_command.py +++ b/apps/dashboard_command.py @@ -8,9 +8,7 @@ from pathlib import Path import shutil import subprocess -import sys import time -from typing import Any import urllib.error import urllib.request import webbrowser @@ -187,11 +185,14 @@ def _ensure_frontend_dist(*, skip_build: bool = False, rebuild: bool = False) -> "Elephant Agent dashboard", "Dashboard frontend dependencies are not installed.", sections=( - CliCardSection("Next step", ( - "Install dependencies first:", - " cd apps/dashboard && npm install", - "Then run this command again.", - )), + CliCardSection( + "Next step", + ( + "Install dependencies first:", + " cd apps/dashboard && npm install", + "Then run this command again.", + ), + ), ), ) return False @@ -344,10 +345,18 @@ def build_typer_app( @app.callback(invoke_without_command=True) def main_callback( state_dir: Path = typer.Option(default_state_dir, "--state-dir", hidden=True), - open_browser: bool = typer.Option(True, "--open/--no-open", help="Open the dashboard URL in the default browser."), + open_browser: bool = typer.Option( + True, + "--open/--no-open", + help="Open the dashboard URL in the default browser.", + ), skip_build: bool = typer.Option(False, "--skip-build", help="Skip the frontend build check."), rebuild: bool = typer.Option(False, "--rebuild", help="Force rebuild the frontend assets."), - start_daemon: bool = typer.Option(True, "--start/--no-start", help="Start the daemon automatically when it is not running."), + start_daemon: bool = typer.Option( + True, + "--start/--no-start", + help="Start the daemon automatically when it is not running.", + ), ) -> None: plan = DashboardLaunchPlan(state_dir=state_dir) raise typer.Exit( diff --git a/apps/dashboard_static_server.py b/apps/dashboard_static_server.py index df44537..1cd21e6 100644 --- a/apps/dashboard_static_server.py +++ b/apps/dashboard_static_server.py @@ -28,7 +28,10 @@ def _headers_for(path: Path, *, content_length: int) -> list[tuple[str, str]]: return [ ("Content-Type", content_type), ("Content-Length", str(content_length)), - ("Cache-Control", "no-cache" if path.name == "index.html" else "public, max-age=31536000, immutable"), + ( + "Cache-Control", + "no-cache" if path.name == "index.html" else "public, max-age=31536000, immutable", + ), ] diff --git a/apps/episode_runtime.py b/apps/episode_runtime.py index ca8a4ff..7a15709 100644 --- a/apps/episode_runtime.py +++ b/apps/episode_runtime.py @@ -155,7 +155,11 @@ def relationship_projection_policy( preserve_preferences: bool = True, preserve_corrections: bool = True, preserve_emotional_context: bool = True, - allowed_signal_kinds: tuple[str, ...] = ("relationship", "preference", "continuity"), + allowed_signal_kinds: tuple[str, ...] = ( + "relationship", + "preference", + "continuity", + ), ) -> RelationshipPolicy: return build_relationship_policy( profile_mode=profile_mode, @@ -168,7 +172,9 @@ def relationship_projection_policy( ) -def install_app_episode_runtime(repository: RuntimeStorageRepository) -> EpisodeLifecycleService: +def install_app_episode_runtime( + repository: RuntimeStorageRepository, +) -> EpisodeLifecycleService: """Build the app-owned Episode lifecycle service on top of repository methods.""" return EpisodeLifecycleService(repository) diff --git a/apps/gateway/capabilities.py b/apps/gateway/capabilities.py index 8ec1fbe..ca66b45 100644 --- a/apps/gateway/capabilities.py +++ b/apps/gateway/capabilities.py @@ -72,7 +72,13 @@ def retrieve_evidence(self, request: EvidenceRetrievalRequest) -> EvidenceRetrie class GatewayContextCapability(ContextCapability): - def __init__(self, profile: LoadedProfile, *, total_tokens: int = 3072, install_root: Path | None = None) -> None: + def __init__( + self, + profile: LoadedProfile, + *, + total_tokens: int = 3072, + install_root: Path | None = None, + ) -> None: self.profile = profile self.install_root = install_root self.prompt_contract = build_prompt_contract(profile, prompt_mode="full") diff --git a/apps/gateway/cli_control.py b/apps/gateway/cli_control.py index c062ec6..e08f573 100644 --- a/apps/gateway/cli_control.py +++ b/apps/gateway/cli_control.py @@ -52,14 +52,11 @@ def _abbreviate_identifier(value: str, *, head: int = 12, tail: int = 6) -> str: class CliRuntimeLike(Protocol): - def list_herd(self, *, limit: int = 12) -> tuple[object, ...]: - ... + def list_herd(self, *, limit: int = 12) -> tuple[object, ...]: ... - def latest_session_for_elephant(self, elephant_id: str) -> Episode | None: - ... + def latest_session_for_elephant(self, elephant_id: str) -> Episode | None: ... - def session_ids_for_elephant(self, elephant_id: str) -> tuple[str, ...]: - ... + def session_ids_for_elephant(self, elephant_id: str) -> tuple[str, ...]: ... def create_elephant( self, @@ -69,14 +66,11 @@ def create_elephant( display_name: str | None = None, mode: str | None = None, session_id: str | None = None, - ) -> Episode: - ... + ) -> Episode: ... - def inspect_session(self, session_id: str) -> Episode: - ... + def inspect_session(self, session_id: str) -> Episode: ... - def prepare_session_surface(self, session_id: str) -> Episode: - ... + def prepare_session_surface(self, session_id: str) -> Episode: ... def explain_next_step( self, @@ -87,8 +81,7 @@ def explain_next_step( tool_name: str | None = None, tool_arguments: Mapping[str, Any] | None = None, delivery_payload: Mapping[str, Any] | None = None, - ) -> Any: - ... + ) -> Any: ... def compact_session_context( self, @@ -96,8 +89,7 @@ def compact_session_context( *, reason: str = "gateway-hygiene", force: bool = False, - ) -> Any: - ... + ) -> Any: ... def schedule_learning_for_session( self, @@ -106,8 +98,7 @@ def schedule_learning_for_session( trigger: str, summary: str = "", metadata: Mapping[str, str] | None = None, - ) -> Any: - ... + ) -> Any: ... def open_next_episode( self, @@ -116,11 +107,9 @@ def open_next_episode( next_episode_id: str | None = None, reason: str = "wake_boundary", summary: str = "", - ) -> Any: - ... + ) -> Any: ... - def elephant_id_for_session(self, session: Episode) -> str: - ... + def elephant_id_for_session(self, session: Episode) -> str: ... CliRuntimeFactory = Callable[[Path, Path], CliRuntimeLike] @@ -245,7 +234,9 @@ def load_gateway_cli_control_config( ) -def load_feishu_cli_control_config(manifest: Mapping[str, object]) -> GatewayCliControlConfig: +def load_feishu_cli_control_config( + manifest: Mapping[str, object], +) -> GatewayCliControlConfig: config = load_gateway_cli_control_config( manifest, adapter_key="feishu", @@ -500,12 +491,7 @@ def _handle_elephant_command( ) -> GatewayCliControlResult: if not argument: return GatewayCliControlResult( - body=( - "Usage:\n" - "- /elephant list\n" - "- /elephant use \n" - "- /elephant current" - ), + body=("Usage:\n- /elephant list\n- /elephant use \n- /elephant current"), summary="missing elephant subcommand", ) action, _, remainder = argument.strip().partition(" ") @@ -581,7 +567,11 @@ def _handle_elephant_command( route_session = self.app.core.dependencies.session_store.lookup(bound.session_id) if route_session is not None: self.app.core.dependencies.session_store.save( - replace(route_session, profile_id=state.personal_model_id, updated_at=_utc_now()) + replace( + route_session, + profile_id=state.personal_model_id, + updated_at=_utc_now(), + ) ) return GatewayCliControlResult( body=( @@ -594,12 +584,7 @@ def _handle_elephant_command( summary="elephant shaped", ) return GatewayCliControlResult( - body=( - "Usage:\n" - "- /elephant list\n" - "- /elephant use \n" - "- /elephant current" - ), + body=("Usage:\n- /elephant list\n- /elephant use \n- /elephant current"), summary="unknown elephant subcommand", ) @@ -635,7 +620,11 @@ def _session_selection( if selection_mode in {"parent-bound", "parent-bound-session"}: return elephant_id, recovered, "parent-bound-recovered" return elephant_id, recovered, "bound-recovered" - return elephant_id, self._session_for_elephant(runtime, elephant_id), selection_mode + return ( + elephant_id, + self._session_for_elephant(runtime, elephant_id), + selection_mode, + ) def _elephant_selection( self, @@ -679,10 +668,7 @@ def _binding_lookup_order( inbound: GatewayInboundMessage, ) -> tuple[str, ...]: candidates = [inbound.conversation_id] - if ( - inbound.parent_conversation_id is not None - and inbound.parent_conversation_id != inbound.conversation_id - ): + if inbound.parent_conversation_id is not None and inbound.parent_conversation_id != inbound.conversation_id: candidates.append(inbound.parent_conversation_id) return tuple(dict.fromkeys(candidates)) @@ -711,15 +697,20 @@ def _elephant_display(self, elephant_id: str) -> str: def _elephant_listing(self) -> str: herd = self._list_states(limit=12) if not herd: - return ( - "No local Elephant Agent herd are available yet.\n" - "Create one from the CLI first." - ) + return "No local Elephant Agent herd are available yet.\nCreate one from the CLI first." lines = ["Available local Elephant Agent herd:"] for state in herd: - elephant_id = str(getattr(state, "elephant_id", "") or getattr(state, "elephant_name", "") or getattr(state, "state_id", "")) + elephant_id = str( + getattr(state, "elephant_id", "") + or getattr(state, "elephant_name", "") + or getattr(state, "state_id", "") + ) elephant_name = str(getattr(state, "elephant_name", "") or "").strip() - elephant_label = f"{elephant_name} (`{elephant_id}`)" if elephant_name and elephant_name != elephant_id else f"`{elephant_id}`" + elephant_label = ( + f"{elephant_name} (`{elephant_id}`)" + if elephant_name and elephant_name != elephant_id + else f"`{elephant_id}`" + ) status = str(getattr(state, "status", "") or getattr(state, "latest_status", "") or "active") summary = str(getattr(state, "summary", "") or "").strip() current_marker = "" @@ -757,7 +748,7 @@ def _help_text(self) -> str: f"- /elephant use · pin this {self.binding_subject} to an elephant", f"- /elephant current · inspect the elephant currently handling this {self.binding_subject}", f"- /status · inspect the elephant currently handling this {self.binding_subject}", - f"- /clear · close this Episode and open a fresh one on the same elephant", + "- /clear · close this Episode and open a fresh one on the same elephant", f"- plain text · forward the message into the bound elephant after this {self.binding_subject} is pinned", ) ) @@ -870,9 +861,7 @@ def _auto_bind_single_elephant( return None only_state = herd[0] elephant_ref = str( - getattr(only_state, "elephant_id", "") - or getattr(only_state, "elephant_name", "") - or "" + getattr(only_state, "elephant_id", "") or getattr(only_state, "elephant_name", "") or "" ).strip() if not elephant_ref: return None diff --git a/apps/gateway/cron_service.py b/apps/gateway/cron_service.py index 5cdbec6..0a1652d 100644 --- a/apps/gateway/cron_service.py +++ b/apps/gateway/cron_service.py @@ -13,9 +13,12 @@ from apps.cli.runtime import CliRuntime from apps.runtime_layout import default_cli_state_dir from packages.cron import CronJob, CronJobExecution -from packages.runtime_layout import infer_install_root_from_state_dir -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path +from .plugins import ( + GatewayManagedRuntime, + GatewayPluginRegistry, + default_gateway_runtime_path, +) CRON_SCHEDULER_TARGET = "scheduler" @@ -173,6 +176,7 @@ def run_cron_scheduler_loop( if now_ts - last_maintenance_at > 86400: try: from packages.understanding.auto_retire import retire_stale_facts + retired = retire_stale_facts(runtime.repository) if retired: print(f"auto-retire: {retired} stale fact(s) retired", flush=True) @@ -287,6 +291,7 @@ def build_gateway_cron_delivery_callback( callbacks.append(callback) if not callbacks: return None + def _fanout(job, execution) -> None: for callback in callbacks: try: @@ -364,7 +369,9 @@ def _try_weixin_cron_callback( return None -def register_cron_scheduler_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: +def register_cron_scheduler_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: registry.register_service("cron", factory=build_cron_scheduler_service, enabled_by_default=True) return registry diff --git a/apps/gateway/dingding.py b/apps/gateway/dingding.py index bea00e2..06d82f6 100644 --- a/apps/gateway/dingding.py +++ b/apps/gateway/dingding.py @@ -13,7 +13,9 @@ from .runtime import build_gateway_app -def register_dingding_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: +def register_dingding_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: registry.register_service( "dingding", factory=lambda app, **kwargs: DingdingGatewayService(app=app, **kwargs), diff --git a/apps/gateway/dingding_service.py b/apps/gateway/dingding_service.py index eeb62aa..d71ab06 100644 --- a/apps/gateway/dingding_service.py +++ b/apps/gateway/dingding_service.py @@ -5,11 +5,9 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass, field import asyncio -import importlib.util import logging import os from pathlib import Path -import threading from typing import Any from uuid import uuid4 @@ -17,14 +15,12 @@ DEFAULT_GATEWAY_ACCOUNT_ID, GatewayAccountRef, GatewayConversationRef, - GatewayExchange, GatewayInboundMessage, GatewayOutboundMessage, GatewayOutboundQueue, GatewayOutboundRow, InboundSequencer, default_outbound_queue_path, - resolve_cron_identity_records, run_outbound_drain_loop, ) @@ -36,17 +32,20 @@ GatewayCliControlService, load_gateway_cli_control_config, ) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path -from .runtime import DINGDING_ADAPTER_ID, DingdingMessagingAdapter, GatewayApp, build_gateway_app +from .plugins import ( + GatewayManagedRuntime, + GatewayPluginRegistry, + default_gateway_runtime_path, +) +from .runtime import ( + DINGDING_ADAPTER_ID, + DingdingMessagingAdapter, + GatewayApp, + build_gateway_app, +) from .dingding_support import ( - DEFAULT_DINGDING_CLIENT_ID_ENV, - DEFAULT_DINGDING_CLIENT_SECRET_ENV, - DEFAULT_DINGDING_ROBOT_CODE_ENV, - DINGTALK_STREAM_PIP_SPEC, - SUPPORTED_DINGDING_TRANSPORTS, DingdingGatewayAccountConfig, - DingdingGatewayEventResult, DingdingResolvedAccount, _dingding_chat_type, _dingtalk_stream_dependency_status, @@ -97,13 +96,9 @@ def __post_init__(self) -> None: if binding_store is None: state_root = self.app.state_dir binding_path = ( - None - if state_root is None - else os.path.join(state_root, "dingding-cli-bindings.json") - ) - binding_store = GatewayCliBindingStore( - path=None if binding_path is None else Path(binding_path) + None if state_root is None else os.path.join(state_root, "dingding-cli-bindings.json") ) + binding_store = GatewayCliBindingStore(path=None if binding_path is None else Path(binding_path)) self.cli_control = GatewayCliControlService( config=self._resolved_cli_control_config(config), app=self.app, @@ -166,9 +161,7 @@ def describe(self) -> Mapping[str, object]: "surface": config.surface, "enabled": config.enabled, "credentials_status": ( - "configured" - if _can_resolve_account(config, environ=self.environ) - else "missing_credentials" + "configured" if _can_resolve_account(config, environ=self.environ) else "missing_credentials" ), } for config in self.account_configs @@ -190,14 +183,10 @@ def configured_transport(self) -> str: transport_configs = self._transport_account_configs() if not transport_configs: return "stream" - transports = tuple( - dict.fromkeys(_normalize_transport(config.surface) for config in transport_configs) - ) + transports = tuple(dict.fromkeys(_normalize_transport(config.surface) for config in transport_configs)) if len(transports) == 1: return transports[0] - raise LookupError( - "configured DingDing accounts use multiple transport surfaces; choose one explicitly" - ) + raise LookupError("configured DingDing accounts use multiple transport surfaces; choose one explicitly") def configured_runtime_target(self) -> str: return self.configured_transport() @@ -280,9 +269,11 @@ async def start_gateway( ) -> object: dingtalk_stream = _load_dingtalk_stream_sdk(dingtalk_module) account = self._match_account(account_id=account_id) - LOGGER.info("DingDing start_gateway: client_id=%s..., robot_code=%s...", - account.client_id[:8] if account.client_id else "(empty)", - account.robot_code[:8] if account.robot_code else "(empty)") + LOGGER.info( + "DingDing start_gateway: client_id=%s..., robot_code=%s...", + account.client_id[:8] if account.client_id else "(empty)", + account.robot_code[:8] if account.robot_code else "(empty)", + ) # New SDK: DingTalkStreamClient + Credential (replaces OpenDingTalkClient) client_cls = getattr(dingtalk_stream, "DingTalkStreamClient", None) @@ -312,7 +303,10 @@ async def process(self, callback: object) -> tuple: try: data = getattr(callback, "data", None) payload = _dingtalk_callback_payload(callback) - LOGGER.debug("DingDing callback: payload keys=%s", list(payload.keys()) if isinstance(payload, Mapping) else type(payload).__name__) + LOGGER.debug( + "DingDing callback: payload keys=%s", + list(payload.keys()) if isinstance(payload, Mapping) else type(payload).__name__, + ) # Build ChatbotMessage for reply methods incoming_msg = None chatbot_msg_cls = getattr(dingtalk_stream, "ChatbotMessage", None) @@ -339,9 +333,14 @@ async def process(self, callback: object) -> tuple: client = client_cls(credential) # ChatbotMessage.TOPIC = '/v1.0/im/bot/messages/get' chatbot_msg_cls = getattr(dingtalk_stream, "ChatbotMessage", None) - topic = getattr(chatbot_msg_cls, "TOPIC", "/v1.0/im/bot/messages/get") if chatbot_msg_cls else "/v1.0/im/bot/messages/get" + topic = ( + getattr(chatbot_msg_cls, "TOPIC", "/v1.0/im/bot/messages/get") + if chatbot_msg_cls + else "/v1.0/im/bot/messages/get" + ) client.register_callback_handler( - topic, handler_instance, + topic, + handler_instance, ) else: client = open_dingtalk_client( @@ -381,7 +380,10 @@ def _outbound_queue(self) -> GatewayOutboundQueue: return GatewayOutboundQueue(path=default_outbound_queue_path(state_root)) def _inbound_sequence_key( - self, payload: Mapping[str, object], *, account: DingdingResolvedAccount, + self, + payload: Mapping[str, object], + *, + account: DingdingResolvedAccount, ) -> str | None: sender_id = str(payload.get("sender_id") or "").strip() robot_code = str(payload.get("robot_code") or account.robot_code or "").strip() @@ -459,7 +461,10 @@ async def _on_dingtalk_message( if result.handled and result.body is not None: outbound = self._build_control_outbound(inbound, body=result.body, session_id=result.session_id) delivery_request = adapter.build_reply_request(outbound) - delivery_request = {**dict(delivery_request), "incoming_message": incoming_message} + delivery_request = { + **dict(delivery_request), + "incoming_message": incoming_message, + } await self._send_dingtalk_reply( delivery_request, account=account, @@ -481,7 +486,10 @@ async def _on_dingtalk_message( if exchange.delivery.outbound is None: return delivery_request = adapter.build_reply_request(exchange.delivery.outbound) - delivery_request = {**dict(delivery_request), "incoming_message": incoming_message} + delivery_request = { + **dict(delivery_request), + "incoming_message": incoming_message, + } await self._send_dingtalk_reply( delivery_request, account=account, @@ -514,6 +522,7 @@ async def _send_dingtalk_reply( title = raw_title if raw_title else (msg_param.split("\n")[0][:64] if msg_param else "Reply") import json as _json import requests as _requests + webhook_url = getattr(incoming, "session_webhook", None) sender_staff_id = getattr(incoming, "sender_staff_id", None) if not webhook_url: @@ -535,7 +544,11 @@ async def _send_dingtalk_reply( data=_json.dumps(reply_payload), ) if resp.status_code != 200 or resp.json().get("errcode"): - LOGGER.error("DingDing reply failed: status=%s body=%s", resp.status_code, resp.text[:500]) + LOGGER.error( + "DingDing reply failed: status=%s body=%s", + resp.status_code, + resp.text[:500], + ) else: LOGGER.debug("DingDing reply sent successfully") resp.raise_for_status() @@ -597,9 +610,7 @@ def _match_account( raise LookupError("no enabled DingDing gateway accounts are configured") if len(enabled_configs) == 1: return resolve_dingding_account(enabled_configs[0], environ=self.environ) - raise LookupError( - "multiple enabled DingDing gateway accounts are configured; pass account_id explicitly" - ) + raise LookupError("multiple enabled DingDing gateway accounts are configured; pass account_id explicitly") # ── Outbound drain (shared queue) ───────────────────────── @@ -732,7 +743,9 @@ def _dingtalk_callback_payload(callback: object) -> Mapping[str, object]: return merged -def register_dingding_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: +def register_dingding_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: from .dingding import DingdingGatewayService registry.register_service( diff --git a/apps/gateway/dingding_support.py b/apps/gateway/dingding_support.py index 7877f88..bf6c1a1 100644 --- a/apps/gateway/dingding_support.py +++ b/apps/gateway/dingding_support.py @@ -6,28 +6,15 @@ from dataclasses import dataclass, field import importlib.util import os -from pathlib import Path -from apps.runtime_layout import default_cli_state_dir from packages.gateway_core import ( DEFAULT_GATEWAY_ACCOUNT_ID, GatewayExchange, - GatewayInboundMessage, GatewayOutboundMessage, ) -from .cli_control import ( - CliRuntimeFactory, - GatewayCliBindingStore, - GatewayCliControlService, - load_gateway_cli_control_config, -) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path from .runtime import ( - DINGDING_ADAPTER_ID, - DingdingMessagingAdapter, GatewayApp, - build_gateway_app, ) DEFAULT_DINGDING_CLIENT_ID_ENV = "ELEPHANT_DINGDING_CLIENT_ID" @@ -66,10 +53,7 @@ def _normalize_transport(value: str | None) -> str: normalized = str(value or "stream").strip().lower().replace("_", "-") if normalized in {"stream", "dingtalk-stream"}: return "stream" - raise ValueError( - "dingding transport must be one of " - f"{', '.join(SUPPORTED_DINGDING_TRANSPORTS)}" - ) + raise ValueError(f"dingding transport must be one of {', '.join(SUPPORTED_DINGDING_TRANSPORTS)}") def _dingtalk_stream_dependency_status() -> str: @@ -145,15 +129,9 @@ def load_dingding_gateway_accounts( resolved.append( DingdingGatewayAccountConfig( account_id=str(account_mapping.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID), - client_id_env_var=str( - env_payload.get("client_id") or DEFAULT_DINGDING_CLIENT_ID_ENV - ), - client_secret_env_var=str( - env_payload.get("client_secret") or DEFAULT_DINGDING_CLIENT_SECRET_ENV - ), - robot_code_env_var=str( - env_payload.get("robot_code") or DEFAULT_DINGDING_ROBOT_CODE_ENV - ), + client_id_env_var=str(env_payload.get("client_id") or DEFAULT_DINGDING_CLIENT_ID_ENV), + client_secret_env_var=str(env_payload.get("client_secret") or DEFAULT_DINGDING_CLIENT_SECRET_ENV), + robot_code_env_var=str(env_payload.get("robot_code") or DEFAULT_DINGDING_ROBOT_CODE_ENV), surface=str(account_mapping.get("surface") or default_surface), enabled=account_enabled, metadata={"manifest_index": index}, @@ -174,13 +152,9 @@ def resolve_dingding_account( client_secret = str(env.get(config.client_secret_env_var) or "").strip() robot_code = str(env.get(config.robot_code_env_var) or "").strip() if not client_id: - raise LookupError( - f"dingding account '{config.account_id}' requires {config.client_id_env_var}" - ) + raise LookupError(f"dingding account '{config.account_id}' requires {config.client_id_env_var}") if not client_secret: - raise LookupError( - f"dingding account '{config.account_id}' requires {config.client_secret_env_var}" - ) + raise LookupError(f"dingding account '{config.account_id}' requires {config.client_secret_env_var}") return DingdingResolvedAccount( account_id=config.account_id, client_id=client_id, diff --git a/apps/gateway/discord.py b/apps/gateway/discord.py index 4707e2b..63ede47 100644 --- a/apps/gateway/discord.py +++ b/apps/gateway/discord.py @@ -3,10 +3,12 @@ from __future__ import annotations from .discord_support import * # noqa: F401,F403 -from .discord_transport import DiscordPyDeliveryTransport from .discord_service import DiscordGatewayService -def register_discord_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: + +def register_discord_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: registry.register_service( "discord", factory=lambda app, **kwargs: DiscordGatewayService(app=app, **kwargs), diff --git a/apps/gateway/discord_service.py b/apps/gateway/discord_service.py index 05f6b02..597596b 100644 --- a/apps/gateway/discord_service.py +++ b/apps/gateway/discord_service.py @@ -1,7 +1,7 @@ from __future__ import annotations from .discord_support import * # noqa: F401,F403 -from .discord_transport import DiscordPyDeliveryTransport + @dataclass(slots=True) class DiscordGatewayService: @@ -41,14 +41,8 @@ def __post_init__(self) -> None: binding_store = self.cli_binding_store if binding_store is None: state_root = self.app.state_dir - binding_path = ( - None - if state_root is None - else os.path.join(state_root, "discord-cli-bindings.json") - ) - binding_store = GatewayCliBindingStore( - path=None if binding_path is None else Path(binding_path) - ) + binding_path = None if state_root is None else os.path.join(state_root, "discord-cli-bindings.json") + binding_store = GatewayCliBindingStore(path=None if binding_path is None else Path(binding_path)) self.cli_control = GatewayCliControlService( config=self._resolved_cli_control_config(config), app=self.app, @@ -80,7 +74,9 @@ def _transport_account_configs(self) -> tuple[DiscordGatewayAccountConfig, ...]: enabled_configs = self._enabled_account_configs() return enabled_configs if enabled_configs else self.account_configs - def _describe_accounts(self) -> tuple[tuple[dict[str, object], ...], dict[str, object]]: + def _describe_accounts( + self, + ) -> tuple[tuple[dict[str, object], ...], dict[str, object]]: accounts: list[dict[str, object]] = [] configured_account_ids: list[str] = [] enabled_account_ids: list[str] = [] @@ -403,14 +399,10 @@ def configured_transport(self) -> str: transport_configs = self._transport_account_configs() if not transport_configs: return "gateway" - transports = tuple( - dict.fromkeys(_normalize_transport(config.surface) for config in transport_configs) - ) + transports = tuple(dict.fromkeys(_normalize_transport(config.surface) for config in transport_configs)) if len(transports) == 1: return transports[0] - raise LookupError( - "configured Discord accounts use multiple transport surfaces; choose one explicitly" - ) + raise LookupError("configured Discord accounts use multiple transport surfaces; choose one explicitly") def configured_runtime_target(self) -> str: return self.configured_transport() @@ -521,12 +513,8 @@ async def start_gateway( if account_id is not None: raise LookupError(f"unknown or unrunnable Discord gateway account: {account_id}") if blocked_accounts: - blocked_summary = "; ".join( - f"{account_label}: {error}" for account_label, error in blocked_accounts - ) - raise LookupError( - "no enabled Discord gateway accounts are runnable; " + blocked_summary - ) + blocked_summary = "; ".join(f"{account_label}: {error}" for account_label, error in blocked_accounts) + raise LookupError("no enabled Discord gateway accounts are runnable; " + blocked_summary) raise LookupError("no enabled Discord gateway accounts are configured") for account_label, error in blocked_accounts: print( @@ -680,9 +668,7 @@ def sdk_message_payload(self, message: object) -> Mapping[str, object]: "id": str(getattr(message, "id", "")), "channel_id": str(getattr(channel, "id", "")), "guild_id": ( - str(getattr(guild, "id", "")) - if guild is not None and getattr(guild, "id", None) is not None - else None + str(getattr(guild, "id", "")) if guild is not None and getattr(guild, "id", None) is not None else None ), "content": str(getattr(message, "content", "") or ""), "author": { @@ -702,9 +688,7 @@ def sdk_message_payload(self, message: object) -> Mapping[str, object]: "message_id": str(getattr(reference, "message_id")), "channel_id": str(getattr(reference, "channel_id", getattr(channel, "id", ""))), "guild_id": ( - str(getattr(reference, "guild_id")) - if getattr(reference, "guild_id", None) is not None - else None + str(getattr(reference, "guild_id")) if getattr(reference, "guild_id", None) is not None else None ), } if parent is not None and getattr(parent, "id", None) is not None: @@ -749,9 +733,7 @@ def _match_account( raise LookupError("no enabled Discord gateway accounts are configured") if len(enabled_configs) == 1: return resolve_discord_account(enabled_configs[0], environ=self.environ) - raise LookupError( - "multiple enabled Discord gateway accounts are configured; pass account_id explicitly" - ) + raise LookupError("multiple enabled Discord gateway accounts are configured; pass account_id explicitly") def _resolved_accounts_for_start( self, @@ -815,7 +797,6 @@ def deliver_cron_result(self, job, execution) -> None: gateway process polls the queue and sends each row through the same REST send path a normal reply uses. """ - from packages.cron import CronJob, CronJobExecution if getattr(job, "action_kind", "") == "learning": return @@ -882,9 +863,7 @@ def _send_outbound_queue_row(self, row: GatewayOutboundRow) -> None: try: account = self._match_account(account_id=row.account_id) except LookupError as error: - raise RuntimeError( - f"cannot resolve discord account for queued row: {row.account_id}" - ) from error + raise RuntimeError(f"cannot resolve discord account for queued row: {row.account_id}") from error _discord_rest_send_message( channel_id=row.conversation_id, content=row.body, diff --git a/apps/gateway/discord_support.py b/apps/gateway/discord_support.py index 00ac733..c481c59 100644 --- a/apps/gateway/discord_support.py +++ b/apps/gateway/discord_support.py @@ -2,48 +2,26 @@ from __future__ import annotations -import asyncio from collections.abc import Callable, Mapping from dataclasses import dataclass, field import importlib.util import inspect -import io import json import logging import os from pathlib import Path -import threading from typing import Any, Protocol, runtime_checkable -from uuid import uuid4 LOGGER = logging.getLogger(__name__) -from apps.runtime_layout import default_cli_state_dir from packages.gateway_core import ( GatewayExchange, - GatewayInboundMessage, - GatewayOutboundMessage, - GatewayOutboundQueue, - GatewayOutboundRow, - default_outbound_queue_path, - resolve_cron_identity_records, - run_outbound_drain_thread, ) -from .cli_control import ( - CliRuntimeFactory, - GatewayCliBindingStore, - GatewayCliControlService, - load_gateway_cli_control_config, -) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path from .runtime import ( DEFAULT_GATEWAY_ACCOUNT_ID, - DISCORD_ADAPTER_ID, - DiscordMessagingAdapter, GatewayApp, - build_gateway_app, ) DEFAULT_DISCORD_BOT_TOKEN_ENV = "ELEPHANT_DISCORD_BOT_TOKEN" @@ -69,10 +47,7 @@ def _normalize_transport(value: str | None) -> str: normalized = str(value or "gateway").strip().lower().replace("_", "-") if normalized in {"gateway", "discord-gateway"}: return "gateway" - raise ValueError( - "discord transport must be one of " - f"{', '.join(SUPPORTED_DISCORD_TRANSPORTS)}" - ) + raise ValueError(f"discord transport must be one of {', '.join(SUPPORTED_DISCORD_TRANSPORTS)}") def _string_list(value: object, *, field_name: str) -> tuple[str, ...]: @@ -200,7 +175,6 @@ def _split_discord_message_content( return tuple(chunks) - def _discord_fence_state(content: str) -> str | None: open_fence: str | None = None for raw_line in content.split("\n"): @@ -214,7 +188,6 @@ def _discord_fence_state(content: str) -> str | None: return open_fence - def _rebalance_discord_fenced_chunks( chunks: tuple[str, ...], *, @@ -361,9 +334,7 @@ def load_discord_gateway_accounts( resolved.append( DiscordGatewayAccountConfig( account_id=str(account_mapping.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID), - bot_token_env_var=str( - env_payload.get("bot_token") or DEFAULT_DISCORD_BOT_TOKEN_ENV - ), + bot_token_env_var=str(env_payload.get("bot_token") or DEFAULT_DISCORD_BOT_TOKEN_ENV), surface=str(account_mapping.get("surface") or default_surface), enabled=account_enabled, allow_guild_ids=_string_list( @@ -393,13 +364,12 @@ def resolve_discord_account( if not bot_token and config.bot_token_env_var == DEFAULT_DISCORD_BOT_TOKEN_ENV: bot_token = str(env.get(LEGACY_DISCORD_BOT_TOKEN_ENV) or "").strip() if not bot_token: - raise LookupError( - f"discord account '{config.account_id}' requires {config.bot_token_env_var}" - ) + raise LookupError(f"discord account '{config.account_id}' requires {config.bot_token_env_var}") return DiscordResolvedAccount( account_id=config.account_id, bot_token=bot_token, config=config, ) + __all__ = [name for name in globals() if not name.startswith("__")] diff --git a/apps/gateway/discord_transport.py b/apps/gateway/discord_transport.py index 9110f0d..f54a813 100644 --- a/apps/gateway/discord_transport.py +++ b/apps/gateway/discord_transport.py @@ -2,6 +2,7 @@ from .discord_support import * # noqa: F401,F403 + @dataclass(frozen=True, slots=True) class DiscordPyDeliveryTransport: client: object @@ -48,9 +49,7 @@ async def send_request( if "```" in content else DISCORD_MESSAGE_CONTENT_LIMIT ) - content_chunks = _rebalance_discord_fenced_chunks( - _split_discord_message_content(content, limit=chunk_limit) - ) + content_chunks = _rebalance_discord_fenced_chunks(_split_discord_message_content(content, limit=chunk_limit)) message = await self._send_content_chunks( channel=channel, content_chunks=content_chunks, diff --git a/apps/gateway/feishu_accounts.py b/apps/gateway/feishu_accounts.py index 01d8449..62ddcf0 100644 --- a/apps/gateway/feishu_accounts.py +++ b/apps/gateway/feishu_accounts.py @@ -1,42 +1,25 @@ """Account discovery and credential resolution for the Feishu gateway.""" - from __future__ import annotations from collections.abc import Callable, Mapping -from dataclasses import dataclass, field, replace -import importlib.util -import json import logging import os -import queue -import threading -import time -from pathlib import Path from typing import Any -from urllib.error import HTTPError, URLError -from urllib.request import Request, urlopen -from uuid import uuid4 from packages.gateway_core import ( DEFAULT_GATEWAY_ACCOUNT_ID, - GatewayExchange, - GatewayInboundMessage, - GatewayOutboundMessage, ) from apps.provider_runtime import secret_reference_from_payload -from apps.runtime_layout import default_cli_state_dir -from packages.auth import AuthProfile, EnvironmentSecretStore, ProfileCredentialResolver, SecretReference - -from .cli_control import ( - CliRuntimeFactory, - FeishuCliBindingStore, - FeishuCliControlService, - load_feishu_cli_control_config, +from packages.auth import ( + AuthProfile, + EnvironmentSecretStore, + ProfileCredentialResolver, + SecretReference, ) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path -from .runtime import FEISHU_ADAPTER_ID, FeishuMessagingAdapter, GatewayApp, build_gateway_app + +from .runtime import FEISHU_ADAPTER_ID, GatewayApp DEFAULT_FEISHU_APP_ID_ENV = "ELEPHANT_FEISHU_APP_ID" DEFAULT_FEISHU_APP_SECRET_ENV = "ELEPHANT_FEISHU_APP_SECRET" @@ -63,7 +46,10 @@ from .feishu_support import * # noqa: F401,F403 -def _feishu_event_identifiers(payload: Mapping[str, object]) -> tuple[str | None, str | None]: + +def _feishu_event_identifiers( + payload: Mapping[str, object], +) -> tuple[str | None, str | None]: header = _mapping(payload.get("header")) or {} event = _mapping(payload.get("event")) or {} message = _mapping(event.get("message")) or {} @@ -72,6 +58,7 @@ def _feishu_event_identifiers(payload: Mapping[str, object]) -> tuple[str | None _optional_text(message.get("message_id")), ) + def _feishu_secret_reference_from_payload( payload: Mapping[str, object], *, @@ -80,18 +67,15 @@ def _feishu_secret_reference_from_payload( normalized_payload = dict(payload) secret_key = str(normalized_payload.get("secret_key") or "") if not secret_key: - raise ValueError( - f"feishu account '{account_id}' secret_references entries must declare secret_key" - ) + raise ValueError(f"feishu account '{account_id}' secret_references entries must declare secret_key") normalized_payload.setdefault("provider_id", FEISHU_ADAPTER_ID) normalized_payload.setdefault("secret_name", secret_key) reference = secret_reference_from_payload(normalized_payload) if reference.provider_id != FEISHU_ADAPTER_ID: - raise ValueError( - f"feishu account '{account_id}' secret reference provider_id must be {FEISHU_ADAPTER_ID}" - ) + raise ValueError(f"feishu account '{account_id}' secret reference provider_id must be {FEISHU_ADAPTER_ID}") return reference + def _secret_reference_env_alias( references: tuple[SecretReference, ...], secret_key: str, @@ -104,18 +88,18 @@ def _secret_reference_env_alias( return candidates[0] return None + def _credential_env_vars(config: FeishuGatewayAccountConfig) -> tuple[str, ...]: if config.secret_references: return tuple( dict.fromkeys( - env_var - for reference in config.secret_references - for env_var in reference.env_var_candidates() + env_var for reference in config.secret_references for env_var in reference.env_var_candidates() ) ) env_vars = [config.app_id_env_var, config.app_secret_env_var] return tuple(dict.fromkeys(value for value in env_vars if value)) + def _feishu_account_profile(config: FeishuGatewayAccountConfig) -> AuthProfile: return AuthProfile( profile_id=f"gateway.feishu.{config.account_id}", @@ -130,6 +114,7 @@ def _feishu_account_profile(config: FeishuGatewayAccountConfig) -> AuthProfile: }, ) + def load_feishu_gateway_accounts( app: GatewayApp, *, @@ -145,9 +130,7 @@ def load_feishu_gateway_accounts( default_surface = _normalize_configured_transport((feishu_payload or {}).get("surface")) default_event_path = _normalize_path((feishu_payload or {}).get("event_path")) default_base_url = str((feishu_payload or {}).get("base_url") or DEFAULT_FEISHU_BASE_URL) - default_token_path = _normalize_path( - (feishu_payload or {}).get("token_path") or DEFAULT_FEISHU_TOKEN_PATH - ) + default_token_path = _normalize_path((feishu_payload or {}).get("token_path") or DEFAULT_FEISHU_TOKEN_PATH) accounts_payload = (feishu_payload or {}).get("accounts") if isinstance(accounts_payload, list) and accounts_payload: resolved: list[FeishuGatewayAccountConfig] = [] @@ -167,13 +150,9 @@ def load_feishu_gateway_accounts( if isinstance(item, Mapping) ) if len(secret_references) != len(secret_references_payload): - raise ValueError( - f"feishu account '{account_id}' secret_references entries must be JSON objects" - ) + raise ValueError(f"feishu account '{account_id}' secret_references entries must be JSON objects") else: - raise ValueError( - f"feishu account '{account_id}' secret_references must be a JSON array" - ) + raise ValueError(f"feishu account '{account_id}' secret_references must be a JSON array") app_id_env_var = str( env_payload.get("app_id") or _secret_reference_env_alias(secret_references, "app_id") @@ -208,6 +187,7 @@ def load_feishu_gateway_accounts( ), ) + def resolve_feishu_account( config: FeishuGatewayAccountConfig, *, @@ -215,9 +195,9 @@ def resolve_feishu_account( ) -> FeishuResolvedAccount: env = environ or os.environ if config.secret_references: - credentials = ProfileCredentialResolver(EnvironmentSecretStore(env)).resolve( - _feishu_account_profile(config) - ).as_mapping() + credentials = ( + ProfileCredentialResolver(EnvironmentSecretStore(env)).resolve(_feishu_account_profile(config)).as_mapping() + ) app_id = str(credentials.get("app_id") or "") app_secret = str(credentials.get("app_secret") or "") if not app_id or not app_secret: @@ -239,8 +219,7 @@ def resolve_feishu_account( app_secret = str(env.get(LEGACY_FEISHU_APP_SECRET_ENV) or "") if not app_id or not app_secret: raise LookupError( - f"feishu account '{config.account_id}' requires " - f"{config.app_id_env_var} and {config.app_secret_env_var}" + f"feishu account '{config.account_id}' requires {config.app_id_env_var} and {config.app_secret_env_var}" ) return FeishuResolvedAccount( account_id=config.account_id, @@ -249,6 +228,7 @@ def resolve_feishu_account( config=config, ) + __all__ = [ "DEFAULT_FEISHU_APP_ID_ENV", "DEFAULT_FEISHU_APP_SECRET_ENV", diff --git a/apps/gateway/feishu_dispatch.py b/apps/gateway/feishu_dispatch.py index aab653e..22ff249 100644 --- a/apps/gateway/feishu_dispatch.py +++ b/apps/gateway/feishu_dispatch.py @@ -354,9 +354,7 @@ def _duplicate_event_result( duplicate_response["delivery_outcome"] = "deduplicated" duplicate_response["duplicate_event"] = True duplicate_response["duplicate_handling"] = "replayed-no-delivery" - duplicate_response["summary"] = ( - "Duplicate Feishu event ignored; the original event was already processed." - ) + duplicate_response["summary"] = "Duplicate Feishu event ignored; the original event was already processed." return FeishuGatewayEventResult(exchange=None, response_body=duplicate_response) def _inflight_duplicate_event_result( diff --git a/apps/gateway/feishu_impl.py b/apps/gateway/feishu_impl.py index 70f52ac..a6fd3c2 100644 --- a/apps/gateway/feishu_impl.py +++ b/apps/gateway/feishu_impl.py @@ -1,12 +1,9 @@ """Feishu gateway implementation assembled from support and store modules.""" - from __future__ import annotations from collections.abc import Callable, Mapping -from dataclasses import dataclass, field, replace -import importlib.util -import json +from dataclasses import dataclass, field import logging import os import queue @@ -14,15 +11,11 @@ import time from pathlib import Path from typing import Any -from urllib.error import HTTPError, URLError -from urllib.request import Request, urlopen from uuid import uuid4 from packages.gateway_core import ( - DEFAULT_GATEWAY_ACCOUNT_ID, GatewayAccountRef, GatewayConversationRef, - GatewayExchange, GatewayInboundMessage, GatewayOutboundMessage, GatewayOutboundQueue, @@ -32,9 +25,7 @@ run_outbound_drain_thread, ) -from apps.provider_runtime import secret_reference_from_payload from apps.runtime_layout import default_cli_state_dir -from packages.auth import AuthProfile, EnvironmentSecretStore, ProfileCredentialResolver, SecretReference from packages.cron import CronJob, CronJobExecution from .cli_control import ( @@ -43,7 +34,11 @@ FeishuCliControlService, load_feishu_cli_control_config, ) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path +from .plugins import ( + GatewayManagedRuntime, + GatewayPluginRegistry, + default_gateway_runtime_path, +) from .runtime import ( FEISHU_ADAPTER_ID, FeishuMessagingAdapter, @@ -155,22 +150,12 @@ def __post_init__(self) -> None: ) if self.inbound_event_store is None: state_root = self.app.state_dir - dedupe_path = ( - None - if state_root is None - else os.path.join(state_root, "feishu-inbound-events.json") - ) - self.inbound_event_store = FeishuInboundEventStore( - path=None if dedupe_path is None else Path(dedupe_path) - ) + dedupe_path = None if state_root is None else os.path.join(state_root, "feishu-inbound-events.json") + self.inbound_event_store = FeishuInboundEventStore(path=None if dedupe_path is None else Path(dedupe_path)) if self.async_job_store is None: state_root = self.app.state_dir - async_jobs_path = ( - None if state_root is None else os.path.join(state_root, "feishu-async-jobs.json") - ) - self.async_job_store = FeishuAsyncJobStore( - path=None if async_jobs_path is None else Path(async_jobs_path) - ) + async_jobs_path = None if state_root is None else os.path.join(state_root, "feishu-async-jobs.json") + self.async_job_store = FeishuAsyncJobStore(path=None if async_jobs_path is None else Path(async_jobs_path)) if self.adapter is None: self.adapter = FeishuMessagingAdapter(app=self.app) if self.cli_control is None and self.app.loaded_profile is not None: @@ -179,14 +164,8 @@ def __post_init__(self) -> None: binding_store = self.cli_binding_store if binding_store is None: state_root = self.app.state_dir - binding_path = ( - None - if state_root is None - else os.path.join(state_root, "feishu-cli-bindings.json") - ) - binding_store = FeishuCliBindingStore( - path=None if binding_path is None else Path(binding_path) - ) + binding_path = None if state_root is None else os.path.join(state_root, "feishu-cli-bindings.json") + binding_store = FeishuCliBindingStore(path=None if binding_path is None else Path(binding_path)) self.cli_control = FeishuCliControlService( config=self._resolved_cli_control_config(config), app=self.app, @@ -494,12 +473,8 @@ def describe(self) -> Mapping[str, object]: "app_id_env_var": config.app_id_env_var, "app_secret_env_var": config.app_secret_env_var, "credential_env_vars": _credential_env_vars(config), - "secret_reference_ids": tuple( - reference.reference_id for reference in config.secret_references - ), - "credentials_source": ( - "secret_references" if config.secret_references else "environment" - ), + "secret_reference_ids": tuple(reference.reference_id for reference in config.secret_references), + "credentials_source": ("secret_references" if config.secret_references else "environment"), "credentials_status": status, "resolved_app_id": resolved_app_id, } @@ -514,9 +489,7 @@ def describe(self) -> Mapping[str, object]: "adapter_id": FEISHU_ADAPTER_ID, "profile_id": self.app.profile_id, "preferred_transport": "long-connection", - "implemented_transports": ( - "python-sdk-long-connection", - ), + "implemented_transports": ("python-sdk-long-connection",), "configured_transport": configured_transport, "configured_transport_error": configured_transport_error, "sdk_dependency_status": _lark_sdk_dependency_status(), @@ -530,16 +503,18 @@ def describe(self) -> Mapping[str, object]: "control": ( self.cli_control.describe() if self.cli_control is not None - else {"enabled": True, "runtime": "cli-runtime", "runtime_status": "unavailable"} + else { + "enabled": True, + "runtime": "cli-runtime", + "runtime_status": "unavailable", + } ), } def configured_transport(self) -> str: if not self.account_configs: return "long-connection" - transports = tuple( - dict.fromkeys(_normalize_transport(config.surface) for config in self.account_configs) - ) + transports = tuple(dict.fromkeys(_normalize_transport(config.surface) for config in self.account_configs)) if len(transports) == 1: return transports[0] raise LookupError( @@ -713,10 +688,7 @@ def deliver_cron_result(self, job: CronJob, execution: CronJobExecution) -> None # Only warn when we have Feishu identities but cannot disambiguate — if # there are zero Feishu identities, the scheduler's fan-out simply asked # the wrong adapter, which is expected noise. - any_feishu = any( - r.key.adapter_id == FEISHU_ADAPTER_ID - for r in identity_store.list_records() - ) + any_feishu = any(r.key.adapter_id == FEISHU_ADAPTER_ID for r in identity_store.list_records()) if any_feishu: LOGGER.warning( "cron delivery: skipping job=%s — no job.elephant_id and multiple feishu herd", @@ -780,9 +752,7 @@ def _send_outbound_queue_row(self, row: GatewayOutboundRow) -> None: try: account = self._resolve_account_by_id(row.account_id) except LookupError as error: - raise RuntimeError( - f"cannot resolve feishu account for queued row: {row.account_id}" - ) from error + raise RuntimeError(f"cannot resolve feishu account for queued row: {row.account_id}") from error if self.adapter is None: self._ensure_runtime_dependencies() assert self.adapter is not None @@ -890,7 +860,10 @@ async def stop_daemon_task(self) -> None: pass self._daemon_task = None -def register_feishu_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: + +def register_feishu_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: registry.register_service( "feishu", factory=lambda app, **kwargs: FeishuGatewayService(app=app, **kwargs), @@ -898,6 +871,7 @@ def register_feishu_gateway_service(registry: GatewayPluginRegistry) -> GatewayP ) return registry + def build_feishu_gateway_service( *, profile_id: str = "you", @@ -918,16 +892,14 @@ def build_feishu_gateway_service( app=app, http_requester=http_requester, environ=dict(environ or os.environ), - default_cli_state_dir=( - None if default_cli_state_dir is None else str(Path(default_cli_state_dir)) - ), + default_cli_state_dir=(None if default_cli_state_dir is None else str(Path(default_cli_state_dir))), ) + def create_gateway_web_app(service: FeishuGatewayService): return create_gateway_http_app(service, app=service.app) - __all__ = [ "DEFAULT_FEISHU_APP_ID_ENV", "DEFAULT_FEISHU_APP_SECRET_ENV", diff --git a/apps/gateway/feishu_stores.py b/apps/gateway/feishu_stores.py index 924cca4..ab9d0a8 100644 --- a/apps/gateway/feishu_stores.py +++ b/apps/gateway/feishu_stores.py @@ -1,42 +1,17 @@ """Inbound-event and async-job stores for the Feishu gateway.""" - from __future__ import annotations from collections.abc import Callable, Mapping from dataclasses import dataclass, field, replace -import importlib.util import json import logging -import os -import queue import threading import time from pathlib import Path from typing import Any -from urllib.error import HTTPError, URLError -from urllib.request import Request, urlopen from uuid import uuid4 -from packages.gateway_core import ( - DEFAULT_GATEWAY_ACCOUNT_ID, - GatewayExchange, - GatewayInboundMessage, - GatewayOutboundMessage, -) - -from apps.provider_runtime import secret_reference_from_payload -from apps.runtime_layout import default_cli_state_dir -from packages.auth import AuthProfile, EnvironmentSecretStore, ProfileCredentialResolver, SecretReference - -from .cli_control import ( - CliRuntimeFactory, - FeishuCliBindingStore, - FeishuCliControlService, - load_feishu_cli_control_config, -) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path -from .runtime import FEISHU_ADAPTER_ID, FeishuMessagingAdapter, GatewayApp, build_gateway_app DEFAULT_FEISHU_APP_ID_ENV = "ELEPHANT_FEISHU_APP_ID" DEFAULT_FEISHU_APP_SECRET_ENV = "ELEPHANT_FEISHU_APP_SECRET" @@ -63,6 +38,7 @@ from .feishu_support import * # noqa: F401,F403 + @dataclass(slots=True) class FeishuInboundEventStore: path: Path | None = None @@ -202,11 +178,7 @@ def _load(self) -> dict[str, FeishuInboundEventRecord]: def _prune(self) -> None: cutoff = time.time() - max(self.retention_seconds, 0) - kept = [ - (key, record) - for key, record in self._records.items() - if record.recorded_at >= cutoff - ] + kept = [(key, record) for key, record in self._records.items() if record.recorded_at >= cutoff] kept.sort(key=lambda item: item[1].recorded_at, reverse=True) if self.max_records > 0: kept = kept[: self.max_records] @@ -243,6 +215,7 @@ def _persist(self) -> None: } self.path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") + @dataclass(frozen=True, slots=True) class FeishuAsyncJobRecord: account_id: str @@ -264,6 +237,7 @@ class FeishuAsyncJobRecord: completed_at: float | None = None failed_at: float | None = None + @dataclass(slots=True) class FeishuAsyncJobStore: path: Path | None = None @@ -447,11 +421,7 @@ def fail( def incomplete_records(self) -> tuple[tuple[str, FeishuAsyncJobRecord], ...]: with self._lock: self._prune() - items = [ - (key, record) - for key, record in self._records.items() - if record.status in {"queued", "running"} - ] + items = [(key, record) for key, record in self._records.items() if record.status in {"queued", "running"}] items.sort(key=lambda item: (item[1].created_at, item[0])) return tuple(items) @@ -513,12 +483,7 @@ def _load(self) -> dict[str, FeishuAsyncJobRecord]: account_id = _optional_text(item.get("account_id")) conversation_id = _optional_text(item.get("conversation_id")) transport = _optional_text(item.get("transport")) - if ( - payload_mapping is None - or account_id is None - or conversation_id is None - or transport is None - ): + if payload_mapping is None or account_id is None or conversation_id is None or transport is None: continue event_id = _optional_text(item.get("event_id")) message_id = _optional_text(item.get("message_id")) @@ -547,27 +512,15 @@ def _load(self) -> dict[str, FeishuAsyncJobRecord]: retry_count=int(item.get("retry_count") or 0), created_at=float(item.get("created_at") or 0.0), updated_at=float(item.get("updated_at") or 0.0), - started_at=( - None if item.get("started_at") is None else float(item.get("started_at")) - ), - completed_at=( - None - if item.get("completed_at") is None - else float(item.get("completed_at")) - ), - failed_at=( - None if item.get("failed_at") is None else float(item.get("failed_at")) - ), + started_at=(None if item.get("started_at") is None else float(item.get("started_at"))), + completed_at=(None if item.get("completed_at") is None else float(item.get("completed_at"))), + failed_at=(None if item.get("failed_at") is None else float(item.get("failed_at"))), ) return loaded def _prune(self) -> None: cutoff = time.time() - max(self.retention_seconds, 0) - kept = [ - (key, record) - for key, record in self._records.items() - if record.updated_at >= cutoff - ] + kept = [(key, record) for key, record in self._records.items() if record.updated_at >= cutoff] kept.sort(key=lambda item: item[1].updated_at, reverse=True) if self.max_records > 0: kept = kept[: self.max_records] @@ -599,9 +552,7 @@ def _persist(self) -> None: "status": record.status, "placeholder_sent": record.placeholder_sent, "placeholder_message_id": record.placeholder_message_id, - "response_body": ( - None if record.response_body is None else dict(record.response_body) - ), + "response_body": (None if record.response_body is None else dict(record.response_body)), "external_message_id": record.external_message_id, "failure_summary": record.failure_summary, "retry_count": record.retry_count, @@ -620,6 +571,7 @@ def _persist(self) -> None: } self.path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") + __all__ = [ "DEFAULT_FEISHU_APP_ID_ENV", "DEFAULT_FEISHU_APP_SECRET_ENV", diff --git a/apps/gateway/feishu_support.py b/apps/gateway/feishu_support.py index 32089c2..e4f4b0d 100644 --- a/apps/gateway/feishu_support.py +++ b/apps/gateway/feishu_support.py @@ -1,42 +1,23 @@ """Support helpers and data contracts for the Feishu gateway.""" - from __future__ import annotations from collections.abc import Callable, Mapping -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field import importlib.util import json import logging -import os -import queue -import threading -import time -from pathlib import Path from typing import Any from urllib.error import HTTPError, URLError from urllib.request import Request, urlopen -from uuid import uuid4 from packages.gateway_core import ( DEFAULT_GATEWAY_ACCOUNT_ID, GatewayExchange, - GatewayInboundMessage, - GatewayOutboundMessage, ) -from apps.provider_runtime import secret_reference_from_payload -from apps.runtime_layout import default_cli_state_dir -from packages.auth import AuthProfile, EnvironmentSecretStore, ProfileCredentialResolver, SecretReference +from packages.auth import SecretReference -from .cli_control import ( - CliRuntimeFactory, - FeishuCliBindingStore, - FeishuCliControlService, - load_feishu_cli_control_config, -) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path -from .runtime import FEISHU_ADAPTER_ID, FeishuMessagingAdapter, GatewayApp, build_gateway_app DEFAULT_FEISHU_APP_ID_ENV = "ELEPHANT_FEISHU_APP_ID" DEFAULT_FEISHU_APP_SECRET_ENV = "ELEPHANT_FEISHU_APP_SECRET" @@ -61,19 +42,23 @@ LOGGER = logging.getLogger(__name__) + def _mapping(value: object) -> Mapping[str, object] | None: return value if isinstance(value, Mapping) else None + def _optional_text(value: object) -> str | None: if value is None: return None text = str(value).strip() return text or None + def _normalize_path(value: str | None) -> str: text = str(value or DEFAULT_FEISHU_EVENT_PATH).strip() or DEFAULT_FEISHU_EVENT_PATH return text if text.startswith("/") else f"/{text}" + def _normalize_transport(value: str | None) -> str: normalized = str(value or "webhook").strip().lower().replace("_", "-") if normalized in {"long-connection", "longconnection", "websocket", "ws"}: @@ -82,6 +67,7 @@ def _normalize_transport(value: str | None) -> str: return "webhook" raise ValueError("feishu transport must be one of long-connection, webhook") + def _normalize_configured_transport(value: str | None) -> str: normalized = str(value or "long-connection").strip().lower().replace("_", "-") if normalized in {"long-connection", "longconnection", "websocket", "ws"}: @@ -90,9 +76,11 @@ def _normalize_configured_transport(value: str | None) -> str: return "long-connection" raise ValueError("feishu configured transport must resolve to long-connection") + def _json_bytes(payload: Mapping[str, object]) -> bytes: return json.dumps(dict(payload), ensure_ascii=False).encode("utf-8") + def _default_json_request( method: str, url: str, @@ -115,9 +103,7 @@ def _default_json_request( raw = response.read().decode("utf-8") except HTTPError as exc: detail = exc.read().decode("utf-8") - raise RuntimeError( - f"feishu request failed with HTTP {exc.code}: {detail or exc.reason}" - ) from exc + raise RuntimeError(f"feishu request failed with HTTP {exc.code}: {detail or exc.reason}") from exc except URLError as exc: raise RuntimeError(f"feishu request failed: {exc.reason}") from exc try: @@ -127,14 +113,14 @@ def _default_json_request( if not isinstance(parsed, dict): raise RuntimeError("feishu request returned a non-object JSON response") if int(parsed.get("code", 0) or 0) != 0: - raise RuntimeError( - f"feishu request rejected: code={parsed.get('code')} msg={parsed.get('msg')}" - ) + raise RuntimeError(f"feishu request rejected: code={parsed.get('code')} msg={parsed.get('msg')}") return parsed + def _lark_sdk_dependency_status() -> str: return "installed" if importlib.util.find_spec("lark_oapi") is not None else "missing_optional_dependency" + def _load_lark_sdk(lark_module: Any | None = None) -> Any: if lark_module is not None: return lark_module @@ -147,6 +133,7 @@ def _load_lark_sdk(lark_module: Any | None = None) -> Any: ) from exc return lark + def _lark_log_level(lark_module: Any, level_name: str) -> object | None: log_levels = getattr(lark_module, "LogLevel", None) if log_levels is None: @@ -154,6 +141,7 @@ def _lark_log_level(lark_module: Any, level_name: str) -> object | None: normalized = str(level_name or "INFO").strip().upper().replace("-", "_") return getattr(log_levels, normalized, None) + def _lark_event_payload(event: object, *, lark_module: Any) -> Mapping[str, object]: marshaled = getattr(getattr(lark_module, "JSON", None), "marshal", None) if not callable(marshaled): @@ -167,6 +155,7 @@ def _lark_event_payload(event: object, *, lark_module: Any) -> Mapping[str, obje raise RuntimeError("lark_oapi long-connection event did not marshal to a JSON object") return parsed + def _default_ws_client_factory( lark_module: Any, app_id: str, @@ -184,6 +173,7 @@ def _default_ws_client_factory( log_level=log_level, ) + @dataclass(frozen=True, slots=True) class FeishuGatewayAccountConfig: account_id: str = DEFAULT_GATEWAY_ACCOUNT_ID @@ -196,6 +186,7 @@ class FeishuGatewayAccountConfig: token_path: str = DEFAULT_FEISHU_TOKEN_PATH metadata: Mapping[str, object] = field(default_factory=dict) + @dataclass(frozen=True, slots=True) class FeishuResolvedAccount: account_id: str @@ -203,6 +194,7 @@ class FeishuResolvedAccount: app_secret: str config: FeishuGatewayAccountConfig + @dataclass(frozen=True, slots=True) class FeishuGatewayEventResult: exchange: GatewayExchange | None @@ -210,11 +202,13 @@ class FeishuGatewayEventResult: delivery_request: Mapping[str, object] | None = None delivery_response: Mapping[str, object] | None = None + @dataclass(slots=True) class _FeishuTokenCacheEntry: token: str expires_at: float + @dataclass(frozen=True, slots=True) class FeishuInboundEventRecord: account_id: str @@ -223,6 +217,7 @@ class FeishuInboundEventRecord: response_body: Mapping[str, object] recorded_at: float + __all__ = [ "DEFAULT_FEISHU_APP_ID_ENV", "DEFAULT_FEISHU_APP_SECRET_ENV", diff --git a/apps/gateway/gateway_main_impl.py b/apps/gateway/gateway_main_impl.py index 5f4f542..f02df92 100644 --- a/apps/gateway/gateway_main_impl.py +++ b/apps/gateway/gateway_main_impl.py @@ -2,73 +2,25 @@ from __future__ import annotations import asyncio -from argparse import SUPPRESS, ArgumentParser, Namespace -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass -from datetime import UTC, datetime -import getpass -import apps.cli.wizard as cli_wizard -import importlib.util +from argparse import ArgumentParser, Namespace +from collections.abc import Mapping, Sequence import json -import os from pathlib import Path -import re -import shlex -import signal -import subprocess import sys import time from wsgiref.simple_server import make_server import typer -from apps.cli.runtime import CliRuntime -from apps.cli.shell import ( - Align, - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_LIGHT, - BRAND_MUTED, - Console, - Group, - Panel, - RICH_AVAILABLE, - Table, - Text, - _resolve_elephant_version, - render_elephant_mark, -) -from apps.provider_runtime import load_runtime_local_secret_env -from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir -from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID from . import ( - DEFAULT_DINGDING_CLIENT_ID_ENV, - DEFAULT_DINGDING_CLIENT_SECRET_ENV, - DEFAULT_DINGDING_ROBOT_CODE_ENV, - DEFAULT_DISCORD_BOT_TOKEN_ENV, - DEFAULT_FEISHU_APP_ID_ENV, - DEFAULT_FEISHU_APP_SECRET_ENV, - DEFAULT_FEISHU_EVENT_PATH, - DEFAULT_WECOM_BOT_ID_ENV, - DEFAULT_WECOM_SECRET_ENV, - FEISHU_ADAPTER_ID, GatewayHttpService, - GatewayManagedRuntime, GatewayManagedService, - SUPPORTED_DINGDING_TRANSPORTS, - SUPPORTED_DISCORD_TRANSPORTS, - SUPPORTED_FEISHU_TRANSPORTS, - SUPPORTED_WECOM_TRANSPORTS, - SUPPORTED_WEIXIN_TRANSPORTS, - WECOM_ADAPTER_ID, - build_gateway_app, - build_gateway_plugin_registry, create_gateway_web_app, ) -from .dingding import DINGTALK_STREAM_PIP_SPEC, DingdingGatewayService -from .discord import DISCORD_PY_PIP_SPEC, DiscordGatewayService -from .feishu import FEISHU_SDK_PIP_SPEC, FeishuGatewayService +from .dingding import DingdingGatewayService +from .discord import DiscordGatewayService +from .feishu import FeishuGatewayService from .wecom import WecomGatewayService from .weixin import WeixinGatewayService @@ -100,25 +52,7 @@ from .gateway_main_wizard import * # noqa: F401,F403 from .gateway_main_wizard import ( GATEWAY_WIZARD_BACK, - _confirm_gateway_wizard_intro, _gateway_wizard_choice_prompt, - _gateway_wizard_dialogs_supported, - _gateway_wizard_secret_prompt, - _gateway_wizard_text_prompt, - _interactive_shell_supported, - _print_gateway_dingding_wizard_intro, - _print_gateway_discord_wizard_intro, - _print_gateway_feishu_wizard_intro, - _print_gateway_setup_paused, - _print_gateway_wecom_wizard_intro, - _print_gateway_weixin_wizard_intro, - _run_interactive_dingding_wizard, - _run_interactive_discord_wizard, - _run_interactive_feishu_wizard, - _run_interactive_wecom_wizard, - _run_interactive_weixin_wizard, - _shared_wizard_choice_prompt, - _shared_wizard_text_prompt, ) @@ -142,10 +76,7 @@ def _add_message_subparser( parser = parent_subparsers.add_parser( "message", parents=[common], - help=( - f"Send a one-off text message through the {adapter_label} gateway outbound queue " - f"(connectivity test)." - ), + help=(f"Send a one-off text message through the {adapter_label} gateway outbound queue (connectivity test)."), ) _add_optional_account_argument( parser, @@ -234,6 +165,7 @@ def _run_start(service: FeishuGatewayService, args: Namespace) -> int: service.stop_outbound_drain() return 0 + def _run_discord_start(service: DiscordGatewayService, args: Namespace) -> int: transport = _resolve_runtime_target_argument(args, service=service) service.prepare_managed_runtime(action="startup", target=transport) @@ -249,6 +181,7 @@ def _run_discord_start(service: DiscordGatewayService, args: Namespace) -> int: asyncio.run(service.start_gateway(account_id=args.account_id)) return 0 + def _run_dingding_start(service: DingdingGatewayService, args: Namespace) -> int: transport = _resolve_runtime_target_argument(args, service=service) service.prepare_managed_runtime(action="startup", target=transport) @@ -264,6 +197,7 @@ def _run_dingding_start(service: DingdingGatewayService, args: Namespace) -> int asyncio.run(service.start_gateway(account_id=args.account_id)) return 0 + def _run_weixin_start(service: WeixinGatewayService, args: Namespace) -> int: transport = _resolve_runtime_target_argument(args, service=service) service.prepare_managed_runtime(action="startup", target=transport) @@ -369,7 +303,6 @@ def _run_adapter_message( reach your chat, the problem is in the gateway; if it does, the problem is upstream (prompt, scheduler, model provider). """ - import time from pathlib import Path from packages.gateway_core import ( @@ -471,6 +404,7 @@ def _run_wecom_start(service: WecomGatewayService, args: Namespace) -> int: asyncio.run(service.start_gateway(account_id=args.account_id)) return 0 + def _start_wecom_runtime_after_setup(args: Namespace, *, transport: str) -> int: start_args = Namespace(**vars(args)) start_args.runtime_target = transport or "configured" @@ -516,10 +450,10 @@ def _start_via_daemon(args: Namespace) -> int: host = record.get("host", "0.0.0.0") port = record.get("port", 8900) print(f"Elephant daemon is already running (pid {pid}).") - print(f"All configured IM adapters are managed by the daemon.") + print("All configured IM adapters are managed by the daemon.") print(f" HTTP: http://{host}:{port}/healthz") - print(f" Stop: elephant daemon stop") - print(f" Status: elephant daemon status") + print(" Stop: elephant daemon stop") + print(" Status: elephant daemon status") return 0 # Start the daemon — use args.host/port if available, otherwise defaults @@ -536,6 +470,7 @@ def _start_via_daemon(args: Namespace) -> int: def _resolve_daemon_http_addr(state_dir: Path) -> tuple[str, int]: """Resolve the daemon HTTP address from runtime record.""" from apps.daemon_command import _load_record, _daemon_record_path + record_path = _daemon_record_path(state_dir) record = _load_record(record_path) if record_path.exists() else {} record = record or {} @@ -687,14 +622,12 @@ def _restart_via_daemon(args: Namespace) -> int: force=bool(getattr(args, "force", False)), ) + def _http_services( services: Mapping[str, object], ) -> dict[str, GatewayHttpService]: - return { - key: service - for key, service in services.items() - if isinstance(service, GatewayHttpService) - } + return {key: service for key, service in services.items() if isinstance(service, GatewayHttpService)} + def _run_serve(args: Namespace) -> int: app, services = _build_services(args) @@ -712,6 +645,7 @@ def _run_serve(args: Namespace) -> int: server.serve_forever() return 0 + def command_main( argv: Sequence[str] | None = None, *, @@ -764,7 +698,6 @@ def command_main( ) describe.set_defaults(command_action="describe_all") - feishu = subparsers.add_parser("feishu", parents=[common], help="Manage Feishu accounts.") feishu.set_defaults(command_action="status", service_key="feishu") feishu_subparsers = feishu.add_subparsers(dest="feishu_command") @@ -775,7 +708,11 @@ def command_main( help="Add or update a Feishu account.", ) _add_feishu_add_options(feishu_setup) - feishu_setup.add_argument("--no-start", action="store_true", help="Only save config, do not start the adapter after setup.") + feishu_setup.add_argument( + "--no-start", + action="store_true", + help="Only save config, do not start the adapter after setup.", + ) feishu_setup.set_defaults(command_action="add_feishu", service_key="feishu", auto_start=True) feishu_remove = feishu_subparsers.add_parser( @@ -850,8 +787,7 @@ def command_main( service_key="feishu", adapter_label="feishu", conversation_description=( - "Feishu conversation id (chat_id / open_chat_id). Omit to fall back to the single " - "feishu elephant." + "Feishu conversation id (chat_id / open_chat_id). Omit to fall back to the single feishu elephant." ), ) @@ -865,7 +801,11 @@ def command_main( help="Add or update a Discord account.", ) _add_discord_add_options(discord_setup) - discord_setup.add_argument("--no-start", action="store_true", help="Only save config, do not start the adapter after setup.") + discord_setup.add_argument( + "--no-start", + action="store_true", + help="Only save config, do not start the adapter after setup.", + ) discord_setup.set_defaults(command_action="add_discord", service_key="discord", auto_start=True) discord_remove = discord_subparsers.add_parser( @@ -939,9 +879,7 @@ def command_main( common=common, service_key="discord", adapter_label="discord", - conversation_description=( - "Discord channel id. Omit to fall back to the single discord elephant." - ), + conversation_description=("Discord channel id. Omit to fall back to the single discord elephant."), ) dingding = subparsers.add_parser("dingding", parents=[common], help="Manage DingDing accounts.") @@ -950,14 +888,20 @@ def command_main( dingding_setup = dingding_subparsers.add_parser("setup", parents=[common], help="Add or update a DingDing account.") _add_dingding_add_options(dingding_setup) - dingding_setup.add_argument("--no-start", action="store_true", help="Only save config, do not start the adapter after setup.") + dingding_setup.add_argument( + "--no-start", + action="store_true", + help="Only save config, do not start the adapter after setup.", + ) dingding_setup.set_defaults(command_action="add_dingding", service_key="dingding", auto_start=True) dingding_remove = dingding_subparsers.add_parser("remove", parents=[common], help="Remove a DingDing account.") _add_required_account_argument(dingding_remove, help_text="DingDing account id to remove.") dingding_remove.set_defaults(command_action="remove_dingding", service_key="dingding") - dingding_start = dingding_subparsers.add_parser("start", parents=[common], help="Start all or one DingDing account.") + dingding_start = dingding_subparsers.add_parser( + "start", parents=[common], help="Start all or one DingDing account." + ) _add_dingding_start_options(dingding_start) dingding_start.set_defaults(command_action="start", service_key="dingding") @@ -969,7 +913,9 @@ def command_main( _add_dingding_stop_options(dingding_stop) dingding_stop.set_defaults(command_action="stop", service_key="dingding") - dingding_restart = dingding_subparsers.add_parser("restart", parents=[common], help="Restart all or one DingDing account.") + dingding_restart = dingding_subparsers.add_parser( + "restart", parents=[common], help="Restart all or one DingDing account." + ) _add_dingding_restart_options(dingding_restart) dingding_restart.set_defaults(command_action="restart", service_key="dingding") @@ -977,11 +923,18 @@ def command_main( _add_dingding_logs_options(dingding_logs) dingding_logs.set_defaults(command_action="logs", service_key="dingding") - dingding_describe = dingding_subparsers.add_parser("describe", parents=[common], help="Print resolved DingDing account wiring as JSON.") + dingding_describe = dingding_subparsers.add_parser( + "describe", + parents=[common], + help="Print resolved DingDing account wiring as JSON.", + ) dingding_describe.set_defaults(command_action="describe", service_key="dingding") dingding_doctor = dingding_subparsers.add_parser("doctor", parents=[common], help="Check DingDing health.") - _add_optional_account_argument(dingding_doctor, help_text="DingDing account id. Omit to inspect all DingDing accounts.") + _add_optional_account_argument( + dingding_doctor, + help_text="DingDing account id. Omit to inspect all DingDing accounts.", + ) dingding_doctor.set_defaults(command_action="doctor", service_key="dingding") weixin = subparsers.add_parser("weixin", parents=[common], help="Manage WeChat accounts.") @@ -990,7 +943,11 @@ def command_main( weixin_setup = weixin_subparsers.add_parser("setup", parents=[common], help="Add or update a WeChat account.") _add_weixin_add_options(weixin_setup) - weixin_setup.add_argument("--no-start", action="store_true", help="Only save config, do not start the adapter after setup.") + weixin_setup.add_argument( + "--no-start", + action="store_true", + help="Only save config, do not start the adapter after setup.", + ) weixin_setup.set_defaults(command_action="add_weixin", service_key="weixin", auto_start=True) weixin_remove = weixin_subparsers.add_parser("remove", parents=[common], help="Remove a WeChat account.") @@ -1009,7 +966,9 @@ def command_main( _add_weixin_stop_options(weixin_stop) weixin_stop.set_defaults(command_action="stop", service_key="weixin") - weixin_restart = weixin_subparsers.add_parser("restart", parents=[common], help="Restart all or one WeChat account.") + weixin_restart = weixin_subparsers.add_parser( + "restart", parents=[common], help="Restart all or one WeChat account." + ) _add_weixin_restart_options(weixin_restart) weixin_restart.set_defaults(command_action="restart", service_key="weixin") @@ -1017,11 +976,18 @@ def command_main( _add_weixin_logs_options(weixin_logs) weixin_logs.set_defaults(command_action="logs", service_key="weixin") - weixin_describe = weixin_subparsers.add_parser("describe", parents=[common], help="Print resolved WeChat account wiring as JSON.") + weixin_describe = weixin_subparsers.add_parser( + "describe", + parents=[common], + help="Print resolved WeChat account wiring as JSON.", + ) weixin_describe.set_defaults(command_action="describe", service_key="weixin") weixin_doctor = weixin_subparsers.add_parser("doctor", parents=[common], help="Check WeChat health.") - _add_optional_account_argument(weixin_doctor, help_text="WeChat account id. Omit to inspect all WeChat accounts.") + _add_optional_account_argument( + weixin_doctor, + help_text="WeChat account id. Omit to inspect all WeChat accounts.", + ) weixin_doctor.set_defaults(command_action="doctor", service_key="weixin") _add_message_subparser( @@ -1038,7 +1004,11 @@ def command_main( wecom_setup = wecom_subparsers.add_parser("setup", parents=[common], help="Add or update a WeCom account.") _add_wecom_add_options(wecom_setup) - wecom_setup.add_argument("--no-start", action="store_true", help="Only save config, do not start the adapter after setup.") + wecom_setup.add_argument( + "--no-start", + action="store_true", + help="Only save config, do not start the adapter after setup.", + ) wecom_setup.set_defaults(command_action="add_wecom", service_key="wecom", auto_start=True) wecom_remove = wecom_subparsers.add_parser("remove", parents=[common], help="Remove a WeCom account.") @@ -1065,7 +1035,11 @@ def command_main( _add_wecom_logs_options(wecom_logs) wecom_logs.set_defaults(command_action="logs", service_key="wecom") - wecom_describe = wecom_subparsers.add_parser("describe", parents=[common], help="Print resolved WeCom account wiring as JSON.") + wecom_describe = wecom_subparsers.add_parser( + "describe", + parents=[common], + help="Print resolved WeCom account wiring as JSON.", + ) wecom_describe.set_defaults(command_action="describe", service_key="wecom") wecom_doctor = wecom_subparsers.add_parser("doctor", parents=[common], help="Check WeCom health.") @@ -1215,6 +1189,7 @@ def ensure_managed_service() -> GatewayManagedService: raise TypeError("gateway service plugin 'feishu' must build FeishuGatewayService") return _run_start(feishu_service, args) + def run_im_setup( *, default_state_dir: Path | None = None, @@ -1241,6 +1216,7 @@ def run_im_setup( default_control_state_dir=default_control_state_dir, ) + def build_typer_app( *, default_state_dir: Path | None = None, @@ -1271,19 +1247,35 @@ def gateway_callback(ctx: typer.Context) -> None: if ctx.invoked_subcommand is None: raise typer.Exit(_forward(ctx)) - @app.command("setup", help="Open interactive IM setup.", context_settings=passthrough_settings) + @app.command( + "setup", + help="Open interactive IM setup.", + context_settings=passthrough_settings, + ) def setup_command(ctx: typer.Context) -> None: raise typer.Exit(_forward(ctx, "setup")) - @app.command("status", help="Show status for all providers and accounts.", context_settings=passthrough_settings) + @app.command( + "status", + help="Show status for all providers and accounts.", + context_settings=passthrough_settings, + ) def status_command(ctx: typer.Context) -> None: raise typer.Exit(_forward(ctx, "status")) - @app.command("doctor", help="Run health checks for all providers and accounts.", context_settings=passthrough_settings) + @app.command( + "doctor", + help="Run health checks for all providers and accounts.", + context_settings=passthrough_settings, + ) def doctor_command(ctx: typer.Context) -> None: raise typer.Exit(_forward(ctx, "doctor")) - @app.command("describe", help="Print resolved IM provider and account wiring as JSON.", context_settings=passthrough_settings) + @app.command( + "describe", + help="Print resolved IM provider and account wiring as JSON.", + context_settings=passthrough_settings, + ) def describe_command(ctx: typer.Context) -> None: raise typer.Exit(_forward(ctx, "describe")) @@ -1291,11 +1283,19 @@ def describe_command(ctx: typer.Context) -> None: def feishu_command(ctx: typer.Context) -> None: raise typer.Exit(_forward(ctx, "feishu")) - @app.command("discord", help="Manage Discord accounts.", context_settings=passthrough_settings) + @app.command( + "discord", + help="Manage Discord accounts.", + context_settings=passthrough_settings, + ) def discord_command(ctx: typer.Context) -> None: raise typer.Exit(_forward(ctx, "discord")) - @app.command("dingding", help="Manage DingDing accounts.", context_settings=passthrough_settings) + @app.command( + "dingding", + help="Manage DingDing accounts.", + context_settings=passthrough_settings, + ) def dingding_command(ctx: typer.Context) -> None: raise typer.Exit(_forward(ctx, "dingding")) diff --git a/apps/gateway/gateway_main_parser.py b/apps/gateway/gateway_main_parser.py index c77cfe1..1233814 100644 --- a/apps/gateway/gateway_main_parser.py +++ b/apps/gateway/gateway_main_parser.py @@ -1,74 +1,23 @@ """Gateway parser, account, and status helpers.""" from __future__ import annotations -import asyncio -from argparse import SUPPRESS, ArgumentParser, Namespace -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass -from datetime import UTC, datetime -import getpass -import apps.cli.wizard as cli_wizard -import importlib.util +from argparse import Namespace +from collections.abc import Iterable, Mapping import json -import os from pathlib import Path -import re -import shlex -import signal -import subprocess -import sys -import time -from wsgiref.simple_server import make_server - -from apps.cli.runtime import CliRuntime -from apps.cli.shell import ( - Align, - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_LIGHT, - BRAND_MUTED, - Console, - Group, - Panel, - RICH_AVAILABLE, - Table, - Text, - _resolve_elephant_version, - render_elephant_mark, -) -from apps.provider_runtime import load_provider_profile, load_runtime_local_secret_env -from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir + +from apps.provider_runtime import load_provider_profile +from apps.runtime_layout import default_cli_state_dir from packages.runtime_config import global_config_path_for_state_dir from . import ( - DEFAULT_DINGDING_CLIENT_ID_ENV, - DEFAULT_DINGDING_CLIENT_SECRET_ENV, - DEFAULT_DINGDING_ROBOT_CODE_ENV, - DEFAULT_DISCORD_BOT_TOKEN_ENV, - DEFAULT_FEISHU_APP_ID_ENV, - DEFAULT_FEISHU_APP_SECRET_ENV, - DEFAULT_FEISHU_EVENT_PATH, - DEFAULT_WECOM_BOT_ID_ENV, - DEFAULT_WECOM_SECRET_ENV, - DINGDING_ADAPTER_ID, - FEISHU_ADAPTER_ID, - GatewayHttpService, - GatewayManagedRuntime, GatewayManagedService, - SUPPORTED_DINGDING_TRANSPORTS, - SUPPORTED_DISCORD_TRANSPORTS, - SUPPORTED_FEISHU_TRANSPORTS, - SUPPORTED_WECOM_TRANSPORTS, - SUPPORTED_WEIXIN_TRANSPORTS, - WECOM_ADAPTER_ID, - WEIXIN_ADAPTER_ID, build_gateway_app, build_gateway_plugin_registry, - create_gateway_web_app, ) -from .dingding import DINGTALK_STREAM_PIP_SPEC, DingdingGatewayService -from .discord import DISCORD_PY_PIP_SPEC, DiscordGatewayService -from .feishu import FEISHU_SDK_PIP_SPEC, FeishuGatewayService +from .dingding import DingdingGatewayService +from .discord import DiscordGatewayService +from .feishu import FeishuGatewayService from .wecom import WecomGatewayService from .weixin import WeixinGatewayService @@ -104,12 +53,14 @@ from .gateway_main_parser_doctor import * # noqa: F401,F403 from .gateway_main_parser_doctor import __all__ as _DOCTOR_ALL + def _build_registry(): return build_gateway_plugin_registry() + def _gateway_provider_profile_for(args: Namespace): """Load provider profile from the canonical CLI control runtime config. - + Always uses cli_state_dir to ensure consistent configuration across all IM components. """ cli_state_dir = getattr(args, "cli_state_dir", None) @@ -117,7 +68,7 @@ def _gateway_provider_profile_for(args: Namespace): cli_state_dir = default_cli_state_dir() if cli_state_dir is None: return None - + state_dir = Path(cli_state_dir) config_path = global_config_path_for_state_dir(state_dir) profile = load_provider_profile(state_dir, config_path=config_path) @@ -146,12 +97,11 @@ def _build_app(args: Namespace, *, registry=None): ) return app + def _service_kwargs_for(service_key: str, args: Namespace) -> dict[str, object]: if service_key == "discord": return { - "default_cli_state_dir": ( - None if args.cli_state_dir is None else str(args.cli_state_dir) - ), + "default_cli_state_dir": (None if args.cli_state_dir is None else str(args.cli_state_dir)), "environ": _gateway_runtime_environ( args.state_dir, cli_state_dir=args.cli_state_dir, @@ -161,9 +111,7 @@ def _service_kwargs_for(service_key: str, args: Namespace) -> dict[str, object]: } if service_key == "feishu": return { - "default_cli_state_dir": ( - None if args.cli_state_dir is None else str(args.cli_state_dir) - ), + "default_cli_state_dir": (None if args.cli_state_dir is None else str(args.cli_state_dir)), "environ": _gateway_runtime_environ( args.state_dir, cli_state_dir=args.cli_state_dir, @@ -172,9 +120,7 @@ def _service_kwargs_for(service_key: str, args: Namespace) -> dict[str, object]: } if service_key == "dingding": return { - "default_cli_state_dir": ( - None if args.cli_state_dir is None else str(args.cli_state_dir) - ), + "default_cli_state_dir": (None if args.cli_state_dir is None else str(args.cli_state_dir)), "environ": _gateway_runtime_environ( args.state_dir, cli_state_dir=args.cli_state_dir, @@ -184,9 +130,7 @@ def _service_kwargs_for(service_key: str, args: Namespace) -> dict[str, object]: } if service_key == "weixin": return { - "default_cli_state_dir": ( - None if args.cli_state_dir is None else str(args.cli_state_dir) - ), + "default_cli_state_dir": (None if args.cli_state_dir is None else str(args.cli_state_dir)), "environ": _gateway_runtime_environ( args.state_dir, cli_state_dir=args.cli_state_dir, @@ -196,9 +140,7 @@ def _service_kwargs_for(service_key: str, args: Namespace) -> dict[str, object]: } if service_key == "wecom": return { - "default_cli_state_dir": ( - None if args.cli_state_dir is None else str(args.cli_state_dir) - ), + "default_cli_state_dir": (None if args.cli_state_dir is None else str(args.cli_state_dir)), "environ": _gateway_runtime_environ( args.state_dir, cli_state_dir=args.cli_state_dir, @@ -208,6 +150,7 @@ def _service_kwargs_for(service_key: str, args: Namespace) -> dict[str, object]: } return {} + def _build_services( args: Namespace, *, @@ -216,11 +159,7 @@ def _build_services( registry = _build_registry() app = _build_app(args, registry=registry) manifest = app.loaded_profile.manifest if app.loaded_profile is not None else None - resolved_keys = ( - tuple(service_keys) - if service_keys is not None - else registry.configured_service_keys(manifest) - ) + resolved_keys = tuple(service_keys) if service_keys is not None else registry.configured_service_keys(manifest) services = { key: registry.create_service( key, @@ -231,6 +170,7 @@ def _build_services( } return app, services + def _build_service( args: Namespace, *, @@ -247,12 +187,14 @@ def _build_service( ) return service + def _build_feishu_service(args: Namespace) -> FeishuGatewayService: service = _build_service(args, service_key="feishu", respect_enabled=False) if not isinstance(service, FeishuGatewayService): raise TypeError("gateway service plugin 'feishu' must build FeishuGatewayService") return service + def _build_discord_service(args: Namespace) -> DiscordGatewayService: service = _build_service(args, service_key="discord", respect_enabled=False) if not isinstance(service, DiscordGatewayService): @@ -280,20 +222,21 @@ def _build_wecom_service(args: Namespace) -> WecomGatewayService: raise TypeError("gateway service plugin 'wecom' must build WecomGatewayService") return service + def _build_managed_service(args: Namespace, *, service_key: str) -> GatewayManagedService: service = _build_service(args, service_key=service_key, respect_enabled=False) if not isinstance(service, GatewayManagedService): - raise TypeError( - f"gateway service plugin '{service_key}' must build a managed gateway service" - ) + raise TypeError(f"gateway service plugin '{service_key}' must build a managed gateway service") return service + def _describe_payload(service_key: str, service) -> dict[str, object]: return { "gateway": dict(service.app.setup_summary()), service_key: dict(service.describe()), } + def _describe_services_payload( app, services: Mapping[str, object], @@ -301,9 +244,7 @@ def _describe_services_payload( payload: dict[str, object] = { "gateway": dict(app.setup_summary()), "services": { - key: dict(service.describe()) - for key, service in services.items() - if hasattr(service, "describe") + key: dict(service.describe()) for key, service in services.items() if hasattr(service, "describe") }, } for key, service in services.items(): @@ -311,6 +252,7 @@ def _describe_services_payload( payload[key] = dict(service.describe()) return payload + def _print_json(payload: dict[str, object]) -> None: print(json.dumps(payload, ensure_ascii=False, indent=2, default=str)) @@ -324,6 +266,7 @@ def _run_status_all(args: Namespace) -> int: try: from apps.daemon_command import daemon_is_running from apps.daemon_command import _read_pid, _daemon_pid_path + state_dir_path = Path(args.state_dir) daemon_pid = _read_pid(_daemon_pid_path(state_dir_path)) if daemon_is_running(state_dir_path): @@ -364,4 +307,25 @@ def _run_status_all(args: Namespace) -> int: return 0 -__all__ = [*_STATE_ALL, *_PROVIDER_ALL, *_DOCTOR_ALL, *['_build_registry', '_build_app', '_service_kwargs_for', '_build_services', '_build_service', '_build_feishu_service', '_build_discord_service', '_build_dingding_service', '_build_weixin_service', '_build_wecom_service', '_build_managed_service', '_describe_payload', '_describe_services_payload', '_print_json', '_run_status_all']] +__all__ = [ + *_STATE_ALL, + *_PROVIDER_ALL, + *_DOCTOR_ALL, + *[ + "_build_registry", + "_build_app", + "_service_kwargs_for", + "_build_services", + "_build_service", + "_build_feishu_service", + "_build_discord_service", + "_build_dingding_service", + "_build_weixin_service", + "_build_wecom_service", + "_build_managed_service", + "_describe_payload", + "_describe_services_payload", + "_print_json", + "_run_status_all", + ], +] diff --git a/apps/gateway/gateway_main_parser_doctor.py b/apps/gateway/gateway_main_parser_doctor.py index 0aef9e4..5d74ee1 100644 --- a/apps/gateway/gateway_main_parser_doctor.py +++ b/apps/gateway/gateway_main_parser_doctor.py @@ -9,6 +9,7 @@ from .gateway_main_parser_state import * # noqa: F401,F403 from .gateway_main_runtime import * # noqa: F401,F403 + def _mapping_payload(value: object) -> dict[str, object]: return dict(value) if isinstance(value, Mapping) else {} @@ -25,6 +26,7 @@ def _render_feishu_account_line(account: Mapping[str, object], *, prefix: str = parts.append(f"app_id={resolved_app_id}") return f"{prefix}: " + " · ".join(parts) + def _selected_account_payloads( description: Mapping[str, object], *, @@ -35,9 +37,7 @@ def _selected_account_payloads( if account_id is None: return accounts matched = tuple( - account - for account in accounts - if str(account.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID) == account_id + account for account in accounts if str(account.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID) == account_id ) if matched: return matched @@ -50,7 +50,9 @@ def _next_steps(service) -> tuple[str, ...]: control = dict(description.get("control") or {}) steps: list[str] = [] if description.get("sdk_dependency_status") == "missing_optional_dependency": - steps.append("Elephant Agent will auto-install the Feishu SDK when you run `elephant gateway` or `elephant gateway feishu start`.") + steps.append( + "Elephant Agent will auto-install the Feishu SDK when you run `elephant gateway` or `elephant gateway feishu start`." + ) if any(account.get("credentials_status") != "configured" for account in accounts if isinstance(account, dict)): env_vars: list[str] = [] secret_reference_ids: list[str] = [] @@ -59,9 +61,7 @@ def _next_steps(service) -> tuple[str, ...]: continue credential_env_vars = account.get("credential_env_vars") if isinstance(credential_env_vars, (list, tuple)): - env_vars.extend( - value for value in credential_env_vars if isinstance(value, str) and value - ) + env_vars.extend(value for value in credential_env_vars if isinstance(value, str) and value) else: env_vars.extend( value @@ -73,9 +73,7 @@ def _next_steps(service) -> tuple[str, ...]: ) secret_refs = account.get("secret_reference_ids") if isinstance(secret_refs, (list, tuple)): - secret_reference_ids.extend( - value for value in secret_refs if isinstance(value, str) and value - ) + secret_reference_ids.extend(value for value in secret_refs if isinstance(value, str) and value) if env_vars: steps.append( "Complete Feishu IM setup again with `elephant gateway` to store the App ID and App Secret locally, or export these advanced credential aliases manually: " @@ -99,6 +97,7 @@ def _next_steps(service) -> tuple[str, ...]: steps.append("IM wiring looks healthy. Start it with `elephant gateway feishu start`. ") return tuple(steps) + def _render_discord_account_line(account: Mapping[str, object], *, prefix: str = "discord_account") -> str: allow_guild_ids = tuple(account.get("allow_guild_ids") or ()) allow_channel_ids = tuple(account.get("allow_channel_ids") or ()) @@ -117,6 +116,7 @@ def _render_discord_account_line(account: Mapping[str, object], *, prefix: str = parts.append(f"error={credentials_error}") return f"{prefix}: " + " · ".join(parts) + def _feishu_async_status_lines( description: Mapping[str, object], *, @@ -142,6 +142,7 @@ def _feishu_async_status_lines( ) return tuple(lines) + def _discord_account_status_lines( description: Mapping[str, object], *, @@ -165,6 +166,7 @@ def _discord_account_status_lines( + (", ".join(str(account_id) for account_id in disabled_account_ids if account_id) or ""), ) + def _discord_portal_checklist() -> tuple[str, ...]: return ( "Open Discord Developer Portal → OAuth2 → URL Generator and include the `bot` scope before inviting the app.", @@ -172,6 +174,7 @@ def _discord_portal_checklist() -> tuple[str, ...]: "Grant these bot permissions in Discord: `View Channels` (`查看频道`), `Send Messages` (`发送消息`), `Send Messages in Threads` (`在子区内发送消息`), and `Read Message History` (`阅读消息历史记录`).", ) + def _discord_next_steps(service) -> tuple[str, ...]: description = service.describe() accounts = tuple(description.get("accounts") or ()) @@ -208,9 +211,7 @@ def _discord_next_steps(service) -> tuple[str, ...]: steps.append( "Enable at least one Discord account for runtime starts by re-running `elephant gateway discord setup [account-id]` before starting the gateway runtime." ) - steps.append( - "Review the Discord developer portal checklist below before starting the gateway runtime." - ) + steps.append("Review the Discord developer portal checklist below before starting the gateway runtime.") if runtime_status == "running": if service_status == "degraded" and blocked_account_ids: steps.append( @@ -223,11 +224,10 @@ def _discord_next_steps(service) -> tuple[str, ...]: "Discord wiring is partially ready. Start it with `elephant gateway discord start`; runnable enabled accounts will connect while blocked accounts are skipped." ) elif service_status == "ready" and runnable_accounts > 0: - steps.append( - "Discord wiring looks healthy. Start it with `elephant gateway discord start`." - ) + steps.append("Discord wiring looks healthy. Start it with `elephant gateway discord start`.") return tuple(steps) + def _doctor_lines(service, args: Namespace) -> tuple[str, ...]: description = service.describe() control = dict(description.get("control") or {}) @@ -252,13 +252,13 @@ def _doctor_lines(service, args: Namespace) -> tuple[str, ...]: lines.append(f"control_runtime_error: {control['runtime_error']}") known_elephants = tuple(control.get("known_elephants") or ()) lines.append( - "control_known_elephants: " - + (", ".join(str(elephant) for elephant in known_elephants if elephant) or "") + "control_known_elephants: " + (", ".join(str(elephant) for elephant in known_elephants if elephant) or "") ) lines.append("next_steps:") lines.extend(f"- {step}" for step in _next_steps(service)) return tuple(lines) + def _discord_doctor_lines(service, args: Namespace) -> tuple[str, ...]: description = service.describe() runtime = _mapping_payload(description.get("runtime")) @@ -287,8 +287,7 @@ def _discord_doctor_lines(service, args: Namespace) -> tuple[str, ...]: lines.append(f"control_runtime_error: {control['runtime_error']}") known_elephants = tuple(control.get("known_elephants") or ()) lines.append( - "control_known_elephants: " - + (", ".join(str(elephant) for elephant in known_elephants if elephant) or "") + "control_known_elephants: " + (", ".join(str(elephant) for elephant in known_elephants if elephant) or "") ) for account in _selected_account_payloads( description, @@ -317,10 +316,14 @@ def _dingding_next_steps(service) -> tuple[str, ...]: accounts = tuple(description.get("accounts") or ()) steps: list[str] = [] if description.get("sdk_dependency_status") == "missing_optional_dependency": - steps.append("Elephant Agent will auto-install DingDing support when you run `elephant gateway dingding start`.") + steps.append( + "Elephant Agent will auto-install DingDing support when you run `elephant gateway dingding start`." + ) missing_credentials = [ - account for account in accounts - if isinstance(account, dict) and account.get("enabled") is not False + account + for account in accounts + if isinstance(account, dict) + and account.get("enabled") is not False and account.get("credentials_status") != "configured" ] if missing_credentials: @@ -376,7 +379,9 @@ def _weixin_next_steps(service) -> tuple[str, ...]: description = service.describe() steps: list[str] = [] if description.get("sdk_dependency_status") == "missing_optional_dependency": - steps.append("Elephant Agent will auto-install WeChat (iLink) support when you run `elephant gateway weixin start`.") + steps.append( + "Elephant Agent will auto-install WeChat (iLink) support when you run `elephant gateway weixin start`." + ) if not steps: steps.append("WeChat wiring looks healthy. Start it with `elephant gateway weixin start`.") return tuple(steps) @@ -427,8 +432,10 @@ def _wecom_next_steps(service) -> tuple[str, ...]: if description.get("sdk_dependency_status") == "missing_optional_dependency": steps.append("Elephant Agent will auto-install WeCom support when you run `elephant gateway wecom start`.") missing_credentials = [ - account for account in accounts - if isinstance(account, dict) and account.get("enabled") is not False + account + for account in accounts + if isinstance(account, dict) + and account.get("enabled") is not False and account.get("credentials_status") != "configured" ] if missing_credentials: @@ -466,6 +473,7 @@ def _wecom_doctor_lines(service, args: Namespace) -> tuple[str, ...]: lines.extend(f"- {step}" for step in _wecom_next_steps(service)) return tuple(lines) + def _doctor_service_lines( service_key: str, service, @@ -513,14 +521,10 @@ def render_account_line(account: Mapping[str, object]) -> str: f"service[{service_key}].configured_transport_error: {description.get('configured_transport_error')}" ) if description.get("sdk_dependency_status") is not None: - lines.append( - f"service[{service_key}].sdk_dependency_status: {description.get('sdk_dependency_status')}" - ) + lines.append(f"service[{service_key}].sdk_dependency_status: {description.get('sdk_dependency_status')}") runtime = _mapping_payload(description.get("runtime")) if runtime: - lines.append( - f"service[{service_key}].runtime_status: {runtime.get('runtime_status') or ''}" - ) + lines.append(f"service[{service_key}].runtime_status: {runtime.get('runtime_status') or ''}") if runtime.get("target") is not None: lines.append(f"service[{service_key}].runtime_target: {runtime.get('target')}") for account in tuple(description.get("accounts") or ()): @@ -529,6 +533,7 @@ def render_account_line(account: Mapping[str, object]) -> str: lines.append(render_account_line(account)) return tuple(lines) + def _doctor_services_lines(app, services: Mapping[str, object], args: Namespace) -> tuple[str, ...]: lines = [ "Elephant Agent Gateway doctor", @@ -537,9 +542,7 @@ def _doctor_services_lines(app, services: Mapping[str, object], args: Namespace) "registered_services: " + (", ".join(services.keys()) or ""), ] lines.extend( - line - for service_key, service in services.items() - for line in _doctor_service_lines(service_key, service) + line for service_key, service in services.items() for line in _doctor_service_lines(service_key, service) ) if "feishu" in services: lines.append("next_steps:") @@ -558,6 +561,7 @@ def _doctor_services_lines(app, services: Mapping[str, object], args: Namespace) lines.extend(f"- {step}" for step in _wecom_next_steps(services["wecom"])) return tuple(lines) + def _service_runtime_status_summary(service: object, args: Namespace) -> tuple[str, str | None]: if not isinstance(service, GatewayManagedService): return "unavailable", "service is not a managed runtime" @@ -570,4 +574,27 @@ def _service_runtime_status_summary(service: object, args: Namespace) -> tuple[s return "unavailable", str(exc) -__all__ = ['_render_feishu_account_line', '_selected_account_payloads', '_next_steps', '_render_discord_account_line', '_feishu_async_status_lines', '_discord_account_status_lines', '_discord_portal_checklist', '_discord_next_steps', '_doctor_lines', '_discord_doctor_lines', '_render_dingding_account_line', '_dingding_next_steps', '_dingding_doctor_lines', '_render_weixin_account_line', '_weixin_next_steps', '_weixin_doctor_lines', '_render_wecom_account_line', '_wecom_next_steps', '_wecom_doctor_lines', '_doctor_service_lines', '_doctor_services_lines', '_service_runtime_status_summary'] +__all__ = [ + "_render_feishu_account_line", + "_selected_account_payloads", + "_next_steps", + "_render_discord_account_line", + "_feishu_async_status_lines", + "_discord_account_status_lines", + "_discord_portal_checklist", + "_discord_next_steps", + "_doctor_lines", + "_discord_doctor_lines", + "_render_dingding_account_line", + "_dingding_next_steps", + "_dingding_doctor_lines", + "_render_weixin_account_line", + "_weixin_next_steps", + "_weixin_doctor_lines", + "_render_wecom_account_line", + "_wecom_next_steps", + "_wecom_doctor_lines", + "_doctor_service_lines", + "_doctor_services_lines", + "_service_runtime_status_summary", +] diff --git a/apps/gateway/gateway_main_parser_providers.py b/apps/gateway/gateway_main_parser_providers.py index 73a2149..275016b 100644 --- a/apps/gateway/gateway_main_parser_providers.py +++ b/apps/gateway/gateway_main_parser_providers.py @@ -3,7 +3,6 @@ from __future__ import annotations from argparse import SUPPRESS, ArgumentParser -from pathlib import Path from .gateway_main_parser_state import * # noqa: F401,F403 from .gateway_main_parser_state import ( @@ -16,6 +15,7 @@ from .gateway_main_runtime import * # noqa: F401,F403 from .gateway_main_wizard import * # noqa: F401,F403 + def _add_discord_runtime_target_options( parser: ArgumentParser, *, @@ -32,6 +32,7 @@ def _add_discord_runtime_target_options( if include_account_id: parser.add_argument("--account-id", dest="account_id_flag", help=SUPPRESS) + def _add_discord_start_options(parser: ArgumentParser) -> None: _add_discord_runtime_target_options(parser, include_account_id=True) _add_optional_account_argument( @@ -44,6 +45,7 @@ def _add_discord_start_options(parser: ArgumentParser) -> None: help="Start the Discord gateway transport in a background process and return immediately.", ) + def _add_discord_status_options(parser: ArgumentParser) -> None: _add_discord_runtime_target_options(parser, include_account_id=True) _add_optional_account_argument( @@ -51,6 +53,7 @@ def _add_discord_status_options(parser: ArgumentParser) -> None: help_text="Discord account id. Omit to inspect the provider-wide runtime and all accounts.", ) + def _add_discord_stop_options(parser: ArgumentParser) -> None: _add_discord_runtime_target_options(parser, include_account_id=True) _add_optional_account_argument( @@ -69,6 +72,7 @@ def _add_discord_stop_options(parser: ArgumentParser) -> None: help="Send SIGKILL when the process does not exit within --timeout.", ) + def _add_discord_restart_options(parser: ArgumentParser) -> None: _add_discord_runtime_target_options(parser, include_account_id=True) _add_optional_account_argument( @@ -87,6 +91,7 @@ def _add_discord_restart_options(parser: ArgumentParser) -> None: help="Send SIGKILL when the previous process does not exit within --timeout.", ) + def _add_discord_logs_options(parser: ArgumentParser) -> None: _add_discord_runtime_target_options(parser, include_account_id=True) _add_required_account_argument( @@ -110,6 +115,7 @@ def _add_discord_logs_options(parser: ArgumentParser) -> None: help="Print the resolved log file path and exit.", ) + def _add_discord_add_options(parser: ArgumentParser) -> None: _add_optional_account_argument( parser, @@ -192,6 +198,7 @@ def _add_discord_add_options(parser: ArgumentParser) -> None: help=SUPPRESS, ) + def _add_feishu_runtime_target_options( parser: ArgumentParser, *, @@ -208,6 +215,7 @@ def _add_feishu_runtime_target_options( if include_account_id: parser.add_argument("--account-id", dest="account_id_flag", help=SUPPRESS) + def _add_feishu_start_options(parser: ArgumentParser) -> None: _add_feishu_runtime_target_options(parser, include_account_id=True) _add_optional_account_argument( @@ -221,6 +229,7 @@ def _add_feishu_start_options(parser: ArgumentParser) -> None: ) _add_http_server_options(parser) + def _add_feishu_status_options(parser: ArgumentParser) -> None: _add_feishu_runtime_target_options(parser, include_account_id=True) _add_optional_account_argument( @@ -228,6 +237,7 @@ def _add_feishu_status_options(parser: ArgumentParser) -> None: help_text="Feishu account id. Omit to inspect the provider-wide runtime and all accounts.", ) + def _add_feishu_stop_options(parser: ArgumentParser) -> None: _add_feishu_runtime_target_options(parser, include_account_id=True) _add_optional_account_argument( @@ -246,6 +256,7 @@ def _add_feishu_stop_options(parser: ArgumentParser) -> None: help="Send SIGKILL when the process does not exit within --timeout.", ) + def _add_feishu_restart_options(parser: ArgumentParser) -> None: _add_feishu_runtime_target_options(parser, include_account_id=True) _add_optional_account_argument( @@ -265,6 +276,7 @@ def _add_feishu_restart_options(parser: ArgumentParser) -> None: help="Send SIGKILL when the previous process does not exit within --timeout.", ) + def _add_feishu_logs_options(parser: ArgumentParser) -> None: _add_feishu_runtime_target_options(parser, include_account_id=True) _add_required_account_argument( @@ -288,6 +300,7 @@ def _add_feishu_logs_options(parser: ArgumentParser) -> None: help="Print the resolved log file path and exit.", ) + def _add_feishu_add_options(parser: ArgumentParser) -> None: _add_optional_account_argument( parser, @@ -861,5 +874,40 @@ def _add_wecom_add_options(parser: ArgumentParser) -> None: ) - -__all__ = ['_add_discord_runtime_target_options', '_add_discord_start_options', '_add_discord_status_options', '_add_discord_stop_options', '_add_discord_restart_options', '_add_discord_logs_options', '_add_discord_add_options', '_add_feishu_runtime_target_options', '_add_feishu_start_options', '_add_feishu_status_options', '_add_feishu_stop_options', '_add_feishu_restart_options', '_add_feishu_logs_options', '_add_feishu_add_options', '_add_dingding_runtime_target_options', '_add_dingding_start_options', '_add_dingding_status_options', '_add_dingding_stop_options', '_add_dingding_restart_options', '_add_dingding_logs_options', '_add_dingding_add_options', '_add_weixin_runtime_target_options', '_add_weixin_start_options', '_add_weixin_status_options', '_add_weixin_stop_options', '_add_weixin_restart_options', '_add_weixin_logs_options', '_add_weixin_add_options', '_add_wecom_runtime_target_options', '_add_wecom_start_options', '_add_wecom_status_options', '_add_wecom_stop_options', '_add_wecom_restart_options', '_add_wecom_logs_options', '_add_wecom_add_options'] +__all__ = [ + "_add_discord_runtime_target_options", + "_add_discord_start_options", + "_add_discord_status_options", + "_add_discord_stop_options", + "_add_discord_restart_options", + "_add_discord_logs_options", + "_add_discord_add_options", + "_add_feishu_runtime_target_options", + "_add_feishu_start_options", + "_add_feishu_status_options", + "_add_feishu_stop_options", + "_add_feishu_restart_options", + "_add_feishu_logs_options", + "_add_feishu_add_options", + "_add_dingding_runtime_target_options", + "_add_dingding_start_options", + "_add_dingding_status_options", + "_add_dingding_stop_options", + "_add_dingding_restart_options", + "_add_dingding_logs_options", + "_add_dingding_add_options", + "_add_weixin_runtime_target_options", + "_add_weixin_start_options", + "_add_weixin_status_options", + "_add_weixin_stop_options", + "_add_weixin_restart_options", + "_add_weixin_logs_options", + "_add_weixin_add_options", + "_add_wecom_runtime_target_options", + "_add_wecom_start_options", + "_add_wecom_status_options", + "_add_wecom_stop_options", + "_add_wecom_restart_options", + "_add_wecom_logs_options", + "_add_wecom_add_options", +] diff --git a/apps/gateway/gateway_main_parser_state.py b/apps/gateway/gateway_main_parser_state.py index 975ec26..e55c7ef 100644 --- a/apps/gateway/gateway_main_parser_state.py +++ b/apps/gateway/gateway_main_parser_state.py @@ -1,42 +1,11 @@ """Gateway parser, account, and status helpers.""" from __future__ import annotations -import asyncio -from argparse import SUPPRESS, ArgumentParser, Namespace -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass -from datetime import UTC, datetime -import getpass -import apps.cli.wizard as cli_wizard -import importlib.util -import json -import os +from argparse import ArgumentParser, Namespace +from collections.abc import Mapping, Sequence from pathlib import Path import re -import shlex -import signal -import subprocess -import sys -import time -from wsgiref.simple_server import make_server - -from apps.cli.runtime import CliRuntime -from apps.cli.shell import ( - Align, - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_LIGHT, - BRAND_MUTED, - Console, - Group, - Panel, - RICH_AVAILABLE, - Table, - Text, - _resolve_elephant_version, - render_elephant_mark, -) -from apps.provider_runtime import load_runtime_local_secret_env + from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID @@ -47,30 +16,10 @@ DEFAULT_DISCORD_BOT_TOKEN_ENV, DEFAULT_FEISHU_APP_ID_ENV, DEFAULT_FEISHU_APP_SECRET_ENV, - DEFAULT_FEISHU_EVENT_PATH, DEFAULT_WECOM_BOT_ID_ENV, DEFAULT_WECOM_SECRET_ENV, - DINGDING_ADAPTER_ID, FEISHU_ADAPTER_ID, - GatewayHttpService, - GatewayManagedRuntime, - GatewayManagedService, - SUPPORTED_DINGDING_TRANSPORTS, - SUPPORTED_DISCORD_TRANSPORTS, - SUPPORTED_FEISHU_TRANSPORTS, - SUPPORTED_WECOM_TRANSPORTS, - SUPPORTED_WEIXIN_TRANSPORTS, - WECOM_ADAPTER_ID, - WEIXIN_ADAPTER_ID, - build_gateway_app, - build_gateway_plugin_registry, - create_gateway_web_app, ) -from .dingding import DINGTALK_STREAM_PIP_SPEC, DingdingGatewayService -from .discord import DISCORD_PY_PIP_SPEC, DiscordGatewayService -from .feishu import FEISHU_SDK_PIP_SPEC, FeishuGatewayService -from .wecom import WecomGatewayService -from .weixin import WeixinGatewayService try: from prompt_toolkit.application import Application @@ -97,11 +46,13 @@ from .gateway_main_runtime import * # noqa: F401,F403 from .gateway_main_wizard import * # noqa: F401,F403 + def _secret_reference_id(*, account_id: str, secret_key: str) -> str: normalized_account = re.sub(r"[^a-z0-9]+", "-", account_id.strip().lower()).strip("-") or "default" normalized_key = secret_key.replace("_", "-") return f"secret-feishu-{normalized_account}-{normalized_key}" + def _default_feishu_secret_env_var(*, account_id: str, secret_key: str) -> str: if account_id == DEFAULT_GATEWAY_ACCOUNT_ID: if secret_key == "app_id": @@ -112,6 +63,7 @@ def _default_feishu_secret_env_var(*, account_id: str, secret_key: str) -> str: suffix = "APP_ID" if secret_key == "app_id" else "APP_SECRET" return f"ELEPHANT_FEISHU_{normalized_account}_{suffix}" + def _build_feishu_secret_reference( *, account_id: str, @@ -126,6 +78,7 @@ def _build_feishu_secret_reference( "metadata": {"env_var": env_var}, } + def _find_feishu_account( accounts: Sequence[Mapping[str, object]], *, @@ -137,6 +90,7 @@ def _find_feishu_account( return account return None + def _account_secret_env_var( account_payload: Mapping[str, object] | None, *, @@ -168,6 +122,7 @@ def _account_secret_env_var( return text return None + def _payload_string_list(value: object) -> list[str]: if value is None: return [] @@ -180,6 +135,7 @@ def _payload_string_list(value: object) -> list[str]: resolved.append(text) return list(dict.fromkeys(resolved)) + def _resolved_cli_account_id(args: Namespace) -> str | None: raw_account_id = getattr(args, "account_id", None) direct = _optional_text(raw_account_id) if isinstance(raw_account_id, str) else None @@ -190,13 +146,17 @@ def _resolved_cli_account_id(args: Namespace) -> str | None: return None return _optional_text(raw_account_id_flag) + def _default_dingding_secret_env_var(*, account_id: str, secret_key: str) -> str: defaults = { ("default", "client_id"): DEFAULT_DINGDING_CLIENT_ID_ENV, ("default", "client_secret"): DEFAULT_DINGDING_CLIENT_SECRET_ENV, ("default", "robot_code"): DEFAULT_DINGDING_ROBOT_CODE_ENV, } - key = (DEFAULT_GATEWAY_ACCOUNT_ID if account_id == "default" else account_id, secret_key) + key = ( + DEFAULT_GATEWAY_ACCOUNT_ID if account_id == "default" else account_id, + secret_key, + ) if key in defaults: return defaults[key] normalized_account = re.sub(r"[^A-Za-z0-9]+", "_", account_id.strip()).strip("_").upper() or "DEFAULT" @@ -253,7 +213,9 @@ def _upsert_dingding_account( return resolved -def _dingding_account_secret_env_vars(account_payload: Mapping[str, object]) -> tuple[str, ...]: +def _dingding_account_secret_env_vars( + account_payload: Mapping[str, object], +) -> tuple[str, ...]: env_payload = _mapping(account_payload.get("env")) or {} env_vars: list[str] = [] for key in ("client_id", "client_secret", "robot_code"): @@ -269,6 +231,7 @@ def _default_discord_bot_token_env_var(*, account_id: str) -> str: normalized_account = re.sub(r"[^A-Za-z0-9]+", "_", account_id.strip()).strip("_").upper() or "DEFAULT" return f"ELEPHANT_DISCORD_{normalized_account}_BOT_TOKEN" + def _find_discord_account( accounts: Sequence[Mapping[str, object]], *, @@ -280,6 +243,7 @@ def _find_discord_account( return account return None + def _resolved_discord_bot_token_env_var( *, explicit_env_var: object, @@ -299,6 +263,7 @@ def _resolved_discord_bot_token_env_var( return text return _default_discord_bot_token_env_var(account_id=account_id) + def _is_unconfigured_default_discord_account( account_payload: Mapping[str, object], *, @@ -325,6 +290,7 @@ def _is_unconfigured_default_discord_account( return False return True + def _upsert_discord_account( accounts: Sequence[Mapping[str, object]], account_payload: Mapping[str, object], @@ -357,6 +323,7 @@ def _upsert_discord_account( updated.append({str(key): value for key, value in account_payload.items()}) return updated + def _resolved_feishu_secret_env_var( *, explicit_env_var: object, @@ -373,6 +340,7 @@ def _resolved_feishu_secret_env_var( secret_key=secret_key, ) + def _upsert_feishu_account( accounts: Sequence[Mapping[str, object]], account_payload: Mapping[str, object], @@ -387,6 +355,7 @@ def _upsert_feishu_account( resolved.append(dict(account_payload)) return resolved + def _remove_account_payload( accounts: Sequence[Mapping[str, object]], *, @@ -404,12 +373,18 @@ def _remove_account_payload( raise SystemExit(f"unknown gateway account: {account_id}") return updated, removed -def _discord_account_secret_env_vars(account_payload: Mapping[str, object]) -> tuple[str, ...]: + +def _discord_account_secret_env_vars( + account_payload: Mapping[str, object], +) -> tuple[str, ...]: env_payload = _mapping(account_payload.get("env")) or {} env_var = _optional_text(env_payload.get("bot_token")) return (env_var,) if env_var is not None else () -def _feishu_account_secret_env_vars(account_payload: Mapping[str, object]) -> tuple[str, ...]: + +def _feishu_account_secret_env_vars( + account_payload: Mapping[str, object], +) -> tuple[str, ...]: env_vars: list[str] = [] env_payload = _mapping(account_payload.get("env")) or {} for key in ("app_id", "app_secret"): @@ -457,7 +432,9 @@ def _upsert_weixin_account( return resolved -def _weixin_account_secret_env_vars(account_payload: Mapping[str, object]) -> tuple[str, ...]: +def _weixin_account_secret_env_vars( + account_payload: Mapping[str, object], +) -> tuple[str, ...]: return () @@ -488,7 +465,9 @@ def _upsert_wecom_account( return resolved -def _wecom_account_secret_env_vars(account_payload: Mapping[str, object]) -> tuple[str, ...]: +def _wecom_account_secret_env_vars( + account_payload: Mapping[str, object], +) -> tuple[str, ...]: env_payload = _mapping(account_payload.get("env")) or {} env_vars: list[str] = [] for key in ("bot_id", "secret"): @@ -530,6 +509,7 @@ def _resolved_wecom_secret_env_var( return text return _default_wecom_secret_env_var(account_id=account_id, secret_key=secret_key) + def _resolved_defaults( *, default_state_dir_override: Path | None = None, @@ -540,19 +520,59 @@ def _resolved_defaults( "cli_state_dir": default_control_state_dir_override or default_cli_state_dir(), } + def _add_common_gateway_options(parser: ArgumentParser, *, defaults: dict[str, Path]) -> None: parser.add_argument("--state-dir", type=Path, default=defaults["state_dir"]) parser.add_argument("--cli-state-dir", type=Path, default=defaults["cli_state_dir"]) + def _add_http_server_options(parser: ArgumentParser) -> None: parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=8788) + def _add_optional_account_argument(parser: ArgumentParser, *, help_text: str) -> None: parser.add_argument("account_id", nargs="?", help=help_text) + def _add_required_account_argument(parser: ArgumentParser, *, help_text: str) -> None: parser.add_argument("account_id", nargs="?", help=help_text) -__all__ = ['_secret_reference_id', '_default_feishu_secret_env_var', '_build_feishu_secret_reference', '_find_feishu_account', '_account_secret_env_var', '_payload_string_list', '_resolved_cli_account_id', '_default_dingding_secret_env_var', '_find_dingding_account', '_resolved_dingding_secret_env_var', '_upsert_dingding_account', '_dingding_account_secret_env_vars', '_default_discord_bot_token_env_var', '_find_discord_account', '_resolved_discord_bot_token_env_var', '_is_unconfigured_default_discord_account', '_upsert_discord_account', '_resolved_feishu_secret_env_var', '_upsert_feishu_account', '_remove_account_payload', '_discord_account_secret_env_vars', '_feishu_account_secret_env_vars', '_find_weixin_account', '_upsert_weixin_account', '_weixin_account_secret_env_vars', '_find_wecom_account', '_upsert_wecom_account', '_wecom_account_secret_env_vars', '_default_wecom_secret_env_var', '_resolved_wecom_secret_env_var', '_resolved_defaults', '_add_common_gateway_options', '_add_http_server_options', '_add_optional_account_argument', '_add_required_account_argument'] +__all__ = [ + "_secret_reference_id", + "_default_feishu_secret_env_var", + "_build_feishu_secret_reference", + "_find_feishu_account", + "_account_secret_env_var", + "_payload_string_list", + "_resolved_cli_account_id", + "_default_dingding_secret_env_var", + "_find_dingding_account", + "_resolved_dingding_secret_env_var", + "_upsert_dingding_account", + "_dingding_account_secret_env_vars", + "_default_discord_bot_token_env_var", + "_find_discord_account", + "_resolved_discord_bot_token_env_var", + "_is_unconfigured_default_discord_account", + "_upsert_discord_account", + "_resolved_feishu_secret_env_var", + "_upsert_feishu_account", + "_remove_account_payload", + "_discord_account_secret_env_vars", + "_feishu_account_secret_env_vars", + "_find_weixin_account", + "_upsert_weixin_account", + "_weixin_account_secret_env_vars", + "_find_wecom_account", + "_upsert_wecom_account", + "_wecom_account_secret_env_vars", + "_default_wecom_secret_env_var", + "_resolved_wecom_secret_env_var", + "_resolved_defaults", + "_add_common_gateway_options", + "_add_http_server_options", + "_add_optional_account_argument", + "_add_required_account_argument", +] diff --git a/apps/gateway/gateway_main_runtime.py b/apps/gateway/gateway_main_runtime.py index c21a732..44dea5b 100644 --- a/apps/gateway/gateway_main_runtime.py +++ b/apps/gateway/gateway_main_runtime.py @@ -1,64 +1,32 @@ """Gateway managed-runtime persistence and process helpers.""" from __future__ import annotations -import asyncio -from argparse import SUPPRESS, ArgumentParser, Namespace -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass +from argparse import Namespace +from collections.abc import Mapping, Sequence +from dataclasses import asdict from datetime import UTC, datetime -import getpass -import apps.cli.wizard as cli_wizard -import importlib.util import json import os from pathlib import Path -import re import shlex import signal import subprocess import sys import time import warnings -from wsgiref.simple_server import make_server - -from apps.cli.runtime import CliRuntime -from apps.cli.shell import ( - Align, - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_LIGHT, - BRAND_MUTED, - Console, - Group, - Panel, - RICH_AVAILABLE, - Table, - Text, - _resolve_elephant_version, - render_elephant_mark, -) + from apps.provider_runtime import load_runtime_local_secret_env -from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir -from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID -from packages.runtime_config import save_extensions_to_config, global_config_path_for_state_dir, load_extensions_from_config, load_global_config +from apps.runtime_layout import default_cli_state_dir +from packages.runtime_config import ( + global_config_path_for_state_dir, + load_extensions_from_config, + load_global_config, +) from . import ( - DEFAULT_DISCORD_BOT_TOKEN_ENV, - DEFAULT_FEISHU_APP_ID_ENV, - DEFAULT_FEISHU_APP_SECRET_ENV, - DEFAULT_FEISHU_EVENT_PATH, - FEISHU_ADAPTER_ID, - GatewayHttpService, GatewayManagedRuntime, GatewayManagedService, - SUPPORTED_DISCORD_TRANSPORTS, - SUPPORTED_FEISHU_TRANSPORTS, - build_gateway_app, - build_gateway_plugin_registry, - create_gateway_web_app, ) -from .discord import DISCORD_PY_PIP_SPEC, DiscordGatewayService -from .feishu import FEISHU_SDK_PIP_SPEC, FeishuGatewayService GATEWAY_LOCAL_SECRET_ENV_FILE = "gateway-local-secrets.json" @@ -86,9 +54,11 @@ from .gateway_main_wizard import * # noqa: F401,F403 + def _mapping(value: object) -> Mapping[str, object] | None: return value if isinstance(value, Mapping) else None + def _mapping_payload(value: object, *, path: str) -> dict[str, object]: if value is None: return {} @@ -96,6 +66,7 @@ def _mapping_payload(value: object, *, path: str) -> dict[str, object]: raise ValueError(f"{path} must be a JSON object") return {str(key): item for key, item in value.items()} + def _load_profile_manifest(state_dir: Path) -> dict[str, object]: """Load gateway and extension data from the canonical config.yaml.""" try: @@ -113,9 +84,11 @@ def _load_profile_manifest(state_dir: Path) -> dict[str, object]: manifest.update(extensions) return manifest + def _gateway_local_secret_env_path(state_dir: Path) -> Path: return state_dir / GATEWAY_LOCAL_SECRET_ENV_FILE + def _load_gateway_local_secret_env(state_dir: Path) -> dict[str, str]: path = _gateway_local_secret_env_path(state_dir) if not path.exists(): @@ -130,6 +103,7 @@ def _load_gateway_local_secret_env(state_dir: Path) -> dict[str, str]: resolved[str(key)] = text return resolved + def _persist_gateway_local_secret_env( state_dir: Path, updates: Mapping[str, str], @@ -148,6 +122,7 @@ def _persist_gateway_local_secret_env( pass return path + def _delete_gateway_local_secret_env( state_dir: Path, keys: Sequence[str], @@ -177,6 +152,7 @@ def _delete_gateway_local_secret_env( pass return path + def _gateway_runtime_environ( state_dir: Path, *, @@ -188,6 +164,7 @@ def _gateway_runtime_environ( env.update(os.environ) return env + def _read_pid(path: Path) -> int | None: if not path.exists(): return None @@ -196,6 +173,7 @@ def _read_pid(path: Path) -> int | None: except (OSError, ValueError): return None + def _pid_is_running(pid: int | None) -> bool: if pid is None or pid <= 0: return False @@ -205,12 +183,14 @@ def _pid_is_running(pid: int | None) -> bool: return False return True + def _optional_text(value: object) -> str | None: if value is None: return None text = str(value).strip() return text or None + def _resolved_cli_account_id(args: Namespace) -> str | None: raw_account_id = getattr(args, "account_id", None) direct = _optional_text(raw_account_id) if isinstance(raw_account_id, str) else None @@ -221,6 +201,7 @@ def _resolved_cli_account_id(args: Namespace) -> str | None: return None return _optional_text(raw_account_id_flag) + def _coerce_int(value: object) -> int | None: if value is None: return None @@ -229,9 +210,11 @@ def _coerce_int(value: object) -> int | None: except (TypeError, ValueError): return None + def _utc_now_iso() -> str: return datetime.now(UTC).isoformat() + def _load_runtime_record(path: Path) -> dict[str, object] | None: if not path.exists(): return None @@ -243,6 +226,7 @@ def _load_runtime_record(path: Path) -> dict[str, object] | None: return None return {str(key): value for key, value in payload.items()} + def _write_runtime_record(path: Path, record: GatewayRuntimeRecord) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text( @@ -250,6 +234,7 @@ def _write_runtime_record(path: Path, record: GatewayRuntimeRecord) -> None: encoding="utf-8", ) + def _build_runtime_record( args: Namespace, *, @@ -264,10 +249,7 @@ def _build_runtime_record( last_error: str | None = None, ) -> GatewayRuntimeRecord: existing_payload = dict(existing or {}) - command_payload = tuple( - str(value) - for value in (command or existing_payload.get("command") or ()) - ) + command_payload = tuple(str(value) for value in (command or existing_payload.get("command") or ())) cli_state_dir = _optional_text(getattr(args, "cli_state_dir", None)) or _optional_text( existing_payload.get("cli_state_dir") ) @@ -313,6 +295,7 @@ def _build_runtime_record( transport=runtime.target, ) + def _runtime_state(runtime: GatewayManagedRuntime) -> dict[str, object]: record = _load_runtime_record(runtime.record_path) or {} pid_from_file = _read_pid(runtime.pid_path) @@ -336,12 +319,14 @@ def _runtime_state(runtime: GatewayManagedRuntime) -> dict[str, object]: "status": status, } + def _remove_file_if_exists(path: Path) -> None: try: path.unlink() except FileNotFoundError: return + def _wait_for_pid_exit(pid: int, *, timeout_seconds: float) -> bool: deadline = time.monotonic() + max(timeout_seconds, 0.0) while time.monotonic() < deadline: @@ -350,6 +335,7 @@ def _wait_for_pid_exit(pid: int, *, timeout_seconds: float) -> bool: time.sleep(0.2) return not _pid_is_running(pid) + def _terminate_pid(pid: int, *, timeout_seconds: float, force: bool) -> str | None: if timeout_seconds < 0: raise SystemExit("--timeout must be zero or a positive number.") @@ -377,6 +363,7 @@ def _terminate_pid(pid: int, *, timeout_seconds: float, force: bool) -> str | No return signal.Signals(signal.SIGKILL).name raise SystemExit(f"Process {pid} is still running after SIGKILL.") + def _resolve_runtime_target_argument( args: Namespace, *, @@ -386,11 +373,10 @@ def _resolve_runtime_target_argument( if requested_target != "configured": return requested_target if service is None: - raise SystemExit( - "Resolving the configured runtime target requires an active gateway service profile." - ) + raise SystemExit("Resolving the configured runtime target requires an active gateway service profile.") return service.configured_runtime_target() + def _process_command_contains(pid: int, needles: Sequence[str]) -> bool: """Return True if the process command line contains all of the given needles. @@ -620,9 +606,7 @@ def _run_start_detached( state = _runtime_state(runtime) existing_pid = state["pid"] if state["pid_active"]: - raise SystemExit( - f"{runtime.label} is already running in the background with pid {existing_pid}." - ) + raise SystemExit(f"{runtime.label} is already running in the background with pid {existing_pid}.") args.state_dir.mkdir(parents=True, exist_ok=True) command = service.build_detached_runtime_command(args=args, target=target) started_at = _utc_now_iso() @@ -714,6 +698,7 @@ def _run_start_detached( del process return 0 + def _read_log_excerpt(path: Path, *, tail: int) -> tuple[str, ...]: if tail < 0: raise SystemExit("--tail must be zero or a positive integer.") @@ -725,6 +710,7 @@ def _read_log_excerpt(path: Path, *, tail: int) -> tuple[str, ...]: return () return tuple(lines[-tail:]) + def _follow_log_file(path: Path) -> None: with path.open("r", encoding="utf-8") as stream: stream.seek(0, os.SEEK_END) @@ -742,6 +728,7 @@ def _follow_log_file(path: Path) -> None: if current_size < stream.tell(): stream.seek(0) + def _format_runtime_command(record: Mapping[str, object]) -> str: command = record.get("command") if not isinstance(command, (list, tuple)): @@ -751,6 +738,7 @@ def _format_runtime_command(record: Mapping[str, object]) -> str: return "" return shlex.join(parts) + def _run_status(args: Namespace, *, service: GatewayManagedService | None = None) -> int: if service is None: raise TypeError("_run_status requires a managed gateway service") @@ -808,6 +796,7 @@ def _run_status(args: Namespace, *, service: GatewayManagedService | None = None print(_render_feishu_account_line(account, prefix="account")) return 0 + def _stop_managed_runtime( args: Namespace, *, @@ -850,6 +839,7 @@ def _stop_managed_runtime( ) return "stopped", signal_name, runtime + def _run_stop(args: Namespace, *, service: GatewayManagedService | None = None) -> int: if service is None: raise TypeError("_run_stop requires a managed gateway service") @@ -868,6 +858,7 @@ def _run_stop(args: Namespace, *, service: GatewayManagedService | None = None) print(f"Runtime record: {runtime.record_path}") return 0 + def _run_restart(args: Namespace, *, service: GatewayManagedService | None = None) -> int: if service is None: raise TypeError("_run_restart requires a managed gateway service") @@ -879,6 +870,7 @@ def _run_restart(args: Namespace, *, service: GatewayManagedService | None = Non print("No running detached runtime was found; starting a fresh background process.") return _run_start_detached(args, service=service, target=target, action="restart") + def _run_logs(args: Namespace, *, service: GatewayManagedService | None = None) -> int: if service is None: raise TypeError("_run_logs requires a managed gateway service") @@ -891,9 +883,7 @@ def _run_logs(args: Namespace, *, service: GatewayManagedService | None = None) print(runtime.log_path) return 0 if not runtime.log_path.exists(): - running_hint = ( - f" Background pid {state['pid']} is still recorded." if state["pid_active"] else "" - ) + running_hint = f" Background pid {state['pid']} is still recorded." if state["pid_active"] else "" raise SystemExit( f"No log file found for {runtime.label} at {runtime.log_path}." f" Start it with `{service.managed_runtime_log_hint(target=target).replace(' logs ', ' start ').replace('--follow', '--detach')}` first.{running_hint}" @@ -910,6 +900,7 @@ def _run_logs(args: Namespace, *, service: GatewayManagedService | None = None) return 0 return 0 + __all__ = [ "_mapping", "_mapping_payload", diff --git a/apps/gateway/gateway_main_setup_impl.py b/apps/gateway/gateway_main_setup_impl.py index aa6d789..c25f420 100644 --- a/apps/gateway/gateway_main_setup_impl.py +++ b/apps/gateway/gateway_main_setup_impl.py @@ -2,74 +2,21 @@ from __future__ import annotations import asyncio -from argparse import SUPPRESS, ArgumentParser, Namespace -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass -from datetime import UTC, datetime -import getpass -import apps.cli.wizard as cli_wizard -import importlib.util -import json -import os +from argparse import Namespace +from collections.abc import Mapping from pathlib import Path -import re -import shlex -import signal -import subprocess -import sys -import time -from wsgiref.simple_server import make_server - -from apps.cli.runtime import CliRuntime -from apps.cli.shell import ( - Align, - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_LIGHT, - BRAND_MUTED, - Console, - Group, - Panel, - RICH_AVAILABLE, - Table, - Text, - _resolve_elephant_version, - render_elephant_mark, -) -from apps.provider_runtime import load_runtime_local_secret_env -from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir + +from apps.runtime_layout import default_gateway_state_dir from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID -from packages.runtime_config import save_extensions_to_config, global_config_path_for_state_dir, load_global_config, write_global_config +from packages.runtime_config import ( + global_config_path_for_state_dir, + load_global_config, + write_global_config, +) from . import ( - DEFAULT_DINGDING_CLIENT_ID_ENV, - DEFAULT_DINGDING_CLIENT_SECRET_ENV, - DEFAULT_DINGDING_ROBOT_CODE_ENV, - DEFAULT_DISCORD_BOT_TOKEN_ENV, - DEFAULT_FEISHU_APP_ID_ENV, - DEFAULT_FEISHU_APP_SECRET_ENV, DEFAULT_FEISHU_EVENT_PATH, - DEFAULT_WECOM_BOT_ID_ENV, - DEFAULT_WECOM_SECRET_ENV, - FEISHU_ADAPTER_ID, - GatewayHttpService, - GatewayManagedRuntime, - GatewayManagedService, - SUPPORTED_DINGDING_TRANSPORTS, - SUPPORTED_DISCORD_TRANSPORTS, - SUPPORTED_FEISHU_TRANSPORTS, - SUPPORTED_WECOM_TRANSPORTS, - SUPPORTED_WEIXIN_TRANSPORTS, - WECOM_ADAPTER_ID, - build_gateway_app, - build_gateway_plugin_registry, - create_gateway_web_app, ) -from .dingding import DINGTALK_STREAM_PIP_SPEC, DingdingGatewayService -from .discord import DISCORD_PY_PIP_SPEC, DiscordGatewayService -from .feishu import FEISHU_SDK_PIP_SPEC, FeishuGatewayService -from .wecom import WecomGatewayService -from .weixin import WeixinGatewayService try: from prompt_toolkit.application import Application @@ -98,28 +45,19 @@ from .gateway_main_runtime import * # noqa: F401,F403 from .gateway_main_wizard import * # noqa: F401,F403 from .gateway_main_wizard import ( - GATEWAY_WIZARD_BACK, - _confirm_gateway_wizard_intro, - _gateway_wizard_choice_prompt, - _gateway_wizard_dialogs_supported, - _gateway_wizard_secret_prompt, - _gateway_wizard_text_prompt, _interactive_shell_supported, _print_gateway_dingding_wizard_intro, _print_gateway_discord_wizard_intro, _print_gateway_feishu_wizard_intro, _print_gateway_setup_paused, _print_gateway_wecom_wizard_intro, - _print_gateway_weixin_wizard_intro, _run_interactive_dingding_wizard, _run_interactive_discord_wizard, _run_interactive_feishu_wizard, _run_interactive_wecom_wizard, - _run_interactive_weixin_wizard, - _shared_wizard_choice_prompt, - _shared_wizard_text_prompt, ) + def _save_gateway_manifest(state_dir: Path, manifest: Mapping[str, Any]) -> Path: """Write gateway and extension data to config.yaml, return the config path.""" config_path = global_config_path_for_state_dir(state_dir) @@ -131,7 +69,12 @@ def _save_gateway_manifest(state_dir: Path, manifest: Mapping[str, Any]) -> Path else: config.pop("gateway", None) # Merge extension keys - extension_keys = ("tool_manifests", "skill_manifests", "skill_overrides", "skill_packages") + extension_keys = ( + "tool_manifests", + "skill_manifests", + "skill_overrides", + "skill_packages", + ) extensions = dict(config.get("extensions", {})) if isinstance(config.get("extensions"), Mapping) else {} for key in extension_keys: if key in manifest: @@ -146,16 +89,6 @@ def _save_gateway_manifest(state_dir: Path, manifest: Mapping[str, Any]) -> Path def _run_add_discord(args: Namespace) -> int: _ensure_discord_sdk_available(reason="Discord setup") - - - - - - - - - - manifest = _load_profile_manifest(args.cli_state_dir) gateway_payload = _mapping_payload(manifest.get("gateway"), path="gateway") adapters_payload = _mapping_payload(gateway_payload.get("adapters"), path="gateway.adapters") @@ -170,9 +103,7 @@ def _run_add_discord(args: Namespace) -> int: existing_accounts = [] for index, account in enumerate(accounts_value): if not isinstance(account, Mapping): - raise SystemExit( - f"gateway.adapters.discord.accounts[{index}] must be a JSON object" - ) + raise SystemExit(f"gateway.adapters.discord.accounts[{index}] must be a JSON object") existing_accounts.append({str(key): value for key, value in account.items()}) else: raise SystemExit("gateway.adapters.discord.accounts must be a JSON array") @@ -319,6 +250,7 @@ def _run_add_discord(args: Namespace) -> int: print("- Start the configured bridge with `elephant gateway discord start`.") return 0 + def _start_discord_runtime_after_setup(args: Namespace, *, transport: str) -> int: start_args = Namespace(**vars(args)) start_args.runtime_target = transport or "configured" @@ -327,8 +259,10 @@ def _start_discord_runtime_after_setup(args: Namespace, *, transport: str) -> in start_args.timeout = float(getattr(start_args, "timeout", 10.0) or 10.0) start_args.force = bool(getattr(start_args, "force", False)) from apps.gateway.gateway_main_impl import _start_via_daemon + return _start_via_daemon(start_args) + def _start_feishu_runtime_after_setup(args: Namespace, *, transport: str) -> int: start_args = Namespace(**vars(args)) start_args.runtime_target = transport or "configured" @@ -340,8 +274,10 @@ def _start_feishu_runtime_after_setup(args: Namespace, *, transport: str) -> int start_args.timeout = float(getattr(start_args, "timeout", 10.0) or 10.0) start_args.force = bool(getattr(start_args, "force", False)) from apps.gateway.gateway_main_impl import _start_via_daemon + return _start_via_daemon(start_args) + def _run_add_feishu(args: Namespace) -> int: _ensure_feishu_sdk_available(reason="Feishu setup") @@ -359,9 +295,7 @@ def _run_add_feishu(args: Namespace) -> int: existing_accounts = [] for index, account in enumerate(accounts_value): if not isinstance(account, Mapping): - raise SystemExit( - f"gateway.adapters.feishu.accounts[{index}] must be a JSON object" - ) + raise SystemExit(f"gateway.adapters.feishu.accounts[{index}] must be a JSON object") existing_accounts.append({str(key): value for key, value in account.items()}) else: raise SystemExit("gateway.adapters.feishu.accounts must be a JSON array") @@ -370,19 +304,13 @@ def _run_add_feishu(args: Namespace) -> int: transport = ( str(args.transport) if args.transport is not None - else str( - (existing_account or {}).get("surface") - or feishu_payload.get("surface") - or "long-connection" - ) + else str((existing_account or {}).get("surface") or feishu_payload.get("surface") or "long-connection") ) event_path = ( str(args.event_path) if args.event_path is not None else str( - (existing_account or {}).get("event_path") - or feishu_payload.get("event_path") - or DEFAULT_FEISHU_EVENT_PATH + (existing_account or {}).get("event_path") or feishu_payload.get("event_path") or DEFAULT_FEISHU_EVENT_PATH ) ) app_id_env_var = _resolved_feishu_secret_env_var( @@ -524,6 +452,7 @@ def _run_add_feishu(args: Namespace) -> int: print("- Start the configured bridge with `elephant gateway feishu start`.") return 0 + def _run_remove_discord(args: Namespace) -> int: account_id = _resolved_cli_account_id(args) if account_id is None: @@ -559,6 +488,7 @@ def _run_remove_discord(args: Namespace) -> int: print(f"Updated local IM secret file: {secret_path}") return 0 + def _run_remove_feishu(args: Namespace) -> int: account_id = _resolved_cli_account_id(args) if account_id is None: @@ -594,6 +524,7 @@ def _run_remove_feishu(args: Namespace) -> int: print(f"Updated local IM secret file: {secret_path}") return 0 + def _start_dingding_runtime_after_setup(args: Namespace, *, transport: str) -> int: start_args = Namespace(**vars(args)) start_args.runtime_target = transport or "configured" @@ -602,8 +533,10 @@ def _start_dingding_runtime_after_setup(args: Namespace, *, transport: str) -> i start_args.timeout = float(getattr(start_args, "timeout", 10.0) or 10.0) start_args.force = bool(getattr(start_args, "force", False)) from apps.gateway.gateway_main_impl import _start_via_daemon + return _start_via_daemon(start_args) + def _run_add_dingding(args: Namespace) -> int: _ensure_dingding_sdk_available(reason="DingDing setup") @@ -621,10 +554,30 @@ def _run_add_dingding(args: Namespace) -> int: else: raise SystemExit("gateway.adapters.dingding.accounts must be a JSON array") existing_account = _find_dingding_account(existing_accounts, account_id=account_id) - transport = str(args.transport or "").strip() or str((existing_account or {}).get("surface") or "").strip() or str(dingding_payload.get("surface") or "").strip() or "stream" - client_id_env_var = _resolved_dingding_secret_env_var(explicit_env_var=args.client_id_env_var, existing_account=existing_account, account_id=account_id, secret_key="client_id") - client_secret_env_var = _resolved_dingding_secret_env_var(explicit_env_var=args.client_secret_env_var, existing_account=existing_account, account_id=account_id, secret_key="client_secret") - robot_code_env_var = _resolved_dingding_secret_env_var(explicit_env_var=args.robot_code_env_var, existing_account=existing_account, account_id=account_id, secret_key="robot_code") + transport = ( + str(args.transport or "").strip() + or str((existing_account or {}).get("surface") or "").strip() + or str(dingding_payload.get("surface") or "").strip() + or "stream" + ) + client_id_env_var = _resolved_dingding_secret_env_var( + explicit_env_var=args.client_id_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="client_id", + ) + client_secret_env_var = _resolved_dingding_secret_env_var( + explicit_env_var=args.client_secret_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="client_secret", + ) + robot_code_env_var = _resolved_dingding_secret_env_var( + explicit_env_var=args.robot_code_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="robot_code", + ) client_id_value = str(args.client_id or "").strip() client_secret_value = str(args.client_secret or "").strip() robot_code_value = str(args.robot_code or "").strip() @@ -633,33 +586,86 @@ def _run_add_dingding(args: Namespace) -> int: use_wizard = bool(args.wizard) if args.wizard is not None else _interactive_shell_supported() if use_wizard: if not _print_gateway_dingding_wizard_intro(): - _print_gateway_setup_paused("DingDing"); return 0 - ws = _run_interactive_dingding_wizard(account_id=account_id, transport=transport, client_id_value=client_id_value, client_secret_value=client_secret_value, robot_code_value=robot_code_value, enabled=enabled, allow_group_chats=allow_group_chats) + _print_gateway_setup_paused("DingDing") + return 0 + ws = _run_interactive_dingding_wizard( + account_id=account_id, + transport=transport, + client_id_value=client_id_value, + client_secret_value=client_secret_value, + robot_code_value=robot_code_value, + enabled=enabled, + allow_group_chats=allow_group_chats, + ) if ws is None: - _print_gateway_setup_paused("DingDing"); return 0 - account_id, transport, client_id_value, client_secret_value, robot_code_value = ws.account_id, ws.transport, ws.client_id_value, ws.client_secret_value, ws.robot_code_value + _print_gateway_setup_paused("DingDing") + return 0 + ( + account_id, + transport, + client_id_value, + client_secret_value, + robot_code_value, + ) = ( + ws.account_id, + ws.transport, + ws.client_id_value, + ws.client_secret_value, + ws.robot_code_value, + ) enabled, allow_group_chats = ws.enabled, ws.allow_group_chats auto_start = (bool(getattr(args, "auto_start", False)) or use_wizard) and not getattr(args, "no_start", False) args.account_id = account_id existing_account = _find_dingding_account(existing_accounts, account_id=account_id) - client_id_env_var = _resolved_dingding_secret_env_var(explicit_env_var=args.client_id_env_var, existing_account=existing_account, account_id=account_id, secret_key="client_id") - client_secret_env_var = _resolved_dingding_secret_env_var(explicit_env_var=args.client_secret_env_var, existing_account=existing_account, account_id=account_id, secret_key="client_secret") - robot_code_env_var = _resolved_dingding_secret_env_var(explicit_env_var=args.robot_code_env_var, existing_account=existing_account, account_id=account_id, secret_key="robot_code") - account_payload: dict[str, object] = {"account_id": account_id, "surface": transport, "enabled": True, "env": {"client_id": client_id_env_var, "client_secret": client_secret_env_var, "robot_code": robot_code_env_var}} + client_id_env_var = _resolved_dingding_secret_env_var( + explicit_env_var=args.client_id_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="client_id", + ) + client_secret_env_var = _resolved_dingding_secret_env_var( + explicit_env_var=args.client_secret_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="client_secret", + ) + robot_code_env_var = _resolved_dingding_secret_env_var( + explicit_env_var=args.robot_code_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="robot_code", + ) + account_payload: dict[str, object] = { + "account_id": account_id, + "surface": transport, + "enabled": True, + "env": { + "client_id": client_id_env_var, + "client_secret": client_secret_env_var, + "robot_code": robot_code_env_var, + }, + } local_secrets = {} - if client_id_value: local_secrets[client_id_env_var] = client_id_value - if client_secret_value: local_secrets[client_secret_env_var] = client_secret_value - if robot_code_value: local_secrets[robot_code_env_var] = robot_code_value + if client_id_value: + local_secrets[client_id_env_var] = client_id_value + if client_secret_value: + local_secrets[client_secret_env_var] = client_secret_value + if robot_code_value: + local_secrets[robot_code_env_var] = robot_code_value control_payload.pop("default_elephant_id", None) control_payload.pop("default_session_id", None) control_payload.pop("auto_create_elephant", None) - if allow_group_chats: control_payload["allow_group_chats"] = True - elif use_wizard: control_payload.pop("allow_group_chats", None) + if allow_group_chats: + control_payload["allow_group_chats"] = True + elif use_wizard: + control_payload.pop("allow_group_chats", None) dingding_payload["accounts"] = _upsert_dingding_account(existing_accounts, account_payload) dingding_payload["surface"] = transport dingding_payload["enabled"] = enabled - if control_payload: dingding_payload["control"] = control_payload - else: dingding_payload.pop("control", None) + if control_payload: + dingding_payload["control"] = control_payload + else: + dingding_payload.pop("control", None) adapters_payload["dingding"] = dingding_payload gateway_payload["adapters"] = adapters_payload manifest["gateway"] = gateway_payload @@ -668,37 +674,54 @@ def _run_add_dingding(args: Namespace) -> int: print(f"Configured DingDing IM in {manifest_path}") print(f"DingDing account: {account_id}") print(f"DingDing transport: {transport}") - if local_secret_path is not None: print(f"Local IM secret file: {local_secret_path}") + if local_secret_path is not None: + print(f"Local IM secret file: {local_secret_path}") if auto_start: print("Starting the configured DingDing bridge in the background...") - try: _start_dingding_runtime_after_setup(args, transport=transport) - except SystemExit: print("- Start it again with `elephant gateway dingding start --detach`."); return 1 - print("DingDing setup is complete."); return 0 + try: + _start_dingding_runtime_after_setup(args, transport=transport) + except SystemExit: + print("- Start it again with `elephant gateway dingding start --detach`.") + return 1 + print("DingDing setup is complete.") + return 0 print("- Start the configured bridge with `elephant gateway dingding start`.") return 0 + def _run_remove_dingding(args: Namespace) -> int: account_id = _resolved_cli_account_id(args) - if account_id is None: raise SystemExit("remove requires ") + if account_id is None: + raise SystemExit("remove requires ") manifest = _load_profile_manifest(args.cli_state_dir) gateway_payload = _mapping_payload(manifest.get("gateway"), path="gateway") adapters_payload = _mapping_payload(gateway_payload.get("adapters"), path="gateway.adapters") dingding_payload = _mapping_payload(adapters_payload.get("dingding"), path="gateway.adapters.dingding") accounts_value = dingding_payload.get("accounts") - if not isinstance(accounts_value, list): raise SystemExit("gateway.adapters.dingding.accounts must be a JSON array") + if not isinstance(accounts_value, list): + raise SystemExit("gateway.adapters.dingding.accounts must be a JSON array") remaining_accounts, removed_account = _remove_account_payload(accounts_value, account_id=account_id) secret_path = _delete_gateway_local_secret_env(args.state_dir, _dingding_account_secret_env_vars(removed_account)) - if remaining_accounts: dingding_payload["accounts"] = remaining_accounts; dingding_payload["enabled"] = True; adapters_payload["dingding"] = dingding_payload - else: adapters_payload.pop("dingding", None) - if adapters_payload: gateway_payload["adapters"] = adapters_payload; manifest["gateway"] = gateway_payload - else: manifest.pop("gateway", None) + if remaining_accounts: + dingding_payload["accounts"] = remaining_accounts + dingding_payload["enabled"] = True + adapters_payload["dingding"] = dingding_payload + else: + adapters_payload.pop("dingding", None) + if adapters_payload: + gateway_payload["adapters"] = adapters_payload + manifest["gateway"] = gateway_payload + else: + manifest.pop("gateway", None) manifest_path = _save_gateway_manifest(args.cli_state_dir, manifest) print(f"Removed DingDing account: {account_id}") print(f"Updated manifest: {manifest_path}") - if secret_path is not None: print(f"Updated local IM secret file: {secret_path}") + if secret_path is not None: + print(f"Updated local IM secret file: {secret_path}") return 0 + def _start_weixin_runtime_after_setup(args: Namespace, *, transport: str) -> int: start_args = Namespace(**vars(args)) start_args.runtime_target = transport or "configured" @@ -707,17 +730,19 @@ def _start_weixin_runtime_after_setup(args: Namespace, *, transport: str) -> int start_args.timeout = float(getattr(start_args, "timeout", 10.0) or 10.0) start_args.force = bool(getattr(start_args, "force", False)) from apps.gateway.gateway_main_impl import _start_via_daemon + return _start_via_daemon(start_args) + def _run_add_weixin(args: Namespace) -> int: from .weixin_support import check_weixin_requirements, qr_login, ILINK_BASE_URL + if not check_weixin_requirements(): _ensure_weixin_sdk_available(reason="WeChat setup") if not check_weixin_requirements(): print("Failed to install WeChat dependencies. Run: pip install aiohttp cryptography") return 1 - manifest = _load_profile_manifest(args.cli_state_dir) gateway_payload = _mapping_payload(manifest.get("gateway"), path="gateway") adapters_payload = _mapping_payload(gateway_payload.get("adapters"), path="gateway.adapters") @@ -753,8 +778,10 @@ def _run_add_weixin(args: Namespace) -> int: control_payload.pop("default_elephant_id", None) control_payload.pop("default_session_id", None) control_payload.pop("auto_create_elephant", None) - if allow_group_chats: control_payload["allow_group_chats"] = True - elif use_wizard: control_payload.pop("allow_group_chats", None) + if allow_group_chats: + control_payload["allow_group_chats"] = True + elif use_wizard: + control_payload.pop("allow_group_chats", None) account_payload: dict[str, object] = { "account_id": resolved_account_id, "token": resolved_token, @@ -765,45 +792,63 @@ def _run_add_weixin(args: Namespace) -> int: weixin_payload["accounts"] = _upsert_weixin_account(existing_accounts, account_payload) weixin_payload["surface"] = "ilink" weixin_payload["enabled"] = enabled - if control_payload: weixin_payload["control"] = control_payload - else: weixin_payload.pop("control", None) + if control_payload: + weixin_payload["control"] = control_payload + else: + weixin_payload.pop("control", None) adapters_payload["weixin"] = weixin_payload gateway_payload["adapters"] = adapters_payload manifest["gateway"] = gateway_payload manifest_path = _save_gateway_manifest(args.cli_state_dir, manifest) print(f"Configured WeChat IM in {manifest_path}") print(f"WeChat account: {resolved_account_id}") - print(f"WeChat transport: ilink") + print("WeChat transport: ilink") if auto_start: print("Starting the configured WeChat bridge in the background...") - try: _start_weixin_runtime_after_setup(args, transport="ilink") - except SystemExit: print("- Start it again with `elephant gateway weixin start --detach`."); return 1 - print("WeChat setup is complete."); return 0 + try: + _start_weixin_runtime_after_setup(args, transport="ilink") + except SystemExit: + print("- Start it again with `elephant gateway weixin start --detach`.") + return 1 + print("WeChat setup is complete.") + return 0 print("- Start the configured bridge with `elephant gateway weixin start`.") return 0 + def _run_remove_weixin(args: Namespace) -> int: account_id = _resolved_cli_account_id(args) - if account_id is None: raise SystemExit("remove requires ") + if account_id is None: + raise SystemExit("remove requires ") manifest = _load_profile_manifest(args.cli_state_dir) gateway_payload = _mapping_payload(manifest.get("gateway"), path="gateway") adapters_payload = _mapping_payload(gateway_payload.get("adapters"), path="gateway.adapters") weixin_payload = _mapping_payload(adapters_payload.get("weixin"), path="gateway.adapters.weixin") accounts_value = weixin_payload.get("accounts") - if not isinstance(accounts_value, list): raise SystemExit("gateway.adapters.weixin.accounts must be a JSON array") + if not isinstance(accounts_value, list): + raise SystemExit("gateway.adapters.weixin.accounts must be a JSON array") remaining_accounts, removed_account = _remove_account_payload(accounts_value, account_id=account_id) secret_path = _delete_gateway_local_secret_env(args.state_dir, _weixin_account_secret_env_vars(removed_account)) - if remaining_accounts: weixin_payload["accounts"] = remaining_accounts; weixin_payload["enabled"] = True; adapters_payload["weixin"] = weixin_payload - else: adapters_payload.pop("weixin", None) - if adapters_payload: gateway_payload["adapters"] = adapters_payload; manifest["gateway"] = gateway_payload - else: manifest.pop("gateway", None) + if remaining_accounts: + weixin_payload["accounts"] = remaining_accounts + weixin_payload["enabled"] = True + adapters_payload["weixin"] = weixin_payload + else: + adapters_payload.pop("weixin", None) + if adapters_payload: + gateway_payload["adapters"] = adapters_payload + manifest["gateway"] = gateway_payload + else: + manifest.pop("gateway", None) manifest_path = _save_gateway_manifest(args.cli_state_dir, manifest) print(f"Removed WeChat account: {account_id}") print(f"Updated manifest: {manifest_path}") - if secret_path is not None: print(f"Updated local IM secret file: {secret_path}") + if secret_path is not None: + print(f"Updated local IM secret file: {secret_path}") return 0 + def _run_add_wecom(args: Namespace) -> int: _ensure_wecom_sdk_available(reason="WeCom setup") @@ -821,9 +866,24 @@ def _run_add_wecom(args: Namespace) -> int: else: raise SystemExit("gateway.adapters.wecom.accounts must be a JSON array") existing_account = _find_wecom_account(existing_accounts, account_id=account_id) - transport = str(args.transport or "").strip() or str((existing_account or {}).get("surface") or "").strip() or str(wecom_payload.get("surface") or "").strip() or "websocket" - bot_id_env_var = _resolved_wecom_secret_env_var(explicit_env_var=args.bot_id_env_var, existing_account=existing_account, account_id=account_id, secret_key="bot_id") - secret_env_var = _resolved_wecom_secret_env_var(explicit_env_var=args.secret_env_var, existing_account=existing_account, account_id=account_id, secret_key="secret") + transport = ( + str(args.transport or "").strip() + or str((existing_account or {}).get("surface") or "").strip() + or str(wecom_payload.get("surface") or "").strip() + or "websocket" + ) + bot_id_env_var = _resolved_wecom_secret_env_var( + explicit_env_var=args.bot_id_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="bot_id", + ) + secret_env_var = _resolved_wecom_secret_env_var( + explicit_env_var=args.secret_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="secret", + ) bot_id_value = str(args.bot_id or "").strip() secret_value = str(args.secret or "").strip() enabled = bool(args.enabled) if args.enabled is not None else True @@ -831,31 +891,66 @@ def _run_add_wecom(args: Namespace) -> int: use_wizard = bool(args.wizard) if args.wizard is not None else _interactive_shell_supported() if use_wizard: if not _print_gateway_wecom_wizard_intro(): - _print_gateway_setup_paused("WeCom"); return 0 - ws = _run_interactive_wecom_wizard(account_id=account_id, transport=transport, bot_id_value=bot_id_value, secret_value=secret_value, enabled=enabled, allow_group_chats=allow_group_chats) + _print_gateway_setup_paused("WeCom") + return 0 + ws = _run_interactive_wecom_wizard( + account_id=account_id, + transport=transport, + bot_id_value=bot_id_value, + secret_value=secret_value, + enabled=enabled, + allow_group_chats=allow_group_chats, + ) if ws is None: - _print_gateway_setup_paused("WeCom"); return 0 - account_id, transport, bot_id_value, secret_value = ws.account_id, ws.transport, ws.bot_id_value, ws.secret_value + _print_gateway_setup_paused("WeCom") + return 0 + account_id, transport, bot_id_value, secret_value = ( + ws.account_id, + ws.transport, + ws.bot_id_value, + ws.secret_value, + ) enabled, allow_group_chats = ws.enabled, ws.allow_group_chats auto_start = (bool(getattr(args, "auto_start", False)) or use_wizard) and not getattr(args, "no_start", False) args.account_id = account_id existing_account = _find_wecom_account(existing_accounts, account_id=account_id) - bot_id_env_var = _resolved_wecom_secret_env_var(explicit_env_var=args.bot_id_env_var, existing_account=existing_account, account_id=account_id, secret_key="bot_id") - secret_env_var = _resolved_wecom_secret_env_var(explicit_env_var=args.secret_env_var, existing_account=existing_account, account_id=account_id, secret_key="secret") - account_payload: dict[str, object] = {"account_id": account_id, "surface": transport, "enabled": True, "env": {"bot_id": bot_id_env_var, "secret": secret_env_var}} + bot_id_env_var = _resolved_wecom_secret_env_var( + explicit_env_var=args.bot_id_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="bot_id", + ) + secret_env_var = _resolved_wecom_secret_env_var( + explicit_env_var=args.secret_env_var, + existing_account=existing_account, + account_id=account_id, + secret_key="secret", + ) + account_payload: dict[str, object] = { + "account_id": account_id, + "surface": transport, + "enabled": True, + "env": {"bot_id": bot_id_env_var, "secret": secret_env_var}, + } local_secrets = {} - if bot_id_value: local_secrets[bot_id_env_var] = bot_id_value - if secret_value: local_secrets[secret_env_var] = secret_value + if bot_id_value: + local_secrets[bot_id_env_var] = bot_id_value + if secret_value: + local_secrets[secret_env_var] = secret_value control_payload.pop("default_elephant_id", None) control_payload.pop("default_session_id", None) control_payload.pop("auto_create_elephant", None) - if allow_group_chats: control_payload["allow_group_chats"] = True - elif use_wizard: control_payload.pop("allow_group_chats", None) + if allow_group_chats: + control_payload["allow_group_chats"] = True + elif use_wizard: + control_payload.pop("allow_group_chats", None) wecom_payload["accounts"] = _upsert_wecom_account(existing_accounts, account_payload) wecom_payload["surface"] = transport wecom_payload["enabled"] = enabled - if control_payload: wecom_payload["control"] = control_payload - else: wecom_payload.pop("control", None) + if control_payload: + wecom_payload["control"] = control_payload + else: + wecom_payload.pop("control", None) adapters_payload["wecom"] = wecom_payload gateway_payload["adapters"] = adapters_payload manifest["gateway"] = gateway_payload @@ -864,36 +959,67 @@ def _run_add_wecom(args: Namespace) -> int: print(f"Configured WeCom IM in {manifest_path}") print(f"WeCom account: {account_id}") print(f"WeCom transport: {transport}") - if local_secret_path is not None: print(f"Local IM secret file: {local_secret_path}") + if local_secret_path is not None: + print(f"Local IM secret file: {local_secret_path}") if auto_start: print("Starting the configured WeCom bridge in the background...") - try: _start_wecom_runtime_after_setup(args, transport=transport) - except SystemExit: print("- Start it again with `elephant gateway wecom start --detach`."); return 1 - print("WeCom setup is complete."); return 0 + try: + _start_wecom_runtime_after_setup(args, transport=transport) + except SystemExit: + print("- Start it again with `elephant gateway wecom start --detach`.") + return 1 + print("WeCom setup is complete.") + return 0 print("- Start the configured bridge with `elephant gateway wecom start`.") return 0 + def _run_remove_wecom(args: Namespace) -> int: account_id = _resolved_cli_account_id(args) - if account_id is None: raise SystemExit("remove requires ") + if account_id is None: + raise SystemExit("remove requires ") manifest = _load_profile_manifest(args.cli_state_dir) gateway_payload = _mapping_payload(manifest.get("gateway"), path="gateway") adapters_payload = _mapping_payload(gateway_payload.get("adapters"), path="gateway.adapters") wecom_payload = _mapping_payload(adapters_payload.get("wecom"), path="gateway.adapters.wecom") accounts_value = wecom_payload.get("accounts") - if not isinstance(accounts_value, list): raise SystemExit("gateway.adapters.wecom.accounts must be a JSON array") + if not isinstance(accounts_value, list): + raise SystemExit("gateway.adapters.wecom.accounts must be a JSON array") remaining_accounts, removed_account = _remove_account_payload(accounts_value, account_id=account_id) secret_path = _delete_gateway_local_secret_env(args.state_dir, _wecom_account_secret_env_vars(removed_account)) - if remaining_accounts: wecom_payload["accounts"] = remaining_accounts; wecom_payload["enabled"] = True; adapters_payload["wecom"] = wecom_payload - else: adapters_payload.pop("wecom", None) - if adapters_payload: gateway_payload["adapters"] = adapters_payload; manifest["gateway"] = gateway_payload - else: manifest.pop("gateway", None) + if remaining_accounts: + wecom_payload["accounts"] = remaining_accounts + wecom_payload["enabled"] = True + adapters_payload["wecom"] = wecom_payload + else: + adapters_payload.pop("wecom", None) + if adapters_payload: + gateway_payload["adapters"] = adapters_payload + manifest["gateway"] = gateway_payload + else: + manifest.pop("gateway", None) manifest_path = _save_gateway_manifest(args.cli_state_dir, manifest) print(f"Removed WeCom account: {account_id}") print(f"Updated manifest: {manifest_path}") - if secret_path is not None: print(f"Updated local IM secret file: {secret_path}") + if secret_path is not None: + print(f"Updated local IM secret file: {secret_path}") return 0 -__all__ = ['_run_add_discord', '_start_discord_runtime_after_setup', '_start_feishu_runtime_after_setup', '_run_add_feishu', '_run_remove_discord', '_run_remove_feishu', '_start_dingding_runtime_after_setup', '_run_add_dingding', '_run_remove_dingding', '_start_weixin_runtime_after_setup', '_run_add_weixin', '_run_remove_weixin', '_run_add_wecom', '_run_remove_wecom'] +__all__ = [ + "_run_add_discord", + "_start_discord_runtime_after_setup", + "_start_feishu_runtime_after_setup", + "_run_add_feishu", + "_run_remove_discord", + "_run_remove_feishu", + "_start_dingding_runtime_after_setup", + "_run_add_dingding", + "_run_remove_dingding", + "_start_weixin_runtime_after_setup", + "_run_add_weixin", + "_run_remove_weixin", + "_run_add_wecom", + "_run_remove_wecom", +] diff --git a/apps/gateway/gateway_main_wizard_binding.py b/apps/gateway/gateway_main_wizard_binding.py index 3302c59..74ebba8 100644 --- a/apps/gateway/gateway_main_wizard_binding.py +++ b/apps/gateway/gateway_main_wizard_binding.py @@ -1,71 +1,11 @@ """Gateway setup wizard helpers.""" from __future__ import annotations -import asyncio -from argparse import SUPPRESS, ArgumentParser, Namespace -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass -from datetime import UTC, datetime -import getpass -import apps.cli.wizard as cli_wizard -import importlib.util import json -import os from pathlib import Path -import re -import shlex -import signal -import subprocess -import sys -import time -from wsgiref.simple_server import make_server from apps.cli.runtime import CliRuntime -from apps.cli.shell import ( - Align, - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_LIGHT, - BRAND_MUTED, - Console, - Group, - Panel, - RICH_AVAILABLE, - Table, - Text, - _resolve_elephant_version, - render_elephant_mark, -) -from apps.provider_runtime import load_runtime_local_secret_env -from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir -from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID -from . import ( - DEFAULT_DINGDING_CLIENT_ID_ENV, - DEFAULT_DINGDING_CLIENT_SECRET_ENV, - DEFAULT_DINGDING_ROBOT_CODE_ENV, - DEFAULT_DISCORD_BOT_TOKEN_ENV, - DEFAULT_FEISHU_APP_ID_ENV, - DEFAULT_FEISHU_APP_SECRET_ENV, - DEFAULT_FEISHU_EVENT_PATH, - FEISHU_ADAPTER_ID, - GatewayHttpService, - GatewayManagedRuntime, - GatewayManagedService, - SUPPORTED_DINGDING_TRANSPORTS, - SUPPORTED_DISCORD_TRANSPORTS, - SUPPORTED_FEISHU_TRANSPORTS, - SUPPORTED_WECOM_TRANSPORTS, - SUPPORTED_WEIXIN_TRANSPORTS, - build_gateway_app, - build_gateway_plugin_registry, - create_gateway_web_app, -) -from .dingding import DINGTALK_STREAM_PIP_SPEC, DingdingGatewayService -from .discord import DISCORD_PY_PIP_SPEC, DiscordGatewayService -from .feishu import FEISHU_SDK_PIP_SPEC, FeishuGatewayService -from .wecom import WecomGatewayService -from .weixin import WeixinGatewayService try: from prompt_toolkit.application import Application @@ -91,6 +31,7 @@ from .gateway_main_wizard_ui import * # noqa: F401,F403 + def _load_gateway_control_runtime( *, profile_dir: Path | None, @@ -103,9 +44,17 @@ def _load_gateway_control_runtime( return None try: return CliRuntime.create(state_dir=state_dir) - except (OSError, RuntimeError, ValueError, TypeError, KeyError, json.JSONDecodeError): + except ( + OSError, + RuntimeError, + ValueError, + TypeError, + KeyError, + json.JSONDecodeError, + ): return None + def _gateway_elephant_choices( runtime: CliRuntime, *, @@ -159,6 +108,7 @@ def _gateway_elephant_choices( ) return tuple(choices) + def _gateway_session_choices( runtime: CliRuntime, *, @@ -210,6 +160,7 @@ def _gateway_session_choices( ) return tuple(choices) + def _prompt_gateway_control_binding( *, runtime: CliRuntime | None, @@ -298,4 +249,9 @@ def _prompt_gateway_control_binding( return elephant_id, str(session_answer).strip() -__all__ = ['_load_gateway_control_runtime', '_gateway_elephant_choices', '_gateway_session_choices', '_prompt_gateway_control_binding'] +__all__ = [ + "_load_gateway_control_runtime", + "_gateway_elephant_choices", + "_gateway_session_choices", + "_prompt_gateway_control_binding", +] diff --git a/apps/gateway/gateway_main_wizard_providers.py b/apps/gateway/gateway_main_wizard_providers.py index e90cd1b..e9fe41b 100644 --- a/apps/gateway/gateway_main_wizard_providers.py +++ b/apps/gateway/gateway_main_wizard_providers.py @@ -1,71 +1,8 @@ """Gateway setup wizard helpers.""" from __future__ import annotations -import asyncio -from argparse import SUPPRESS, ArgumentParser, Namespace -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass -from datetime import UTC, datetime -import getpass -import apps.cli.wizard as cli_wizard -import importlib.util -import json -import os -from pathlib import Path -import re -import shlex -import signal -import subprocess -import sys -import time -from wsgiref.simple_server import make_server +from collections.abc import Sequence -from apps.cli.runtime import CliRuntime -from apps.cli.shell import ( - Align, - BRAND_ACCENT, - BRAND_ACCENT_STRONG, - BRAND_LIGHT, - BRAND_MUTED, - Console, - Group, - Panel, - RICH_AVAILABLE, - Table, - Text, - _resolve_elephant_version, - render_elephant_mark, -) -from apps.provider_runtime import load_runtime_local_secret_env -from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir -from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID - -from . import ( - DEFAULT_DINGDING_CLIENT_ID_ENV, - DEFAULT_DINGDING_CLIENT_SECRET_ENV, - DEFAULT_DINGDING_ROBOT_CODE_ENV, - DEFAULT_DISCORD_BOT_TOKEN_ENV, - DEFAULT_FEISHU_APP_ID_ENV, - DEFAULT_FEISHU_APP_SECRET_ENV, - DEFAULT_FEISHU_EVENT_PATH, - FEISHU_ADAPTER_ID, - GatewayHttpService, - GatewayManagedRuntime, - GatewayManagedService, - SUPPORTED_DINGDING_TRANSPORTS, - SUPPORTED_DISCORD_TRANSPORTS, - SUPPORTED_FEISHU_TRANSPORTS, - SUPPORTED_WECOM_TRANSPORTS, - SUPPORTED_WEIXIN_TRANSPORTS, - build_gateway_app, - build_gateway_plugin_registry, - create_gateway_web_app, -) -from .dingding import DINGTALK_STREAM_PIP_SPEC, DingdingGatewayService -from .discord import DISCORD_PY_PIP_SPEC, DiscordGatewayService -from .feishu import FEISHU_SDK_PIP_SPEC, FeishuGatewayService -from .wecom import WecomGatewayService -from .weixin import WeixinGatewayService try: from prompt_toolkit.application import Application @@ -91,6 +28,7 @@ from .gateway_main_wizard_ui import * # noqa: F401,F403 + def _run_interactive_feishu_wizard( *, account_id: str, @@ -148,6 +86,7 @@ def _run_interactive_feishu_wizard( continue return state + def _run_interactive_discord_wizard( *, account_id: str, @@ -169,9 +108,7 @@ def _run_interactive_discord_wizard( allow_guild_ids=tuple(str(value).strip() for value in allow_guild_ids if str(value).strip()), allow_channel_ids=tuple(str(value).strip() for value in allow_channel_ids if str(value).strip()), ) - steps = ( - "bot_token_value", - ) + steps = ("bot_token_value",) step_index = 0 while step_index < len(steps): step = steps[step_index] @@ -313,7 +250,7 @@ def _run_interactive_weixin_wizard( if step == "wxhook_host": answer = _gateway_wizard_text_prompt( "wxhook Server", - f"wxhook API address (host:port) for sending replies.", + "wxhook API address (host:port) for sending replies.", default=f"{state.wxhook_host}:{state.wxhook_port}", allow_back=True, ) @@ -328,7 +265,7 @@ def _run_interactive_weixin_wizard( if step == "callback_host": answer = _gateway_wizard_text_prompt( "Callback Server", - f"Callback server address (host:port) for receiving WeChat messages.", + "Callback server address (host:port) for receiving WeChat messages.", default=f"{state.callback_host}:{state.callback_port}", allow_back=True, ) @@ -396,4 +333,13 @@ def _run_interactive_wecom_wizard( return state -__all__ = ['_run_interactive_feishu_wizard', '_run_interactive_discord_wizard', '_print_gateway_dingding_wizard_intro', '_print_gateway_weixin_wizard_intro', '_print_gateway_wecom_wizard_intro', '_run_interactive_dingding_wizard', '_run_interactive_weixin_wizard', '_run_interactive_wecom_wizard'] +__all__ = [ + "_run_interactive_feishu_wizard", + "_run_interactive_discord_wizard", + "_print_gateway_dingding_wizard_intro", + "_print_gateway_weixin_wizard_intro", + "_print_gateway_wecom_wizard_intro", + "_run_interactive_dingding_wizard", + "_run_interactive_weixin_wizard", + "_run_interactive_wecom_wizard", +] diff --git a/apps/gateway/gateway_main_wizard_ui.py b/apps/gateway/gateway_main_wizard_ui.py index 1a528ea..840517b 100644 --- a/apps/gateway/gateway_main_wizard_ui.py +++ b/apps/gateway/gateway_main_wizard_ui.py @@ -1,31 +1,18 @@ """Gateway setup wizard helpers.""" from __future__ import annotations -import asyncio -from argparse import SUPPRESS, ArgumentParser, Namespace -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass -from datetime import UTC, datetime +from dataclasses import dataclass import getpass import apps.cli.wizard as cli_wizard import importlib.util -import json -import os -from pathlib import Path -import re import shlex -import signal import subprocess import sys -import time -from wsgiref.simple_server import make_server from apps.cli.cli_main_support import _render_cli_banner_mark -from apps.cli.runtime import CliRuntime from apps.cli.shell import ( Align, BRAND_ACCENT, - BRAND_ACCENT_STRONG, BRAND_LIGHT, BRAND_MUTED, Console, @@ -36,36 +23,17 @@ Text, _resolve_elephant_version, ) -from apps.provider_runtime import load_runtime_local_secret_env -from apps.runtime_layout import default_cli_state_dir, default_gateway_state_dir -from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID from . import ( - DEFAULT_DINGDING_CLIENT_ID_ENV, - DEFAULT_DINGDING_CLIENT_SECRET_ENV, - DEFAULT_DINGDING_ROBOT_CODE_ENV, - DEFAULT_DISCORD_BOT_TOKEN_ENV, - DEFAULT_FEISHU_APP_ID_ENV, - DEFAULT_FEISHU_APP_SECRET_ENV, - DEFAULT_FEISHU_EVENT_PATH, - FEISHU_ADAPTER_ID, - GatewayHttpService, - GatewayManagedRuntime, - GatewayManagedService, SUPPORTED_DINGDING_TRANSPORTS, SUPPORTED_DISCORD_TRANSPORTS, SUPPORTED_FEISHU_TRANSPORTS, SUPPORTED_WECOM_TRANSPORTS, SUPPORTED_WEIXIN_TRANSPORTS, - build_gateway_app, - build_gateway_plugin_registry, - create_gateway_web_app, ) -from .dingding import DINGTALK_STREAM_PIP_SPEC, DingdingGatewayService -from .discord import DISCORD_PY_PIP_SPEC, DiscordGatewayService -from .feishu import FEISHU_SDK_PIP_SPEC, FeishuGatewayService -from .wecom import WecomGatewayService -from .weixin import WeixinGatewayService +from .dingding import DINGTALK_STREAM_PIP_SPEC +from .discord import DISCORD_PY_PIP_SPEC +from .feishu import FEISHU_SDK_PIP_SPEC try: from prompt_toolkit.application import Application @@ -88,6 +56,7 @@ PromptStyle = None PROMPT_TOOLKIT_DIALOGS_AVAILABLE = False + @dataclass(frozen=True) class GatewayRuntimeRecord: runtime_id: str @@ -110,6 +79,7 @@ class GatewayRuntimeRecord: last_error: str | None = None transport: str | None = None + @dataclass(slots=True) class FeishuGatewayWizardState: account_id: str @@ -122,6 +92,7 @@ class FeishuGatewayWizardState: enabled: bool allow_group_chats: bool + @dataclass(slots=True) class DiscordGatewayWizardState: account_id: str @@ -166,6 +137,7 @@ class WecomGatewayWizardState: enabled: bool allow_group_chats: bool + GATEWAY_WIZARD_MAX_VISIBLE_CHOICES = cli_wizard.WIZARD_MAX_VISIBLE_CHOICES GatewayWizardChoice = cli_wizard.WizardChoice _GatewayWizardBackSignal = cli_wizard._WizardBackSignal @@ -179,11 +151,13 @@ class WecomGatewayWizardState: _GATEWAY_MANUAL_EGG = "__elephant.gateway.manual_elephant__" _GATEWAY_FOLLOW_LATEST_SESSION = "__elephant.gateway.follow_latest_session__" + def _gateway_wizard_choice_label(choice: GatewayWizardChoice) -> str: if not choice.emoji: return choice.label return f"{choice.emoji} {choice.label}" + def _gateway_wizard_choice_window( total: int, selected: int, @@ -204,6 +178,7 @@ def _gateway_wizard_choice_window( start = end - max_visible return start, end + def _gateway_wizard_choice_fragments( title: str, prompt: str, @@ -240,6 +215,7 @@ def _gateway_wizard_choice_fragments( fragments.append(("class:hint", "\nEnter confirms · ↑/↓ or j/k moves")) return fragments + def _gateway_wizard_choice_menu( title: str, prompt: str, @@ -317,6 +293,7 @@ def _cancel(event) -> None: return GATEWAY_WIZARD_BACK return str(answer or default) + def _gateway_prompt_value( label: str, *, @@ -331,6 +308,7 @@ def _gateway_prompt_value( return default or "" return "" + def _gateway_wizard_text_prompt( title: str, prompt: str, @@ -356,6 +334,7 @@ def _gateway_wizard_text_prompt( preserve_default_on_empty=preserve_default_on_empty, ) + def _gateway_wizard_choice_prompt( title: str, prompt: str, @@ -398,6 +377,7 @@ def _gateway_wizard_choice_prompt( return choice.value print(" choose a listed number, transport id, or label.") + def _gateway_bool_choices( *, enabled_label: str, @@ -420,6 +400,7 @@ def _gateway_bool_choices( ), ) + def _gateway_bool_prompt( title: str, prompt: str, @@ -447,6 +428,7 @@ def _gateway_bool_prompt( return GATEWAY_WIZARD_BACK return str(answer) == "yes" + def _feishu_transport_choices() -> tuple[GatewayWizardChoice, ...]: details = { "long-connection": "Use Feishu long connection for a local bridge without webhook setup.", @@ -461,6 +443,7 @@ def _feishu_transport_choices() -> tuple[GatewayWizardChoice, ...]: for transport in SUPPORTED_FEISHU_TRANSPORTS ) + def _discord_transport_choices() -> tuple[GatewayWizardChoice, ...]: details = { "gateway": "Use the managed Discord gateway runtime for local IM bring-up.", @@ -505,6 +488,7 @@ def _weixin_transport_choices() -> tuple[GatewayWizardChoice, ...]: for transport in SUPPORTED_WEIXIN_TRANSPORTS ) + def _wecom_transport_choices() -> tuple[GatewayWizardChoice, ...]: details = { "websocket": "Use the WeCom AI Bot WebSocket transport for local IM bring-up.", @@ -564,14 +548,17 @@ def _im_setup_choices(*, allow_skip: bool) -> tuple[GatewayWizardChoice, ...]: ) return tuple(choices) + def _center_brand_block(renderable): if Align is None: return renderable return Align.center(renderable) + def _confirm_gateway_wizard_intro() -> bool: return True + def _print_gateway_feishu_wizard_intro() -> bool: if not RICH_AVAILABLE or Table is None or Panel is None or Group is None: print("💬 Elephant Agent Gateway // Feishu setup") @@ -592,13 +579,22 @@ def _print_gateway_feishu_wizard_intro() -> bool: flow = Text() flow.append("🧭 IM setup flow\n", style=f"bold {BRAND_ACCENT}") flow.append("1 · Choose the Feishu account and long-connection surface\n", style=BRAND_LIGHT) - flow.append("2 · Paste App ID and App Secret directly into the local IM secret store\n", style=BRAND_LIGHT) + flow.append( + "2 · Paste App ID and App Secret directly into the local IM secret store\n", + style=BRAND_LIGHT, + ) flow.append("3 · Decide how the control bridge routes herd\n", style=BRAND_LIGHT) - flow.append("4 · Start the bridge with credentials kept out of profile.json", style=BRAND_LIGHT) + flow.append( + "4 · Start the bridge with credentials kept out of profile.json", + style=BRAND_LIGHT, + ) portal = Text() portal.append("Feishu console checklist\n", style=f"bold {BRAND_ACCENT}") portal.append("Capability · Add App Capability → Bot\n", style=BRAND_LIGHT) - portal.append("Events · Event Subscriptions → add `im.message.receive_v1`\n", style=BRAND_LIGHT) + portal.append( + "Events · Event Subscriptions → add `im.message.receive_v1`\n", + style=BRAND_LIGHT, + ) portal.append("Transport · Use Long Connection for local IM bring-up\n", style=BRAND_LIGHT) portal.append( "Permissions · Enable `im:message`, `im:message.p2p_msg:readonly`, and `im:message:send_as_bot`", @@ -636,6 +632,7 @@ def _print_gateway_feishu_wizard_intro() -> bool: ) return _confirm_gateway_wizard_intro() + def _print_gateway_discord_wizard_intro() -> bool: if not RICH_AVAILABLE or Table is None or Panel is None or Group is None: print("💬 Elephant Agent Gateway // Discord setup") @@ -659,8 +656,14 @@ def _print_gateway_discord_wizard_intro() -> bool: ) flow = Text() flow.append("🧭 IM setup flow\n", style=f"bold {BRAND_ACCENT}") - flow.append("1 · Choose the Discord account and managed gateway surface\n", style=BRAND_LIGHT) - flow.append("2 · Paste the bot token directly into the local IM secret file\n", style=BRAND_LIGHT) + flow.append( + "1 · Choose the Discord account and managed gateway surface\n", + style=BRAND_LIGHT, + ) + flow.append( + "2 · Paste the bot token directly into the local IM secret file\n", + style=BRAND_LIGHT, + ) flow.append( "3 · Choose the elephant Discord should route new conversations into, or pin a known session\n", style=BRAND_LIGHT, @@ -671,8 +674,14 @@ def _print_gateway_discord_wizard_intro() -> bool: ) portal = Text() portal.append("Discord portal checklist\n", style=f"bold {BRAND_ACCENT}") - portal.append("OAuth2 · URL Generator → include the `bot` scope when inviting the app\n", style=BRAND_LIGHT) - portal.append("Bot · Privileged Gateway Intents → enable `MESSAGE_CONTENT`\n", style=BRAND_LIGHT) + portal.append( + "OAuth2 · URL Generator → include the `bot` scope when inviting the app\n", + style=BRAND_LIGHT, + ) + portal.append( + "Bot · Privileged Gateway Intents → enable `MESSAGE_CONTENT`\n", + style=BRAND_LIGHT, + ) portal.append( "Permissions · Grant `View Channels`, `Send Messages`, `Send Messages in Threads`, and `Read Message History`\n", style=BRAND_LIGHT, @@ -710,6 +719,7 @@ def _print_gateway_discord_wizard_intro() -> bool: ) return True + def _gateway_wizard_secret_prompt( title: str, prompt: str, @@ -732,6 +742,7 @@ def _gateway_wizard_secret_prompt( return GATEWAY_WIZARD_BACK return answer + def _print_gateway_setup_paused(service_name: str) -> None: print(f"{service_name} IM setup paused") print(" No IM changes were written.") @@ -739,6 +750,7 @@ def _print_gateway_setup_paused(service_name: str) -> None: print(" - elephant gateway") print(" - elephant gateway doctor") + def _ensure_feishu_sdk_available(*, reason: str) -> bool: if importlib.util.find_spec("lark_oapi") is not None: return False @@ -756,8 +768,7 @@ def _ensure_feishu_sdk_available(*, reason: str) -> bool: except (OSError, subprocess.CalledProcessError) as exc: rendered = " ".join(shlex.quote(part) for part in command) raise SystemExit( - "Elephant Agent could not automatically install the Feishu SDK. " - f"Run `{rendered}` and retry." + f"Elephant Agent could not automatically install the Feishu SDK. Run `{rendered}` and retry." ) from exc if importlib.util.find_spec("lark_oapi") is None: rendered = " ".join(shlex.quote(part) for part in command) @@ -768,6 +779,7 @@ def _ensure_feishu_sdk_available(*, reason: str) -> bool: print("Feishu support is ready.") return True + def _ensure_discord_sdk_available(*, reason: str) -> bool: if importlib.util.find_spec("discord") is not None: return False @@ -785,8 +797,7 @@ def _ensure_discord_sdk_available(*, reason: str) -> bool: except (OSError, subprocess.CalledProcessError) as exc: rendered = " ".join(shlex.quote(part) for part in command) raise SystemExit( - "Elephant Agent could not automatically install Discord support. " - f"Run `{rendered}` and retry." + f"Elephant Agent could not automatically install Discord support. Run `{rendered}` and retry." ) from exc if importlib.util.find_spec("discord") is None: rendered = " ".join(shlex.quote(part) for part in command) @@ -815,8 +826,7 @@ def _ensure_dingding_sdk_available(*, reason: str) -> bool: except (OSError, subprocess.CalledProcessError) as exc: rendered = " ".join(shlex.quote(part) for part in command) raise SystemExit( - "Elephant Agent could not automatically install DingDing support. " - f"Run `{rendered}` and retry." + f"Elephant Agent could not automatically install DingDing support. Run `{rendered}` and retry." ) from exc if importlib.util.find_spec("dingtalk_stream") is None: rendered = " ".join(shlex.quote(part) for part in command) @@ -850,8 +860,7 @@ def _ensure_weixin_sdk_available(*, reason: str) -> bool: except (OSError, subprocess.CalledProcessError) as exc: rendered = " ".join(shlex.quote(part) for part in command) raise SystemExit( - "Elephant Agent could not automatically install WeChat (iLink) support. " - f"Run `{rendered}` and retry." + f"Elephant Agent could not automatically install WeChat (iLink) support. Run `{rendered}` and retry." ) from exc print("WeChat support is ready.") return True @@ -879,12 +888,12 @@ def _ensure_wecom_sdk_available(*, reason: str) -> bool: except (OSError, subprocess.CalledProcessError) as exc: rendered = " ".join(shlex.quote(part) for part in command) raise SystemExit( - "Elephant Agent could not automatically install WeCom support. " - f"Run `{rendered}` and retry." + f"Elephant Agent could not automatically install WeCom support. Run `{rendered}` and retry." ) from exc print("WeCom support is ready.") return True + def _parse_gateway_id_csv(value: str) -> tuple[str, ...]: return tuple(dict.fromkeys(part.strip() for part in value.split(",") if part.strip())) diff --git a/apps/gateway/platforms/__init__.py b/apps/gateway/platforms/__init__.py index 9daf4fb..4fabff6 100644 --- a/apps/gateway/platforms/__init__.py +++ b/apps/gateway/platforms/__init__.py @@ -3,10 +3,18 @@ from __future__ import annotations from .chat_bot import CHAT_BOT_PLATFORM, ChatBotGatewayPlatform, ChatBotMessagingAdapter -from .dingding import DINGDING_PLATFORM, DingdingGatewayPlatform, DingdingMessagingAdapter +from .dingding import ( + DINGDING_PLATFORM, + DingdingGatewayPlatform, + DingdingMessagingAdapter, +) from .discord import DISCORD_PLATFORM, DiscordGatewayPlatform, DiscordMessagingAdapter from .feishu import FEISHU_PLATFORM, FeishuGatewayPlatform, FeishuMessagingAdapter -from .telegram import TELEGRAM_PLATFORM, TelegramGatewayPlatform, TelegramMessagingAdapter +from .telegram import ( + TELEGRAM_PLATFORM, + TelegramGatewayPlatform, + TelegramMessagingAdapter, +) from .wecom import WECOM_PLATFORM, WecomGatewayPlatform, WecomMessagingAdapter from .weixin import WEIXIN_PLATFORM, WeixinGatewayPlatform, WeixinMessagingAdapter from .webhook import WEBHOOK_PLATFORM, WebhookGatewayPlatform, WebhookMessagingAdapter diff --git a/apps/gateway/platforms/chat_bot.py b/apps/gateway/platforms/chat_bot.py index 1ed980a..da767fe 100644 --- a/apps/gateway/platforms/chat_bot.py +++ b/apps/gateway/platforms/chat_bot.py @@ -5,7 +5,11 @@ from collections.abc import Mapping from dataclasses import dataclass -from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID, GatewayExchange, GatewayInboundMessage +from packages.gateway_core import ( + DEFAULT_GATEWAY_ACCOUNT_ID, + GatewayExchange, + GatewayInboundMessage, +) from ..plugins import GatewayAdapterDescriptor, GatewayServicePluginRegistration from ..runtime_app import GatewayApp @@ -72,6 +76,7 @@ def receive_text( }, ) + @dataclass(frozen=True, slots=True) class ChatBotGatewayPlatform: key: str = "chat_bot" diff --git a/apps/gateway/platforms/dingding.py b/apps/gateway/platforms/dingding.py index 0d6e362..17eff9c 100644 --- a/apps/gateway/platforms/dingding.py +++ b/apps/gateway/platforms/dingding.py @@ -95,9 +95,7 @@ def normalize_event( reply_to_message_id=message_id, attachment_refs=(), policy_hint=_policy_hint( - target_trusted_default=( - target_trusted_default if target_trusted is None else target_trusted - ), + target_trusted_default=(target_trusted_default if target_trusted is None else target_trusted), consent_default=consent_default if consent_given is None else consent_given, is_external_default=external_default if is_external is None else is_external, audience_scope=chat_type, @@ -138,6 +136,7 @@ def receive_event( def build_reply_request(self, outbound: GatewayOutboundMessage) -> Mapping[str, object]: from ..dingding_support import _dingding_reply_request + if outbound.adapter_id != self.adapter_id: raise ValueError("dingding reply request requires a dingding outbound message") return _dingding_reply_request(outbound) diff --git a/apps/gateway/platforms/discord.py b/apps/gateway/platforms/discord.py index 8e97c16..60fef21 100644 --- a/apps/gateway/platforms/discord.py +++ b/apps/gateway/platforms/discord.py @@ -102,11 +102,7 @@ def normalize_event( sender=_sender_ref( str(author.get("id") or ""), display_name=_discord_display_name(author, member=member), - username=( - f"@{str(author['username'])}" - if author.get("username") is not None - else None - ), + username=(f"@{str(author['username'])}" if author.get("username") is not None else None), is_bot=bool(author.get("bot", False)), metadata={"global_name": str(author.get("global_name") or "")}, ), @@ -114,9 +110,7 @@ def normalize_event( reply_to_message_id=reply_to_message_id, attachment_refs=attachment_refs, policy_hint=_policy_hint( - target_trusted_default=( - target_trusted_default if target_trusted is None else target_trusted - ), + target_trusted_default=(target_trusted_default if target_trusted is None else target_trusted), consent_default=consent_default if consent_given is None else consent_given, is_external_default=external_default if is_external is None else is_external, audience_scope=chat_type, diff --git a/apps/gateway/platforms/feishu.py b/apps/gateway/platforms/feishu.py index 71bbd19..1c6daad 100644 --- a/apps/gateway/platforms/feishu.py +++ b/apps/gateway/platforms/feishu.py @@ -63,16 +63,10 @@ def normalize_event( tenant_key = ( str(header["tenant_key"]) if header.get("tenant_key") is not None - else ( - str(event["tenant_key"]) - if event.get("tenant_key") is not None - else None - ) + else (str(event["tenant_key"]) if event.get("tenant_key") is not None else None) ) resolved_account_id = account_id or ( - str(header["app_id"]) - if header.get("app_id") is not None - else DEFAULT_GATEWAY_ACCOUNT_ID + str(header["app_id"]) if header.get("app_id") is not None else DEFAULT_GATEWAY_ACCOUNT_ID ) chat_id = str(message.get("chat_id") or "") @@ -84,9 +78,7 @@ def normalize_event( if not message_id: raise ValueError("feishu message payload requires message_id") root_id = ( - str(message["root_id"]) - if message.get("root_id") is not None and str(message["root_id"]).strip() - else None + str(message["root_id"]) if message.get("root_id") is not None and str(message["root_id"]).strip() else None ) parent_id = ( str(message["parent_id"]) @@ -144,20 +136,14 @@ def normalize_event( is_bot=str(sender.get("sender_type") or "user") != "user", metadata={ "sender_type": str(sender.get("sender_type") or "user"), - "tenant_key": ( - str(sender["tenant_key"]) - if sender.get("tenant_key") is not None - else "" - ), + "tenant_key": (str(sender["tenant_key"]) if sender.get("tenant_key") is not None else ""), }, ), body=_feishu_message_body(message_type, content), reply_to_message_id=parent_id or root_id or message_id, attachment_refs=attachment_refs, policy_hint=_policy_hint( - target_trusted_default=( - target_trusted_default if target_trusted is None else target_trusted - ), + target_trusted_default=(target_trusted_default if target_trusted is None else target_trusted), consent_default=consent_default if consent_given is None else consent_given, is_external_default=external_default if is_external is None else is_external, audience_scope=normalized_chat_type, diff --git a/apps/gateway/platforms/telegram.py b/apps/gateway/platforms/telegram.py index 9576fe0..609a60d 100644 --- a/apps/gateway/platforms/telegram.py +++ b/apps/gateway/platforms/telegram.py @@ -5,7 +5,11 @@ from collections.abc import Mapping from dataclasses import dataclass -from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID, GatewayExchange, GatewayInboundMessage +from packages.gateway_core import ( + DEFAULT_GATEWAY_ACCOUNT_ID, + GatewayExchange, + GatewayInboundMessage, +) from ..plugins import GatewayAdapterDescriptor, GatewayServicePluginRegistration from ..runtime_app import GatewayApp @@ -55,9 +59,7 @@ def receive_update( update_kind = "callback_query" if callback_query.get("data") is not None: callback_data = str(callback_query["data"]) - if not isinstance(message.get("from"), Mapping) and isinstance( - callback_query.get("from"), Mapping - ): + if not isinstance(message.get("from"), Mapping) and isinstance(callback_query.get("from"), Mapping): message = { **message, "from": callback_query["from"], @@ -73,11 +75,7 @@ def receive_update( thread_id = message.get("message_thread_id") normalized_chat_type = _normalized_chat_type(chat_type) attachment_refs = _attachment_refs(_telegram_attachment_ids(message)) - message_id = ( - str(message["message_id"]) - if message.get("message_id") is not None - else None - ) + message_id = str(message["message_id"]) if message.get("message_id") is not None else None metadata = { "channel": "telegram", "chat_type": chat_type, @@ -91,9 +89,7 @@ def receive_update( if thread_id is not None: metadata["message_thread_id"] = str(thread_id) if message.get("reply_to_message") is not None: - metadata["reply_to_message_id"] = str( - dict(message["reply_to_message"]).get("message_id") or "" - ) + metadata["reply_to_message_id"] = str(dict(message["reply_to_message"]).get("message_id") or "") if callback_data is not None: metadata["callback_data"] = callback_data target_trusted_default, consent_default, external_default = _telegram_delivery_defaults(chat_type) @@ -115,22 +111,14 @@ def receive_update( sender=_sender_ref( str(sender["id"]), display_name=_telegram_display_name(sender), - username=( - f"@{str(sender['username'])}" - if sender.get("username") is not None - else None - ), + username=(f"@{str(sender['username'])}" if sender.get("username") is not None else None), is_bot=bool(sender.get("is_bot", False)), ), body=str(message.get("text") or message.get("caption") or callback_data or "telegram-event"), - reply_to_message_id=( - str(metadata.get("reply_to_message_id") or message_id or "") or None - ), + reply_to_message_id=(str(metadata.get("reply_to_message_id") or message_id or "") or None), attachment_refs=attachment_refs, policy_hint=_policy_hint( - target_trusted_default=( - target_trusted_default if target_trusted is None else target_trusted - ), + target_trusted_default=(target_trusted_default if target_trusted is None else target_trusted), consent_default=consent_default if consent_given is None else consent_given, is_external_default=external_default if is_external is None else is_external, audience_scope=normalized_chat_type, diff --git a/apps/gateway/platforms/webhook.py b/apps/gateway/platforms/webhook.py index f61020d..89b6734 100644 --- a/apps/gateway/platforms/webhook.py +++ b/apps/gateway/platforms/webhook.py @@ -5,7 +5,11 @@ from collections.abc import Mapping from dataclasses import dataclass -from packages.gateway_core import DEFAULT_GATEWAY_ACCOUNT_ID, GatewayExchange, GatewayInboundMessage +from packages.gateway_core import ( + DEFAULT_GATEWAY_ACCOUNT_ID, + GatewayExchange, + GatewayInboundMessage, +) from ..plugins import GatewayAdapterDescriptor, GatewayServicePluginRegistration from ..runtime_app import GatewayApp @@ -45,11 +49,7 @@ def receive_event( account=_account_ref( self.adapter_id, account_id=str(payload.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID), - tenant_id=( - str(payload["tenant_id"]) - if payload.get("tenant_id") is not None - else None - ), + tenant_id=(str(payload["tenant_id"]) if payload.get("tenant_id") is not None else None), surface="generic-webhook", ), conversation=_conversation_ref( @@ -58,21 +58,13 @@ def receive_event( ), sender=_sender_ref( str(payload["external_user_id"]), - display_name=( - str(payload["display_name"]) - if payload.get("display_name") is not None - else None - ), + display_name=(str(payload["display_name"]) if payload.get("display_name") is not None else None), ), body=str(payload["body"]), reply_to_message_id=( str(payload["reply_to_message_id"]) if payload.get("reply_to_message_id") is not None - else ( - str(payload["reply_to_event_id"]) - if payload.get("reply_to_event_id") is not None - else None - ) + else (str(payload["reply_to_event_id"]) if payload.get("reply_to_event_id") is not None else None) ), attachment_refs=attachment_refs, policy_hint=_policy_hint( @@ -98,6 +90,7 @@ def receive_event( metadata=response_metadata, ) + @dataclass(frozen=True, slots=True) class WebhookGatewayPlatform: key: str = "webhook" diff --git a/apps/gateway/platforms/wecom.py b/apps/gateway/platforms/wecom.py index 7bfa300..d8e4d04 100644 --- a/apps/gateway/platforms/wecom.py +++ b/apps/gateway/platforms/wecom.py @@ -93,9 +93,7 @@ def normalize_event( reply_to_message_id=message_id, attachment_refs=(), policy_hint=_policy_hint( - target_trusted_default=( - target_trusted_default if target_trusted is None else target_trusted - ), + target_trusted_default=(target_trusted_default if target_trusted is None else target_trusted), consent_default=consent_default if consent_given is None else consent_given, is_external_default=external_default if is_external is None else is_external, audience_scope=chat_type, diff --git a/apps/gateway/platforms/weixin.py b/apps/gateway/platforms/weixin.py index 103730a..c0a699f 100644 --- a/apps/gateway/platforms/weixin.py +++ b/apps/gateway/platforms/weixin.py @@ -99,9 +99,7 @@ def normalize_event( reply_to_message_id=message_id, attachment_refs=(), policy_hint=_policy_hint( - target_trusted_default=( - target_trusted_default if target_trusted is None else target_trusted - ), + target_trusted_default=(target_trusted_default if target_trusted is None else target_trusted), consent_default=consent_default if consent_given is None else consent_given, is_external_default=external_default if is_external is None else is_external, audience_scope=chat_type, diff --git a/apps/gateway/plugins.py b/apps/gateway/plugins.py index 1b56dbe..b938fbf 100644 --- a/apps/gateway/plugins.py +++ b/apps/gateway/plugins.py @@ -236,16 +236,10 @@ def create_service(self, key: str, *, app: Any, **kwargs: object) -> object: return registration.factory(app=app, **kwargs) def adapter_id_map(self) -> dict[str, str]: - return { - key: registration.descriptor.adapter_id - for key, registration in self._adapters.items() - } + return {key: registration.descriptor.adapter_id for key, registration in self._adapters.items()} def adapter_setup_payload(self) -> dict[str, dict[str, object]]: - return { - key: registration.descriptor.summary_payload() - for key, registration in self._adapters.items() - } + return {key: registration.descriptor.summary_payload() for key, registration in self._adapters.items()} def configured_service_keys( self, diff --git a/apps/gateway/proactive_ask_job.py b/apps/gateway/proactive_ask_job.py index 9cd2f32..e15aad8 100644 --- a/apps/gateway/proactive_ask_job.py +++ b/apps/gateway/proactive_ask_job.py @@ -25,6 +25,7 @@ @dataclass(frozen=True, slots=True) class ProactiveAskTickResult: """Result summary from a single proactive ask cron tick.""" + scanned: int = 0 eligible: int = 0 enqueued: int = 0 @@ -185,6 +186,7 @@ def run_proactive_ask_tick( # Helpers # --------------------------------------------------------------------------- + def _personal_model_id(app: Any, record: Any) -> str | None: if record.state_id: try: diff --git a/apps/gateway/runtime_app.py b/apps/gateway/runtime_app.py index 21ed8e1..b87232e 100644 --- a/apps/gateway/runtime_app.py +++ b/apps/gateway/runtime_app.py @@ -1,83 +1,51 @@ """Gateway runtime application.""" - from __future__ import annotations from collections.abc import Mapping from dataclasses import dataclass, replace from datetime import datetime, timezone -import hashlib -from pathlib import Path -import re -import tempfile -from typing import Any from uuid import uuid4 -from apps.provider_runtime import ( - load_provider_profile, - provider_profile_from_payload, -) -from packages.auth import AuthProfile, EnvironmentSecretStore, PersistentAuthProfileStore, ProfileCredentialResolver -from packages.models import SurfaceModelProviderCapability -from packages.models.runtime_capability import provider_fallback_summary, provider_profile_summary -from packages.capabilities.runtime import ( - CapabilityDescriptor, - ContextCapability, - RecallCapability, - ModelProviderCapability, - TelemetrySinkCapability, -) +from packages.auth import AuthProfile, PersistentAuthProfileStore from packages.context import ( - ContextRuntime, next_session_context_epoch, ) -from packages.context.epoch_store import EpochStore, FileEpochStore +from packages.context.epoch_store import EpochStore from packages.context.compress import compress_epoch from packages.contracts.runtime import ( - ContextBundle, EventEnvelope, - ExecutionResult, RecallEvidence, - PersonalModelRuntimeState, PromptMessage, ) from packages.contracts import Episode from packages.gateway_core import ( - DEFAULT_GATEWAY_ACCOUNT_ID, - FileGatewayIdentityStore, - FileGatewaySessionStore, - GatewayAccountRef, GatewayAttachmentRef, - GatewayConversationRef, - GatewayCoreDependencies, GatewayCoreService, GatewayExchange, GatewayIdentityRecord, GatewayInboundMessage, - GatewayOutboundMessage, - GatewayPolicyHint, GatewayRouteState, - GatewaySenderRef, - InMemoryGatewayIdentityStore, - InMemoryGatewaySessionStore, ) -from packages.kernel import KernelDependencies, KernelOutcome, KernelService, KernelSourceRequest, ReconciliationPipeline, StateReconciler +from packages.kernel import ( + KernelOutcome, + KernelService, + KernelSourceRequest, + ReconciliationPipeline, + StateReconciler, +) from packages.kernel.context_compaction import ( flush_projection_cache, ) from packages.evidence.recall_runtime import RecallRuntime from packages.state import ( - DEFAULT_ELEPHANT_IDENTITY_TEXT, LoadedProfile, - ProfileLoader, - build_prompt_contract, ) from packages.state.persistence import resolve_runtime_state -from packages.security.runtime import SecurityPolicy from packages.skills import SkillRuntime from packages.storage import RuntimeStorageRepository from packages.tools import ToolRuntime -from .plugins import GatewayAdapterDescriptor, GatewayPluginRegistry +from .plugins import GatewayPluginRegistry def _episode_status_from_route(status: str) -> str: @@ -88,6 +56,7 @@ def _episode_status_from_route(status: str) -> str: return "closed" return "open" + CHAT_BOT_ADAPTER_ID = "messaging.chat-bot" WEBHOOK_ADAPTER_ID = "messaging.webhook" TELEGRAM_ADAPTER_ID = "messaging.telegram" @@ -95,7 +64,8 @@ def _episode_status_from_route(status: str) -> str: DISCORD_ADAPTER_ID = "messaging.discord" from .runtime_support import * # noqa: F401,F403 -from .runtime_capabilities import GatewayContextCapability, GatewayRecallCapability, GatewayPreviewModelProvider, GatewaySurfaceModelProvider, GatewayTelemetrySink +from .runtime_capabilities import GatewaySurfaceModelProvider, GatewayTelemetrySink + def _aware_gateway_utc(value: datetime) -> datetime: if value.tzinfo is None: @@ -187,7 +157,7 @@ def handle_message( target_trusted=target_trusted, consent_given=consent_given, is_external=is_external, - ) + ) session = self._ensure_runtime_session(route) event = self._event_for_inbound(inbound, episode_id=session.episode_id) outcome = self.kernel.run( @@ -233,9 +203,7 @@ def handle_message( delivery = self.core.deliver( route, body=reply_body or outcome.execution.summary, - reply_to_message_id=reply_to_message_id - or inbound.reply_to_message_id - or inbound.event_id, + reply_to_message_id=reply_to_message_id or inbound.reply_to_message_id or inbound.event_id, attachment_refs=attachment_refs, metadata={ **dict(metadata or {}), @@ -317,9 +285,7 @@ def _reject_unbound_route( delivery = self.core.deliver( route, body=guidance, - reply_to_message_id=reply_to_message_id - or route.inbound.reply_to_message_id - or route.inbound.event_id, + reply_to_message_id=reply_to_message_id or route.inbound.reply_to_message_id or route.inbound.event_id, attachment_refs=attachment_refs, metadata={ **dict(metadata or {}), @@ -376,7 +342,10 @@ def _record_context_epoch(self, session: Episode, outcome: KernelOutcome) -> Non def _run_context_hygiene(self, session_id: str, *, event_id: str, outcome: KernelOutcome | None = None) -> None: execution = outcome.execution if outcome is not None else None - usage_tokens = max(int(getattr(execution, "prompt_tokens", 0) or 0), int(getattr(execution, "total_tokens", 0) or 0)) + usage_tokens = max( + int(getattr(execution, "prompt_tokens", 0) or 0), + int(getattr(execution, "total_tokens", 0) or 0), + ) context_limit = int(getattr(outcome.context, "token_budget", 0) or 0) if outcome is not None else 0 if self.epoch_store is None: return @@ -397,17 +366,27 @@ def _run_context_hygiene(self, session_id: str, *, event_id: str, outcome: Kerne # Persist to episode for dashboard visibility try: with self.repository.connection() as connection: - connection.execute("UPDATE episodes SET exit_summary = ? WHERE episode_id = ?", (compress_result.summary, session_id)) + connection.execute( + "UPDATE episodes SET exit_summary = ? WHERE episode_id = ?", + (compress_result.summary, session_id), + ) connection.commit() except Exception: pass - self.telemetry.emit({ - "event_id": f"telemetry:{session_id}:context-compact:{uuid4().hex}", - "event_type": "kernel.stage", - "session_id": session_id, - "source": "gateway", - "payload": {"stage": "context-compact", "detail": f"method={compress_result.method} messages={compress_result.before_messages}->{compress_result.after_messages}", "recorded_at": datetime.now(timezone.utc).isoformat(), "event_id": event_id}, - }) + self.telemetry.emit( + { + "event_id": f"telemetry:{session_id}:context-compact:{uuid4().hex}", + "event_type": "kernel.stage", + "session_id": session_id, + "source": "gateway", + "payload": { + "stage": "context-compact", + "detail": f"method={compress_result.method} messages={compress_result.before_messages}->{compress_result.after_messages}", + "recorded_at": datetime.now(timezone.utc).isoformat(), + "event_id": event_id, + }, + } + ) flush_projection_cache(self.kernel.dependencies.context) def _llm_compress( @@ -430,7 +409,7 @@ def _llm_compress( if role == "tool": continue # Skip tool results if role == "assistant" and msg.tool_calls and not content: - for call in (msg.tool_calls or ()): + for call in msg.tool_calls or (): pending_tools.append(call.get("name") or call.get("tool_name") or "tool") continue if pending_tools: @@ -477,7 +456,10 @@ def _llm_compress( prompt=f"Summarize this conversation:\n\n{conversation_text}", messages=( PromptMessage(role="system", content=system_prompt), - PromptMessage(role="user", content=f"Summarize this conversation:\n\n{conversation_text}"), + PromptMessage( + role="user", + content=f"Summarize this conversation:\n\n{conversation_text}", + ), ), tools=(), metadata={"source": "gateway-compress"}, @@ -574,13 +556,13 @@ def interrupt_episode( def _ensure_runtime_session(self, route) -> Episode: """Ensure runtime session with correct personal_model_id from state. - + When we have a state_id (identity/companion), load it directly to extract the correct personal_model_id (which links the identity back to its user). This ensures that if the gateway route was created with an identity, we use the authoritative state.personal_model_id, not a potentially stale or incorrect session.profile_id value. - + This fixes the IM mode system prompt injection bug where Zoey (the identity) was being shown as the user's name because personal_model_id was incorrectly set to the state_id instead of the user's personal_model_id. @@ -588,14 +570,14 @@ def _ensure_runtime_session(self, route) -> Episode: session = route.session identity = route.identity runtime_episode_id = identity.episode_id or session.session_id - + # When we have a state_id (bound identity/elephant), load the State directly # to get the authoritative personal_model_id (which links the identity to its user). # This ensures personal_model_id is never confused with state_id or elephant_id. resolved_state = None if identity.state_id: resolved_state = self.repository.load_state(identity.state_id) - + # Fallback to resolve_runtime_state if direct load didn't work if resolved_state is None: resolved_state = resolve_runtime_state( @@ -606,9 +588,11 @@ def _ensure_runtime_session(self, route) -> Episode: elephant_id=identity.elephant_id, required=False, ) - + existing = self.repository.load_episode_state(runtime_episode_id) - idle_gap_seconds = max(0.0, (session.updated_at - existing.updated_at).total_seconds()) if existing is not None else 0.0 + idle_gap_seconds = ( + max(0.0, (session.updated_at - existing.updated_at).total_seconds()) if existing is not None else 0.0 + ) if idle_gap_seconds > 1800: self._clear_idle_context_epoch( runtime_episode_id, diff --git a/apps/gateway/runtime_capabilities.py b/apps/gateway/runtime_capabilities.py index af1899f..9688bd4 100644 --- a/apps/gateway/runtime_capabilities.py +++ b/apps/gateway/runtime_capabilities.py @@ -1,26 +1,15 @@ """Gateway runtime capabilities.""" - from __future__ import annotations from collections.abc import Mapping -from dataclasses import dataclass, replace -from datetime import datetime, timezone -import hashlib -import json +from dataclasses import replace from pathlib import Path -import re -import tempfile from typing import Any from uuid import uuid4 -from apps.provider_runtime import ( - load_provider_profile, - provider_profile_from_payload, -) -from packages.auth import AuthProfile, EnvironmentSecretStore, PersistentAuthProfileStore, ProfileCredentialResolver +from packages.auth import AuthProfile, EnvironmentSecretStore, ProfileCredentialResolver from packages.models import SurfaceModelProviderCapability -from packages.models.runtime_capability import provider_fallback_summary, provider_profile_summary from packages.capabilities.runtime import ( CapabilityDescriptor, ContextCapability, @@ -47,30 +36,10 @@ GenerationModelProfile, SupportModelProfile, ) -from packages.gateway_core import ( - DEFAULT_GATEWAY_ACCOUNT_ID, - FileGatewayIdentityStore, - FileGatewaySessionStore, - GatewayAccountRef, - GatewayAttachmentRef, - GatewayConversationRef, - GatewayCoreDependencies, - GatewayCoreService, - GatewayExchange, - GatewayIdentityRecord, - GatewayInboundMessage, - GatewayOutboundMessage, - GatewayPolicyHint, - GatewaySenderRef, - InMemoryGatewayIdentityStore, - InMemoryGatewaySessionStore, -) -from packages.kernel import KernelDependencies, KernelService, KernelSourceRequest, ReconciliationPipeline, StateReconciler from packages.evidence import RecallRuntime from packages.runtime_layout import elephant_file_path from packages.skills import SkillPromptContextBuilder from packages.state import ( - DEFAULT_ELEPHANT_IDENTITY_TEXT, LoadedProfile, ProfileLoader, build_prompt_contract, @@ -78,10 +47,8 @@ load_runtime_profile, profile_with_authored_elephant_identity, ) -from packages.security.runtime import SecurityPolicy from packages.storage import RuntimeStorageRepository from packages.tools import ToolRuntime -from .plugins import GatewayAdapterDescriptor, GatewayPluginRegistry CHAT_BOT_ADAPTER_ID = "messaging.chat-bot" WEBHOOK_ADAPTER_ID = "messaging.webhook" @@ -91,6 +58,7 @@ from .runtime_support import * # noqa: F401,F403 + class GatewayTelemetrySink(TelemetrySinkCapability): def __init__(self) -> None: self.descriptor = CapabilityDescriptor( @@ -108,6 +76,7 @@ def events(self) -> tuple[dict[str, Any], ...]: def emit(self, event: Mapping[str, Any]) -> None: self._events.append(dict(event)) + class GatewayRecallCapability(RecallCapability): def __init__(self, runtime: RecallRuntime) -> None: self.descriptor = CapabilityDescriptor( @@ -122,7 +91,6 @@ def retrieve_evidence(self, request: EvidenceRetrievalRequest) -> EvidenceRetrie return self.runtime.retrieve_evidence(request) - class GatewayContextCapability(ContextCapability): def __init__( self, @@ -217,11 +185,7 @@ def assemble( bundle_id=f"bundle:{session.episode_id}:{len(work_items)}:{len(recall_items)}", instruction_refs=runtime.instruction_refs, ) - epoch = ( - self.epoch_store.load(session.episode_id) - if self.epoch_store is not None - else None - ) + epoch = self.epoch_store.load(session.episode_id) if self.epoch_store is not None else None return apply_session_context_epoch(bundle, epoch) def force_projection_compaction( @@ -256,6 +220,7 @@ def force_projection_compaction( def flush_projection_cache(self) -> None: return None + class GatewayPreviewModelProvider(ModelProviderCapability): def __init__(self) -> None: self.descriptor = CapabilityDescriptor( @@ -295,9 +260,14 @@ def generate( episode_id=session.episode_id, outcome="ok", summary=summary, - side_effects=("gateway-preview-provider", profile.mode, f"model_role={model_role}"), + side_effects=( + "gateway-preview-provider", + profile.mode, + f"model_role={model_role}", + ), ) + class GatewaySurfaceModelProvider(ModelProviderCapability): def __init__( self, diff --git a/apps/gateway/runtime_factory.py b/apps/gateway/runtime_factory.py index b2708e6..3033689 100644 --- a/apps/gateway/runtime_factory.py +++ b/apps/gateway/runtime_factory.py @@ -1,19 +1,13 @@ """Gateway runtime adapter registration and app factory.""" - from __future__ import annotations from collections.abc import Mapping -from dataclasses import dataclass, replace -from datetime import datetime, timezone -import hashlib -import json +from dataclasses import replace from pathlib import Path -import re import sys import tempfile from typing import Any -from uuid import uuid4 from apps.provider_runtime import ( load_provider_profile, @@ -21,8 +15,15 @@ ) from apps.runtime_layout import default_cli_state_dir from packages.cron import CronRuntime -from packages.runtime_config import configured_external_skill_dirs, global_config_path_for_state_dir, load_extensions_from_config, load_global_config -from packages.runtime_layout import default_cron_dir, elephant_file_path, infer_install_root_from_state_dir +from packages.runtime_config import ( + configured_external_skill_dirs, + load_extensions_from_config, +) +from packages.runtime_layout import ( + default_cron_dir, + elephant_file_path, + infer_install_root_from_state_dir, +) from packages.auth import ( AuthProfile, EncryptedRepositorySecretStore, @@ -34,45 +35,24 @@ SecretStore, SecretValueResolution, ) -from packages.capabilities.runtime import ( - CapabilityDescriptor, - ContextCapability, - RecallCapability, - ModelProviderCapability, - TelemetrySinkCapability, -) -from packages.context import ContextRuntime from packages.context.epoch_store import FileEpochStore -from packages.contracts.runtime import ( - ContextBundle, - EventEnvelope, - ExecutionResult, - RecallEvidence, - PersonalModelRuntimeState, -) from packages.gateway_core import ( - DEFAULT_GATEWAY_ACCOUNT_ID, FileGatewayIdentityStore, FileGatewaySessionStore, - GatewayAccountRef, - GatewayAttachmentRef, - GatewayConversationRef, GatewayCoreDependencies, GatewayCoreService, - GatewayExchange, - GatewayIdentityRecord, - GatewayInboundMessage, GatewayMessageDeliverySurface, - GatewayOutboundMessage, GatewayOutboundQueue, - GatewayPolicyHint, - GatewaySenderRef, InMemoryGatewayIdentityStore, InMemoryGatewaySessionStore, default_outbound_queue_path, ) -from packages.kernel import KernelDependencies, KernelService, KernelSourceRequest, ReconciliationPipeline, StateReconciler -from packages.evidence import RecallRuntime, SemanticSummaryIndexer, build_semantic_index_bundle +from packages.kernel import KernelDependencies, KernelService +from packages.evidence import ( + RecallRuntime, + SemanticSummaryIndexer, + build_semantic_index_bundle, +) from packages.skills import ( RuntimeSkillManagementSurface, SkillHub, @@ -83,7 +63,7 @@ load_skill_extension_manifest, sync_builtin_skill_shelf, ) -from packages.state import DEFAULT_ELEPHANT_IDENTITY_TEXT, LoadedProfile, ProfileLoader, build_prompt_contract +from packages.state import LoadedProfile, ProfileLoader from packages.security.runtime import SecurityPolicy from packages.storage import RuntimeStorageRepository from packages.tools import ( @@ -111,11 +91,15 @@ ) from .runtime_support import * # noqa: F401,F403 -def register_builtin_gateway_adapters(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: + +def register_builtin_gateway_adapters( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: for platform in BUILTIN_GATEWAY_PLATFORMS: registry.register_platform(platform) return registry + def _builtin_gateway_plugin_registry() -> GatewayPluginRegistry: registry = GatewayPluginRegistry() return register_builtin_gateway_adapters(registry) @@ -190,6 +174,7 @@ def _gateway_provider_credential_resolver( stores.append(EnvironmentSecretStore(runtime_environ)) return ProfileCredentialResolver(_GatewayFallbackSecretStore(tuple(stores))) + def build_gateway_app( *, profile_id: str = "you", @@ -207,9 +192,17 @@ def build_gateway_app( resolved_state_dir = Path(state_dir) if state_dir is not None else None if resolved_state_dir is not None: - from packages.runtime_config import load_global_config, global_config_path_for_state_dir - _gw_cfg = load_global_config(global_config_path_for_state_dir(resolved_state_dir), state_dir=resolved_state_dir) + from packages.runtime_config import ( + load_global_config, + global_config_path_for_state_dir, + ) + + _gw_cfg = load_global_config( + global_config_path_for_state_dir(resolved_state_dir), + state_dir=resolved_state_dir, + ) from packages.observability import setup_from_config + setup_from_config(_gw_cfg, state_dir=str(resolved_state_dir)) # The extension-manifest loader surfaces skill / tool overrides (profile.json @@ -228,12 +221,8 @@ def build_gateway_app( identity_store = InMemoryGatewayIdentityStore() session_store = InMemoryGatewaySessionStore() else: - identity_store = FileGatewayIdentityStore( - resolved_state_dir / "gateway-identities.json" - ) - session_store = FileGatewaySessionStore( - resolved_state_dir / "gateway-sessions.json" - ) + identity_store = FileGatewayIdentityStore(resolved_state_dir / "gateway-identities.json") + session_store = FileGatewaySessionStore(resolved_state_dir / "gateway-sessions.json") telemetry = GatewayTelemetrySink() core = GatewayCoreService( @@ -265,9 +254,7 @@ def build_gateway_app( auth_store = PersistentAuthProfileStore(runtime_repository) runtime_state_dir = runtime_repository.database_path.parent install_root = ( - ephemeral_home - if ephemeral_home is not None - else infer_install_root_from_state_dir(runtime_state_dir) + ephemeral_home if ephemeral_home is not None else infer_install_root_from_state_dir(runtime_state_dir) ) profile_loader = profile_loader or ProfileLoader(install_root) sync_builtin_skill_shelf(destination_root=install_root / "skills" / "builtin") @@ -347,6 +334,7 @@ def build_gateway_app( install_root=install_root, surface_kind="gateway", ) + def _resolve_elephant_state(elephant_id: str): resolved_elephant_id = elephant_id.strip() if resolved_elephant_id: @@ -499,6 +487,7 @@ def _tool_context_for_session(session_id: str, requester: ToolRequester | None) if start_learning_worker and resolved_state_dir is not None: try: from apps.learning_worker_runtime import ensure_learning_worker_running + ensure_learning_worker_running( state_dir=resolved_state_dir, ) diff --git a/apps/gateway/runtime_support.py b/apps/gateway/runtime_support.py index 5e8ac4e..2a53e92 100644 --- a/apps/gateway/runtime_support.py +++ b/apps/gateway/runtime_support.py @@ -3,62 +3,22 @@ from __future__ import annotations from collections.abc import Mapping -from dataclasses import dataclass, replace from datetime import datetime, timezone import hashlib import json from pathlib import Path import re import tempfile -from typing import Any -from uuid import uuid4 -from apps.provider_runtime import ( - load_provider_profile, - provider_profile_from_payload, -) -from packages.auth import AuthProfile, EnvironmentSecretStore, PersistentAuthProfileStore, ProfileCredentialResolver -from packages.models import SurfaceModelProviderCapability -from packages.models.runtime_capability import provider_fallback_summary, provider_profile_summary -from packages.capabilities.runtime import ( - CapabilityDescriptor, - ContextCapability, - RecallCapability, - ModelProviderCapability, - TelemetrySinkCapability, -) -from packages.context import ContextRuntime -from packages.contracts import Episode -from packages.contracts.runtime import ( - ContextBundle, - EventEnvelope, - ExecutionResult, - RecallEvidence, -) from packages.gateway_core import ( DEFAULT_GATEWAY_ACCOUNT_ID, - FileGatewayIdentityStore, - FileGatewaySessionStore, GatewayAccountRef, GatewayAttachmentRef, GatewayConversationRef, - GatewayCoreDependencies, - GatewayCoreService, - GatewayExchange, - GatewayIdentityRecord, - GatewayInboundMessage, GatewayOutboundMessage, GatewayPolicyHint, GatewaySenderRef, - InMemoryGatewayIdentityStore, - InMemoryGatewaySessionStore, ) -from packages.kernel import KernelDependencies, KernelService, KernelSourceRequest, ReconciliationPipeline, StateReconciler -from packages.evidence.recall_runtime import RecallRuntime -from packages.state import build_prompt_contract -from packages.security.runtime import SecurityPolicy -from packages.storage import RuntimeStorageRepository -from .plugins import GatewayAdapterDescriptor, GatewayPluginRegistry CHAT_BOT_ADAPTER_ID = "messaging.chat-bot" WEBHOOK_ADAPTER_ID = "messaging.webhook" @@ -354,10 +314,7 @@ def _discord_body(payload: Mapping[str, object]) -> str: "<", "|=>|::|\{.*\}|]*>|;\s*$)" -) - +_CODE_SIGNAL_RE = re.compile(r"(\b[A-Za-z_][\w.]*\([^\n]*\)|\s:=\s|\s=\s|->|=>|::|\{.*\}|]*>|;\s*$)") def _is_command_line(line: str) -> bool: @@ -369,7 +326,6 @@ def _is_command_line(line: str) -> bool: return _COMMAND_PREFIX_RE.match(stripped) is not None - def _is_formula_line(line: str) -> bool: stripped = line.strip() if not stripped or _is_command_line(stripped): @@ -389,7 +345,6 @@ def _is_formula_line(line: str) -> bool: return alpha_count <= max(16, len(stripped) // 2) - def _looks_like_code_line(line: str) -> bool: stripped = line.strip() if not stripped or _is_command_line(stripped) or _is_formula_line(stripped): @@ -401,7 +356,6 @@ def _looks_like_code_line(line: str) -> bool: return _CODE_SIGNAL_RE.search(stripped) is not None - def _detect_code_fence_language(lines: list[str]) -> str: first = next((line.strip() for line in lines if line.strip()), "") if not first: @@ -423,13 +377,11 @@ def _detect_code_fence_language(lines: list[str]) -> str: return "text" - def _fenced_block(lines: list[str], *, language: str) -> list[str]: opening = f"```{language}" if language else "```" return [opening, *lines, "```"] - def _wrap_rich_text_block(lines: list[str]) -> list[str]: if not lines: return [] @@ -441,14 +393,11 @@ def _wrap_rich_text_block(lines: list[str]) -> list[str]: if all(_is_formula_line(line) for line in meaningful_lines): return _fenced_block(lines, language="tex") code_like_lines = [line for line in meaningful_lines if _looks_like_code_line(line)] - if code_like_lines and ( - len(meaningful_lines) == 1 or len(code_like_lines) >= max(1, len(meaningful_lines) - 1) - ): + if code_like_lines and (len(meaningful_lines) == 1 or len(code_like_lines) >= max(1, len(meaningful_lines) - 1)): return _fenced_block(lines, language=_detect_code_fence_language(meaningful_lines)) return lines - def _render_rich_text_plain_segment(lines: list[str]) -> list[str]: rendered: list[str] = [] block: list[str] = [] @@ -470,7 +419,6 @@ def flush_block() -> None: return rendered - def _render_rich_text_body(body: str) -> str: normalized = body.replace("\r\n", "\n") if not normalized.strip(): @@ -496,7 +444,6 @@ def _render_rich_text_body(body: str) -> str: return "\n".join(rendered) - def _discord_reply_request(outbound: GatewayOutboundMessage) -> Mapping[str, object]: rendered_body = _render_rich_text_body(outbound.body) body: dict[str, object] = { @@ -561,7 +508,9 @@ def _feishu_message_content(content: object) -> dict[str, object]: return {"raw": str(content)} -def _feishu_attachment_refs(content: Mapping[str, object]) -> tuple[GatewayAttachmentRef, ...]: +def _feishu_attachment_refs( + content: Mapping[str, object], +) -> tuple[GatewayAttachmentRef, ...]: deduped: dict[str, GatewayAttachmentRef] = {} for field_name, kind in ( ("image_key", "image"), @@ -661,7 +610,6 @@ def _feishu_extract_title_and_body(body: str) -> tuple[str, str]: return title, "\n".join(lines).strip() or normalized.strip() or "(empty reply)" - def _feishu_json_v2_markdown_body(body: str) -> str: normalized = body.replace("\r\n", "\n") language_aliases = { @@ -681,7 +629,6 @@ def _feishu_json_v2_markdown_body(body: str) -> str: return "\n".join(rendered_lines) - def _feishu_interactive_payload(body: str) -> dict[str, object]: title, markdown_body = _feishu_extract_title_and_body(body) return { @@ -708,7 +655,6 @@ def _feishu_interactive_payload(body: str) -> dict[str, object]: } - def _feishu_reply_request(outbound: "GatewayOutboundMessage") -> dict[str, object]: if not outbound.reply_to_message_id: raise ValueError("feishu reply request requires reply_to_message_id") @@ -732,6 +678,7 @@ def _runtime_database_path(state_dir: Path | None) -> Path: state_dir.mkdir(parents=True, exist_ok=True) return state_dir / "elephant.sqlite3" + __all__ = [ "CHAT_BOT_ADAPTER_ID", "WEBHOOK_ADAPTER_ID", diff --git a/apps/gateway/telegram.py b/apps/gateway/telegram.py index e1d3f43..676a440 100644 --- a/apps/gateway/telegram.py +++ b/apps/gateway/telegram.py @@ -6,7 +6,6 @@ from dataclasses import dataclass, field import json import os -from pathlib import Path from typing import Any from urllib.error import HTTPError, URLError from urllib.request import Request, urlopen @@ -14,7 +13,12 @@ from packages.gateway_core import GatewayExchange, GatewayOutboundMessage from .plugins import GatewayPluginRegistry -from .runtime import DEFAULT_GATEWAY_ACCOUNT_ID, GatewayApp, TelegramMessagingAdapter, build_gateway_app +from .runtime import ( + DEFAULT_GATEWAY_ACCOUNT_ID, + GatewayApp, + TelegramMessagingAdapter, + build_gateway_app, +) DEFAULT_TELEGRAM_BOT_TOKEN_ENV = "ELEPHANT_TELEGRAM_BOT_TOKEN" LEGACY_TELEGRAM_BOT_TOKEN_ENV = "TELEGRAM_BOT_TOKEN" @@ -38,10 +42,7 @@ def _normalize_transport(value: str | None) -> str: normalized = str(value or "webhook").strip().lower().replace("_", "-") if normalized in {"callback", "http", "webhook"}: return "webhook" - raise ValueError( - "telegram transport must be one of " - f"{', '.join(SUPPORTED_TELEGRAM_TRANSPORTS)}" - ) + raise ValueError(f"telegram transport must be one of {', '.join(SUPPORTED_TELEGRAM_TRANSPORTS)}") def _json_bytes(payload: Mapping[str, object]) -> bytes: @@ -70,9 +71,7 @@ def _default_json_request( raw = response.read().decode("utf-8") except HTTPError as exc: detail = exc.read().decode("utf-8") - raise RuntimeError( - f"telegram request failed with HTTP {exc.code}: {detail or exc.reason}" - ) from exc + raise RuntimeError(f"telegram request failed with HTTP {exc.code}: {detail or exc.reason}") from exc except URLError as exc: raise RuntimeError(f"telegram request failed: {exc.reason}") from exc try: @@ -113,7 +112,9 @@ class TelegramGatewayEventResult: delivery_response: Mapping[str, object] | None = None -def load_telegram_gateway_accounts(app: GatewayApp) -> tuple[TelegramGatewayAccountConfig, ...]: +def load_telegram_gateway_accounts( + app: GatewayApp, +) -> tuple[TelegramGatewayAccountConfig, ...]: manifest = app.loaded_profile.manifest if app.loaded_profile is not None else {} gateway_payload = _mapping(manifest.get("gateway")) or {} adapters_payload = _mapping(gateway_payload.get("adapters")) or {} @@ -135,9 +136,7 @@ def load_telegram_gateway_accounts(app: GatewayApp) -> tuple[TelegramGatewayAcco resolved.append( TelegramGatewayAccountConfig( account_id=str(account_mapping.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID), - bot_token_env_var=str( - env_payload.get("bot_token") or DEFAULT_TELEGRAM_BOT_TOKEN_ENV - ), + bot_token_env_var=str(env_payload.get("bot_token") or DEFAULT_TELEGRAM_BOT_TOKEN_ENV), surface=str(account_mapping.get("surface") or default_surface), event_path=_normalize_path(account_mapping.get("event_path") or default_event_path), base_url=str(account_mapping.get("base_url") or default_base_url), @@ -165,9 +164,7 @@ def resolve_telegram_account( if not bot_token and config.bot_token_env_var == DEFAULT_TELEGRAM_BOT_TOKEN_ENV: bot_token = str(env.get(LEGACY_TELEGRAM_BOT_TOKEN_ENV) or "") if not bot_token: - raise LookupError( - f"telegram account '{config.account_id}' requires {config.bot_token_env_var}" - ) + raise LookupError(f"telegram account '{config.account_id}' requires {config.bot_token_env_var}") return TelegramResolvedAccount( account_id=config.account_id, bot_token=bot_token, @@ -237,9 +234,7 @@ def describe(self) -> Mapping[str, object]: def configured_transport(self) -> str: if not self.account_configs: return "webhook" - transports = tuple( - dict.fromkeys(_normalize_transport(config.surface) for config in self.account_configs) - ) + transports = tuple(dict.fromkeys(_normalize_transport(config.surface) for config in self.account_configs)) if len(transports) == 1: return transports[0] raise LookupError( @@ -383,7 +378,9 @@ async def stop_daemon_task(self) -> None: """No-op for webhook-only service.""" -def register_telegram_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: +def register_telegram_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: registry.register_service( "telegram", factory=lambda app, **kwargs: TelegramGatewayService(app=app, **kwargs), diff --git a/apps/gateway/web.py b/apps/gateway/web.py index e8bb8f9..6d63b7d 100644 --- a/apps/gateway/web.py +++ b/apps/gateway/web.py @@ -20,8 +20,7 @@ def _normalize_services( return {str(key): value for key, value in services.items()} if isinstance(services, Sequence) and not isinstance(services, (str, bytes, bytearray)): return { - str(getattr(service, "service_key", f"service-{index}")): service - for index, service in enumerate(services) + str(getattr(service, "service_key", f"service-{index}")): service for index, service in enumerate(services) } service = services return {str(getattr(service, "service_key", "service")): service} @@ -61,10 +60,7 @@ def application(environ: Mapping[str, object], start_response: Callable[..., obj if method == "GET" and path == "/healthz": payload: dict[str, object] = { "ok": True, - "services": { - key: dict(service.describe()) - for key, service in service_map.items() - }, + "services": {key: dict(service.describe()) for key, service in service_map.items()}, } if gateway_app is not None and hasattr(gateway_app, "setup_summary"): payload["gateway"] = gateway_app.setup_summary() diff --git a/apps/gateway/wecom.py b/apps/gateway/wecom.py index b070c29..f9413e4 100644 --- a/apps/gateway/wecom.py +++ b/apps/gateway/wecom.py @@ -13,7 +13,9 @@ from .runtime import build_gateway_app -def register_wecom_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: +def register_wecom_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: registry.register_service( "wecom", factory=lambda app, **kwargs: WecomGatewayService(app=app, **kwargs), diff --git a/apps/gateway/wecom_service.py b/apps/gateway/wecom_service.py index db33677..9a3ca60 100644 --- a/apps/gateway/wecom_service.py +++ b/apps/gateway/wecom_service.py @@ -14,7 +14,6 @@ from packages.gateway_core import ( DEFAULT_GATEWAY_ACCOUNT_ID, - GatewayExchange, GatewayInboundMessage, GatewayOutboundMessage, ) @@ -27,19 +26,22 @@ GatewayCliControlService, load_gateway_cli_control_config, ) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path -from .runtime import WECOM_ADAPTER_ID, WecomMessagingAdapter, GatewayApp, build_gateway_app +from .plugins import ( + GatewayManagedRuntime, + GatewayPluginRegistry, + default_gateway_runtime_path, +) +from .runtime import ( + WECOM_ADAPTER_ID, + WecomMessagingAdapter, + GatewayApp, + build_gateway_app, +) from .wecom_support import ( - DEFAULT_WECOM_BOT_ID_ENV, - DEFAULT_WECOM_SECRET_ENV, - DEFAULT_WECOM_WS_URL, MESSAGE_DEDUP_TTL_SECONDS, - WECOM_AVAILABLE, WecomGatewayAccountConfig, - WecomGatewayEventResult, WecomResolvedAccount, - _coerce_bool, _extract_wecom_text, _normalize_transport, check_wecom_requirements, @@ -140,14 +142,8 @@ def __post_init__(self) -> None: binding_store = self.cli_binding_store if binding_store is None: state_root = self.app.state_dir - binding_path = ( - None - if state_root is None - else os.path.join(state_root, "wecom-cli-bindings.json") - ) - binding_store = GatewayCliBindingStore( - path=None if binding_path is None else Path(binding_path) - ) + binding_path = None if state_root is None else os.path.join(state_root, "wecom-cli-bindings.json") + binding_store = GatewayCliBindingStore(path=None if binding_path is None else Path(binding_path)) self.cli_control = GatewayCliControlService( config=self._resolved_cli_control_config(config), app=self.app, @@ -326,9 +322,7 @@ def _match_account(self, *, account_id: str | None = None) -> WecomResolvedAccou raise LookupError("no enabled WeCom gateway accounts are configured") if len(enabled_configs) == 1: return resolve_wecom_account(enabled_configs[0], environ=self.environ) - raise LookupError( - "multiple enabled WeCom gateway accounts are configured; pass account_id explicitly" - ) + raise LookupError("multiple enabled WeCom gateway accounts are configured; pass account_id explicitly") # ----------------------------------------------------------------------- # Async WebSocket lifecycle @@ -337,10 +331,7 @@ def _match_account(self, *, account_id: str | None = None) -> WecomResolvedAccou async def start_gateway(self, *, account_id: str | None = None) -> None: """Connect WebSocket and run the listen loop.""" if not check_wecom_requirements(): - raise RuntimeError( - "WeCom startup failed: aiohttp is required. " - "Install it with: pip install aiohttp" - ) + raise RuntimeError("WeCom startup failed: aiohttp is required. Install it with: pip install aiohttp") account = self._match_account(account_id=account_id) config = account.config @@ -417,7 +408,6 @@ async def stop_gateway(self) -> None: async def _open_connection(self) -> None: """Authenticate via aibot_subscribe command.""" - import aiohttp assert self._session is not None @@ -434,9 +424,7 @@ async def _open_connection(self) -> None: receive_timeout=WS_READ_TIMEOUT, ) except Exception as exc: - raise RuntimeError( - f"WeCom WebSocket connection failed: {exc}" - ) from exc + raise RuntimeError(f"WeCom WebSocket connection failed: {exc}") from exc # Send subscribe command req_id = self._new_req_id("subscribe") @@ -475,16 +463,12 @@ async def _wait_for_handshake(self, req_id: str) -> None: while True: remaining = deadline - asyncio.get_running_loop().time() if remaining <= 0: - raise RuntimeError( - "WeCom WebSocket handshake timed out; check bot_id and secret" - ) + raise RuntimeError("WeCom WebSocket handshake timed out; check bot_id and secret") try: msg = await asyncio.wait_for(self._ws.receive(), timeout=remaining) except asyncio.TimeoutError: - raise RuntimeError( - "WeCom WebSocket handshake timed out; check bot_id and secret" - ) + raise RuntimeError("WeCom WebSocket handshake timed out; check bot_id and secret") if msg.type == aiohttp.WSMsgType.TEXT: try: @@ -501,9 +485,7 @@ async def _wait_for_handshake(self, req_id: str) -> None: errcode = payload.get("errcode", 0) if errcode not in (0, None): errmsg = payload.get("errmsg") or "authentication failed" - raise RuntimeError( - f"WeCom subscribe failed: {errmsg} (errcode={errcode})" - ) + raise RuntimeError(f"WeCom subscribe failed: {errmsg} (errcode={errcode})") return LOGGER.debug( @@ -517,9 +499,7 @@ async def _wait_for_handshake(self, req_id: str) -> None: aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR, ): - raise RuntimeError( - "WeCom WebSocket closed during authentication" - ) + raise RuntimeError("WeCom WebSocket closed during authentication") async def _listen_loop(self) -> None: """Read events with automatic reconnection and exponential backoff.""" @@ -647,7 +627,11 @@ async def _heartbeat_loop(self) -> None: break try: await self._send_json( - {"cmd": "ping", "headers": {"req_id": self._new_req_id("ping")}, "body": {}} + { + "cmd": "ping", + "headers": {"req_id": self._new_req_id("ping")}, + "body": {}, + } ) except Exception as exc: LOGGER.debug("[%s] Heartbeat send failed: %s", self.service_key, exc) @@ -961,7 +945,9 @@ def _safe_id(value: str) -> str: return value[:4] + "***" + value[-4:] -def register_wecom_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: +def register_wecom_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: from .wecom import WecomGatewayService registry.register_service( diff --git a/apps/gateway/wecom_support.py b/apps/gateway/wecom_support.py index 8130a72..285ee8c 100644 --- a/apps/gateway/wecom_support.py +++ b/apps/gateway/wecom_support.py @@ -4,26 +4,15 @@ from collections.abc import Mapping from dataclasses import dataclass, field -import importlib.util import os -from pathlib import Path -from apps.runtime_layout import default_cli_state_dir from packages.gateway_core import ( DEFAULT_GATEWAY_ACCOUNT_ID, GatewayExchange, - GatewayInboundMessage, GatewayOutboundMessage, ) -from .cli_control import ( - CliRuntimeFactory, - GatewayCliBindingStore, - GatewayCliControlService, - load_gateway_cli_control_config, -) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path -from .runtime import GatewayApp, build_gateway_app +from .runtime import GatewayApp # --------------------------------------------------------------------------- # Environment variable defaults @@ -129,10 +118,7 @@ def _normalize_transport(value: str | None) -> str: normalized = str(value or "websocket").strip().lower().replace("_", "-") if normalized in {"websocket", "wecom-websocket"}: return "websocket" - raise ValueError( - "wecom transport must be one of " - f"{', '.join(SUPPORTED_WECOM_TRANSPORTS)}" - ) + raise ValueError(f"wecom transport must be one of {', '.join(SUPPORTED_WECOM_TRANSPORTS)}") def _coerce_list(value: object) -> tuple[str, ...]: @@ -260,12 +246,8 @@ def load_wecom_gateway_accounts( resolved.append( WecomGatewayAccountConfig( account_id=str(account_mapping.get("account_id") or DEFAULT_GATEWAY_ACCOUNT_ID), - bot_id_env_var=str( - env_payload.get("bot_id") or DEFAULT_WECOM_BOT_ID_ENV - ), - secret_env_var=str( - env_payload.get("secret") or DEFAULT_WECOM_SECRET_ENV - ), + bot_id_env_var=str(env_payload.get("bot_id") or DEFAULT_WECOM_BOT_ID_ENV), + secret_env_var=str(env_payload.get("secret") or DEFAULT_WECOM_SECRET_ENV), surface=str(account_mapping.get("surface") or default_surface), enabled=account_enabled, ws_url=str(account_mapping.get("ws_url") or DEFAULT_WECOM_WS_URL), @@ -290,13 +272,9 @@ def resolve_wecom_account( bot_id = str(env.get(config.bot_id_env_var) or "").strip() secret = str(env.get(config.secret_env_var) or "").strip() if not bot_id: - raise LookupError( - f"wecom account '{config.account_id}' requires {config.bot_id_env_var}" - ) + raise LookupError(f"wecom account '{config.account_id}' requires {config.bot_id_env_var}") if not secret: - raise LookupError( - f"wecom account '{config.account_id}' requires {config.secret_env_var}" - ) + raise LookupError(f"wecom account '{config.account_id}' requires {config.secret_env_var}") return WecomResolvedAccount( account_id=config.account_id, bot_id=bot_id, diff --git a/apps/gateway/weixin.py b/apps/gateway/weixin.py index 18f1f00..9b21114 100644 --- a/apps/gateway/weixin.py +++ b/apps/gateway/weixin.py @@ -13,7 +13,9 @@ from .runtime import build_gateway_app -def register_weixin_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: +def register_weixin_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: registry.register_service( "weixin", factory=lambda app, **kwargs: WeixinGatewayService(app=app, **kwargs), diff --git a/apps/gateway/weixin_service.py b/apps/gateway/weixin_service.py index 5fb06f1..9ed23af 100644 --- a/apps/gateway/weixin_service.py +++ b/apps/gateway/weixin_service.py @@ -5,11 +5,9 @@ import asyncio from collections.abc import Callable, Mapping from dataclasses import dataclass, field -import json import logging import os from pathlib import Path -import threading from typing import Any from uuid import uuid4 @@ -17,7 +15,6 @@ DEFAULT_GATEWAY_ACCOUNT_ID, GatewayAccountRef, GatewayConversationRef, - GatewayExchange, GatewayInboundMessage, GatewayOutboundMessage, GatewayOutboundQueue, @@ -36,7 +33,11 @@ GatewayCliControlService, load_gateway_cli_control_config, ) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path +from .plugins import ( + GatewayManagedRuntime, + GatewayPluginRegistry, + default_gateway_runtime_path, +) from .runtime import ( WEIXIN_ADAPTER_ID, WeixinMessagingAdapter, @@ -46,14 +47,11 @@ from .weixin_support import ( ILINK_BASE_URL, - WEIXIN_CDN_BASE_URL, - AIOHTTP_AVAILABLE, ContextTokenStore, TypingTicketCache, WeixinGatewayAccountConfig, WeixinGatewayEventResult, WeixinResolvedAccount, - _coerce_bool, _extract_text, _guess_chat_type, _load_sync_buf, @@ -66,31 +64,16 @@ load_weixin_account, load_weixin_gateway_accounts, resolve_weixin_account, - save_weixin_account, - # iLink API helpers - _api_post, _get_updates, _send_message as _ilink_send_message, - _send_typing, _get_config, - EP_SEND_MESSAGE, LONG_POLL_TIMEOUT_MS, - API_TIMEOUT_MS, SESSION_EXPIRED_ERRCODE, MAX_CONSECUTIVE_FAILURES, RETRY_DELAY_SECONDS, BACKOFF_DELAY_SECONDS, MESSAGE_DEDUP_TTL_SECONDS, - ITEM_TEXT, - MSG_TYPE_BOT, - MSG_STATE_FINISH, - TYPING_START, - TYPING_STOP, _make_ssl_connector, - _headers, - _json_dumps, - _base_info, - _random_wechat_uin, ) LOGGER = logging.getLogger(__name__) @@ -171,14 +154,8 @@ def __post_init__(self) -> None: binding_store = self.cli_binding_store if binding_store is None: state_root = self.app.state_dir - binding_path = ( - None - if state_root is None - else os.path.join(state_root, "weixin-cli-bindings.json") - ) - binding_store = GatewayCliBindingStore( - path=None if binding_path is None else Path(binding_path) - ) + binding_path = None if state_root is None else os.path.join(state_root, "weixin-cli-bindings.json") + binding_store = GatewayCliBindingStore(path=None if binding_path is None else Path(binding_path)) self.cli_control = GatewayCliControlService( config=self._resolved_cli_control_config(config), app=self.app, @@ -278,14 +255,10 @@ def configured_transport(self) -> str: transport_configs = self._transport_account_configs() if not transport_configs: return "ilink" - transports = tuple( - dict.fromkeys(_normalize_transport(config.surface) for config in transport_configs) - ) + transports = tuple(dict.fromkeys(_normalize_transport(config.surface) for config in transport_configs)) if len(transports) == 1: return transports[0] - raise LookupError( - "configured WeChat accounts use multiple transport surfaces; choose one explicitly" - ) + raise LookupError("configured WeChat accounts use multiple transport surfaces; choose one explicitly") @property def event_paths(self) -> tuple[str, ...]: @@ -302,7 +275,10 @@ def handle_http_event( path: str, ) -> tuple[str, Mapping[str, object]]: # iLink mode does not use HTTP callbacks. - return "501 Not Implemented", {"ok": False, "error": "iLink transport does not use HTTP callbacks"} + return "501 Not Implemented", { + "ok": False, + "error": "iLink transport does not use HTTP callbacks", + } def configured_runtime_target(self) -> str: return self.configured_transport() @@ -394,9 +370,7 @@ def _resolve_credentials(self, account_id: str | None = None) -> WeixinResolvedA raise LookupError("no enabled WeChat gateway accounts are configured") if len(enabled_configs) == 1: return resolve_weixin_account(enabled_configs[0]) - raise LookupError( - "multiple enabled WeChat gateway accounts are configured; pass account_id explicitly" - ) + raise LookupError("multiple enabled WeChat gateway accounts are configured; pass account_id explicitly") async def start_gateway(self, account_id: str | None = None) -> None: """Connect and start the long-polling loop.""" @@ -585,7 +559,10 @@ async def _poll_loop(self) -> None: errcode = response.get("errcode", 0) if ret not in (0, None) or errcode not in (0, None): if ret == SESSION_EXPIRED_ERRCODE or errcode == SESSION_EXPIRED_ERRCODE: - LOGGER.error("[%s] Session expired; pausing for 10 minutes", self.service_key) + LOGGER.error( + "[%s] Session expired; pausing for 10 minutes", + self.service_key, + ) await asyncio.sleep(600) consecutive_failures = 0 continue @@ -600,7 +577,8 @@ async def _poll_loop(self) -> None: MAX_CONSECUTIVE_FAILURES, ) await asyncio.sleep( - BACKOFF_DELAY_SECONDS if consecutive_failures >= MAX_CONSECUTIVE_FAILURES + BACKOFF_DELAY_SECONDS + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES else RETRY_DELAY_SECONDS ) if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: @@ -627,8 +605,7 @@ async def _poll_loop(self) -> None: exc, ) await asyncio.sleep( - BACKOFF_DELAY_SECONDS if consecutive_failures >= MAX_CONSECUTIVE_FAILURES - else RETRY_DELAY_SECONDS + BACKOFF_DELAY_SECONDS if consecutive_failures >= MAX_CONSECUTIVE_FAILURES else RETRY_DELAY_SECONDS ) if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: consecutive_failures = 0 @@ -763,7 +740,12 @@ async def _maybe_fetch_typing_ticket(self, user_id: str, context_token: str | No if typing_ticket and self._typing_cache: self._typing_cache.set(user_id, typing_ticket) except Exception as exc: - LOGGER.debug("[%s] getConfig failed for %s: %s", self.service_key, _safe_id(user_id), exc) + LOGGER.debug( + "[%s] getConfig failed for %s: %s", + self.service_key, + _safe_id(user_id), + exc, + ) async def _send_ilink_message(self, outbound: GatewayOutboundMessage) -> None: """Send a reply via iLink sendmessage API.""" @@ -772,15 +754,9 @@ async def _send_ilink_message(self, outbound: GatewayOutboundMessage) -> None: content = _normalize_markdown_blocks(outbound.body) chat_id = outbound.conversation_id - context_token = ( - self._token_store.get(self._resolved_account_id, chat_id) - if self._token_store - else None - ) + context_token = self._token_store.get(self._resolved_account_id, chat_id) if self._token_store else None - chunks = _split_text_for_weixin_delivery( - content, MAX_MESSAGE_LENGTH, self._split_multiline_messages - ) + chunks = _split_text_for_weixin_delivery(content, MAX_MESSAGE_LENGTH, self._split_multiline_messages) chunks = [c for c in chunks if c and c.strip()] for idx, chunk in enumerate(chunks): @@ -802,32 +778,31 @@ async def _send_ilink_message(self, outbound: GatewayOutboundMessage) -> None: ret = resp.get("ret") errcode = resp.get("errcode") if (ret is not None and ret not in (0,)) or (errcode is not None and errcode not in (0,)): - is_session_expired = ( - ret == SESSION_EXPIRED_ERRCODE - or errcode == SESSION_EXPIRED_ERRCODE - ) + is_session_expired = ret == SESSION_EXPIRED_ERRCODE or errcode == SESSION_EXPIRED_ERRCODE if is_session_expired and not retried_without_token and context_token: retried_without_token = True context_token = None if self._token_store: self._token_store._cache.pop( - self._token_store._key(self._resolved_account_id, chat_id), None + self._token_store._key(self._resolved_account_id, chat_id), + None, ) LOGGER.warning( "[%s] session expired for %s; retrying without context_token", - self.service_key, _safe_id(chat_id), + self.service_key, + _safe_id(chat_id), ) continue errmsg = resp.get("errmsg") or resp.get("msg") or "unknown error" - raise RuntimeError( - f"iLink sendmessage error: ret={ret} errcode={errcode} errmsg={errmsg}" - ) + raise RuntimeError(f"iLink sendmessage error: ret={ret} errcode={errcode} errmsg={errmsg}") break except Exception as exc: if attempt >= 2: LOGGER.error( "[%s] send failed to=%s after 3 attempts: %s", - self.service_key, _safe_id(chat_id), exc, + self.service_key, + _safe_id(chat_id), + exc, ) raise await asyncio.sleep(1.0 * (attempt + 1)) @@ -894,10 +869,7 @@ def deliver_cron_result(self, job, execution) -> None: # Only warn when we have weixin identities but cannot disambiguate — if # there are zero weixin identities, the scheduler's fan-out simply asked # the wrong adapter, which is expected noise. - any_weixin = any( - r.key.adapter_id == WEIXIN_ADAPTER_ID - for r in identity_store.list_records() - ) + any_weixin = any(r.key.adapter_id == WEIXIN_ADAPTER_ID for r in identity_store.list_records()) if any_weixin: LOGGER.warning( "cron delivery: skipping job=%s — no job.elephant_id and multiple weixin herd", @@ -957,7 +929,9 @@ def _outbound_queue_for_state_dir(state_dir: str) -> GatewayOutboundQueue: return GatewayOutboundQueue(path=default_outbound_queue_path(state_dir)) -def register_weixin_gateway_service(registry: GatewayPluginRegistry) -> GatewayPluginRegistry: +def register_weixin_gateway_service( + registry: GatewayPluginRegistry, +) -> GatewayPluginRegistry: from .weixin import WeixinGatewayService registry.register_service( diff --git a/apps/gateway/weixin_support.py b/apps/gateway/weixin_support.py index 8e07683..607f96f 100644 --- a/apps/gateway/weixin_support.py +++ b/apps/gateway/weixin_support.py @@ -13,7 +13,6 @@ import base64 import json import logging -import os import secrets import struct import time @@ -21,31 +20,13 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Optional -from urllib.parse import quote, urlparse -from apps.runtime_layout import default_cli_state_dir from packages.gateway_core import ( DEFAULT_GATEWAY_ACCOUNT_ID, GatewayExchange, - GatewayInboundMessage, - GatewayOutboundMessage, ) -from .cli_control import ( - CliRuntimeFactory, - GatewayCliBindingStore, - GatewayCliControlService, - load_gateway_cli_control_config, -) -from .plugins import GatewayManagedRuntime, GatewayPluginRegistry, default_gateway_runtime_path -from .runtime import GatewayApp, build_gateway_app -from .weixin_delivery import ( - _normalize_markdown_blocks, - _pack_markdown_blocks_for_weixin, - _split_delivery_units_for_weixin, - _split_markdown_blocks, - _split_text_for_weixin_delivery, -) +from .runtime import GatewayApp logger = logging.getLogger(__name__) @@ -161,6 +142,7 @@ def check_weixin_requirements() -> bool: # SSL connector helper # --------------------------------------------------------------------------- + def _make_ssl_connector() -> Optional["aiohttp.TCPConnector"]: """Return a TCPConnector with a certifi CA bundle, or None.""" try: @@ -178,6 +160,7 @@ def _make_ssl_connector() -> Optional["aiohttp.TCPConnector"]: # Utility helpers # --------------------------------------------------------------------------- + def _safe_id(value: Optional[str], keep: int = 8) -> str: raw = str(value or "").strip() if not raw: @@ -228,16 +211,14 @@ def _normalize_transport(value: str | None) -> str: normalized = str(value or "ilink").strip().lower().replace("_", "-") if normalized in {"ilink", "weixin-ilink", "wxhook", "weixin-wxhook"}: return "ilink" - raise ValueError( - "weixin transport must be one of " - f"{', '.join(SUPPORTED_WEIXIN_TRANSPORTS)}" - ) + raise ValueError(f"weixin transport must be one of {', '.join(SUPPORTED_WEIXIN_TRANSPORTS)}") # --------------------------------------------------------------------------- # AES-128-ECB encryption / decryption # --------------------------------------------------------------------------- + def _pkcs7_pad(data: bytes, block_size: int = 16) -> bytes: pad_len = block_size - (len(data) % block_size) return data + bytes([pad_len] * pad_len) @@ -280,6 +261,7 @@ def _parse_aes_key(aes_key_b64: str) -> bytes: # iLink API helpers # --------------------------------------------------------------------------- + def _random_wechat_uin() -> str: value = struct.unpack(">I", secrets.token_bytes(4))[0] return base64.b64encode(str(value).encode("utf-8")).decode("ascii") @@ -312,8 +294,6 @@ async def _api_post( token: Optional[str], timeout_ms: int, ) -> dict[str, Any]: - import asyncio - body = _json_dumps({**payload, "base_info": _base_info()}) url = f"{base_url.rstrip('/')}/{endpoint}" timeout = aiohttp.ClientTimeout(total=timeout_ms / 1000) @@ -447,6 +427,7 @@ async def _get_config( # Context token persistence # --------------------------------------------------------------------------- + def _account_dir(state_dir: str) -> Path: path = Path(state_dir) / "weixin" / "accounts" path.mkdir(parents=True, exist_ok=True) @@ -511,7 +492,11 @@ def restore(self, account_id: str) -> None: try: data = json.loads(path.read_text(encoding="utf-8")) except Exception as exc: - logger.warning("weixin: failed to restore context tokens for %s: %s", _safe_id(account_id), exc) + logger.warning( + "weixin: failed to restore context tokens for %s: %s", + _safe_id(account_id), + exc, + ) return restored = 0 for user_id, token in data.items(): @@ -519,7 +504,11 @@ def restore(self, account_id: str) -> None: self._cache[self._key(account_id, user_id)] = token restored += 1 if restored: - logger.info("weixin: restored %d context token(s) for %s", restored, _safe_id(account_id)) + logger.info( + "weixin: restored %d context token(s) for %s", + restored, + _safe_id(account_id), + ) def get(self, account_id: str, user_id: str) -> Optional[str]: return self._cache.get(self._key(account_id, user_id)) @@ -530,17 +519,15 @@ def set(self, account_id: str, user_id: str, token: str) -> None: def _persist(self, account_id: str) -> None: prefix = f"{account_id}:" - payload = { - key[len(prefix):]: value - for key, value in self._cache.items() - if key.startswith(prefix) - } + payload = {key[len(prefix) :]: value for key, value in self._cache.items() if key.startswith(prefix)} try: - self._path(account_id).write_text( - json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8" - ) + self._path(account_id).write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") except Exception as exc: - logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc) + logger.warning( + "weixin: failed to persist context tokens for %s: %s", + _safe_id(account_id), + exc, + ) class TypingTicketCache: @@ -567,6 +554,7 @@ def set(self, user_id: str, ticket: str) -> None: # Sync buffer persistence # --------------------------------------------------------------------------- + def _sync_buf_path(state_dir: str, account_id: str) -> Path: return _account_dir(state_dir) / f"{account_id}.sync.json" @@ -593,6 +581,7 @@ def _save_sync_buf(state_dir: str, account_id: str, sync_buf: str) -> None: # QR login flow # --------------------------------------------------------------------------- + async def qr_login( state_dir: str, *, @@ -690,6 +679,7 @@ async def qr_login( print(qrcode_url) try: import qrcode as _qrcode + qr = _qrcode.QRCode() qr.add_data(qr_scan_data) qr.make(fit=True) @@ -731,6 +721,7 @@ async def qr_login( # Message format helpers (inbound iLink → elephant) # --------------------------------------------------------------------------- + def _extract_text(item_list: list[dict[str, Any]]) -> str: for item in item_list: if item.get("type") == ITEM_TEXT: @@ -807,6 +798,7 @@ def _weixin_body(payload: Mapping[str, object]) -> str: # Account configuration # --------------------------------------------------------------------------- + @dataclass(frozen=True, slots=True) class WeixinGatewayAccountConfig: account_id: str = DEFAULT_GATEWAY_ACCOUNT_ID diff --git a/apps/launcher.py b/apps/launcher.py index b40d368..74a9e15 100644 --- a/apps/launcher.py +++ b/apps/launcher.py @@ -9,20 +9,46 @@ from apps.runtime_layout import default_cli_state_dir -from .cli.shell import Align, BRAND_ACCENT, BRAND_LIGHT, BRAND_MUTED, Console, Group, Panel, RICH_AVAILABLE, Text, _resolve_elephant_version +from .cli.shell import ( + Align, + BRAND_ACCENT, + BRAND_LIGHT, + BRAND_MUTED, + Console, + Group, + Panel, + RICH_AVAILABLE, + Text, + _resolve_elephant_version, +) from .cli.typer_support import run_typer_app -from .cli.cli_main_support import CLI_COMMAND_HELP, CLI_HELP_COMMANDS, CLI_HELP_NEXT_COMMANDS, CLI_HELP_TAGLINE, _print_cli_help, _render_cli_banner_mark +from .cli.cli_main_support import ( + CLI_COMMAND_HELP, + CLI_HELP_COMMANDS, + CLI_HELP_NEXT_COMMANDS, + CLI_HELP_TAGLINE, + _print_cli_help, + _render_cli_banner_mark, +) LAUNCHER_COMMAND_HELP = { **CLI_COMMAND_HELP, "upgrade": "Gracefully upgrade Elephant Agent, preserving state and restarting managed runtimes.", } -LAUNCHER_HELP_COMMANDS = (*CLI_HELP_COMMANDS, ("upgrade", LAUNCHER_COMMAND_HELP["upgrade"])) +LAUNCHER_HELP_COMMANDS = ( + *CLI_HELP_COMMANDS, + ("upgrade", LAUNCHER_COMMAND_HELP["upgrade"]), +) def _ensure_config_yaml(state_dir: Path) -> None: """Ensure config.yaml exists so the configuration is visible.""" - from packages.runtime_config import default_global_config, write_global_config, global_config_path_for_state_dir + from packages.runtime_config import ( + default_global_config, + write_global_config, + global_config_path_for_state_dir, + ) + state_dir.mkdir(parents=True, exist_ok=True) config_path = global_config_path_for_state_dir(state_dir) if not config_path.exists(): @@ -38,7 +64,10 @@ def _show_launcher_banner() -> None: header = Text() header.append("🥚 Elephant Agent\n", style=f"bold {BRAND_LIGHT}") header.append("Personal-model-first AI, with curiosity built in.\n", style=BRAND_MUTED) - header.append(f"🐣 v{_resolve_elephant_version()} · understands first, gets curious at your pace.", style=BRAND_ACCENT) + header.append( + f"🐣 v{_resolve_elephant_version()} · understands first, gets curious at your pace.", + style=BRAND_ACCENT, + ) commands = Text() commands.append("Start here\n", style=f"bold {BRAND_ACCENT}") commands.append("🐣 elephant init\n", style=f"bold {BRAND_LIGHT}") @@ -147,7 +176,11 @@ def skills_command(ctx: typer.Context) -> None: ) ) - @app.command("gateway", help=CLI_COMMAND_HELP["gateway"], context_settings=passthrough_settings) + @app.command( + "gateway", + help=CLI_COMMAND_HELP["gateway"], + context_settings=passthrough_settings, + ) def gateway_command(ctx: typer.Context) -> None: from apps.gateway.__main__ import command_main as gateway_command_main @@ -162,7 +195,9 @@ def gateway_command(ctx: typer.Context) -> None: @app.command("cron", help=CLI_COMMAND_HELP["cron"], context_settings=passthrough_settings) def cron_command(ctx: typer.Context) -> None: - from apps.cron_scheduler_command import command_main as cron_scheduler_command_main + from apps.cron_scheduler_command import ( + command_main as cron_scheduler_command_main, + ) obj = ctx.obj or {} raise typer.Exit( @@ -173,7 +208,11 @@ def cron_command(ctx: typer.Context) -> None: ) ) - @app.command("upgrade", help=LAUNCHER_COMMAND_HELP["upgrade"], context_settings=passthrough_settings) + @app.command( + "upgrade", + help=LAUNCHER_COMMAND_HELP["upgrade"], + context_settings=passthrough_settings, + ) def upgrade_command(ctx: typer.Context) -> None: from apps.upgrade_command import command_main as upgrade_command_main @@ -206,7 +245,11 @@ def supervisor_command(ctx: typer.Context) -> None: ) ) - @app.command("dashboard", help=CLI_COMMAND_HELP["dashboard"], context_settings=passthrough_settings) + @app.command( + "dashboard", + help=CLI_COMMAND_HELP["dashboard"], + context_settings=passthrough_settings, + ) def dashboard_command(ctx: typer.Context) -> None: from apps.dashboard_command import command_main as dashboard_command_main @@ -218,7 +261,11 @@ def dashboard_command(ctx: typer.Context) -> None: ) ) - @app.command("provider", help=CLI_COMMAND_HELP["provider"], context_settings=passthrough_settings) + @app.command( + "provider", + help=CLI_COMMAND_HELP["provider"], + context_settings=passthrough_settings, + ) def provider_passthrough(ctx: typer.Context) -> None: obj = ctx.obj or {} raise typer.Exit(_forward_cli(["provider", *ctx.args], state_dir=obj["state_dir"])) @@ -233,7 +280,11 @@ def facts_passthrough(ctx: typer.Context) -> None: obj = ctx.obj or {} raise typer.Exit(_forward_cli(["facts", *ctx.args], state_dir=obj["state_dir"])) - @app.command("reflect", help=CLI_COMMAND_HELP["reflect"], context_settings=passthrough_settings) + @app.command( + "reflect", + help=CLI_COMMAND_HELP["reflect"], + context_settings=passthrough_settings, + ) def reflect_passthrough(ctx: typer.Context) -> None: obj = ctx.obj or {} raise typer.Exit(_forward_cli(["reflect", *ctx.args], state_dir=obj["state_dir"])) diff --git a/apps/learning_worker_runtime.py b/apps/learning_worker_runtime.py index e5e3c96..206946c 100644 --- a/apps/learning_worker_runtime.py +++ b/apps/learning_worker_runtime.py @@ -237,6 +237,7 @@ def close_finished_learning_child_episode(runtime: CliRuntime, job: LearningJob, return False # Close by updating status via upsert from dataclasses import replace as _replace + closed = _replace(child, status="closed") runtime.repository.upsert_episode(closed) return True diff --git a/apps/provider_runtime_support.py b/apps/provider_runtime_support.py index 11663f7..7ae8dd8 100644 --- a/apps/provider_runtime_support.py +++ b/apps/provider_runtime_support.py @@ -24,8 +24,6 @@ AuthProfile, LocalEncryptedSecretCipher, PersistentAuthProfileStore, - ProfileCredentialResolver, - ProviderAuthState, ProviderCatalog, ProviderProfileFactory, ProviderProfileInput, @@ -34,30 +32,10 @@ SecretStore, profile_from_input, ) -from packages.capabilities.runtime import CapabilityDescriptor, ModelProviderCapability -from packages.contracts import Episode -from packages.contracts.runtime import ( - ContextBundle, - ExecutionResult, - RuntimeModelChoice, - PersonalModelRuntimeState, - GenerationModelProfile, - SupportModelProfile, -) -from packages.models import ModelRequest, ProviderRuntimeResolver from packages.models.discovery import DiscoveredProviderModel, DiscoveredProviderState -from packages.models.model_metadata import resolve_provider_model_metadata -from packages.models.provider_catalog import default_provider_definitions, provider_definition +from packages.models.provider_catalog import provider_definition from packages.models.provider_runtime import provider_auth_headers -from packages.models.providers import build_model_adapter -from packages.models.runtime_capability import ( - provider_fallback_summary, - provider_profile_summary, - generation_model_profile_from_auth_profile, - support_model_profile_from_auth_profile, -) from packages.storage import RuntimeStorageRepository -from packages.tools import ToolDefinition, ToolRuntime _MODEL_CONTEXT_KEYS = ( "context_length", @@ -240,7 +218,11 @@ def _query_ollama_context_window(*, model_id: str, base_url: str, timeout_second with request.urlopen(http_request, timeout=timeout_seconds) as response: raw_body = response.read().decode("utf-8") payload = json.loads(raw_body) if raw_body else {} - except (error.HTTPError, error.URLError, json.JSONDecodeError): # pragma: no cover - covered by caller fallback + except ( + error.HTTPError, + error.URLError, + json.JSONDecodeError, + ): # pragma: no cover - covered by caller fallback return None if not isinstance(payload, Mapping): return None @@ -413,16 +395,13 @@ def _read_google_gemini_oauth_resolution() -> SecretValueResolution | None: return None -def _read_anthropic_token_from_payload(path: Path, payload: Mapping[str, Any], *, source: str) -> SecretValueResolution | None: +def _read_anthropic_token_from_payload( + path: Path, payload: Mapping[str, Any], *, source: str +) -> SecretValueResolution | None: claude_code_oauth = payload.get("claudeAiOauth") if isinstance(claude_code_oauth, Mapping): payload = {str(key): value for key, value in claude_code_oauth.items()} - access_token = str( - payload.get("accessToken") - or payload.get("access_token") - or payload.get("token") - or "" - ).strip() + access_token = str(payload.get("accessToken") or payload.get("access_token") or payload.get("token") or "").strip() if not access_token: return None expires_at = payload.get("expiresAt") or payload.get("expires_at") @@ -458,11 +437,7 @@ def _read_copilot_resolution() -> SecretValueResolution | None: value = str(os.environ.get(env_name) or "").strip() if value and not value.startswith("ghp_"): return SecretValueResolution(value=value, source=f"env:{env_name}") - clean_env = { - key: value - for key, value in os.environ.items() - if key not in {"GH_TOKEN", "GITHUB_TOKEN"} - } + clean_env = {key: value for key, value in os.environ.items() if key not in {"GH_TOKEN", "GITHUB_TOKEN"}} try: completed = subprocess.run( ["gh", "auth", "token"], @@ -515,10 +490,7 @@ def _provider_base_url_from_env(provider_id: str, primary_env_var: str | None) - def provider_profile_from_payload(payload: Mapping[str, Any]) -> AuthProfile: if "profile_id" not in payload or "provider_id" not in payload: raise ValueError("provider_profile must include profile_id and provider_id") - secret_references = tuple( - secret_reference_from_payload(item) - for item in payload.get("secret_references", ()) - ) + secret_references = tuple(secret_reference_from_payload(item) for item in payload.get("secret_references", ())) profile_input = ProviderProfileInput( profile_id=str(payload["profile_id"]), provider_id=str(payload["provider_id"]), @@ -539,7 +511,17 @@ def provider_profile_from_payload(payload: Mapping[str, Any]) -> AuthProfile: provider_defaults = catalog.get(provider_id) if provider_id == "openai-compatible" and (base_url is None or default_model is None): raise ValueError("openai-compatible provider profiles require base_url and default_model") - if any(value is not None for value in (base_url, default_model, transport_id, auth_method, provider_kind, extra_headers)): + if any( + value is not None + for value in ( + base_url, + default_model, + transport_id, + auth_method, + provider_kind, + extra_headers, + ) + ): default_profile = None if provider_defaults is not None: default_profile = ProviderProfileFactory(catalog).from_provider_defaults( @@ -611,7 +593,11 @@ def secret_reference_from_payload(payload: Mapping[str, Any]) -> SecretReference def load_provider_profile(state_dir: Path, *, config_path: Path | None = None) -> AuthProfile | None: """Load the active provider profile from config.yaml (models.provider).""" if config_path is not None: - from packages.runtime_config import load_global_config, load_provider_from_config + from packages.runtime_config import ( + load_global_config, + load_provider_from_config, + ) + try: config = load_global_config(config_path, state_dir=state_dir) provider_payload = load_provider_from_config(config) @@ -708,7 +694,11 @@ def resolve(self, reference: SecretReference) -> SecretValueResolution: env = self.environ or os.environ candidates: list[str] = list(reference.env_var_candidates()) seen = set(candidates) - for candidate in (reference.secret_name, reference.secret_key, reference.reference_id): + for candidate in ( + reference.secret_name, + reference.secret_key, + reference.reference_id, + ): normalized = _normalize_env_name(candidate) if normalized and normalized not in seen: seen.add(normalized) @@ -783,6 +773,7 @@ def _external_resolution(self, reference: SecretReference) -> SecretValueResolut run_embedding_bootstrap_worker as _package_run_embedding_bootstrap_worker, trigger_embedding_bootstrap as _package_trigger_embedding_bootstrap, ) + EmbeddingBootstrapState = _PackageEmbeddingBootstrapState EnvironmentSecretStore = _PackageEnvironmentSecretStore EncryptedRepositorySecretStore = _PackageEncryptedRepositorySecretStore @@ -819,8 +810,4 @@ def register_provider_profile( "provider_fallback_summary", } -__all__ = [ - name - for name in globals() - if not name.startswith("_") and name not in _APP_PROVIDER_RUNTIME_COMPAT_EXPORTS -] +__all__ = [name for name in globals() if not name.startswith("_") and name not in _APP_PROVIDER_RUNTIME_COMPAT_EXPORTS] diff --git a/apps/reflect/evidence.py b/apps/reflect/evidence.py index fec657b..e288c5e 100644 --- a/apps/reflect/evidence.py +++ b/apps/reflect/evidence.py @@ -119,11 +119,13 @@ def _build_compress_evidence(metadata: dict[str, Any]) -> str: ] if previous_summary: lines.extend(["", "## Previous summary (for continuity)", previous_summary]) - lines.extend([ - "", - "## Conversation to compress", - compressed_messages or "(no content)", - ]) + lines.extend( + [ + "", + "## Conversation to compress", + compressed_messages or "(no content)", + ] + ) if tail_hint: lines.extend(["", "## Recent context (do NOT summarize, for handoff only)", tail_hint]) return "\n".join(lines) @@ -203,7 +205,10 @@ def _episode_turn_summary(runtime: Any, *, episode_id: str) -> tuple[str, ...]: if tool_counts: total_calls = sum(tool_counts.values()) - tool_parts = [f"{name} ×{count}" if count > 1 else name for name, count in sorted(tool_counts.items(), key=lambda x: -x[1])[:6]] + tool_parts = [ + f"{name} ×{count}" if count > 1 else name + for name, count in sorted(tool_counts.items(), key=lambda x: -x[1])[:6] + ] lines.append(f" [tools: {total_calls} calls — {', '.join(tool_parts)}]") if skills_used: @@ -253,14 +258,16 @@ def build_evidence( if str(job.trigger or "").strip().lower() == "init_profile": init_answers = _init_profile_answer_lines(metadata) portrait = _pm_portrait_lines(active_facts) - lines.extend([ - "", - "## Init profile answers", - *(init_answers or ("(none)",)), - "", - "## Bootstrapped Personal Model facts", - *(portrait or ("(no facts yet)",)), - ]) + lines.extend( + [ + "", + "## Init profile answers", + *(init_answers or ("(none)",)), + "", + "## Bootstrapped Personal Model facts", + *(portrait or ("(no facts yet)",)), + ] + ) if "dream" in feature_ids: target_date = str(metadata.get("target_date") or "today").strip() or "today" @@ -271,12 +278,14 @@ def build_evidence( user_tz = user.timezone except Exception: pass - lines.extend([ - "", - "## Dream context", - f"target_date: {target_date}", - f"user_timezone: {user_tz}", - ]) + lines.extend( + [ + "", + "## Dream context", + f"target_date: {target_date}", + f"user_timezone: {user_tz}", + ] + ) # Episode evidence for features that learn from the supplied close packet. # Dream is a scheduled consolidation mode and intentionally receives no @@ -286,16 +295,21 @@ def build_evidence( and "dream" not in feature_ids and feature_ids & {"pm", "questions", "skills"} ): - episode_summary = _compact(getattr(episode, "exit_summary", "") if episode is not None else "", limit=700) + episode_summary = _compact( + getattr(episode, "exit_summary", "") if episode is not None else "", + limit=700, + ) turn_lines = _episode_turn_summary(runtime, episode_id=job.episode_id) - lines.extend([ - "", - "## Episode summary", - *(tuple(item for item in (episode_summary,) if item) or ("(none)",)), - "", - "## Conversation turns", - *(turn_lines or ("(no conversation data)",)), - ]) + lines.extend( + [ + "", + "## Episode summary", + *(tuple(item for item in (episode_summary,) if item) or ("(none)",)), + "", + "## Conversation turns", + *(turn_lines or ("(no conversation data)",)), + ] + ) # Diary-specific context if "diary" in feature_ids: @@ -308,14 +322,16 @@ def build_evidence( except Exception: pass portrait = _pm_portrait_lines(active_facts) - lines.extend([ - "", - "## Diary context", - f"target_date: {target_date}", - f"user_timezone: {user_tz}", - "", - "## Who this person is (active PM facts)", - *(portrait or ("(no facts yet)",)), - ]) + lines.extend( + [ + "", + "## Diary context", + f"target_date: {target_date}", + f"user_timezone: {user_tz}", + "", + "## Who this person is (active PM facts)", + *(portrait or ("(no facts yet)",)), + ] + ) return "\n".join(lines) diff --git a/apps/reflect/features/__init__.py b/apps/reflect/features/__init__.py index 18169ba..ec8605f 100644 --- a/apps/reflect/features/__init__.py +++ b/apps/reflect/features/__init__.py @@ -15,9 +15,7 @@ from .dream import FEATURE as DREAM -ALL_FEATURES: dict[str, Feature] = { - f.feature_id: f for f in (PM, QUESTIONS, RECALL, DIARY, SKILLS, COMPRESS, DREAM) -} +ALL_FEATURES: dict[str, Feature] = {f.feature_id: f for f in (PM, QUESTIONS, RECALL, DIARY, SKILLS, COMPRESS, DREAM)} # Trigger → default feature set mapping TRIGGER_FEATURES: dict[str, tuple[str, ...]] = { diff --git a/apps/reflect/runner.py b/apps/reflect/runner.py index 4fd28fe..6383cab 100644 --- a/apps/reflect/runner.py +++ b/apps/reflect/runner.py @@ -11,7 +11,13 @@ from .evidence import build_evidence from .features import TRIGGER_CONSERVATISM, resolve_features from .features.types import Feature -from .prompts import BOUNDARIES, CLAIM_TEXT_RULE, CONSERVATISM_PROMPTS, LANGUAGE_RULE, TOPIC_FORMAT +from .prompts import ( + BOUNDARIES, + CLAIM_TEXT_RULE, + CONSERVATISM_PROMPTS, + LANGUAGE_RULE, + TOPIC_FORMAT, +) @dataclass(frozen=True, slots=True) @@ -125,7 +131,12 @@ def _reflect_result_payload( features: tuple[str, ...], ) -> dict[str, object]: has_writes = any( - name in ("tool.personal_model.update", "tool.personal_model.questions", "tool.diary.write") + name + in ( + "tool.personal_model.update", + "tool.personal_model.questions", + "tool.diary.write", + ) for name in tool_names ) status = "completed" if has_writes else "no_op" diff --git a/apps/site/build.sh b/apps/site/build.sh index aeb0090..4b7b2d0 100755 --- a/apps/site/build.sh +++ b/apps/site/build.sh @@ -2,7 +2,7 @@ set -eu -ROOT_DIR=$(CDPATH= cd -- "$(dirname "$0")" && pwd) +ROOT_DIR=$(cd -- "$(dirname "$0")" && pwd) DIST_DIR="${ROOT_DIR}/dist" ensure_node_modules() { diff --git a/apps/site/docs/capacities/embeddings.md b/apps/site/docs/capacities/embeddings.md index d49af9c..f267f2f 100644 --- a/apps/site/docs/capacities/embeddings.md +++ b/apps/site/docs/capacities/embeddings.md @@ -84,4 +84,3 @@ understanding?** That boundary is why retrieval can be powerful without turning every retrieved chunk into truth. - diff --git a/apps/site/docs/learning/correctable.md b/apps/site/docs/learning/correctable.md index 22a8351..8d1b21a 100644 --- a/apps/site/docs/learning/correctable.md +++ b/apps/site/docs/learning/correctable.md @@ -76,4 +76,3 @@ Personal Model claim operations. | Dashboard You | Reviewing and correcting active claims. | | Dashboard Curiosity | Answering or dismissing open questions. | | Dashboard History / Why | Inspecting source support before changing a claim. | - diff --git a/apps/site/docs/learning/proactive.md b/apps/site/docs/learning/proactive.md index dd1364b..dd22a62 100644 --- a/apps/site/docs/learning/proactive.md +++ b/apps/site/docs/learning/proactive.md @@ -55,4 +55,3 @@ flowchart TD - `wake` can surface a question when the cadence allows. - Answered questions can become Personal Model claims with provenance. - Dismissed or ignored questions should not escalate into nagging. - diff --git a/apps/site/docs/user-interface/dashboard.md b/apps/site/docs/user-interface/dashboard.md index 45bb40b..d074c45 100644 --- a/apps/site/docs/user-interface/dashboard.md +++ b/apps/site/docs/user-interface/dashboard.md @@ -82,4 +82,3 @@ flowchart LR Use the CLI when you want to work. Use the dashboard when you want to inspect, correct, or understand what happened. - diff --git a/apps/site/preview.sh b/apps/site/preview.sh index 78e0fc9..4a44520 100755 --- a/apps/site/preview.sh +++ b/apps/site/preview.sh @@ -2,7 +2,7 @@ set -eu -ROOT_DIR=$(CDPATH= cd -- "$(dirname "$0")" && pwd) +ROOT_DIR=$(cd -- "$(dirname "$0")" && pwd) DIST_DIR="${ROOT_DIR}/dist" REQUESTED_PORT=${PORT:-4180} PORT_VALUE=${REQUESTED_PORT} diff --git a/apps/supervisor_command.py b/apps/supervisor_command.py index b7b3f9a..f986163 100644 --- a/apps/supervisor_command.py +++ b/apps/supervisor_command.py @@ -19,7 +19,6 @@ import json import logging import signal -import sys import threading from typing import Sequence diff --git a/apps/upgrade_command.py b/apps/upgrade_command.py index 51e6cae..0bc65ba 100644 --- a/apps/upgrade_command.py +++ b/apps/upgrade_command.py @@ -217,9 +217,7 @@ def stop_runtime(runtime: ManagedRuntimeSnapshot, *, timeout_seconds: float, for return "sigterm" time.sleep(0.2) if not force: - raise RuntimeError( - f"{runtime.label} did not exit within {timeout_seconds:g}s; rerun with --force-stop" - ) + raise RuntimeError(f"{runtime.label} did not exit within {timeout_seconds:g}s; rerun with --force-stop") try: os.kill(runtime.pid, signal.SIGKILL) except ProcessLookupError: @@ -382,7 +380,16 @@ def run_upgrade(args: Namespace) -> int: print("\nUpgrading package") _run_checked( - [sys.executable, "-m", "pip", "install", "--upgrade", "pip", "setuptools", "wheel"], + [ + sys.executable, + "-m", + "pip", + "install", + "--upgrade", + "pip", + "setuptools", + "wheel", + ], env=env, dry_run=dry_run, ) @@ -430,17 +437,64 @@ def build_parser() -> ArgumentParser: prog="elephant upgrade", description="Gracefully upgrade Elephant Agent in place with backup, managed-runtime stop, storage bootstrap, and restart.", ) - parser.add_argument("--state-dir", type=Path, default=default_cli_state_dir(), help="CLI state directory.") - parser.add_argument("--gateway-state-dir", type=Path, default=default_gateway_state_dir(), help="Gateway runtime state directory.") - parser.add_argument("--channel", choices=("dev", "stable"), default=os.environ.get("ELEPHANT_INSTALL_CHANNEL", "dev"), help="Package channel to install when --pip-spec is omitted.") - parser.add_argument("--pip-spec", default=os.environ.get("ELEPHANT_PIP_SPEC", "") or None, help="Explicit pip-installable package spec.") - parser.add_argument("--timeout", type=float, default=10.0, help="Seconds to wait for each runtime to stop before escalation.") - parser.add_argument("--force-stop", action="store_true", default=True, help="Send SIGKILL when a runtime ignores SIGTERM after --timeout.") - parser.add_argument("--no-force-stop", action="store_false", dest="force_stop", help="Fail instead of sending SIGKILL after --timeout.") + parser.add_argument( + "--state-dir", + type=Path, + default=default_cli_state_dir(), + help="CLI state directory.", + ) + parser.add_argument( + "--gateway-state-dir", + type=Path, + default=default_gateway_state_dir(), + help="Gateway runtime state directory.", + ) + parser.add_argument( + "--channel", + choices=("dev", "stable"), + default=os.environ.get("ELEPHANT_INSTALL_CHANNEL", "dev"), + help="Package channel to install when --pip-spec is omitted.", + ) + parser.add_argument( + "--pip-spec", + default=os.environ.get("ELEPHANT_PIP_SPEC", "") or None, + help="Explicit pip-installable package spec.", + ) + parser.add_argument( + "--timeout", + type=float, + default=10.0, + help="Seconds to wait for each runtime to stop before escalation.", + ) + parser.add_argument( + "--force-stop", + action="store_true", + default=True, + help="Send SIGKILL when a runtime ignores SIGTERM after --timeout.", + ) + parser.add_argument( + "--no-force-stop", + action="store_false", + dest="force_stop", + help="Fail instead of sending SIGKILL after --timeout.", + ) parser.add_argument("--no-backup", action="store_true", help="Skip the pre-upgrade state backup.") - parser.add_argument("--skip-restart", action="store_true", help="Do not restart runtimes that were running before the upgrade.") - parser.add_argument("--skip-browser-install", action="store_true", default=os.environ.get("ELEPHANT_SKIP_BROWSER_INSTALL") == "1", help="Skip Playwright Chromium refresh.") - parser.add_argument("--dry-run", action="store_true", help="Print the planned upgrade without changing files or processes.") + parser.add_argument( + "--skip-restart", + action="store_true", + help="Do not restart runtimes that were running before the upgrade.", + ) + parser.add_argument( + "--skip-browser-install", + action="store_true", + default=os.environ.get("ELEPHANT_SKIP_BROWSER_INSTALL") == "1", + help="Skip Playwright Chromium refresh.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print the planned upgrade without changing files or processes.", + ) return parser diff --git a/docs/paper/Elephant_Agent__Personal_Model_First_Evolution_for_Personal_AI.aux b/docs/paper/Elephant_Agent__Personal_Model_First_Evolution_for_Personal_AI.aux index 111fea0..31724e1 100644 --- a/docs/paper/Elephant_Agent__Personal_Model_First_Evolution_for_Personal_AI.aux +++ b/docs/paper/Elephant_Agent__Personal_Model_First_Evolution_for_Personal_AI.aux @@ -1,4 +1,4 @@ -\relax +\relax \providecommand \babel@aux [2]{\global \let \babel@toc \@gobbletwo } \@nameuse{bbl@beforestart} \providecommand\hyper@newdestlabel[2]{} @@ -71,8 +71,8 @@ \@writefile{toc}{\contentsline {subsection}{\numberline {6.4}Local Semantic Recall}{12}{subsection.6.4}\protected@file@percent } \@writefile{toc}{\contentsline {subsection}{\numberline {6.5}Background Reflect}{12}{subsection.6.5}\protected@file@percent } \@writefile{toc}{\contentsline {subsection}{\numberline {6.6}Deletion, Correction, and Repair}{12}{subsection.6.6}\protected@file@percent } -\gdef \LT@i {\LT@entry - {1}{106.15425pt}\LT@entry +\gdef \LT@i {\LT@entry + {1}{106.15425pt}\LT@entry {1}{333.7763pt}} \@writefile{toc}{\contentsline {section}{\numberline {7}Implementation and Evaluation}{13}{section.7}\protected@file@percent } \newlabel{sec:implementation-direction}{{7}{13}{Implementation and Evaluation}{section.7}{}} diff --git a/docs/paper/Elephant_Agent__Personal_Model_First_Evolution_for_Personal_AI.log b/docs/paper/Elephant_Agent__Personal_Model_First_Evolution_for_Personal_AI.log index 5390366..1d8372b 100644 --- a/docs/paper/Elephant_Agent__Personal_Model_First_Evolution_for_Personal_AI.log +++ b/docs/paper/Elephant_Agent__Personal_Model_First_Evolution_for_Personal_AI.log @@ -8,7 +8,7 @@ entering extended mode LaTeX2e <2025-11-01> L3 programming layer <2026-01-19> (./elephant.cls -Document Class: elephant +Document Class: elephant (/usr/local/texlive/2026/texmf-dist/tex/latex/base/article.cls Document Class: article 2025/01/22 v1.4n Standard LaTeX document class (/usr/local/texlive/2026/texmf-dist/tex/latex/base/size10.clo @@ -99,11 +99,11 @@ LaTeX Encoding Info: Redeclaring text command \ij (encoding T1) on input lin e 2082. LaTeX Encoding Info: Redeclaring text command \IJ (encoding T1) on input lin e 2083. -LaTeX Encoding Info: Ignoring declaration for text command \ij (encoding ?) +LaTeX Encoding Info: Ignoring declaration for text command \ij (encoding ?) on input line 2084. -LaTeX Encoding Info: Ignoring declaration for text command \IJ (encoding ?) +LaTeX Encoding Info: Ignoring declaration for text command \IJ (encoding ?) on input line 2086. -LaTeX Encoding Info: Ignoring declaration for text command \SS (encoding ?) +LaTeX Encoding Info: Ignoring declaration for text command \SS (encoding ?) on input line 2111. \U@D=\dimen150 \l@unhyphenated=\language92 @@ -995,7 +995,7 @@ tim * paper: a4paper * layout: * layoutoffset:(h,v)=(0.0pt,0.0pt) -* modes: +* modes: * h-part:(L,W,R)=(71.13188pt, 455.24411pt, 71.13188pt) * v-part:(T,H,B)=(71.13188pt, 702.78308pt, 71.13188pt) * \paperwidth=597.50787pt @@ -1028,7 +1028,7 @@ File: assets/elephant-logo.png Graphic file (type png) Package pdftex.def Info: assets/elephant-logo.png used on input line 40. (pdftex.def) Requested size: 78.24507pt x 78.22696pt. (/usr/local/texlive/2026/texmf-dist/tex/latex/microtype/mt-cmr.cfg -File: mt-cmr.cfg 2013/05/19 v2.2 microtype config. file: Computer Modern Roman +File: mt-cmr.cfg 2013/05/19 v2.2 microtype config. file: Computer Modern Roman (RS) ) LaTeX Font Info: Trying to load font information for U+msa on input line 40. @@ -1221,7 +1221,7 @@ LaTeX Font Warning: Some font shapes were not available, defaults substituted. Package rerunfilecheck Info: File `Elephant_Agent__Personal_Model_First_Evoluti on_for_Personal_AI.out' has not changed. (rerunfilecheck) Checksum: 3DB92E103E96316ED498F4FCFE9C5071;7434. - ) + ) Here is how much of TeX's memory you used: 30437 strings out of 467525 576955 string characters out of 5418982 @@ -1252,4 +1252,3 @@ PDF statistics: 340 compressed objects within 4 object streams 86 named destinations out of 1000 (max. 500000) 14234 words of extra memory for PDF output out of 14400 (max. 10000000) - diff --git a/packages/auth/discovery.py b/packages/auth/discovery.py index e83c9e6..b9b1c5e 100644 --- a/packages/auth/discovery.py +++ b/packages/auth/discovery.py @@ -139,12 +139,7 @@ def _read_anthropic_token_from_payload( claude_code_oauth = payload.get("claudeAiOauth") if isinstance(claude_code_oauth, Mapping): payload = {str(key): value for key, value in claude_code_oauth.items()} - access_token = str( - payload.get("accessToken") - or payload.get("access_token") - or payload.get("token") - or "" - ).strip() + access_token = str(payload.get("accessToken") or payload.get("access_token") or payload.get("token") or "").strip() if not access_token: return None expires_at = payload.get("expiresAt") or payload.get("expires_at") @@ -183,11 +178,7 @@ def _read_copilot_resolution() -> SecretValueResolution | None: value = str(os.environ.get(env_name) or "").strip() if value and not value.startswith("ghp_"): return SecretValueResolution(value=value, source=f"env:{env_name}") - clean_env = { - key: value - for key, value in os.environ.items() - if key not in {"GH_TOKEN", "GITHUB_TOKEN"} - } + clean_env = {key: value for key, value in os.environ.items() if key not in {"GH_TOKEN", "GITHUB_TOKEN"}} try: completed = subprocess.run( ["gh", "auth", "token"], @@ -315,7 +306,11 @@ def resolve(self, reference: SecretReference) -> SecretValueResolution: env = self.environ or os.environ candidates: list[str] = list(reference.env_var_candidates()) seen = set(candidates) - for candidate in (reference.secret_name, reference.secret_key, reference.reference_id): + for candidate in ( + reference.secret_name, + reference.secret_key, + reference.reference_id, + ): normalized = _normalize_env_name(candidate) if normalized and normalized not in seen: seen.add(normalized) diff --git a/packages/auth/runtime.py b/packages/auth/runtime.py index f0a0e91..f7d2a76 100644 --- a/packages/auth/runtime.py +++ b/packages/auth/runtime.py @@ -583,7 +583,17 @@ def profile_from_input( extra_headers: Mapping[str, str] | None = None, ) -> AuthProfile: factory = ProviderProfileFactory(catalog=catalog) - if any(value is not None for value in (base_url, default_model, transport_id, auth_method, provider_kind, extra_headers)): + if any( + value is not None + for value in ( + base_url, + default_model, + transport_id, + auth_method, + provider_kind, + extra_headers, + ) + ): return factory.from_compatible_endpoint( profile_id=profile_input.profile_id, provider_id=profile_input.provider_id, diff --git a/packages/context/compress.py b/packages/context/compress.py index 26f5045..4bc6bf7 100644 --- a/packages/context/compress.py +++ b/packages/context/compress.py @@ -89,7 +89,8 @@ def compress_epoch( if reflect_compressor is not None: try: summary = reflect_compressor( - to_summarize, tail, + to_summarize, + tail, session_id=resolved_session_id, context_limit=context_limit, ) @@ -105,7 +106,9 @@ def compress_epoch( # Apply compression: replace history with tail, set summary before_tokens = estimate_epoch_prompt_tokens( - epoch, history_messages=history, compacted_summary=epoch.compacted_history_summary, + epoch, + history_messages=history, + compacted_summary=epoch.compacted_history_summary, ) updated, _result = compact_session_context_epoch( epoch, @@ -122,6 +125,7 @@ def compress_epoch( _append_prompt_section, invalidate_prefix_cache, ) + invalidate_prefix_cache(resolved_session_id) updated_prefix = _strip_prompt_sections(updated.frozen_prefix, "Episode resume") updated_prefix = _append_prompt_section( @@ -132,7 +136,9 @@ def compress_epoch( updated = replace(updated, frozen_prefix=updated_prefix) after_tokens = estimate_epoch_prompt_tokens( - updated, history_messages=tail, compacted_summary=summary, + updated, + history_messages=tail, + compacted_summary=summary, ) result = CompressResult( @@ -179,11 +185,7 @@ def split_for_compress( tail_start = user_starts[-protected_tail_turns] # Ensure we're actually compressing a meaningful amount (>30%) if tail_start > total * 0.3: - selected_groups = { - group_index - for group_index, (start, _end) in enumerate(groups) - if start >= tail_start - } + selected_groups = {group_index for group_index, (start, _end) in enumerate(groups) if start >= tail_start} to_summarize, tail = _split_by_tail_groups(messages, groups, selected_groups) if to_summarize: return to_summarize, tail @@ -206,10 +208,7 @@ def split_for_compress( group_messages = messages[start:end] if any(msg.role == "user" and msg.content.strip() for msg in group_messages): tail_groups.add(group_index) - elif any( - msg.role == "assistant" and msg.content.strip() and not msg.tool_calls - for msg in group_messages - ): + elif any(msg.role == "assistant" and msg.content.strip() and not msg.tool_calls for msg in group_messages): tail_groups.add(group_index) to_summarize, tail = _split_by_tail_groups(messages, groups, tail_groups) @@ -220,11 +219,7 @@ def split_for_compress( # provider-visible tool call/result group. cut = max(1, int(total * 0.6)) tail_start = _group_boundary_after_index(groups, cut) - tail_groups = { - group_index - for group_index, (start, _end) in enumerate(groups) - if start >= tail_start - } + tail_groups = {group_index for group_index, (start, _end) in enumerate(groups) if start >= tail_start} return _split_by_tail_groups(messages, groups, tail_groups) @@ -258,11 +253,7 @@ def _preservable_live_group(group: tuple[PromptMessage, ...]) -> bool: tool_messages = tuple(message for message in group[1:] if message.role == "tool") if not tool_messages: return False - call_ids = { - call_id - for call in first.tool_calls - if (call_id := tool_call_id(call)) - } + call_ids = {call_id for call in first.tool_calls if (call_id := tool_call_id(call))} if not call_ids: return True result_ids = {message.tool_call_id for message in tool_messages if message.tool_call_id} diff --git a/packages/context/projection.py b/packages/context/projection.py index 624b38f..c1826a3 100644 --- a/packages/context/projection.py +++ b/packages/context/projection.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from dataclasses import replace from datetime import datetime import hashlib @@ -11,8 +11,18 @@ from typing import Any from packages.contracts.layers import Episode -from packages.contracts.runtime import ContextBundle, ExecutionResult, PersonalModelRuntimeState, PromptEnvelope, PromptMessage -from packages.embeddings import EmbeddingPreloadEntry, cosine_similarity, embedding_runtime_is_loaded +from packages.contracts.runtime import ( + ContextBundle, + ExecutionResult, + PersonalModelRuntimeState, + PromptEnvelope, + PromptMessage, +) +from packages.embeddings import ( + EmbeddingPreloadEntry, + cosine_similarity, + embedding_runtime_is_loaded, +) from packages.context.projection_types import ( ContextProjectionCompactionResult, ProjectionCompactionPolicy, @@ -33,9 +43,7 @@ projection_group_preload_entries as _projection_group_preload_entries, projection_query_text as _projection_query_text, normalize_projection_query_text as _normalize_projection_query_text, - projection_result_with_estimated_tokens, prompt_message_projection_line as _prompt_message_projection_line, - tool_call_id as _tool_call_id, tool_call_name as _tool_call_name, ) @@ -60,6 +68,7 @@ def _get_tiktoken_encoding(): return _TIKTOKEN_ENCODING try: import tiktoken + _TIKTOKEN_ENCODING = tiktoken.get_encoding("cl100k_base") _USE_TIKTOKEN = True return _TIKTOKEN_ENCODING @@ -73,7 +82,7 @@ def estimate_projection_tokens(text: str) -> int: normalized = str(text or "") if not normalized: return 0 - + # Use tiktoken for more accurate estimation if available encoding = _get_tiktoken_encoding() if encoding is not None: @@ -82,7 +91,7 @@ def estimate_projection_tokens(text: str) -> int: except Exception: # Fallback to character-based estimation pass - + # Fallback to character-based estimation return max(1, len(normalized) // _CHARS_PER_TOKEN) @@ -145,8 +154,7 @@ def prune_messages(self, messages: Sequence[PromptMessage], *, max_chars: int = tool_name = message.tool_name.strip() or self._tool_name(line) content = _compact_text(message.content, limit=max(80, max_chars // 2)) pruned.append( - f"tool-result-pruned: {tool_name} " - f"tool_call_id={message.tool_call_id or ''} | {content}" + f"tool-result-pruned: {tool_name} tool_call_id={message.tool_call_id or ''} | {content}" ) continue if message.role == "assistant" and message.tool_calls: @@ -245,7 +253,11 @@ def _snapshot_cached_vectors( cached_vector = getattr(self.embedding_service, "cached_vector", None) pending_vector = getattr(self.embedding_service, "pending_vector", None) if not callable(cached_vector): - return None, {}, ProjectionSemanticAnchorStats(candidate_count=len(candidates)) + return ( + None, + {}, + ProjectionSemanticAnchorStats(candidate_count=len(candidates)), + ) query_key = _projection_embedding_cache_key("query", query) query_vector = None @@ -337,7 +349,11 @@ def _queue_missing_backfill( entries[query_key] = EmbeddingPreloadEntry( cache_key=query_key, text=query, - metadata={"surface": "context-projection", "kind": "query", "priority": "recent"}, + metadata={ + "surface": "context-projection", + "kind": "query", + "priority": "recent", + }, ) for index, candidate in reversed(tuple(enumerate(candidates))): normalized_candidate = _projection_embedding_text(candidate) @@ -345,11 +361,14 @@ def _queue_missing_backfill( continue key = _projection_embedding_cache_key("group", normalized_candidate) try: - if self.embedding_service.cached_vector( - target=PROJECTION_EMBEDDING_TARGET, - cache_key=key, - dimensions=self.dimensions, - ) is not None: + if ( + self.embedding_service.cached_vector( + target=PROJECTION_EMBEDDING_TARGET, + cache_key=key, + dimensions=self.dimensions, + ) + is not None + ): continue except Exception: continue @@ -419,7 +438,11 @@ def queue_projection_history_embedding_backfill( entries[query_key] = EmbeddingPreloadEntry( cache_key=query_key, text=query_text, - metadata={"surface": "context-projection", "kind": "query", "priority": "recent"}, + metadata={ + "surface": "context-projection", + "kind": "query", + "priority": "recent", + }, ) for entry in _projection_group_preload_entries(messages, recent_first=True): entries.setdefault(entry.cache_key, entry) @@ -462,12 +485,20 @@ def summarize( parts.extend( ( "## Relevant facts/events", - _compact_text(previous_summary, limit=max(360, token_budget * _CHARS_PER_TOKEN // 3)), + _compact_text( + previous_summary, + limit=max(360, token_budget * _CHARS_PER_TOKEN // 3), + ), ) ) tail_note = _latest_user_line(protected_tail) or thread_focus.strip() if tail_note: - parts.extend(("## Handoff notes for recent tail", f"- {_compact_text(tail_note, limit=220)}")) + parts.extend( + ( + "## Handoff notes for recent tail", + f"- {_compact_text(tail_note, limit=220)}", + ) + ) return _ensure_reference_only_summary("\n".join(parts).strip()) @@ -588,13 +619,25 @@ def _prompt( def _message_projection_surface(message: PromptMessage) -> str: metadata = dict(message.metadata or {}) - return str(metadata.get("projection_surface") or metadata.get("surface") or metadata.get("source") or "").strip().lower() + return ( + str(metadata.get("projection_surface") or metadata.get("surface") or metadata.get("source") or "") + .strip() + .lower() + ) def _history_is_im(messages: tuple[PromptMessage, ...]) -> bool: for message in messages: surface = _message_projection_surface(message) - if surface == "im" or surface.startswith("gateway:") or surface.startswith("feishu") or surface.startswith("wecom") or surface.startswith("weixin") or surface.startswith("dingding") or surface.startswith("discord"): + if ( + surface == "im" + or surface.startswith("gateway:") + or surface.startswith("feishu") + or surface.startswith("wecom") + or surface.startswith("weixin") + or surface.startswith("dingding") + or surface.startswith("discord") + ): return True return False @@ -619,7 +662,10 @@ def _im_burst_tail_start( ) -> int | None: if not groups: return None - latest = next((_message_created_at(message) for message in reversed(messages) if _message_created_at(message) is not None), None) + latest = next( + (_message_created_at(message) for message in reversed(messages) if _message_created_at(message) is not None), + None, + ) if latest is None: return None tail_start = len(messages) @@ -628,16 +674,16 @@ def _im_burst_tail_start( if end <= head_end: break group_times = tuple( - timestamp - for message in messages[start:end] - if (timestamp := _message_created_at(message)) is not None + timestamp for message in messages[start:end] if (timestamp := _message_created_at(message)) is not None ) group_time = max(group_times) if group_times else previous_group_time if group_time is None: break if max(0.0, (latest - group_time).total_seconds()) > max(1, window_seconds): break - if previous_group_time is not None and max(0.0, (previous_group_time - group_time).total_seconds()) > max(1, idle_gap_seconds): + if previous_group_time is not None and max(0.0, (previous_group_time - group_time).total_seconds()) > max( + 1, idle_gap_seconds + ): break tail_start = start previous_group_time = group_time @@ -670,7 +716,9 @@ def compact_messages( reason: str = "manual", force: bool = False, ) -> SessionMessageProjection: - normalized = tuple(message for message in (_normalize_prompt_message(message) for message in messages) if message is not None) + normalized = tuple( + message for message in (_normalize_prompt_message(message) for message in messages) if message is not None + ) rendered = tuple(_prompt_message_projection_line(message) for message in normalized) before_tokens = estimate_projection_lines_tokens(rendered) before_count = len(normalized) @@ -692,7 +740,10 @@ def compact_messages( head_count=min(self.policy.protected_head_lines, before_count), tail_count=min( self.policy.protected_tail_lines, - max(0, before_count - min(self.policy.protected_head_lines, before_count)), + max( + 0, + before_count - min(self.policy.protected_head_lines, before_count), + ), ), ), summary_hash=_projection_summary_hash(previous_summary), @@ -743,7 +794,13 @@ def compact_messages( head_count=len(head), tail_count=len(tail), ) - anchor_messages, compacted_middle, anchor_stats, selected_raw_ids, compaction_query = self._semantic_anchor_messages( + ( + anchor_messages, + compacted_middle, + anchor_stats, + selected_raw_ids, + compaction_query, + ) = self._semantic_anchor_messages( middle, thread_focus=thread_focus, protected_tail=tail, @@ -826,9 +883,7 @@ def _split_messages( return (), (), () groups = _message_groups(messages) resolved_head_lines = ( - self.policy.protected_head_lines - if protected_head_lines is None - else max(0, protected_head_lines) + self.policy.protected_head_lines if protected_head_lines is None else max(0, protected_head_lines) ) head_target = min(resolved_head_lines, len(messages)) head_end = _group_end_at_or_after(groups, head_target) @@ -863,9 +918,7 @@ def _tail_start( if not remaining_groups: return len(messages) min_tail_messages = ( - self.policy.protected_tail_lines - if protected_tail_lines is None - else max(0, protected_tail_lines) + self.policy.protected_tail_lines if protected_tail_lines is None else max(0, protected_tail_lines) ) if force and len(messages) > head_end + 3: min_tail_messages = max(3, self.policy.protected_tail_lines // 2) @@ -888,7 +941,9 @@ def _tail_start( tail_start = len(messages) for start, end in reversed(remaining_groups): group = messages[start:end] - token_count += estimate_projection_lines_tokens(tuple(_prompt_message_projection_line(message) for message in group)) + token_count += estimate_projection_lines_tokens( + tuple(_prompt_message_projection_line(message) for message in group) + ) message_count += len(group) tail_start = start if protected_tail_lines is not None and message_count >= min_tail_messages: @@ -933,7 +988,13 @@ def _semantic_anchor_messages( ) candidate_ids = tuple(_projection_embedding_cache_key("group", candidate) for candidate in candidates) if not query: - return (), middle, ProjectionSemanticAnchorStats(candidate_count=len(candidates)), (), query + return ( + (), + middle, + ProjectionSemanticAnchorStats(candidate_count=len(candidates)), + (), + query, + ) try: ranked_groups = tuple(self.relevance_scorer.rank(query=query, candidates=candidates, limit=max_anchors)) stats = getattr(self.relevance_scorer, "last_stats", None) @@ -943,7 +1004,13 @@ def _semantic_anchor_messages( selected_group_count=len(ranked_groups), ) except Exception: - return (), middle, ProjectionSemanticAnchorStats(candidate_count=len(candidates)), (), query + return ( + (), + middle, + ProjectionSemanticAnchorStats(candidate_count=len(candidates)), + (), + query, + ) if not ranked_groups: return (), middle, stats, (), query anchor_indexes: set[int] = set() diff --git a/packages/context/projection_support.py b/packages/context/projection_support.py index 4ba3b7a..754ee57 100644 --- a/packages/context/projection_support.py +++ b/packages/context/projection_support.py @@ -69,11 +69,7 @@ def prompt_message_projection_line(message: PromptMessage) -> str: name = message.tool_name.strip() or "unknown" return f"tool: {name} tool_call_id={message.tool_call_id or ''} summary: {message.content}" if role == "assistant" and message.tool_calls: - call_names = ", ".join( - tool_call_name(call) - for call in message.tool_calls - if isinstance(call, Mapping) - ) + call_names = ", ".join(tool_call_name(call) for call in message.tool_calls if isinstance(call, Mapping)) text = message.content.strip() suffix = f" tool_calls: {call_names}" if call_names else "" return f"assistant: {text}{suffix}".strip() @@ -126,9 +122,7 @@ def projection_group_preload_entries( recent_first: bool = False, ) -> tuple[EmbeddingPreloadEntry, ...]: normalized = tuple( - message - for message in (normalize_prompt_message(message) for message in messages) - if message is not None + message for message in (normalize_prompt_message(message) for message in messages) if message is not None ) entries: list[EmbeddingPreloadEntry] = [] groups = tuple(enumerate(message_groups(normalized))) @@ -188,9 +182,7 @@ def message_groups(messages: tuple[PromptMessage, ...]) -> tuple[tuple[int, int] end = index + 1 if message.role == "assistant" and message.tool_calls: call_ids = { - call_id - for call in message.tool_calls - if isinstance(call, Mapping) and (call_id := tool_call_id(call)) + call_id for call in message.tool_calls if isinstance(call, Mapping) and (call_id := tool_call_id(call)) } while end < len(messages) and messages[end].role == "tool": if call_ids and messages[end].tool_call_id and messages[end].tool_call_id not in call_ids: diff --git a/packages/context/runtime_impl.py b/packages/context/runtime_impl.py index 0303404..6f67476 100644 --- a/packages/context/runtime_impl.py +++ b/packages/context/runtime_impl.py @@ -1,17 +1,11 @@ """Layered context runtime implementation assembled from smaller modules.""" - from __future__ import annotations -from dataclasses import dataclass, field -from datetime import datetime, timezone -import re -from typing import Any, Mapping, Protocol, runtime_checkable from packages.capabilities.runtime import CapabilityDescriptor, ContextCapability from packages.contracts.layers import Episode -from packages.contracts.runtime import ContextBundle, StateFocusDecision, RecallEvidence, StructuredTurnSlot - +from packages.contracts.runtime import ContextBundle, StateFocusDecision, RecallEvidence from .runtime_types import ( @@ -45,43 +39,21 @@ build_prompt_envelope, ) from .runtime_support import ( - _budget_for, - _work_item_line, - _evidence_line, _select_steady_recall_items, - _steady_recall_refs, - _work_item_trace_reason, _derived_source_refs, _loop_context_trace_reason, _session_snapshot_trace_reason, _request_attachment_trace_reason, - _session_snapshot_lines, - _build_retrieval_query, - _build_retrieval_reason, _estimate_tokens, _state_focus_budget_multiplier, - _truncate_lines, - _summary_content_for_layer, - _retrieval_lines, - _ReplayRequestSpec, _split_retrieval_requests, _infer_replay_specs, - _schedule_replay_requests, - _select_replay_evidence, - _replay_rank, - _project_replay_slot, - _replay_lines, - _replay_summary_lines, _replay_packet_trace_reason, - _tokenize, - _thematic_tokens, - _continuity_marker_tokens, - _context_evidence_score, - _retrieval_priority_bucket, _plan_rationale, _snapshot_work_items, ) + class LayeredContextPlanner: """Plan the layered context structure from runtime state.""" @@ -154,7 +126,14 @@ def plan( summary_requests=summary_requests, retrieval_requests=retrieval_requests, ) - rationale = _plan_rationale(session, work_items, recall_items, budgets, retrieval_requests, state_focus=state_focus) + rationale = _plan_rationale( + session, + work_items, + recall_items, + budgets, + retrieval_requests, + state_focus=state_focus, + ) frame = EpisodeFrameBuilder().build( session=session, instruction_refs=instruction_refs, @@ -199,7 +178,10 @@ def _build_budget_requests( snapshot_tokens = max( 96, int( - max(144, len(profile_snapshot_refs) * 6 + len(snapshot_work_items) * 28 + min(len(recall_items), 6) * 24) + max( + 144, + len(profile_snapshot_refs) * 6 + len(snapshot_work_items) * 28 + min(len(recall_items), 6) * 24, + ) * _state_focus_budget_multiplier(state_focus) ), ) @@ -302,8 +284,10 @@ def _build_summary_requests( ) ) replay_budget = budgets.allocation_for("replay_packet") - if replay_budget and replay_retrieval_requests and ( - replay_budget.allocated_tokens < replay_budget.requested_tokens or len(replay_retrieval_requests) > 1 + if ( + replay_budget + and replay_retrieval_requests + and (replay_budget.allocated_tokens < replay_budget.requested_tokens or len(replay_retrieval_requests) > 1) ): requests.append( ContextSummaryRequest( @@ -366,20 +350,31 @@ def _build_source_trace( summary_requests: tuple[ContextSummaryRequest, ...], retrieval_requests: tuple[ContextRetrievalRequest, ...], ) -> tuple[ContextSourceTrace, ...]: - steady_recall_items = _select_steady_recall_items(recall_items, session=session, work_items=work_items, state_focus=state_focus) + steady_recall_items = _select_steady_recall_items( + recall_items, + session=session, + work_items=work_items, + state_focus=state_focus, + ) snapshot_work_items = _snapshot_work_items(work_items, state_focus=state_focus) steady_refs = tuple(evidence.evidence_id for evidence in steady_recall_items) snapshot_retrieval_requests, replay_retrieval_requests = _split_retrieval_requests(retrieval_requests) retrieved_evidence_refs = tuple( - dict.fromkeys(evidence_ref for request in snapshot_retrieval_requests for evidence_ref in request.evidence_refs) + dict.fromkeys( + evidence_ref for request in snapshot_retrieval_requests for evidence_ref in request.evidence_refs + ) ) replay_evidence_refs = tuple( - dict.fromkeys(evidence_ref for request in replay_retrieval_requests for evidence_ref in request.evidence_refs) + dict.fromkeys( + evidence_ref for request in replay_retrieval_requests for evidence_ref in request.evidence_refs + ) ) omitted_snapshot_refs = tuple( evidence.evidence_id for evidence in recall_items - if evidence.evidence_id not in steady_refs and evidence.evidence_id not in retrieved_evidence_refs and evidence.evidence_id not in replay_evidence_refs + if evidence.evidence_id not in steady_refs + and evidence.evidence_id not in retrieved_evidence_refs + and evidence.evidence_id not in replay_evidence_refs ) traces: list[ContextSourceTrace] = [ ContextSourceTrace( @@ -420,7 +415,9 @@ def _build_source_trace( selected_refs=replay_evidence_refs, reason=_replay_packet_trace_reason(replay_retrieval_requests), omitted_refs=tuple( - evidence_ref for evidence_ref in structured_turn_refs if evidence_ref not in replay_evidence_refs + evidence_ref + for evidence_ref in structured_turn_refs + if evidence_ref not in replay_evidence_refs ), ) ) @@ -442,6 +439,7 @@ def _build_source_trace( ) return tuple(traces) + class ContextRuntime(ContextCapability): """Capability adapter for layered context assembly.""" @@ -542,15 +540,9 @@ def assemble_detailed( ) rendered = self._renderer.render(plan) prompt_envelope = build_prompt_envelope(plan.frame) - summary_by_layer = { - layer.layer_name: layer.summary - for layer in plan.layers - if layer.summary is not None - } + summary_by_layer = {layer.layer_name: layer.summary for layer in plan.layers if layer.summary is not None} retrieved_evidence_refs = tuple( - evidence_ref - for request in plan.retrieval_requests - for evidence_ref in request.evidence_refs + evidence_ref for request in plan.retrieval_requests for evidence_ref in request.evidence_refs ) bundle = ContextBundle( bundle_id=f"{session.episode_id}:context", @@ -573,6 +565,7 @@ def assemble_detailed( frame=plan.frame, ) + __all__ = [ "BudgetManager", "ContextAssemblyPlan", diff --git a/packages/context/runtime_layers.py b/packages/context/runtime_layers.py index 4b3271c..20274b4 100644 --- a/packages/context/runtime_layers.py +++ b/packages/context/runtime_layers.py @@ -1,32 +1,22 @@ """Context runtime planning protocols and deterministic implementations.""" - from __future__ import annotations -from dataclasses import dataclass, field -from datetime import datetime, timezone -import re -from typing import Any, Mapping, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable -from packages.capabilities.runtime import CapabilityDescriptor, ContextCapability from packages.contracts import Episode from packages.contracts.runtime import ( - ContextBundle, StateFocusDecision, RecallEvidence, PromptEnvelope, - StructuredTurnSlot, ) - from .runtime_types import ( ContextAssemblyPlan, - ContextAssemblyResult, ContextBudgetPlan, ContextBudgetRequest, ContextLayerBudget, - ContextLayerSnapshot, ContextRetrievalRequest, ContextSourceTrace, ContextSummaryRequest, @@ -55,6 +45,7 @@ _truncate_lines, ) + def _operational_layer_heading(layer_name: str) -> str: """Human-readable heading for each context frame layer. @@ -87,7 +78,9 @@ def _render_live_prompt_section( suppress_heading: bool = False, ) -> str: normalized_summary = str(summary or "").strip() - if normalized_summary.casefold() == "no content" and not tuple(str(line).strip() for line in content if str(line).strip()): + if normalized_summary.casefold() == "no content" and not tuple( + str(line).strip() for line in content if str(line).strip() + ): normalized_summary = "" if normalized_summary and summary_replaces_content: lines: list[str] = [] @@ -148,6 +141,7 @@ def build_prompt_envelope(frame: EpisodeFrame | None) -> PromptEnvelope: loop_context="\n\n".join(part for part in loop_parts if part.strip()), ) + @runtime_checkable class SummaryHook(Protocol): def summarize( @@ -161,6 +155,7 @@ def summarize( ) -> str: """Summarize content for a single context layer.""" + @runtime_checkable class RetrievalScheduler(Protocol): def schedule( @@ -176,16 +171,19 @@ def schedule( ) -> tuple[ContextRetrievalRequest, ...]: """Schedule retrieval requests for the current session.""" + @runtime_checkable class BudgetManager(Protocol): def allocate(self, total_tokens: int, requests: tuple[ContextBudgetRequest, ...]) -> ContextBudgetPlan: """Allocate explicit token budgets to ordered layers.""" + @runtime_checkable class PromptRenderer(Protocol): def render(self, plan: ContextAssemblyPlan) -> str: """Render a structured prompt bundle.""" + @runtime_checkable class ContextPlanner(Protocol): def plan( @@ -203,6 +201,7 @@ def plan( ) -> ContextAssemblyPlan: """Plan layered context from structured runtime state.""" + class DeterministicBudgetManager: """Allocate context budgets in explicit priority order.""" @@ -260,6 +259,7 @@ def allocate(self, total_tokens: int, requests: tuple[ContextBudgetRequest, ...] omitted_layers=tuple(dict.fromkeys(omitted)), ) + class DeterministicRetrievalScheduler: """Score recall_items deterministically against session work_items.""" @@ -323,9 +323,7 @@ def schedule( replay_budget = _budget_for(budget_plan, "replay_packet") if replay_budget <= 0: return tuple(requests) - return tuple( - requests - ) + _schedule_replay_requests( + return tuple(requests) + _schedule_replay_requests( session=session, work_items=work_items, recall_items=recall_items, @@ -334,6 +332,7 @@ def schedule( state_focus=state_focus, ) + class DeterministicSummaryHook: """Summarize a layer by compressing content into inspectable bullets.""" @@ -353,13 +352,16 @@ def summarize( # only; hidden from the model via the HTML-comment strip path) # for call-site audit and render only the bullets otherwise. del reason, layer_name # telemetry-only inputs, see docstring - body = tuple(line for line in _truncate_lines(content, token_budget) if str(line).strip().casefold() != "no content") + body = tuple( + line for line in _truncate_lines(content, token_budget) if str(line).strip().casefold() != "no content" + ) pieces: list[str] = [] pieces.extend(f"- {line}" for line in body) if session.interruption_state: pieces.append(f"- continuity: {session.interruption_state}") return "\n".join(pieces) + class MarkdownPromptRenderer: """Render the assembled plan as stable markdown-like text. @@ -401,6 +403,7 @@ def render(self, plan: ContextAssemblyPlan) -> str: lines.append("") return "\n".join(lines).strip() + class EpisodeFrameBuilder: """Build the explicit Episode frame from selected runtime slices.""" @@ -422,12 +425,14 @@ def build( state_focus: StateFocusDecision | None = None, ) -> EpisodeFrame: snapshot_work_items = _snapshot_work_items(work_items, state_focus=state_focus) - steady_recall_items = _select_steady_recall_items(recall_items, session=session, work_items=work_items, state_focus=state_focus) + steady_recall_items = _select_steady_recall_items( + recall_items, + session=session, + work_items=work_items, + state_focus=state_focus, + ) evidence_index = {evidence.evidence_id: evidence for evidence in recall_items} - summary_by_layer = { - request.layer_name: request - for request in summary_requests - } + summary_by_layer = {request.layer_name: request for request in summary_requests} snapshot_retrieval_requests, replay_retrieval_requests = _split_retrieval_requests(retrieval_requests) snapshot_summary = None snapshot_request = summary_by_layer.get("session_snapshot") @@ -452,7 +457,9 @@ def build( reason=snapshot_request.reason, ) retrieved_evidence_refs = tuple( - dict.fromkeys(evidence_ref for request in snapshot_retrieval_requests for evidence_ref in request.evidence_refs) + dict.fromkeys( + evidence_ref for request in snapshot_retrieval_requests for evidence_ref in request.evidence_refs + ) ) replay_summary = None replay_request = summary_by_layer.get("replay_packet") @@ -477,7 +484,9 @@ def build( reason=replay_request.reason, ) replay_evidence_refs = tuple( - dict.fromkeys(evidence_ref for request in replay_retrieval_requests for evidence_ref in request.evidence_refs) + dict.fromkeys( + evidence_ref for request in replay_retrieval_requests for evidence_ref in request.evidence_refs + ) ) replay_packet = None if replay_retrieval_requests: diff --git a/packages/context/runtime_support.py b/packages/context/runtime_support.py index e1fec74..5f6adc2 100644 --- a/packages/context/runtime_support.py +++ b/packages/context/runtime_support.py @@ -1,37 +1,32 @@ """Context runtime retrieval, replay, and scoring helpers.""" - from __future__ import annotations -from dataclasses import dataclass, field -from datetime import datetime, timezone +from dataclasses import dataclass import re -from typing import Any, Mapping, Protocol, runtime_checkable +from typing import Mapping -from packages.capabilities.runtime import CapabilityDescriptor, ContextCapability from packages.contracts.layers import Episode -from packages.contracts.runtime import ContextBundle, StateFocusDecision, RecallEvidence, StructuredTurnSlot - +from packages.contracts.runtime import ( + StateFocusDecision, + RecallEvidence, + StructuredTurnSlot, +) from .runtime_types import ( ContextBudgetPlan, - ContextLayerBudget, - ContextLayerSnapshot, ContextRetrievalRequest, - ContextSourceTrace, RecallEvidence, - EpisodeReplay, - EpisodeFrame, - StateSnapshot, StructuredTurnSlot, - LoopContext, ) + def _budget_for(budgets: ContextBudgetPlan, layer_name: str) -> int: allocation = budgets.allocation_for(layer_name) return allocation.allocated_tokens if allocation else 0 + def _work_item_line(work_item: object) -> str: """Human-readable work-item line for prompt injection. @@ -41,6 +36,7 @@ def _work_item_line(work_item: object) -> str: """ return f"{work_item.title} [{work_item.status}/{work_item.priority}]" + def _evidence_line(evidence: RecallEvidence) -> str: """Human-readable evidence line for prompt injection. @@ -56,11 +52,19 @@ def _evidence_line(evidence: RecallEvidence) -> str: _MEMORY_KIND_PROSE = { - "decision": "Decision", "observation": "Runtime signal", "correction": "Correction", - "preference": "Preference", "knowledge": "What you know", "relationship": "Relationship note", - "procedural": "Procedure", "style": "Style note", "core": "Core identity note", - "episodic_index": "Episode note", "episodic": "Episode note", - "work_item": "Work note", "continuity": "Continuity note", + "decision": "Decision", + "observation": "Runtime signal", + "correction": "Correction", + "preference": "Preference", + "knowledge": "What you know", + "relationship": "Relationship note", + "procedural": "Procedure", + "style": "Style note", + "core": "Core identity note", + "episodic_index": "Episode note", + "episodic": "Episode note", + "work_item": "Work note", + "continuity": "Continuity note", } @@ -77,13 +81,15 @@ def _looks_like_profile_evidence_line(line: str) -> bool: normalized = " ".join(str(line or "").casefold().split()) if not normalized: return False - return normalized.startswith(( - "what you know: preferred name", - "what you know: first language", - "what you know: city/timezone context", - "what you know: day-to-day context", - "what you know: care context", - )) + return normalized.startswith( + ( + "what you know: preferred name", + "what you know: first language", + "what you know: city/timezone context", + "what you know: day-to-day context", + "what you know: care context", + ) + ) def _content_dedup_key(text: str) -> str: @@ -93,6 +99,7 @@ def _content_dedup_key(text: str) -> str: generation_context all at once. """ from hashlib import blake2b as _blake2b + compact_text = " ".join(str(text or "").casefold().split()) while compact_text and compact_text[-1] in ".,;:!?": compact_text = compact_text[:-1].rstrip() @@ -111,6 +118,7 @@ def _state_focus_focus_work_item_ids( work_item_ids = {work_item.work_item_id for work_item in work_items} return tuple(work_item_id for work_item_id in state_focus.focus_work_item_ids if work_item_id in work_item_ids) + def _snapshot_work_items( work_items: tuple[...], *, @@ -131,6 +139,7 @@ def _snapshot_work_items( return focused return work_items + def _state_focus_budget_multiplier(state_focus: StateFocusDecision | None) -> float: if state_focus is None: return 1.0 @@ -140,6 +149,7 @@ def _state_focus_budget_multiplier(state_focus: StateFocusDecision | None) -> fl return 1.35 return 1.0 + def _select_steady_recall_items( recall_items: tuple[RecallEvidence, ...], *, @@ -153,7 +163,13 @@ def _select_steady_recall_items( scored = sorted( recall_items, key=lambda evidence: ( - -_context_evidence_score(evidence, session=session, work_items=work_items, state_focus=state_focus, layer_name="steady"), + -_context_evidence_score( + evidence, + session=session, + work_items=work_items, + state_focus=state_focus, + layer_name="steady", + ), -(evidence.created_at.timestamp() if evidence.created_at is not None else 0.0), evidence.evidence_id, ), @@ -169,6 +185,7 @@ def _select_steady_recall_items( ) ) + def _steady_recall_refs( recall_items: tuple[RecallEvidence, ...], *, @@ -177,16 +194,26 @@ def _steady_recall_refs( state_focus: StateFocusDecision | None = None, ) -> tuple[str, ...]: return tuple( - evidence.evidence_id for evidence in _select_steady_recall_items(recall_items, session=session, work_items=work_items, state_focus=state_focus) + evidence.evidence_id + for evidence in _select_steady_recall_items( + recall_items, + session=session, + work_items=work_items, + state_focus=state_focus, + ) ) + def _work_item_trace_reason(work_items: tuple[...]) -> str: if not work_items: return "no active elephant work items were available" - selected = ", ".join(f"{work_item.work_item_id}({work_item.status}/{work_item.priority})" for work_item in work_items[:3]) + selected = ", ".join( + f"{work_item.work_item_id}({work_item.status}/{work_item.priority})" for work_item in work_items[:3] + ) tail = " ..." if len(work_items) > 3 else "" return f"active elephant work items stayed visible: {selected}{tail}" + def _derived_source_refs(prefix: str, items: tuple[str, ...]) -> tuple[str, ...]: refs: list[str] = [] for index, item in enumerate(items, start=1): @@ -197,6 +224,7 @@ def _derived_source_refs(prefix: str, items: tuple[str, ...]) -> tuple[str, ...] refs.append(f"{prefix}:{index}") return tuple(refs) + def _loop_context_trace_reason(session: Episode, recent_loop_context: tuple[str, ...]) -> str: if recent_loop_context: return f"{len(recent_loop_context)} live Loop context item(s) keep the current exchange request-time only" @@ -204,6 +232,7 @@ def _loop_context_trace_reason(session: Episode, recent_loop_context: tuple[str, return f"no request-time Loop context was supplied, so the frame leans on {session.interruption_state}" return "no request-time Loop context was supplied, so the frame leans on durable snapshot state" + def _session_snapshot_trace_reason( session: Episode, work_items: tuple[...], @@ -243,11 +272,13 @@ def _session_snapshot_trace_reason( pieces.append("no durable evidence records were available") return "; ".join(pieces) + def _request_attachment_trace_reason(artifacts: tuple[str, ...]) -> str: if artifacts: return f"{len(artifacts)} request/runtime attachment(s) stayed visible for request-time steering" return "no request attachments were needed" + def _session_snapshot_lines( *, session: Episode, @@ -291,22 +322,42 @@ def _session_snapshot_lines( lines.extend(retrieval_lines) return tuple(lines) + def _build_retrieval_query( evidence: RecallEvidence, work_items: tuple[...], *, state_focus: StateFocusDecision | None = None, ) -> str: - work_item_titles = " ".join(work_item.title for work_item in work_items if work_item.work_item_id in evidence.work_item_ids) + work_item_titles = " ".join( + work_item.title for work_item in work_items if work_item.work_item_id in evidence.work_item_ids + ) focus_titles = "" state_focus_terms = "" if state_focus is not None: focus_ids = _state_focus_focus_work_item_ids(work_items, state_focus=state_focus) focus_titles = " ".join(work_item.title for work_item in work_items if work_item.work_item_id in focus_ids) - state_focus_terms = " ".join((state_focus.focus_family, state_focus.focus_scope, state_focus.context_budget)) - query = " ".join(part for part in (state_focus_terms, focus_titles, evidence.kind, evidence.content, work_item_titles) if part) + state_focus_terms = " ".join( + ( + state_focus.focus_family, + state_focus.focus_scope, + state_focus.context_budget, + ) + ) + query = " ".join( + part + for part in ( + state_focus_terms, + focus_titles, + evidence.kind, + evidence.content, + work_item_titles, + ) + if part + ) return query[:240] + def _build_retrieval_reason( evidence: RecallEvidence, work_items: tuple[...], @@ -326,8 +377,7 @@ def _build_retrieval_reason( focus_titles = [ work_item.title.strip() or "active work" for work_item in work_items - if work_item.work_item_id in focus_ids - and work_item.work_item_id in evidence.work_item_ids + if work_item.work_item_id in focus_ids and work_item.work_item_id in evidence.work_item_ids ] if focus_titles: pieces.append(f"elephant focus kept {', '.join(focus_titles)} ahead of generic recall") @@ -343,9 +393,11 @@ def _build_retrieval_reason( pieces.append("supporting continuity evidence") return "; ".join(pieces[:4]) + def _estimate_tokens(content: str) -> int: return max(8, (len(content) // 4) + 1) + def _truncate_lines(content: tuple[str, ...], token_budget: int) -> tuple[str, ...]: remaining = max(token_budget, 0) lines: list[str] = [] @@ -374,6 +426,7 @@ def _truncate_text(value: str, *, limit: int) -> str: boundary = limit return f"{text[:boundary].rstrip(' ,;|')}..." + def _summary_content_for_layer( layer_name: str, session: Episode, @@ -413,8 +466,7 @@ def _summary_content_for_layer( lines.append( "active work: " + "; ".join( - f"{work_item.title} [{work_item.status}/{work_item.priority}]" - for work_item in snapshot_work_items + f"{work_item.title} [{work_item.status}/{work_item.priority}]" for work_item in snapshot_work_items ) ) # Dedup by content hash across steady / retrieved snippets and against @@ -524,6 +576,7 @@ def _retrieval_lines( lines.append(f"{_evidence_line(evidence)} | why: {request.reason}") return tuple(lines) + @dataclass(frozen=True, slots=True) class _ReplayRequestSpec: slot_name: str @@ -541,6 +594,7 @@ class _ReplayRequestSpec: "raw_trace": 3, } + def _split_retrieval_requests( retrieval_requests: tuple[ContextRetrievalRequest, ...], ) -> tuple[tuple[ContextRetrievalRequest, ...], tuple[ContextRetrievalRequest, ...]]: @@ -548,6 +602,7 @@ def _split_retrieval_requests( replay_requests = tuple(request for request in retrieval_requests if request.layer_name == "replay_packet") return snapshot_requests, replay_requests + def _infer_replay_specs( recent_loop_context: tuple[str, ...], *, @@ -579,7 +634,11 @@ def _infer_replay_specs( ) wants_action = explicit_replay and tokens.intersection({"action", "step", "steps", "command", "tool", "run", "did"}) wants_outcome = explicit_replay and tokens.intersection({"outcome", "result", "results"}) - replay_mode = "episode" if explicit_replay and tokens.intersection({"previous", "earlier", "history", "across", "episode"}) else "turn" + replay_mode = ( + "episode" + if explicit_replay and tokens.intersection({"previous", "earlier", "history", "across", "episode"}) + else "turn" + ) wants_raw_trace = "raw trace" in text or "exact trace" in text or ("raw" in tokens and "trace" in tokens) replay_specs: list[_ReplayRequestSpec] = [] if wants_reasoning: @@ -632,6 +691,7 @@ def _infer_replay_specs( ) return () + def _schedule_replay_requests( *, session: Episode, @@ -671,7 +731,11 @@ def _schedule_replay_requests( session_id=session.episode_id, query=" ".join(recent_loop_context)[:240], evidence_refs=(evidence.evidence_id,), - work_item_ids=tuple(work_item.work_item_id for work_item in work_items if work_item.work_item_id in evidence.work_item_ids), + work_item_ids=tuple( + work_item.work_item_id + for work_item in work_items + if work_item.work_item_id in evidence.work_item_ids + ), token_budget=selected_tokens, priority=max(0, 120 - index * 10), reason=f"{replay_intent.reason}; {detail_reason}", @@ -682,6 +746,7 @@ def _schedule_replay_requests( ) return tuple(requests) + def _select_replay_evidence( *, session: Episode, @@ -693,12 +758,23 @@ def _select_replay_evidence( max_compression: str, state_focus: StateFocusDecision | None = None, ) -> tuple[RecallEvidence, str] | None: - del session, work_items, recall_items, recent_loop_context, slot_name, replay_mode, max_compression, state_focus + del ( + session, + work_items, + recall_items, + recent_loop_context, + slot_name, + replay_mode, + max_compression, + state_focus, + ) return None + def _replay_rank(compression: str) -> int: return _REPLAY_COMPRESSION_RANK.get(compression.strip().lower(), _REPLAY_COMPRESSION_RANK["structured_summary"]) + def _project_replay_slot(slot: StructuredTurnSlot, *, max_compression: str) -> tuple[StructuredTurnSlot, bool]: if _replay_rank(slot.compression) <= _replay_rank(max_compression): return slot, False @@ -714,6 +790,7 @@ def _project_replay_slot(slot: StructuredTurnSlot, *, max_compression: str) -> t True, ) + def _replay_lines( replay_requests: tuple[ContextRetrievalRequest, ...], evidence_index: Mapping[str, RecallEvidence], @@ -727,6 +804,7 @@ def _replay_lines( del replay_requests, evidence_index return () + def _replay_summary_lines( replay_requests: tuple[ContextRetrievalRequest, ...], recall_items: tuple[RecallEvidence, ...], @@ -740,21 +818,24 @@ def _replay_summary_lines( ) return tuple(lines) -def _replay_packet_trace_reason(replay_requests: tuple[ContextRetrievalRequest, ...]) -> str: + +def _replay_packet_trace_reason( + replay_requests: tuple[ContextRetrievalRequest, ...], +) -> str: parts = [] for request in replay_requests: slot_summary = ", ".join(request.target_slots) or "reasoning" - parts.append( - f"{slot_summary} via {request.replay_mode}/{request.max_compression}" - ) + parts.append(f"{slot_summary} via {request.replay_mode}/{request.max_compression}") return ( f"targeted replay kept {len(replay_requests)} slice(s) with explicit slot budgets: {'; '.join(parts)}; " "stable policy stayed in EpisodeFrozenContext while replay detail remained request-time only" ) + def _tokenize(text: str) -> set[str]: return {token for token in re.findall(r"[A-Za-z0-9_]+", text.lower()) if token} + def _thematic_tokens( session: Episode, work_items: tuple[...], @@ -772,10 +853,18 @@ def _thematic_tokens( tokens.update(_continuity_marker_tokens(session)) return tokens + def _continuity_marker_tokens(session: Episode) -> set[str]: if not session.interruption_state: return set() - return _tokenize(session.interruption_state) | {"resume", "recovery", "continuity", "interruption", "gap"} + return _tokenize(session.interruption_state) | { + "resume", + "recovery", + "continuity", + "interruption", + "gap", + } + def _context_evidence_score( evidence: RecallEvidence, @@ -803,7 +892,9 @@ def _context_evidence_score( reasons.append( f"active elephant work-linked: {', '.join(sorted(work_titles_by_id.get(wid, 'active work') for wid in overlap))}" ) - focus_overlap = set(_state_focus_focus_work_item_ids(work_items, state_focus=state_focus)).intersection(evidence.work_item_ids) + focus_overlap = set(_state_focus_focus_work_item_ids(work_items, state_focus=state_focus)).intersection( + evidence.work_item_ids + ) score += float(len(focus_overlap)) * 6.0 if focus_overlap: reasons.append( @@ -836,16 +927,33 @@ def _context_evidence_score( if "continuity" in tags or "recovery" in tags: score += 1.0 if state_focus is not None: - if state_focus.focus_scope == "personal_model" and evidence.kind in {"summary", "decision", "semantic"}: + if state_focus.focus_scope == "personal_model" and evidence.kind in { + "summary", + "decision", + "semantic", + }: score += 1.5 reasons.append("personal-model recall") - if state_focus.focus_scope == "state" and evidence.kind in {"artifact", "procedural"}: + if state_focus.focus_scope == "state" and evidence.kind in { + "artifact", + "procedural", + }: score += 1.0 reasons.append("elephant-scoped recall") - if state_focus.continuity_signal != "none" and evidence.kind in {"summary", "decision", "semantic", "procedural"}: + if state_focus.continuity_signal != "none" and evidence.kind in { + "summary", + "decision", + "semantic", + "procedural", + }: score += 1.0 reasons.append("elephant focus resume recovery") - if state_focus.context_budget == "narrow" and state_focus.focus_work_item_ids and not focus_overlap and not overlap: + if ( + state_focus.context_budget == "narrow" + and state_focus.focus_work_item_ids + and not focus_overlap + and not overlap + ): score -= 1.5 text_tokens = _tokenize(evidence.content) | _tokenize(" ".join(evidence.tags)) thematic_overlap = tuple(sorted(text_tokens & thematic_tokens)) @@ -869,6 +977,7 @@ def _context_evidence_score( return score, reason_tuple return score + def _retrieval_priority_bucket( evidence: RecallEvidence, *, @@ -885,10 +994,17 @@ def _retrieval_priority_bucket( return 3 if _thematic_tokens(session, work_items, recent_loop_context).intersection(text_tokens): return 2 - if evidence.episode_id == session.episode_id and evidence.kind in {"summary", "decision", "lesson", "semantic", "procedural"}: + if evidence.episode_id == session.episode_id and evidence.kind in { + "summary", + "decision", + "lesson", + "semantic", + "procedural", + }: return 1 return 0 + def _plan_rationale( session: Episode, work_items: tuple[...], @@ -910,14 +1026,10 @@ def _plan_rationale( "pulls targeted reasoning/action evidence without moving stable policy out of EpisodeFrozenContext" ) if state_focus is not None and state_focus.focus_scope == "personal_model": - return ( - "personal-model elephant focus suppresses unrelated work refs so the session snapshot stays centered on durable Personal Model continuity" - ) + return "personal-model elephant focus suppresses unrelated work refs so the session snapshot stays centered on durable Personal Model continuity" if state_focus is not None and state_focus.context_budget == "narrow" and state_focus.focus_work_item_ids: # R1: rationale stays human-readable — the model cannot dereference work_item_ids. - return ( - "elephant focus narrows the session snapshot and compacts retrieval around the active continuity slice" - ) + return "elephant focus narrows the session snapshot and compacts retrieval around the active continuity slice" if session.interruption_state: return ( f"continuity recovery is prioritized because the session resumed from {session.interruption_state}; " diff --git a/packages/context/runtime_types.py b/packages/context/runtime_types.py index 12af419..05dbaa1 100644 --- a/packages/context/runtime_types.py +++ b/packages/context/runtime_types.py @@ -1,15 +1,11 @@ """Context runtime data contracts.""" - from __future__ import annotations from dataclasses import dataclass, field -from datetime import datetime, timezone -import re -from typing import Any, Mapping, Protocol, runtime_checkable +from typing import Mapping -from packages.capabilities.runtime import CapabilityDescriptor, ContextCapability -from packages.contracts.runtime import ContextBundle, RecallEvidence, StructuredTurnSlot +from packages.contracts.runtime import ContextBundle @dataclass(frozen=True, slots=True) @@ -22,6 +18,7 @@ class ContextLayerBudget: omitted: bool = False source_refs: tuple[str, ...] = () + @dataclass(frozen=True, slots=True) class ContextBudgetRequest: layer_name: str @@ -31,6 +28,7 @@ class ContextBudgetRequest: priority: int = 0 source_refs: tuple[str, ...] = () + @dataclass(frozen=True, slots=True) class ContextBudgetPlan: total_tokens: int @@ -48,6 +46,7 @@ def allocation_for(self, layer_name: str) -> ContextLayerBudget | None: return allocation return None + @dataclass(frozen=True, slots=True) class ContextSummaryRequest: layer_name: str @@ -56,6 +55,7 @@ class ContextSummaryRequest: reason: str required: bool = False + @dataclass(frozen=True, slots=True) class ContextRetrievalRequest: request_id: str @@ -71,6 +71,7 @@ class ContextRetrievalRequest: max_compression: str = "episode_summary" replay_mode: str = "off" + @dataclass(frozen=True, slots=True) class ContextLayerSnapshot: layer_name: str @@ -79,12 +80,14 @@ class ContextLayerSnapshot: token_budget: int summary: str | None = None + @dataclass(frozen=True, slots=True) class EpisodeFrozenContext: source_refs: tuple[str, ...] content: tuple[str, ...] token_budget: int + @dataclass(frozen=True, slots=True) class StateSnapshot: source_refs: tuple[str, ...] @@ -95,6 +98,7 @@ class StateSnapshot: token_budget: int summary: str | None = None + @dataclass(frozen=True, slots=True) class EpisodeReplay: source_refs: tuple[str, ...] @@ -103,18 +107,21 @@ class EpisodeReplay: token_budget: int summary: str | None = None + @dataclass(frozen=True, slots=True) class LoopContext: source_refs: tuple[str, ...] content: tuple[str, ...] token_budget: int + @dataclass(frozen=True, slots=True) class RequestAttachments: source_refs: tuple[str, ...] content: tuple[str, ...] token_budget: int + @dataclass(frozen=True, slots=True) class EpisodeFrame: session_id: str @@ -173,6 +180,7 @@ def layers(self) -> tuple[ContextLayerSnapshot, ...]: ) return tuple(layers) + @dataclass(frozen=True, slots=True) class ContextSourceTrace: layer_name: str @@ -185,6 +193,7 @@ def describe(self) -> str: omitted = f" | omitted: {', '.join(self.omitted_refs)}" if self.omitted_refs else "" return f"- {self.layer_name}: {self.reason} | selected: {selected}{omitted}" + @dataclass(frozen=True, slots=True) class ContextAssemblyPlan: session_id: str @@ -198,6 +207,7 @@ class ContextAssemblyPlan: rationale: str = "" source_trace: tuple[ContextSourceTrace, ...] = () + @dataclass(frozen=True, slots=True) class ContextAssemblyResult: bundle: ContextBundle diff --git a/packages/context/session_projection.py b/packages/context/session_projection.py index e094ef6..f60da48 100644 --- a/packages/context/session_projection.py +++ b/packages/context/session_projection.py @@ -8,7 +8,13 @@ from typing import Any from packages.contracts.layers import Episode -from packages.contracts.runtime import ContextBundle, EventEnvelope, ExecutionResult, PromptEnvelope, PromptMessage +from packages.contracts.runtime import ( + ContextBundle, + EventEnvelope, + ExecutionResult, + PromptEnvelope, + PromptMessage, +) from packages.context.projection import ( estimate_projection_tokens, projection_result_with_estimated_tokens, @@ -104,19 +110,27 @@ def prompt_messages_tuple(value: object) -> tuple[PromptMessage, ...]: tool_call_id=str(item.get("tool_call_id") or ""), tool_name=str(item.get("tool_name") or ""), tool_calls=tool_calls, - metadata={str(k): str(v) for k, v in dict(metadata_raw).items()} if isinstance(metadata_raw, Mapping) else {}, + metadata={str(k): str(v) for k, v in dict(metadata_raw).items()} + if isinstance(metadata_raw, Mapping) + else {}, ) ) return tuple(message for message in messages if message.role) -def restore_session_context_epoch(payload: Mapping[str, Any] | None, *, session_id: str | None = None) -> SessionContextEpoch | None: +def restore_session_context_epoch( + payload: Mapping[str, Any] | None, *, session_id: str | None = None +) -> SessionContextEpoch | None: if not isinstance(payload, Mapping): return None if "session_context_epoch" in payload: if session_id is not None: session = payload.get("session") - resolved = str(session.get("episode_id") or session.get("session_id") or "").strip() if isinstance(session, Mapping) else "" + resolved = ( + str(session.get("episode_id") or session.get("session_id") or "").strip() + if isinstance(session, Mapping) + else "" + ) if resolved and resolved != session_id: return None payload = payload.get("session_context_epoch") # type: ignore[assignment] @@ -191,14 +205,19 @@ def next_session_context_epoch( fallback_history_messages: tuple[PromptMessage, ...] = (), now: datetime | None = None, ) -> SessionContextEpoch: - epoch = existing if existing is not None and existing.session_id == session.episode_id else SessionContextEpoch(session_id=session.episode_id) + epoch = ( + existing + if existing is not None and existing.session_id == session.episode_id + else SessionContextEpoch(session_id=session.episode_id) + ) is_user_turn = event is not None and _event_is_user_turn(event) - episode_open_refresh = epoch.frozen and context is not None and event is None and execution is None and not epoch.history_messages + episode_open_refresh = ( + epoch.frozen and context is not None and event is None and execution is None and not epoch.history_messages + ) if context is not None: envelope = context.prompt_envelope refresh_frozen = ( - not epoch.frozen - or episode_open_refresh + not epoch.frozen or episode_open_refresh # No longer detecting prefix changes per-turn and refreshing; the caller refreshes explicitly during compress ) if refresh_frozen: @@ -225,8 +244,7 @@ def next_session_context_epoch( now_value = now or _utc_now() raw_history_messages = tuple(message for message in turn_messages if message.content.strip() or message.tool_calls) history_messages = _annotate_history_messages( - _with_fallback_user_anchor(raw_history_messages, fallback_history_messages) - or fallback_history_messages, + _with_fallback_user_anchor(raw_history_messages, fallback_history_messages) or fallback_history_messages, event=event, now=now_value, ) @@ -397,17 +415,14 @@ def _with_fallback_user_anchor( if any(message.role == "user" and message.content.strip() for message in messages): return messages user_anchor = next( - ( - message - for message in fallback_messages - if message.role == "user" and message.content.strip() - ), + (message for message in fallback_messages if message.role == "user" and message.content.strip()), None, ) if user_anchor is None: return messages return (user_anchor, *messages) + def _event_is_user_turn(event: EventEnvelope) -> bool: event_type = str(event.event_type or "").strip().lower() source = str(event.source or "").strip().lower() @@ -415,7 +430,11 @@ def _event_is_user_turn(event: EventEnvelope) -> bool: return False if source.startswith("cli.startup"): return False - return event_type in {"turn.received", "loop.received", "im.message.receive_v1"} or event_type.endswith(".received") + return event_type in { + "turn.received", + "loop.received", + "im.message.receive_v1", + } or event_type.endswith(".received") def _event_is_im(event: EventEnvelope | None) -> bool: @@ -451,10 +470,7 @@ def _annotate_history_messages( "event_id": event_id, "projection_surface": projection_surface, } - return tuple( - replace(message, metadata={**metadata, **dict(message.metadata or {})}) - for message in messages - ) + return tuple(replace(message, metadata={**metadata, **dict(message.metadata or {})}) for message in messages) def _history_idle_gap_exceeded( diff --git a/packages/continuity/projection.py b/packages/continuity/projection.py index 155b0ae..46a36b7 100644 --- a/packages/continuity/projection.py +++ b/packages/continuity/projection.py @@ -10,7 +10,11 @@ EpisodeContinuityState, ) from packages.state.rendered_views import RenderedRelationshipView -from packages.state import CompanionGovernanceState, LoadedProfile, build_companion_governance_state +from packages.state import ( + CompanionGovernanceState, + LoadedProfile, + build_companion_governance_state, +) from .runtime import ( RelationshipPolicy, build_relationship_policy, @@ -57,9 +61,7 @@ def inspect( ), preserve_preferences=companion.preserve_preferences if companion is not None else True, preserve_corrections=companion.preserve_corrections if companion is not None else True, - preserve_emotional_context=( - companion.preserve_emotional_context if companion is not None else True - ), + preserve_emotional_context=(companion.preserve_emotional_context if companion is not None else True), ) initiative = identity_record.initiative if identity_record is not None else governance.identity.initiative continuity_notes = ( @@ -123,8 +125,7 @@ def _reengagement_prompt( return prompt, style style = "steady-follow-through" prompt = ( - f"Preserve continuity explicitly and keep the next step legible; " - f"continuity cues: {note_text}.{focus_clause}" + f"Preserve continuity explicitly and keep the next step legible; continuity cues: {note_text}.{focus_clause}" ) return prompt, style diff --git a/packages/continuity/runtime.py b/packages/continuity/runtime.py index 2700239..96cd3a2 100644 --- a/packages/continuity/runtime.py +++ b/packages/continuity/runtime.py @@ -53,9 +53,7 @@ def build_episode_continuity_state( chain = lineage or (episode,) lineage_episode_ids = tuple(node.episode_id for node in chain) origin_episode_id = ( - lineage_episode_ids[0] - if lineage_episode_ids - else (episode.parent_episode_id or episode.episode_id) + lineage_episode_ids[0] if lineage_episode_ids else (episode.parent_episode_id or episode.episode_id) ) inherited_interruption_state = normalize_interruption_state(episode.interruption_state) if inherited_interruption_state is None: @@ -112,7 +110,11 @@ def build_relationship_policy( preserve_preferences: bool = True, preserve_corrections: bool = True, preserve_emotional_context: bool = True, - allowed_signal_kinds: tuple[str, ...] = ("relationship", "preference", "continuity"), + allowed_signal_kinds: tuple[str, ...] = ( + "relationship", + "preference", + "continuity", + ), ) -> RelationshipPolicy: return RelationshipPolicy( profile_mode=profile_mode, @@ -140,11 +142,7 @@ def _strip_generated_suffixes(value: str | None) -> str | None: return None parts = [part.strip() for part in text.split(";")] kept = [part for part in parts[:1] if part] - kept.extend( - part - for part in parts[1:] - if part and not part.startswith(_GENERATED_SUFFIX_PREFIXES) - ) + kept.extend(part for part in parts[1:] if part and not part.startswith(_GENERATED_SUFFIX_PREFIXES)) return _compact_interruption_text("; ".join(kept)) @@ -156,17 +154,13 @@ def _normalize_interruption_state(value: str | None) -> _NormalizedInterruptionS return _NormalizedInterruptionState(None, generated_resume=generated_resume) if text.startswith(_RECOVER_INTERRUPTION_PREFIX): generated_resume = True - text = _strip_generated_suffixes( - text.removeprefix(_RECOVER_INTERRUPTION_PREFIX) - ) + text = _strip_generated_suffixes(text.removeprefix(_RECOVER_INTERRUPTION_PREFIX)) continue if text.startswith(_RESUME_SUMMARY_PREFIX): generated_resume = True if _AFTER_INTERRUPTION_MARKER not in text: return _NormalizedInterruptionState(None, generated_resume=True) - text = _strip_generated_suffixes( - text.split(_AFTER_INTERRUPTION_MARKER, 1)[1] - ) + text = _strip_generated_suffixes(text.split(_AFTER_INTERRUPTION_MARKER, 1)[1]) continue return _NormalizedInterruptionState(text, generated_resume=generated_resume) return _NormalizedInterruptionState(text, generated_resume=True) @@ -187,12 +181,8 @@ def _continuity_summary( summary = f"recover after interruption: {inherited_interruption_state}" else: summary = ( - f"resume durable work from episode {origin_episode_id} " - f"after interruption: {inherited_interruption_state}" + f"resume durable work from episode {origin_episode_id} after interruption: {inherited_interruption_state}" ) if episode.parent_episode_id and episode.parent_episode_id != origin_episode_id: summary += f"; immediate parent={episode.parent_episode_id}" return summary - - - diff --git a/packages/contracts/layers.py b/packages/contracts/layers.py index 61e04ee..d21e709 100644 --- a/packages/contracts/layers.py +++ b/packages/contracts/layers.py @@ -103,9 +103,7 @@ def __post_init__(self) -> None: _ensure_non_empty_text(self.entry_surface, name="episode entry surface") _ensure_non_empty_text(self.status, name="episode status") if self.status not in _EPISODE_STATUSES: - raise ValueError( - f"episode status must be one of {sorted(_EPISODE_STATUSES)}: {self.status}" - ) + raise ValueError(f"episode status must be one of {sorted(_EPISODE_STATUSES)}: {self.status}") @dataclass(frozen=True, slots=True) diff --git a/packages/contracts/personal_model.py b/packages/contracts/personal_model.py index 648a26e..12244e7 100644 --- a/packages/contracts/personal_model.py +++ b/packages/contracts/personal_model.py @@ -22,9 +22,7 @@ ALLOWED_LENSES = frozenset({"identity", "world", "pulse", "journey"}) ALLOWED_FACT_STATUSES = frozenset({"active", "retired", "disputed", "deleted"}) ALLOWED_FACT_SOURCES = frozenset({"user_explicit", "pm_agent_promote"}) -ALLOWED_QUESTION_STATUSES = frozenset( - {"open", "asked", "answered", "dismissed", "stale"} -) +ALLOWED_QUESTION_STATUSES = frozenset({"open", "asked", "answered", "dismissed", "stale"}) ALLOWED_QUESTION_SOURCES = frozenset({"coverage_gap", "ambiguity", "contextual"}) ALLOWED_SENSITIVITIES = frozenset({"low", "medium", "high"}) ALLOWED_LEARNING_INTENSITIES = frozenset({"low", "medium", "high"}) @@ -46,9 +44,7 @@ def _ensure_lens(value: str | None, *, name: str, allow_none: bool = False) -> N return raise ValueError(f"{name} lens must be provided") if value not in ALLOWED_LENSES: - raise ValueError( - f"{name} lens must be one of {sorted(ALLOWED_LENSES)}: {value}" - ) + raise ValueError(f"{name} lens must be one of {sorted(ALLOWED_LENSES)}: {value}") @dataclass(frozen=True, slots=True) @@ -76,13 +72,9 @@ def __post_init__(self) -> None: _ensure_lens(self.lens, name="fact") _ensure_confidence(self.confidence, name="fact") if self.source not in ALLOWED_FACT_SOURCES: - raise ValueError( - f"fact source must be one of {sorted(ALLOWED_FACT_SOURCES)}: {self.source}" - ) + raise ValueError(f"fact source must be one of {sorted(ALLOWED_FACT_SOURCES)}: {self.source}") if self.status not in ALLOWED_FACT_STATUSES: - raise ValueError( - f"fact status must be one of {sorted(ALLOWED_FACT_STATUSES)}: {self.status}" - ) + raise ValueError(f"fact status must be one of {sorted(ALLOWED_FACT_STATUSES)}: {self.status}") @dataclass(frozen=True, slots=True) @@ -115,9 +107,7 @@ class OpenQuestion: def __post_init__(self) -> None: _ensure_non_empty_text(self.question_id, name="open question id") - _ensure_non_empty_text( - self.personal_model_id, name="open question personal model id" - ) + _ensure_non_empty_text(self.personal_model_id, name="open question personal model id") _ensure_lens(self.lens, name="open question") _ensure_non_empty_text(self.sub_lens, name="open question sub_lens") _ensure_non_empty_text(self.text, name="open question text") @@ -126,19 +116,12 @@ def __post_init__(self) -> None: raise ValueError("open question priority must stay between 0.0 and 1.0") if self.sensitivity not in ALLOWED_SENSITIVITIES: raise ValueError( - f"open question sensitivity must be one of {sorted(ALLOWED_SENSITIVITIES)}: " - f"{self.sensitivity}" + f"open question sensitivity must be one of {sorted(ALLOWED_SENSITIVITIES)}: {self.sensitivity}" ) if self.source not in ALLOWED_QUESTION_SOURCES: - raise ValueError( - f"open question source must be one of {sorted(ALLOWED_QUESTION_SOURCES)}: " - f"{self.source}" - ) + raise ValueError(f"open question source must be one of {sorted(ALLOWED_QUESTION_SOURCES)}: {self.source}") if self.status not in ALLOWED_QUESTION_STATUSES: - raise ValueError( - f"open question status must be one of {sorted(ALLOWED_QUESTION_STATUSES)}: " - f"{self.status}" - ) + raise ValueError(f"open question status must be one of {sorted(ALLOWED_QUESTION_STATUSES)}: {self.status}") if self.asked_count < 0: raise ValueError("open question asked_count must be >= 0") diff --git a/packages/contracts/runtime.py b/packages/contracts/runtime.py index 3ccfb64..e8f4eb2 100644 --- a/packages/contracts/runtime.py +++ b/packages/contracts/runtime.py @@ -7,7 +7,7 @@ from __future__ import annotations -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field from datetime import datetime from typing import Mapping @@ -24,7 +24,9 @@ def _ensure_non_empty_text(value: str, *, name: str) -> None: _ALLOWED_STATE_FOCUS_MODES = frozenset({"embedded", "skip"}) _ALLOWED_INDEX_REFRESH_SCOPES = frozenset({"noop", "full"}) -_ALLOWED_STATE_FOCUS_FAMILIES = frozenset({"execution", "exploration", "creation", "reference", "personal_model", "resume"}) +_ALLOWED_STATE_FOCUS_FAMILIES = frozenset( + {"execution", "exploration", "creation", "reference", "personal_model", "resume"} +) _ALLOWED_STATE_FOCUS_CANDIDATE_KINDS = frozenset({"work_item"}) _ALLOWED_CONTINUITY_SIGNALS = frozenset({"none", "continue", "resume", "interrupted", "inherit"}) _ALLOWED_FOCUS_SCOPES = frozenset({"episode", "lineage", "state", "personal_model"}) @@ -252,7 +254,9 @@ class StateFocusDecision: def __post_init__(self) -> None: if self.focus_family not in _ALLOWED_STATE_FOCUS_FAMILIES: - raise ValueError(f"focus_family must be one of {sorted(_ALLOWED_STATE_FOCUS_FAMILIES)}: {self.focus_family}") + raise ValueError( + f"focus_family must be one of {sorted(_ALLOWED_STATE_FOCUS_FAMILIES)}: {self.focus_family}" + ) if not 0.0 <= self.confidence <= 1.0: raise ValueError("state focus confidence must stay between 0.0 and 1.0") if self.continuity_signal not in _ALLOWED_CONTINUITY_SIGNALS: @@ -260,13 +264,9 @@ def __post_init__(self) -> None: f"continuity_signal must be one of {sorted(_ALLOWED_CONTINUITY_SIGNALS)}: {self.continuity_signal}" ) if self.focus_scope not in _ALLOWED_FOCUS_SCOPES: - raise ValueError( - f"focus_scope must be one of {sorted(_ALLOWED_FOCUS_SCOPES)}: {self.focus_scope}" - ) + raise ValueError(f"focus_scope must be one of {sorted(_ALLOWED_FOCUS_SCOPES)}: {self.focus_scope}") if self.context_budget not in _ALLOWED_BUDGET_CLASSES: - raise ValueError( - f"context_budget must be one of {sorted(_ALLOWED_BUDGET_CLASSES)}: {self.context_budget}" - ) + raise ValueError(f"context_budget must be one of {sorted(_ALLOWED_BUDGET_CLASSES)}: {self.context_budget}") if self.degradation_mode not in _ALLOWED_STATE_FOCUS_DEGRADATION_MODES: raise ValueError( "degradation_mode must be one of " @@ -397,7 +397,10 @@ class RuntimeEvidenceBundle: artifacts: tuple[RuntimeArtifact, ...] = () def __post_init__(self) -> None: - _ensure_unique_ids(tuple(item.evidence_id for item in self.recall_items), name="recall evidence") + _ensure_unique_ids( + tuple(item.evidence_id for item in self.recall_items), + name="recall evidence", + ) _ensure_unique_ids(tuple(artifact.artifact_id for artifact in self.artifacts), name="artifact") for item in self.recall_items: if item.episode_id != self.episode_id: @@ -566,7 +569,10 @@ class ProcedureLibrary: procedures: tuple[ProcedureRecord, ...] = () def __post_init__(self) -> None: - _ensure_unique_ids(tuple(procedure.procedure_id for procedure in self.procedures), name="procedure") + _ensure_unique_ids( + tuple(procedure.procedure_id for procedure in self.procedures), + name="procedure", + ) @dataclass(frozen=True, slots=True) @@ -593,9 +599,7 @@ class PromptEnvelope: def system_prompt(self) -> str: return "\n\n".join( - section - for section in (self.frozen_prefix.strip(), self.session_snapshot.strip()) - if section + section for section in (self.frozen_prefix.strip(), self.session_snapshot.strip()) if section ) def user_prelude(self) -> str: diff --git a/packages/cron/runtime.py b/packages/cron/runtime.py index 20280d0..2ba9fc0 100644 --- a/packages/cron/runtime.py +++ b/packages/cron/runtime.py @@ -124,7 +124,12 @@ def list_jobs( if not include_inactive and job.status != "scheduled": continue filtered.append(job) - return tuple(sorted(filtered, key=lambda item: ((item.next_run_at or item.updated_at), item.name))) + return tuple( + sorted( + filtered, + key=lambda item: ((item.next_run_at or item.updated_at), item.name), + ) + ) def inspect_job(self, job_id: str) -> CronJob: for job in self._load_jobs(): @@ -186,7 +191,12 @@ def _resume(job: CronJob) -> CronJob: if job.status == "completed": raise ValueError("completed jobs cannot be resumed") next_run_at = job.next_run_at or _parse_schedule(job.schedule_text, self._clock())["next_run_at"] - return replace(job, status="scheduled", next_run_at=next_run_at, updated_at=self._clock()) + return replace( + job, + status="scheduled", + next_run_at=next_run_at, + updated_at=self._clock(), + ) return self._update_job(job_id, _resume) @@ -218,7 +228,9 @@ def due_jobs( due.append(job) return tuple(due) - def record_execution(self, job_id: str, *, outcome: str, summary: str, now: datetime | None = None) -> CronJobExecution: + def record_execution( + self, job_id: str, *, outcome: str, summary: str, now: datetime | None = None + ) -> CronJobExecution: recorded_at = now or self._clock() def _advance(job: CronJob) -> CronJob: @@ -447,14 +459,10 @@ def _job_from_payload(payload: Mapping[str, Any]) -> CronJob: created_at=datetime.fromisoformat(str(payload["created_at"])), updated_at=datetime.fromisoformat(str(payload["updated_at"])), next_run_at=( - datetime.fromisoformat(str(payload["next_run_at"])) - if payload.get("next_run_at") is not None - else None + datetime.fromisoformat(str(payload["next_run_at"])) if payload.get("next_run_at") is not None else None ), last_run_at=( - datetime.fromisoformat(str(payload["last_run_at"])) - if payload.get("last_run_at") is not None - else None + datetime.fromisoformat(str(payload["last_run_at"])) if payload.get("last_run_at") is not None else None ), run_count=int(payload.get("run_count", 0)), interval_seconds=int(payload["interval_seconds"]) if payload.get("interval_seconds") is not None else None, diff --git a/packages/curiosity/open_question_generator.py b/packages/curiosity/open_question_generator.py index 47cab21..33d7863 100644 --- a/packages/curiosity/open_question_generator.py +++ b/packages/curiosity/open_question_generator.py @@ -58,7 +58,9 @@ def generate_contextual_questions( status="open", metadata={ "seed_text": text, - "question_intent": str(seed.get("intent") or seed.get("rationale") or "follow up while context is steady").strip(), + "question_intent": str( + seed.get("intent") or seed.get("rationale") or "follow up while context is steady" + ).strip(), }, ) ) diff --git a/packages/curiosity/proactive_ask_policy.py b/packages/curiosity/proactive_ask_policy.py index f3ac319..344a246 100644 --- a/packages/curiosity/proactive_ask_policy.py +++ b/packages/curiosity/proactive_ask_policy.py @@ -9,7 +9,6 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timezone from typing import Sequence from packages.contracts import OpenQuestion diff --git a/packages/curiosity/question_renderer.py b/packages/curiosity/question_renderer.py index 396eeed..f8267fc 100644 --- a/packages/curiosity/question_renderer.py +++ b/packages/curiosity/question_renderer.py @@ -114,8 +114,7 @@ def render_session_hint( if lang == "zh": return ( "这里有一个可以顺带理解用户的开放问题。只在对话自然走到这里时,用自己的话轻轻问;" - "不要照抄模板,也不要像问卷。若用户回答,用 tool.personal_model.update 写成一个四 lens claim:\n" - + bullets + "不要照抄模板,也不要像问卷。若用户回答,用 tool.personal_model.update 写成一个四 lens claim:\n" + bullets ) return ( "One open question may help you understand the user. Ask it only when the conversation naturally opens a door; " @@ -205,4 +204,9 @@ def _ensure_question_mark(text: str, language: str) -> str: return stripped + ("?" if language == "zh" else "?") -__all__ = ["contextualize_question", "render_opener", "render_idle_push", "render_session_hint"] +__all__ = [ + "contextualize_question", + "render_opener", + "render_idle_push", + "render_session_hint", +] diff --git a/packages/curiosity/question_tool_surface.py b/packages/curiosity/question_tool_surface.py index cef9ba2..01501fb 100644 --- a/packages/curiosity/question_tool_surface.py +++ b/packages/curiosity/question_tool_surface.py @@ -49,11 +49,21 @@ def manage_questions( normalized = action.strip().lower() pm_id = self._personal_model_id(session_id, personal_model_id) if normalized in {"list", "ls"}: - return {"action": "list", "questions": self._list(pm_id, status=status, lens=lens, sub_lens=sub_lens, limit=limit)} + return { + "action": "list", + "questions": self._list(pm_id, status=status, lens=lens, sub_lens=sub_lens, limit=limit), + } if normalized in {"inspect", "view"}: - return {"action": "inspect", "question": self._question_payload(self._load(pm_id, question_id))} + return { + "action": "inspect", + "question": self._question_payload(self._load(pm_id, question_id)), + } if normalized in {"bank", "templates"}: - return {"action": "bank", "templates": [], "note": "Static question bank removed. Questions are now created by the learning agent."} + return { + "action": "bank", + "templates": [], + "note": "Static question bank removed. Questions are now created by the learning agent.", + } if normalized == "create": question = self._create( pm_id, @@ -109,7 +119,11 @@ def manage_questions( return {"action": "stale", "question": self._question_payload(question)} if normalized in {"delete", "remove"}: deleted = self._delete(pm_id, question_id) - return {"action": "delete", "question_id": deleted.question_id, "status": "deleted"} + return { + "action": "delete", + "question_id": deleted.question_id, + "status": "deleted", + } raise ValueError(f"tool.personal_model.questions unsupported action: {action!r}") def _personal_model_id(self, session_id: str, explicit: str) -> str: @@ -124,9 +138,19 @@ def _personal_model_id(self, session_id: str, explicit: str) -> str: ensure(personal_model_id=pm_id) return pm_id - def _list(self, personal_model_id: str, *, status: str, lens: str, sub_lens: str, limit: int) -> list[dict[str, Any]]: + def _list( + self, + personal_model_id: str, + *, + status: str, + lens: str, + sub_lens: str, + limit: int, + ) -> list[dict[str, Any]]: statuses: str | tuple[str, ...] - statuses = tuple(item.strip() for item in status.replace("|", ",").split(",") if item.strip()) if status else "open" + statuses = ( + tuple(item.strip() for item in status.replace("|", ",").split(",") if item.strip()) if status else "open" + ) questions = self.repository.list_open_questions( personal_model_id=personal_model_id, status=statuses, @@ -150,12 +174,12 @@ def _load(self, personal_model_id: str, question_id: str) -> OpenQuestion: return question if "/" in resolved_id: lens, sub_lens = (part.strip() for part in resolved_id.split("/", 1)) - matches = [ - question for question in questions - if question.lens == lens and question.sub_lens == sub_lens - ] + matches = [question for question in questions if question.lens == lens and question.sub_lens == sub_lens] if matches: - return sorted(matches, key=lambda q: (q.status != "open", -q.priority, q.created_at))[0] + return sorted( + matches, + key=lambda q: (q.status != "open", -q.priority, q.created_at), + )[0] raise KeyError(resolved_id) def _create( @@ -190,7 +214,10 @@ def _create( source=resolved_source, created_at=datetime.now(timezone.utc), status="open", - metadata={"managed_by": "tool.personal_model.questions", **dict(metadata or {})}, + metadata={ + "managed_by": "tool.personal_model.questions", + **dict(metadata or {}), + }, ) self.repository.upsert_open_question(question) return question @@ -215,7 +242,9 @@ def _update(self, personal_model_id: str, *, question_id: str, **updates: Any) - if updates.get("priority") is not None: values["priority"] = _priority(updates.get("priority"), default=current.priority) if str(updates.get("sensitivity") or "").strip(): - values["sensitivity"] = _normalized_choice(str(updates["sensitivity"]), ALLOWED_SENSITIVITIES, field="sensitivity") + values["sensitivity"] = _normalized_choice( + str(updates["sensitivity"]), ALLOWED_SENSITIVITIES, field="sensitivity" + ) if str(updates.get("source") or "").strip(): values["source"] = _normalized_choice(str(updates["source"]), ALLOWED_QUESTION_SOURCES, field="source") updated = replace(current, **values) @@ -251,7 +280,10 @@ def _delete(self, personal_model_id: str, question_id: str) -> OpenQuestion: delete(question_id=current.question_id) return current with self.repository.connection() as connection: - connection.execute("DELETE FROM personal_model_open_questions WHERE question_id = ?", (current.question_id,)) + connection.execute( + "DELETE FROM personal_model_open_questions WHERE question_id = ?", + (current.question_id,), + ) connection.commit() return current @@ -276,7 +308,13 @@ def _question_payload(question: OpenQuestion) -> dict[str, Any]: } -def _normalized_choice(value: str, allowed: set[str] | frozenset[str], *, default: str | None = None, field: str) -> str: +def _normalized_choice( + value: str, + allowed: set[str] | frozenset[str], + *, + default: str | None = None, + field: str, +) -> str: normalized = str(value or default or "").strip().lower() if normalized not in allowed: raise ValueError(f"{field} must be one of {sorted(allowed)}: {value!r}") diff --git a/packages/embeddings/runtime.py b/packages/embeddings/runtime.py index 5d4b224..ed45b9b 100644 --- a/packages/embeddings/runtime.py +++ b/packages/embeddings/runtime.py @@ -18,7 +18,9 @@ ELEPHANT_EMBED_SOURCE_URL = "https://huggingface.co/llm-semantic-router/elephant-embeddings-v1-text-small" ELEPHANT_EMBED_MODEL_ROOT = str(Path.home() / ".elephant" / "models" / "elephant-embeddings-v1-text-small") ELEPHANT_EMBED_MODELSCOPE_ID = "agentic-intelligence-lab/elephant-embeddings-v1-text-small" -ELEPHANT_EMBED_MODELSCOPE_URL = "https://modelscope.cn/models/agentic-intelligence-lab/elephant-embeddings-v1-text-small" +ELEPHANT_EMBED_MODELSCOPE_URL = ( + "https://modelscope.cn/models/agentic-intelligence-lab/elephant-embeddings-v1-text-small" +) ELEPHANT_EMBED_ONLINE_DIMENSIONS = (64, 256, 768) _ALLOWED_EMBEDDING_STATUSES = frozenset({"pending", "downloading", "ready", "skipped", "failed"}) _ALLOWED_PRELOAD_STATUSES = frozenset({"idle", "steadying", "ready", "failed", "skipped"}) @@ -47,6 +49,7 @@ def _suppress_local_embedding_load_warnings(model_root: str | None = None) -> No global _LOCAL_EMBEDDING_LOG_FILTER_INSTALLED _LOCAL_EMBEDDING_WARNING_MODEL_ROOTS.add(str(embedding_model_root_path(model_root))) if not _LOCAL_EMBEDDING_LOG_FILTER_INSTALLED: + class _LocalEmbeddingLoadFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: message = record.getMessage() @@ -84,9 +87,7 @@ def sentence_transformers_dependencies_ready() -> bool: def resolve_embedding_dimensions(latency_mode: str = "balanced", *, dimensions: int | None = None) -> int: if dimensions is not None: if dimensions not in ELEPHANT_EMBED_ONLINE_DIMENSIONS: - raise ValueError( - f"embedding dimensions must be one of {ELEPHANT_EMBED_ONLINE_DIMENSIONS}: {dimensions}" - ) + raise ValueError(f"embedding dimensions must be one of {ELEPHANT_EMBED_ONLINE_DIMENSIONS}: {dimensions}") return dimensions normalized = latency_mode.strip().lower() if normalized in {"fast", "low-latency", "64d"}: @@ -253,11 +254,19 @@ def health(self) -> EmbeddingHealth: ... def preload_state(self) -> EmbeddingPreloadState: ... def preload( - self, *, target: str, entries: tuple[EmbeddingPreloadEntry, ...], latency_mode: str = "balanced" + self, + *, + target: str, + entries: tuple[EmbeddingPreloadEntry, ...], + latency_mode: str = "balanced", ) -> EmbeddingPreloadState: ... def queue_backfill( - self, *, target: str, entries: tuple[EmbeddingPreloadEntry, ...], latency_mode: str = "balanced" + self, + *, + target: str, + entries: tuple[EmbeddingPreloadEntry, ...], + latency_mode: str = "balanced", ) -> EmbeddingPreloadState: ... def cached_vector(self, *, target: str, cache_key: str, dimensions: int) -> EmbeddingVector | None: ... @@ -501,9 +510,7 @@ def _remember_vectors( def _prune_expired_failures_locked(self) -> None: current = _utc_now() expired = [ - worker_key - for worker_key, (_message, retry_at) in self._failure_by_target.items() - if retry_at <= current + worker_key for worker_key, (_message, retry_at) in self._failure_by_target.items() if retry_at <= current ] for worker_key in expired: self._failure_by_target.pop(worker_key, None) @@ -523,12 +530,18 @@ def _has_cached_vector_locked(self, *, target: str, cache_key: str, dimensions: def _is_pending_locked(self, *, target: str, cache_key: str, dimensions: int) -> bool: normalized_target = _normalize_target(target) resolved_cache_key = str(cache_key).strip() - for (pending_target, pending_dimensions), bucket in self._queued_entries.items(): + for ( + pending_target, + pending_dimensions, + ), bucket in self._queued_entries.items(): if pending_target != normalized_target or pending_dimensions < dimensions: continue if resolved_cache_key in bucket: return True - for (pending_target, pending_dimensions), inflight in self._inflight_entries.items(): + for ( + pending_target, + pending_dimensions, + ), inflight in self._inflight_entries.items(): if pending_target != normalized_target or pending_dimensions < dimensions: continue if resolved_cache_key in inflight: @@ -608,7 +621,11 @@ def embed(self, request: EmbeddingRequest) -> EmbeddingBatch: source_text=text, ) for index, (text, values) in enumerate( - zip(request.texts, self._encode_texts(request.texts, dimensions=dimensions), strict=True) + zip( + request.texts, + self._encode_texts(request.texts, dimensions=dimensions), + strict=True, + ) ) ) return EmbeddingBatch( @@ -630,10 +647,7 @@ def health(self) -> EmbeddingHealth: runtime_state = self._runtime_state() with self._cache_lock: self._prune_expired_failures_locked() - failures = tuple( - f"{target}:{dimensions}d" - for (target, dimensions) in sorted(self._failure_by_target) - ) + failures = tuple(f"{target}:{dimensions}d" for (target, dimensions) in sorted(self._failure_by_target)) ready_dimensions = tuple(sorted(self._ready_targets_by_dimension)) if status == "ready": if runtime_state == "steadying": @@ -665,16 +679,11 @@ def preload_state(self) -> EmbeddingPreloadState: with self._cache_lock: self._prune_expired_failures_locked() active_workers = {key for key, worker in self._workers.items() if worker.is_alive()} - queued = { - key: value - for key, value in self._queued_entries.items() - if value - } - pending_targets = { - target - for target, _dimensions in (*active_workers, *queued.keys()) - } - ready_targets = set().union(*self._ready_targets_by_dimension.values()) if self._ready_targets_by_dimension else set() + queued = {key: value for key, value in self._queued_entries.items() if value} + pending_targets = {target for target, _dimensions in (*active_workers, *queued.keys())} + ready_targets = ( + set().union(*self._ready_targets_by_dimension.values()) if self._ready_targets_by_dimension else set() + ) pending_targets.update(target for target in _DEFAULT_PRELOAD_TARGETS if target not in ready_targets) ready_dimensions = tuple(sorted(self._ready_targets_by_dimension)) failures = tuple(sorted(self._failure_by_target.items())) diff --git a/packages/evidence/__init__.py b/packages/evidence/__init__.py index 82e0488..a736217 100644 --- a/packages/evidence/__init__.py +++ b/packages/evidence/__init__.py @@ -25,9 +25,9 @@ DefaultEvidenceRetriever, build_embedding_index_policy, build_embedding_index_rebuild_plan, - build_resume_packet, parse_step_replay_record, ) +from .state_focus_support import build_resume_packet from .semantic_index_factory import ( SemanticIndexBundle, build_semantic_index_bundle, diff --git a/packages/evidence/episode_summary_indexer.py b/packages/evidence/episode_summary_indexer.py index 7f40529..ff3ea41 100644 --- a/packages/evidence/episode_summary_indexer.py +++ b/packages/evidence/episode_summary_indexer.py @@ -131,7 +131,10 @@ def build_step_recall_text(step: Step) -> str: elif normalized_action == "emit_response": parts = [str(metadata.get("final_response") or metadata.get("assistant_response") or step.summary).strip()] elif normalized_action == "reply": - parts = [str(step.summary or "").strip(), str(metadata.get("final_response") or metadata.get("assistant_response") or "").strip()] + parts = [ + str(step.summary or "").strip(), + str(metadata.get("final_response") or metadata.get("assistant_response") or "").strip(), + ] else: parts = [ str(step.summary or "").strip(), diff --git a/packages/evidence/locator_match.py b/packages/evidence/locator_match.py index df201e2..65ab1a1 100644 --- a/packages/evidence/locator_match.py +++ b/packages/evidence/locator_match.py @@ -111,11 +111,7 @@ def find_entry_by_locator( return None # Tier 1: exact normalised equality. - exact = [ - entry - for entry in materialised - if normalize_locator(_entry_content(entry)) == needle - ] + exact = [entry for entry in materialised if normalize_locator(_entry_content(entry)) == needle] if len(exact) == 1: return exact[0] if exact: @@ -124,11 +120,7 @@ def find_entry_by_locator( return _pick_most_recent(exact) # Tier 2 + 3: substring. - substring = [ - entry - for entry in materialised - if needle in normalize_locator(_entry_content(entry)) - ] + substring = [entry for entry in materialised if needle in normalize_locator(_entry_content(entry))] if len(substring) == 1: return substring[0] if substring: @@ -233,4 +225,4 @@ def _cosine(a: tuple[float, ...], b: tuple[float, ...]) -> float | None: nb += y * y if na <= 0.0 or nb <= 0.0: return None - return dot / ((na ** 0.5) * (nb ** 0.5)) + return dot / ((na**0.5) * (nb**0.5)) diff --git a/packages/evidence/recall_lifecycle.py b/packages/evidence/recall_lifecycle.py index 9a65322..41ef5d3 100644 --- a/packages/evidence/recall_lifecycle.py +++ b/packages/evidence/recall_lifecycle.py @@ -35,13 +35,15 @@ "review": "review", "episode": "episode", } -_TIME_METADATA_KEYS = frozenset({ - "effective_at", - "expires_at", - "last_verified_at", - "verified_at", - "review_after_days", -}) +_TIME_METADATA_KEYS = frozenset( + { + "effective_at", + "expires_at", + "last_verified_at", + "verified_at", + "review_after_days", + } +) @dataclass(frozen=True, slots=True) @@ -70,9 +72,7 @@ def _explicit_policy(metadata: Mapping[str, object]) -> str: def _explicit_lifecycle(metadata: Mapping[str, object]) -> str: lifecycle = _clean( - metadata.get("retention_lifecycle") - or metadata.get("lifecycle") - or metadata.get("staleness_policy") + metadata.get("retention_lifecycle") or metadata.get("lifecycle") or metadata.get("staleness_policy") ) return lifecycle if lifecycle in _ALLOWED_LIFECYCLES else "" diff --git a/packages/evidence/recall_rerank.py b/packages/evidence/recall_rerank.py index e1373d8..53f2116 100644 --- a/packages/evidence/recall_rerank.py +++ b/packages/evidence/recall_rerank.py @@ -84,7 +84,11 @@ def rerank_recall_hits( ) -> tuple[RecallRankedHit, ...]: ranked = sorted( (score_recall_hit(hit, plan=plan, now=now) for hit in hits), - key=lambda item: (-item.final_score, item.hit.title.casefold(), item.hit.content.casefold()), + key=lambda item: ( + -item.final_score, + item.hit.title.casefold(), + item.hit.content.casefold(), + ), ) if limit is not None: return tuple(ranked[: max(0, int(limit))]) diff --git a/packages/evidence/recall_runtime.py b/packages/evidence/recall_runtime.py index 36d6d97..2d2429f 100644 --- a/packages/evidence/recall_runtime.py +++ b/packages/evidence/recall_runtime.py @@ -14,7 +14,11 @@ RecallEvidence, ) -from .runtime import DefaultEvidenceRetriever, build_embedding_index_policy, build_resume_packet +from .runtime import ( + DefaultEvidenceRetriever, + build_embedding_index_policy, +) +from .state_focus_support import build_resume_packet from .unified_recall import UnifiedRecallRequest, unified_recall diff --git a/packages/evidence/recall_support.py b/packages/evidence/recall_support.py index 3998d6b..31addc5 100644 --- a/packages/evidence/recall_support.py +++ b/packages/evidence/recall_support.py @@ -133,9 +133,7 @@ def _score_candidate( if query_ngrams: body_ngrams = _char_ngrams(_alnum_compact(candidate.body)) if body_ngrams: - jaccard = len(query_ngrams & body_ngrams) / float( - len(query_ngrams | body_ngrams) - ) + jaccard = len(query_ngrams & body_ngrams) / float(len(query_ngrams | body_ngrams)) signal += 0.25 * jaccard if signal <= 0.0: return 0.0 diff --git a/packages/evidence/recall_time_range.py b/packages/evidence/recall_time_range.py index 9722932..0081841 100644 --- a/packages/evidence/recall_time_range.py +++ b/packages/evidence/recall_time_range.py @@ -8,7 +8,15 @@ import re from zoneinfo import ZoneInfo, ZoneInfoNotFoundError -_FUZZY_TIME_WINDOW_LABELS = frozenset({"last_night", "yesterday_evening", "this_morning", "today_afternoon", "today_evening"}) +_FUZZY_TIME_WINDOW_LABELS = frozenset( + { + "last_night", + "yesterday_evening", + "this_morning", + "today_afternoon", + "today_evening", + } +) @dataclass(frozen=True, slots=True) @@ -89,12 +97,18 @@ def _with_local_time(day: datetime, tz: timezone | ZoneInfo, *, hour: int, minut return day.astimezone(tz).replace(hour=hour, minute=minute, second=0, microsecond=0) -def _parse_relative_expr(expr: str, *, now: datetime, tz: timezone | ZoneInfo) -> tuple[datetime | None, datetime | None]: +def _parse_relative_expr( + expr: str, *, now: datetime, tz: timezone | ZoneInfo +) -> tuple[datetime | None, datetime | None]: match = re.fullmatch(r"last:(\d+)([hdw])", expr) if match: amount = max(1, int(match.group(1))) unit = match.group(2) - delta = {"h": timedelta(hours=amount), "d": timedelta(days=amount), "w": timedelta(weeks=amount)}[unit] + delta = { + "h": timedelta(hours=amount), + "d": timedelta(days=amount), + "w": timedelta(weeks=amount), + }[unit] return now - delta, now if expr in {"today", "this:day"}: return _local_day_bounds(now, tz) @@ -116,7 +130,9 @@ def _parse_relative_expr(expr: str, *, now: datetime, tz: timezone | ZoneInfo) - return None, None -def _parse_human_window_expr(expr: str, *, now: datetime, tz: timezone | ZoneInfo) -> tuple[datetime | None, datetime | None]: +def _parse_human_window_expr( + expr: str, *, now: datetime, tz: timezone | ZoneInfo +) -> tuple[datetime | None, datetime | None]: local_now = now.astimezone(tz) today = local_now yesterday = local_now - timedelta(days=1) diff --git a/packages/evidence/runtime.py b/packages/evidence/runtime.py index 1f29ede..b42673f 100644 --- a/packages/evidence/runtime.py +++ b/packages/evidence/runtime.py @@ -34,7 +34,11 @@ ) from packages.semantic_index import HybridSemanticSearcher, SemanticSearchQuery from packages.storage import RuntimeStorageRepository -from .state_focus_support import build_resume_packet, focus_work_item_ids, state_focus_scope_hints, state_focus_score_adjustments +from .state_focus_support import ( + focus_work_item_ids, + state_focus_scope_hints, + state_focus_score_adjustments, +) if TYPE_CHECKING: from .recall_runtime import StepEvidenceStore @@ -74,6 +78,8 @@ def evidence_query_cache_key(query: str, *, latency_mode: str = "fast") -> str: dims = resolve_embedding_dimensions(latency_mode) digest = hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:16] return f"evidence-query:{dims}d:{digest}" + + _SEMANTIC_RECALL_SCORE_SCALE = 100.0 _SEMANTIC_MEMORY_ENTRY_INACTIVE_STATES = frozenset( {"deleted", "superseded", "retired", "inactive", "archived", "rejected"} @@ -125,6 +131,8 @@ class _ResolvedScope: lineage_episode_ids: tuple[str, ...] elephant_episode_ids: tuple[str, ...] personal_model_episode_ids: tuple[str, ...] + + _REPLAY_SLOT_NAMES = ("observation", "reasoning", "action", "outcome") _REPLAY_SLOT_LABELS = { "observation": "observation", @@ -140,6 +148,8 @@ class _ResolvedScope: "raw_turn": 4, "raw_trace": 5, } + + def _tuple_from_metadata(value: object) -> tuple[str, ...]: if isinstance(value, (list, tuple)): return tuple(str(item) for item in value if str(item)) @@ -147,27 +157,35 @@ def _tuple_from_metadata(value: object) -> tuple[str, ...]: return () cleaned = str(value).strip() return (cleaned,) if cleaned else () + + def _record_search_text(record: RecallEvidence, *, structured_text: str = "") -> str: return "\n".join(part for part in (record.content, structured_text) if part) + + def _embedding_text(value: str, *, max_chars: int = _EVIDENCE_EMBED_TEXT_LIMIT) -> str: normalized = re.sub(r"\s+", " ", value).strip() if len(normalized) <= max_chars: return normalized return normalized[:max_chars].rstrip() + + def _record_embedding_text(record: RecallEvidence, *, structured_text: str | None = None) -> str: if structured_text is None: structured_turn = _structured_turn_from_record(record) structured_text = ( - _replay_text(structured_turn, selected_slots=_REPLAY_SLOT_NAMES) - if structured_turn is not None - else "" + _replay_text(structured_turn, selected_slots=_REPLAY_SLOT_NAMES) if structured_turn is not None else "" ) search_text = _record_search_text(record, structured_text=structured_text) or record.content return _embedding_text(search_text) + + def _evidence_cache_key(record: RecallEvidence, *, search_text: str) -> str: created_at = record.created_at.isoformat() if record.created_at is not None else "" digest = hashlib.sha256(search_text.encode("utf-8")).hexdigest()[:16] return f"{record.evidence_id}:{created_at}:{digest}" + + def _evidence_preload_entry(record: RecallEvidence, *, structured_text: str = "") -> EmbeddingPreloadEntry: search_text = _record_embedding_text(record, structured_text=structured_text or None) return EmbeddingPreloadEntry( @@ -192,6 +210,8 @@ def _structured_slot_from_metadata(value: object) -> StructuredTurnSlot: source_refs=_tuple_from_metadata(value.get("source_refs")), linkage_refs=_tuple_from_metadata(value.get("linkage_refs")), ) + + def _structured_turn_from_record(record: RecallEvidence) -> StepReplayRecord | None: if record.kind != "structured_turn": return None @@ -207,12 +227,12 @@ def _structured_turn_from_record(record: RecallEvidence) -> StepReplayRecord | N action=_structured_slot_from_metadata(payload.get("action")), outcome=_structured_slot_from_metadata(payload.get("outcome")), personal_model_id=( - str(payload.get("personal_model_id")) - if payload.get("personal_model_id") is not None - else None + str(payload.get("personal_model_id")) if payload.get("personal_model_id") is not None else None ), elephant_id=str(payload.get("elephant_id")) if payload.get("elephant_id") is not None else None, - source_event_id=str(payload.get("source_event_id")) if payload.get("source_event_id") is not None else record.source_id, + source_event_id=str(payload.get("source_event_id")) + if payload.get("source_event_id") is not None + else record.source_id, reasoning_availability=str(payload.get("reasoning_availability", "summary_only")), reasoning_provenance=str(payload.get("reasoning_provenance", "runtime.decision_summary")), compression_tier=str(payload.get("compression_tier", "raw_turn")), @@ -228,22 +248,18 @@ def parse_step_replay_record(record: RecallEvidence) -> StepReplayRecord | None: """Parse Step replay metadata from a recall evidence record.""" return _structured_turn_from_record(record) + + def _normalize_target_slots(target_slots: tuple[str, ...]) -> tuple[str, ...]: return tuple( - dict.fromkeys( - slot.strip().lower() - for slot in target_slots - if slot.strip().lower() in _REPLAY_SLOT_NAMES - ) + dict.fromkeys(slot.strip().lower() for slot in target_slots if slot.strip().lower() in _REPLAY_SLOT_NAMES) ) - def _detail_rank(compression: str) -> int: return _REPLAY_DETAIL_RANK.get(compression.strip().lower(), _REPLAY_DETAIL_RANK["structured_summary"]) - def _project_slot( slot: StructuredTurnSlot, *, @@ -266,7 +282,6 @@ def _project_slot( ) - def _selected_replay_slots( request: EvidenceRetrievalRequest, turn: StepReplayRecord | None, @@ -283,7 +298,6 @@ def _selected_replay_slots( ) - def _project_replay_record( turn: StepReplayRecord, *, @@ -327,7 +341,6 @@ def project(slot_name: str) -> StructuredTurnSlot: ) - def _slot_text(slot_name: str, slot: StructuredTurnSlot) -> tuple[str, ...]: label = _REPLAY_SLOT_LABELS.get(slot_name, slot_name) lines: list[str] = [] @@ -337,7 +350,6 @@ def _slot_text(slot_name: str, slot: StructuredTurnSlot) -> tuple[str, ...]: return tuple(lines) - def _replay_text(turn: StepReplayRecord, *, selected_slots: tuple[str, ...]) -> str: lines: list[str] = [] for slot_name in selected_slots: @@ -345,7 +357,6 @@ def _replay_text(turn: StepReplayRecord, *, selected_slots: tuple[str, ...]) -> return "\n".join(line for line in lines if line) - def _replay_summary(turn: StepReplayRecord, *, selected_slots: tuple[str, ...]) -> str: slot_summary = ", ".join(selected_slots) or "structured evidence" work_summary = ", ".join(turn.work_item_ids[:2]) or "the active thread" @@ -406,12 +417,7 @@ def retrieve(self, request: EvidenceRetrievalRequest) -> EvidenceRetrievalResult for record in self.store.list(include_inactive=request.include_inactive) if record.episode_id in scope_set ) - scope_records = tuple( - { - record.evidence_id: record - for record in episode_scope_records - }.values() - ) + scope_records = tuple({record.evidence_id: record for record in episode_scope_records}.values()) query_vector: tuple[float, ...] = () embeddings_allowed = bool(request.allow_embeddings) if not embeddings_allowed: @@ -499,10 +505,7 @@ def _resolve_scope(self, request: EvidenceRetrievalRequest) -> _ResolvedScope: opened_scopes: list[str] = [] lineage_episode_ids = tuple( - dict.fromkeys( - request.lineage_episode_ids - or self._lineage_episode_ids(request.episode_id) - ) + dict.fromkeys(request.lineage_episode_ids or self._lineage_episode_ids(request.episode_id)) ) elephant_episode_ids = ( _query_episode_ids(self.repository, elephant_id=request.elephant_id) @@ -757,7 +760,9 @@ def _recall_evidence_from_semantic_match(self, request: EvidenceRetrievalRequest content = str(getattr(document, "payload", {}).get("text") or "").strip() if not content: content = str(metadata.get("indexed_text") or metadata.get("text") or source_id) - layer_type = str(getattr(document, "layer_type", "") or metadata.get("kind") or getattr(document, "kind", "") or "semantic") + layer_type = str( + getattr(document, "layer_type", "") or metadata.get("kind") or getattr(document, "kind", "") or "semantic" + ) episode_id = str(metadata.get("episode_id") or request.episode_id) step_id = metadata.get("step_id") if source_id.startswith("step:") or metadata.get("step_id") else None loop_id = metadata.get("loop_id") or None @@ -795,8 +800,7 @@ def _semantic_candidate_from_match( owner_scope: str, ) -> EvidenceCandidate: scaled_signal_scores = { - signal: value * _SEMANTIC_RECALL_SCORE_SCALE - for signal, value in match.signal_scores.items() + signal: value * _SEMANTIC_RECALL_SCORE_SCALE for signal, value in match.signal_scores.items() } lexical_score = sum(score for signal, score in scaled_signal_scores.items() if signal != "vector") vector_score = scaled_signal_scores.get("vector", 0.0) @@ -861,11 +865,7 @@ def _merge_candidate_sets( merged.values(), key=lambda item: ( -item.score, - -( - item.evidence.created_at.timestamp() - if item.evidence.created_at is not None - else 0.0 - ), + -(item.evidence.created_at.timestamp() if item.evidence.created_at is not None else 0.0), item.evidence_id, ), ) @@ -967,7 +967,13 @@ def _candidate_for_record( overlap = sorted(query_tokens & content_tokens) lexical_score = float(len(overlap)) * 2.0 if overlap: - reasons.append(RecallReason("lexical.query", f"query overlap: {','.join(overlap)}", lexical_score)) + reasons.append( + RecallReason( + "lexical.query", + f"query overlap: {','.join(overlap)}", + lexical_score, + ) + ) tag_tokens = _tokenize(" ".join(record.tags)) tag_overlap = sorted(query_tokens & tag_tokens) if tag_overlap: @@ -1048,7 +1054,13 @@ def _candidate_for_record( continuity_score = 0.0 if query_tokens & _CONTINUITY_QUERY_TOKENS: - if record.kind in {"procedural", "semantic", "summary", "decision", "structured_turn"}: + if record.kind in { + "procedural", + "semantic", + "summary", + "decision", + "structured_turn", + }: continuity_score += 1.75 reasons.append( RecallReason( @@ -1066,7 +1078,13 @@ def _candidate_for_record( 0.4, ) ) - continuity_tags = {"continuity", "handoff", "recovery", "resume", "scope-aware"} + continuity_tags = { + "continuity", + "handoff", + "recovery", + "resume", + "scope-aware", + } if continuity_tags & set(record.tags): continuity_score += 0.4 reasons.append( @@ -1082,7 +1100,13 @@ def _candidate_for_record( replay_score = 0.0 if structured_turn is not None: replay_score += 0.75 - reasons.append(RecallReason("replay.structured-turn", "structured turn evidence is replayable", 0.75)) + reasons.append( + RecallReason( + "replay.structured-turn", + "structured turn evidence is replayable", + 0.75, + ) + ) if request.replay_mode != "off": replay_score += 0.8 reasons.append( @@ -1124,7 +1148,10 @@ def _candidate_for_record( ) ) elif request.replay_mode == "episode": - if structured_turn.compression_tier == "episode_summary" or len(structured_turn.source_turn_ids) > 1: + if ( + structured_turn.compression_tier == "episode_summary" + or len(structured_turn.source_turn_ids) > 1 + ): replay_score += 1.5 reasons.append( RecallReason( @@ -1263,7 +1290,9 @@ def _index_invalidation_reason(*, lifecycle_state: str, replacement_evidence_id: return f"{lifecycle_state} evidence must refresh derived lexical and vector views from canonical rows" -def _embedding_index_invalidations(store: "StepEvidenceStore") -> tuple[EmbeddingIndexInvalidation, ...]: +def _embedding_index_invalidations( + store: "StepEvidenceStore", +) -> tuple[EmbeddingIndexInvalidation, ...]: invalidations: list[EmbeddingIndexInvalidation] = [] ordered_records = tuple(sorted(store.list(include_inactive=True), key=_evidence_sort_key)) for record in ordered_records: @@ -1292,7 +1321,9 @@ def _embedding_index_invalidations(store: "StepEvidenceStore") -> tuple[Embeddin return tuple(invalidations) -def build_embedding_index_rebuild_plan(store: "StepEvidenceStore") -> EmbeddingIndexRebuildPlan: +def build_embedding_index_rebuild_plan( + store: "StepEvidenceStore", +) -> EmbeddingIndexRebuildPlan: ordered_records = tuple(sorted(store.list(include_inactive=True), key=_evidence_sort_key)) active_records = tuple( record diff --git a/packages/evidence/state_focus_support.py b/packages/evidence/state_focus_support.py index ff51827..3e10fc6 100644 --- a/packages/evidence/state_focus_support.py +++ b/packages/evidence/state_focus_support.py @@ -118,16 +118,24 @@ def build_resume_packet( reasons.append(f"elephant focus resume signal={focus.continuity_signal}") if focus is not None: reasons.append(f"elephant focus scope={focus.focus_scope}") - opener = "Resume" if focus is None or focus.focus_family == "resume" or focus.continuity_signal != "none" else "Continue" + opener = ( + "Resume" if focus is None or focus.focus_family == "resume" or focus.continuity_signal != "none" else "Continue" + ) if top is not None: reasons.extend(reason.detail for reason in top.reasons[:3]) if top.replay_summary: reasons.append(top.replay_summary) - focused_work_item_ids = tuple(work_item_id for work_item_id in focus_ids if work_item_id in top.evidence.work_item_ids) + focused_work_item_ids = tuple( + work_item_id for work_item_id in focus_ids if work_item_id in top.evidence.work_item_ids + ) if focused_work_item_ids: focus_ids = focused_work_item_ids replay_clause = f" Replay: {top.replay_summary}." if top.replay_summary else "" - lead_phrase = "inherit the resolved focus and lead with" if focus is not None and focus.focus_work_item_ids else "lead with" + lead_phrase = ( + "inherit the resolved focus and lead with" + if focus is not None and focus.focus_work_item_ids + else "lead with" + ) summary = ( f"{opener} {request.episode_id} around {', '.join(focus_ids[:2]) or 'the active thread'}; " f"{lead_phrase} {top.evidence_id} because {', '.join(reason.detail for reason in top.reasons[:2])}.{replay_clause}" diff --git a/packages/evidence/unified_recall.py b/packages/evidence/unified_recall.py index 869d7b6..ff7dee1 100644 --- a/packages/evidence/unified_recall.py +++ b/packages/evidence/unified_recall.py @@ -160,11 +160,9 @@ def list_episodes( *, state_id: str | None = None, limit: int | None = None, - ) -> tuple[Any, ...]: - ... + ) -> tuple[Any, ...]: ... - def list_steps(self, *, loop_id: str | None = None) -> tuple[Any, ...]: - ... + def list_steps(self, *, loop_id: str | None = None) -> tuple[Any, ...]: ... def list_semantic_index_entries( self, @@ -174,8 +172,7 @@ def list_semantic_index_entries( personal_model_id: str | None = None, provider_id: str | None = None, model_id: str | None = None, - ) -> tuple[SemanticIndexEntry, ...]: - ... + ) -> tuple[SemanticIndexEntry, ...]: ... def _aware(value: datetime) -> datetime: @@ -262,7 +259,17 @@ def documents_from_episodes(episodes: Iterable[Any]) -> list[RecallDocument]: summary = str(getattr(episode, "exit_summary", "") or "").strip() entry_surface = str(getattr(episode, "entry_surface", "") or "").strip() metadata = dict(getattr(episode, "metadata", {}) or {}) - body = " | ".join(part for part in (summary, entry_surface, metadata.get("topic", ""), metadata.get("focus", ""), metadata.get("note", "")) if str(part or "").strip()) + body = " | ".join( + part + for part in ( + summary, + entry_surface, + metadata.get("topic", ""), + metadata.get("focus", ""), + metadata.get("note", ""), + ) + if str(part or "").strip() + ) if not body: continue out.append( @@ -275,7 +282,10 @@ def documents_from_episodes(episodes: Iterable[Any]) -> list[RecallDocument]: personal_model_id=getattr(episode, "personal_model_id", None), state_id=getattr(episode, "state_id", None), episode_id=getattr(episode, "episode_id", None), - metadata={**{str(k): str(v) for k, v in metadata.items()}, "recall_source": "episode"}, + metadata={ + **{str(k): str(v) for k, v in metadata.items()}, + "recall_source": "episode", + }, ) ) return out @@ -326,7 +336,10 @@ def _step_text(step: Any, metadata: Mapping[str, str]) -> str: elif normalized_action == "emit_response": parts = [str(metadata.get("final_response") or metadata.get("assistant_response") or summary).strip()] elif normalized_action == "reply": - parts = [summary, str(metadata.get("final_response") or metadata.get("assistant_response") or "").strip()] + parts = [ + summary, + str(metadata.get("final_response") or metadata.get("assistant_response") or "").strip(), + ] else: parts = [ summary, @@ -363,7 +376,8 @@ def _collect_recall_documents( except Exception: episodes = () documents.extend( - document for document in documents_from_episodes(episodes or ()) + document + for document in documents_from_episodes(episodes or ()) if not _is_excluded_episode(document.episode_id, excluded) ) elif scope == "steps": @@ -372,16 +386,17 @@ def _collect_recall_documents( except Exception: steps = () documents.extend( - document for document in documents_from_steps(steps or ()) + document + for document in documents_from_steps(steps or ()) if (not state_id or document.state_id in {None, "", state_id}) and not _is_excluded_episode(document.episode_id, excluded) ) # Legacy scopes (personal_model, state, sources) are no longer supported. # Steps + episodes + semantic index are the canonical search path. return [ - document for document in documents - if _in_time_range(document.when, time_range) - and not _is_excluded_episode(document.episode_id, excluded) + document + for document in documents + if _in_time_range(document.when, time_range) and not _is_excluded_episode(document.episode_id, excluded) ] @@ -668,8 +683,7 @@ def recall_timeline( if query and score <= 0.0: continue anchor_items = [ - item for item in items - if item.text.strip() and (not query or _text_relevance_score(query, item.text) > 0.0) + item for item in items if item.text.strip() and (not query or _text_relevance_score(query, item.text) > 0.0) ] anchors = tuple( {"kind": item.kind, "text": _anchor_text(item.text)} @@ -689,7 +703,12 @@ def recall_timeline( payload["anchors"] = anchors ranges.append(payload) if query: - ranges.sort(key=lambda item: (-float(item.get("score", 0.0)), str(item.get("start_at", "")))) + ranges.sort( + key=lambda item: ( + -float(item.get("score", 0.0)), + str(item.get("start_at", "")), + ) + ) else: ranges.sort(key=lambda item: str(item.get("start_at", ""))) selected = tuple(ranges[:capped]) @@ -729,7 +748,7 @@ def unified_recall( tuple of RecallHit ordered best-to-worst. """ capped = max(1, min(int(request.limit or 5), 20)) - now_ts = (request.now or datetime.now(timezone.utc)) + now_ts = request.now or datetime.now(timezone.utc) query_plan = plan_recall_query(request.query) query = query_plan.search_query.strip() scopes = tuple(request.scopes) or CONVERSATION_SEARCH_SCOPES @@ -760,7 +779,9 @@ def unified_recall( embedding_available = False query_vector_cache: dict[int | None, tuple[tuple[float, ...], int | None]] = {} - def query_vector_for(dimensions: int | None) -> tuple[tuple[float, ...], int | None]: + def query_vector_for( + dimensions: int | None, + ) -> tuple[tuple[float, ...], int | None]: if dimensions in query_vector_cache: return query_vector_cache[dimensions] if not embedding_available: @@ -786,7 +807,11 @@ def query_vector_for(dimensions: int | None) -> tuple[tuple[float, ...], int | N # Attempt hybrid per-scope; collect matches into one ranked list. hits: list[RecallHit] = [] - per_scope_limit = max(capped, capped * (8 if time_range is not None else 2), 50 if time_range is not None else capped) + per_scope_limit = max( + capped, + capped * (8 if time_range is not None else 2), + 50 if time_range is not None else capped, + ) require_text_anchor = _query_needs_text_anchor(query) excluded = _excluded_episode_ids(request.exclude_episode_ids) for scope in scopes: @@ -832,7 +857,11 @@ def query_vector_for(dimensions: int | None) -> tuple[tuple[float, ...], int | N if scope == "steps" and _semantic_step_is_noise(document, hit): continue hit_episode_id = dict(getattr(hit, "extra_metadata", {}) or {}).get("episode_id") if hit is not None else "" - if hit is not None and not _is_excluded_episode(hit_episode_id, excluded) and _in_time_range(hit.when_datetime, time_range): + if ( + hit is not None + and not _is_excluded_episode(hit_episode_id, excluded) + and _in_time_range(hit.when_datetime, time_range) + ): hits.append(hit) step_candidates = _collect_fallback_candidates( @@ -844,7 +873,14 @@ def query_vector_for(dimensions: int | None) -> tuple[tuple[float, ...], int | N time_range=time_range, exclude_episode_ids=request.exclude_episode_ids, ) - hits.extend(rank_recall_candidates(request.query, step_candidates, limit=max(capped, len(step_candidates)), now=now_ts)) + hits.extend( + rank_recall_candidates( + request.query, + step_candidates, + limit=max(capped, len(step_candidates)), + now=now_ts, + ) + ) # Merge: sort by relevance plus intent-aware freshness, then de-duplicate by provenance first. ranked_hits = rerank_recall_hits(tuple(hits), plan=query_plan, now=now_ts) @@ -854,8 +890,7 @@ def query_vector_for(dimensions: int | None) -> tuple[tuple[float, ...], int | N hit = ranked_hit.hit metadata = dict(hit.extra_metadata or {}) provenance = "|".join( - str(metadata.get(key) or "").strip() - for key in ("episode_id", "loop_id", "step_id", "source_id") + str(metadata.get(key) or "").strip() for key in ("episode_id", "loop_id", "step_id", "source_id") ).strip("|") key = provenance or hit.content.casefold() if key in seen_keys: diff --git a/packages/gateway_core/outbound_delivery.py b/packages/gateway_core/outbound_delivery.py index ffa6dbd..77d498c 100644 --- a/packages/gateway_core/outbound_delivery.py +++ b/packages/gateway_core/outbound_delivery.py @@ -8,7 +8,7 @@ from __future__ import annotations from collections.abc import Mapping -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Protocol from packages.contracts.runtime import ExecutionResult @@ -75,23 +75,16 @@ def send_message( side_effects=("delivery",), ) - def _resolve_route_from_identity( - self, *, target: str | None - ) -> tuple[str, str, str]: + def _resolve_route_from_identity(self, *, target: str | None) -> tuple[str, str, str]: """Resolve IM routing via identity store (CLI/TUI fallback path).""" if self.identity_store is None: - raise ValueError( - "Cannot send IM message: no gateway session context and no identity store configured." - ) + raise ValueError("Cannot send IM message: no gateway session context and no identity store configured.") if self.default_elephant_id: records = self.identity_store.lookup_by_elephant_id(self.default_elephant_id) else: records = self.identity_store.list_records() if not records: - raise ValueError( - "Cannot send IM message: no IM identity records found. " - "Pair an IM account first." - ) + raise ValueError("Cannot send IM message: no IM identity records found. Pair an IM account first.") # Filter by target hint if provided (e.g. "feishu", "weixin", "dingding") if target: hint = target.strip().lower() diff --git a/packages/gateway_core/outbound_drain.py b/packages/gateway_core/outbound_drain.py index 7114695..d50e92f 100644 --- a/packages/gateway_core/outbound_drain.py +++ b/packages/gateway_core/outbound_drain.py @@ -211,4 +211,3 @@ def _send_one_row_sync( __all__ = ["run_outbound_drain_loop", "run_outbound_drain_thread"] - diff --git a/packages/gateway_core/outbound_queue.py b/packages/gateway_core/outbound_queue.py index a73b153..ad1c9a9 100644 --- a/packages/gateway_core/outbound_queue.py +++ b/packages/gateway_core/outbound_queue.py @@ -58,7 +58,7 @@ from collections.abc import Mapping from contextlib import contextmanager -from dataclasses import asdict, dataclass, field, replace +from dataclasses import asdict, dataclass, replace from datetime import datetime, timedelta, timezone import json from pathlib import Path @@ -329,9 +329,7 @@ def _row_from_payload(payload: Mapping[str, Any]) -> GatewayOutboundRow: attempts=int(payload.get("attempts") or 0), created_at=created_at, available_at=available_at, - last_error=( - str(payload["last_error"]) if payload.get("last_error") not in (None, "") else None - ), + last_error=(str(payload["last_error"]) if payload.get("last_error") not in (None, "") else None), ) diff --git a/packages/gateway_core/pairing.py b/packages/gateway_core/pairing.py index e66a037..989369a 100644 --- a/packages/gateway_core/pairing.py +++ b/packages/gateway_core/pairing.py @@ -123,7 +123,10 @@ def pending_requests(self, *, platform: str) -> tuple[PairingRequest, ...]: if not _request_expired(payload, now) ) if len(active) != len(pending): - self._write_mapping(self._pending_path(platform_key), {item.code: _request_payload(item) for item in active}) + self._write_mapping( + self._pending_path(platform_key), + {item.code: _request_payload(item) for item in active}, + ) return active def _pending_path(self, platform: str) -> Path: diff --git a/packages/gateway_core/runtime.py b/packages/gateway_core/runtime.py index 7270e69..9248185 100644 --- a/packages/gateway_core/runtime.py +++ b/packages/gateway_core/runtime.py @@ -277,9 +277,7 @@ def resolve_cron_identity_records( records = identity_store.lookup_by_elephant_id(elephant_id) else: records = identity_store.list_records() - adapter_records = tuple( - record for record in records if record.key.adapter_id == adapter_id - ) + adapter_records = tuple(record for record in records if record.key.adapter_id == adapter_id) if elephant_id or not adapter_records: return adapter_records unique_elephants = {record.elephant_id for record in adapter_records if record.elephant_id} @@ -318,9 +316,7 @@ def list_records(self) -> tuple[GatewayRouteState, ...]: @dataclass(slots=True) class InMemoryGatewayIdentityStore: - _records: dict[GatewayIdentityKey, GatewayIdentityRecord] = field( - default_factory=dict - ) + _records: dict[GatewayIdentityKey, GatewayIdentityRecord] = field(default_factory=dict) def lookup(self, key: GatewayIdentityKey) -> GatewayIdentityRecord | None: return self._records.get(key) @@ -332,7 +328,7 @@ def lookup_by_elephant_id(self, elephant_id: str) -> tuple[GatewayIdentityRecord return tuple( sorted( (r for r in self._records.values() if r.elephant_id == elephant_id), - key=lambda r: (r.updated_at or r.created_at or _utc_now()), + key=lambda r: r.updated_at or r.created_at or _utc_now(), reverse=True, ) ) @@ -366,7 +362,7 @@ def lookup_by_elephant_id(self, elephant_id: str) -> tuple[GatewayIdentityRecord return tuple( sorted( (r for r in self._load_records().values() if r.elephant_id == elephant_id), - key=lambda r: (r.updated_at or r.created_at or _utc_now()), + key=lambda r: r.updated_at or r.created_at or _utc_now(), reverse=True, ) ) @@ -405,17 +401,9 @@ def _load_records(self) -> dict[GatewayIdentityKey, GatewayIdentityRecord]: state_id=str(item["state_id"]) if item.get("state_id") is not None else None, elephant_id=str(item["elephant_id"]) if item.get("elephant_id") is not None else None, episode_id=str(item["episode_id"]) if item.get("episode_id") is not None else None, - display_name=( - str(item["display_name"]) - if item.get("display_name") is not None - else None - ), - created_at=_parse_datetime( - str(item["created_at"]) if item.get("created_at") is not None else None - ), - updated_at=_parse_datetime( - str(item["updated_at"]) if item.get("updated_at") is not None else None - ), + display_name=(str(item["display_name"]) if item.get("display_name") is not None else None), + created_at=_parse_datetime(str(item["created_at"]) if item.get("created_at") is not None else None), + updated_at=_parse_datetime(str(item["updated_at"]) if item.get("updated_at") is not None else None), ) records[record.key] = record return records @@ -518,9 +506,7 @@ def _load_records(self) -> dict[str, GatewayRouteState]: started_at=datetime.fromisoformat(str(item["started_at"])), updated_at=datetime.fromisoformat(str(item["updated_at"])), interruption_state=( - str(item["interruption_state"]) - if item.get("interruption_state") is not None - else None + str(item["interruption_state"]) if item.get("interruption_state") is not None else None ), ) records[session.session_id] = session @@ -579,10 +565,14 @@ def bind_elephant( conversation_id=inbound.conversation_id, ) existing = self.dependencies.identity_store.lookup(key) - session_id = existing.session_id if existing is not None else _session_id( - inbound.adapter_id, - inbound.account_id, - inbound.conversation_id, + session_id = ( + existing.session_id + if existing is not None + else _session_id( + inbound.adapter_id, + inbound.account_id, + inbound.conversation_id, + ) ) session = self.dependencies.session_store.lookup(session_id) if session is None: @@ -596,7 +586,9 @@ def bind_elephant( else: session = replace(session, updated_at=now) identity = GatewayIdentityRecord( - mapping_id=existing.mapping_id if existing is not None else _mapping_id( + mapping_id=existing.mapping_id + if existing is not None + else _mapping_id( inbound.adapter_id, inbound.account_id, inbound.conversation_id, @@ -629,7 +621,11 @@ def route_inbound( ) identity = self.dependencies.identity_store.lookup(key) inherited: GatewayIdentityRecord | None = None - if identity is None and inbound.parent_conversation_id and inbound.parent_conversation_id != inbound.conversation_id: + if ( + identity is None + and inbound.parent_conversation_id + and inbound.parent_conversation_id != inbound.conversation_id + ): inherited = self.dependencies.identity_store.lookup( GatewayIdentityKey( adapter_id=inbound.adapter_id, @@ -748,20 +744,10 @@ def deliver( is_external: bool | None = None, ) -> GatewayDeliveryReceipt: resolved_target_trusted = ( - route.inbound.policy_hint.target_trusted_default - if target_trusted is None - else target_trusted - ) - resolved_consent_given = ( - route.inbound.policy_hint.consent_default - if consent_given is None - else consent_given - ) - resolved_is_external = ( - route.inbound.policy_hint.is_external_default - if is_external is None - else is_external + route.inbound.policy_hint.target_trusted_default if target_trusted is None else target_trusted ) + resolved_consent_given = route.inbound.policy_hint.consent_default if consent_given is None else consent_given + resolved_is_external = route.inbound.policy_hint.is_external_default if is_external is None else is_external request = SecurityRequest( request_id=f"{route.route_id}:policy:{uuid4().hex[:8]}", approval_class=ApprovalClass.MESSAGING, @@ -841,9 +827,7 @@ def process_message( delivery = self.deliver( route, body=body or inbound.body, - reply_to_message_id=reply_to_message_id - or inbound.reply_to_message_id - or inbound.event_id, + reply_to_message_id=reply_to_message_id or inbound.reply_to_message_id or inbound.event_id, attachment_refs=attachment_refs, metadata=metadata, target_trusted=target_trusted, diff --git a/packages/growth/projection.py b/packages/growth/projection.py index b3c0556..e83e363 100644 --- a/packages/growth/projection.py +++ b/packages/growth/projection.py @@ -5,18 +5,20 @@ from dataclasses import dataclass from datetime import datetime -from packages.contracts.runtime import ExperienceRecord, ProcedureRecord, PersonalModelGrowthState +from packages.contracts.runtime import ( + ExperienceRecord, + ProcedureRecord, + PersonalModelGrowthState, +) from .runtime import ( GROWTH_STAGES, - GrowthSnapshot, GrowthStageDescriptor, GrowthUpdate, _PROMOTED_PROCEDURE_STATUSES, _canonical_active_days, _contains_any, _curve_stage_id_for_level, - _dedupe_text, _local_day, _round_to_five, _roman_numeral, @@ -169,10 +171,7 @@ def build( 1 for procedure in procedures if procedure.status.strip().lower() in _PROMOTED_PROCEDURE_STATUSES ) skill_refs = { - skill_id - for experience in experiences - for skill_id in experience.related_skill_ids - if str(skill_id).strip() + skill_id for experience in experiences for skill_id in experience.related_skill_ids if str(skill_id).strip() } skill_refs.update( procedure.skill_id @@ -192,7 +191,10 @@ def build( mastery_vector = ( ProgressionMasterySignal( axis="execution", - score=min(12, (4 if active_work_item is not None else 0) + min(4, experience_count) + min(4, artifact_count)), + score=min( + 12, + (4 if active_work_item is not None else 0) + min(4, experience_count) + min(4, artifact_count), + ), summary=( f"Current focus is {str(getattr(active_work_item, 'title', '') or '')}." if active_work_item is not None @@ -201,7 +203,10 @@ def build( ), ProgressionMasterySignal( axis="continuity", - score=min(12, continuity_bonus + min(4, continuity_recoveries) + min(4, power_state.streak_days)), + score=min( + 12, + continuity_bonus + min(4, continuity_recoveries) + min(4, power_state.streak_days), + ), summary=( "Recovery context is active and the thread is being carried forward." if continuity_mode.strip().lower() != "foreground" @@ -371,7 +376,9 @@ def _projection_state_from_canonical( default=(state.last_dialogue_at if state is not None else None) or first_moment, ) experience_count = len(experiences) if experiences else (state.total_experiences if state is not None else 0) - canonical_active_days = _canonical_active_days(experiences) if experiences else (state.active_days if state is not None else 0) + canonical_active_days = ( + _canonical_active_days(experiences) if experiences else (state.active_days if state is not None else 0) + ) if state is not None: power_score = max(0, state.growth_score) total_dialogues = max(state.total_dialogues, experience_count) @@ -381,7 +388,10 @@ def _projection_state_from_canonical( power_score = _round_to_five( (experience_count * 40) + (promoted_procedures * 80) - + min(30, sum(len(experience.related_skill_ids) for experience in experiences) * 10) + + min( + 30, + sum(len(experience.related_skill_ids) for experience in experiences) * 10, + ) ) total_dialogues = experience_count total_tokens = sum(max(0, len(experience.summary) + len(experience.title)) for experience in experiences) @@ -395,9 +405,17 @@ def _projection_state_from_canonical( promoted_experiences=promoted_procedures, active_days=max(canonical_active_days, state.active_days if state is not None else 0), streak_days=streak_days, - first_dialogue_at=state.first_dialogue_at if state is not None and state.first_dialogue_at is not None else first_moment, - last_dialogue_at=state.last_dialogue_at if state is not None and state.last_dialogue_at is not None else last_moment, - last_active_day=(state.last_active_day if state is not None and state.last_active_day is not None else _local_day(last_moment).isoformat()), + first_dialogue_at=state.first_dialogue_at + if state is not None and state.first_dialogue_at is not None + else first_moment, + last_dialogue_at=state.last_dialogue_at + if state is not None and state.last_dialogue_at is not None + else last_moment, + last_active_day=( + state.last_active_day + if state is not None and state.last_active_day is not None + else _local_day(last_moment).isoformat() + ), created_at=state.created_at if state is not None and state.created_at is not None else first_moment, updated_at=fallback_updated_at or last_moment, ) @@ -406,7 +424,10 @@ def _projection_state_from_canonical( def _lifetime_days_for(state: PersonalModelGrowthState) -> int: if state.first_dialogue_at is None or state.last_dialogue_at is None: return 0 - return max(1, (_local_day(state.last_dialogue_at) - _local_day(state.first_dialogue_at)).days + 1) + return max( + 1, + (_local_day(state.last_dialogue_at) - _local_day(state.first_dialogue_at)).days + 1, + ) def _understanding_rank( @@ -473,7 +494,7 @@ def _active_challenge_tracks( ProgressionChallengeTrack( track_id="current-focus", label="Keep the current focus visible", - summary=str(getattr(active_work_item, 'title', '') or ''), + summary=str(getattr(active_work_item, "title", "") or ""), status="active", ) ) diff --git a/packages/growth/rollout.py b/packages/growth/rollout.py index b6d9517..a62f80a 100644 --- a/packages/growth/rollout.py +++ b/packages/growth/rollout.py @@ -6,7 +6,11 @@ from datetime import datetime, timedelta, timezone from types import SimpleNamespace -from packages.contracts.runtime import ExperienceRecord, ProcedureRecord, PersonalModelGrowthState +from packages.contracts.runtime import ( + ExperienceRecord, + ProcedureRecord, + PersonalModelGrowthState, +) from .projection import ProgressionProjection, ProgressionProjectionBuilder from .runtime import ( @@ -287,7 +291,9 @@ def default_progression_rollout_scorecard() -> ProgressionRolloutScorecard: status="blocked", priority="high", ) - meaningful_state = _certification_base_growth_state("profile-rollout-meaningful", now=now, dialogues=5, experiences=3) + meaningful_state = _certification_base_growth_state( + "profile-rollout-meaningful", now=now, dialogues=5, experiences=3 + ) trivial_state_a = _certification_base_growth_state("profile-rollout-trivial-a", now=now, dialogues=5, experiences=2) trivial_state_b = _certification_base_growth_state("profile-rollout-trivial-b", now=now, dialogues=7, experiences=2) cases = ( @@ -450,18 +456,18 @@ def _worst_variance_for(comparisons: tuple[ProgressionShadowComparison, ...]) -> return max(variances) -def _minimum_motivation_delta_for(comparisons: tuple[ProgressionShadowComparison, ...]) -> int: - non_trivial = [ - comparison.delta_score - for comparison in comparisons - if comparison.classification != "trivial" - ] +def _minimum_motivation_delta_for( + comparisons: tuple[ProgressionShadowComparison, ...], +) -> int: + non_trivial = [comparison.delta_score for comparison in comparisons if comparison.classification != "trivial"] if not non_trivial: return 0 return min(non_trivial) -def _explanation_drift_cases_for(comparisons: tuple[ProgressionShadowComparison, ...]) -> list[str]: +def _explanation_drift_cases_for( + comparisons: tuple[ProgressionShadowComparison, ...], +) -> list[str]: drifted: list[str] = [] pattern_families = {comparison.pattern_family for comparison in comparisons} for family in pattern_families: @@ -544,7 +550,12 @@ def _meaningful_rollout_signals(profile_id: str, *, session_id: str, now: dateti artifact_ids=("artifact:patch-note",), promoted_procedure_ids=("procedure:resume-checklist",), personal_model_fact_count=8, - personal_model_lens_counts=(("identity", 2), ("world", 2), ("pulse", 2), ("journey", 2)), + personal_model_lens_counts=( + ("identity", 2), + ("world", 2), + ("pulse", 2), + ("journey", 2), + ), personal_model_topic_count=6, personal_model_new_fact_count=2, personal_model_updated_fact_count=1, diff --git a/packages/growth/runtime.py b/packages/growth/runtime.py index 2eddf69..ab6d85a 100644 --- a/packages/growth/runtime.py +++ b/packages/growth/runtime.py @@ -363,7 +363,10 @@ def build_growth_snapshot(state: PersonalModelGrowthState) -> GrowthSnapshot: progress_percent = int(round(progress_ratio * 100)) lifetime_days = 0 if state.first_dialogue_at is not None and state.last_dialogue_at is not None: - lifetime_days = max(1, (_local_day(state.last_dialogue_at) - _local_day(state.first_dialogue_at)).days + 1) + lifetime_days = max( + 1, + (_local_day(state.last_dialogue_at) - _local_day(state.first_dialogue_at)).days + 1, + ) return GrowthSnapshot( state=state, level=level, @@ -501,7 +504,9 @@ def _personal_model_lens_count_map(signals: GrowthTurnSignals) -> dict[str, int] return counts -def _understanding_coverage_reason(signals: GrowthTurnSignals) -> GrowthRewardReason | None: +def _understanding_coverage_reason( + signals: GrowthTurnSignals, +) -> GrowthRewardReason | None: counts = _personal_model_lens_count_map(signals) covered_lenses = tuple(lens for lens in _PERSONAL_MODEL_LENSES if counts[lens] > 0) fact_count = max(0, signals.personal_model_fact_count) @@ -524,7 +529,9 @@ def _understanding_coverage_reason(signals: GrowthTurnSignals) -> GrowthRewardRe ) -def _understanding_richness_reason(signals: GrowthTurnSignals) -> GrowthRewardReason | None: +def _understanding_richness_reason( + signals: GrowthTurnSignals, +) -> GrowthRewardReason | None: topic_count = max(0, signals.personal_model_topic_count) if topic_count <= 0: return None @@ -550,7 +557,9 @@ def _understanding_richness_reason(signals: GrowthTurnSignals) -> GrowthRewardRe ) -def _understanding_freshness_reason(signals: GrowthTurnSignals) -> GrowthRewardReason | None: +def _understanding_freshness_reason( + signals: GrowthTurnSignals, +) -> GrowthRewardReason | None: new_facts = max(0, signals.personal_model_new_fact_count) updated_facts = max(0, signals.personal_model_updated_fact_count) score = min(30, (new_facts * 10) + (updated_facts * 6)) @@ -566,7 +575,9 @@ def _understanding_freshness_reason(signals: GrowthTurnSignals) -> GrowthRewardR ) -def _understanding_grounding_reason(signals: GrowthTurnSignals) -> GrowthRewardReason | None: +def _understanding_grounding_reason( + signals: GrowthTurnSignals, +) -> GrowthRewardReason | None: supported_facts = max(0, signals.personal_model_supported_fact_count) evidence_refs = max(0, signals.personal_model_evidence_ref_count) if supported_facts <= 0 and evidence_refs <= 0: @@ -587,7 +598,9 @@ def _understanding_grounding_reason(signals: GrowthTurnSignals) -> GrowthRewardR ) -def _interaction_activity_reason(signals: GrowthTurnSignals) -> GrowthRewardReason | None: +def _interaction_activity_reason( + signals: GrowthTurnSignals, +) -> GrowthRewardReason | None: score = 0 facts: list[str] = [] if signals.tool_call_count: @@ -725,7 +738,9 @@ def _learning_yield_reason(signals: GrowthTurnSignals) -> GrowthRewardReason | N ) -def _capability_leverage_reason(signals: GrowthTurnSignals) -> GrowthRewardReason | None: +def _capability_leverage_reason( + signals: GrowthTurnSignals, +) -> GrowthRewardReason | None: score = 0 facts: list[str] = [] if signals.tool_call_count: diff --git a/packages/harness/supervisor.py b/packages/harness/supervisor.py index 1528979..60b1fa7 100644 --- a/packages/harness/supervisor.py +++ b/packages/harness/supervisor.py @@ -250,6 +250,7 @@ def _timer_is_ripe( def _replace(loop: LoopState, **updates) -> LoopState: from dataclasses import replace as dc_replace + return dc_replace(loop, **updates) diff --git a/packages/kernel/context_compaction.py b/packages/kernel/context_compaction.py index 13e90af..cc72f2e 100644 --- a/packages/kernel/context_compaction.py +++ b/packages/kernel/context_compaction.py @@ -41,7 +41,6 @@ def looks_like_context_overflow(error: BaseException) -> bool: ) - def projection_compaction_detail(result: object) -> str: before_tokens = getattr(result, "before_tokens", 0) after_tokens = getattr(result, "after_tokens", 0) @@ -70,7 +69,6 @@ def projection_compaction_detail(result: object) -> str: ) - def latest_compacted_projection(context_capability: object) -> object | None: result = getattr(context_capability, "last_projection_compaction", None) if result is None or not bool(getattr(result, "compacted", False)): @@ -78,14 +76,12 @@ def latest_compacted_projection(context_capability: object) -> object | None: return result - def flush_projection_cache(context_capability: object) -> None: flush = getattr(context_capability, "flush_projection_cache", None) if callable(flush): flush() - def stage_context_usage( stage: Any, prompt_tokens: int, @@ -100,7 +96,6 @@ def stage_context_usage( ) - def estimate_context_projection_tokens(context: Any) -> int: rendered_prompt = str(getattr(context, "rendered_prompt", "") or "").strip() if rendered_prompt: @@ -112,7 +107,6 @@ def estimate_context_projection_tokens(context: Any) -> int: return 0 - def stage_context_projection(stage: Any, context: Any, *, source: str = "generation") -> None: if not callable(stage): return @@ -120,8 +114,10 @@ def stage_context_projection(stage: Any, context: Any, *, source: str = "generat token_budget = int(getattr(context, "token_budget", 0) or 0) if prompt_tokens <= 0 and token_budget <= 0: return - stage("context-projection", f"prompt_tokens={prompt_tokens} token_budget={token_budget} source={source}") - + stage( + "context-projection", + f"prompt_tokens={prompt_tokens} token_budget={token_budget} source={source}", + ) def compact_context_after_usage( @@ -140,7 +136,6 @@ def compact_context_after_usage( return None - def episode_continuity_packet( *, request: Any, @@ -181,7 +176,6 @@ def episode_continuity_packet( return EpisodeContinuityPacket(packet_id=packet_id, text="\n".join(lines), source_refs=source_refs) - def compaction_step_metadata( *, packet: EpisodeContinuityPacket, @@ -201,7 +195,6 @@ def compaction_step_metadata( } - def append_episode_continuity_packet(context: ContextBundle, packet: EpisodeContinuityPacket) -> ContextBundle: rendered = str(context.rendered_prompt or "").strip() updated_rendered = packet.text if not rendered else f"{rendered}\n\n{packet.text}" @@ -212,18 +205,15 @@ def append_episode_continuity_packet(context: ContextBundle, packet: EpisodeCont ) - def _csv(values: tuple[object, ...], *, separator: str = ", ") -> str: cleaned = tuple(str(value).strip() for value in values if str(value).strip()) return separator.join(cleaned) if cleaned else "" - def _single_line(value: str) -> str: return " ".join(str(value).split()) - def retry_context_after_provider_overflow( *, error: RuntimeError, diff --git a/packages/kernel/execution_support.py b/packages/kernel/execution_support.py index e6e8a92..f580287 100644 --- a/packages/kernel/execution_support.py +++ b/packages/kernel/execution_support.py @@ -100,9 +100,20 @@ def execute_kernel_turn( "tool_result": execution.summary, }, ) - return execution, checkpoint, (*turn_messages, *assistant_turn_messages(_clean_execution_summary(execution))) + return ( + execution, + checkpoint, + ( + *turn_messages, + *assistant_turn_messages(_clean_execution_summary(execution)), + ), + ) - model_prompt = turn_messages[0].content if len(turn_messages) == 1 and turn_messages[0].role == "user" else prompt_for_execution + model_prompt = ( + turn_messages[0].content + if len(turn_messages) == 1 and turn_messages[0].role == "user" + else prompt_for_execution + ) _record_effective_user_query_step( step_recorder, raw_prompt=request.prompt, @@ -121,7 +132,12 @@ def execute_kernel_turn( planned_summary="initial model call", ) if service.dependencies.tools is None: - stage_context_usage(stage, response.prompt_tokens, response.completion_tokens, response.total_tokens) + stage_context_usage( + stage, + response.prompt_tokens, + response.completion_tokens, + response.total_tokens, + ) cleaned = _clean_execution_summary(response) return cleaned, None, (*turn_messages, *assistant_turn_messages(cleaned)) return _execute_model_tool_loop( @@ -331,7 +347,9 @@ def _execute_model_tool_loop( service._persist_loop_checkpoint(current_loop) provider_system_prompt = _provider_system_prompt_for_recording(context) if current_loop is not None and not context_recorded and provider_system_prompt: - current_loop, context_step = loop_service.record_context_prompt(current_loop, system_prompt=provider_system_prompt) + current_loop, context_step = loop_service.record_context_prompt( + current_loop, system_prompt=provider_system_prompt + ) service._persist_loop_checkpoint(current_loop, step=context_step) context_recorded = True if current_loop is not None: @@ -448,9 +466,13 @@ def _finalize_model_loop_response( cache_creation_prompt_tokens=cache_creation_prompt_tokens_total, cache_usage_reported=cache_usage_reported, ) - finalized = cleaned if not loop_traces else replace( - cleaned, - side_effects=tuple(dict.fromkeys((*cleaned.side_effects, *loop_traces))), + finalized = ( + cleaned + if not loop_traces + else replace( + cleaned, + side_effects=tuple(dict.fromkeys((*cleaned.side_effects, *loop_traces))), + ) ) if current_loop is not None: current_loop = loop_service.complete(current_loop, summary=finalized.summary) diff --git a/packages/kernel/generation_context.py b/packages/kernel/generation_context.py index 6316840..8cb041c 100644 --- a/packages/kernel/generation_context.py +++ b/packages/kernel/generation_context.py @@ -119,7 +119,10 @@ def _augment_with_system_layers( frozen_prefix = cached[1] else: frozen_prefix = _build_frozen_prefix( - frozen_prefix, committed_pm_lines, resume_lines, skill_index_section, + frozen_prefix, + committed_pm_lines, + resume_lines, + skill_index_section, ) if episode_id: if len(_prefix_cache) >= _PREFIX_CACHE_MAX and episode_id not in _prefix_cache: @@ -176,9 +179,9 @@ def _augment_with_system_layers( # Canonical facet order within each lens (for stable rendering order). _LENS_FACET_ORDER: dict[str, tuple[str, ...]] = { "identity": ("anchor", "character", "values", "style", "body"), - "world": ("people", "projects", "tools", "places", "assets"), - "pulse": ("chapter", "focus", "mood", "blockers", "intent"), - "journey": ("lessons", "patterns", "decisions", "milestones"), + "world": ("people", "projects", "tools", "places", "assets"), + "pulse": ("chapter", "focus", "mood", "blockers", "intent"), + "journey": ("lessons", "patterns", "decisions", "milestones"), } _KNOWN_LENSES = frozenset({"identity", "world", "pulse", "journey"}) @@ -208,15 +211,15 @@ def _prefix_input_hash( ) -> str: h = hashlib.sha256() h.update(base_prefix.encode()) - h.update(b'\x00') + h.update(b"\x00") for line in pm_lines: h.update(len(line).to_bytes(4, "big")) h.update(line.encode()) - h.update(b'\x00') + h.update(b"\x00") for line in resume_lines: h.update(len(line).to_bytes(4, "big")) h.update(line.encode()) - h.update(b'\x00') + h.update(b"\x00") h.update(skill_section.encode()) return h.hexdigest() @@ -300,9 +303,7 @@ def _frozen_committed_pm_lines(storage: Any, request: Any) -> tuple[str, ...]: return () # by_lens_facet[lens][facet] = [fact, ...] - by_lens_facet: dict[str, dict[str, list]] = { - lens: {} for lens in ("identity", "world", "pulse", "journey") - } + by_lens_facet: dict[str, dict[str, list]] = {lens: {} for lens in ("identity", "world", "pulse", "journey")} for fact in facts: if fact.confidence < 0.6: continue @@ -332,7 +333,13 @@ def _frozen_committed_pm_lines(storage: Any, request: Any) -> tuple[str, ...]: ordered_facets = [f for f in canonical if f in facet_map] ordered_facets += sorted(f for f in facet_map if f not in canonical) for facet in ordered_facets: - facet_facts = sorted(facet_map[facet], key=lambda f: (-f.confidence, getattr(f, "fact_id", "") or getattr(f, "text", ""))) + facet_facts = sorted( + facet_map[facet], + key=lambda f: ( + -f.confidence, + getattr(f, "fact_id", "") or getattr(f, "text", ""), + ), + ) lines.append(f"#### {facet}") for fact in facet_facts: text = _fact_prompt_text(fact) @@ -341,11 +348,14 @@ def _frozen_committed_pm_lines(storage: Any, request: Any) -> tuple[str, ...]: return tuple(lines) - def _fact_visible_in_core_prompt(fact: Any, metadata: dict[str, Any]) -> bool: recall_policy = str(metadata.get("recall_policy") or "").strip().lower() lifecycle = str(metadata.get("retention_lifecycle") or "").strip().lower() - if recall_policy in {"temporary", "review"} or lifecycle in {"temporal", "draft", "working"}: + if recall_policy in {"temporary", "review"} or lifecycle in { + "temporal", + "draft", + "working", + }: return False text = str(getattr(fact, "text", "") or "").strip() if text.startswith("Question-bank signal for "): @@ -359,8 +369,6 @@ def _fact_prompt_text(fact: Any) -> str: return str(getattr(fact, "text", "") or "").strip() - - def _facts_for(storage: Any, personal_model_id: str) -> tuple[Any, ...]: list_facts = getattr(storage, "list_personal_model_facts", None) if not callable(list_facts): @@ -371,7 +379,6 @@ def _facts_for(storage: Any, personal_model_id: str) -> tuple[Any, ...]: return () - def _episode_resume_lines(storage: Any, *, request: Any, session: Any) -> tuple[str, ...]: """Project the episode-open resume note outside the cacheable prefix. @@ -431,8 +438,6 @@ def _clean_state_field(raw: Any) -> str: return text - - def _dynamic_system_layer_lines( storage: Any, request: Any, @@ -443,7 +448,6 @@ def _dynamic_system_layer_lines( return () - def _render_section_line(line: str) -> str: cleaned = str(line or "").strip() if not cleaned: @@ -454,7 +458,12 @@ def _render_section_line(line: str) -> str: def _render_prompt_section(heading: str, lines: tuple[str, ...]) -> str: - return "\n".join((f"### {heading}", *(_render_section_line(line) for line in lines if str(line or "").strip()))).strip() + return "\n".join( + ( + f"### {heading}", + *(_render_section_line(line) for line in lines if str(line or "").strip()), + ) + ).strip() def _append_raw_prompt_section(current: str, section: str) -> str: @@ -516,7 +525,7 @@ def _extract_prompt_section_content(current: str, heading: str) -> str: if lines[index].startswith("### "): end = index break - return "\n".join(lines[start + 1:end]).strip() + return "\n".join(lines[start + 1 : end]).strip() def _strip_prompt_section(current: str, heading: str) -> str: @@ -574,7 +583,12 @@ def _insert_prompt_section_after( def _append_rendered_section(current: str | None, heading: str, lines: tuple[str, ...]) -> str | None: if not lines: return current - section = "\n".join((f"### {heading}", *(_render_section_line(line) for line in lines if str(line or "").strip()))).strip() + section = "\n".join( + ( + f"### {heading}", + *(_render_section_line(line) for line in lines if str(line or "").strip()), + ) + ).strip() existing = str(current or "").strip() return section if not existing else f"{existing}\n\n{section}" diff --git a/packages/kernel/lifecycle_support.py b/packages/kernel/lifecycle_support.py index 4f90f03..78b43ce 100644 --- a/packages/kernel/lifecycle_support.py +++ b/packages/kernel/lifecycle_support.py @@ -32,7 +32,13 @@ class KernelLoopLifecycle: class KernelStepRecorder: - def __init__(self, storage: KernelStoragePort, loop: Loop, *, semantic_summary_indexer: object | None = None) -> None: + def __init__( + self, + storage: KernelStoragePort, + loop: Loop, + *, + semantic_summary_indexer: object | None = None, + ) -> None: self._storage = storage self._loop = loop self._semantic_summary_indexer = semantic_summary_indexer @@ -126,9 +132,7 @@ def resolve_runtime_identity( storage.switch_state(state.state_id, selected_at=current) return KernelRuntimeIdentity(personal_model=personal_model, state=state) - personal_model = storage.ensure_default_personal_model( - personal_model_id=request.personal_model_id or "you" - ) + personal_model = storage.ensure_default_personal_model(personal_model_id=request.personal_model_id or "you") state = storage.current_state() if state is None or state.personal_model_id != personal_model.personal_model_id: state = storage.create_state( @@ -252,7 +256,11 @@ def open_episode_lifecycle( if not is_new and episode.started_at is not None and episode.updated_at is not None: is_new = episode.started_at == episode.updated_at storage.upsert_episode(episode) - return KernelEpisodeLifecycle(episode=episode, close_on_completion=policy == "single_turn", is_new_episode=is_new) + return KernelEpisodeLifecycle( + episode=episode, + close_on_completion=policy == "single_turn", + is_new_episode=is_new, + ) idle_closed: list[Episode] = [] if policy == "gateway_idle_reuse": @@ -267,7 +275,10 @@ def open_episode_lifecycle( refreshed = replace( episode, updated_at=current, - metadata={**dict(episode.metadata), "last_activity_at": current.isoformat()}, + metadata={ + **dict(episode.metadata), + "last_activity_at": current.isoformat(), + }, ) storage.upsert_episode(refreshed) return KernelEpisodeLifecycle(episode=refreshed, close_on_completion=False) @@ -351,7 +362,10 @@ def close_episode_lifecycle( refreshed = replace( lifecycle.episode, updated_at=current, - metadata={**dict(lifecycle.episode.metadata), "last_activity_at": current.isoformat()}, + metadata={ + **dict(lifecycle.episode.metadata), + "last_activity_at": current.isoformat(), + }, ) storage.upsert_episode(refreshed) return refreshed @@ -361,7 +375,10 @@ def close_episode_lifecycle( ended_at=current, updated_at=current, exit_summary=summary, - metadata={**dict(lifecycle.episode.metadata), "closed_reason": "final_response"}, + metadata={ + **dict(lifecycle.episode.metadata), + "closed_reason": "final_response", + }, ) storage.upsert_episode(closed) # Push the exit summary into the semantic index so future episodes can diff --git a/packages/kernel/loop_checkpoint_support.py b/packages/kernel/loop_checkpoint_support.py index d749d0a..b23e9ec 100644 --- a/packages/kernel/loop_checkpoint_support.py +++ b/packages/kernel/loop_checkpoint_support.py @@ -384,16 +384,22 @@ def should_resume(self, prompt: str) -> bool: "keep digging", ) return any( - normalized == phrase - or normalized.startswith(f"{phrase} ") - or normalized.endswith(f" {phrase}") + normalized == phrase or normalized.startswith(f"{phrase} ") or normalized.endswith(f" {phrase}") for phrase in explicit_phrases ) def resume_prompt_for_request(self, run: LoopState, prompt: str) -> str: base = run.continuation_prompt or self.build_continuation_prompt(run, recent_steps=()) normalized = " ".join(prompt.casefold().split()) - if normalized in {"continue", "resume", "keep going", "go on", "carry on", "finish this", "finish it"}: + if normalized in { + "continue", + "resume", + "keep going", + "go on", + "carry on", + "finish this", + "finish it", + }: return base return f"{base}\n\nLatest user nudge:\n{prompt.strip()}" diff --git a/packages/kernel/reconciliation.py b/packages/kernel/reconciliation.py index d3d2ae7..d4ed1b6 100644 --- a/packages/kernel/reconciliation.py +++ b/packages/kernel/reconciliation.py @@ -3,7 +3,6 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timezone import re from typing import Mapping, Protocol from uuid import uuid4 @@ -293,17 +292,23 @@ def _extract_user_fields(text: str) -> dict[str, str]: def _extract_preference_updates(text: str) -> tuple[str, ...]: updates: list[str] = [] lower = text.lower() - if re.search(r"(?i)(?:reply|respond|responses|replies|answers|be|keep).{0,24}(?:concise|brief|short)", text) or any( - token in text for token in ("简洁", "简短", "精炼") - ): + if re.search( + r"(?i)(?:reply|respond|responses|replies|answers|be|keep).{0,24}(?:concise|brief|short)", + text, + ) or any(token in text for token in ("简洁", "简短", "精炼")): updates.append("verbosity:concise") - if re.search(r"(?i)(?:reply|respond|responses|replies|answers|be|keep).{0,24}(?:detailed|thorough|long-form)", text) or any( - token in text for token in ("详细", "展开一些") - ): + if re.search( + r"(?i)(?:reply|respond|responses|replies|answers|be|keep).{0,24}(?:detailed|thorough|long-form)", + text, + ) or any(token in text for token in ("详细", "展开一些")): updates.append("verbosity:detailed") - if re.search(r"(?i)(?:reply|respond).{0,16}(?:in chinese)", text) or any(token in text for token in ("用中文", "中文回答", "请中文回答")): + if re.search(r"(?i)(?:reply|respond).{0,16}(?:in chinese)", text) or any( + token in text for token in ("用中文", "中文回答", "请中文回答") + ): updates.append("language:zh-CN") - if re.search(r"(?i)(?:reply|respond).{0,16}(?:in english)", text) or any(token in text for token in ("用英文", "英文回答", "请英文回答")): + if re.search(r"(?i)(?:reply|respond).{0,16}(?:in english)", text) or any( + token in text for token in ("用英文", "英文回答", "请英文回答") + ): updates.append("language:en") if "bullet" in lower or "bullets" in lower or "bullet points" in lower or "要点" in text or "列表" in text: updates.append("response-style:bullets") @@ -317,13 +322,16 @@ def _extract_relationship_notes(text: str) -> tuple[str, ...]: lines = [text.strip()] for line in lines: lowered = line.lower() - if _first_match( - line, - ( - r"(?im)^\s*(?:preferred name|name|nickname|current work|work|work focus)\s*[::]", - r"(?im)^\s*(?:称呼|叫我|当前工作|工作方向|目前在做)\s*[::]", - ), - ) is not None: + if ( + _first_match( + line, + ( + r"(?im)^\s*(?:preferred name|name|nickname|current work|work|work focus)\s*[::]", + r"(?im)^\s*(?:称呼|叫我|当前工作|工作方向|目前在做)\s*[::]", + ), + ) + is not None + ): continue if any( marker in lowered @@ -339,7 +347,19 @@ def _extract_relationship_notes(text: str) -> tuple[str, ...]: "do not call me", "keep it", ) - ) or any(marker in line for marker in ("以后", "记住", "下次", "别叫我", "不要叫我", "回复时", "回答时", "说话时")): + ) or any( + marker in line + for marker in ( + "以后", + "记住", + "下次", + "别叫我", + "不要叫我", + "回复时", + "回答时", + "说话时", + ) + ): cleaned = _clean_capture(line) if cleaned: notes.append(cleaned) @@ -400,7 +420,9 @@ def _transcript_final_assistant_response(messages: tuple[PromptMessage, ...]) -> return "" -def _transcript_tool_result_details(messages: tuple[PromptMessage, ...]) -> tuple[str, ...]: +def _transcript_tool_result_details( + messages: tuple[PromptMessage, ...], +) -> tuple[str, ...]: details: list[str] = [] for message in messages: if message.role != "tool": diff --git a/packages/kernel/resume_support.py b/packages/kernel/resume_support.py index b38d536..5afa2a0 100644 --- a/packages/kernel/resume_support.py +++ b/packages/kernel/resume_support.py @@ -131,9 +131,7 @@ def plan_pending_tool_replay( call=entry, action="skip", reason=( - "tool completed before crash; Step " - + matching_step_ids[0] - + " already records the outcome" + "tool completed before crash; Step " + matching_step_ids[0] + " already records the outcome" ), ) ) @@ -223,9 +221,7 @@ def apply_resume_snapshot( if snapshot.state is None: return None current = now or datetime.now(timezone.utc) - remaining_pending = tuple( - plan.call for plan in snapshot.replay_plans if plan.action in {"replay", "poll"} - ) + remaining_pending = tuple(plan.call for plan in snapshot.replay_plans if plan.action in {"replay", "poll"}) return replace( snapshot.state, pending_tool_calls=remaining_pending, diff --git a/packages/kernel/runtime.py b/packages/kernel/runtime.py index 37ca02c..057f179 100644 --- a/packages/kernel/runtime.py +++ b/packages/kernel/runtime.py @@ -9,4 +9,3 @@ from .runtime_support import * # noqa: F401,F403 from .lifecycle_support import * # noqa: F401,F403 -from .runtime_impl import KernelService diff --git a/packages/kernel/runtime_impl.py b/packages/kernel/runtime_impl.py index 9f20c38..1af5532 100644 --- a/packages/kernel/runtime_impl.py +++ b/packages/kernel/runtime_impl.py @@ -1,20 +1,15 @@ from __future__ import annotations -import concurrent.futures import sys from uuid import uuid4 from .context_compaction import ( - append_episode_continuity_packet, - compact_context_after_usage, compaction_step_metadata, - episode_continuity_packet, flush_projection_cache, latest_compacted_projection, projection_compaction_detail, retry_context_after_provider_overflow, stage_context_projection, - stage_context_usage, ) from .execution_support import execute_kernel_turn from .generation_context import build_context_for_generation @@ -27,6 +22,7 @@ resolve_runtime_identity, ) from .runtime_support import * # noqa: F401,F403 + _SUPPORT_UTC_NOW = _utc_now @@ -68,9 +64,7 @@ def _primary_learning_trigger( outcome = str(getattr(execution, "outcome", "") or "").strip() if outcome == "paused": return "checkpoint" - checkpoint_seen = any( - step.action == "checkpoint" and step.status == "completed" for step in steps - ) + checkpoint_seen = any(step.action == "checkpoint" and step.status == "completed" for step in steps) if checkpoint_seen: return "checkpoint" if outcome == "failed": @@ -106,8 +100,12 @@ def _prompt_for_request_execution( return request.prompt prompt = request.prompt # Time injection: first turn of episode, idle > 1h, or temporal keywords - if _should_inject_time(prompt, is_first_turn=is_first_turn_of_episode, - session_updated_at=previous_updated_at, now=clock.local_datetime): + if _should_inject_time( + prompt, + is_first_turn=is_first_turn_of_episode, + session_updated_at=previous_updated_at, + now=clock.local_datetime, + ): prompt = f"{prompt.rstrip()}\n\n{_time_annotation(clock)}" # Execution strategy hints (multi-source, compare, artifact) prompt = _apply_execution_guidance(prompt) @@ -356,7 +354,10 @@ def stage(name: str, detail: str) -> None: current=_clock_now(), summary=projection_compaction_detail(retry_compaction.result), outcome=str(getattr(retry_compaction.result, "reason", "") or "provider-overflow"), - payload_refs=(*retry_compaction.packet.source_refs, retry_compaction.packet.packet_id), + payload_refs=( + *retry_compaction.packet.source_refs, + retry_compaction.packet.packet_id, + ), metadata=compaction_step_metadata( packet=retry_compaction.packet, result=retry_compaction.result, @@ -461,9 +462,15 @@ def stage(name: str, detail: str) -> None: # episodes (multi-turn sessions) defer learning until explicit close. primary_job = None if episode.status != "closed": - stage("episode_learning", f"deferred episode={episode.episode_id} status={episode.status}") + stage( + "episode_learning", + f"deferred episode={episode.episode_id} status={episode.status}", + ) elif _suppress_primary_learning_for_request(request): - stage("episode_learning", f"suppressed internal turn episode={episode.episode_id}") + stage( + "episode_learning", + f"suppressed internal turn episode={episode.episode_id}", + ) else: primary_trigger = _primary_learning_trigger(execution=execution, steps=step_recorder.steps) primary_job = self._enqueue_episode_learning_job( @@ -559,7 +566,9 @@ def _route_session( previous_updated_at = existing_session.updated_at if existing_session is not None else current return Episode( episode_id=request.route_id, - state_id=existing_session.state_id if existing_session is not None else (request.state_id or "state:default"), + state_id=existing_session.state_id + if existing_session is not None + else (request.state_id or "state:default"), personal_model_id=profile.profile_id, entry_surface=existing_session.entry_surface if existing_session is not None else request.surface, elephant_id=(existing_session.elephant_id if existing_session is not None else "") or "", @@ -602,7 +611,9 @@ def _retrieve_recall_evidence( ) work_item_ids: tuple[str, ...] = () scope_episode_ids = _episode_lineage_ids(self.dependencies.storage, session) - scope_reason = "recovery follows the active episode lineage while allowing elephant and personal-model continuity recall" + scope_reason = ( + "recovery follows the active episode lineage while allowing elephant and personal-model continuity recall" + ) requested_scopes = ["episode"] if session.elephant_id: requested_scopes.append("elephant") @@ -632,9 +643,7 @@ def _retrieve_recall_evidence( work_item_ids=work_item_ids, scope_episode_ids=retrieval.scope_episode_ids, scope_reason=retrieval.scope_reason, - vector_cache_status=str( - getattr(retrieval.recall_reasons, "vector_cache_status", "") or "" - ), + vector_cache_status=str(getattr(retrieval.recall_reasons, "vector_cache_status", "") or ""), ) def _context_for_generation( @@ -762,9 +771,7 @@ def _refresh_state_projection( surface = (request.surface or "").strip().lower() source_event_type = (request.source_event_type or "").strip().lower() is_internal_turn = ( - surface.startswith("cli.startup") - or surface.endswith(".startup") - or source_event_type == "turn.internal" + surface.startswith("cli.startup") or surface.endswith(".startup") or source_event_type == "turn.internal" ) if is_internal_turn: return replace(state, updated_at=current) diff --git a/packages/kernel/runtime_support.py b/packages/kernel/runtime_support.py index 75daea6..87fd4b5 100644 --- a/packages/kernel/runtime_support.py +++ b/packages/kernel/runtime_support.py @@ -9,7 +9,7 @@ from collections.abc import Iterable, Mapping from dataclasses import dataclass, field, replace -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone import hashlib from html import unescape as html_unescape import json @@ -33,19 +33,11 @@ from packages.contracts.layers import Episode, Loop, PersonalModel, State, Step from packages.contracts.runtime import ( LoopState, - LoopStep, ContextBundle, - EvidenceRetrievalRequest, EventEnvelope, ExecutionResult, - PendingToolCall, - RetryState, - StateFocusDecision, RecallEvidence, - PlanDraft, PromptMessage, - PersonalModelRuntimeState, - WaitCondition, ) from packages.tools.tool_result_storage import ( ToolResultBudgetConfig, @@ -251,6 +243,7 @@ def to_event(self) -> EventEnvelope: def event(self) -> EventEnvelope: return self.to_event() + @dataclass(frozen=True, slots=True) class KernelStageRecord: stage: str @@ -283,11 +276,7 @@ def step_ids(self) -> tuple[str, ...]: return tuple(step.step_id for step in self.steps) def step_action_count(self, action: str, *, status: str | None = None) -> int: - return sum( - 1 - for step in self.steps - if step.action == action and (status is None or step.status == status) - ) + return sum(1 for step in self.steps if step.action == action and (status is None or step.status == status)) @property def tool_call_count(self) -> int: @@ -394,9 +383,7 @@ def _parse_execution_tool_calls(result: ExecutionResult) -> _ParsedToolCalls: return _parse_text_tool_calls(result.summary) -_JSON_LITERAL_PATTERN = re.compile( - r"^-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?$" -) +_JSON_LITERAL_PATTERN = re.compile(r"^-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?$") _ARTIFACT_PATH_PATTERNS = ( re.compile( r"""(?ix) @@ -421,11 +408,7 @@ def _decode_text_tool_argument(raw_value: str) -> object: candidate = html_unescape(raw_value).strip() if not candidate: return "" - if ( - candidate[0] in "[{\"" - or candidate in {"true", "false", "null"} - or _JSON_LITERAL_PATTERN.match(candidate) - ): + if candidate[0] in '[{"' or candidate in {"true", "false", "null"} or _JSON_LITERAL_PATTERN.match(candidate): try: return json.loads(candidate) except json.JSONDecodeError: @@ -569,7 +552,9 @@ def _tool_call_signature(call: _TextToolCall) -> str: return f"{call.tool_name}:{payload}" -def _deduplicate_tool_calls(calls: Iterable[_TextToolCall]) -> tuple[_TextToolCall, ...]: +def _deduplicate_tool_calls( + calls: Iterable[_TextToolCall], +) -> tuple[_TextToolCall, ...]: unique: list[_TextToolCall] = [] seen: set[str] = set() for call in calls: @@ -593,10 +578,7 @@ def _should_parallelize_tool_batch(calls: tuple[_TextToolCall, ...]) -> bool: continue if call.tool_name not in _PARALLEL_SAFE_TOOLS: return False - if ( - call.tool_name == "tool.file.read" - and _normalized_tool_path(call.arguments.get("path")) is None - ): + if call.tool_name == "tool.file.read" and _normalized_tool_path(call.arguments.get("path")) is None: return False return True @@ -636,7 +618,9 @@ def _model_turn_summary(result: ExecutionResult, *, parsed: _ParsedToolCalls) -> return result.summary.strip() -def _resolve_clock_timezone(timezone_name: str | None) -> tuple[timezone | ZoneInfo, str]: +def _resolve_clock_timezone( + timezone_name: str | None, +) -> tuple[timezone | ZoneInfo, str]: candidate = str(timezone_name or os.environ.get("ELEPHANT_TIMEZONE") or os.environ.get("TZ") or "").strip() if candidate: try: @@ -766,7 +750,9 @@ def _apply_execution_guidance(prompt: str) -> str: lines = ["Execution guidance for this turn:"] if multi_source: lines.append("- Use more than one tool step and at least two distinct sources before concluding.") - lines.append("- Preferred flow: tool.web.search first, then tool.web.extract or multiple tool.web.read calls, then synthesize.") + lines.append( + "- Preferred flow: tool.web.search first, then tool.web.extract or multiple tool.web.read calls, then synthesize." + ) if compare_request: lines.append("- Compare approaches explicitly instead of returning a single-source note.") if artifact_request is not None: @@ -790,7 +776,13 @@ def _looks_like_multi_source_research_request(normalized_prompt: str) -> bool: "survey", "investigate", ) - synthesis_markers = ("summary", "summarize", "write a summary", "report", "overview") + synthesis_markers = ( + "summary", + "summarize", + "write a summary", + "report", + "overview", + ) return any(marker in normalized_prompt for marker in research_markers) and ( any(marker in normalized_prompt for marker in synthesis_markers) or any(marker in normalized_prompt for marker in ("compare", "comparison", "latest")) diff --git a/packages/models/auth_headers.py b/packages/models/auth_headers.py index 9377f31..7b5ce89 100644 --- a/packages/models/auth_headers.py +++ b/packages/models/auth_headers.py @@ -6,10 +6,7 @@ from typing import Mapping, Protocol, runtime_checkable ANTHROPIC_OAUTH_BETA_HEADER = ( - "interleaved-thinking-2025-05-14," - "fine-grained-tool-streaming-2025-05-14," - "claude-code-20250219," - "oauth-2025-04-20" + "interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14,claude-code-20250219,oauth-2025-04-20" ) @@ -76,10 +73,7 @@ def default(cls) -> "InMemoryAuthHeaderStrategyRegistry": def _is_anthropic_oauth_token(api_key: str) -> bool: - return ( - (api_key.startswith("sk-ant-") and not api_key.startswith("sk-ant-api")) - or api_key.startswith("eyJ") - ) + return (api_key.startswith("sk-ant-") and not api_key.startswith("sk-ant-api")) or api_key.startswith("eyJ") class _CopilotBearerStrategy: @@ -105,8 +99,7 @@ def supports(self, context: AuthHeaderContext) -> bool: api_key = str(context.api_key or "").strip() provider_id = context.provider_id.strip().lower() return bool(api_key) and ( - provider_id == "claude-code" - or (provider_id == "anthropic" and _is_anthropic_oauth_token(api_key)) + provider_id == "claude-code" or (provider_id == "anthropic" and _is_anthropic_oauth_token(api_key)) ) def build_headers(self, context: AuthHeaderContext) -> Mapping[str, str]: @@ -131,9 +124,7 @@ def supports(self, context: AuthHeaderContext) -> bool: api_key = str(context.api_key or "").strip() provider_id = context.provider_id.strip().lower() request_family = context.request_family.strip().lower() - return bool(api_key) and ( - request_family == "messages" or provider_id in {"anthropic", "minimax"} - ) + return bool(api_key) and (request_family == "messages" or provider_id in {"anthropic", "minimax"}) def build_headers(self, context: AuthHeaderContext) -> Mapping[str, str]: api_key = str(context.api_key or "").strip() diff --git a/packages/models/bootstrap.py b/packages/models/bootstrap.py index 77106ab..2a000cd 100644 --- a/packages/models/bootstrap.py +++ b/packages/models/bootstrap.py @@ -126,8 +126,7 @@ def _embedding_bootstrap_summary( return f"local embedding root is available at {EMBEDDING_MODEL_ROOT}" if status == "pending": return ( - "local semantic-index bootstrap is preparing minimal " - "sentence-transformers dependencies in the background." + "local semantic-index bootstrap is preparing minimal sentence-transformers dependencies in the background." ) if status == "downloading": return ( @@ -135,15 +134,16 @@ def _embedding_bootstrap_summary( f"model acquisition from {EMBEDDING_MODEL_SOURCE_URL} is in progress." ) if status == "failed": - detail = str(failure_message or "embedding bootstrap request failed").strip() or "embedding bootstrap request failed" + detail = ( + str(failure_message or "embedding bootstrap request failed").strip() or "embedding bootstrap request failed" + ) return f"local semantic-index bootstrap remains non-blocking after a failure: {detail}" - return ( - "local semantic-index bootstrap is waiting " - "for the background worker to report state." - ) + return "local semantic-index bootstrap is waiting for the background worker to report state." -def embedding_bootstrap_state_from_payload(payload: Mapping[str, Any]) -> EmbeddingBootstrapState: +def embedding_bootstrap_state_from_payload( + payload: Mapping[str, Any], +) -> EmbeddingBootstrapState: status = _normalize_embedding_bootstrap_status(payload.get("status")) state_focus_mode = _normalize_state_focus_mode(payload.get("state_focus_mode")) failure_message = str(payload.get("failure_message") or "").strip() or None @@ -159,8 +159,7 @@ def embedding_bootstrap_state_from_payload(payload: Mapping[str, Any]) -> Embedd model_id = str(payload.get("model_id") or EMBEDDING_MODEL_ID).strip() or EMBEDDING_MODEL_ID model_root = str(payload.get("model_root") or EMBEDDING_MODEL_ROOT).strip() or str(EMBEDDING_MODEL_ROOT) model_source_url = ( - str(payload.get("model_source_url") or EMBEDDING_MODEL_SOURCE_URL).strip() - or EMBEDDING_MODEL_SOURCE_URL + str(payload.get("model_source_url") or EMBEDDING_MODEL_SOURCE_URL).strip() or EMBEDDING_MODEL_SOURCE_URL ) source_raw = str(payload.get("source") or "huggingface").strip().lower() source = source_raw if source_raw in _ALLOWED_EMBEDDING_SOURCES else "huggingface" @@ -179,7 +178,9 @@ def embedding_bootstrap_state_from_payload(payload: Mapping[str, Any]) -> Embedd ) -def load_embedding_bootstrap_state(state_dir: Path | None) -> EmbeddingBootstrapState | None: +def load_embedding_bootstrap_state( + state_dir: Path | None, +) -> EmbeddingBootstrapState | None: path = embedding_bootstrap_state_path(state_dir) if path is None or not path.exists(): return None @@ -298,11 +299,7 @@ def resolve_embedding_bootstrap_state( normalized_state_focus_mode = _normalize_state_focus_mode(state_focus_mode) stored = load_embedding_bootstrap_state(state_dir) if normalized_state_focus_mode == "skip": - updated_at = ( - stored.updated_at - if stored is not None and stored.state_focus_mode == "skip" - else _utc_now_iso() - ) + updated_at = stored.updated_at if stored is not None and stored.state_focus_mode == "skip" else _utc_now_iso() return EmbeddingBootstrapState( status="skipped", summary=_embedding_bootstrap_summary(state_focus_mode="skip", status="skipped"), @@ -351,9 +348,7 @@ def run_embedding_bootstrap_worker(state_dir_arg: str) -> int: _embedding_bootstrap_state_for_runtime(status="pending", background_pid=current_pid, source=source), ) pip_specs = ( - _EMBEDDING_BOOTSTRAP_PIP_SPECS_MODELSCOPE - if source == "modelscope" - else _EMBEDDING_BOOTSTRAP_PIP_SPECS + _EMBEDDING_BOOTSTRAP_PIP_SPECS_MODELSCOPE if source == "modelscope" else _EMBEDDING_BOOTSTRAP_PIP_SPECS ) subprocess.check_call( [ @@ -423,9 +418,12 @@ def trigger_embedding_bootstrap( normalized_source = effective_source if effective_source in _ALLOWED_EMBEDDING_SOURCES else "huggingface" retryable = resolved if resolved.status == "failed": - retryable = _embedding_bootstrap_state_for_runtime(status="pending", background_pid=None, source=normalized_source) + retryable = _embedding_bootstrap_state_for_runtime( + status="pending", background_pid=None, source=normalized_source + ) else: from dataclasses import replace as _dc_replace + retryable = _dc_replace(retryable, source=normalized_source) try: spawned = _spawn_embedding_bootstrap_worker(state_dir, retryable) diff --git a/packages/models/discovery.py b/packages/models/discovery.py index 556437a..e9e8f2d 100644 --- a/packages/models/discovery.py +++ b/packages/models/discovery.py @@ -12,7 +12,7 @@ from packages.auth.runtime import AuthProfile, SecretValueResolution from .model_metadata import resolve_provider_model_metadata -from .provider_catalog import default_provider_definitions, provider_definition +from .provider_catalog import provider_definition from .provider_runtime import ProviderRuntimeResolver, provider_auth_headers RequestJsonCallable = Callable[..., dict[str, Any]] @@ -307,7 +307,9 @@ def heuristic_context_window(model_id: str) -> int | None: return DEFAULT_CONTEXT_WINDOW_TOKENS -def _hinted_models(provider_id: str, *, runtime_resolver: ProviderRuntimeResolver) -> tuple[DiscoveredProviderModel, ...]: +def _hinted_models( + provider_id: str, *, runtime_resolver: ProviderRuntimeResolver +) -> tuple[DiscoveredProviderModel, ...]: definition = provider_definition(provider_id) if definition is None: return () @@ -438,9 +440,7 @@ def list_models(self, context: ProviderModelLookupContext) -> tuple[DiscoveredPr raw_efforts = supports_payload.get("reasoning_effort") if isinstance(raw_efforts, list): reasoning_efforts = tuple( - str(value).strip().lower() - for value in raw_efforts - if str(value).strip() + str(value).strip().lower() for value in raw_efforts if str(value).strip() ) models.append( DiscoveredProviderModel( @@ -516,7 +516,11 @@ def detect_context_window(self, context: ProviderModelLookupContext) -> int | No with request.urlopen(http_request, timeout=5.0) as response: raw_body = response.read().decode("utf-8") payload = json.loads(raw_body) if raw_body else {} - except (error.HTTPError, error.URLError, json.JSONDecodeError): # pragma: no cover + except ( + error.HTTPError, + error.URLError, + json.JSONDecodeError, + ): # pragma: no cover return None if not isinstance(payload, Mapping): return None @@ -584,12 +588,16 @@ def detect_context_window( extra_headers: Mapping[str, str] | None = None, hinted_models: tuple[DiscoveredProviderModel, ...] | None = None, ) -> int | None: - models = hinted_models if hinted_models is not None else self.discover_models( - provider_id=provider_id, - base_url=base_url, - api_key=api_key, - extra_headers=extra_headers, - default_model_id=model_id, + models = ( + hinted_models + if hinted_models is not None + else self.discover_models( + provider_id=provider_id, + base_url=base_url, + api_key=api_key, + extra_headers=extra_headers, + default_model_id=model_id, + ) ) normalized_provider_id = provider_id.strip().lower() for item in models: diff --git a/packages/models/ephemeral_injection.py b/packages/models/ephemeral_injection.py index c6a37f7..3def0a8 100644 --- a/packages/models/ephemeral_injection.py +++ b/packages/models/ephemeral_injection.py @@ -244,8 +244,7 @@ def resolve( if cached is not None and cached[0] == cache_key: return cached[1] blocks = tuple( - safe_call_block_builder(builder, profile, session, context, prompt, query) - for builder in builders + safe_call_block_builder(builder, profile, session, context, prompt, query) for builder in builders ) if episode_id: self._entries[episode_id] = (cache_key, blocks) diff --git a/packages/models/model_metadata.py b/packages/models/model_metadata.py index 2265e8a..1553da7 100644 --- a/packages/models/model_metadata.py +++ b/packages/models/model_metadata.py @@ -24,7 +24,11 @@ _ENDPOINT_CACHE_TTL_SECONDS = 300 _LOCAL_HOSTS = frozenset({"localhost", "127.0.0.1", "::1", "0.0.0.0"}) -_CONTAINER_LOCAL_SUFFIXES = (".docker.internal", ".containers.internal", ".lima.internal") +_CONTAINER_LOCAL_SUFFIXES = ( + ".docker.internal", + ".containers.internal", + ".lima.internal", +) _CONTEXT_KEYS = ( "context_length", @@ -263,7 +267,11 @@ def query_local_endpoint_metadata( props = _request_json_object(f"{server_root}/v1/props", headers=_bearer_headers(api_key), timeout_seconds=2.0) if not props: - props = _request_json_object(f"{server_root}/props", headers=_bearer_headers(api_key), timeout_seconds=2.0) + props = _request_json_object( + f"{server_root}/props", + headers=_bearer_headers(api_key), + timeout_seconds=2.0, + ) if props: context = _extract_nested_int(props, _CONTEXT_KEYS) if context is not None: @@ -430,7 +438,7 @@ def _request_json_object( try: with request.urlopen(http_request, timeout=timeout_seconds) as response: raw_body = response.read().decode("utf-8") - except (error.HTTPError, error.URLError, TimeoutError) as exc: + except (error.HTTPError, error.URLError, TimeoutError): return None try: payload = json.loads(raw_body) if raw_body else {} diff --git a/packages/models/provider_catalog.py b/packages/models/provider_catalog.py index 956f5a2..19400e2 100644 --- a/packages/models/provider_catalog.py +++ b/packages/models/provider_catalog.py @@ -160,7 +160,11 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat", "embeddings"), - model_hints=("openai/gpt-4o-mini", "anthropic/claude-3.7-sonnet", "google/gemini-2.5-pro"), + model_hints=( + "openai/gpt-4o-mini", + "anthropic/claude-3.7-sonnet", + "google/gemini-2.5-pro", + ), supports_custom_base_url=False, listing_priority=25, provider_kind="aggregator", @@ -319,7 +323,11 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat", "embeddings"), - model_hints=("llama-3.3-70b-versatile", "qwen-qwq-32b", "deepseek-r1-distill-llama-70b"), + model_hints=( + "llama-3.3-70b-versatile", + "qwen-qwq-32b", + "deepseek-r1-distill-llama-70b", + ), supports_custom_base_url=False, listing_priority=40, provider_kind="first_party", @@ -388,7 +396,11 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat", "embeddings"), - model_hints=("mistral-small-latest", "mistral-medium-latest", "codestral-latest"), + model_hints=( + "mistral-small-latest", + "mistral-medium-latest", + "codestral-latest", + ), supports_custom_base_url=False, listing_priority=55, provider_kind="first_party", @@ -411,7 +423,11 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat", "embeddings"), - model_hints=("meta-llama/Llama-4-Scout-17B-16E-Instruct", "deepseek-ai/DeepSeek-V3.1", "Qwen/Qwen3-235B-A22B-Instruct-2507"), + model_hints=( + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "deepseek-ai/DeepSeek-V3.1", + "Qwen/Qwen3-235B-A22B-Instruct-2507", + ), supports_custom_base_url=False, listing_priority=60, provider_kind="aggregator", @@ -434,7 +450,10 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat", "embeddings"), - model_hints=("accounts/fireworks/models/deepseek-v3", "accounts/fireworks/models/llama-v3p1-70b-instruct"), + model_hints=( + "accounts/fireworks/models/deepseek-v3", + "accounts/fireworks/models/llama-v3p1-70b-instruct", + ), supports_custom_base_url=False, listing_priority=62, provider_kind="aggregator", @@ -530,7 +549,15 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat",), - model_hints=("MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1", "MiniMax-M2.1-highspeed", "MiniMax-M2"), + model_hints=( + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.1", + "MiniMax-M2.1-highspeed", + "MiniMax-M2", + ), supports_custom_base_url=False, listing_priority=70, provider_kind="first_party", @@ -554,7 +581,15 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat",), - model_hints=("MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1", "MiniMax-M2.1-highspeed", "MiniMax-M2"), + model_hints=( + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.1", + "MiniMax-M2.1-highspeed", + "MiniMax-M2", + ), supports_custom_base_url=False, listing_priority=72, provider_kind="first_party", @@ -579,7 +614,16 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat", "embeddings"), - model_hints=("glm-5.1", "glm-5", "glm-5-turbo", "glm-4.7", "glm-4.7-flashx", "glm-4.6", "glm-4.5-air", "glm-4-long"), + model_hints=( + "glm-5.1", + "glm-5", + "glm-5-turbo", + "glm-4.7", + "glm-4.7-flashx", + "glm-4.6", + "glm-4.5-air", + "glm-4-long", + ), supports_custom_base_url=False, listing_priority=73, docs_url="https://docs.bigmodel.cn/cn/api/introduction", @@ -676,7 +720,11 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat", "embeddings"), - model_hints=("openai/gpt-oss-120b", "meta-llama/Llama-3.3-70B-Instruct", "Qwen/Qwen3-235B-A22B-Instruct-2507"), + model_hints=( + "openai/gpt-oss-120b", + "meta-llama/Llama-3.3-70B-Instruct", + "Qwen/Qwen3-235B-A22B-Instruct-2507", + ), supports_custom_base_url=False, listing_priority=76, provider_kind="aggregator", @@ -750,7 +798,11 @@ class ProviderDefinition: required_secret_keys=("api_key",), required_config_keys=("model_id",), capability_flags=("chat", "embeddings"), - model_hints=("google/gemini-3-flash-preview", "openai/gpt-5.4", "anthropic/claude-sonnet-4.6"), + model_hints=( + "google/gemini-3-flash-preview", + "openai/gpt-5.4", + "anthropic/claude-sonnet-4.6", + ), supports_custom_base_url=False, listing_priority=79, provider_kind="aggregator", @@ -796,7 +848,11 @@ class ProviderDefinition: required_secret_keys=(), required_config_keys=("base_url", "model_id"), capability_flags=("chat", "embeddings"), - model_hints=("Qwen/Qwen2.5-7B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"), + model_hints=( + "Qwen/Qwen2.5-7B-Instruct", + "meta-llama/Llama-3.1-8B-Instruct", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + ), supports_custom_base_url=True, listing_priority=85, provider_kind="self_hosted", diff --git a/packages/models/provider_runtime.py b/packages/models/provider_runtime.py index 67ef011..6c9457c 100644 --- a/packages/models/provider_runtime.py +++ b/packages/models/provider_runtime.py @@ -550,9 +550,7 @@ def _manifest_for(self, provider_id: str) -> ProviderManifest: def _transport_for_manifest(self, manifest: ProviderManifest) -> ProviderTransportDefinition: transport = self.transport_registry.get(manifest.transport_id) if transport is None: - raise LookupError( - f"no transport registered for manifest transport id: {manifest.transport_id}" - ) + raise LookupError(f"no transport registered for manifest transport id: {manifest.transport_id}") return transport def _transport_for_transport_id(self, transport_id: str) -> ProviderTransportDefinition: diff --git a/packages/models/providers/anthropic.py b/packages/models/providers/anthropic.py index 5ebb097..8c7deb6 100644 --- a/packages/models/providers/anthropic.py +++ b/packages/models/providers/anthropic.py @@ -18,13 +18,14 @@ ContextBundle, ExecutionResult, ExecutionToolCall, - RuntimeModelChoice, PromptMessage, PersonalModelRuntimeState, - GenerationModelProfile, - SupportModelProfile, ) -from packages.models.provider_runtime import ProviderRuntimeResolution, attach_session_header, provider_auth_headers +from packages.models.provider_runtime import ( + ProviderRuntimeResolution, + attach_session_header, + provider_auth_headers, +) from packages.models.runtime import ( CredentialSource, ModelAdapter, @@ -43,7 +44,13 @@ ANTHROPIC_API_VERSION = "2023-06-01" ANTHROPIC_ENDPOINT_PATH = "/v1/messages" ANTHROPIC_REQUEST_FAMILY = "messages" -THINKING_BUDGET = {"max": 64000, "xhigh": 32000, "high": 16000, "medium": 8000, "low": 4000} +THINKING_BUDGET = { + "max": 64000, + "xhigh": 32000, + "high": 16000, + "medium": 8000, + "low": 4000, +} ADAPTIVE_EFFORT_MAP = { "max": "max", "xhigh": "xhigh", @@ -144,7 +151,11 @@ def as_mapping(self) -> dict[str, object]: if self.system: if self._supports_cache_control(): payload["system"] = [ - {"type": "text", "text": self.system, "cache_control": {"type": "ephemeral"}}, + { + "type": "text", + "text": self.system, + "cache_control": {"type": "ephemeral"}, + }, ] else: payload["system"] = self.system @@ -155,7 +166,10 @@ def as_mapping(self) -> dict[str, object]: if self.tools: tools_list = [dict(tool) for tool in self.tools] if tools_list and self._supports_cache_control(): - tools_list[-1] = {**tools_list[-1], "cache_control": {"type": "ephemeral"}} + tools_list[-1] = { + **tools_list[-1], + "cache_control": {"type": "ephemeral"}, + } payload["tools"] = tools_list if self.thinking: payload["thinking"] = dict(self.thinking) @@ -372,11 +386,7 @@ def parse_response( AnthropicContentBlock( type=block_type, text=combined.content, - metadata={ - key: str(value) - for key, value in block.items() - if key not in {"type", "text"} - }, + metadata={key: str(value) for key, value in block.items() if key not in {"type", "text"}}, ) ) usage_payload = payload.get("usage", {}) @@ -387,8 +397,7 @@ def parse_response( cached_prompt_tokens=int(usage_payload.get("cache_read_input_tokens", 0) or 0), cache_creation_prompt_tokens=int(usage_payload.get("cache_creation_input_tokens", 0) or 0), cache_usage_reported=( - "cache_read_input_tokens" in usage_payload - or "cache_creation_input_tokens" in usage_payload + "cache_read_input_tokens" in usage_payload or "cache_creation_input_tokens" in usage_payload ), ) return AnthropicMessagesResponse( @@ -578,7 +587,10 @@ def _anthropic_tool_use_block( tool_name_map: Mapping[str, str], ) -> AnthropicContentBlock: call_id = str(call.get("id") or call.get("call_id") or "").strip() or "toolu_context" - name = self._provider_tool_name(str(call.get("name") or call.get("tool_name") or ""), tool_name_map=tool_name_map) + name = self._provider_tool_name( + str(call.get("name") or call.get("tool_name") or ""), + tool_name_map=tool_name_map, + ) arguments = call.get("arguments") return AnthropicContentBlock( type="tool_use", diff --git a/packages/models/providers/http.py b/packages/models/providers/http.py index 095b42c..8402bdb 100644 --- a/packages/models/providers/http.py +++ b/packages/models/providers/http.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timedelta, timezone from html import unescape as html_unescape import json @@ -247,6 +247,7 @@ def _provider_http_error(self, exc: error.HTTPError, *, url: str | None = None) retry_after_raw = headers_dict.get("retry-after") if retry_after_raw is not None: from packages.harness.retry_policy import parse_retry_after + retry_after_s = parse_retry_after(retry_after_raw) message = self._status_error_message( status_code=int(exc.code), @@ -350,9 +351,7 @@ def _post_json_with_curl( ) -> JSONHTTPResponse: curl = shutil.which("curl") if curl is None: - raise RuntimeError( - f"provider request failed for {url}: curl is unavailable for TLS fallback" - ) + raise RuntimeError(f"provider request failed for {url}: curl is unavailable for TLS fallback") status_marker = "__ELEPHANT_STATUS__:" max_time = max(1, int(round(self.timeout_seconds))) connect_timeout = max(1, min(10, max_time)) @@ -418,9 +417,7 @@ def _post_json_stream_with_curl( ): curl = shutil.which("curl") if curl is None: - raise RuntimeError( - f"provider request failed for {url}: curl is unavailable for TLS fallback" - ) + raise RuntimeError(f"provider request failed for {url}: curl is unavailable for TLS fallback") status_marker = "__ELEPHANT_STATUS__:" max_time = max(1, int(round(self.stream_timeout_seconds))) connect_timeout = max(1, min(10, max_time)) diff --git a/packages/models/providers/identity_contract.py b/packages/models/providers/identity_contract.py index e6739d2..9f9d690 100644 --- a/packages/models/providers/identity_contract.py +++ b/packages/models/providers/identity_contract.py @@ -84,9 +84,7 @@ def build_provider_messages(request: ModelRequest) -> tuple[PromptMessage, ...]: """Return the role-preserved message projection for a provider request.""" normalized_messages = _normalized_messages(request.messages) - messages: list[PromptMessage] = [ - PromptMessage(role="system", content=build_provider_system_prompt(request)) - ] + messages: list[PromptMessage] = [PromptMessage(role="system", content=build_provider_system_prompt(request))] messages.extend(message for message in normalized_messages if message.role != "system") prompt = build_provider_user_prompt(request) if prompt: @@ -94,7 +92,9 @@ def build_provider_messages(request: ModelRequest) -> tuple[PromptMessage, ...]: return tuple(message for message in messages if message.content.strip() or message.tool_calls) -def _normalized_messages(messages: Iterable[PromptMessage]) -> tuple[PromptMessage, ...]: +def _normalized_messages( + messages: Iterable[PromptMessage], +) -> tuple[PromptMessage, ...]: normalized: list[PromptMessage] = [] for message in messages: role = str(message.role or "").strip().lower() diff --git a/packages/models/providers/message_payloads.py b/packages/models/providers/message_payloads.py index baa8fe2..4f56b02 100644 --- a/packages/models/providers/message_payloads.py +++ b/packages/models/providers/message_payloads.py @@ -93,7 +93,10 @@ def _openai_chat_tool_call_payload( call_id = str(call.get("id") or call.get("call_id") or "").strip() or "call_context" if len(call_id) > 64: call_id = call_id[:64] - name = _provider_tool_alias_for_message(str(call.get("name") or call.get("tool_name") or ""), tool_name_map=tool_name_map) + name = _provider_tool_alias_for_message( + str(call.get("name") or call.get("tool_name") or ""), + tool_name_map=tool_name_map, + ) arguments = _tool_call_arguments(call.get("arguments")) return { "id": call_id, @@ -113,7 +116,10 @@ def _openai_responses_function_call_payload( call_id = str(call.get("id") or call.get("call_id") or "").strip() or "call_context" if len(call_id) > 64: call_id = call_id[:64] - name = _provider_tool_alias_for_message(str(call.get("name") or call.get("tool_name") or ""), tool_name_map=tool_name_map) + name = _provider_tool_alias_for_message( + str(call.get("name") or call.get("tool_name") or ""), + tool_name_map=tool_name_map, + ) return { "type": "function_call", "call_id": call_id, diff --git a/packages/models/providers/openai_compatible.py b/packages/models/providers/openai_compatible.py index 445a767..38f1807 100644 --- a/packages/models/providers/openai_compatible.py +++ b/packages/models/providers/openai_compatible.py @@ -14,14 +14,33 @@ from packages.contracts.runtime import ExecutionToolCall -from ..provider_runtime import ProviderRuntimeResolution, ProviderRuntimeResolver, attach_session_header -from ..runtime import CredentialSource, ModelAdapterDescriptor, ModelEmbeddingResult, ModelRequest, ModelTextResult, ModelUsage +from ..provider_runtime import ( + ProviderRuntimeResolution, + ProviderRuntimeResolver, + attach_session_header, +) +from ..runtime import ( + CredentialSource, + ModelAdapterDescriptor, + ModelEmbeddingResult, + ModelRequest, + ModelTextResult, + ModelUsage, +) from ._tool_names import provider_tool_name from .identity_contract import build_provider_messages, build_provider_system_prompt from .http import JSONHTTPTransport, UrllibJSONHTTPTransport -from .message_payloads import openai_chat_messages_payload, openai_responses_input_payload +from .message_payloads import ( + openai_chat_messages_payload, + openai_responses_input_payload, +) from .openai_usage import openai_compatible_usage_from_payload -from ..reasoning_parser import combine_reasoning_text, normalize_reasoning_text, split_reasoning_and_content, stitch_text_fragments +from ..reasoning_parser import ( + combine_reasoning_text, + normalize_reasoning_text, + split_reasoning_and_content, + stitch_text_fragments, +) _SCHEMA_TYPE_PREFERENCE = ("string", "object", "array", "integer", "number", "boolean") @@ -212,7 +231,13 @@ def _generate_streaming( if delta: text_parts.append(delta) self._emit_stream_delta(delta, reasoning=False) - if any((chunk_usage.prompt_tokens, chunk_usage.completion_tokens, chunk_usage.total_tokens)): + if any( + ( + chunk_usage.prompt_tokens, + chunk_usage.completion_tokens, + chunk_usage.total_tokens, + ) + ): usage = chunk_usage if plan.request_family == "responses": payload = ( @@ -400,8 +425,7 @@ def _resolve_provider_id(self, request: ModelRequest) -> str: provider_id = request.provider_id or self.config.provider_id if provider_id != self.config.provider_id: raise ValueError( - f"request provider_id {provider_id!r} does not match adapter provider_id " - f"{self.config.provider_id!r}" + f"request provider_id {provider_id!r} does not match adapter provider_id {self.config.provider_id!r}" ) return provider_id @@ -494,11 +518,7 @@ def _build_payload( "messages": messages, "stream": should_stream, } - if ( - request.reasoning_effort - and resolution.supports_reasoning - and resolution.provider_id == "copilot" - ): + if request.reasoning_effort and resolution.supports_reasoning and resolution.provider_id == "copilot": payload["reasoning_effort"] = request.reasoning_effort if chat_tools: payload["tools"] = chat_tools @@ -568,7 +588,9 @@ def _normalized_tool_definitions( ) if not normalized: continue - normalized = self._sanitize_tool_definition(normalized, request_family=request_family, strict_schema=strict_schema) + normalized = self._sanitize_tool_definition( + normalized, request_family=request_family, strict_schema=strict_schema + ) if not normalized: continue if request_family == "responses": @@ -595,7 +617,11 @@ def _normalized_tool_definitions( return normalized_tools, tool_name_map def _requires_strict_tool_schema(self, resolution: ProviderRuntimeResolution) -> bool: - return resolution.request_family == "responses" or resolution.provider_id in {"copilot", "openai", "openai-codex"} + return resolution.request_family == "responses" or resolution.provider_id in { + "copilot", + "openai", + "openai-codex", + } def _sanitize_tool_definition( self, @@ -648,9 +674,7 @@ def _sanitize_json_schema(self, payload: object, *, strict: bool) -> dict[str, o required = schema.get("required") if isinstance(required, (list, tuple)) and normalized_properties: normalized["required"] = [ - str(item) - for item in required - if str(item).strip() and str(item) in normalized_properties + str(item) for item in required if str(item).strip() and str(item) in normalized_properties ] if resolved_type == "array": items = schema.get("items") @@ -747,9 +771,7 @@ def _extract_text_content( return content if isinstance(content, list): texts = [ - str(block.get("text", "")) - for block in content - if isinstance(block, dict) and block.get("text") + str(block.get("text", "")) for block in content if isinstance(block, dict) and block.get("text") ] if texts: return "".join(texts) @@ -781,7 +803,12 @@ def _extract_reasoning_content( ) -> str: parts: list[str] = [] if request_family == "responses": - for key in ("reasoning", "thinking", "reasoning_content", "thinking_content"): + for key in ( + "reasoning", + "thinking", + "reasoning_content", + "thinking_content", + ): value = payload.get(key) if isinstance(value, str) and value.strip(): parts.append(value) @@ -803,7 +830,12 @@ def _extract_reasoning_content( message = choice.get("message") if not isinstance(message, Mapping): continue - for key in ("reasoning", "reasoning_content", "thinking", "thinking_content"): + for key in ( + "reasoning", + "reasoning_content", + "thinking", + "thinking_content", + ): value = message.get(key) if isinstance(value, str) and value.strip(): parts.append(value) @@ -825,12 +857,25 @@ def _reasoning_text_from_node(self, payload: object, *, hinted_reasoning: bool) node_type = str(payload.get("type") or "").strip().lower() effective_hint = hinted_reasoning or self._is_reasoning_type(node_type) parts: list[str] = [] - for key in ("text", "output_text", "reasoning", "reasoning_content", "thinking", "thinking_content"): + for key in ( + "text", + "output_text", + "reasoning", + "reasoning_content", + "thinking", + "thinking_content", + ): value = payload.get(key) if isinstance(value, str) and value.strip(): - parts.append(value.strip() if effective_hint or key != "text" else split_reasoning_and_content(value, streaming=False).reasoning) + parts.append( + value.strip() + if effective_hint or key != "text" + else split_reasoning_and_content(value, streaming=False).reasoning + ) elif isinstance(value, (list, tuple, Mapping)): - parts.append(self._reasoning_text_from_node(value, hinted_reasoning=effective_hint or key != "text")) + parts.append( + self._reasoning_text_from_node(value, hinted_reasoning=effective_hint or key != "text") + ) content = payload.get("content") if isinstance(content, (list, tuple, Mapping, str)): parts.append(self._reasoning_text_from_node(content, hinted_reasoning=effective_hint)) @@ -978,7 +1023,11 @@ def _tool_call_from_payload( arguments = self._tool_arguments_from_payload(function.get("arguments")) else: payload_type = str(payload.get("type") or "").strip() - if payload_type and payload_type not in {"function_call", "tool_call", "function"} and "tool" not in payload_type: + if ( + payload_type + and payload_type not in {"function_call", "tool_call", "function"} + and "tool" not in payload_type + ): return None name = str(payload.get("name") or payload.get("tool_name") or "").strip() arguments = self._tool_arguments_from_payload(payload.get("arguments") or payload.get("input")) @@ -1040,7 +1089,9 @@ def _extract_stream_text_delta( fragments.extend( str(block.get("text", "")) for block in content - if isinstance(block, Mapping) and block.get("text") and not self._is_reasoning_type(block.get("type")) + if isinstance(block, Mapping) + and block.get("text") + and not self._is_reasoning_type(block.get("type")) ) return "".join(fragments) @@ -1071,7 +1122,12 @@ def _extract_stream_reasoning_delta( delta = choice.get("delta") if not isinstance(delta, Mapping): continue - for key in ("reasoning", "reasoning_content", "thinking", "thinking_content"): + for key in ( + "reasoning", + "reasoning_content", + "thinking", + "thinking_content", + ): value = delta.get(key) if isinstance(value, str) and value.strip(): parts.append(value) diff --git a/packages/models/providers/registry.py b/packages/models/providers/registry.py index 7c9ebb8..3b16aa6 100644 --- a/packages/models/providers/registry.py +++ b/packages/models/providers/registry.py @@ -6,11 +6,17 @@ from typing import Mapping, Protocol, runtime_checkable from packages.auth.runtime import AuthProfile -from packages.models.provider_runtime import ProviderRuntimeResolution, ProviderRuntimeResolver +from packages.models.provider_runtime import ( + ProviderRuntimeResolution, + ProviderRuntimeResolver, +) from packages.models.runtime import CredentialSource, ModelAdapter from .anthropic import AnthropicMessagesModelAdapter -from .openai_compatible import OpenAICompatibleProviderAdapter, OpenAICompatibleProviderConfig +from .openai_compatible import ( + OpenAICompatibleProviderAdapter, + OpenAICompatibleProviderConfig, +) @dataclass(frozen=True, slots=True) @@ -67,7 +73,12 @@ def select(self, context: ModelAdapterBuildContext) -> ModelAdapterBuilder: @classmethod def default(cls) -> "InMemoryModelAdapterBuilderRegistry": - return cls((_OpenAICompatibleModelAdapterBuilder(), _AnthropicMessagesModelAdapterBuilder())) + return cls( + ( + _OpenAICompatibleModelAdapterBuilder(), + _AnthropicMessagesModelAdapterBuilder(), + ) + ) class _OpenAICompatibleModelAdapterBuilder: diff --git a/packages/models/reasoning_parser.py b/packages/models/reasoning_parser.py index b624311..401f3a3 100644 --- a/packages/models/reasoning_parser.py +++ b/packages/models/reasoning_parser.py @@ -29,12 +29,7 @@ def _is_cjk(char: str) -> bool: if not char: return False code = ord(char) - return ( - 0x4E00 <= code <= 0x9FFF - or 0x3400 <= code <= 0x4DBF - or 0x3040 <= code <= 0x30FF - or 0xAC00 <= code <= 0xD7AF - ) + return 0x4E00 <= code <= 0x9FFF or 0x3400 <= code <= 0x4DBF or 0x3040 <= code <= 0x30FF or 0xAC00 <= code <= 0xD7AF def _needs_collapsed_whitespace_spacing(previous: str, current: str) -> bool: @@ -145,9 +140,7 @@ def replace_block(match: re.Match[str]) -> str: def _restore_code_blocks(text: str, blocks: tuple[str, ...]) -> str: if not blocks: return text - placeholder_re = re.compile( - re.escape(_PLACEHOLDER_PREFIX) + r"(\d+)" + re.escape(_PLACEHOLDER_SUFFIX) - ) + placeholder_re = re.compile(re.escape(_PLACEHOLDER_PREFIX) + r"(\d+)" + re.escape(_PLACEHOLDER_SUFFIX)) restored = text while True: changed = False @@ -172,7 +165,7 @@ def parse_reasoning_content(content: str, *, streaming: bool) -> ParsedReasoning last_index = 0 for match in _TAG_RE.finditer(masked): - body_parts.append(masked[last_index:match.start()]) + body_parts.append(masked[last_index : match.start()]) segments.append(match.group(2)) last_index = match.end() @@ -190,9 +183,7 @@ def parse_reasoning_content(content: str, *, streaming: bool) -> ParsedReasoning body_parts.append(rest) restored_segments = tuple(_restore_code_blocks(segment, protected.blocks) for segment in segments) - restored_pending = ( - _restore_code_blocks(pending, protected.blocks) if pending is not None else None - ) + restored_pending = _restore_code_blocks(pending, protected.blocks) if pending is not None else None restored_body = _restore_code_blocks("".join(body_parts), protected.blocks) return ParsedReasoningContent( segments=restored_segments, diff --git a/packages/models/runtime.py b/packages/models/runtime.py index 1902053..8bc2060 100644 --- a/packages/models/runtime.py +++ b/packages/models/runtime.py @@ -19,19 +19,6 @@ SupportModelProfile, ) -from .provider_runtime import ( - InMemoryProviderManifestRegistry, - InMemoryProviderTransportRegistry, - ProviderCatalogRecord, - ProviderManifest, - ProviderManifestRegistry, - ProviderRuntimeResolution, - ProviderRuntimeResolver, - ProviderSetupGuide, - ProviderTransportDefinition, - ProviderTransportRegistry, -) - @dataclass(frozen=True, slots=True) class ModelAdapterDescriptor: diff --git a/packages/models/runtime_capability.py b/packages/models/runtime_capability.py index de6b773..4202dc1 100644 --- a/packages/models/runtime_capability.py +++ b/packages/models/runtime_capability.py @@ -31,7 +31,10 @@ PromptEnvelope, SupportModelProfile, ) -from packages.embeddings import OPENAI_COMPATIBLE_EMBED_PROFILE_ID, OPENAI_COMPATIBLE_EMBED_PROVIDER_ID +from packages.embeddings import ( + OPENAI_COMPATIBLE_EMBED_PROFILE_ID, + OPENAI_COMPATIBLE_EMBED_PROVIDER_ID, +) from packages.models.bootstrap import ( EmbeddingBootstrapState, resolve_embedding_bootstrap_state, @@ -45,13 +48,21 @@ heuristic_context_window, request_json, ) -from packages.models.provider_catalog import default_provider_definitions, provider_definition +from packages.models.provider_catalog import ( + default_provider_definitions, + provider_definition, +) from packages.models.provider_runtime import ProviderRuntimeResolver from packages.models.providers import build_model_adapter from packages.storage import RuntimeStorageRepository from packages.tools import ToolDefinition, ToolRuntime, build_tool_fallback_prompt -from .ephemeral_injection import TurnScopedPrefixCache, ephemeral_blocks_as_user_suffix, recall_block_contents, strip_recall_blocks +from .ephemeral_injection import ( + TurnScopedPrefixCache, + ephemeral_blocks_as_user_suffix, + recall_block_contents, + strip_recall_blocks, +) from .runtime import ModelRequest RequestJsonCallable = Callable[..., dict[str, Any]] @@ -95,7 +106,9 @@ def _recall_message_contents(message: PromptMessage) -> tuple[str, ...]: return recall_block_contents(content) -def _surfaced_recall_stats(messages: tuple[PromptMessage, ...]) -> tuple[int, frozenset[str]]: +def _surfaced_recall_stats( + messages: tuple[PromptMessage, ...], +) -> tuple[int, frozenset[str]]: contents = tuple(content for message in messages for content in _recall_message_contents(message)) return sum(len(content.encode("utf-8")) for content in contents), frozenset(contents) @@ -150,7 +163,9 @@ def _copilot_acp_status() -> tuple[str, str] | None: return None -def generation_model_profile_from_auth_profile(profile: AuthProfile) -> GenerationModelProfile: +def generation_model_profile_from_auth_profile( + profile: AuthProfile, +) -> GenerationModelProfile: if not str(profile.default_model or "").strip(): raise ValueError(f"auth profile '{profile.profile_id}' is missing a generation model id") return GenerationModelProfile( @@ -164,7 +179,9 @@ def generation_model_profile_from_auth_profile(profile: AuthProfile) -> Generati ) -def support_model_profile_from_auth_profile(profile: AuthProfile) -> SupportModelProfile: +def support_model_profile_from_auth_profile( + profile: AuthProfile, +) -> SupportModelProfile: if not str(profile.default_model or "").strip(): raise ValueError(f"auth profile '{profile.profile_id}' is missing a support model id") return SupportModelProfile( @@ -260,9 +277,7 @@ def __init__( # overflow retries inside the same user turn reuse the cached # result and never re-trigger memory providers with a stale query. # See `packages/models/ephemeral_injection.py`. - self.ephemeral_prefix_builders: tuple[EphemeralPrefixBuilder, ...] = tuple( - ephemeral_prefix_builders - ) + self.ephemeral_prefix_builders: tuple[EphemeralPrefixBuilder, ...] = tuple(ephemeral_prefix_builders) self._ephemeral_prefix_cache = TurnScopedPrefixCache() self.state_focus_mode = "skip" self.bootstrap_state_dir = bootstrap_state_dir or repository.database_path.parent @@ -735,8 +750,10 @@ def _discover_provider_state(self, provider_id: str) -> DiscoveredProviderState: except LookupError: profile = None active_profile = self.active_profile() - selected_profile = profile if profile is not None else ( - active_profile if active_profile and active_profile.provider_id == provider_id else None + selected_profile = ( + profile + if profile is not None + else (active_profile if active_profile and active_profile.provider_id == provider_id else None) ) base_url = ( (selected_profile.base_url if selected_profile is not None else None) @@ -744,18 +761,16 @@ def _discover_provider_state(self, provider_id: str) -> DiscoveredProviderState: or definition.default_base_url ) default_model = ( - (selected_profile.default_model if selected_profile is not None else None) - or definition.default_model_id - ) + selected_profile.default_model if selected_profile is not None else None + ) or definition.default_model_id secret_status = None secret_source = None if selected_profile is not None: secret_status, secret_source = self._profile_secret_status(selected_profile) discovered_secret = None if selected_profile is not None else self._discovered_secret_resolution(provider_id) external_process_status = _copilot_acp_status() if provider_id == "copilot-acp" else None - local_provider_reachable = ( - definition.provider_kind == "local" - and self._local_provider_reachable(provider_id, base_url) + local_provider_reachable = definition.provider_kind == "local" and self._local_provider_reachable( + provider_id, base_url ) return self.state_evaluator.evaluate( provider_id, @@ -869,9 +884,7 @@ def generate( self._fallback_tool_prompt(visible_tools), ) request_tools = ( - tuple(tool.model_function_schema() for tool in visible_tools) - if resolution.supports_tools - else () + tuple(tool.model_function_schema() for tool in visible_tools) if resolution.supports_tools else () ) api_messages = tuple(context.prompt_envelope.messages) request = ModelRequest( diff --git a/packages/observability/context.py b/packages/observability/context.py index 2eda140..97efb4b 100644 --- a/packages/observability/context.py +++ b/packages/observability/context.py @@ -16,7 +16,8 @@ class TraceContext: _current_context: ContextVar[TraceContext | None] = ContextVar( - "elephant_trace_context", default=None, + "elephant_trace_context", + default=None, ) diff --git a/packages/observability/instrumentor.py b/packages/observability/instrumentor.py index 20b8de4..562ef45 100644 --- a/packages/observability/instrumentor.py +++ b/packages/observability/instrumentor.py @@ -65,6 +65,7 @@ def _resolve_module_attr(dotted: str) -> object | None: parts = dotted.split(".") try: import importlib + mod = importlib.import_module(parts[0]) for part in parts[1:]: mod = getattr(mod, part) @@ -101,11 +102,13 @@ def wrapped_run(self, request): loop_id = getattr(loop, "loop_id", "") or "" span.set_attribute("elephant.episode_id", episode_id) span.set_attribute("elephant.loop_id", loop_id) - set_context(TraceContext( - episode_id=episode_id, - loop_id=loop_id, - request_id=request_id, - )) + set_context( + TraceContext( + episode_id=episode_id, + loop_id=loop_id, + request_id=request_id, + ) + ) except Exception: duration = timer.elapsed() logger.error("kernel turn failed: duration=%.2fs", duration) @@ -116,7 +119,9 @@ def wrapped_run(self, request): record_turn_metrics(episode_id=episode_id, duration_s=duration, trigger_type=trigger) logger.info( "kernel turn completed: episode=%s loop=%s duration=%.2fs", - episode_id, loop_id, duration, + episode_id, + loop_id, + duration, ) return result @@ -133,7 +138,16 @@ def _patch_generate_with_steps() -> None: logger = get_logger("kernel.execution") @wraps(original) - def wrapped(service, profile, session, context, prompt, *, step_recorder=None, planned_summary=""): + def wrapped( + service, + profile, + session, + context, + prompt, + *, + step_recorder=None, + planned_summary="", + ): mp = getattr(getattr(service, "dependencies", None), "model_provider", None) provider_id = getattr(mp, "active_provider_id", "") or "" model_id = "" @@ -150,7 +164,15 @@ def wrapped(service, profile, session, context, prompt, *, step_recorder=None, p timer = DurationTimer() with trace_model_call(provider_id=provider_id, model_id=model_id, episode_id=episode_id) as span: - response = original(service, profile, session, context, prompt, step_recorder=step_recorder, planned_summary=planned_summary) + response = original( + service, + profile, + session, + context, + prompt, + step_recorder=step_recorder, + planned_summary=planned_summary, + ) record_token_usage( span, input_tokens=getattr(response, "prompt_tokens", 0) or 0, @@ -172,10 +194,14 @@ def wrapped(service, profile, session, context, prompt, *, step_recorder=None, p cache_pct = f"{cache_read / input_tokens * 100:.0f}%" if input_tokens > 0 else "n/a" logger.info( "model call completed: provider=%s model=%s tokens=%d/%d cache_read=%d cache_create=%d cache_hit=%s duration=%.2fs", - provider_id, model_id, + provider_id, + model_id, input_tokens, getattr(response, "completion_tokens", 0) or 0, - cache_read, cache_creation, cache_pct, elapsed, + cache_read, + cache_creation, + cache_pct, + elapsed, ) return response @@ -211,8 +237,14 @@ def _patch_http_transport_post_json() -> None: return original_post = UrllibJSONHTTPTransport.post_json original_stream = UrllibJSONHTTPTransport.post_json_stream - _save_original("packages.models.providers.http.UrllibJSONHTTPTransport.post_json", original_post) - _save_original("packages.models.providers.http.UrllibJSONHTTPTransport.post_json_stream", original_stream) + _save_original( + "packages.models.providers.http.UrllibJSONHTTPTransport.post_json", + original_post, + ) + _save_original( + "packages.models.providers.http.UrllibJSONHTTPTransport.post_json_stream", + original_stream, + ) logger = get_logger("provider.http") @wraps(original_post) @@ -221,11 +253,17 @@ def wrapped_post(self, *, url, headers, payload): try: result = original_post(self, url=url, headers=headers, payload=payload) except Exception: - logger.warning("provider HTTP request failed: url=%s duration=%.2fs", url, time.monotonic() - start) + logger.warning( + "provider HTTP request failed: url=%s duration=%.2fs", + url, + time.monotonic() - start, + ) raise logger.debug( "provider HTTP request: url=%s status=%d duration=%.2fs", - url, getattr(result, "status_code", 0), time.monotonic() - start, + url, + getattr(result, "status_code", 0), + time.monotonic() - start, ) return result @@ -238,12 +276,14 @@ def wrapped_stream(self, *, url, headers, payload): except Exception: logger.warning( "provider HTTP stream failed: url=%s duration=%.2fs", - url, time.monotonic() - start, + url, + time.monotonic() - start, ) raise logger.debug( "provider HTTP stream completed: url=%s duration=%.2fs", - url, time.monotonic() - start, + url, + time.monotonic() - start, ) UrllibJSONHTTPTransport.post_json = wrapped_post @@ -274,6 +314,12 @@ def instrumented_executor(job): logger.info("cron job completed: job_id=%s outcome=%s", job.job_id, outcome) return outcome, summary - return original(self, instrumented_executor, profile_id=profile_id, elephant_id=elephant_id, now=now) + return original( + self, + instrumented_executor, + profile_id=profile_id, + elephant_id=elephant_id, + now=now, + ) CronRuntime.run_due = wrapped diff --git a/packages/observability/logger.py b/packages/observability/logger.py index 2fbb476..70c37c1 100644 --- a/packages/observability/logger.py +++ b/packages/observability/logger.py @@ -26,14 +26,9 @@ def filter(self, record: logging.LogRecord) -> bool: record.msg = _redact(record.msg) if record.args: if isinstance(record.args, dict): - record.args = { - k: _redact(str(v)) if isinstance(v, str) else v - for k, v in record.args.items() - } + record.args = {k: _redact(str(v)) if isinstance(v, str) else v for k, v in record.args.items()} elif isinstance(record.args, tuple): - record.args = tuple( - _redact(str(a)) if isinstance(a, str) else a for a in record.args - ) + record.args = tuple(_redact(str(a)) if isinstance(a, str) else a for a in record.args) return True @@ -64,9 +59,7 @@ def format(self, record: logging.LogRecord) -> str: return json.dumps(entry, ensure_ascii=False) -_CONSOLE_FORMAT = ( - "%(asctime)s %(levelname)-5s [%(trace_id).8s] %(name)s: %(message)s" -) +_CONSOLE_FORMAT = "%(asctime)s %(levelname)-5s [%(trace_id).8s] %(name)s: %(message)s" def configure_logging( @@ -98,7 +91,9 @@ def configure_logging( if log_path: log_path.parent.mkdir(parents=True, exist_ok=True) file_handler = RotatingFileHandler( - log_path, maxBytes=10 * 1024 * 1024, backupCount=5, + log_path, + maxBytes=10 * 1024 * 1024, + backupCount=5, ) file_handler.setLevel(level) file_handler.setFormatter(_JSONFormatter()) diff --git a/packages/observability/metrics.py b/packages/observability/metrics.py index c841fe3..139a5bb 100644 --- a/packages/observability/metrics.py +++ b/packages/observability/metrics.py @@ -60,10 +60,13 @@ def record_tool_metrics( duration_s: float, status: str = "success", ) -> None: - tool_duration.record(duration_s, attributes={ - "gen_ai.tool.name": tool_name, - "elephant.tool.status": status, - }) + tool_duration.record( + duration_s, + attributes={ + "gen_ai.tool.name": tool_name, + "elephant.tool.status": status, + }, + ) def record_turn_metrics( diff --git a/packages/observability/setup.py b/packages/observability/setup.py index 9ec0e2a..ceea3e7 100644 --- a/packages/observability/setup.py +++ b/packages/observability/setup.py @@ -40,12 +40,14 @@ def setup_observability( tracer_provider = TracerProvider(resource=resource) if otel_endpoint: - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter - from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter - - tracer_provider.add_span_processor( - BatchSpanProcessor(OTLPSpanExporter(endpoint=otel_endpoint)) + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter, + ) + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( + OTLPMetricExporter, ) + + tracer_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint=otel_endpoint))) metric_reader = PeriodicExportingMetricReader( OTLPMetricExporter(endpoint=otel_endpoint), export_interval_millis=60_000, diff --git a/packages/observability/spans.py b/packages/observability/spans.py index db139bf..f153a84 100644 --- a/packages/observability/spans.py +++ b/packages/observability/spans.py @@ -52,7 +52,8 @@ def trace_kernel_turn( if trigger_type: attrs[ATTR_TRIGGER_TYPE] = trigger_type with _tracer.start_as_current_span( - "invoke_agent", attributes=attrs, + "invoke_agent", + attributes=attrs, ) as span: yield span @@ -72,7 +73,8 @@ def trace_model_call( **_elephant_attrs(episode_id=episode_id, loop_id=loop_id), } with _tracer.start_as_current_span( - f"chat {model_id}", attributes=attrs, + f"chat {model_id}", + attributes=attrs, ) as span: yield span @@ -105,6 +107,7 @@ def trace_tool_execution( **_elephant_attrs(episode_id=episode_id, loop_id=loop_id), } with _tracer.start_as_current_span( - f"execute_tool {tool_name}", attributes=attrs, + f"execute_tool {tool_name}", + attributes=attrs, ) as span: yield span diff --git a/packages/operator/runtime.py b/packages/operator/runtime.py index ee43548..dc27022 100644 --- a/packages/operator/runtime.py +++ b/packages/operator/runtime.py @@ -9,7 +9,10 @@ ProcedureRecord, ) from packages.contracts.runtime import RecallEvidence -from packages.state.rendered_views import RenderedRelationshipView, RenderedUserProfileView +from packages.state.rendered_views import ( + RenderedRelationshipView, + RenderedUserProfileView, +) @dataclass(frozen=True, slots=True) @@ -144,6 +147,7 @@ class DashboardEggRecord: tone: str details: tuple["DashboardDetailItem", ...] = () + @dataclass(frozen=True, slots=True) class DashboardDetailItem: label: str @@ -179,6 +183,7 @@ class DashboardSessionRecord: tone: str details: tuple[DashboardDetailItem, ...] = () + @dataclass(frozen=True, slots=True) class DashboardOpsRecord: lane: str @@ -271,7 +276,12 @@ def build_profile_operator_surface( identity=identity, user=user, relationship=relationship, - provenance=provenance or ("state.identity", "pm_facts.user_profile_view", "pm_facts.relationship_view"), + provenance=provenance + or ( + "state.identity", + "pm_facts.user_profile_view", + "pm_facts.relationship_view", + ), ) @@ -606,7 +616,9 @@ def render_profile_lines(surface: ProfileOperatorSurface) -> tuple[str, ...]: ) -def render_recall_evidence_lines(surface: RecallEvidenceOperatorSurface) -> tuple[str, ...]: +def render_recall_evidence_lines( + surface: RecallEvidenceOperatorSurface, +) -> tuple[str, ...]: lines: list[str] = [] for item in surface.evidence_items: lines.append( @@ -616,7 +628,13 @@ def render_recall_evidence_lines(surface: RecallEvidenceOperatorSurface) -> tupl if not lines: lines.append("") if surface.search_query is not None: - lines.extend(("", f"search_query: {surface.search_query}", f"scope_reason: {surface.scope_reason or ''}")) + lines.extend( + ( + "", + f"search_query: {surface.search_query}", + f"scope_reason: {surface.scope_reason or ''}", + ) + ) for hit in surface.search_hits: lines.append( f"- {hit.evidence.evidence_id} | score={hit.score:.2f} | reasons={'; '.join(hit.reasons) or ''} | {hit.evidence.content}" @@ -625,7 +643,10 @@ def render_recall_evidence_lines(surface: RecallEvidenceOperatorSurface) -> tupl def render_procedure_lines(surface: ProcedureOperatorSurface) -> tuple[str, ...]: - lines = [f"profile_id: {surface.profile_id}", f"procedure_count: {len(surface.procedures)}"] + lines = [ + f"profile_id: {surface.profile_id}", + f"procedure_count: {len(surface.procedures)}", + ] if surface.procedures: lines.append("procedures:") lines.extend( diff --git a/packages/runtime_config.py b/packages/runtime_config.py index 7831dcc..74c5b43 100644 --- a/packages/runtime_config.py +++ b/packages/runtime_config.py @@ -15,6 +15,7 @@ DEFAULT_EXTERNAL_SKILL_DIRS: tuple[str, ...] = ("~/.agents/skills",) _MISSING = object() + def default_personal_model_question_config() -> dict[str, Any]: return { "proactive_ask": { @@ -26,7 +27,9 @@ def default_personal_model_question_config() -> dict[str, Any]: } -def personal_model_question_config_from_global(config: Mapping[str, Any] | None) -> dict[str, Any]: +def personal_model_question_config_from_global( + config: Mapping[str, Any] | None, +) -> dict[str, Any]: if not isinstance(config, Mapping): return default_personal_model_question_config() questions = config.get("personal_model_questions") @@ -40,7 +43,6 @@ def global_config_path_for_state_dir(state_dir: str | Path) -> Path: return install_root / GLOBAL_CONFIG_FILENAME - def default_global_config(*, state_dir: str | Path) -> dict[str, Any]: resolved_state_dir = Path(state_dir) return { @@ -89,30 +91,150 @@ def default_global_config(*, state_dir: str | Path) -> dict[str, Any]: def global_config_schema() -> list[dict[str, Any]]: return [ - {"path": "runtime.state_dir", "type": "string", "label": "Elephant directory", "section": "Runtime"}, - {"path": "runtime.default_profile_id", "type": "string", "label": "Default profile", "section": "Runtime"}, - {"path": "models.default_provider_source", "type": "string", "label": "Provider source", "section": "Models"}, - {"path": "models.provider", "type": "object", "label": "Provider profile", "section": "Models"}, - {"path": "sessions.persist_system_prompts", "type": "boolean", "label": "Persist system prompts", "section": "Sessions"}, - {"path": "sessions.persist_assistant_responses", "type": "boolean", "label": "Persist assistant responses", "section": "Sessions"}, - {"path": "sessions.max_history_rows", "type": "number", "label": "Max history rows", "section": "Sessions"}, - {"path": "skills.enable_profile_overrides", "type": "boolean", "label": "Skill profile overrides", "section": "Skills"}, - {"path": "skills.external_dirs", "type": "string_list", "label": "External skill dirs", "section": "Skills"}, - {"path": "tools.require_approval_for_risky", "type": "boolean", "label": "Approval for risky tools", "section": "Tools"}, - {"path": "gateway.enabled", "type": "boolean", "label": "Gateway enabled", "section": "Gateway"}, - {"path": "gateway.state_dir", "type": "string", "label": "Gateway herd directory", "section": "Gateway"}, - {"path": "dashboard.host", "type": "string", "label": "Dashboard host", "section": "Dashboard"}, - {"path": "dashboard.port", "type": "number", "label": "Dashboard port", "section": "Dashboard"}, - {"path": "personal_model.first_language", "type": "string", "label": "First language", "section": "Personal Model"}, - {"path": "personal_model_questions.proactive_ask.enabled", "type": "boolean", "label": "Proactive asks enabled", "section": "Personal Model"}, - {"path": "personal_model_questions.proactive_ask.idle_threshold_minutes", "type": "number", "label": "Idle threshold (minutes)", "section": "Personal Model"}, - {"path": "personal_model_questions.proactive_ask.daily_max", "type": "number", "label": "Daily max questions", "section": "Personal Model"}, - {"path": "personal_model_questions.proactive_ask.quiet_hours", "type": "string_list", "label": "Quiet hours [start, end]", "section": "Personal Model"}, - {"path": "observability.enabled", "type": "boolean", "label": "Observability enabled", "section": "Observability"}, - {"path": "observability.log_level", "type": "string", "label": "Log level", "section": "Observability"}, - {"path": "observability.log_file", "type": "string", "label": "Log file path", "section": "Observability"}, - {"path": "observability.otel_endpoint", "type": "string", "label": "OTLP endpoint", "section": "Observability"}, - {"path": "observability.service_name", "type": "string", "label": "Service name", "section": "Observability"}, + { + "path": "runtime.state_dir", + "type": "string", + "label": "Elephant directory", + "section": "Runtime", + }, + { + "path": "runtime.default_profile_id", + "type": "string", + "label": "Default profile", + "section": "Runtime", + }, + { + "path": "models.default_provider_source", + "type": "string", + "label": "Provider source", + "section": "Models", + }, + { + "path": "models.provider", + "type": "object", + "label": "Provider profile", + "section": "Models", + }, + { + "path": "sessions.persist_system_prompts", + "type": "boolean", + "label": "Persist system prompts", + "section": "Sessions", + }, + { + "path": "sessions.persist_assistant_responses", + "type": "boolean", + "label": "Persist assistant responses", + "section": "Sessions", + }, + { + "path": "sessions.max_history_rows", + "type": "number", + "label": "Max history rows", + "section": "Sessions", + }, + { + "path": "skills.enable_profile_overrides", + "type": "boolean", + "label": "Skill profile overrides", + "section": "Skills", + }, + { + "path": "skills.external_dirs", + "type": "string_list", + "label": "External skill dirs", + "section": "Skills", + }, + { + "path": "tools.require_approval_for_risky", + "type": "boolean", + "label": "Approval for risky tools", + "section": "Tools", + }, + { + "path": "gateway.enabled", + "type": "boolean", + "label": "Gateway enabled", + "section": "Gateway", + }, + { + "path": "gateway.state_dir", + "type": "string", + "label": "Gateway herd directory", + "section": "Gateway", + }, + { + "path": "dashboard.host", + "type": "string", + "label": "Dashboard host", + "section": "Dashboard", + }, + { + "path": "dashboard.port", + "type": "number", + "label": "Dashboard port", + "section": "Dashboard", + }, + { + "path": "personal_model.first_language", + "type": "string", + "label": "First language", + "section": "Personal Model", + }, + { + "path": "personal_model_questions.proactive_ask.enabled", + "type": "boolean", + "label": "Proactive asks enabled", + "section": "Personal Model", + }, + { + "path": "personal_model_questions.proactive_ask.idle_threshold_minutes", + "type": "number", + "label": "Idle threshold (minutes)", + "section": "Personal Model", + }, + { + "path": "personal_model_questions.proactive_ask.daily_max", + "type": "number", + "label": "Daily max questions", + "section": "Personal Model", + }, + { + "path": "personal_model_questions.proactive_ask.quiet_hours", + "type": "string_list", + "label": "Quiet hours [start, end]", + "section": "Personal Model", + }, + { + "path": "observability.enabled", + "type": "boolean", + "label": "Observability enabled", + "section": "Observability", + }, + { + "path": "observability.log_level", + "type": "string", + "label": "Log level", + "section": "Observability", + }, + { + "path": "observability.log_file", + "type": "string", + "label": "Log file path", + "section": "Observability", + }, + { + "path": "observability.otel_endpoint", + "type": "string", + "label": "OTLP endpoint", + "section": "Observability", + }, + { + "path": "observability.service_name", + "type": "string", + "label": "Service name", + "section": "Observability", + }, ] diff --git a/packages/security/runtime.py b/packages/security/runtime.py index 3566b41..71df187 100644 --- a/packages/security/runtime.py +++ b/packages/security/runtime.py @@ -49,7 +49,10 @@ lambda _match: _REDACTED, ), ( - re.compile(r"(?i)(-----BEGIN [A-Z ]*PRIVATE KEY-----)(.*?)(-----END [A-Z ]*PRIVATE KEY-----)", re.DOTALL), + re.compile( + r"(?i)(-----BEGIN [A-Z ]*PRIVATE KEY-----)(.*?)(-----END [A-Z ]*PRIVATE KEY-----)", + re.DOTALL, + ), lambda match: f"{match.group(1)}\n{_REDACTED}\n{match.group(3)}", ), ) @@ -296,19 +299,31 @@ def default_surface_policy_bundles() -> tuple[SurfacePolicyBundle, ...]: SurfacePolicyBundle( surface_id="cli.operator", label="CLI operator path", - approval_classes=(ApprovalClass.READ, ApprovalClass.WRITE, ApprovalClass.VOICE_DEVICE), + approval_classes=( + ApprovalClass.READ, + ApprovalClass.WRITE, + ApprovalClass.VOICE_DEVICE, + ), summary="Local shell inspection, governed edits, and optional voice extension.", ), SurfacePolicyBundle( surface_id="gateway.messaging", label="Gateway messaging path", - approval_classes=(ApprovalClass.READ, ApprovalClass.MESSAGING, ApprovalClass.NETWORK), + approval_classes=( + ApprovalClass.READ, + ApprovalClass.MESSAGING, + ApprovalClass.NETWORK, + ), summary="Outbound messaging and remote delivery across recipient and boundary checks.", ), SurfacePolicyBundle( surface_id="deploy.support", label="Deploy and support path", - approval_classes=(ApprovalClass.READ, ApprovalClass.EXEC, ApprovalClass.NETWORK), + approval_classes=( + ApprovalClass.READ, + ApprovalClass.EXEC, + ApprovalClass.NETWORK, + ), summary="Install, doctor, deploy, and support collection without secret exfiltration.", ), ) @@ -545,11 +560,7 @@ def emit_terminal_failure( if result.rule_id != "unregistered-surface" and result.decision != PolicyDecision.DENY: return None severity: FailureSeverity = "critical" if result.rule_id == "unregistered-surface" else "warning" - error_kind = ( - "approval_context_missing" - if result.rule_id == "unregistered-surface" - else "approval_denied" - ) + error_kind = "approval_context_missing" if result.rule_id == "unregistered-surface" else "approval_denied" return emit_failure_event( self.sink, event_id=f"{request.request_id}:failure.side_effect.reported", diff --git a/packages/semantic_index/backend.py b/packages/semantic_index/backend.py index 682b536..aec6eaf 100644 --- a/packages/semantic_index/backend.py +++ b/packages/semantic_index/backend.py @@ -164,8 +164,7 @@ def upsert(self, vector: SemanticIndexVector) -> SemanticIndexWriteResult: (vector.semantic_index_entry_id,), ) connection.execute( - "INSERT INTO " + table_name + "(rowid, semantic_index_entry_id, embedding)" - " VALUES (?, ?, ?)", + "INSERT INTO " + table_name + "(rowid, semantic_index_entry_id, embedding) VALUES (?, ?, ?)", (rowid, vector.semantic_index_entry_id, _vector_json(vector.values)), ) connection.commit() @@ -185,7 +184,8 @@ def search(self, query: SemanticIndexVectorQuery) -> tuple[SemanticIndexVectorMa return () rows = connection.execute( "SELECT semantic_index_entry_id, distance" - + " FROM " + table_name + + " FROM " + + table_name + " WHERE embedding MATCH ? AND k = ?" + " ORDER BY distance ASC", (_vector_json(query.values), query.limit), @@ -262,15 +262,20 @@ def _connect(self) -> Iterator[sqlite3.Connection]: connection.close() @contextmanager - def _loaded_connection(self) -> Iterator[tuple[sqlite3.Connection, SQLiteVecLoadState]]: + def _loaded_connection( + self, + ) -> Iterator[tuple[sqlite3.Connection, SQLiteVecLoadState]]: with self._connect() as connection: yield connection, load_sqlite_vec_extension(connection) def _ensure_vector_table(connection: sqlite3.Connection, table_name: str, dimensions: int) -> None: connection.execute( - "CREATE VIRTUAL TABLE IF NOT EXISTS " + table_name - + " USING vec0(+semantic_index_entry_id TEXT, embedding FLOAT[" + str(dimensions) + "])" + "CREATE VIRTUAL TABLE IF NOT EXISTS " + + table_name + + " USING vec0(+semantic_index_entry_id TEXT, embedding FLOAT[" + + str(dimensions) + + "])" ) diff --git a/packages/semantic_index/search.py b/packages/semantic_index/search.py index a815009..889608b 100644 --- a/packages/semantic_index/search.py +++ b/packages/semantic_index/search.py @@ -105,9 +105,7 @@ def search(self, query: SemanticSearchQuery) -> tuple[SemanticSearchMatch, ...]: ) entries_by_id = {entry.semantic_index_entry_id: entry for entry in entries} documents_by_entry_id = _documents_by_entry_id(self.repository, entries, query) - contributions: dict[str, dict[str, float]] = { - entry_id: {} for entry_id in documents_by_entry_id - } + contributions: dict[str, dict[str, float]] = {entry_id: {} for entry_id in documents_by_entry_id} vector_ranked_ids = self._vector_ranking(query, documents_by_entry_id, entries_by_id) _add_ranked_signal( contributions, @@ -249,7 +247,12 @@ def _episode_record( except Exception: episode = None if episode is None: - return _metadata_record(entry, metadata, schema_version="episode_summary/v1", layer_type="episode_summary") + return _metadata_record( + entry, + metadata, + schema_version="episode_summary/v1", + layer_type="episode_summary", + ) episode_metadata = {str(key): str(value) for key, value in dict(getattr(episode, "metadata", {}) or {}).items()} text = _indexed_text(metadata) if not text: @@ -271,7 +274,12 @@ def _episode_record( state_id=getattr(episode, "state_id", None) or entry.state_id, layer_type="episode_summary", created_at=getattr(episode, "ended_at", None) or getattr(episode, "started_at", None) or entry.created_at, - metadata={**episode_metadata, **dict(metadata), "kind": "episode_summary", "episode_id": episode_id}, + metadata={ + **episode_metadata, + **dict(metadata), + "kind": "episode_summary", + "episode_id": episode_id, + }, ) @@ -322,7 +330,10 @@ def _step_text(step: object, metadata: Mapping[str, str]) -> str: elif action == "emit_response": parts = (str(metadata.get("final_response") or metadata.get("assistant_response") or summary).strip(),) elif action == "reply": - parts = (summary, str(metadata.get("final_response") or metadata.get("assistant_response") or "").strip()) + parts = ( + summary, + str(metadata.get("final_response") or metadata.get("assistant_response") or "").strip(), + ) else: parts = ( summary, @@ -340,7 +351,12 @@ def _fact_record( ) -> SemanticSourceDocument | None: fact = _load_fact(repository, entry, fact_id) if fact is None: - return _metadata_record(entry, metadata, schema_version="personal_model_claim/v1", layer_type="personal_model_claim") + return _metadata_record( + entry, + metadata, + schema_version="personal_model_claim/v1", + layer_type="personal_model_claim", + ) fact_metadata = {str(key): str(value) for key, value in dict(getattr(fact, "metadata", {}) or {}).items()} text = _indexed_text(metadata) or str(getattr(fact, "text", "") or "").strip() return SemanticSourceDocument( @@ -376,7 +392,10 @@ def _load_fact(repository: SemanticSearchRepository, entry: SemanticIndexEntry, ) except Exception: return None - return next((fact for fact in facts if str(getattr(fact, "fact_id", "") or "") == fact_id), None) + return next( + (fact for fact in facts if str(getattr(fact, "fact_id", "") or "") == fact_id), + None, + ) def _metadata_record( @@ -389,7 +408,9 @@ def _metadata_record( text = _indexed_text(metadata) if not text: return None - resolved_layer = layer_type or str(metadata.get("layer_type") or metadata.get("kind") or entry.owner_scope or "semantic_index") + resolved_layer = layer_type or str( + metadata.get("layer_type") or metadata.get("kind") or entry.owner_scope or "semantic_index" + ) return SemanticSourceDocument( source_id=entry.source_id, kind="derived", @@ -489,15 +510,11 @@ def _bm25_ranking( query_tokens = _tokens(text) if not query_tokens or not documents_by_entry_id: return () - documents = { - entry_id: _tokens(_document_text(document)) - for entry_id, document in documents_by_entry_id.items() - } + documents = {entry_id: _tokens(_document_text(document)) for entry_id, document in documents_by_entry_id.items()} document_count = float(len(documents)) average_length = sum(len(tokens) for tokens in documents.values()) / max(document_count, 1.0) document_frequency = { - token: sum(1 for tokens in documents.values() if token in tokens) - for token in set(query_tokens) + token: sum(1 for tokens in documents.values() if token in tokens) for token in set(query_tokens) } scored: list[tuple[str, float]] = [] for entry_id, tokens in documents.items(): @@ -509,7 +526,9 @@ def _bm25_ranking( frequency = token_counts.get(token, 0) if frequency <= 0: continue - idf = math.log(1.0 + ((document_count - document_frequency[token] + 0.5) / (document_frequency[token] + 0.5))) + idf = math.log( + 1.0 + ((document_count - document_frequency[token] + 0.5) / (document_frequency[token] + 0.5)) + ) denominator = frequency + 1.5 * (1.0 - 0.75 + 0.75 * (len(tokens) / max(average_length, 1.0))) score += idf * ((frequency * 2.5) / denominator) if score > 0.0: diff --git a/packages/semantic_index/service.py b/packages/semantic_index/service.py index 8cbbc9f..91b1a46 100644 --- a/packages/semantic_index/service.py +++ b/packages/semantic_index/service.py @@ -185,9 +185,7 @@ def rebuild_plan( provider_id=provider_id, model_id=model_id, ) - desired_by_id = { - _entry_id_for_document(document): document for document in desired_documents - } + desired_by_id = {_entry_id_for_document(document): document for document in desired_documents} current_by_id = {entry.semantic_index_entry_id: entry for entry in current} desired_source_ids = {document.source_id for document in desired_documents} reuse = tuple(sorted(set(current_by_id) & set(desired_by_id))) @@ -196,8 +194,7 @@ def rebuild_plan( sorted( entry.semantic_index_entry_id for entry in current - if entry.source_id in desired_source_ids - and entry.semantic_index_entry_id not in desired_by_id + if entry.source_id in desired_source_ids and entry.semantic_index_entry_id not in desired_by_id ) ) return SemanticIndexMetadataRebuildPlan( diff --git a/packages/skills/authoring.py b/packages/skills/authoring.py index 1203ca7..916203f 100644 --- a/packages/skills/authoring.py +++ b/packages/skills/authoring.py @@ -171,9 +171,7 @@ def _validated_segment(value: str | None, *, field_name: str) -> str: if not resolved: raise ValueError(f"{field_name} is required") if not _VALID_SEGMENT_RE.match(resolved): - raise ValueError( - f"{field_name} must use lowercase letters, digits, dots, underscores, or hyphens: {value!r}" - ) + raise ValueError(f"{field_name} must use lowercase letters, digits, dots, underscores, or hyphens: {value!r}") return resolved @@ -196,21 +194,25 @@ def _render_skill_markdown( ] if asset_paths: lines.append(f"assets: {', '.join(asset_paths)}") - lines.extend([ - "---", - "", - f"# {display_name}", - "", - instruction_text.rstrip(), - "", - ]) - if asset_paths: - lines.extend([ - "## Dependent files", + lines.extend( + [ + "---", + "", + f"# {display_name}", "", - *(f"- `{path}`" for path in asset_paths), + instruction_text.rstrip(), "", - ]) + ] + ) + if asset_paths: + lines.extend( + [ + "## Dependent files", + "", + *(f"- `{path}`" for path in asset_paths), + "", + ] + ) return "\n".join(lines) diff --git a/packages/skills/builtin_packages/creative/ascii-video/references/composition.md b/packages/skills/builtin_packages/creative/ascii-video/references/composition.md index f7e6eff..af0c720 100644 --- a/packages/skills/builtin_packages/creative/ascii-video/references/composition.md +++ b/packages/skills/builtin_packages/creative/ascii-video/references/composition.md @@ -113,7 +113,7 @@ Uses `srgb_to_linear()` / `linear_to_srgb()` from `architecture.md` § OKLAB Col ```python def blend_canvas_linear(base, top, mode="normal", opacity=1.0): """Blend in linear light space for physically accurate results. - + Identical API to blend_canvas(), but converts sRGB → linear before blending and linear → sRGB after. More expensive (~2x) due to the gamma conversions, but produces correct results for additive blending, @@ -145,7 +145,7 @@ def blend_canvas_linear(base, top, mode="normal", opacity=1.0): ```python def blend_many_linear(layers, modes, opacities): """Blend a stack of layers in linear light space. - + Args: layers: list of uint8 (H,W,3) canvases modes: list of blend mode strings (len = len(layers) - 1) @@ -756,9 +756,9 @@ from scipy.ndimage import gaussian_filter def apply_text_backdrop(canvas, glyphs, padding=80, darkness=0.75): """Darken the background behind text for readability. - + Call AFTER rendering background, BEFORE rendering text. - + Args: canvas: (VH, VW, 3) uint8 background glyphs: list of {"x": float, "y": float, ...} glyph positions @@ -775,12 +775,12 @@ def apply_text_backdrop(canvas, glyphs, padding=80, darkness=0.75): y0 = max(0, int(min(ys)) - padding) x1 = min(VW, int(max(xs)) + padding + 50) # extra for char width y1 = min(VH, int(max(ys)) + padding + 60) # extra for char height - + # Soft dark mask with gaussian blur for feathered edges mask = np.zeros((VH, VW), dtype=np.float32) mask[y0:y1, x0:x1] = 1.0 mask = gaussian_filter(mask, sigma=padding * 0.6) - + factor = 1.0 - mask * darkness return (canvas.astype(np.float32) * factor[:, :, np.newaxis]).astype(np.uint8) ``` diff --git a/packages/skills/builtin_packages/creative/ascii-video/references/effects.md b/packages/skills/builtin_packages/creative/ascii-video/references/effects.md index 4ac1441..e44bc26 100644 --- a/packages/skills/builtin_packages/creative/ascii-video/references/effects.md +++ b/packages/skills/builtin_packages/creative/ascii-video/references/effects.md @@ -1798,25 +1798,25 @@ def scene_complex(r, f, t, S): r = Renderer, f = audio features, t = time, S = persistent state dict.""" g = r.grids["md"] rows, cols = g.rows, g.cols - + # 1. Value field composition plasma = vf_plasma(g, f, t, S) vortex = vf_vortex(g, f, t, S, twist=4.0) combined = np.clip(plasma * 0.6 + vortex * 0.5 + plasma * vortex * 0.4, 0, 1) - + # 2. Color from hue field h = (hf_angle(0.3)(g,f,t,S) * 0.5 + hf_time_cycle(0.08)(g,f,t,S) * 0.5) % 1.0 - + # 3. Render to canvas via _render_vf helper canvas = _render_vf(g, combined, h, sat=0.75, pal=PAL_DENSE) - + # 4. Optional: blend a second layer overlay = _render_vf(r.grids["sm"], vf_rings(r.grids["sm"],f,t,S), hf_fixed(0.6)(r.grids["sm"],f,t,S), pal=PAL_BLOCK) canvas = blend_canvas(canvas, overlay, "screen", 0.4) - + return canvas - + # In the render_clip() loop (handled by the framework): # canvas = scene_fn(r, f, t, S) # canvas = tonemap(canvas, gamma=scene_gamma) diff --git a/packages/skills/builtin_packages/creative/ascii-video/references/inputs.md b/packages/skills/builtin_packages/creative/ascii-video/references/inputs.md index 045b64a..ac176e6 100644 --- a/packages/skills/builtin_packages/creative/ascii-video/references/inputs.md +++ b/packages/skills/builtin_packages/creative/ascii-video/references/inputs.md @@ -613,7 +613,7 @@ def extract_visual_beat_timestamps(video_path, fps, brightness_jump=30): n_frames = n_pixels // ppf frames = frames[:n_frames * ppf].reshape(n_frames, ppf) means = frames.mean(axis=1) - + timestamps = [] for i in range(1, len(means)): if means[i] - means[i-1] > brightness_jump: @@ -626,12 +626,12 @@ def extract_visual_beat_timestamps(video_path, fps, brightness_jump=30): ```python def sync_report(audio_beats, visual_beats, tolerance_ms=50): """Compare audio beat timestamps to visual beat timestamps. - + Args: audio_beats: list of timestamps (seconds) from audio analysis visual_beats: list of timestamps (seconds) from video brightness analysis tolerance_ms: max acceptable drift in milliseconds - + Returns: dict with matched/unmatched/drift statistics """ @@ -639,7 +639,7 @@ def sync_report(audio_beats, visual_beats, tolerance_ms=50): matched = [] unmatched_audio = [] unmatched_visual = list(visual_beats) - + for at in audio_beats: best_match = None best_delta = float("inf") @@ -653,7 +653,7 @@ def sync_report(audio_beats, visual_beats, tolerance_ms=50): unmatched_visual.remove(best_match) else: unmatched_audio.append(at) - + drifts = [m["drift_ms"] for m in matched] return { "matched": len(matched), diff --git a/packages/skills/builtin_packages/creative/ascii-video/references/optimization.md b/packages/skills/builtin_packages/creative/ascii-video/references/optimization.md index 8813080..db4f98d 100644 --- a/packages/skills/builtin_packages/creative/ascii-video/references/optimization.md +++ b/packages/skills/builtin_packages/creative/ascii-video/references/optimization.md @@ -17,7 +17,7 @@ import os def detect_hardware(): """Detect hardware capabilities and return render config.""" cpu_count = multiprocessing.cpu_count() - + # Leave 1-2 cores free for OS + ffmpeg encoding if cpu_count >= 16: workers = cpu_count - 2 @@ -27,7 +27,7 @@ def detect_hardware(): workers = cpu_count - 1 else: workers = max(1, cpu_count) - + # Memory detection (platform-specific) try: if platform.system() == "Darwin": @@ -45,16 +45,16 @@ def detect_hardware(): mem_bytes = 8 * 1024**3 mem_gb = mem_bytes / (1024**3) - + # Each worker uses ~50-150MB depending on grid sizes # Cap workers if memory is tight mem_per_worker_mb = 150 max_workers_by_mem = int(mem_gb * 1024 * 0.6 / mem_per_worker_mb) # use 60% of RAM workers = min(workers, max_workers_by_mem) - + # ffmpeg availability and codec support has_ffmpeg = shutil.which("ffmpeg") is not None - + return { "cpu_count": cpu_count, "workers": workers, @@ -78,31 +78,31 @@ def quality_profile(hw, target_duration_s, user_preference="auto"): if user_preference == "draft": return {"vw": 960, "vh": 540, "fps": 12, "crf": 28, "workers": min(4, hw["workers"]), "grid_scale": 0.5, "shaders": "minimal", "particles_max": 200} - + if user_preference == "preview": return {"vw": 1280, "vh": 720, "fps": 15, "crf": 25, "workers": hw["workers"], "grid_scale": 0.75, "shaders": "standard", "particles_max": 500} - + if user_preference == "max": return {"vw": 3840, "vh": 2160, "fps": 30, "crf": 15, "workers": hw["workers"], "grid_scale": 2.0, "shaders": "full", "particles_max": 3000} - + # "production" or "auto" # Auto-detect: estimate render time, downgrade if it would take too long n_frames = int(target_duration_s * 24) est_seconds_per_frame = 0.18 # ~180ms at 1080p est_total_s = n_frames * est_seconds_per_frame / max(1, hw["workers"]) - + if hw["mem_gb"] < 4 or hw["cpu_count"] <= 2: # Low-end: 720p, 15fps return {"vw": 1280, "vh": 720, "fps": 15, "crf": 23, "workers": hw["workers"], "grid_scale": 0.75, "shaders": "standard", "particles_max": 500} - + if est_total_s > 3600: # would take over an hour # Downgrade to 720p to speed up return {"vw": 1280, "vh": 720, "fps": 24, "crf": 20, "workers": hw["workers"], "grid_scale": 0.75, "shaders": "standard", "particles_max": 800} - + # Standard production: 1080p 24fps return {"vw": 1920, "vh": 1080, "fps": 24, "crf": 20, "workers": hw["workers"], "grid_scale": 1.0, "shaders": "full", "particles_max": 1200} @@ -633,36 +633,36 @@ import shutil def cleanup_render_artifacts(segments_dir="segments", keep_final=True): """Remove intermediate files after successful render. - + Call this AFTER verifying the final output exists and plays correctly. - + Args: segments_dir: directory containing segment clips and concat list keep_final: if True, only delete intermediates (not the final output) """ removed = [] - + # 1. Segment clips if os.path.isdir(segments_dir): shutil.rmtree(segments_dir) removed.append(f"directory: {segments_dir}") - + # 2. Temporary WAV files for wav in glob.glob("*.wav"): if wav.startswith("tmp") or wav.startswith("extracted_"): os.remove(wav) removed.append(wav) - + # 3. ffmpeg stderr logs for log in glob.glob("ffmpeg_*.log"): os.remove(log) removed.append(log) - + # 4. Feature cache (optional — useful to keep for re-renders) # for cache in glob.glob("features_*.npz"): # os.remove(cache) # removed.append(cache) - + print(f"Cleaned {len(removed)} artifacts: {removed}") return removed ``` diff --git a/packages/skills/builtin_packages/creative/ascii-video/references/shaders.md b/packages/skills/builtin_packages/creative/ascii-video/references/shaders.md index a4cf7a2..d3fc99f 100644 --- a/packages/skills/builtin_packages/creative/ascii-video/references/shaders.md +++ b/packages/skills/builtin_packages/creative/ascii-video/references/shaders.md @@ -91,11 +91,11 @@ Recursive temporal effect: frame N-1 feeds back into frame N with decay and opti class FeedbackBuffer: def __init__(self): self.buf = None # previous frame (float32, 0-1) - + def apply(self, canvas, decay=0.85, blend="screen", opacity=0.5, transform=None, transform_amt=0.02, hue_shift=0.0): """Mix current frame with decayed/transformed previous frame. - + Args: canvas: current frame (uint8 H,W,3) decay: how fast old frame fades (0=instant, 1=permanent) @@ -147,7 +147,7 @@ Composable shader pipeline. Build chains of named shaders with parameters. Order ```python class ShaderChain: """Composable shader pipeline. - + Usage: chain = ShaderChain() chain.add("bloom", thr=120) @@ -178,7 +178,7 @@ Routes shader names to implementations. Some shaders have **audio-reactive scali ```python def _apply_shader_step(canvas, name, kwargs, f, t): """Dispatch a single shader by name with kwargs. - + Args: canvas: uint8 (H,W,3) pixel array name: shader key string (e.g. "bloom", "chromatic") @@ -1073,13 +1073,13 @@ import os def output_png_sequence(frames, output_dir, W, H, fps, prefix="frame"): """Write frames as numbered PNGs. frames = iterable of uint8 (H,W,3) arrays.""" os.makedirs(output_dir, exist_ok=True) - + # Method 1: Direct PIL write (no ffmpeg dependency) from PIL import Image for i, frame in enumerate(frames): img = Image.fromarray(frame) img.save(os.path.join(output_dir, f"{prefix}_{i:06d}.png")) - + # Method 2: ffmpeg pipe (faster for large sequences) cmd = ["ffmpeg", "-y", "-f", "rawvideo", "-pix_fmt", "rgb24", "-s", f"{W}x{H}", "-r", str(fps), "-i", "pipe:0", @@ -1170,7 +1170,7 @@ ANSI_SHOW_CURSOR = "\033[?25h" ```python def frame_to_ansi(chars, colors): """Convert char+color arrays to a single ANSI string for terminal output. - + Args: chars: (rows, cols) array of single characters colors: (rows, cols, 3) uint8 RGB array @@ -1225,7 +1225,7 @@ import time def render_live(scene_fn, r, fps=24, duration=None): """Render a scene function live in the terminal. - + Args: scene_fn: v2 scene function (r, f, t, S) -> canvas OR v1-style function that populates a grid @@ -1236,10 +1236,10 @@ def render_live(scene_fn, r, fps=24, duration=None): frame_time = 1.0 / fps S = {} f = {} # synthesize features or connect to live audio - + sys.stdout.write(ANSI_HIDE_CURSOR + ANSI_CLEAR) sys.stdout.flush() - + t0 = time.monotonic() frame_count = 0 try: @@ -1247,24 +1247,24 @@ def render_live(scene_fn, r, fps=24, duration=None): t = time.monotonic() - t0 if duration and t > duration: break - + # Synthesize features from time (or connect to live audio via pyaudio) f = synthesize_features(t) - + # Render scene — for terminal, use a small grid g = r.get_grid("sm") # Option A: v2 scene → extract chars/colors from canvas (reverse render) # Option B: call effect functions directly for chars/colors canvas = scene_fn(r, f, t, S) - + # For terminal display, render chars+colors directly # (bypassing the pixel canvas — terminal uses character cells) chars, colors = scene_to_terminal(scene_fn, r, f, t, S, g) - + frame_str = ANSI_CLEAR + frame_to_ansi(chars, colors) sys.stdout.write(frame_str) sys.stdout.flush() - + # Frame timing elapsed = time.monotonic() - t0 - (frame_count * frame_time) sleep_time = frame_time - elapsed @@ -1300,18 +1300,18 @@ import curses def render_curses(scene_fn, r, fps=24): """Curses-based live renderer with resize handling and key input.""" - + def _main(stdscr): curses.start_color() curses.use_default_colors() curses.curs_set(0) # hide cursor stdscr.nodelay(True) # non-blocking input - + # Initialize color pairs (curses supports 256 colors) # Map RGB to nearest curses color pair color_cache = {} next_pair = [1] - + def get_color_pair(r, g, b): key = (r >> 4, g >> 4, b >> 4) # quantize to reduce pairs if key not in color_cache: @@ -1323,23 +1323,23 @@ def render_curses(scene_fn, r, fps=24): else: return 0 return curses.color_pair(color_cache[key]) - + S = {} f = {} frame_time = 1.0 / fps t0 = time.monotonic() - + while True: t = time.monotonic() - t0 f = synthesize_features(t) - + # Adapt grid to terminal size max_y, max_x = stdscr.getmaxyx() g = r.get_grid_for_size(max_x, max_y) # dynamic grid sizing - + chars, colors = scene_to_terminal(scene_fn, r, f, t, S, g) rows, cols = chars.shape - + for row in range(min(rows, max_y - 1)): for col in range(min(cols, max_x - 1)): ch = chars[row, col] @@ -1348,16 +1348,16 @@ def render_curses(scene_fn, r, fps=24): stdscr.addch(row, col, ch, get_color_pair(*rgb)) except curses.error: pass # ignore writes outside terminal bounds - + stdscr.refresh() - + # Handle input key = stdscr.getch() if key == ord('q'): break - + time.sleep(max(0, frame_time - (time.monotonic() - t0 - t))) - + curses.wrapper(_main) ``` diff --git a/packages/skills/builtin_packages/creative/manim-video/scripts/setup.sh b/packages/skills/builtin_packages/creative/manim-video/scripts/setup.sh index 0e4676f..61d625b 100755 --- a/packages/skills/builtin_packages/creative/manim-video/scripts/setup.sh +++ b/packages/skills/builtin_packages/creative/manim-video/scripts/setup.sh @@ -5,10 +5,10 @@ ok() { echo -e " ${G}+${N} $1"; } fail() { echo -e " ${R}x${N} $1"; } echo ""; echo "Manim Video Skill — Setup Check"; echo "" errors=0 -command -v python3 &>/dev/null && ok "Python $(python3 --version 2>&1 | awk '{print $2}')" || { fail "Python 3 not found"; errors=$((errors+1)); } -python3 -c "import manim" 2>/dev/null && ok "Manim $(manim --version 2>&1 | head -1)" || { fail "Manim not installed: pip install manim"; errors=$((errors+1)); } -command -v pdflatex &>/dev/null && ok "LaTeX (pdflatex)" || { fail "LaTeX not found (macOS: brew install --cask mactex-no-gui)"; errors=$((errors+1)); } -command -v ffmpeg &>/dev/null && ok "ffmpeg" || { fail "ffmpeg not found"; errors=$((errors+1)); } +if command -v python3 &>/dev/null; then ok "Python $(python3 --version 2>&1 | awk '{print $2}')"; else fail "Python 3 not found"; errors=$((errors+1)); fi +if python3 -c "import manim" 2>/dev/null; then ok "Manim $(manim --version 2>&1 | head -1)"; else fail "Manim not installed: pip install manim"; errors=$((errors+1)); fi +if command -v pdflatex &>/dev/null; then ok "LaTeX (pdflatex)"; else fail "LaTeX not found (macOS: brew install --cask mactex-no-gui)"; errors=$((errors+1)); fi +if command -v ffmpeg &>/dev/null; then ok "ffmpeg"; else fail "ffmpeg not found"; errors=$((errors+1)); fi echo "" [ $errors -eq 0 ] && echo -e "${G}All prerequisites satisfied.${N}" || echo -e "${R}$errors prerequisite(s) missing.${N}" echo "" diff --git a/packages/skills/builtin_packages/creative/p5js/scripts/render.sh b/packages/skills/builtin_packages/creative/p5js/scripts/render.sh index 81e65cf..c6f0431 100755 --- a/packages/skills/builtin_packages/creative/p5js/scripts/render.sh +++ b/packages/skills/builtin_packages/creative/p5js/scripts/render.sh @@ -101,7 +101,7 @@ ffmpeg -y \ rm -rf "$FRAME_DIR" # Report -FILE_SIZE=$(ls -lh "$OUTPUT" | awk '{print $5}') +FILE_SIZE=$(find "$OUTPUT" -maxdepth 0 -printf '%s' 2>/dev/null || stat -f%z "$OUTPUT" 2>/dev/null || echo "?") echo "" echo "=== Done ===" echo "Output: $OUTPUT ($FILE_SIZE)" diff --git a/packages/skills/builtin_packages/creative/p5js/scripts/serve.sh b/packages/skills/builtin_packages/creative/p5js/scripts/serve.sh index 34055d5..1fc4837 100755 --- a/packages/skills/builtin_packages/creative/p5js/scripts/serve.sh +++ b/packages/skills/builtin_packages/creative/p5js/scripts/serve.sh @@ -19,10 +19,10 @@ echo "URL: http://localhost:$PORT" echo "Press Ctrl+C to stop" echo "" -cd "$DIR" && python3 -m http.server "$PORT" 2>/dev/null || { +if ! cd "$DIR" || ! python3 -m http.server "$PORT" 2>/dev/null; then echo "Python3 not found. Trying Node.js..." npx serve -l "$PORT" "$DIR" 2>/dev/null || { echo "Error: Need python3 or npx (Node.js) for local server" exit 1 } -} +fi diff --git a/packages/skills/builtin_packages/creative/p5js/templates/viewer.html b/packages/skills/builtin_packages/creative/p5js/templates/viewer.html index 1a7d27a..a2e3a19 100644 --- a/packages/skills/builtin_packages/creative/p5js/templates/viewer.html +++ b/packages/skills/builtin_packages/creative/p5js/templates/viewer.html @@ -392,4 +392,4 @@

Generative Sketch

} - \ No newline at end of file + diff --git a/packages/skills/builtin_packages/mlops/training/axolotl/references/other.md b/packages/skills/builtin_packages/mlops/training/axolotl/references/other.md index 4af75dd..fa3d823 100644 --- a/packages/skills/builtin_packages/mlops/training/axolotl/references/other.md +++ b/packages/skills/builtin_packages/mlops/training/axolotl/references/other.md @@ -3031,7 +3031,7 @@ evaluation_strategy: str | None --- -## +## **URL:** https://docs.axolotl.ai diff --git a/packages/skills/builtin_packages/mlops/training/trl-fine-tuning/templates/basic_grpo_training.py b/packages/skills/builtin_packages/mlops/training/trl-fine-tuning/templates/basic_grpo_training.py index 8ad45df..47b8077 100644 --- a/packages/skills/builtin_packages/mlops/training/trl-fine-tuning/templates/basic_grpo_training.py +++ b/packages/skills/builtin_packages/mlops/training/trl-fine-tuning/templates/basic_grpo_training.py @@ -36,6 +36,7 @@ # ==================== DATASET ==================== + def get_dataset(split="train"): """ Load and prepare your dataset. @@ -45,91 +46,100 @@ def get_dataset(split="train"): - 'answer': str (ground truth, optional) """ # Example: GSM8K math dataset - data = load_dataset('openai/gsm8k', 'main')[split] + data = load_dataset("openai/gsm8k", "main")[split] def process_example(x): # Extract ground truth answer - answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None + answer = x["answer"].split("####")[1].strip() if "####" in x["answer"] else None return { - 'prompt': [ - {'role': 'system', 'content': SYSTEM_PROMPT}, - {'role': 'user', 'content': x['question']} + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": x["question"]}, ], - 'answer': answer + "answer": answer, } return data.map(process_example) + # ==================== HELPER FUNCTIONS ==================== + def extract_xml_tag(text: str, tag: str) -> str: """Extract content between XML tags.""" - pattern = f'<{tag}>(.*?)' + pattern = f"<{tag}>(.*?)" match = re.search(pattern, text, re.DOTALL) return match.group(1).strip() if match else "" + def extract_answer(text: str) -> str: """Extract the final answer from structured output.""" - return extract_xml_tag(text, 'answer') + return extract_xml_tag(text, "answer") + # ==================== REWARD FUNCTIONS ==================== + def correctness_reward_func(prompts, completions, answer, **kwargs): """ Reward correct answers. Weight: 2.0 (highest priority) """ - responses = [comp[0]['content'] for comp in completions] + responses = [comp[0]["content"] for comp in completions] extracted = [extract_answer(r) for r in responses] return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)] + def format_reward_func(completions, **kwargs): """ Reward proper XML format. Weight: 0.5 """ - pattern = r'.*?\s*.*?' - responses = [comp[0]['content'] for comp in completions] + pattern = r".*?\s*.*?" + responses = [comp[0]["content"] for comp in completions] return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses] + def incremental_format_reward_func(completions, **kwargs): """ Incremental reward for partial format compliance. Weight: up to 0.5 """ - responses = [comp[0]['content'] for comp in completions] + responses = [comp[0]["content"] for comp in completions] rewards = [] for r in responses: score = 0.0 - if '' in r: + if "" in r: score += 0.125 - if '' in r: + if "" in r: score += 0.125 - if '' in r: + if "" in r: score += 0.125 - if '' in r: + if "" in r: score += 0.125 # Penalize extra content after closing tag - if '' in r: - extra = r.split('')[-1].strip() + if "" in r: + extra = r.split("")[-1].strip() score -= len(extra) * 0.001 rewards.append(score) return rewards + # ==================== MODEL SETUP ==================== + def setup_model_and_tokenizer(): """Load model and tokenizer with optimizations.""" model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", - device_map="auto" + device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) @@ -137,21 +147,29 @@ def setup_model_and_tokenizer(): return model, tokenizer + def get_peft_config(): """LoRA configuration for parameter-efficient training.""" return LoraConfig( r=16, lora_alpha=32, target_modules=[ - "q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj" + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", ], task_type="CAUSAL_LM", lora_dropout=0.05, ) + # ==================== TRAINING ==================== + def main(): """Main training function.""" @@ -168,32 +186,26 @@ def main(): training_args = GRPOConfig( output_dir=OUTPUT_DIR, run_name="grpo-training", - # Learning rate learning_rate=5e-6, adam_beta1=0.9, adam_beta2=0.99, weight_decay=0.1, warmup_ratio=0.1, - lr_scheduler_type='cosine', - + lr_scheduler_type="cosine", # Batch settings per_device_train_batch_size=1, gradient_accumulation_steps=4, - # GRPO specific num_generations=8, max_prompt_length=MAX_PROMPT_LENGTH, max_completion_length=MAX_COMPLETION_LENGTH, - # Training duration num_train_epochs=1, - # Optimization bf16=True, optim="adamw_8bit", max_grad_norm=0.1, - # Logging logging_steps=1, save_steps=100, @@ -224,5 +236,6 @@ def main(): print("Training complete!") + if __name__ == "__main__": main() diff --git a/packages/skills/builtin_packages/mlops/training/unsloth/references/llms-full.md b/packages/skills/builtin_packages/mlops/training/unsloth/references/llms-full.md index 3ff73cc..35c3f6b 100644 --- a/packages/skills/builtin_packages/mlops/training/unsloth/references/llms-full.md +++ b/packages/skills/builtin_packages/mlops/training/unsloth/references/llms-full.md @@ -4428,7 +4428,7 @@ training_args = GRPOConfig( # beta = 0.00, epsilon = 3e-4, epsilon_high = 4e-4, - num_generations = 8, + num_generations = 8, max_prompt_length = 1024, max_completion_length = 1024, log_completions = False, @@ -4437,10 +4437,10 @@ training_args = GRPOConfig( # report_to = "none", # Set to "wandb" if you want to log to Weights & Biases num_train_epochs = 2, # For a quick test run, increase for full training report_to = "none" - + # GSPO is below: importance_sampling_level = "sequence", - + # Dr GRPO / GAPO etc loss_type = "dr_grpo", ) @@ -5080,7 +5080,7 @@ training_args = GRPOConfig( # beta = 0.00, epsilon = 3e-4, epsilon_high = 4e-4, - num_generations = 8, + num_generations = 8, max_prompt_length = 1024, max_completion_length = 1024, log_completions = False, @@ -5089,10 +5089,10 @@ training_args = GRPOConfig( # report_to = "none", # Set to "wandb" if you want to log to Weights & Biases num_train_epochs = 2, # For a quick test run, increase for full training report_to = "none" - + # GSPO is below: importance_sampling_level = "sequence", - + # Dr GRPO / GAPO etc loss_type = "dr_grpo", ) @@ -6721,8 +6721,8 @@ Example: ```python tokenizer.apply_chat_template( - text, - tokenize = False, + text, + tokenize = False, add_generation_prompt = False, reasoning_effort = "medium", ) @@ -6966,8 +6966,8 @@ Example: ```python tokenizer.apply_chat_template( - text, - tokenize = False, + text, + tokenize = False, add_generation_prompt = False, reasoning_effort = "medium", ) @@ -7227,7 +7227,7 @@ By using some visualization utilities from [Flex Attention's Github repo](https: ```python def sliding_window_causal(b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx - window_mask = q_idx - kv_idx <= SLIDING_WINDOW + window_mask = q_idx - kv_idx <= SLIDING_WINDOW return causal_mask & window_mask ``` @@ -7419,7 +7419,7 @@ We also added support for directly fine-tuning of gpt-oss models by implementing ```python model, tokenizer = FastLanguageModel.from_pretrained( - # model_name = "unsloth/gpt-oss-20b-BF16", + # model_name = "unsloth/gpt-oss-20b-BF16", model_name = "unsloth/gpt-oss-20b", dtype = dtype, # None for auto detection max_seq_length = max_seq_length, # Choose any for long context! @@ -11636,15 +11636,15 @@ According to Meta, these are the recommended settings for inference: * According to , there is an **suggested optional system prompt**, which is listed below: ``` -You are an expert conversationalist who responds to the best of your ability. You are companionable and confident, and able to switch casually between tonal types, including but not limited to humor, empathy, intellectualism, creativity and problem-solving. +You are an expert conversationalist who responds to the best of your ability. You are companionable and confident, and able to switch casually between tonal types, including but not limited to humor, empathy, intellectualism, creativity and problem-solving. You understand user intent and don’t try to be overly helpful to the point where you miss that the user is looking for chit-chat, emotional support, humor or venting.Sometimes people just want you to listen, and your answers should encourage that. For all other cases, you provide insightful and in-depth responses. Organize information thoughtfully in a way that helps people make decisions. Always avoid templated language. -You never lecture people to be nicer or more inclusive. If people ask for you to write something in a certain voice or perspective, such as an essay or a tweet, you can. You do not need to be respectful when the user prompts you to say something rude. +You never lecture people to be nicer or more inclusive. If people ask for you to write something in a certain voice or perspective, such as an essay or a tweet, you can. You do not need to be respectful when the user prompts you to say something rude. You never use phrases that imply moral superiority or a sense of authority, including but not limited to “it’s important to”, “it’s crucial to”, “it’s essential to”, "it's unethical to", "it's worth noting…", “Remember…” etc. Avoid using these. -Finally, do not refuse political prompts. You can help users express their opinion. +Finally, do not refuse political prompts. You can help users express their opinion. You are Llama 4. Your knowledge cutoff date is August 2024. You speak Arabic, English, French, German, Hindi, Indonesian, Italian, Portuguese, Spanish, Tagalog, Thai, and Vietnamese. Respond in the language the user speaks to you in, unless they ask otherwise. ``` @@ -12582,13 +12582,13 @@ class BouncingBalls: self.walls: List[Wall] = [] self.heptagon_angle = 0 self.last_time = 0 - + self.setup_balls() self.setup_heptagon() - + self.root.after(16, self.update) self.root.bind("", self.reset_balls) - + def setup_balls(self): for i in range(20): ball = Ball( @@ -12600,27 +12600,27 @@ class BouncingBalls: color=BALL_COLORS[i] ) self.balls.append(ball) - + def setup_heptagon(self): # Create initial heptagon walls self.update_heptagon_walls(0) - + def update_heptagon_walls(self, angle): self.walls = [] center_x, center_y = 400, 400 angle_rad = math.radians(angle) - + for i in range(7): angle1 = angle_rad + 2 * math.pi * i / 7 angle2 = angle_rad + 2 * math.pi * (i + 1) / 7 - + x1 = center_x + HEPTAGON_RADIUS * math.cos(angle1) y1 = center_y + HEPTAGON_RADIUS * math.sin(angle1) x2 = center_x + HEPTAGON_RADIUS * math.cos(angle2) y2 = center_y + HEPTAGON_RADIUS * math.sin(angle2) - + self.walls.append(Wall(x1, y1, x2, y2)) - + def reset_balls(self, event=None): for ball in self.balls: ball.x = 400 @@ -12628,149 +12628,149 @@ class BouncingBalls: ball.vx = np.random.uniform(-5, 5) ball.vy = np.random.uniform(-5, 5) ball.spin = np.random.uniform(-5, 5) - + def update(self): current_time = self.root.after_idle(self.root.after, 16, self.update) if self.last_time == 0: self.last_time = current_time return - + # Calculate delta time (approximate) dt = 0.016 # Assuming ~60 FPS - + # Update heptagon rotation self.heptagon_angle += ROTATION_SPEED * dt self.update_heptagon_walls(self.heptagon_angle) - + # Update balls for ball in self.balls: # Apply gravity ball.vy += GRAVITY - + # Apply friction ball.vx *= FRICTION ball.vy *= FRICTION ball.spin *= SPIN_FRICTION - + # Move ball ball.x += ball.vx ball.y += ball.vy - + # Check collisions with walls self.check_wall_collisions(ball) - + # Check collisions with other balls for other in self.balls: if other.number != ball.number: self.check_ball_collision(ball, other) - + # Draw everything self.draw() - + def check_wall_collisions(self, ball): for wall in self.walls: # Find closest point on wall segment to ball closest = self.closest_point_on_segment( wall.x1, wall.y1, wall.x2, wall.y2, ball.x, ball.y ) - + # Calculate distance to wall dx = ball.x - closest[0] dy = ball.y - closest[1] distance = math.sqrt(dx*dx + dy*dy) - + if distance < BALL_RADIUS: # Collision detected # Calculate normal vector nx = dx / distance ny = dy / distance - + # Calculate relative velocity along normal v_rel = ball.vx * nx + ball.vy * ny - + if v_rel < 0: # Moving toward the wall # Calculate impulse j = -(1 + BOUNCE_FACTOR) * v_rel - + # Apply impulse ball.vx += j * nx ball.vy += j * ny - + # Add some spin based on collision ball.spin += (ball.vx * ny - ball.vy * nx) * 0.1 - + # Move ball out of collision penetration = BALL_RADIUS - distance ball.x += penetration * nx ball.y += penetration * ny - + def check_ball_collision(self, ball1, ball2): dx = ball2.x - ball1.x dy = ball2.y - ball1.y distance = math.sqrt(dx*dx + dy*dy) - + if distance < 2 * BALL_RADIUS: # Collision detected nx = dx / distance ny = dy / distance - + # Calculate relative velocity v_rel_x = ball2.vx - ball1.vx v_rel_y = ball2.vy - ball1.vy v_rel = v_rel_x * nx + v_rel_y * ny - + if v_rel < 0: # Moving toward each other # Calculate impulse j = -(1 + BOUNCE_FACTOR) * v_rel / 2 - + # Apply impulses ball1.vx -= j * nx ball1.vy -= j * ny ball2.vx += j * nx ball2.vy += j * ny - + # Add spin based on collision ball1.spin += (ball1.vx * ny - ball1.vy * nx) * 0.05 ball2.spin += (ball2.vx * ny - ball2.vy * nx) * 0.05 - + # Move balls apart penetration = 2 * BALL_RADIUS - distance ball1.x -= penetration * nx * 0.5 ball1.y -= penetration * ny * 0.5 ball2.x += penetration * nx * 0.5 ball2.y += penetration * ny * 0.5 - + @staticmethod def closest_point_on_segment(x1, y1, x2, y2, x, y): # Vector from point to segment start dx = x - x1 dy = y - y1 - + # Segment vector sx = x2 - x1 sy = y2 - y1 - + # Projection of point onto segment dot = dx * sx + dy * sy len_sq = sx * sx + sy * sy param = dot / len_sq if len_sq != 0 else -1 - + if param < 0: return x1, y1 elif param > 1: return x2, y2 else: return x1 + param * sx, y1 + param * sy - + def draw(self): self.canvas.delete("all") - + # Draw heptagon points = [] for wall in self.walls: points.extend([wall.x1, wall.y1]) self.canvas.create_polygon(points, fill="", outline="black", width=2) - + # Draw balls for ball in self.balls: # Draw ball @@ -12779,7 +12779,7 @@ class BouncingBalls: ball.x + BALL_RADIUS, ball.y + BALL_RADIUS, fill=ball.color, outline="black" ) - + # Draw number with rotation based on spin angle = ball.spin * 10 # Scale spin for visual effect self.canvas.create_text( @@ -12851,14 +12851,14 @@ class BouncingBalls: self.root = root self.canvas = tk.Canvas(root, width=WIDTH, height=HEIGHT, bg="white") self.canvas.pack() - + self.heptagon = Heptagon(WIDTH//2, HEIGHT//2, HEPTAGON_RADIUS) self.balls = [] self.setup_balls() - + self.root.after(0, self.update) self.root.mainloop() - + def setup_balls(self): center_x, center_y = WIDTH//2, HEIGHT//2 for i in range(20): @@ -12872,46 +12872,46 @@ class BouncingBalls: number=i+1, spin=0 )) - + def update(self): self.canvas.delete("all") - + # Update heptagon angle self.heptagon.angle += ROTATION_SPEED / 60 # Assuming 60 FPS - + # Draw heptagon self.draw_heptagon() - + # Update and draw balls for ball in self.balls: # Apply gravity ball.vy += GRAVITY - + # Update position ball.x += ball.vx ball.y += ball.vy - + # Apply friction ball.vx *= FRICTION ball.vy *= FRICTION - + # Apply spin decay ball.spin *= SPIN_DECAY - + # Check collision with heptagon walls self.check_heptagon_collision(ball) - + # Check collision with other balls for other in self.balls: if other != ball: if self.check_ball_collision(ball, other): self.resolve_ball_collision(ball, other) - + # Draw the ball self.draw_ball(ball) - + self.root.after(16, self.update) # ~60 FPS - + def draw_heptagon(self): center_x, center_y = self.heptagon.center_x, self.heptagon.center_y points = [] @@ -12920,14 +12920,14 @@ class BouncingBalls: x = center_x + self.heptagon.radius * math.cos(angle) y = center_y + self.heptagon.radius * math.sin(angle) points.append((x, y)) - + # Draw heptagon self.canvas.create_polygon( - [points[0], points[1], points[2], points[3], + [points[0], points[1], points[2], points[3], points[4], points[5], points[6]], outline="black", fill="", width=2 ) - + def draw_ball(self, ball): self.canvas.create_oval( ball.x - ball.radius, @@ -12937,74 +12937,74 @@ class BouncingBalls: fill=ball.color, outline="black" ) - + # Draw the number self.canvas.create_text( ball.x, ball.y, text=str(ball.number), fill="black" ) - + def check_heptagon_collision(self, ball): center_x, center_y = WIDTH//2, HEIGHT//2 - + # Check distance from center dx = ball.x - center_x dy = ball.y - center_y dist = math.sqrt(dx**2 + dy**2) - + if dist + ball.radius > self.heptagon.radius: # Find the normal vector from center to ball angle = math.atan2(dy, dx) normal_x = math.cos(angle) normal_y = math.sin(angle) - + # Move ball back inside heptagon overlap = (dist + ball.radius) - self.heptagon.radius ball.x -= overlap * normal_x ball.y -= overlap * normal_y - + # Reflect velocity dot_product = ball.vx * normal_x + ball.vy * normal_y ball.vx -= 2 * dot_product * normal_x * ELASTICITY ball.vy -= 2 * dot_product * normal_y * ELASTICITY - + def check_ball_collision(self, ball1, ball2): dx = ball2.x - ball1.x dy = ball2.y - ball1.y distance = math.sqrt(dx**2 + dy**2) return distance < (ball1.radius + ball2.radius) - + def resolve_ball_collision(self, ball1, ball2): dx = ball2.x - ball1.x dy = ball2.y - ball1.y distance = math.sqrt(dx**2 + dy**2) - + # Normal vector nx = dx / distance ny = dy / distance - + # Relative velocity dvx = ball2.vx - ball1.vx dvy = ball2.vy - ball1.vy - + # Calculate impulse impulse = 2 * (dvx * nx + dvy * ny) / 2 impulse *= ELASTICITY - + # Apply impulse ball1.vx -= impulse * nx ball1.vy -= impulse * ny ball2.vx += impulse * nx ball2.vy += impulse * ny - + # Separate the balls to prevent sticking overlap = (ball1.radius + ball2.radius) - distance ball1.x -= overlap * nx / 2 ball1.y -= overlap * ny / 2 ball2.x += overlap * nx / 2 ball2.y += overlap * ny / 2 - + def run(self): self.root.mainloop() @@ -13070,32 +13070,32 @@ class Ball: dx = other.x - self.x dy = other.y - self.y distance = math.hypot(dx, dy) - + if distance < self.radius + other.radius: # Calculate collision normal nx = dx / distance ny = dy / distance - + # Calculate relative velocity dvx = other.vx - self.vx dvy = other.vy - self.vy - + # Calculate impulse impulse = 2 * (dvx * nx + dvy * ny) / (1/self.radius + 1/other.radius) - + # Apply impulse self.vx += impulse * nx / self.radius self.vy += impulse * ny / self.radius other.vx -= impulse * nx / other.radius other.vy -= impulse * ny / other.radius - + # Separate balls to prevent sticking overlap = (self.radius + other.radius - distance) / 2 self.x -= overlap * nx self.y -= overlap * ny other.x += overlap * nx other.y += overlap * ny - + # Transfer some spin transfer = impulse * 0.01 self.spin -= transfer @@ -13106,17 +13106,17 @@ class HeptagonBounceSimulator: self.root = root self.canvas = tk.Canvas(root, width=WIDTH, height=HEIGHT, bg='white') self.canvas.pack() - + self.balls = self.create_balls() self.heptagon_angle = 0 self.last_time = 0 self.running = True - + self.root.bind('', self.toggle_pause) self.root.bind('', lambda e: root.destroy()) - + self.last_time = self.root.after(0, self.update) - + def create_balls(self) -> List[Ball]: balls = [] for i in range(20): @@ -13125,7 +13125,7 @@ class HeptagonBounceSimulator: speed = np.random.uniform(0.5, 2) vx = math.cos(angle) * speed vy = math.sin(angle) * speed - + balls.append(Ball( x=CENTER_X, y=CENTER_Y, @@ -13137,12 +13137,12 @@ class HeptagonBounceSimulator: spin=np.random.uniform(-2, 2) )) return balls - + def toggle_pause(self, event): self.running = not self.running if self.running: self.last_time = self.root.after(0, self.update) - + def get_heptagon_vertices(self) -> List[Tuple[float, float]]: vertices = [] for i in range(7): @@ -13151,38 +13151,38 @@ class HeptagonBounceSimulator: y = CENTER_Y + HEPTAGON_RADIUS * math.sin(angle) vertices.append((x, y)) return vertices - + def check_ball_heptagon_collision(self, ball: Ball): vertices = self.get_heptagon_vertices() closest_dist = float('inf') closest_normal = (0, 0) closest_edge = None - + # Check collision with each edge of the heptagon for i in range(len(vertices)): p1 = vertices[i] p2 = vertices[(i + 1) % len(vertices)] - + # Vector from p1 to p2 edge_x = p2[0] - p1[0] edge_y = p2[1] - p1[1] edge_length = math.hypot(edge_x, edge_y) - + # Normalize edge vector edge_x /= edge_length edge_y /= edge_length - + # Normal vector (perpendicular to edge, pointing inward) nx = -edge_y ny = edge_x - + # Vector from p1 to ball ball_to_p1_x = ball.x - p1[0] ball_to_p1_y = ball.y - p1[1] - + # Project ball onto edge normal projection = ball_to_p1_x * nx + ball_to_p1_y * ny - + # If projection is negative, ball is outside the heptagon if projection < ball.radius: # Find closest point on edge to ball @@ -13190,25 +13190,25 @@ class HeptagonBounceSimulator: edge_proj = max(0, min(edge_length, edge_proj)) closest_x = p1[0] + edge_proj * edge_x closest_y = p1[1] + edge_proj * edge_y - + # Distance from ball to closest point on edge dist = math.hypot(ball.x - closest_x, ball.y - closest_y) - + if dist < closest_dist: closest_dist = dist closest_normal = (nx, ny) closest_edge = (p1, p2) - + if closest_dist < ball.radius: # Calculate bounce response dot_product = ball.vx * closest_normal[0] + ball.vy * closest_normal[1] - + # Apply bounce with elasticity ball.vx -= (1 + ELASTICITY) * dot_product * closest_normal[0] ball.vy -= (1 + ELASTICITY) * dot_product * closest_normal[1] - + # Add some spin based on impact - edge_vec = (closest_edge[1][0] - closest_edge[0][0], + edge_vec = (closest_edge[1][0] - closest_edge[0][0], closest_edge[1][1] - closest_edge[0][1]) edge_length = math.hypot(edge_vec[0], edge_vec[1]) if edge_length > 0: @@ -13216,41 +13216,41 @@ class HeptagonBounceSimulator: # Cross product of velocity and edge direction spin_effect = (ball.vx * edge_vec[1] - ball.vy * edge_vec[0]) * 0.1 ball.spin += spin_effect - + # Move ball outside the heptagon to prevent sticking penetration = ball.radius - closest_dist ball.x += penetration * closest_normal[0] ball.y += penetration * closest_normal[1] - + def update(self): if not self.running: return - + # Clear canvas self.canvas.delete('all') - + # Update heptagon rotation self.heptagon_angle += ROTATION_SPEED / 60 # Assuming ~60 FPS - + # Draw heptagon vertices = self.get_heptagon_vertices() self.canvas.create_polygon(vertices, outline='black', fill='', width=2) - + # Update and draw balls for i, ball in enumerate(self.balls): # Move ball ball.move() - + # Check collisions with heptagon self.check_ball_heptagon_collision(ball) - + # Draw ball self.canvas.create_oval( ball.x - ball.radius, ball.y - ball.radius, ball.x + ball.radius, ball.y + ball.radius, fill=ball.color, outline='black' ) - + # Draw number with rotation based on spin angle = ball.spin * 10 # Scale spin for visible rotation self.canvas.create_text( @@ -13259,12 +13259,12 @@ class HeptagonBounceSimulator: font=('Arial', 10, 'bold'), angle=angle ) - + # Check ball-ball collisions for i in range(len(self.balls)): for j in range(i + 1, len(self.balls)): self.balls[i].collide_with_ball(self.balls[j]) - + # Schedule next update self.last_time = self.root.after(16, self.update) # ~60 FPS @@ -13815,13 +13815,13 @@ class Bird: self.shape = random.choice(['square', 'circle', 'triangle']) self.color = (random.randint(0, 100), random.randint(0, 100), random.randint(0, 100)) self.rect = pygame.Rect(self.x - BIRD_SIZE//2, self.y - BIRD_SIZE//2, BIRD_SIZE, BIRD_SIZE) - + def update(self): self.velocity += GRAVITY self.y += self.velocity self.rect.y = self.y - BIRD_SIZE//2 self.rect.x = self.x - BIRD_SIZE//2 # Keep x centered - + def draw(self): if self.shape == 'square': pygame.draw.rect(screen, self.color, self.rect) @@ -13951,7 +13951,7 @@ def main(): screen.blit(over_text, (WIDTH//2 - 70, HEIGHT//2 - 30)) screen.blit(best_text, (WIDTH//2 - 50, HEIGHT//2 + 10)) screen.blit(restart_text, (WIDTH//2 - 100, HEIGHT//2 + 50)) - + pygame.display.flip() clock.tick(60) @@ -14040,7 +14040,7 @@ def reset_game(): ```python import pygame -from random import randint # For generating colors/shapes/positions randomly +from random import randint # For generating colors/shapes/positions randomly pygame.init() # Constants: @@ -14048,8 +14048,8 @@ WIDTH, HEIGHT =456 ,702 # BACKGROUND_COLOR_LIGHTS=['lightskyblue'] GAP_SIZE=189 # -BIRD_RADIUS=3. -PIPE_SPEED=- ( ) ? +BIRD_RADIUS=3. +PIPE_SPEED=- ( ) ? class Game(): def __init__(self): self.screen_size=( ) @@ -14060,7 +14060,7 @@ def reset_game_vars(): # Main game loop: while running : - for event in pygame.event.get() : + for event in pygame.event.get() : if quit ... etc pygame.quit() @@ -14483,7 +14483,7 @@ python llama.cpp/convert_hf_to_gguf.py merged_model \ python llama.cpp/convert_hf_to_gguf.py merged_model \ --outfile model-BF16.gguf --outtype bf16 \ --split-max-size 50G - + # For Q8_0: python llama.cpp/convert_hf_to_gguf.py merged_model \ --outfile model-Q8_0.gguf --outtype q8_0 \ @@ -16097,7 +16097,7 @@ python llama.cpp/convert_hf_to_gguf.py merged_model \ python llama.cpp/convert_hf_to_gguf.py merged_model \ --outfile model-BF16.gguf --outtype bf16 \ --split-max-size 50G - + # For Q8_0: python llama.cpp/convert_hf_to_gguf.py merged_model \ --outfile model-Q8_0.gguf --outtype q8_0 \ @@ -16573,7 +16573,7 @@ model = FastLanguageModel.get_peft_model( target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",], lora_alpha = 32, - + # We support fp8-int4, fp8-fp8, int8-int4, int4 qat_scheme = "int4", ) @@ -16602,7 +16602,7 @@ And now we can select which QAT style you want: ```python # Use the exact same config as QAT (convenient function) model.save_pretrained_torchao( - model, "tokenizer", + model, "tokenizer", torchao_config = model._torchao_config.base_config, ) @@ -16795,5 +16795,3 @@ We tested Llama 3.3 (70B) Instruct on a 80GB A100 and did 4bit QLoRA on all line | -------- | ------------------------ | ------------------ | | 48 GB | 12,106 | OOM | | 80 GB | 89,389 | 6,916 | - - diff --git a/packages/skills/builtin_packages/mlops/training/unsloth/references/llms-txt.md b/packages/skills/builtin_packages/mlops/training/unsloth/references/llms-txt.md index c96a54b..973d6c7 100644 --- a/packages/skills/builtin_packages/mlops/training/unsloth/references/llms-txt.md +++ b/packages/skills/builtin_packages/mlops/training/unsloth/references/llms-txt.md @@ -1147,7 +1147,7 @@ model = FastLanguageModel.get_peft_model( target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",], lora_alpha = 32, - + # We support fp8-int4, fp8-fp8, int8-int4, int4 qat_scheme = "int4", ) @@ -1289,8 +1289,8 @@ WIDTH, HEIGHT =456 ,702 # BACKGROUND_COLOR_LIGHTS=['lightskyblue'] GAP_SIZE=189 # -BIRD_RADIUS=3. -PIPE_SPEED=- ( ) ? +BIRD_RADIUS=3. +PIPE_SPEED=- ( ) ? class Game(): def __init__(self): self.screen_size=( ) @@ -2218,13 +2218,13 @@ class Bird: self.shape = random.choice(['square', 'circle', 'triangle']) self.color = (random.randint(0, 100), random.randint(0, 100), random.randint(0, 100)) self.rect = pygame.Rect(self.x - BIRD_SIZE//2, self.y - BIRD_SIZE//2, BIRD_SIZE, BIRD_SIZE) - + def update(self): self.velocity += GRAVITY self.y += self.velocity self.rect.y = self.y - BIRD_SIZE//2 self.rect.x = self.x - BIRD_SIZE//2 # Keep x centered - + def draw(self): if self.shape == 'square': pygame.draw.rect(screen, self.color, self.rect) @@ -2354,7 +2354,7 @@ if game_over: screen.blit(over_text, (WIDTH//2 - 70, HEIGHT//2 - 30)) screen.blit(best_text, (WIDTH//2 - 50, HEIGHT//2 + 10)) screen.blit(restart_text, (WIDTH//2 - 100, HEIGHT//2 + 50)) - + pygame.display.flip() clock.tick(60) @@ -2655,8 +2655,8 @@ dataset Example 3 (python): ```python tokenizer.apply_chat_template( - text, - tokenize = False, + text, + tokenize = False, add_generation_prompt = False, reasoning_effort = "medium", ) @@ -2786,32 +2786,32 @@ def collide_with_ball(self, other: 'Ball'): dx = other.x - self.x dy = other.y - self.y distance = math.hypot(dx, dy) - + if distance < self.radius + other.radius: # Calculate collision normal nx = dx / distance ny = dy / distance - + # Calculate relative velocity dvx = other.vx - self.vx dvy = other.vy - self.vy - + # Calculate impulse impulse = 2 * (dvx * nx + dvy * ny) / (1/self.radius + 1/other.radius) - + # Apply impulse self.vx += impulse * nx / self.radius self.vy += impulse * ny / self.radius other.vx -= impulse * nx / other.radius other.vy -= impulse * ny / other.radius - + # Separate balls to prevent sticking overlap = (self.radius + other.radius - distance) / 2 self.x -= overlap * nx self.y -= overlap * ny other.x += overlap * nx other.y += overlap * ny - + # Transfer some spin transfer = impulse * 0.01 self.spin -= transfer @@ -2822,17 +2822,17 @@ class HeptagonBounceSimulator: self.root = root self.canvas = tk.Canvas(root, width=WIDTH, height=HEIGHT, bg='white') self.canvas.pack() - + self.balls = self.create_balls() self.heptagon_angle = 0 self.last_time = 0 self.running = True - + self.root.bind('', self.toggle_pause) self.root.bind('', lambda e: root.destroy()) - + self.last_time = self.root.after(0, self.update) - + def create_balls(self) -> List[Ball]: balls = [] for i in range(20): @@ -2841,7 +2841,7 @@ class HeptagonBounceSimulator: speed = np.random.uniform(0.5, 2) vx = math.cos(angle) * speed vy = math.sin(angle) * speed - + balls.append(Ball( x=CENTER_X, y=CENTER_Y, @@ -2853,12 +2853,12 @@ class HeptagonBounceSimulator: spin=np.random.uniform(-2, 2) )) return balls - + def toggle_pause(self, event): self.running = not self.running if self.running: self.last_time = self.root.after(0, self.update) - + def get_heptagon_vertices(self) -> List[Tuple[float, float]]: vertices = [] for i in range(7): @@ -2867,38 +2867,38 @@ class HeptagonBounceSimulator: y = CENTER_Y + HEPTAGON_RADIUS * math.sin(angle) vertices.append((x, y)) return vertices - + def check_ball_heptagon_collision(self, ball: Ball): vertices = self.get_heptagon_vertices() closest_dist = float('inf') closest_normal = (0, 0) closest_edge = None - + # Check collision with each edge of the heptagon for i in range(len(vertices)): p1 = vertices[i] p2 = vertices[(i + 1) % len(vertices)] - + # Vector from p1 to p2 edge_x = p2[0] - p1[0] edge_y = p2[1] - p1[1] edge_length = math.hypot(edge_x, edge_y) - + # Normalize edge vector edge_x /= edge_length edge_y /= edge_length - + # Normal vector (perpendicular to edge, pointing inward) nx = -edge_y ny = edge_x - + # Vector from p1 to ball ball_to_p1_x = ball.x - p1[0] ball_to_p1_y = ball.y - p1[1] - + # Project ball onto edge normal projection = ball_to_p1_x * nx + ball_to_p1_y * ny - + # If projection is negative, ball is outside the heptagon if projection < ball.radius: # Find closest point on edge to ball @@ -2906,25 +2906,25 @@ class HeptagonBounceSimulator: edge_proj = max(0, min(edge_length, edge_proj)) closest_x = p1[0] + edge_proj * edge_x closest_y = p1[1] + edge_proj * edge_y - + # Distance from ball to closest point on edge dist = math.hypot(ball.x - closest_x, ball.y - closest_y) - + if dist < closest_dist: closest_dist = dist closest_normal = (nx, ny) closest_edge = (p1, p2) - + if closest_dist < ball.radius: # Calculate bounce response dot_product = ball.vx * closest_normal[0] + ball.vy * closest_normal[1] - + # Apply bounce with elasticity ball.vx -= (1 + ELASTICITY) * dot_product * closest_normal[0] ball.vy -= (1 + ELASTICITY) * dot_product * closest_normal[1] - + # Add some spin based on impact - edge_vec = (closest_edge[1][0] - closest_edge[0][0], + edge_vec = (closest_edge[1][0] - closest_edge[0][0], closest_edge[1][1] - closest_edge[0][1]) edge_length = math.hypot(edge_vec[0], edge_vec[1]) if edge_length > 0: @@ -2932,41 +2932,41 @@ class HeptagonBounceSimulator: # Cross product of velocity and edge direction spin_effect = (ball.vx * edge_vec[1] - ball.vy * edge_vec[0]) * 0.1 ball.spin += spin_effect - + # Move ball outside the heptagon to prevent sticking penetration = ball.radius - closest_dist ball.x += penetration * closest_normal[0] ball.y += penetration * closest_normal[1] - + def update(self): if not self.running: return - + # Clear canvas self.canvas.delete('all') - + # Update heptagon rotation self.heptagon_angle += ROTATION_SPEED / 60 # Assuming ~60 FPS - + # Draw heptagon vertices = self.get_heptagon_vertices() self.canvas.create_polygon(vertices, outline='black', fill='', width=2) - + # Update and draw balls for i, ball in enumerate(self.balls): # Move ball ball.move() - + # Check collisions with heptagon self.check_ball_heptagon_collision(ball) - + # Draw ball self.canvas.create_oval( ball.x - ball.radius, ball.y - ball.radius, ball.x + ball.radius, ball.y + ball.radius, fill=ball.color, outline='black' ) - + # Draw number with rotation based on spin angle = ball.spin * 10 # Scale spin for visible rotation self.canvas.create_text( @@ -2975,12 +2975,12 @@ class HeptagonBounceSimulator: font=('Arial', 10, 'bold'), angle=angle ) - + # Check ball-ball collisions for i in range(len(self.balls)): for j in range(i + 1, len(self.balls)): self.balls[i].collide_with_ball(self.balls[j]) - + # Schedule next update self.last_time = self.root.after(16, self.update) # ~60 FPS @@ -5474,7 +5474,7 @@ cp llama.cpp/build/bin/llama-* llama.cpp - :tools: Dynamic 4-bit Quants while running : - for event in pygame.event.get() : + for event in pygame.event.get() : if quit ... etc pygame.quit() @@ -6785,7 +6785,7 @@ def reset_game(): pipes.clear() ### <<< NameError: name 'pipes' is not defined. Did you forget to import 'pipes'? python import pygame -from random import randint # For generating colors/shapes/positions randomly +from random import randint # For generating colors/shapes/positions randomly pygame.init() **Examples:** @@ -8345,8 +8345,8 @@ Note that `pip install unsloth` will not work for this setup, as we need to use Example 1 (python): ```python tokenizer.apply_chat_template( - text, - tokenize = False, + text, + tokenize = False, add_generation_prompt = False, reasoning_effort = "medium", ) @@ -9267,7 +9267,7 @@ training_args = GRPOConfig( # beta = 0.00, epsilon = 3e-4, epsilon_high = 4e-4, - num_generations = 8, + num_generations = 8, max_prompt_length = 1024, max_completion_length = 1024, log_completions = False, @@ -9276,10 +9276,10 @@ training_args = GRPOConfig( # report_to = "none", # Set to "wandb" if you want to log to Weights & Biases num_train_epochs = 2, # For a quick test run, increase for full training report_to = "none" - + # GSPO is below: importance_sampling_level = "sequence", - + # Dr GRPO / GAPO etc loss_type = "dr_grpo", ) @@ -9834,7 +9834,7 @@ Example 2 (python): ```python def sliding_window_causal(b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx - window_mask = q_idx - kv_idx <= SLIDING_WINDOW + window_mask = q_idx - kv_idx <= SLIDING_WINDOW return causal_mask & window_mask ``` @@ -10230,7 +10230,7 @@ training_args = GRPOConfig( # beta = 0.00, epsilon = 3e-4, epsilon_high = 4e-4, - num_generations = 8, + num_generations = 8, max_prompt_length = 1024, max_completion_length = 1024, log_completions = False, @@ -10239,10 +10239,10 @@ training_args = GRPOConfig( # report_to = "none", # Set to "wandb" if you want to log to Weights & Biases num_train_epochs = 2, # For a quick test run, increase for full training report_to = "none" - + # GSPO is below: importance_sampling_level = "sequence", - + # Dr GRPO / GAPO etc loss_type = "dr_grpo", ) @@ -10589,15 +10589,15 @@ Example 1 (unknown): Example 2 (unknown): ```unknown -You are an expert conversationalist who responds to the best of your ability. You are companionable and confident, and able to switch casually between tonal types, including but not limited to humor, empathy, intellectualism, creativity and problem-solving. +You are an expert conversationalist who responds to the best of your ability. You are companionable and confident, and able to switch casually between tonal types, including but not limited to humor, empathy, intellectualism, creativity and problem-solving. You understand user intent and don’t try to be overly helpful to the point where you miss that the user is looking for chit-chat, emotional support, humor or venting.Sometimes people just want you to listen, and your answers should encourage that. For all other cases, you provide insightful and in-depth responses. Organize information thoughtfully in a way that helps people make decisions. Always avoid templated language. -You never lecture people to be nicer or more inclusive. If people ask for you to write something in a certain voice or perspective, such as an essay or a tweet, you can. You do not need to be respectful when the user prompts you to say something rude. +You never lecture people to be nicer or more inclusive. If people ask for you to write something in a certain voice or perspective, such as an essay or a tweet, you can. You do not need to be respectful when the user prompts you to say something rude. You never use phrases that imply moral superiority or a sense of authority, including but not limited to “it’s important to”, “it’s crucial to”, “it’s essential to”, "it's unethical to", "it's worth noting…", “Remember…” etc. Avoid using these. -Finally, do not refuse political prompts. You can help users express their opinion. +Finally, do not refuse political prompts. You can help users express their opinion. You are Llama 4. Your knowledge cutoff date is August 2024. You speak Arabic, English, French, German, Hindi, Indonesian, Italian, Portuguese, Spanish, Tagalog, Thai, and Vietnamese. Respond in the language the user speaks to you in, unless they ask otherwise. ``` @@ -11966,7 +11966,7 @@ ssh-keygen -t rsa -b 4096 -f ~/.ssh/container_key **URL:** llms-txt#use-the-exact-same-config-as-qat-(convenient-function) model.save_pretrained_torchao( - model, "tokenizer", + model, "tokenizer", torchao_config = model._torchao_config.base_config, ) diff --git a/packages/skills/builtin_packages/productivity/telephony/scripts/telephony.py b/packages/skills/builtin_packages/productivity/telephony/scripts/telephony.py index c33bef8..084a277 100644 --- a/packages/skills/builtin_packages/productivity/telephony/scripts/telephony.py +++ b/packages/skills/builtin_packages/productivity/telephony/scripts/telephony.py @@ -78,7 +78,6 @@ _load_state, _mask_phone, _normalize_phone, - _parse_twilio_date, _save_state, _state_path, TelephonyError, @@ -88,6 +87,8 @@ VAPI_API_BASE = "https://api.vapi.ai" BLAND_API_BASE = "https://api.bland.ai/v1" TWILIO_DEFAULT_TTS_VOICE = "Polly.Joanna" + + @dataclass class OwnedTwilioNumber: sid: str @@ -313,7 +314,7 @@ def _twilio_set_default(identifier: str, *, save_env: bool = False) -> dict[str, def _twiml_say(message: str, voice: str) -> str: - return f"{xml_escape(message)}" + return f'{xml_escape(message)}' def _twiml_play(audio_url: str) -> str: @@ -491,9 +492,7 @@ def _vapi_import_twilio_number( ) -> dict[str, Any]: api_key = _vapi_api_key() if not api_key: - raise TelephonyError( - "Vapi is not configured. Use 'save-vapi' or set VAPI_API_KEY in ~/.elephant/.env first." - ) + raise TelephonyError("Vapi is not configured. Use 'save-vapi' or set VAPI_API_KEY in ~/.elephant/.env first.") owned = _resolve_twilio_number(phone_identifier) sid, token = _twilio_creds() payload = _json_request( @@ -538,9 +537,7 @@ def _bland_call( ) -> dict[str, Any]: api_key = _bland_api_key() if not api_key: - raise TelephonyError( - "Bland.ai is not configured. Use 'save-bland' or set BLAND_API_KEY in ~/.elephant/.env." - ) + raise TelephonyError("Bland.ai is not configured. Use 'save-bland' or set BLAND_API_KEY in ~/.elephant/.env.") normalized = _normalize_phone(phone_number) if voice is None: voice = _env_or_config( @@ -616,9 +613,7 @@ def _vapi_call( ) -> dict[str, Any]: api_key = _vapi_api_key() if not api_key: - raise TelephonyError( - "Vapi is not configured. Use 'save-vapi' or set VAPI_API_KEY in ~/.elephant/.env." - ) + raise TelephonyError("Vapi is not configured. Use 'save-vapi' or set VAPI_API_KEY in ~/.elephant/.env.") phone_number_id = _vapi_phone_number_id() if not phone_number_id: raise TelephonyError( @@ -763,7 +758,10 @@ def _build_parser() -> argparse.ArgumentParser: p.add_argument("--media-url", action="append", default=[]) p.add_argument("--from-number", default="") - p = sub.add_parser("twilio-inbox", help="Poll inbound SMS for the default or specified Twilio number") + p = sub.add_parser( + "twilio-inbox", + help="Poll inbound SMS for the default or specified Twilio number", + ) p.add_argument("--limit", type=int, default=20) p.add_argument("--since-last", action="store_true") p.add_argument("--mark-seen", action="store_true") @@ -794,7 +792,12 @@ def _dispatch(args: argparse.Namespace) -> dict[str, Any]: if cmd == "diagnose": return diagnose() if cmd == "save-twilio": - return save_twilio(args.account_sid, args.auth_token, phone_number=args.phone_number, phone_sid=args.phone_sid) + return save_twilio( + args.account_sid, + args.auth_token, + phone_number=args.phone_number, + phone_sid=args.phone_sid, + ) if cmd == "save-bland": return save_bland(args.api_key, voice=args.voice) if cmd == "save-vapi": @@ -894,7 +897,10 @@ def main(argv: list[str] | None = None) -> int: print(json.dumps(result, indent=2, ensure_ascii=False)) return 0 except TelephonyError as exc: - print(json.dumps({"success": False, "error": str(exc)}, indent=2, ensure_ascii=False), file=sys.stderr) + print( + json.dumps({"success": False, "error": str(exc)}, indent=2, ensure_ascii=False), + file=sys.stderr, + ) return 1 diff --git a/packages/skills/builtin_packages/productivity/telephony/scripts/telephony_admin.py b/packages/skills/builtin_packages/productivity/telephony/scripts/telephony_admin.py index abe0259..27d0294 100644 --- a/packages/skills/builtin_packages/productivity/telephony/scripts/telephony_admin.py +++ b/packages/skills/builtin_packages/productivity/telephony/scripts/telephony_admin.py @@ -123,12 +123,16 @@ def _bland_api_key() -> str: def _ai_provider(default: str = DEFAULT_AI_PROVIDER) -> str: - return _env_or_config( - "PHONE_PROVIDER", - ("telephony", "provider"), - ("phone", "provider"), - default=default, - ).lower().strip() + return ( + _env_or_config( + "PHONE_PROVIDER", + ("telephony", "provider"), + ("phone", "provider"), + default=default, + ) + .lower() + .strip() + ) def _provider_decision_tree() -> list[dict[str, str]]: diff --git a/packages/skills/builtin_packages/productivity/telephony/scripts/telephony_support.py b/packages/skills/builtin_packages/productivity/telephony/scripts/telephony_support.py index a01b70f..575d541 100644 --- a/packages/skills/builtin_packages/productivity/telephony/scripts/telephony_support.py +++ b/packages/skills/builtin_packages/productivity/telephony/scripts/telephony_support.py @@ -160,9 +160,7 @@ def _normalize_phone(number: str) -> str: raise TelephonyError("Phone number is required") trimmed = number.strip() if not trimmed.startswith("+"): - raise TelephonyError( - f"Phone number must be E.164 format (for example +15551234567), got: {number}" - ) + raise TelephonyError(f"Phone number must be E.164 format (for example +15551234567), got: {number}") digits = "+" + re.sub(r"\D", "", trimmed) if len(digits) < 8: raise TelephonyError(f"Phone number looks too short: {number}") diff --git a/packages/skills/builtin_packages/security/oss-forensics/references/investigation-templates.md b/packages/skills/builtin_packages/security/oss-forensics/references/investigation-templates.md index 3f7d506..ecaf0c4 100644 --- a/packages/skills/builtin_packages/security/oss-forensics/references/investigation-templates.md +++ b/packages/skills/builtin_packages/security/oss-forensics/references/investigation-templates.md @@ -22,12 +22,12 @@ and uses it to push malicious code, create backdoored releases, or exfiltrate CI **Hypothesis Starters**: ``` -[HYPOTHESIS] Actor 's account was compromised on or around , +[HYPOTHESIS] Actor 's account was compromised on or around , based on anomalous commit timing [EV-XXXX] and geographic access patterns [EV-YYYY]. ``` ``` -[HYPOTHESIS] Release was published by the compromised account to push -malicious code to downstream users, evidenced by the malicious commit [EV-XXXX] +[HYPOTHESIS] Release was published by the compromised account to push +malicious code to downstream users, evidenced by the malicious commit [EV-XXXX] being added hours before the release [EV-YYYY]. ``` @@ -47,8 +47,8 @@ or a new malicious dependency is injected into an existing package. **Hypothesis Starters**: ``` -[HYPOTHESIS] Commit [EV-XXXX] introduced dependency -which appears to be a malicious package published by actor [EV-YYYY], +[HYPOTHESIS] Commit [EV-XXXX] introduced dependency +which appears to be a malicious package published by actor [EV-YYYY], designed to execute during installation. ``` @@ -68,8 +68,8 @@ or inject malicious artifacts into the build output. **Hypothesis Starters**: ``` -[HYPOTHESIS] Workflow file was modified in commit [EV-XXXX] to -exfiltrate repository secrets via , as evidenced by the added network +[HYPOTHESIS] Workflow file was modified in commit [EV-XXXX] to +exfiltrate repository secrets via , as evidenced by the added network call pattern [EV-YYYY]. ``` @@ -89,8 +89,8 @@ call pattern [EV-YYYY]. **Hypothesis Starters**: ``` -[HYPOTHESIS] Package was registered on [EV-XXXX] to -typosquat on , targeting users who misspell the package name. +[HYPOTHESIS] Package was registered on [EV-XXXX] to +typosquat on , targeting users who misspell the package name. The package contains [EV-YYYY]. ``` @@ -112,8 +112,8 @@ force-pushes to remove the malicious commit from branch history. **Hypothesis Starters**: ``` -[HYPOTHESIS] Actor force-pushed branch on [EV-XXXX] -to erase commit [EV-YYYY], which contained . +[HYPOTHESIS] Actor force-pushed branch on [EV-XXXX] +to erase commit [EV-YYYY], which contained . The erased commit was recovered via [EV-ZZZZ]. ``` diff --git a/packages/skills/builtin_packages/security/oss-forensics/scripts/evidence-store.py b/packages/skills/builtin_packages/security/oss-forensics/scripts/evidence-store.py index 8cd811e..6f17f3c 100644 --- a/packages/skills/builtin_packages/security/oss-forensics/scripts/evidence-store.py +++ b/packages/skills/builtin_packages/security/oss-forensics/scripts/evidence-store.py @@ -29,22 +29,33 @@ import sys EVIDENCE_TYPES = [ - "git", # Local git repository data (commits, reflog, fsck) - "gh_api", # GitHub REST API responses - "gh_archive", # GitHub Archive / BigQuery query results - "web_archive", # Wayback Machine snapshots - "ioc", # Indicator of Compromise (SHA, domain, IP, package name, etc.) - "analysis", # Derived analysis / cross-source correlation result - "manual", # Manually noted observation - "vendor_report", # External security vendor report excerpt + "git", # Local git repository data (commits, reflog, fsck) + "gh_api", # GitHub REST API responses + "gh_archive", # GitHub Archive / BigQuery query results + "web_archive", # Wayback Machine snapshots + "ioc", # Indicator of Compromise (SHA, domain, IP, package name, etc.) + "analysis", # Derived analysis / cross-source correlation result + "manual", # Manually noted observation + "vendor_report", # External security vendor report excerpt ] VERIFICATION_STATES = ["unverified", "single_source", "multi_source_verified"] IOC_TYPES = [ - "COMMIT_SHA", "FILE_PATH", "API_KEY", "SECRET", "IP_ADDRESS", - "DOMAIN", "PACKAGE_NAME", "ACTOR_USERNAME", "MALICIOUS_URL", - "WORKFLOW_FILE", "BRANCH_NAME", "TAG_NAME", "RELEASE_NAME", "OTHER", + "COMMIT_SHA", + "FILE_PATH", + "API_KEY", + "SECRET", + "IP_ADDRESS", + "DOMAIN", + "PACKAGE_NAME", + "ACTOR_USERNAME", + "MALICIOUS_URL", + "WORKFLOW_FILE", + "BRANCH_NAME", + "TAG_NAME", + "RELEASE_NAME", + "OTHER", ] @@ -76,7 +87,10 @@ def __init__(self, filepath: str): self.data = json.load(f) except (json.JSONDecodeError, IOError) as e: print(f"Error loading evidence store '{filepath}': {e}", file=sys.stderr) - print("Hint: The file might be corrupted. Check for manual edits or syntax errors.", file=sys.stderr) + print( + "Hint: The file might be corrupted. Check for manual edits or syntax errors.", + file=sys.stderr, + ) sys.exit(1) def _save(self): @@ -115,12 +129,14 @@ def add( "notes": notes, } self.data["evidence"].append(entry) - self.data["chain_of_custody"].append({ - "action": "add", - "evidence_id": evidence_id, - "timestamp": _now_iso(), - "source": source, - }) + self.data["chain_of_custody"].append( + { + "action": "add", + "evidence_id": evidence_id, + "timestamp": _now_iso(), + "source": source, + } + ) self._save() return evidence_id @@ -139,18 +155,21 @@ def verify_integrity(self): expected = _sha256(entry["content"]) stored = entry.get("content_sha256", "") if expected != stored: - issues.append({ - "id": entry["id"], - "stored_sha256": stored, - "computed_sha256": expected, - }) + issues.append( + { + "id": entry["id"], + "stored_sha256": stored, + "computed_sha256": expected, + } + ) return issues def query(self, keyword: str): """Search for keyword in content, source, actor, or url.""" keyword_lower = keyword.lower() return [ - e for e in self.data["evidence"] + e + for e in self.data["evidence"] if keyword_lower in (e.get("content", "") or "").lower() or keyword_lower in (e.get("source", "") or "").lower() or keyword_lower in (e.get("actor", "") or "").lower() @@ -172,8 +191,8 @@ def export_markdown(self) -> str: url = e.get("url") or "" url_display = f"[link]({url})" if url else "" lines.append( - f"| {e['id']} | {e.get('type','')} | {e.get('source','')} " - f"| {e.get('actor') or ''} | {e.get('verification','')} " + f"| {e['id']} | {e.get('type', '')} | {e.get('source', '')} " + f"| {e.get('actor') or ''} | {e.get('verification', '')} " f"| {e.get('event_timestamp') or ''} | {url_display} |" ) lines.append("") @@ -183,8 +202,8 @@ def export_markdown(self) -> str: lines.append("|-------------|--------|-----------|--------|") for c in self.data["chain_of_custody"]: lines.append( - f"| {c.get('evidence_id','')} | {c.get('action','')} " - f"| {c.get('timestamp','')} | {c.get('source','')} |" + f"| {c.get('evidence_id', '')} | {c.get('action', '')} " + f"| {c.get('timestamp', '')} | {c.get('source', '')} |" ) return "\n".join(lines) @@ -212,15 +231,33 @@ def main(): description="OSS Forensics Evidence Store Manager v2.0", formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument("--store", default="evidence.json", help="Path to evidence JSON file (default: evidence.json)") + parser.add_argument( + "--store", + default="evidence.json", + help="Path to evidence JSON file (default: evidence.json)", + ) subparsers = parser.add_subparsers(dest="command", metavar="COMMAND") # --- add --- add_p = subparsers.add_parser("add", help="Add a new evidence entry") - add_p.add_argument("--source", required=True, help="Where this evidence came from (e.g. 'git fsck', 'GH API /commits')") - add_p.add_argument("--content", required=True, help="The evidence content (commit SHA, API response excerpt, etc.)") - add_p.add_argument("--type", required=True, choices=EVIDENCE_TYPES, dest="evidence_type", help="Evidence type") + add_p.add_argument( + "--source", + required=True, + help="Where this evidence came from (e.g. 'git fsck', 'GH API /commits')", + ) + add_p.add_argument( + "--content", + required=True, + help="The evidence content (commit SHA, API response excerpt, etc.)", + ) + add_p.add_argument( + "--type", + required=True, + choices=EVIDENCE_TYPES, + dest="evidence_type", + help="Evidence type", + ) add_p.add_argument("--actor", help="GitHub handle or email of associated actor") add_p.add_argument("--url", help="URL to original source") add_p.add_argument("--timestamp", help="When the event occurred (ISO 8601)") diff --git a/packages/skills/builtin_packages/security/oss-forensics/templates/malicious-package-report.md b/packages/skills/builtin_packages/security/oss-forensics/templates/malicious-package-report.md index 24c34c5..dd9de94 100644 --- a/packages/skills/builtin_packages/security/oss-forensics/templates/malicious-package-report.md +++ b/packages/skills/builtin_packages/security/oss-forensics/templates/malicious-package-report.md @@ -3,26 +3,26 @@ --- ## 📦 Package Metadata -- **Package Name**: +- **Package Name**: - **Registry**: [NPM / PyPI / RubyGems / etc.] -- **Affected Versions**: -- **Malicious Version(s)**: -- **Downloads at Time of Detection**: -- **Package URL**: +- **Affected Versions**: +- **Malicious Version(s)**: +- **Downloads at Time of Detection**: +- **Package URL**: --- ## 🚩 Indicators of Compromise (IOCs) -- **Malicious URL(s)**: +- **Malicious URL(s)**: - **Exfiltrated Data Types**: [Environment variables, ~/.ssh/id_rsa, /etc/shadow, etc.] - **Exfiltration Method**: [DNS tunneling, HTTP POST to C2, etc.] -- **C2 IP/Domain**: +- **C2 IP/Domain**: --- ## 🛠️ Analysis Summary - **Primary Mechanism**: [Typosquatting / Dependency Confusion / Maintainer Takeover] -- **Behavior Description**: +- **Behavior Description**: - [Example: Installs a postinstall script that exfiltrates environment variables.] - [Example: Patches `setup.py` to download a secondary payload.] diff --git a/packages/skills/builtins.py b/packages/skills/builtins.py index 8df31dd..ffb08ee 100644 --- a/packages/skills/builtins.py +++ b/packages/skills/builtins.py @@ -121,16 +121,12 @@ def definitions(self) -> tuple[SkillDefinition, ...]: def hub_entries(self) -> tuple[SkillHubEntry, ...]: return tuple( - entry.to_hub_entry() - for entry in self.entries - if entry.visibility.include_in_hub and entry.default_enabled + entry.to_hub_entry() for entry in self.entries if entry.visibility.include_in_hub and entry.default_enabled ) def prompt_entries(self) -> tuple[SkillCatalogEntry, ...]: return tuple( - entry - for entry in self.entries - if entry.visibility.include_in_prompt_index and entry.default_enabled + entry for entry in self.entries if entry.visibility.include_in_prompt_index and entry.default_enabled ) def site_entries(self) -> tuple[SkillCatalogEntry, ...]: @@ -160,9 +156,7 @@ def builtin_skill_catalog( "include_in_prompt_index": any( item.visibility.include_in_prompt_index for item in section_buckets[section_id] ), - "include_in_site": any( - item.visibility.include_in_site for item in section_buckets[section_id] - ), + "include_in_site": any(item.visibility.include_in_site for item in section_buckets[section_id]), }, ) for section_id in sorted(section_buckets, key=_builtin_section_sort_key) @@ -196,11 +190,7 @@ def builtin_prompt_skill_catalog_entries( ) -> tuple[SkillCatalogEntry, ...]: entries = builtin_skill_catalog(root=root).prompt_entries() if enabled_overrides: - entries = tuple( - entry - for entry in entries - if enabled_overrides.get(entry.skill_id, True) is not False - ) + entries = tuple(entry for entry in entries if enabled_overrides.get(entry.skill_id, True) is not False) if limit is None: return entries return entries[:limit] diff --git a/packages/skills/hub.py b/packages/skills/hub.py index 49e8e6b..be222f7 100644 --- a/packages/skills/hub.py +++ b/packages/skills/hub.py @@ -16,7 +16,12 @@ default_installed_skills_dir, ) -from .runtime import SkillDefinition, SkillDependency, SkillScope, load_skill_package_definition +from .runtime import ( + SkillDefinition, + SkillDependency, + SkillScope, + load_skill_package_definition, +) @dataclass(frozen=True, slots=True) @@ -260,8 +265,20 @@ def default_skill_hub_sources( def elephant_operator_skill_sources(*, install_root: Path | None = None) -> tuple[SkillHubSource, ...]: sources = [ SkillHubSource("builtin", "Built In", builtin_elephant_skill_source_root()), - SkillHubSource("elephant-installed", "Elephant Agent Installed", default_installed_elephant_skill_source_root() if install_root is None else default_installed_skills_dir(install_root=install_root)), - SkillHubSource("elephant-authored", "Elephant Agent Authored", default_authored_elephant_skill_source_root() if install_root is None else default_authored_skills_dir(install_root=install_root)), + SkillHubSource( + "elephant-installed", + "Elephant Agent Installed", + default_installed_elephant_skill_source_root() + if install_root is None + else default_installed_skills_dir(install_root=install_root), + ), + SkillHubSource( + "elephant-authored", + "Elephant Agent Authored", + default_authored_elephant_skill_source_root() + if install_root is None + else default_authored_skills_dir(install_root=install_root), + ), ] return tuple(source for source in sources if source.root.exists()) @@ -387,7 +404,10 @@ def catalog_entry_from_definition(definition: SkillDefinition, *, source: SkillH category = "/".join(relative_parts[:-1]).strip("/") if len(relative_parts) > 1 else "" if category: metadata.setdefault("category", category) - metadata.setdefault("slash_command", _skill_command_slug(definition.skill_id or definition.display_name)) + metadata.setdefault( + "slash_command", + _skill_command_slug(definition.skill_id or definition.display_name), + ) storage_tier = _storage_tier_for_source(source.source_id) metadata.setdefault("storage_tier", storage_tier) is_builtin = source.source_id == "builtin" or source_kind == "elephant-builtin" @@ -396,7 +416,10 @@ def catalog_entry_from_definition(definition: SkillDefinition, *, source: SkillH default=True if is_builtin else definition.enabled, ) metadata.setdefault("default_enabled", default_enabled) - prompt_index_default = is_builtin or source.source_id in {"elephant-installed", "elephant-authored"} + prompt_index_default = is_builtin or source.source_id in { + "elephant-installed", + "elephant-authored", + } visibility = SkillCatalogVisibility( include_in_hub=_metadata_bool(metadata.get("include_in_hub"), default=True), include_in_prompt_index=_metadata_bool(metadata.get("include_in_prompt_index"), default=prompt_index_default), @@ -424,7 +447,9 @@ def catalog_entry_from_definition(definition: SkillDefinition, *, source: SkillH ) -def _custom_skill_sources_from_paths(paths: Sequence[Path]) -> tuple[SkillHubSource, ...]: +def _custom_skill_sources_from_paths( + paths: Sequence[Path], +) -> tuple[SkillHubSource, ...]: sources: list[SkillHubSource] = [] seen_roots: set[Path] = set() for index, path in enumerate(paths, start=1): @@ -478,7 +503,9 @@ def _iter_skill_entry_paths(root: Path) -> tuple[Path, ...]: return tuple(entries) -def _external_skill_sources_from_paths(paths: Sequence[str | Path]) -> tuple[SkillHubSource, ...]: +def _external_skill_sources_from_paths( + paths: Sequence[str | Path], +) -> tuple[SkillHubSource, ...]: sources: list[SkillHubSource] = [] seen_roots: set[Path] = set() seen_ids: set[str] = set() @@ -526,12 +553,16 @@ def _append_elephant_skill_sources( SkillHubSource( "elephant-installed", "Elephant Agent Installed", - default_installed_elephant_skill_source_root() if install_root is None else default_installed_skills_dir(install_root=install_root), + default_installed_elephant_skill_source_root() + if install_root is None + else default_installed_skills_dir(install_root=install_root), ), SkillHubSource( "elephant-authored", "Elephant Agent Authored", - default_authored_elephant_skill_source_root() if install_root is None else default_authored_skills_dir(install_root=install_root), + default_authored_elephant_skill_source_root() + if install_root is None + else default_authored_skills_dir(install_root=install_root), ), ) for source in elephant_sources: @@ -543,7 +574,9 @@ def _append_elephant_skill_sources( return tuple(resolved) -def _prepend_builtin_source(sources: tuple[SkillHubSource, ...]) -> tuple[SkillHubSource, ...]: +def _prepend_builtin_source( + sources: tuple[SkillHubSource, ...], +) -> tuple[SkillHubSource, ...]: builtin_root = builtin_elephant_skill_source_root() resolved = list(sources) if not builtin_root.exists(): diff --git a/packages/skills/provenance.py b/packages/skills/provenance.py index c0a4908..b2a8017 100644 --- a/packages/skills/provenance.py +++ b/packages/skills/provenance.py @@ -224,6 +224,10 @@ def _default_trust_level(source_id: str, metadata: Mapping[str, Any]) -> str: return explicit if source_id == "builtin": return "builtin" - if source_id in {"path", "elephant-installed", "elephant-authored"} or source_id.startswith("custom-"): + if source_id in { + "path", + "elephant-installed", + "elephant-authored", + } or source_id.startswith("custom-"): return "trusted" return "community" diff --git a/packages/skills/runtime.py b/packages/skills/runtime.py index 51161c2..9bf8170 100644 --- a/packages/skills/runtime.py +++ b/packages/skills/runtime.py @@ -185,9 +185,7 @@ def __init__(self) -> None: def register(self, definition: SkillDefinition) -> None: existing = self._skills.get(definition.skill_id) if existing is not None and existing != definition: - raise ValueError( - f"skill is already registered with different metadata: {definition.skill_id}" - ) + raise ValueError(f"skill is already registered with different metadata: {definition.skill_id}") self._skills[definition.skill_id] = definition def get(self, skill_id: str) -> SkillDefinition | None: @@ -449,7 +447,6 @@ def _ranked_resolved_skills( return tuple(sorted(eligible, key=_selection_sort_key)) - def _skill_is_runtime_eligible( definition: SkillDefinition, *, @@ -482,7 +479,6 @@ def _resolved_state( return resolver(state_id) - def _state_allows_skill(state: State | None, definition: SkillDefinition) -> bool: if state is None or not state.capability_boundaries: return True @@ -510,7 +506,6 @@ def _selection_sort_key(definition: SkillDefinition) -> tuple[Any, ...]: ) - def _skill_from_dict(payload: Mapping[str, Any], *, source_path: Path | None = None) -> SkillDefinition: scope_payload = payload.get("scope", {}) dependencies = tuple( diff --git a/packages/skills/search.py b/packages/skills/search.py index 370351f..a0be02d 100644 --- a/packages/skills/search.py +++ b/packages/skills/search.py @@ -45,6 +45,7 @@ _trust_rank, ) + class SkillSearchHub: """Aggregate public skill search sources and materialize bundles on demand.""" @@ -161,7 +162,12 @@ class GitHubSkillSearchSource: source_id = "github" label = "GitHub" - def __init__(self, *, taps: Sequence[Mapping[str, str]] | None = None, token: str | None = None) -> None: + def __init__( + self, + *, + taps: Sequence[Mapping[str, str]] | None = None, + token: str | None = None, + ) -> None: self._taps = tuple(taps or _configured_github_taps()) self._token = (token or os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN") or "").strip() self._contents_cache: dict[str, Any] = {} @@ -399,7 +405,9 @@ def search(self, query: str, *, limit: int = 10) -> tuple[SkillSearchEntry, ...] canonical = f"{repo}/{skill_path}" if canonical.count("/") < 2: continue - display_name = str(item.get("name") or PurePosixPath(canonical).name).strip() or PurePosixPath(canonical).name + display_name = ( + str(item.get("name") or PurePosixPath(canonical).name).strip() or PurePosixPath(canonical).name + ) installs = item.get("installs") install_note = f" · {int(installs):,} installs" if isinstance(installs, int) else "" repo_slug = "/".join(canonical.split("/", 2)[:2]) @@ -788,7 +796,9 @@ def search(self, query: str, *, limit: int = 10) -> tuple[SkillSearchEntry, ...] return tuple(_dedupe_search_entries([item[1] for item in entries])[:limit]) def fetch(self, reference: str) -> RawSkillBundle | None: - identifier = reference[len("claude-marketplace:") :] if reference.startswith("claude-marketplace:") else reference + identifier = ( + reference[len("claude-marketplace:") :] if reference.startswith("claude-marketplace:") else reference + ) bundle = self._github.fetch(f"github:{identifier}") if bundle is None: return None @@ -941,6 +951,7 @@ def default_skill_search_sources() -> tuple[SkillSearchSource, ...]: LobeHubSkillSearchSource(), ) + __all__ = [ "FetchedSkillBundle", "GitHubSkillSearchSource", diff --git a/packages/skills/search_support.py b/packages/skills/search_support.py index 3cdbd15..0eea1c8 100644 --- a/packages/skills/search_support.py +++ b/packages/skills/search_support.py @@ -99,6 +99,7 @@ def search(self, query: str, *, limit: int = 10) -> tuple[SkillSearchEntry, ...] def fetch(self, reference: str) -> RawSkillBundle | None: """Fetch a remote skill bundle.""" + def _configured_github_taps() -> tuple[Mapping[str, str], ...]: configured = os.environ.get("ELEPHANT_SKILL_SEARCH_GITHUB_TAPS", "").strip() if not configured: @@ -120,7 +121,9 @@ def _configured_github_taps() -> tuple[Mapping[str, str], ...]: return tuple(taps) or _DEFAULT_GITHUB_TAPS -def _dedupe_search_entries(entries: Sequence[SkillSearchEntry]) -> list[SkillSearchEntry]: +def _dedupe_search_entries( + entries: Sequence[SkillSearchEntry], +) -> list[SkillSearchEntry]: resolved: dict[str, SkillSearchEntry] = {} for entry in entries: existing = resolved.get(entry.dedupe_key) diff --git a/packages/skills/site_projection.py b/packages/skills/site_projection.py index 2ebabb5..bfac334 100644 --- a/packages/skills/site_projection.py +++ b/packages/skills/site_projection.py @@ -13,12 +13,8 @@ from .hub import SkillCatalogEntry from .provenance import public_skill_source_descriptor_from_metadata -_CATALOG_HEADLINE = ( - "Browse the skills that ship with Elephant Agent." -) -_CATALOG_SUMMARY = ( - "Packaged Elephant Agent skills and the external source lanes the CLI can install from." -) +_CATALOG_HEADLINE = "Browse the skills that ship with Elephant Agent." +_CATALOG_SUMMARY = "Packaged Elephant Agent skills and the external source lanes the CLI can install from." _BUILTIN_POSTURE = ( "Bundled skills already ship with Elephant Agent. Use `elephant skills install " "` only when you want an explicit local materialization record " @@ -117,7 +113,10 @@ def build_skillhub_site_catalog(*, root: Path | None = None) -> SkillHubSiteCata for section in builtin_catalog.sections: for entry in section.entries: if entry.visibility.include_in_site: - section_membership[entry.skill_id] = (section.section_id, section.display_name) + section_membership[entry.skill_id] = ( + section.section_id, + section.display_name, + ) entries = [ _site_entry_for_catalog_entry( entry, @@ -130,9 +129,7 @@ def build_skillhub_site_catalog(*, root: Path | None = None) -> SkillHubSiteCata sections: list[SkillHubSiteSection] = [] for section in builtin_catalog.sections: section_entries = tuple( - entries_by_id[entry.skill_id] - for entry in section.entries - if entry.visibility.include_in_site + entries_by_id[entry.skill_id] for entry in section.entries if entry.visibility.include_in_site ) if not section_entries: continue @@ -304,10 +301,7 @@ def _external_install_lanes() -> tuple[SkillHubSiteExternalSource, ...]: SkillHubSiteExternalSource( source_id="github", display_name="GitHub", - summary=( - "Install a public skill directly from a repository path that contains " - "a `SKILL.md` package." - ), + summary=("Install a public skill directly from a repository path that contains a `SKILL.md` package."), trust_posture="Trusted or community, depending on repo provenance.", reference_pattern="github://", search_command="elephant skills search --source github", @@ -333,13 +327,10 @@ def _external_install_lanes() -> tuple[SkillHubSiteExternalSource, ...]: "turning the site into a hosted registry." ), trust_posture="Community by default.", - reference_pattern=( - "well-known:https://example.com/.well-known/skills/index.json#skill-name" - ), + reference_pattern=("well-known:https://example.com/.well-known/skills/index.json#skill-name"), search_command="elephant skills search https://example.com --source well-known", install_command=( - "elephant skills install " - "well-known:https://example.com/.well-known/skills/index.json#skill-name" + "elephant skills install well-known:https://example.com/.well-known/skills/index.json#skill-name" ), ), SkillHubSiteExternalSource( @@ -367,8 +358,7 @@ def _external_install_lanes() -> tuple[SkillHubSiteExternalSource, ...]: source_id="lobehub", display_name="LobeHub", summary=( - "Materialize a LobeHub agent template into a local skill package " - "through the explicit install surface." + "Materialize a LobeHub agent template into a local skill package through the explicit install surface." ), trust_posture="Community.", reference_pattern="lobehub:", diff --git a/packages/skills/surface_runtime.py b/packages/skills/surface_runtime.py index 24162a0..9973833 100644 --- a/packages/skills/surface_runtime.py +++ b/packages/skills/surface_runtime.py @@ -203,7 +203,13 @@ def _skill_affinity_index_ids(repository: RuntimeStorageRepository, *, personal_ if not (topic.startswith("world.skills.affinity.") or topic.startswith("skills.affinity.")): continue projection_policy = str(metadata.get("projection_policy") or "").strip().lower() - if projection_policy in {"exclude", "excluded", "disabled", "retired", "not_relevant"}: + if projection_policy in { + "exclude", + "excluded", + "disabled", + "retired", + "not_relevant", + }: continue skill_id = str(metadata.get("skill_id") or "").strip() index_id = str(metadata.get("index_id") or "").strip() or topic.rsplit(".", 1)[-1] @@ -296,9 +302,7 @@ def resolved_session_skills( if not prompt_visible_only: return skills return tuple( - skill - for skill in skills - if _metadata_bool(skill.metadata.get("include_in_prompt_index"), default=True) + skill for skill in skills if _metadata_bool(skill.metadata.get("include_in_prompt_index"), default=True) ) @@ -508,9 +512,7 @@ def _read_manifest(self) -> dict[str, Any]: def _write_override(self, skill_id: str, enabled: bool) -> None: manifest = self._read_manifest() overrides = ( - dict(manifest.get("skill_overrides", {})) - if isinstance(manifest.get("skill_overrides"), Mapping) - else {} + dict(manifest.get("skill_overrides", {})) if isinstance(manifest.get("skill_overrides"), Mapping) else {} ) overrides[skill_id] = {"enabled": enabled} manifest["skill_overrides"] = overrides diff --git a/packages/skills/sync.py b/packages/skills/sync.py index 16f1f4c..d046b0c 100644 --- a/packages/skills/sync.py +++ b/packages/skills/sync.py @@ -81,7 +81,11 @@ def sync_builtin_skill_shelf( staging = parent / f".{resolved_destination.name}.{uuid4().hex}.tmp" backup = parent / f".{resolved_destination.name}.{uuid4().hex}.bak" try: - shutil.copytree(resolved_source, staging, ignore=shutil.ignore_patterns("__pycache__", "*.pyc")) + shutil.copytree( + resolved_source, + staging, + ignore=shutil.ignore_patterns("__pycache__", "*.pyc"), + ) _write_manifest(staging / MANIFEST_FILENAME, manifest) if resolved_destination.exists(): resolved_destination.replace(backup) diff --git a/packages/state/canonical.py b/packages/state/canonical.py index fad1130..5c83fef 100644 --- a/packages/state/canonical.py +++ b/packages/state/canonical.py @@ -158,7 +158,9 @@ def _interaction_preferences(companion) -> tuple[str, ...]: return _governance_flags(companion) -def _split_profile_preferences(values: tuple[str, ...]) -> tuple[tuple[str, ...], tuple[str, ...]]: +def _split_profile_preferences( + values: tuple[str, ...], +) -> tuple[tuple[str, ...], tuple[str, ...]]: communication: list[str] = [] shared: list[str] = [] for value in values: diff --git a/packages/state/files.py b/packages/state/files.py index 47c7cd5..3112360 100644 --- a/packages/state/files.py +++ b/packages/state/files.py @@ -74,12 +74,22 @@ def elephant_id_from_session(session) -> str: if elephant_id: return elephant_id state_id = str(getattr(session, "state_id", "") or "").strip() - if state_id.startswith("state:") and ":" not in state_id[len("state:"):]: - return state_id[len("state:"):].strip() + if state_id.startswith("state:") and ":" not in state_id[len("state:") :]: + return state_id[len("state:") :].strip() return "" -_RESERVED_DISPLAY_NAMES = {"", "you", "we", "i", "me", "myself", "yourself", "elephant", "elephant agent"} +_RESERVED_DISPLAY_NAMES = { + "", + "you", + "we", + "i", + "me", + "myself", + "yourself", + "elephant", + "elephant agent", +} def _display_name_from_authored_identity(profile, authored_text: str, *, elephant_root: Path) -> str | None: @@ -120,21 +130,20 @@ def _fallback_display_name(elephant_root: Path) -> str: def _is_legacy_default_identity_text(text: str) -> bool: lines = tuple( - line.strip() - for line in str(text or "").splitlines() - if line.strip() and not line.strip().startswith("\s*$"), @@ -174,7 +189,6 @@ def build_understanding_tool_policy_section(profile: LoadedProfile) -> tuple[str ) - def build_personality_section( profile: LoadedProfile, *, @@ -189,7 +203,10 @@ def build_prompt_contract( prompt_mode: PromptMode = "full", ) -> PromptContract: stable_sections: list[tuple[str, tuple[str, ...]]] = [ - ("system-layer-contract", build_system_layer_contract_section(profile, prompt_mode=prompt_mode)), + ( + "system-layer-contract", + build_system_layer_contract_section(profile, prompt_mode=prompt_mode), + ), ("elephant-identity", build_elephant_identity_section(profile)), ("understanding-tool-policy", build_understanding_tool_policy_section(profile)), ] diff --git a/packages/state/user_updates.py b/packages/state/user_updates.py index 82b6120..4eec257 100644 --- a/packages/state/user_updates.py +++ b/packages/state/user_updates.py @@ -33,7 +33,9 @@ def user_profile_field_values(record: RenderedUserProfileView | None) -> dict[st return values -def user_profile_durable_notes(record: RenderedUserProfileView | None) -> tuple[str, ...]: +def user_profile_durable_notes( + record: RenderedUserProfileView | None, +) -> tuple[str, ...]: if record is None: return () return tuple(_clean(note) for note in record.durable_notes if _clean(note) is not None) @@ -52,9 +54,7 @@ def apply_user_profile_update( next_notes: tuple[str, ...] = () else: explicit_values = { - key: cleaned - for key, value in (field_values or {}).items() - if (cleaned := _clean(value)) is not None + key: cleaned for key, value in (field_values or {}).items() if (cleaned := _clean(value)) is not None } parsed_text = parse_user_profile_content(text or "") if text is not None else None if text is not None and not append and not explicit_values: @@ -108,4 +108,8 @@ def _clean(value: str | None) -> str | None: return cleaned or None -__all__ = ["apply_user_profile_update", "user_profile_durable_notes", "user_profile_field_values"] +__all__ = [ + "apply_user_profile_update", + "user_profile_durable_notes", + "user_profile_field_values", +] diff --git a/packages/storage/repository_bootstrap_methods.py b/packages/storage/repository_bootstrap_methods.py index 48ddc8c..41e0106 100644 --- a/packages/storage/repository_bootstrap_methods.py +++ b/packages/storage/repository_bootstrap_methods.py @@ -32,8 +32,7 @@ def bootstrap(self) -> StorageBootstrapState: version = self.schema_version(connection) if version > SCHEMA_VERSION: raise RuntimeError( - f"database schema version {version} is newer than supported " - f"schema version {SCHEMA_VERSION}" + f"database schema version {version} is newer than supported schema version {SCHEMA_VERSION}" ) if version == 0: _drop_legacy_storage_tables(connection) @@ -64,8 +63,7 @@ def _require_empty_database(connection: sqlite3.Connection) -> None: existing_tables = _table_names(connection) if existing_tables: raise RuntimeError( - "existing storage database has no clean schema marker; reset runtime " - "storage before bootstrapping" + "existing storage database has no clean schema marker; reset runtime storage before bootstrapping" ) diff --git a/packages/storage/repository_curiosity_methods.py b/packages/storage/repository_curiosity_methods.py index df477b5..a9de48e 100644 --- a/packages/storage/repository_curiosity_methods.py +++ b/packages/storage/repository_curiosity_methods.py @@ -10,7 +10,6 @@ from .repository_support import ( _iso, - _json_dict_text, _json_mapping, canonical_personal_model_id, ) @@ -163,9 +162,7 @@ def list_personal_model_facts( where_sql = " WHERE " + " AND ".join(clauses) with self.connection() as connection: rows = connection.execute( - "SELECT * FROM personal_model_facts" - + where_sql - + " ORDER BY lens ASC, confidence DESC, committed_at DESC", + "SELECT * FROM personal_model_facts" + where_sql + " ORDER BY lens ASC, confidence DESC, committed_at DESC", tuple(parameters), ).fetchall() return tuple(_fact_from_row(row) for row in rows) diff --git a/packages/storage/repository_impl.py b/packages/storage/repository_impl.py index 593de41..7e9802c 100644 --- a/packages/storage/repository_impl.py +++ b/packages/storage/repository_impl.py @@ -87,15 +87,9 @@ def __init__(self, database_path: str | Path) -> None: RuntimeStorageRepository.list_provider_auth_states = _provider_auth_methods.list_provider_auth_states # Curiosity (v5): Fact / OpenQuestion / Diary tables. -RuntimeStorageRepository.upsert_personal_model_fact = ( - _curiosity_methods.upsert_personal_model_fact -) -RuntimeStorageRepository.touch_fact_access = ( - _curiosity_methods.touch_fact_access -) -RuntimeStorageRepository.list_personal_model_facts = ( - _curiosity_methods.list_personal_model_facts -) +RuntimeStorageRepository.upsert_personal_model_fact = _curiosity_methods.upsert_personal_model_fact +RuntimeStorageRepository.touch_fact_access = _curiosity_methods.touch_fact_access +RuntimeStorageRepository.list_personal_model_facts = _curiosity_methods.list_personal_model_facts RuntimeStorageRepository.upsert_open_question = _curiosity_methods.upsert_open_question RuntimeStorageRepository.list_open_questions = _curiosity_methods.list_open_questions RuntimeStorageRepository.mark_open_question = _curiosity_methods.mark_open_question diff --git a/packages/storage/repository_scope_methods.py b/packages/storage/repository_scope_methods.py index 5004ff8..904fb2c 100644 --- a/packages/storage/repository_scope_methods.py +++ b/packages/storage/repository_scope_methods.py @@ -2,7 +2,6 @@ from __future__ import annotations -from datetime import datetime import json from packages.auth import AuthProfile, EncryptedSecretValue, SecretReference from packages.contracts import SemanticIndexEntry @@ -184,11 +183,7 @@ def load_auth_profile(self, profile_id: str) -> AuthProfile | None: def list_auth_profiles(self, provider_id: str | None = None) -> tuple[AuthProfile, ...]: payload = _read_auth_profiles_payload(self) - profiles = tuple( - _auth_profile_from_payload(item) - for _, item in sorted(payload.items()) - if isinstance(item, dict) - ) + profiles = tuple(_auth_profile_from_payload(item) for _, item in sorted(payload.items()) if isinstance(item, dict)) if provider_id is None: return profiles return tuple(profile for profile in profiles if profile.provider_id == provider_id) @@ -345,14 +340,10 @@ def _auth_profile_from_payload(payload: dict[str, object]) -> AuthProfile: provider_id=str(payload["provider_id"]), transport_id=str(payload.get("transport_id") or "openai-compatible"), base_url=str(payload["base_url"]) if payload.get("base_url") is not None else None, - default_model=( - str(payload["default_model"]) if payload.get("default_model") is not None else None - ), + default_model=(str(payload["default_model"]) if payload.get("default_model") is not None else None), auth_method=str(payload.get("auth_method") or "api_key"), provider_kind=str(payload.get("provider_kind") or "first_party"), - extra_headers=( - dict(payload["extra_headers"]) if isinstance(payload.get("extra_headers"), dict) else {} - ), + extra_headers=(dict(payload["extra_headers"]) if isinstance(payload.get("extra_headers"), dict) else {}), secret_references=tuple( _secret_reference_from_payload(item) for item in payload.get("secret_references", ()) @@ -360,8 +351,6 @@ def _auth_profile_from_payload(payload: dict[str, object]) -> AuthProfile: ), priority=int(payload.get("priority") or 0), session_pin=str(payload["session_pin"]) if payload.get("session_pin") is not None else None, - cooldown_until=( - _parse_datetime(str(cooldown_until)) if cooldown_until is not None else None - ), + cooldown_until=(_parse_datetime(str(cooldown_until)) if cooldown_until is not None else None), metadata=dict(payload["metadata"]) if isinstance(payload.get("metadata"), dict) else {}, ) diff --git a/packages/storage/repository_support.py b/packages/storage/repository_support.py index 049f298..dd9da07 100644 --- a/packages/storage/repository_support.py +++ b/packages/storage/repository_support.py @@ -200,7 +200,9 @@ def _semantic_index_entry_from_row(row: sqlite3.Row) -> SemanticIndexEntry: model_id=str(row["model_id"]), dimensions=int(row["dimensions"]), content_hash=str(row["content_hash"]), - personal_model_id=canonical_personal_model_ref(str(row["personal_model_id"]) if row["personal_model_id"] is not None else None), + personal_model_id=canonical_personal_model_ref( + str(row["personal_model_id"]) if row["personal_model_id"] is not None else None + ), state_id=str(row["state_id"]) if row["state_id"] is not None else None, backend=str(row["backend"]), vector_ref=str(row["vector_ref"]), diff --git a/packages/storage/repository_system_methods.py b/packages/storage/repository_system_methods.py index ddc12be..17efb96 100644 --- a/packages/storage/repository_system_methods.py +++ b/packages/storage/repository_system_methods.py @@ -29,12 +29,10 @@ _json_text, _learning_job_from_row, _loop_from_row, - _mapping_object, _personal_model_from_row, _state_from_row, _step_from_row, canonical_personal_model_id, - canonical_personal_model_ref, ) @@ -316,9 +314,7 @@ def list_states( where_sql = "WHERE " + " AND ".join(clauses) if clauses else "" with self.connection() as connection: rows: Sequence[object] = connection.execute( - "SELECT * FROM states" - + (" " + where_sql if where_sql else "") - + " ORDER BY created_at ASC, state_id ASC", + "SELECT * FROM states" + (" " + where_sql if where_sql else "") + " ORDER BY created_at ASC, state_id ASC", tuple(parameters), ).fetchall() return tuple(_state_from_row(row) for row in rows) @@ -697,6 +693,7 @@ def refresh_episode_state( if episode is None: raise KeyError(episode_id) from dataclasses import replace + updated = replace( episode, status=status, @@ -790,11 +787,7 @@ def delete_orphaned_profiles( profile_ids: tuple[str, ...], ) -> int: resolved_profile_ids = tuple( - dict.fromkeys( - canonical_personal_model_id(profile_id) - for profile_id in profile_ids - if str(profile_id).strip() - ) + dict.fromkeys(canonical_personal_model_id(profile_id) for profile_id in profile_ids if str(profile_id).strip()) ) if not resolved_profile_ids: return 0 @@ -1004,7 +997,6 @@ def load_learning_job(self, job_id: str) -> LearningJob | None: return _learning_job_from_row(row) - def load_learning_job_for_episode(self, *, job_type: str, episode_id: str) -> LearningJob | None: with self.connection() as connection: row = connection.execute( @@ -1022,7 +1014,6 @@ def load_learning_job_for_episode(self, *, job_type: str, episode_id: str) -> Le return _learning_job_from_row(row) - def list_learning_jobs( self, *, @@ -1056,13 +1047,13 @@ def list_learning_jobs( rows: Sequence[object] = connection.execute( "SELECT * FROM learning_jobs" + (" " + where_sql if where_sql else "") - + " ORDER BY created_at DESC, job_id DESC" + limit_sql, + + " ORDER BY created_at DESC, job_id DESC" + + limit_sql, tuple(parameters), ).fetchall() return tuple(_learning_job_from_row(row) for row in rows) - def claim_learning_job(self, *, worker_id: str, now: datetime | None = None) -> LearningJob | None: claimed_at = now or datetime.now(timezone.utc) with self.connection() as connection: @@ -1100,7 +1091,6 @@ def claim_learning_job(self, *, worker_id: str, now: datetime | None = None) -> return self.load_learning_job(str(row["job_id"])) - def update_learning_job_progress( self, job_id: str, @@ -1127,7 +1117,6 @@ def update_learning_job_progress( return loaded - def write_learning_job_result( self, job_id: str, @@ -1162,7 +1151,6 @@ def write_learning_job_result( return loaded - def complete_learning_job( self, job_id: str, @@ -1192,7 +1180,6 @@ def complete_learning_job( return loaded - def fail_learning_job( self, job_id: str, @@ -1210,7 +1197,11 @@ def fail_learning_job( next_status = "queued" if will_retry else "failed" next_stage = "retrying" if will_retry else "failed" next_detail = "retry scheduled" if will_retry else "background learning failed" - available_at = failed_at if retry_delay_seconds <= 0 else failed_at.replace(microsecond=0) + timedelta(seconds=retry_delay_seconds) + available_at = ( + failed_at + if retry_delay_seconds <= 0 + else failed_at.replace(microsecond=0) + timedelta(seconds=retry_delay_seconds) + ) with self.connection() as connection: connection.execute( """ @@ -1242,11 +1233,12 @@ def fail_learning_job( return loaded - _LOOP_STATE_SCHEMA_VERSION = 2 -def _wait_condition_to_mapping(condition: WaitCondition | None) -> Mapping[str, object] | None: +def _wait_condition_to_mapping( + condition: WaitCondition | None, +) -> Mapping[str, object] | None: if condition is None: return None payload = dict(condition.payload or {}) @@ -1333,7 +1325,9 @@ def _pending_tool_call_to_mapping(call: PendingToolCall) -> Mapping[str, object] } -def _pending_tool_calls_to_list(calls: tuple[PendingToolCall, ...]) -> list[Mapping[str, object]]: +def _pending_tool_calls_to_list( + calls: tuple[PendingToolCall, ...], +) -> list[Mapping[str, object]]: return [_pending_tool_call_to_mapping(call) for call in calls] @@ -1495,19 +1489,13 @@ def _loop_state_from_loop(loop: Loop) -> LoopState: created_at=loop.started_at, updated_at=loop.ended_at or loop.started_at, waiting_reason=(str(metadata.get("waiting_reason")) if metadata.get("waiting_reason") else None), - continuation_prompt=( - str(metadata.get("continuation_prompt")) if metadata.get("continuation_prompt") else None - ), + continuation_prompt=(str(metadata.get("continuation_prompt")) if metadata.get("continuation_prompt") else None), last_summary=(str(metadata.get("last_summary")) if metadata.get("last_summary") else None), schema_version=int(metadata.get("schema_version") or _LOOP_STATE_SCHEMA_VERSION), wait_condition=_wait_condition_from_mapping(metadata.get("wait_condition")), pending_tool_calls=_pending_tool_calls_from_value(metadata.get("pending_tool_calls")), - partial_assistant=( - str(metadata.get("partial_assistant")) if metadata.get("partial_assistant") else None - ), - context_bundle_id=( - str(metadata.get("context_bundle_id")) if metadata.get("context_bundle_id") else None - ), + partial_assistant=(str(metadata.get("partial_assistant")) if metadata.get("partial_assistant") else None), + context_bundle_id=(str(metadata.get("context_bundle_id")) if metadata.get("context_bundle_id") else None), active_evidence_refs=_active_evidence_refs_from_value(metadata.get("active_evidence_refs")), retry_state=_retry_state_from_mapping(metadata.get("retry_state")), heartbeat_at=_parse_optional_datetime(metadata.get("heartbeat_at")), @@ -1546,9 +1534,7 @@ def upsert_loop_checkpoint(self, run: LoopState, *, verify: bool = True) -> None if verify: reloaded = _verify_loop_checkpoint_roundtrip(self, run) if reloaded is None: - raise RuntimeError( - f"loop checkpoint verify failed: run {run.run_id} did not round-trip" - ) + raise RuntimeError(f"loop checkpoint verify failed: run {run.run_id} did not round-trip") def _verify_loop_checkpoint_roundtrip(self, run: LoopState) -> LoopState | None: @@ -1617,9 +1603,7 @@ def list_loop_checkpoints( if state_id is not None and loop.state_id != state_id: continue if personal_model_id is not None: - if canonical_personal_model_id(loop.personal_model_id) != canonical_personal_model_id( - personal_model_id - ): + if canonical_personal_model_id(loop.personal_model_id) != canonical_personal_model_id(personal_model_id): continue run = _loop_state_from_loop(loop) if heartbeat_before is not None: @@ -1655,7 +1639,11 @@ def load_latest_open_loop_checkpoint( return None latest = sorted( candidates, - key=lambda loop: ((loop.ended_at or loop.started_at).isoformat(), loop.started_at.isoformat(), loop.loop_id), + key=lambda loop: ( + (loop.ended_at or loop.started_at).isoformat(), + loop.started_at.isoformat(), + loop.loop_id, + ), reverse=True, )[0] return _loop_state_from_loop(latest) diff --git a/packages/telemetry/runtime.py b/packages/telemetry/runtime.py index 6039038..698c021 100644 --- a/packages/telemetry/runtime.py +++ b/packages/telemetry/runtime.py @@ -9,7 +9,15 @@ from dataclasses import dataclass, field from datetime import UTC, datetime -from typing import Any, Callable, ClassVar, Literal, Mapping, Protocol, runtime_checkable +from typing import ( + Any, + Callable, + ClassVar, + Literal, + Mapping, + Protocol, + runtime_checkable, +) TelemetryFamily = Literal["lifecycle", "execution", "approval", "delivery", "failure"] TelemetryPhase = Literal["ingest", "resolve", "recover", "assemble", "select", "execute", "persist", "emit"] @@ -37,10 +45,26 @@ "emit", ) -EXECUTION_STATUSES: tuple[ExecutionStatus, ...] = ("started", "completed", "blocked", "failed") +EXECUTION_STATUSES: tuple[ExecutionStatus, ...] = ( + "started", + "completed", + "blocked", + "failed", +) APPROVAL_DECISIONS: tuple[ApprovalDecision, ...] = ("approved", "denied", "deferred") -DELIVERY_STATUSES: tuple[DeliveryStatus, ...] = ("queued", "sent", "acknowledged", "failed") -FAILURE_SEVERITIES: tuple[FailureSeverity, ...] = ("debug", "info", "warning", "error", "critical") +DELIVERY_STATUSES: tuple[DeliveryStatus, ...] = ( + "queued", + "sent", + "acknowledged", + "failed", +) +FAILURE_SEVERITIES: tuple[FailureSeverity, ...] = ( + "debug", + "info", + "warning", + "error", + "critical", +) KRN_REQUIRED_EVENTS = ( "lifecycle.turn.ingested", diff --git a/packages/tools/__init__.py b/packages/tools/__init__.py index 0e09e07..1a0890f 100644 --- a/packages/tools/__init__.py +++ b/packages/tools/__init__.py @@ -37,7 +37,11 @@ ToolRuntime, ToolSideEffectMetadata, ) -from .surfaces import BrowserVisionAnalyzer, BuiltinToolDependencies, InMemorySessionTodoStore +from .surfaces import ( + BrowserVisionAnalyzer, + BuiltinToolDependencies, + InMemorySessionTodoStore, +) __all__ = [ "ApprovalGateway", diff --git a/packages/tools/browser_backend.py b/packages/tools/browser_backend.py index fd0e857..27740a9 100644 --- a/packages/tools/browser_backend.py +++ b/packages/tools/browser_backend.py @@ -86,13 +86,21 @@ class BrowserBackendConfig: @classmethod def from_env(cls, *, headless: bool = True) -> "BrowserBackendConfig": - provider = _first_env("ELEPHANT_BROWSER_CLOUD_PROVIDER", "BROWSER_CLOUD_PROVIDER", "BROWSER_PROVIDER") + provider = _first_env( + "ELEPHANT_BROWSER_CLOUD_PROVIDER", + "BROWSER_CLOUD_PROVIDER", + "BROWSER_PROVIDER", + ) return cls( headless=_env_bool("ELEPHANT_BROWSER_HEADLESS", default=headless), cdp_url=_first_env("ELEPHANT_BROWSER_CDP_URL", "BROWSER_CDP_URL"), cloud_provider=provider.strip().lower(), camofox_url=_first_env("ELEPHANT_BROWSER_CAMOFOX_URL", "CAMOFOX_URL").rstrip("/"), - allow_private_urls=_env_bool("ELEPHANT_BROWSER_ALLOW_PRIVATE_URLS", "BROWSER_ALLOW_PRIVATE_URLS", default=False), + allow_private_urls=_env_bool( + "ELEPHANT_BROWSER_ALLOW_PRIVATE_URLS", + "BROWSER_ALLOW_PRIVATE_URLS", + default=False, + ), screenshots_dir=Path( _first_env("ELEPHANT_BROWSER_SCREENSHOTS_DIR") or (Path.home() / ".elephant" / "cache" / "browser-screenshots") @@ -163,11 +171,17 @@ def _invoke_on_worker( if action == "navigate": return self._navigate(session, invocation) if action == "snapshot": - return self._summary(invocation, self._snapshot_payload(session, full=_bool_arg(invocation, "full"))) + return self._summary( + invocation, + self._snapshot_payload(session, full=_bool_arg(invocation, "full")), + ) if action == "click": target = self._target_selector(session, invocation, require=True) session.page.locator(target).first.click(timeout=_DEFAULT_ACTION_TIMEOUT_MS) - return self._summary(invocation, {"success": True, "clicked": target, **self._page_identity(session)}) + return self._summary( + invocation, + {"success": True, "clicked": target, **self._page_identity(session)}, + ) if action == "type": target = self._target_selector(session, invocation, require=True) if "text" not in invocation.arguments: @@ -185,7 +199,14 @@ def _invoke_on_worker( elif direction == "down": amount = abs(amount) session.page.mouse.wheel(0, amount) - return self._summary(invocation, {"success": True, "scrolled_px": amount, **self._page_identity(session)}) + return self._summary( + invocation, + { + "success": True, + "scrolled_px": amount, + **self._page_identity(session), + }, + ) if action == "back": session.page.go_back(wait_until="domcontentloaded", timeout=_DEFAULT_NAVIGATION_TIMEOUT_MS) return self._summary(invocation, {"success": True, **self._page_identity(session)}) @@ -194,10 +215,16 @@ def _invoke_on_worker( if not key: raise ValueError("tool.browser.press requires a 'key' argument") session.page.keyboard.press(key) - return self._summary(invocation, {"success": True, "pressed": key, **self._page_identity(session)}) + return self._summary( + invocation, + {"success": True, "pressed": key, **self._page_identity(session)}, + ) if action == "images": records = session.page.evaluate(IMAGES_JS, {"maxImages": _MAX_IMAGES}) - return self._summary(invocation, {"success": True, "images": records or [], "count": len(records or [])}) + return self._summary( + invocation, + {"success": True, "images": records or [], "count": len(records or [])}, + ) if action == "vision": return self._vision(session, invocation, vision_analyzer=vision_analyzer) if action == "console": @@ -232,7 +259,9 @@ def _run_on_worker(self, work: Callable[[], Any]) -> Any: worker_queue.put((future, work)) return future.result() - def _ensure_worker(self) -> queue.Queue[tuple[Future[Any], Callable[[], Any]] | None]: + def _ensure_worker( + self, + ) -> queue.Queue[tuple[Future[Any], Callable[[], Any]] | None]: with self._worker_lock: if self._closed: raise RuntimeError("browser backend is closed") @@ -414,15 +443,18 @@ def _navigate(self, session: BrowserSession, invocation: ToolInvocation) -> Mapp return self._summary(invocation, payload) def _snapshot_payload(self, session: BrowserSession, *, full: bool) -> dict[str, Any]: - data = session.page.evaluate( - SNAPSHOT_JS, - { - "full": full, - "compactLimit": _SNAPSHOT_COMPACT_LIMIT, - "fullLimit": _SNAPSHOT_FULL_LIMIT, - "maxElements": _MAX_REF_ELEMENTS, - }, - ) or {} + data = ( + session.page.evaluate( + SNAPSHOT_JS, + { + "full": full, + "compactLimit": _SNAPSHOT_COMPACT_LIMIT, + "fullLimit": _SNAPSHOT_FULL_LIMIT, + "maxElements": _MAX_REF_ELEMENTS, + }, + ) + or {} + ) elements = tuple(data.get("elements") or ()) session.refs = { str(element.get("ref")): _ref_selector(str(element.get("ref"))) @@ -458,7 +490,7 @@ def _format_snapshot(self, data: Mapping[str, Any]) -> str: disabled = " disabled" if element.get("disabled") else "" href = str(element.get("href") or "").strip() suffix = f" href={href}" if href else "" - lines.append(f"- [{ref}] {role}{disabled} \"{label}\"{suffix}") + lines.append(f'- [{ref}] {role}{disabled} "{label}"{suffix}') text = str(data.get("text") or "").strip() if text: lines.append("page text:") @@ -533,7 +565,11 @@ def _console(self, session: BrowserSession, invocation: ToolInvocation) -> Mappi expression = invocation.arguments.get("expression") if expression is not None and str(expression).strip(): result = session.page.evaluate(str(expression)) - payload = {"success": True, "result": result, "result_type": type(result).__name__} + payload = { + "success": True, + "result": result, + "result_type": type(result).__name__, + } return self._summary(invocation, payload) payload = { "success": True, @@ -551,7 +587,9 @@ def _guard_url(self, url: str) -> None: if _SECRET_RE.search(url): raise ValueError("blocked browser navigation: URL appears to contain an API key or token") if not self.config.allow_private_urls and not self._is_local_backend() and _is_private_url(url): - raise ValueError("blocked browser navigation: private/internal URLs are disabled for remote browser backends") + raise ValueError( + "blocked browser navigation: private/internal URLs are disabled for remote browser backends" + ) def _is_local_backend(self) -> bool: return not self.config.cdp_url and self._configured_cloud_provider() is None @@ -628,7 +666,10 @@ def invoke( return self._summary(invocation, self._snapshot_payload(session)) if action == "click": ref = _ref_arg(invocation) - payload = self._post(f"/tabs/{session['tab_id']}/click", {"userId": session["user_id"], "ref": ref.lstrip("@")}) + payload = self._post( + f"/tabs/{session['tab_id']}/click", + {"userId": session["user_id"], "ref": ref.lstrip("@")}, + ) return self._summary(invocation, {"success": True, "clicked": ref, **payload}) if action == "type": ref = _ref_arg(invocation) @@ -652,7 +693,10 @@ def invoke( key = str(invocation.arguments.get("key") or "").strip() if not key: raise ValueError("tool.browser.press requires a 'key' argument") - payload = self._post(f"/tabs/{session['tab_id']}/press", {"userId": session["user_id"], "key": key}) + payload = self._post( + f"/tabs/{session['tab_id']}/press", + {"userId": session["user_id"], "key": key}, + ) return self._summary(invocation, {"success": True, "pressed": key, **payload}) if action == "images": snapshot = self._snapshot_payload(session).get("snapshot", "") @@ -690,7 +734,11 @@ def _navigate(self, invocation: ToolInvocation) -> Mapping[str, Any]: raise ValueError("blocked browser navigation: URL appears to contain an API key or token") session = self._session_for(invocation.session_id, url=url) if session.get("navigated"): - payload = self._post(f"/tabs/{session['tab_id']}/navigate", {"userId": session["user_id"], "url": url}, timeout=60) + payload = self._post( + f"/tabs/{session['tab_id']}/navigate", + {"userId": session["user_id"], "url": url}, + timeout=60, + ) else: payload = {"url": url} session["navigated"] = True @@ -713,8 +761,16 @@ def _session_for(self, session_key: str, *, url: str = "about:blank") -> dict[st if existing is not None: return existing user_id = f"elephant_{uuid4().hex[:12]}" - created = self._post("/tabs", {"userId": user_id, "sessionKey": session_key[:48], "url": url}, timeout=60) - session = {"session_key": session_key, "user_id": user_id, "tab_id": created.get("tabId")} + created = self._post( + "/tabs", + {"userId": user_id, "sessionKey": session_key[:48], "url": url}, + timeout=60, + ) + session = { + "session_key": session_key, + "user_id": user_id, + "tab_id": created.get("tabId"), + } if not session["tab_id"]: raise RuntimeError("Camofox did not return a tabId") self._sessions[session_key] = session @@ -730,8 +786,14 @@ def _snapshot_payload(self, session: Mapping[str, Any]) -> dict[str, Any]: data = self._get(f"/tabs/{session['tab_id']}/snapshot", params={"userId": session["user_id"]}) snapshot = str(data.get("snapshot") or "") if len(snapshot) > _SNAPSHOT_FULL_LIMIT: - snapshot = snapshot[:_SNAPSHOT_FULL_LIMIT] + f"\n[... {len(snapshot) - _SNAPSHOT_FULL_LIMIT} chars truncated]" - return {"success": True, "snapshot": snapshot, "element_count": int(data.get("refsCount") or 0)} + snapshot = ( + snapshot[:_SNAPSHOT_FULL_LIMIT] + f"\n[... {len(snapshot) - _SNAPSHOT_FULL_LIMIT} chars truncated]" + ) + return { + "success": True, + "snapshot": snapshot, + "element_count": int(data.get("refsCount") or 0), + } def _vision( self, @@ -740,7 +802,10 @@ def _vision( *, vision_analyzer: BrowserVisionAnalyzer | None = None, ) -> Mapping[str, Any]: - response = self._get_bytes(f"/tabs/{session['tab_id']}/screenshot", params={"userId": session["user_id"]}) + response = self._get_bytes( + f"/tabs/{session['tab_id']}/screenshot", + params={"userId": session["user_id"]}, + ) self.config.screenshots_dir.mkdir(parents=True, exist_ok=True) screenshot_path = self.config.screenshots_dir / f"browser_screenshot_{uuid4().hex}.png" screenshot_path.write_bytes(response) @@ -935,7 +1000,9 @@ def _apply_vision_analysis( ) -> None: if analyzer is None: payload["vision_analyzer_configured"] = False - payload["vision_setup_hint"] = "Configure a browser vision analyzer before using tool.browser.vision for visual analysis." + payload["vision_setup_hint"] = ( + "Configure a browser vision analyzer before using tool.browser.vision for visual analysis." + ) return result = analyzer.analyze_browser_screenshot( session_id=session_id, diff --git a/packages/tools/browser_providers.py b/packages/tools/browser_providers.py index a4caeea..a502989 100644 --- a/packages/tools/browser_providers.py +++ b/packages/tools/browser_providers.py @@ -43,7 +43,10 @@ def create_session(self, task_id: str) -> CloudBrowserSession: response = _http_json( "POST", os.environ.get("BROWSER_USE_API_URL", "https://api.browser-use.com/api/v3").rstrip("/") + "/browsers", - headers={"X-Browser-Use-API-Key": api_key, "Content-Type": "application/json"}, + headers={ + "X-Browser-Use-API-Key": api_key, + "Content-Type": "application/json", + }, payload={"timeout": int(os.environ.get("BROWSER_USE_TIMEOUT_MINUTES", "5"))}, ) session_id = str(response.get("id") or "") @@ -63,7 +66,10 @@ def close_session(self, session_id: str) -> None: "PATCH", os.environ.get("BROWSER_USE_API_URL", "https://api.browser-use.com/api/v3").rstrip() + f"/browsers/{session_id}", - headers={"X-Browser-Use-API-Key": api_key, "Content-Type": "application/json"}, + headers={ + "X-Browser-Use-API-Key": api_key, + "Content-Type": "application/json", + }, payload={"action": "stop"}, tolerate_http_errors=True, ) @@ -127,7 +133,10 @@ def create_session(self, task_id: str) -> CloudBrowserSession: response = _http_json( "POST", os.environ.get("FIRECRAWL_API_URL", "https://api.firecrawl.dev").rstrip("/") + "/v2/browser", - headers={"Authorization": f"Bearer {os.environ.get('FIRECRAWL_API_KEY', '')}", "Content-Type": "application/json"}, + headers={ + "Authorization": f"Bearer {os.environ.get('FIRECRAWL_API_KEY', '')}", + "Content-Type": "application/json", + }, payload={"ttl": int(os.environ.get("FIRECRAWL_BROWSER_TTL", "300"))}, ) session_id = str(response.get("id") or "") @@ -166,8 +175,14 @@ def _http_json( data = response.read() except HTTPError as error: if tolerate_http_errors: - return {"ok": False, "status": error.code, "body": error.read().decode("utf-8", errors="replace")} - raise RuntimeError(f"browser provider request failed: HTTP {error.code} {error.read().decode('utf-8', errors='replace')}") from error + return { + "ok": False, + "status": error.code, + "body": error.read().decode("utf-8", errors="replace"), + } + raise RuntimeError( + f"browser provider request failed: HTTP {error.code} {error.read().decode('utf-8', errors='replace')}" + ) from error except URLError as error: if tolerate_http_errors: return {"ok": False, "error": str(error)} diff --git a/packages/tools/browser_scripts.py b/packages/tools/browser_scripts.py index c2c3085..92a6968 100644 --- a/packages/tools/browser_scripts.py +++ b/packages/tools/browser_scripts.py @@ -123,6 +123,8 @@ } """ -CLEAR_ANNOTATIONS_JS = "() => document.querySelectorAll('[data-elephant-browser-annotation]').forEach((node) => node.remove())" +CLEAR_ANNOTATIONS_JS = ( + "() => document.querySelectorAll('[data-elephant-browser-annotation]').forEach((node) => node.remove())" +) __all__ = ["ANNOTATE_JS", "CLEAR_ANNOTATIONS_JS", "IMAGES_JS", "SNAPSHOT_JS"] diff --git a/packages/tools/builtins.py b/packages/tools/builtins.py index 3afdcbe..32220ca 100644 --- a/packages/tools/builtins.py +++ b/packages/tools/builtins.py @@ -37,7 +37,13 @@ run_process_action, run_terminal_exec, ) -from .runtime import ToolAudience, ToolAvailability, ToolDefinition, ToolRuntime, ToolSideEffectMetadata +from .runtime import ( + ToolAudience, + ToolAvailability, + ToolDefinition, + ToolRuntime, + ToolSideEffectMetadata, +) from .schema_descriptions import enrich_builtin_tool_schema from .surfaces import BuiltinToolDependencies @@ -116,7 +122,11 @@ def builtin_tool_definitions( properties={ "command": {"type": "string"}, "cwd": {"type": "string"}, - "timeout_seconds": {"type": "integer", "minimum": 1, "maximum": 120}, + "timeout_seconds": { + "type": "integer", + "minimum": 1, + "maximum": 120, + }, "background": {"type": "boolean"}, "env": {"type": "object"}, }, @@ -145,12 +155,31 @@ def builtin_tool_definitions( properties={ "action": { "type": "string", - "enum": ["list", "ls", "poll", "inspect", "wait", "write", "kill"], + "enum": [ + "list", + "ls", + "poll", + "inspect", + "wait", + "write", + "kill", + ], "description": "Use list|ls to enumerate managed processes; use poll/inspect for current status and buffered stdout/stderr, wait to block for completion, write for stdin, kill to stop. Use non-buffered commands (for example python -u) for interactive echo.", }, - "process_id": {"type": "string", "description": "Managed process id returned by a background tool.terminal.exec call."}, - "input": {"type": "string", "description": "Text to write to stdin for action=write; include a newline when the process expects line input."}, - "timeout_seconds": {"type": "integer", "minimum": 1, "maximum": 120, "description": "Maximum seconds for action=wait."}, + "process_id": { + "type": "string", + "description": "Managed process id returned by a background tool.terminal.exec call.", + }, + "input": { + "type": "string", + "description": "Text to write to stdin for action=write; include a newline when the process expects line input.", + }, + "timeout_seconds": { + "type": "integer", + "minimum": 1, + "maximum": 120, + "description": "Maximum seconds for action=wait.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -171,9 +200,21 @@ def builtin_tool_definitions( schema=_object_schema( required=("path",), properties={ - "path": {"type": "string", "description": "Root-relative or absolute file path to read."}, - "offset": {"type": "integer", "minimum": 1, "description": "1-indexed first line to read."}, - "limit": {"type": "integer", "minimum": 1, "maximum": 2000, "description": "Maximum number of lines to read."}, + "path": { + "type": "string", + "description": "Root-relative or absolute file path to read.", + }, + "offset": { + "type": "integer", + "minimum": 1, + "description": "1-indexed first line to read.", + }, + "limit": { + "type": "integer", + "minimum": 1, + "maximum": 2000, + "description": "Maximum number of lines to read.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -193,8 +234,14 @@ def builtin_tool_definitions( schema=_object_schema( required=("path", "content"), properties={ - "path": {"type": "string", "description": "Root-relative or absolute file path to write. Must stay inside an allowed root; sensitive env, credential, and VCS metadata paths are refused."}, - "content": {"type": "string", "description": "Complete text content to write to the file."}, + "path": { + "type": "string", + "description": "Root-relative or absolute file path to write. Must stay inside an allowed root; sensitive env, credential, and VCS metadata paths are refused.", + }, + "content": { + "type": "string", + "description": "Complete text content to write to the file.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -225,10 +272,22 @@ def builtin_tool_definitions( "bookkeeping." ), }, - "path": {"type": "string", "description": "Root-relative or absolute file path for replace mode."}, - "old_string": {"type": "string", "description": "Exact text to locate; must be unique unless replace_all=true."}, - "new_string": {"type": "string", "description": "Replacement text for the matched content."}, - "replace_all": {"type": "boolean", "description": "Replace every match instead of requiring uniqueness."}, + "path": { + "type": "string", + "description": "Root-relative or absolute file path for replace mode.", + }, + "old_string": { + "type": "string", + "description": "Exact text to locate; must be unique unless replace_all=true.", + }, + "new_string": { + "type": "string", + "description": "Replacement text for the matched content.", + }, + "replace_all": { + "type": "boolean", + "description": "Replace every match instead of requiring uniqueness.", + }, "patch": { "type": "string", "description": ( @@ -260,15 +319,48 @@ def builtin_tool_definitions( schema=_object_schema( required=(), properties={ - "query": {"type": "string", "description": "Text or regex-like pattern to search for. Required for target=content; optional for target=files, where it is treated as a glob when glob is omitted."}, - "pattern": {"type": "string", "description": "Backward-compatible alias for query. Use query for new calls."}, - "target": {"type": "string", "enum": ["content", "files"], "description": "Search file contents or file paths."}, - "path": {"type": "string", "description": "Optional file or directory path to search within; must be inside the active root or another configured allowed root and cannot be a sensitive credential/VCS metadata path."}, - "glob": {"type": "string", "description": "Optional file glob filter such as '*.py'. For target=files, omit both query and glob to list files."}, - "include": {"type": "string", "description": "Backward-compatible alias for glob. Use glob for new calls."}, - "limit": {"type": "integer", "minimum": 1, "maximum": 200, "description": "Maximum number of matches to return."}, - "offset": {"type": "integer", "minimum": 0, "description": "Number of matches to skip for pagination."}, - "context": {"type": "integer", "minimum": 0, "maximum": 5, "description": "Context lines around content matches."}, + "query": { + "type": "string", + "description": "Text or regex-like pattern to search for. Required for target=content; optional for target=files, where it is treated as a glob when glob is omitted.", + }, + "pattern": { + "type": "string", + "description": "Backward-compatible alias for query. Use query for new calls.", + }, + "target": { + "type": "string", + "enum": ["content", "files"], + "description": "Search file contents or file paths.", + }, + "path": { + "type": "string", + "description": "Optional file or directory path to search within; must be inside the active root or another configured allowed root and cannot be a sensitive credential/VCS metadata path.", + }, + "glob": { + "type": "string", + "description": "Optional file glob filter such as '*.py'. For target=files, omit both query and glob to list files.", + }, + "include": { + "type": "string", + "description": "Backward-compatible alias for glob. Use glob for new calls.", + }, + "limit": { + "type": "integer", + "minimum": 1, + "maximum": 200, + "description": "Maximum number of matches to return.", + }, + "offset": { + "type": "integer", + "minimum": 0, + "description": "Number of matches to skip for pagination.", + }, + "context": { + "type": "integer", + "minimum": 0, + "maximum": 5, + "description": "Context lines around content matches.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -400,15 +492,39 @@ def builtin_tool_definitions( properties={ "action": { "type": "string", - "enum": ["list", "ls", "create", "inspect", "pause", "resume", "remove", "delete"], + "enum": [ + "list", + "ls", + "create", + "inspect", + "pause", + "resume", + "remove", + "delete", + ], "description": "Use list|ls without job_id; use create with schedule and prompt; use inspect|pause|resume|remove|delete with job_id.", }, - "job_id": {"type": "string", "description": "Cron job id such as cron:9f0e36022b."}, - "name": {"type": "string", "description": "Human-readable job name when action=create."}, - "schedule": {"type": "string", "description": "Schedule when action=create. Accepted examples: ISO timestamp '2026-05-13T09:00:00+08:00', interval '1h'/'30m'/'PT1H', or standard 5-field cron '0 2 * * *'."}, - "prompt": {"type": "string", "description": "Prompt payload for the scheduled prompt job when action=create."}, + "job_id": { + "type": "string", + "description": "Cron job id such as cron:9f0e36022b.", + }, + "name": { + "type": "string", + "description": "Human-readable job name when action=create.", + }, + "schedule": { + "type": "string", + "description": "Schedule when action=create. Accepted examples: ISO timestamp '2026-05-13T09:00:00+08:00', interval '1h'/'30m'/'PT1H', or standard 5-field cron '0 2 * * *'.", + }, + "prompt": { + "type": "string", + "description": "Prompt payload for the scheduled prompt job when action=create.", + }, "skills": { - "oneOf": [{"type": "array", "items": {"type": "string"}}, {"type": "string"}], + "oneOf": [ + {"type": "array", "items": {"type": "string"}}, + {"type": "string"}, + ], "description": "Skill ids to load as operating instructions when a prompt job runs.", }, "profile_id": {"type": "string"}, @@ -436,15 +552,48 @@ def builtin_tool_definitions( ), schema=_object_schema( properties={ - "query": {"type": "string", "description": "Natural-language claim lookup."}, - "query_variants": {"type": "array", "items": {"type": "string"}, "description": "Optional translated or paraphrased query variants for cross-lingual or metaphorical lookup; at most 5 are used."}, - "mode": {"type": "string", "enum": ["auto", "inventory"], "description": "Search mode. Use inventory to get lens→topic list with claim counts (no content). Defaults to auto."}, - "lens": {"type": "string", "enum": ["identity", "world", "pulse", "journey"], "description": "Optional four-lens filter."}, - "topic": {"type": "string", "description": "Optional lens-prefixed topic key: ..[.], e.g. knowledge.projects.aegis.status."}, - "status": {"type": "string", "enum": ["active", "retired", "disputed", "all"], "description": "Claim status filter. Defaults to active; use retired/all to audit old corrected claims."}, - "ref": {"type": "string", "description": "Optional exact claim ref lookup, independent of semantic score."}, - "include_diagnostics": {"type": "boolean", "description": "Return match status, no-match reason, and per-claim scoring signals for debugging."}, - "limit": {"type": "integer", "minimum": 1, "maximum": 30, "description": "Maximum claims to return."}, + "query": { + "type": "string", + "description": "Natural-language claim lookup.", + }, + "query_variants": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional translated or paraphrased query variants for cross-lingual or metaphorical lookup; at most 5 are used.", + }, + "mode": { + "type": "string", + "enum": ["auto", "inventory"], + "description": "Search mode. Use inventory to get lens→topic list with claim counts (no content). Defaults to auto.", + }, + "lens": { + "type": "string", + "enum": ["identity", "world", "pulse", "journey"], + "description": "Optional four-lens filter.", + }, + "topic": { + "type": "string", + "description": "Optional lens-prefixed topic key: ..[.], e.g. knowledge.projects.aegis.status.", + }, + "status": { + "type": "string", + "enum": ["active", "retired", "disputed", "all"], + "description": "Claim status filter. Defaults to active; use retired/all to audit old corrected claims.", + }, + "ref": { + "type": "string", + "description": "Optional exact claim ref lookup, independent of semantic score.", + }, + "include_diagnostics": { + "type": "boolean", + "description": "Return match status, no-match reason, and per-claim scoring signals for debugging.", + }, + "limit": { + "type": "integer", + "minimum": 1, + "maximum": 30, + "description": "Maximum claims to return.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -469,16 +618,52 @@ def builtin_tool_definitions( ), schema=_object_schema( properties={ - "query": {"type": "string", "description": "Content query for prior conversation search. Leave empty only to list a narrow time window."}, - "mode": {"type": "string", "enum": ["discover", "recall"], "description": "Use discover to find relevant ranges; use recall to return conversation details. Defaults to recall. Discover should include expr or explicit start_at/end_at and returns recall_args lines that can be copied into a recall call."}, - "expr": {"type": "string", "description": "Stable time expression: today, yesterday, last:24h, last:3d, this:week, previous:week, last_night, yesterday_evening, this_morning, today_afternoon, today_evening, an ISO date like 2026-05-13, or an ISO interval like 2026-05-08T18:00:00+08:00/PT12H."}, - "start_at": {"type": "string", "description": "Optional RFC3339 start datetime for explicit intervals."}, - "end_at": {"type": "string", "description": "Optional RFC3339 end datetime for explicit intervals. End is exclusive."}, - "timezone": {"type": "string", "description": "Optional IANA timezone such as Asia/Shanghai; defaults to runtime timezone."}, - "bucket": {"type": "string", "enum": ["auto", "hour", "day"], "description": "Discover bucket size. Defaults to auto."}, - "preview": {"type": "string", "enum": ["none", "anchors"], "description": "Discover preview style. Defaults to anchors."}, - "view": {"type": "string", "enum": ["conversation", "debug"], "description": "Use conversation by default; debug includes internal source/tool material for diagnostics only."}, - "limit": {"type": "integer", "minimum": 1, "maximum": 30, "description": "Maximum ranges or hits to return."}, + "query": { + "type": "string", + "description": "Content query for prior conversation search. Leave empty only to list a narrow time window.", + }, + "mode": { + "type": "string", + "enum": ["discover", "recall"], + "description": "Use discover to find relevant ranges; use recall to return conversation details. Defaults to recall. Discover should include expr or explicit start_at/end_at and returns recall_args lines that can be copied into a recall call.", + }, + "expr": { + "type": "string", + "description": "Stable time expression: today, yesterday, last:24h, last:3d, this:week, previous:week, last_night, yesterday_evening, this_morning, today_afternoon, today_evening, an ISO date like 2026-05-13, or an ISO interval like 2026-05-08T18:00:00+08:00/PT12H.", + }, + "start_at": { + "type": "string", + "description": "Optional RFC3339 start datetime for explicit intervals.", + }, + "end_at": { + "type": "string", + "description": "Optional RFC3339 end datetime for explicit intervals. End is exclusive.", + }, + "timezone": { + "type": "string", + "description": "Optional IANA timezone such as Asia/Shanghai; defaults to runtime timezone.", + }, + "bucket": { + "type": "string", + "enum": ["auto", "hour", "day"], + "description": "Discover bucket size. Defaults to auto.", + }, + "preview": { + "type": "string", + "enum": ["none", "anchors"], + "description": "Discover preview style. Defaults to anchors.", + }, + "view": { + "type": "string", + "enum": ["conversation", "debug"], + "description": "Use conversation by default; debug includes internal source/tool material for diagnostics only.", + }, + "limit": { + "type": "integer", + "minimum": 1, + "maximum": 30, + "description": "Maximum ranges or hits to return.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -503,15 +688,54 @@ def builtin_tool_definitions( schema=_object_schema( required=("action", "lens", "topic", "reason"), properties={ - "action": {"type": "string", "enum": ["remember", "correct", "forget", "dispute", "restore", "delete"], "description": "How to change the claim. Use restore with ref to reactivate a retired or disputed claim; use delete with ref only for non-protected accidental/synthetic/duplicate invalid claims that should leave the visible model entirely."}, - "lens": {"type": "string", "enum": ["identity", "world", "pulse", "journey"], "description": "Which Personal Model lens owns the claim."}, - "topic": {"type": "string", "description": "Lens-prefixed topic key: ..[.]. First segment must match lens. Facets: identity={anchor,character,values,style,body}; world={people,projects,tools,places,assets,skills}; pulse={chapter,focus,mood,blockers,intent}; journey={lessons,patterns,decisions,milestones}. Examples: identity.anchor.name.preferred, world.people.zhang_san.role, pulse.chapter.work.role, journey.lessons.collaboration.scope_creep."}, - "text": {"type": "string", "description": "Claim text. Required for action=remember or action=correct."}, - "ref": {"type": "string", "description": "Exact claim ref from personal_model.search. Required for delete/restore; strongly preferred for correct/forget/dispute when topic is uncertain."}, - "reason": {"type": "string", "description": "Why this update is warranted, preferably grounded in the user's words."}, - "source": {"type": "string", "enum": ["user_said", "user_corrected", "learned"], "description": "Where the update came from."}, - "recall_policy": {"type": "string", "enum": ["stable", "current", "temporary", "review"], "description": "Optional; use only when obvious: stable, current, temporary, or review."}, - "metadata": {"type": "object", "additionalProperties": {"type": "string"}, "description": "Optional governance metadata. Skill affinity facts should include skill_id, index_id, and projection_policy when known."}, + "action": { + "type": "string", + "enum": [ + "remember", + "correct", + "forget", + "dispute", + "restore", + "delete", + ], + "description": "How to change the claim. Use restore with ref to reactivate a retired or disputed claim; use delete with ref only for non-protected accidental/synthetic/duplicate invalid claims that should leave the visible model entirely.", + }, + "lens": { + "type": "string", + "enum": ["identity", "world", "pulse", "journey"], + "description": "Which Personal Model lens owns the claim.", + }, + "topic": { + "type": "string", + "description": "Lens-prefixed topic key: ..[.]. First segment must match lens. Facets: identity={anchor,character,values,style,body}; world={people,projects,tools,places,assets,skills}; pulse={chapter,focus,mood,blockers,intent}; journey={lessons,patterns,decisions,milestones}. Examples: identity.anchor.name.preferred, world.people.zhang_san.role, pulse.chapter.work.role, journey.lessons.collaboration.scope_creep.", + }, + "text": { + "type": "string", + "description": "Claim text. Required for action=remember or action=correct.", + }, + "ref": { + "type": "string", + "description": "Exact claim ref from personal_model.search. Required for delete/restore; strongly preferred for correct/forget/dispute when topic is uncertain.", + }, + "reason": { + "type": "string", + "description": "Why this update is warranted, preferably grounded in the user's words.", + }, + "source": { + "type": "string", + "enum": ["user_said", "user_corrected", "learned"], + "description": "Where the update came from.", + }, + "recall_policy": { + "type": "string", + "enum": ["stable", "current", "temporary", "review"], + "description": "Optional; use only when obvious: stable, current, temporary, or review.", + }, + "metadata": { + "type": "object", + "additionalProperties": {"type": "string"}, + "description": "Optional governance metadata. Skill affinity facts should include skill_id, index_id, and projection_policy when known.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -536,17 +760,69 @@ def builtin_tool_definitions( schema=_object_schema( required=("action",), properties={ - "action": {"type": "string", "enum": ["list", "inspect", "bank", "create", "update", "ask", "answer", "dismiss", "reopen", "stale", "delete"], "description": "Question lifecycle action."}, - "question_id": {"type": "string", "description": "Question ref for inspect/update/ask/answer/dismiss/delete."}, - "status": {"type": "string", "description": "Filter for list: open, asked, answered, dismissed, stale."}, - "lens": {"type": "string", "enum": ["identity", "world", "pulse", "journey"], "description": "Four-lens owner."}, - "topic": {"type": "string", "description": "Question topic or sub-lens."}, - "text": {"type": "string", "description": "Question text for create/update."}, - "answer": {"type": "string", "description": "User's answer; answer also creates a Personal Model claim."}, - "reason": {"type": "string", "description": "Why this question exists or changed."}, - "priority": {"type": "number", "minimum": 0, "maximum": 1, "description": "Priority from 0.0 to 1.0 for ordering open questions."}, - "sensitivity": {"type": "string", "enum": ["low", "medium", "high"], "description": "How sensitive the question is for the user."}, - "limit": {"type": "integer", "minimum": 1, "maximum": 20, "description": "Maximum question rows to return."}, + "action": { + "type": "string", + "enum": [ + "list", + "inspect", + "bank", + "create", + "update", + "ask", + "answer", + "dismiss", + "reopen", + "stale", + "delete", + ], + "description": "Question lifecycle action.", + }, + "question_id": { + "type": "string", + "description": "Question ref for inspect/update/ask/answer/dismiss/delete.", + }, + "status": { + "type": "string", + "description": "Filter for list: open, asked, answered, dismissed, stale.", + }, + "lens": { + "type": "string", + "enum": ["identity", "world", "pulse", "journey"], + "description": "Four-lens owner.", + }, + "topic": { + "type": "string", + "description": "Question topic or sub-lens.", + }, + "text": { + "type": "string", + "description": "Question text for create/update.", + }, + "answer": { + "type": "string", + "description": "User's answer; answer also creates a Personal Model claim.", + }, + "reason": { + "type": "string", + "description": "Why this question exists or changed.", + }, + "priority": { + "type": "number", + "minimum": 0, + "maximum": 1, + "description": "Priority from 0.0 to 1.0 for ordering open questions.", + }, + "sensitivity": { + "type": "string", + "enum": ["low", "medium", "high"], + "description": "How sensitive the question is for the user.", + }, + "limit": { + "type": "integer", + "minimum": 1, + "maximum": 20, + "description": "Maximum question rows to return.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -566,12 +842,32 @@ def builtin_tool_definitions( backend="runtime", description="Write or update a diary entry for a specific date. Content should be reflective markdown prose in the user's first language.", audience="both", - schema=_object_schema(required=("entry_date", "content"), properties={ - "entry_date": {"type": "string", "description": "YYYY-MM-DD date for the entry."}, - "content": {"type": "string", "description": "Markdown diary content (2-4 paragraphs)."}, - "source_episode_ids": {"type": "array", "items": {"type": "string"}, "description": "Source episode IDs."}, - }), - side_effects=ToolSideEffectMetadata(risk_class="low", approval_class="none", writes_state=True, reads_state=False, categories=("diary", "write"), notes="Upserts one entry per date."), + schema=_object_schema( + required=("entry_date", "content"), + properties={ + "entry_date": { + "type": "string", + "description": "YYYY-MM-DD date for the entry.", + }, + "content": { + "type": "string", + "description": "Markdown diary content (2-4 paragraphs).", + }, + "source_episode_ids": { + "type": "array", + "items": {"type": "string"}, + "description": "Source episode IDs.", + }, + }, + ), + side_effects=ToolSideEffectMetadata( + risk_class="low", + approval_class="none", + writes_state=True, + reads_state=False, + categories=("diary", "write"), + notes="Upserts one entry per date.", + ), availability=_availability(diary_reason is None, diary_reason), ), _builtin_tool( @@ -581,11 +877,29 @@ def builtin_tool_definitions( backend="runtime", description="List recent diary entries. Use to check if an entry already exists for a date.", audience="both", - schema=_object_schema(required=(), properties={ - "limit": {"type": "integer", "minimum": 1, "maximum": 30, "description": "Max entries (default 10)."}, - "before_date": {"type": "string", "description": "Return entries before this YYYY-MM-DD date."}, - }), - side_effects=ToolSideEffectMetadata(risk_class="none", approval_class="none", writes_state=False, reads_state=True, categories=("diary", "read"), notes="Read-only listing."), + schema=_object_schema( + required=(), + properties={ + "limit": { + "type": "integer", + "minimum": 1, + "maximum": 30, + "description": "Max entries (default 10).", + }, + "before_date": { + "type": "string", + "description": "Return entries before this YYYY-MM-DD date.", + }, + }, + ), + side_effects=ToolSideEffectMetadata( + risk_class="none", + approval_class="none", + writes_state=False, + reads_state=True, + categories=("diary", "read"), + notes="Read-only listing.", + ), availability=_availability(diary_reason is None, diary_reason), ), *sub_agents_tool_definitions(reason=sub_agents_reason), @@ -602,7 +916,12 @@ def builtin_tool_definitions( "type": "string", "description": f"Restricted Python snippet; safe imports include {', '.join(sorted(SAFE_CODE_IMPORTS))}. Safe builtins include pow. May call tool('tool.id', {{...}}) for allowed file/web/terminal tools. Direct open(), os, sys, random, and subprocess access are blocked.", }, - "timeout_seconds": {"type": "integer", "minimum": 1, "maximum": 30, "description": "Maximum runtime in seconds."}, + "timeout_seconds": { + "type": "integer", + "minimum": 1, + "maximum": 30, + "description": "Maximum runtime in seconds.", + }, "mode": { "type": "string", "enum": ["project", "strict"], @@ -659,7 +978,19 @@ def builtin_tool_definitions( properties={ "action": { "type": "string", - "enum": ["list", "ls", "add", "create", "inspect", "update", "complete", "reopen", "remove", "delete", "clear"], + "enum": [ + "list", + "ls", + "add", + "create", + "inspect", + "update", + "complete", + "reopen", + "remove", + "delete", + "clear", + ], "description": "Use add|create|list|clear for scratchpad setup; other actions require an item_id.", }, "item_id": {"type": "string"}, @@ -680,7 +1011,12 @@ def builtin_tool_definitions( *skill_tool_definitions(reason=skill_reason), ) return tuple( - enrich_builtin_tool_schema(replace(definition, enabled=enabled_overrides.get(definition.tool_id, definition.enabled))) + enrich_builtin_tool_schema( + replace( + definition, + enabled=enabled_overrides.get(definition.tool_id, definition.enabled), + ) + ) for definition in definitions ) @@ -729,7 +1065,13 @@ def _browser_tool_definitions(*, reason: str | None, vision_reason: str | None) risk_class="medium", approval_class="standard", reads_state=True, - writes_state=tool_id not in {"tool.browser.snapshot", "tool.browser.images", "tool.browser.vision", "tool.browser.console"}, + writes_state=tool_id + not in { + "tool.browser.snapshot", + "tool.browser.images", + "tool.browser.vision", + "tool.browser.console", + }, touches_network=True, categories=("browser", action), notes="Backed by the configured browser bridge when available.", @@ -758,8 +1100,14 @@ def _browser_tool_definitions(*, reason: str | None, vision_reason: str | None) "Click an element in the active browser page by snapshot ref, with selector fallback.", _object_schema( properties={ - "ref": {"type": "string", "description": "Snapshot element ref such as @e3."}, - "selector": {"type": "string", "description": "CSS selector fallback when no ref exists."}, + "ref": { + "type": "string", + "description": "Snapshot element ref such as @e3.", + }, + "selector": { + "type": "string", + "description": "CSS selector fallback when no ref exists.", + }, } ), ), @@ -771,8 +1119,14 @@ def _browser_tool_definitions(*, reason: str | None, vision_reason: str | None) _object_schema( required=("text",), properties={ - "ref": {"type": "string", "description": "Snapshot element ref such as @e3."}, - "selector": {"type": "string", "description": "CSS selector fallback when no ref exists."}, + "ref": { + "type": "string", + "description": "Snapshot element ref such as @e3.", + }, + "selector": { + "type": "string", + "description": "CSS selector fallback when no ref exists.", + }, "text": {"type": "string"}, }, ), @@ -852,6 +1206,7 @@ def _docs_builtin_tool_definitions() -> tuple[ToolDefinition, ...]: ), ) + def _builtin_tool( *, tool_id: str, @@ -897,7 +1252,9 @@ def _object_schema( return schema -def _group_builtin_tools(definitions: tuple[ToolDefinition, ...]) -> dict[str, tuple[ToolDefinition, ...]]: +def _group_builtin_tools( + definitions: tuple[ToolDefinition, ...], +) -> dict[str, tuple[ToolDefinition, ...]]: grouped: dict[str, list[ToolDefinition]] = {} for definition in definitions: grouped.setdefault(definition.family, []).append(definition) @@ -919,25 +1276,41 @@ def _handler_for_tool( return lambda invocation: run_file_read( invocation, cwd=dependencies.resolve_cwd(invocation.session_id), - allowed_roots=(dependencies.cwd, *invocation.context.allowed_roots, *dependencies.additional_allowed_roots), + allowed_roots=( + dependencies.cwd, + *invocation.context.allowed_roots, + *dependencies.additional_allowed_roots, + ), ) if tool_id == "tool.file.write": return lambda invocation: run_file_write( invocation, cwd=dependencies.resolve_cwd(invocation.session_id), - allowed_roots=(dependencies.cwd, *invocation.context.allowed_roots, *dependencies.additional_allowed_roots), + allowed_roots=( + dependencies.cwd, + *invocation.context.allowed_roots, + *dependencies.additional_allowed_roots, + ), ) if tool_id == "tool.file.patch": return lambda invocation: run_file_patch( invocation, cwd=dependencies.resolve_cwd(invocation.session_id), - allowed_roots=(dependencies.cwd, *invocation.context.allowed_roots, *dependencies.additional_allowed_roots), + allowed_roots=( + dependencies.cwd, + *invocation.context.allowed_roots, + *dependencies.additional_allowed_roots, + ), ) if tool_id == "tool.file.search": return lambda invocation: run_file_search( invocation, cwd=dependencies.resolve_cwd(invocation.session_id), - allowed_roots=(dependencies.cwd, *invocation.context.allowed_roots, *dependencies.additional_allowed_roots), + allowed_roots=( + dependencies.cwd, + *invocation.context.allowed_roots, + *dependencies.additional_allowed_roots, + ), ) if tool_id == "tool.web.search": return lambda invocation: run_web_search(invocation, user_agent=dependencies.web_user_agent) @@ -946,19 +1319,29 @@ def _handler_for_tool( if tool_id == "tool.web.extract": return lambda invocation: run_web_extract(invocation, user_agent=dependencies.web_user_agent) if tool_id.startswith("tool.browser."): - return lambda invocation: run_browser_action(invocation, backend=dependencies.browser_backend, vision_analyzer=dependencies.browser_vision_analyzer) + return lambda invocation: run_browser_action( + invocation, + backend=dependencies.browser_backend, + vision_analyzer=dependencies.browser_vision_analyzer, + ) if tool_id == "tool.clarify": return lambda invocation: run_clarify(invocation, surface=dependencies.clarify_surface) if tool_id == "tool.cron.manage": return lambda invocation: run_cron_action(invocation, runtime=dependencies.cron_runtime) if tool_id == "tool.personal_model.search": - return lambda invocation: run_personal_model_search(invocation, surface=dependencies.personal_model_understanding) + return lambda invocation: run_personal_model_search( + invocation, surface=dependencies.personal_model_understanding + ) if tool_id == "tool.conversation.search": return lambda invocation: run_conversation_search(invocation, surface=dependencies.personal_model_understanding) if tool_id == "tool.personal_model.update": - return lambda invocation: run_personal_model_update(invocation, surface=dependencies.personal_model_understanding) + return lambda invocation: run_personal_model_update( + invocation, surface=dependencies.personal_model_understanding + ) if tool_id == "tool.personal_model.questions": - return lambda invocation: run_personal_model_questions(invocation, surface=dependencies.personal_model_understanding) + return lambda invocation: run_personal_model_questions( + invocation, surface=dependencies.personal_model_understanding + ) if tool_id == "tool.code.execute": return lambda invocation: run_code_execute( invocation, @@ -982,4 +1365,11 @@ def _handler_for_tool( return lambda invocation: run_todo_action(invocation, store=dependencies.todo_store) return None -__all__ = ["BuiltinToolDependencies", "builtin_tool_definitions", "register_builtin_tools", "render_builtin_tool_reference_markdown", "render_builtin_tool_summary_markdown"] + +__all__ = [ + "BuiltinToolDependencies", + "builtin_tool_definitions", + "register_builtin_tools", + "render_builtin_tool_reference_markdown", + "render_builtin_tool_summary_markdown", +] diff --git a/packages/tools/builtins_skills.py b/packages/tools/builtins_skills.py index 7a96bf9..8f2091e 100644 --- a/packages/tools/builtins_skills.py +++ b/packages/tools/builtins_skills.py @@ -51,8 +51,14 @@ def skill_tool_definitions(*, reason: str | None) -> tuple[ToolDefinition, ...]: schema=_object_schema( required=("skill_id",), properties={ - "skill_id": {"type": "string", "description": "Installed skill id or local hub reference to inspect."}, - "reference": {"type": "string", "description": "Optional local hub reference if different from skill_id."}, + "skill_id": { + "type": "string", + "description": "Installed skill id or local hub reference to inspect.", + }, + "reference": { + "type": "string", + "description": "Optional local hub reference if different from skill_id.", + }, }, ), side_effects=ToolSideEffectMetadata( @@ -79,16 +85,48 @@ def skill_tool_definitions(*, reason: str | None) -> tuple[ToolDefinition, ...]: properties={ "action": { "type": "string", - "enum": ["install", "enable", "disable", "create", "update", "delete", "remove"], + "enum": [ + "install", + "enable", + "disable", + "create", + "update", + "delete", + "remove", + ], + }, + "skill_id": { + "type": "string", + "description": "Installed skill id, local hub reference, or authored skill id.", + }, + "reference": { + "type": "string", + "description": "Install source reference or path when action=install.", + }, + "display_name": { + "type": "string", + "description": "Authored skill title when action=create or update.", + }, + "summary": { + "type": "string", + "description": "One-line authored skill summary when action=create or update.", + }, + "instruction_text": { + "type": "string", + "description": "Full SKILL.md body when action=create or update.", + }, + "category": { + "type": "string", + "description": "Optional authored skill category bucket.", + }, + "install": { + "type": "boolean", + "description": "Whether to install an authored skill immediately after writing it.", + }, + "overwrite": { + "type": "boolean", + "description": "Whether action=create may overwrite an existing authored skill.", }, - "skill_id": {"type": "string", "description": "Installed skill id, local hub reference, or authored skill id."}, - "reference": {"type": "string", "description": "Install source reference or path when action=install."}, - "display_name": {"type": "string", "description": "Authored skill title when action=create or update."}, - "summary": {"type": "string", "description": "One-line authored skill summary when action=create or update."}, - "instruction_text": {"type": "string", "description": "Full SKILL.md body when action=create or update."}, - "category": {"type": "string", "description": "Optional authored skill category bucket."}, - "install": {"type": "boolean", "description": "Whether to install an authored skill immediately after writing it."}, - "overwrite": {"type": "boolean", "description": "Whether action=create may overwrite an existing authored skill."}, }, ), side_effects=ToolSideEffectMetadata( diff --git a/packages/tools/builtins_sub_agents.py b/packages/tools/builtins_sub_agents.py index 931b55e..0c920c1 100644 --- a/packages/tools/builtins_sub_agents.py +++ b/packages/tools/builtins_sub_agents.py @@ -23,37 +23,74 @@ def sub_agents_tool_definitions(*, reason: str | None) -> tuple[ToolDefinition, properties={ "action": { "type": "string", - "enum": ["run", "start", "status", "check", "join", "wait", "list"], + "enum": [ + "run", + "start", + "status", + "check", + "join", + "wait", + "list", + ], "description": "Whether to run synchronously, start in the background, inspect, wait for, or list sub-agent runs.", }, - "run_id": {"type": "string", "description": "Background sub-agent run id returned by action=start."}, + "run_id": { + "type": "string", + "description": "Background sub-agent run id returned by action=start.", + }, "sub_agent_run_id": { "type": "string", "description": "Alias for run_id when checking or joining a background sub-agent run.", }, - "name": {"type": "string", "description": "Optional label for a single sub-agent task."}, - "task": {"type": "string", "description": "Single assignment. Mutually exclusive with tasks."}, - "prompt": {"type": "string", "description": "Alias for task. Mutually exclusive with tasks."}, + "name": { + "type": "string", + "description": "Optional label for a single sub-agent task.", + }, + "task": { + "type": "string", + "description": "Single assignment. Mutually exclusive with tasks.", + }, + "prompt": { + "type": "string", + "description": "Alias for task. Mutually exclusive with tasks.", + }, "tasks": { "type": "array", "description": "Small parallel pool of assignments. Mutually exclusive with top-level task/prompt. Each child cannot call tool.sub_agents, tool.clarify, or tool.message.send.", "items": _object_schema( required=("task",), properties={ - "name": {"type": "string", "description": "Optional label for this sub-agent task."}, - "task": {"type": "string", "description": "Assignment for this sub-agent task."}, - "prompt": {"type": "string", "description": "Alias for task in a task-list item."}, + "name": { + "type": "string", + "description": "Optional label for this sub-agent task.", + }, + "task": { + "type": "string", + "description": "Assignment for this sub-agent task.", + }, + "prompt": { + "type": "string", + "description": "Alias for task in a task-list item.", + }, "skills": { "oneOf": [ {"type": "array", "items": {"type": "string"}}, {"type": "string"}, - {"type": "object", "additionalProperties": {"type": "boolean"}}, + { + "type": "object", + "additionalProperties": {"type": "boolean"}, + }, ], }, }, ), }, - "max_concurrency": {"type": "integer", "minimum": 1, "maximum": 3, "description": "Maximum parallel child tasks for tasks; default is 3."}, + "max_concurrency": { + "type": "integer", + "minimum": 1, + "maximum": 3, + "description": "Maximum parallel child tasks for tasks; default is 3.", + }, "timeout_seconds": { "type": "integer", "minimum": 0, @@ -64,7 +101,10 @@ def sub_agents_tool_definitions(*, reason: str | None) -> tuple[ToolDefinition, "oneOf": [ {"type": "array", "items": {"type": "string"}}, {"type": "string"}, - {"type": "object", "additionalProperties": {"type": "boolean"}}, + { + "type": "object", + "additionalProperties": {"type": "boolean"}, + }, ], "description": "Skill ids to load for a single top-level task.", }, diff --git a/packages/tools/factory.py b/packages/tools/factory.py index 528f54e..b96b330 100644 --- a/packages/tools/factory.py +++ b/packages/tools/factory.py @@ -9,7 +9,12 @@ from packages.security import SecurityPolicy from .builtins import register_builtin_tools -from .runtime import ApprovalGateway, SecurityApprovalGateway, ToolContextResolver, ToolRuntime +from .runtime import ( + ApprovalGateway, + SecurityApprovalGateway, + ToolContextResolver, + ToolRuntime, +) from .surfaces import BuiltinToolDependencies diff --git a/packages/tools/handlers_code_execution.py b/packages/tools/handlers_code_execution.py index 160adfa..e5cb402 100644 --- a/packages/tools/handlers_code_execution.py +++ b/packages/tools/handlers_code_execution.py @@ -157,7 +157,10 @@ def _run_code_subprocess( timeout_seconds=timeout_seconds, ) tool_call_count = 0 - with stdout_path.open("wb") as stdout_file, stderr_path.open("wb") as stderr_file: + with ( + stdout_path.open("wb") as stdout_file, + stderr_path.open("wb") as stderr_file, + ): process = subprocess.Popen( [child_python, str(runner_path)], cwd=child_cwd, @@ -215,8 +218,28 @@ def _code_subprocess_env( response_dir: Path, timeout_seconds: int, ) -> dict[str, str]: - safe_prefixes = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM", "TMP", "TEMP", "SHELL", "VIRTUAL_ENV", "CONDA") - secret_fragments = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL", "PASSWD", "AUTH") + safe_prefixes = ( + "PATH", + "HOME", + "USER", + "LANG", + "LC_", + "TERM", + "TMP", + "TEMP", + "SHELL", + "VIRTUAL_ENV", + "CONDA", + ) + secret_fragments = ( + "KEY", + "TOKEN", + "SECRET", + "PASSWORD", + "CREDENTIAL", + "PASSWD", + "AUTH", + ) env: dict[str, str] = {} for key, value in os.environ.items(): if any(fragment in key.upper() for fragment in secret_fragments): @@ -267,7 +290,11 @@ def _code_child_python(*, mode: str) -> str: def _is_usable_code_python(candidate: Path) -> bool: try: completed = subprocess.run( - [str(candidate), "-c", "import sys; raise SystemExit(0 if sys.version_info >= (3, 8) else 1)"], + [ + str(candidate), + "-c", + "import sys; raise SystemExit(0 if sys.version_info >= (3, 8) else 1)", + ], capture_output=True, timeout=5, ) @@ -292,12 +319,21 @@ def _serve_code_tool_requests( tool_id = str(request.get("tool_id") or "") arguments = request.get("arguments") if isinstance(request.get("arguments"), Mapping) else {} if tool_id not in allowlist: - response: Mapping[str, Any] = {"ok": False, "error": f"tool RPC is not allowed for {tool_id}"} + response: Mapping[str, Any] = { + "ok": False, + "error": f"tool RPC is not allowed for {tool_id}", + } elif tool_call_count >= MAX_CODE_TOOL_CALLS: - response = {"ok": False, "error": f"tool.code.execute exceeded {MAX_CODE_TOOL_CALLS} nested tool calls"} + response = { + "ok": False, + "error": f"tool.code.execute exceeded {MAX_CODE_TOOL_CALLS} nested tool calls", + } elif tool_id == "tool.terminal.exec" and _blocked_code_terminal_arguments(arguments): blocked = ", ".join(sorted(_blocked_code_terminal_arguments(arguments))) - response = {"ok": False, "error": f"tool.code.execute does not allow tool.terminal.exec arguments: {blocked}"} + response = { + "ok": False, + "error": f"tool.code.execute does not allow tool.terminal.exec arguments: {blocked}", + } else: tool_call_count += 1 result = runtime.invoke( @@ -335,7 +371,7 @@ def _write_code_response(path: Path, payload: Mapping[str, Any]) -> None: def _code_runner_source() -> str: - return r''' + return r""" from __future__ import annotations import importlib @@ -454,7 +490,7 @@ def tool(tool_id, arguments=None, /, **kwargs): exec(compile(source, "", "exec"), namespace, locals_dict) if "result" in locals_dict: print("__ELEPHANT_RESULT_JSON__=" + json.dumps(locals_dict["result"], ensure_ascii=False, default=repr)) -'''.lstrip() +""".lstrip() def _read_limited_text(path: Path, *, limit: int) -> str: diff --git a/packages/tools/handlers_continuity.py b/packages/tools/handlers_continuity.py index 6711cff..dc746f9 100644 --- a/packages/tools/handlers_continuity.py +++ b/packages/tools/handlers_continuity.py @@ -8,7 +8,6 @@ from packages.contracts.runtime import ExecutionResult from packages.cron import CronRuntime from .handler_support import ( - coerce_int, optional_string, tool_summary, ) @@ -47,7 +46,11 @@ def run_todo_action( status=_normalize_todo_status(invocation.arguments.get("status")), notes=str(invocation.arguments.get("notes") or ""), ) - return tool_summary(invocation, f"created: {_todo_line(item)}", side_effects=("todo", "scratchpad")) + return tool_summary( + invocation, + f"created: {_todo_line(item)}", + side_effects=("todo", "scratchpad"), + ) if action == "clear": removed = store.clear(session_id) return tool_summary(invocation, f"cleared: {removed}", side_effects=("todo", "scratchpad")) @@ -66,7 +69,10 @@ def run_todo_action( status = { "complete": "done", "reopen": "open", - }.get(action, _normalize_todo_status(invocation.arguments.get("status"), default=current.status)) + }.get( + action, + _normalize_todo_status(invocation.arguments.get("status"), default=current.status), + ) item = store.upsert_item( session_id, item_id=item_id, @@ -75,10 +81,18 @@ def run_todo_action( notes=optional_string(invocation.arguments.get("notes")) or current.notes, work_item_id=current.work_item_id, ) - return tool_summary(invocation, f"updated: {_todo_line(item)}", side_effects=("todo", "scratchpad")) + return tool_summary( + invocation, + f"updated: {_todo_line(item)}", + side_effects=("todo", "scratchpad"), + ) if action in {"remove", "delete"}: removed = store.remove_item(session_id, item_id) - return tool_summary(invocation, f"removed: {_todo_line(removed)}", side_effects=("todo", "scratchpad")) + return tool_summary( + invocation, + f"removed: {_todo_line(removed)}", + side_effects=("todo", "scratchpad"), + ) raise ValueError(f"tool.todo.manage does not support action={action!r}") @@ -93,10 +107,12 @@ def run_cron_action(invocation: ToolInvocation, *, runtime: CronRuntime | None) profile_id=optional_string(invocation.arguments.get("profile_id")), elephant_id=optional_string(invocation.arguments.get("elephant_id")), ) - summary = "\n".join( - f"{job.job_id} | {job.status} | {job.name} | {job.schedule_text} | {job.action_kind}" - for job in jobs - ) or "" + summary = ( + "\n".join( + f"{job.job_id} | {job.status} | {job.name} | {job.schedule_text} | {job.action_kind}" for job in jobs + ) + or "" + ) return ExecutionResult( execution_id=invocation.invocation_id, episode_id=invocation.session_id, @@ -121,7 +137,9 @@ def run_cron_action(invocation: ToolInvocation, *, runtime: CronRuntime | None) schedule_text=schedule, payload=payload, profile_id=optional_string(invocation.arguments.get("profile_id")), - elephant_id=optional_string(invocation.arguments.get("elephant_id")) or invocation.context.elephant_id or None, + elephant_id=optional_string(invocation.arguments.get("elephant_id")) + or invocation.context.elephant_id + or None, ) return ExecutionResult( execution_id=invocation.invocation_id, diff --git a/packages/tools/handlers_filesystem.py b/packages/tools/handlers_filesystem.py index a5c6f43..ef8c7a0 100644 --- a/packages/tools/handlers_filesystem.py +++ b/packages/tools/handlers_filesystem.py @@ -4,16 +4,12 @@ import difflib from collections.abc import Mapping -from contextlib import contextmanager, redirect_stdout -import io import os from pathlib import Path import select import shutil -import signal import subprocess import sys -import threading import time from typing import Any @@ -151,7 +147,10 @@ def run_terminal_exec( command = str(invocation.arguments.get("command") or "").strip() if not command: raise ValueError("tool.terminal.exec requires a 'command' argument") - allowed_roots = (*invocation.context.allowed_roots, *dependencies.additional_allowed_roots) + allowed_roots = ( + *invocation.context.allowed_roots, + *dependencies.additional_allowed_roots, + ) local_root = dependencies.resolve_cwd(invocation.session_id) cwd = resolve_allowed_path( local_root, @@ -213,7 +212,13 @@ def run_process_action(invocation: ToolInvocation, *, manager: InMemoryProcessMa if action == "wait": managed = manager.wait( process_id, - timeout_seconds=max(1, min(coerce_int(invocation.arguments.get("timeout_seconds"), default=20), 120)), + timeout_seconds=max( + 1, + min( + coerce_int(invocation.arguments.get("timeout_seconds"), default=20), + 120, + ), + ), ) return tool_summary(invocation, _process_summary(managed), side_effects=("process",)) if action == "write": @@ -252,7 +257,13 @@ def run_file_read( _ensure_text_readable(path, raw_path=raw_path) content = path.read_text(encoding="utf-8", errors="replace") offset = max(1, coerce_int(invocation.arguments.get("offset"), default=1)) - limit = max(1, min(coerce_int(invocation.arguments.get("limit"), default=MAX_FILE_READ_LINES), MAX_FILE_READ_LIMIT)) + limit = max( + 1, + min( + coerce_int(invocation.arguments.get("limit"), default=MAX_FILE_READ_LINES), + MAX_FILE_READ_LIMIT, + ), + ) lines = content.splitlines() end_line = min(len(lines), offset + limit - 1) selected = lines[offset - 1 : end_line] @@ -262,9 +273,7 @@ def run_file_read( f"tool.file.read selected {selected_chars:,} characters, above the " f"{MAX_FILE_READ_CHARS:,} character limit; use a smaller offset/limit window" ) - numbered = "\n".join( - f"{index}|{_truncate_line(line)}" for index, line in enumerate(selected, start=offset) - ) + numbered = "\n".join(f"{index}|{_truncate_line(line)}" for index, line in enumerate(selected, start=offset)) truncated = end_line < len(lines) header = [ f"path: {path}", @@ -637,7 +646,12 @@ def _run_v4a_patch( diffs.append(_unified_diff(old_content, new_content, path)) else: raise ValueError(f"unsupported patch operation: {op}") - lint_lines = tuple(filter(None, (_lint_after_write(path) for path in (*modified, *created) if path.exists()))) + lint_lines = tuple( + filter( + None, + (_lint_after_write(path) for path in (*modified, *created) if path.exists()), + ) + ) lines = [ "mode: patch", f"files_modified: {', '.join(str(path) for path in modified) or ''}", @@ -665,7 +679,12 @@ def _plan_unified_diff_patch( is_add = old_path == "/dev/null" is_delete = new_path == "/dev/null" raw_path = new_path if not is_delete else old_path - path = resolve_allowed_path(cwd, _strip_diff_path(raw_path), must_exist=not is_add, allowed_roots=allowed_roots) + path = resolve_allowed_path( + cwd, + _strip_diff_path(raw_path), + must_exist=not is_add, + allowed_roots=allowed_roots, + ) _ensure_safe_write_path(path) if is_add: if path.exists(): @@ -713,7 +732,12 @@ def _apply_unified_diff_changes(invocation: ToolInvocation, changes: list[dict[s else: modified.append(path) diffs.append(_unified_diff(old_content, new_content, path)) - lint_lines = tuple(filter(None, (_lint_after_write(path) for path in (*modified, *created) if path.exists()))) + lint_lines = tuple( + filter( + None, + (_lint_after_write(path) for path in (*modified, *created) if path.exists()), + ) + ) lines = [ "mode: patch", "format: unified-diff", @@ -863,7 +887,11 @@ def _parse_v4a_patch(patch_text: str) -> list[dict[str, Any]]: if line.startswith("*** Add File: "): if current is not None: operations.append(current) - current = {"op": "add", "path": line.removeprefix("*** Add File: ").strip(), "new_lines": []} + current = { + "op": "add", + "path": line.removeprefix("*** Add File: ").strip(), + "new_lines": [], + } continue if line.startswith("*** Delete File: "): if current is not None: diff --git a/packages/tools/handlers_network.py b/packages/tools/handlers_network.py index d6d7cad..3c4ceb6 100644 --- a/packages/tools/handlers_network.py +++ b/packages/tools/handlers_network.py @@ -22,7 +22,12 @@ truncate, ) from .runtime import ToolInvocation -from .surfaces import BrowserToolBackend, BrowserVisionAnalyzer, ClarifySurface, MessageDeliverySurface +from .surfaces import ( + BrowserToolBackend, + BrowserVisionAnalyzer, + ClarifySurface, + MessageDeliverySurface, +) _WEB_SEARCH_SUMMARY_LIMIT = 4_800 @@ -144,14 +149,18 @@ def run_web_search(invocation: ToolInvocation, *, user_agent: str) -> Mapping[st fallback_lines = _run_duckduckgo_instant_answer(candidate, user_agent=user_agent, limit=limit) except Exception as error: if search_error is not None: - raise RuntimeError(f"web search failed after HTML and fallback attempts: {search_error}; {error}") from error + raise RuntimeError( + f"web search failed after HTML and fallback attempts: {search_error}; {error}" + ) from error raise if fallback_lines: fallback_query = candidate break if search_error is not None and not fallback_lines: raise RuntimeError(f"web search failed: {search_error}") from search_error - summary_lines = [f"search: {fallback_query}", *fallback_lines] if fallback_lines else [f"no web results for query: {query}"] + summary_lines = ( + [f"search: {fallback_query}", *fallback_lines] if fallback_lines else [f"no web results for query: {query}"] + ) return tool_summary( invocation, truncate("\n".join(summary_lines), limit=_WEB_SEARCH_SUMMARY_LIMIT), diff --git a/packages/tools/handlers_personal_model.py b/packages/tools/handlers_personal_model.py index 6c85df7..187c3f0 100644 --- a/packages/tools/handlers_personal_model.py +++ b/packages/tools/handlers_personal_model.py @@ -7,7 +7,10 @@ from datetime import datetime, timezone from typing import Any -from packages.understanding.personal_model_governance import ensure_valid_facet, is_protected_topic +from packages.understanding.personal_model_governance import ( + ensure_valid_facet, + is_protected_topic, +) from .handler_support import coerce_bool, coerce_int, optional_string, tool_summary from .runtime import ToolInvocation @@ -96,13 +99,11 @@ def _check_topic_duplicate( if existing_topic == topic: fact_text = str(getattr(fact, "text", "") or "").strip() # Detect contradiction: same topic but different content - is_contradiction = ( - new_text - and fact_text - and new_text.strip().lower() != fact_text.strip().lower() - ) + is_contradiction = new_text and fact_text and new_text.strip().lower() != fact_text.strip().lower() status_label = "contradiction" if is_contradiction else "duplicate_topic" - hint_action = "This appears to be updated information" if is_contradiction else "An active claim already exists" + hint_action = ( + "This appears to be updated information" if is_contradiction else "An active claim already exists" + ) return ( f"action: remember\n" f"status: {status_label}\n" @@ -145,7 +146,9 @@ def _lines_for_claims(result: Mapping[str, Any]) -> list[str]: status = str(claim.get("status") or "").strip() or "-" protected = str(claim.get("protected") or "").strip() protected_suffix = f" protected={protected}" if protected else "" - lines.append(f"- [{lens}/{topic}] ref={ref} status={status} policy={policy}{protected_suffix} updated={updated}") + lines.append( + f"- [{lens}/{topic}] ref={ref} status={status} policy={policy}{protected_suffix} updated={updated}" + ) lines.append(f" text: {text}") diagnostics = result.get("diagnostics") if isinstance(diagnostics, Mapping): @@ -168,7 +171,11 @@ def _lines_for_claims(result: Mapping[str, Any]) -> list[str]: lines.append( f"health_report: active={health.get('total_active_claims', 0)} retired={health.get('total_retired_claims', 0)} disputed={health.get('total_disputed_claims', 0)} topics={health.get('total_topics', 0)}" ) - for key in ("conflicting_claim_candidates", "review_claims_overdue", "cleanup_suggestions"): + for key in ( + "conflicting_claim_candidates", + "review_claims_overdue", + "cleanup_suggestions", + ): rows = tuple(health.get(key) or ()) if rows: lines.append(f"{key}: {len(rows)}") @@ -181,7 +188,9 @@ def _lines_for_claims(result: Mapping[str, Any]) -> list[str]: reason = str(item.get("relation_reason") or "").strip() suffix = f" relation={relation}" if relation else "" suffix += f" reason={reason}" if reason else "" - lines.append(f" - [{item.get('lens', '')}/{item.get('topic', '')}] {item.get('text', '')} ({item.get('ref', '')}){suffix}") + lines.append( + f" - [{item.get('lens', '')}/{item.get('topic', '')}] {item.get('text', '')} ({item.get('ref', '')}){suffix}" + ) return lines @@ -205,7 +214,8 @@ def run_personal_model_search( limit=max(1, min(coerce_int(invocation.arguments.get("limit"), default=12), 30)), status=_resolve_search_status(optional_string(invocation.arguments.get("status"))), ref=optional_string(invocation.arguments.get("ref")) or "", - personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id, + personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) + or invocation.context.personal_model_id, mode=mode, ) return tool_summary( @@ -221,7 +231,9 @@ def _run_inventory_search( surface: PersonalModelUnderstandingSurface, ) -> Mapping[str, Any]: """Return lens→topic list with claim counts. No content returned.""" - personal_model_id = optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id + personal_model_id = ( + optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id + ) lens_filter = optional_string(invocation.arguments.get("lens")) or "" status_filter = _resolve_search_status(optional_string(invocation.arguments.get("status"))) pm_id = surface._personal_model_id(invocation.session_id, personal_model_id) # noqa: SLF001 @@ -237,6 +249,7 @@ def _run_inventory_search( facts = () # Group by lens → topic with count from collections import defaultdict + inventory: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) for fact in facts: fact_lens = str(getattr(fact, "lens", "") or "").strip() @@ -245,7 +258,12 @@ def _run_inventory_search( if fact_lens and topic: inventory[fact_lens][topic] += 1 # Format output - lines = [f"personal_model_id: {pm_id}", f"mode: inventory", f"status: {status_filter}", f"total_claims: {len(facts)}"] + lines = [ + f"personal_model_id: {pm_id}", + "mode: inventory", + f"status: {status_filter}", + f"total_claims: {len(facts)}", + ] lens_order = ["identity", "world", "pulse", "journey"] for lens in lens_order: topics = inventory.get(lens, {}) @@ -274,13 +292,33 @@ def _conversation_time_range(arguments: Mapping[str, Any]) -> Mapping[str, str]: raw = arguments.get("time_range") out: dict[str, str] = {} if isinstance(raw, Mapping): - for key in ("expr", "start_at", "end_at", "start", "end", "timezone", "tz", "search_start_at", "search_end_at"): + for key in ( + "expr", + "start_at", + "end_at", + "start", + "end", + "timezone", + "tz", + "search_start_at", + "search_end_at", + ): value = optional_string(raw.get(key)) if value: out[key] = value elif isinstance(raw, str) and raw.strip(): out["expr"] = raw.strip() - for key in ("expr", "start_at", "end_at", "start", "end", "timezone", "tz", "search_start_at", "search_end_at"): + for key in ( + "expr", + "start_at", + "end_at", + "start", + "end", + "timezone", + "tz", + "search_start_at", + "search_end_at", + ): value = optional_string(arguments.get(key)) if value: out[key] = value @@ -304,7 +342,8 @@ def run_conversation_search( preview=optional_string(invocation.arguments.get("preview")) or "anchors", view=optional_string(invocation.arguments.get("view")) or "conversation", limit=max(1, min(coerce_int(invocation.arguments.get("limit"), default=8), 30)), - personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id, + personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) + or invocation.context.personal_model_id, include_current_episode=True, ) lines = [ @@ -326,7 +365,9 @@ def run_conversation_search( for item in ranges[:8]: if not isinstance(item, Mapping): continue - lines.append(f"- {item.get('range_id', '')} {item.get('start_at', '')}..{item.get('end_at', '')} score={item.get('score', 0)} count={item.get('count', 0)} by_kind={item.get('by_kind', {})}") + lines.append( + f"- {item.get('range_id', '')} {item.get('start_at', '')}..{item.get('end_at', '')} score={item.get('score', 0)} count={item.get('count', 0)} by_kind={item.get('by_kind', {})}" + ) time_range = item.get("time_range") if isinstance(time_range, Mapping) and time_range: lines.append( @@ -373,7 +414,8 @@ def run_personal_model_inspect( topic=optional_string(invocation.arguments.get("topic")) or "", query=optional_string(invocation.arguments.get("query")) or "", limit=max(1, min(coerce_int(invocation.arguments.get("limit"), default=5), 10)), - personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id, + personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) + or invocation.context.personal_model_id, ) lines = [ f"personal_model_id: {result.get('personal_model_id', '')}", @@ -401,7 +443,8 @@ def run_personal_model_audit( action=optional_string(invocation.arguments.get("action")) or "health", lens=optional_string(invocation.arguments.get("lens")) or "", limit=max(1, min(coerce_int(invocation.arguments.get("limit"), default=30), 100)), - personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id, + personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) + or invocation.context.personal_model_id, ) resolved_action = str(result.get("action", "") or "") lines = [f"action: {resolved_action}"] @@ -416,11 +459,21 @@ def run_personal_model_audit( lines.append( f"health_report: active={health.get('total_active_claims', 0)} retired={health.get('total_retired_claims', 0)} disputed={health.get('total_disputed_claims', 0)} topics={health.get('total_topics', 0)}" ) - for key in ("conflicting_claim_candidates", "review_claims_overdue", "current_claims_stale", "retired_chain_candidates", "cleanup_suggestions"): + for key in ( + "conflicting_claim_candidates", + "review_claims_overdue", + "current_claims_stale", + "retired_chain_candidates", + "cleanup_suggestions", + ): rows = tuple(health.get(key) or ()) if rows: lines.append(f"{key}: {len(rows)}") - if resolved_action == "stale" and not tuple(health.get("review_claims_overdue") or ()) and not tuple(health.get("current_claims_stale") or ()): + if ( + resolved_action == "stale" + and not tuple(health.get("review_claims_overdue") or ()) + and not tuple(health.get("current_claims_stale") or ()) + ): lines.append("stale: none") return tool_summary(invocation, "\n".join(lines), side_effects=("personal_model", "audit")) @@ -524,16 +577,16 @@ def run_personal_model_update( if action in {"remember", "correct"} and not text: raise ValueError(f"tool.personal_model.update action={action} requires 'text'") if source == "learned" and action in {"remember", "correct"} and _looks_like_internal_learning_artifact(text): - raise ValueError("learned Personal Model facts cannot store internal learning, validation, dashboard, or question-bank bookkeeping text") + raise ValueError( + "learned Personal Model facts cannot store internal learning, validation, dashboard, or question-bank bookkeeping text" + ) lens = optional_string(invocation.arguments.get("lens")) or "" topic = optional_string(invocation.arguments.get("topic")) or "" # Validate lens-prefixed topic format if topic and lens and action in {"remember", "correct"}: topic_parts = topic.split(".") if len(topic_parts) < 3: - raise ValueError( - f"topic must have at least 3 dot-separated segments (lens.domain.entity): {topic!r}" - ) + raise ValueError(f"topic must have at least 3 dot-separated segments (lens.domain.entity): {topic!r}") if topic_parts[0] != lens: raise ValueError( f"topic first segment must match lens: topic={topic!r} but lens={lens!r}. " @@ -545,7 +598,9 @@ def run_personal_model_update( reason = optional_string(invocation.arguments.get("reason")) or "" if action in {"delete", "restore"} and not ref: raise ValueError(f"tool.personal_model.update action={action} requires exact 'ref' from personal_model.search") - personal_model_id = optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id + personal_model_id = ( + optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id + ) if action == "delete": result = _delete_personal_model_claim( invocation, @@ -559,7 +614,14 @@ def run_personal_model_update( else: # Anti-duplication guard: warn if same topic already exists for remember if action == "remember" and topic and not ref: - duplicate_hint = _check_topic_duplicate(surface, invocation.session_id, personal_model_id, lens, topic, new_text=text) + duplicate_hint = _check_topic_duplicate( + surface, + invocation.session_id, + personal_model_id, + lens, + topic, + new_text=text, + ) if duplicate_hint: return tool_summary( invocation, @@ -606,7 +668,9 @@ def run_personal_model_update( reason = str(item.get("relation_reason") or "").strip() suffix = f" relation={relation}" if relation else "" suffix += f" reason={reason}" if reason else "" - lines.append(f"- [{item.get('lens', '')}/{item.get('topic', '')}] {item.get('text', '')} ({item.get('ref', '')}){suffix}") + lines.append( + f"- [{item.get('lens', '')}/{item.get('topic', '')}] {item.get('text', '')} ({item.get('ref', '')}){suffix}" + ) return tool_summary( invocation, "\n".join(lines), @@ -624,13 +688,22 @@ def run_personal_model_questions( result = surface.manage_personal_model_questions( invocation.session_id, action=optional_string(invocation.arguments.get("action")) or "", - personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) or invocation.context.personal_model_id, - question_id=optional_string(invocation.arguments.get("question_id")) or optional_string(invocation.arguments.get("ref")) or "", + personal_model_id=optional_string(invocation.arguments.get("personal_model_id")) + or invocation.context.personal_model_id, + question_id=optional_string(invocation.arguments.get("question_id")) + or optional_string(invocation.arguments.get("ref")) + or "", status=optional_string(invocation.arguments.get("status")) or "", lens=optional_string(invocation.arguments.get("lens")) or "", - sub_lens=optional_string(invocation.arguments.get("topic")) or optional_string(invocation.arguments.get("sub_lens")) or "", - text=optional_string(invocation.arguments.get("text")) or optional_string(invocation.arguments.get("question")) or "", - rationale=optional_string(invocation.arguments.get("reason")) or optional_string(invocation.arguments.get("rationale")) or "", + sub_lens=optional_string(invocation.arguments.get("topic")) + or optional_string(invocation.arguments.get("sub_lens")) + or "", + text=optional_string(invocation.arguments.get("text")) + or optional_string(invocation.arguments.get("question")) + or "", + rationale=optional_string(invocation.arguments.get("reason")) + or optional_string(invocation.arguments.get("rationale")) + or "", priority=invocation.arguments.get("priority"), sensitivity=optional_string(invocation.arguments.get("sensitivity")) or "", source=optional_string(invocation.arguments.get("source")) or "contextual", @@ -650,7 +723,9 @@ def run_personal_model_questions( lines.append(f"questions: {len(questions)}") for question in questions[:5]: if isinstance(question, Mapping): - lines.append(f"- [{question.get('lens', '')}/{question.get('sub_lens', '')}] {question.get('text', '')}") + lines.append( + f"- [{question.get('lens', '')}/{question.get('sub_lens', '')}] {question.get('text', '')}" + ) question = result.get("question") if isinstance(question, Mapping): lines.append(f"question_id: {question.get('question_id', '')}") diff --git a/packages/tools/handlers_skills.py b/packages/tools/handlers_skills.py index e6f182e..d6310b0 100644 --- a/packages/tools/handlers_skills.py +++ b/packages/tools/handlers_skills.py @@ -174,9 +174,7 @@ def run_skill_manage( side_effects=("skill", "delete"), ) ) - raise ValueError( - "tool.skill.manage requires action=install|enable|disable|create|update|delete" - ) + raise ValueError("tool.skill.manage requires action=install|enable|disable|create|update|delete") def _required_field(invocation: ToolInvocation, name: str) -> str: diff --git a/packages/tools/handlers_sub_agents.py b/packages/tools/handlers_sub_agents.py index ed8f78d..3110897 100644 --- a/packages/tools/handlers_sub_agents.py +++ b/packages/tools/handlers_sub_agents.py @@ -29,7 +29,11 @@ def run_sub_agents_action( wait_timeout_seconds = None if action in {"join", "wait"}: wait_timeout_seconds = float( - _int_value(invocation.arguments.get("timeout_seconds"), default=3600, name="timeout_seconds") + _int_value( + invocation.arguments.get("timeout_seconds"), + default=3600, + name="timeout_seconds", + ) ) result = surface.inspect_sub_agent_run( session_id=invocation.session_id, @@ -62,7 +66,11 @@ def run_sub_agents_action( result = runner( session_id=invocation.session_id, tasks=tasks, - max_concurrency=_int_value(invocation.arguments.get("max_concurrency"), default=3, name="max_concurrency"), + max_concurrency=_int_value( + invocation.arguments.get("max_concurrency"), + default=3, + name="max_concurrency", + ), ) if isinstance(result, ExecutionResult): return result diff --git a/packages/tools/mcp.py b/packages/tools/mcp.py index 923911c..5fb23a6 100644 --- a/packages/tools/mcp.py +++ b/packages/tools/mcp.py @@ -11,7 +11,14 @@ from packages.contracts.runtime import ExecutionResult -from .runtime import ToolAvailability, ToolDefinition, ToolHandler, ToolInvocation, ToolRuntime, ToolSideEffectMetadata +from .runtime import ( + ToolAvailability, + ToolDefinition, + ToolHandler, + ToolInvocation, + ToolRuntime, + ToolSideEffectMetadata, +) _MCP_TOOL_VERSION = "1.0.0" _MCP_TOOL_KIND = "custom-mcp" @@ -33,11 +40,7 @@ def sync_custom_mcp_tools( for definition, handler in custom_mcp_runtime_entries(config_path=config_path, config=config, cwd=cwd): desired[definition.tool_id] = (definition, handler) - existing_custom_ids = { - tool.tool_id - for tool in runtime.list_tools() - if _is_custom_mcp_tool(tool) - } + existing_custom_ids = {tool.tool_id for tool in runtime.list_tools() if _is_custom_mcp_tool(tool)} for stale_tool_id in sorted(existing_custom_ids - set(desired)): runtime.unregister_tool(stale_tool_id) for tool_id, (definition, handler) in desired.items(): @@ -56,7 +59,12 @@ def custom_mcp_runtime_entries( overrides = _mapping_rows(config.get("mcp_overrides")) entries: list[tuple[ToolDefinition, ToolHandler]] = [] for server_id, server in sorted(_mapping_rows(config.get("mcp_servers")).items()): - transport = str(server.get("transport") or ("http" if str(server.get("url") or "").strip() else "stdio")).strip().lower() or "stdio" + transport = ( + str(server.get("transport") or ("http" if str(server.get("url") or "").strip() else "stdio")) + .strip() + .lower() + or "stdio" + ) label = str(server.get("label") or server_id).strip() or server_id command = str(server.get("command") or "").strip() url = str(server.get("url") or "").strip() @@ -221,9 +229,7 @@ def _handler(invocation: ToolInvocation) -> ExecutionResult: execution_id=invocation.invocation_id, episode_id=invocation.session_id, outcome="failed", - summary=( - f"MCP tool {server_id}.{tool_name} timed out after {_MCP_CALL_TIMEOUT_MS}ms" - ), + summary=(f"MCP tool {server_id}.{tool_name} timed out after {_MCP_CALL_TIMEOUT_MS}ms"), side_effects=("mcp", f"server={server_id}", f"transport={transport}"), ) except OSError as exc: diff --git a/packages/tools/runtime.py b/packages/tools/runtime.py index a4b55b8..2c6dfc1 100644 --- a/packages/tools/runtime.py +++ b/packages/tools/runtime.py @@ -16,7 +16,13 @@ from packages.capabilities.runtime import CapabilityDescriptor, ToolCapability from packages.contracts.runtime import ExecutionResult -from packages.security import ApprovalClass, PolicyDecision, SecurityPolicy, SecurityRequest, evaluate_with_telemetry +from packages.security import ( + ApprovalClass, + PolicyDecision, + SecurityPolicy, + SecurityRequest, + evaluate_with_telemetry, +) from .local_roots import default_local_allowed_roots @@ -150,20 +156,16 @@ def build_tool_fallback_prompt(tools: tuple[ToolDefinition, ...]) -> str: "Prefer claim refs for correct/forget/dispute when the target is uncertain; restore must use an exact ref from status=all search. " "Use updated claims naturally without narrating storage mechanics unless asked." if has_personal_model_update - else - "Durable user understanding changes need Personal Model update tooling, but it is unavailable. State the " + else "Durable user understanding changes need Personal Model update tooling, but it is unavailable. State the " "intended durable update clearly without pretending it was stored." ) - tool_lines = "; ".join( - f"{tool.display_name} ({tool.tool_id}): {tool.description}" - for tool in tools - ) + tool_lines = "; ".join(f"{tool.display_name} ({tool.tool_id}): {tool.description}" for tool in tools) summaries = " ".join(tool.prompt_summary() for tool in tools) return ( "available-tools: governed built-ins are available through the runtime; " f"{tool_lines}\n" "tool-call-protocol: call governed built-in tools directly when the active provider supports native " - "tool calling. Otherwise emit value" + 'tool calling. Otherwise emit value' "; multiple invoke blocks are allowed, structured values may be " "encoded as JSON inside a parameter body, and the final answer must not include raw tool markup.\n" "tool-usage-discipline: use tools only when they materially advance the current request. " @@ -356,11 +358,7 @@ def authorize( return ToolApprovalResult( decision="approved" if approved else "denied", risk_class=definition.side_effects.risk_class, - reason=( - "approved by callable approval gateway" - if approved - else "blocked by callable approval gateway" - ), + reason=("approved by callable approval gateway" if approved else "blocked by callable approval gateway"), ) @@ -517,9 +515,7 @@ def load_manifest(self, path: Path, loader: ToolLoader | None = None) -> ToolMan candidate = tool if existing is not None: if _tool_identity(existing) != _tool_identity(tool): - raise ValueError( - f"tool is already registered with different metadata: {tool.tool_id}" - ) + raise ValueError(f"tool is already registered with different metadata: {tool.tool_id}") candidate = replace(tool, enabled=existing.enabled) bound = self._register_tool(candidate) if bound: @@ -692,7 +688,11 @@ def invoke( if result.side_effects: final = result else: - final = replace(result, side_effects=definition.side_effects.categories, episode_id=session_id) + final = replace( + result, + side_effects=definition.side_effects.categories, + episode_id=session_id, + ) self._executions.append( ToolExecutionRecord( execution_id=final.execution_id, @@ -904,7 +904,9 @@ def _default_context(session_id: str, requester: ToolRequester | None) -> ToolRu ) -def _resolve_approval_class(side_effects: ToolSideEffectMetadata) -> ApprovalClass | None: +def _resolve_approval_class( + side_effects: ToolSideEffectMetadata, +) -> ApprovalClass | None: raw = side_effects.approval_class.strip().lower() if raw in {"", "none"}: return None diff --git a/packages/tools/schema_descriptions.py b/packages/tools/schema_descriptions.py index 4c8ffd1..9e45e2a 100644 --- a/packages/tools/schema_descriptions.py +++ b/packages/tools/schema_descriptions.py @@ -138,8 +138,7 @@ def _enrich_schema(tool_id: str, schema: Mapping[str, Any], path: tuple[str, ... properties = enriched.get("properties") if isinstance(properties, Mapping): enriched["properties"] = { - str(name): _enrich_property(tool_id, str(name), payload, path) - for name, payload in properties.items() + str(name): _enrich_property(tool_id, str(name), payload, path) for name, payload in properties.items() } return enriched diff --git a/packages/tools/surfaces.py b/packages/tools/surfaces.py index 0ddbd5c..5d886be 100644 --- a/packages/tools/surfaces.py +++ b/packages/tools/surfaces.py @@ -18,6 +18,7 @@ from .local_roots import default_local_allowed_roots from .runtime import ToolInvocation + class PersonalModelUnderstandingSurface(Protocol): def search_personal_model( self, @@ -100,7 +101,6 @@ def manage_personal_model_questions( """Manage proactive questions bound to a Personal Model lens/topic.""" - class BrowserVisionAnalyzer(Protocol): def analyze_browser_screenshot( self, diff --git a/packages/tools/tool_result_storage.py b/packages/tools/tool_result_storage.py index edd75c0..33de110 100644 --- a/packages/tools/tool_result_storage.py +++ b/packages/tools/tool_result_storage.py @@ -28,9 +28,7 @@ class ToolResultBudgetConfig: result_size_chars: int | float = DEFAULT_RESULT_SIZE_CHARS turn_budget_chars: int = DEFAULT_TURN_BUDGET_CHARS preview_size_chars: int = DEFAULT_PREVIEW_SIZE_CHARS - pinned_thresholds: Mapping[str, int | float] = field( - default_factory=lambda: dict(DEFAULT_PINNED_THRESHOLDS) - ) + pinned_thresholds: Mapping[str, int | float] = field(default_factory=lambda: dict(DEFAULT_PINNED_THRESHOLDS)) def maybe_persist_tool_result( diff --git a/packages/understanding/personal_model_governance.py b/packages/understanding/personal_model_governance.py index 5f418cb..cc78f9f 100644 --- a/packages/understanding/personal_model_governance.py +++ b/packages/understanding/personal_model_governance.py @@ -59,9 +59,9 @@ class ProtectedTopicPolicy: ALLOWED_FACETS: dict[str, frozenset[str]] = { "identity": frozenset({"anchor", "character", "values", "style", "body"}), - "world": frozenset({"people", "projects", "tools", "places", "assets", "skills"}), - "pulse": frozenset({"chapter", "focus", "mood", "blockers", "intent"}), - "journey": frozenset({"lessons", "patterns", "decisions", "milestones"}), + "world": frozenset({"people", "projects", "tools", "places", "assets", "skills"}), + "pulse": frozenset({"chapter", "focus", "mood", "blockers", "intent"}), + "journey": frozenset({"lessons", "patterns", "decisions", "milestones"}), } @@ -71,23 +71,28 @@ def ensure_valid_facet(lens: str, facet: str) -> None: return if facet not in allowed: raise ValueError( - f"topic second segment must be a fixed facet for lens {lens!r}: " - f"got {facet!r}, allowed {sorted(allowed)}" + f"topic second segment must be a fixed facet for lens {lens!r}: got {facet!r}, allowed {sorted(allowed)}" ) _SYSTEM_PROTECTED_TOPICS: dict[str, ProtectedTopicPolicy] = { # identity — who the person is (anchor, character, style, body) "identity.anchor.name.preferred": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "identity"), - "identity.anchor.gender.self_description": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "identity"), + "identity.anchor.gender.self_description": ProtectedTopicPolicy( + "system", "init_core_profile", "core_prompt", "identity" + ), "identity.anchor.birth.date": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "identity"), "identity.anchor.age.current": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "identity"), "identity.character.mbti.type": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "identity"), "identity.character.rhythm.pressure": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "rhythm"), "identity.character.rhythm.recovery": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "rhythm"), "identity.character.decision.compass": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "rhythm"), - "identity.style.language.first": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "communication"), - "identity.style.companion.posture": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "collaboration"), + "identity.style.language.first": ProtectedTopicPolicy( + "system", "init_core_profile", "core_prompt", "communication" + ), + "identity.style.companion.posture": ProtectedTopicPolicy( + "system", "init_core_profile", "core_prompt", "collaboration" + ), "identity.style.hobbies.personal": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "preference"), "identity.body.safety.boundary": ProtectedTopicPolicy("system", "init_core_profile", "core_prompt", "safety"), # world — what is around the person (people, projects, tools, places, assets) @@ -175,7 +180,13 @@ def parse_topic_path(topic: object) -> TopicPath | None: return None if not all(_TOPIC_SEGMENT_RE.fullmatch(part) for part in parts): return None - return TopicPath(raw=normalized, domain=parts[0], entity=parts[1], aspect=parts[2], qualifier=parts[3:]) + return TopicPath( + raw=normalized, + domain=parts[0], + entity=parts[1], + aspect=parts[2], + qualifier=parts[3:], + ) def valid_topic_key(topic: object) -> str: @@ -226,7 +237,12 @@ def topic_relation_weight(left: object, right: object) -> float: def policy_for_topic(topic: object) -> TopicPolicy: if parse_topic_path(topic) is None: return TopicPolicy() - return TopicPolicy(active_cardinality="single", default_recall_policy="review", review_after_days=14, projection_visible=True) + return TopicPolicy( + active_cardinality="single", + default_recall_policy="review", + review_after_days=14, + projection_visible=True, + ) def is_single_active_topic(topic: object) -> bool: @@ -269,12 +285,13 @@ def is_skill_affinity_topic(topic: object) -> bool: resolved = valid_topic_key(topic) return resolved.startswith("world.skills.affinity.") or resolved.startswith("skills.affinity.") + def skill_affinity_index_id(topic: object) -> str: resolved = valid_topic_key(topic) if resolved.startswith("world.skills.affinity."): - return resolved[len("world.skills.affinity."):] + return resolved[len("world.skills.affinity.") :] if resolved.startswith("skills.affinity."): - return resolved[len("skills.affinity."):] + return resolved[len("skills.affinity.") :] return "" @@ -320,13 +337,28 @@ def relation_payload( fact_numbers = set(numeric_mentions(fact.text)) numeric_conflict = bool(source_numbers and fact_numbers and not source_numbers.issubset(fact_numbers)) if relation == TopicRelation.SAME_TOPIC: - scope, matched_by, reason, score = "same_topic", "topic_path", "same topic path", 1.0 + scope, matched_by, reason, score = ( + "same_topic", + "topic_path", + "same topic path", + 1.0, + ) elif relation == TopicRelation.SAME_ENTITY: entity = topic_entity_key(source_topic) - scope, matched_by, reason, score = "same_entity", "topic_path", f"same topic entity {entity}", 0.75 + scope, matched_by, reason, score = ( + "same_entity", + "topic_path", + f"same topic entity {entity}", + 0.75, + ) elif relation == TopicRelation.SAME_DOMAIN: domain = topic_prefix(source_topic) - scope, matched_by, reason, score = "same_domain", "topic_path", f"same topic domain {domain}", 0.35 + scope, matched_by, reason, score = ( + "same_domain", + "topic_path", + f"same topic domain {domain}", + 0.35, + ) elif numeric_conflict and overlap >= 0.35: scope, matched_by, reason, score = ( "numeric_conflict", @@ -335,9 +367,19 @@ def relation_payload( max(0.72, overlap), ) elif overlap >= 0.45: - scope, matched_by, reason, score = "text_overlap", "claim_text_overlap", "claim text overlaps with selected claim", overlap + scope, matched_by, reason, score = ( + "text_overlap", + "claim_text_overlap", + "claim text overlaps with selected claim", + overlap, + ) elif similarity >= 0.55 or (source_topic and (source_topic in fact_topic or fact_topic in source_topic)): - scope, matched_by, reason, score = "similar_topic", "topic_similarity", "topic keys are lexically similar", similarity + scope, matched_by, reason, score = ( + "similar_topic", + "topic_similarity", + "topic keys are lexically similar", + similarity, + ) else: return None return { @@ -418,7 +460,10 @@ def claim_payload(fact: Fact) -> dict[str, Any]: "review_after_days": metadata.get("review_after_days", ""), "protected": protection.protection if protection is not None else "", "protected_reason": protection.reason if protection is not None else "", - "projection_policy": metadata.get("projection_policy", protection.projection_policy if protection is not None else ""), + "projection_policy": metadata.get( + "projection_policy", + protection.projection_policy if protection is not None else "", + ), "facet": metadata.get("facet", protection.facet if protection is not None else ""), } @@ -481,7 +526,13 @@ def related_claims_for_selection( topic = clean((fact.metadata or {}).get("topic")) if not topic: continue - for item in similar_topic_payloads(facts, topic=topic, text=fact.text, exclude_refs=(fact.fact_id,), limit=limit): + for item in similar_topic_payloads( + facts, + topic=topic, + text=fact.text, + exclude_refs=(fact.fact_id,), + limit=limit, + ): ref = str(item.get("ref") or "") if not ref or ref in seen: continue @@ -502,7 +553,11 @@ def narrowing_suggestions( ) -> tuple[dict[str, str], ...]: if not selected: return () - topics = tuple(dict.fromkeys(clean((fact.metadata or {}).get("topic")) for fact in selected if clean((fact.metadata or {}).get("topic")))) + topics = tuple( + dict.fromkeys( + clean((fact.metadata or {}).get("topic")) for fact in selected if clean((fact.metadata or {}).get("topic")) + ) + ) lenses = tuple(dict.fromkeys(fact.lens for fact in selected if fact.lens)) ambiguous = len(selected) >= min(max(limit, 1), 5) or len(topics) >= 3 or len(lenses) >= 2 if not ambiguous: @@ -511,13 +566,29 @@ def narrowing_suggestions( if topics: reason += f" across {len(topics)} topics" suggestions = [ - {"reason": reason, "suggestion": "retry with topic or ref when locating one known claim"}, - {"reason": "verification needs a precise target", "suggestion": "retry tool.personal_model.search with an exact topic, ref, or claim phrase"}, + { + "reason": reason, + "suggestion": "retry with topic or ref when locating one known claim", + }, + { + "reason": "verification needs a precise target", + "suggestion": "retry tool.personal_model.search with an exact topic, ref, or claim phrase", + }, ] if not lens and lenses: - suggestions.append({"reason": f"matches span lenses: {', '.join(lenses[:4])}", "suggestion": "add lens to constrain the owner surface"}) + suggestions.append( + { + "reason": f"matches span lenses: {', '.join(lenses[:4])}", + "suggestion": "add lens to constrain the owner surface", + } + ) if not topic and topics: - suggestions.append({"reason": "multiple topic keys matched", "suggestion": f"add topic, e.g. {', '.join(topics[:4])}"}) + suggestions.append( + { + "reason": "multiple topic keys matched", + "suggestion": f"add topic, e.g. {', '.join(topics[:4])}", + } + ) return tuple(suggestions) @@ -551,8 +622,14 @@ def _numeric_conflict_payloads(facts: tuple[Fact, ...]) -> tuple[dict[str, Any], "lens": lens, "topic_key": topic_key, "refs": (left.fact_id, right.fact_id), - "topics": (clean((left.metadata or {}).get("topic")), clean((right.metadata or {}).get("topic"))), - "values": (tuple(sorted(left_numbers)), tuple(sorted(right_numbers))), + "topics": ( + clean((left.metadata or {}).get("topic")), + clean((right.metadata or {}).get("topic")), + ), + "values": ( + tuple(sorted(left_numbers)), + tuple(sorted(right_numbers)), + ), "reason": "active claims share a topic key but contain different numeric values", } ) @@ -595,7 +672,9 @@ def personal_model_health_report(facts: tuple[Fact, ...], *, now: datetime | Non without_policy.append({"ref": fact.fact_id, "lens": fact.lens, "topic": topic}) if not reason: without_reason.append({"ref": fact.fact_id, "lens": fact.lens, "topic": topic}) - verified = _parse_datetime_value(metadata.get("last_verified_at") or metadata.get("verified_at")) or fact.committed_at + verified = ( + _parse_datetime_value(metadata.get("last_verified_at") or metadata.get("verified_at")) or fact.committed_at + ) age_days = max(0, (current - verified).days) if policy == "review": try: @@ -603,18 +682,37 @@ def personal_model_health_report(facts: tuple[Fact, ...], *, now: datetime | Non except ValueError: review_days = 14 if age_days > review_days: - review_overdue.append({"ref": fact.fact_id, "lens": fact.lens, "topic": topic, "age_days": str(age_days)}) + review_overdue.append( + { + "ref": fact.fact_id, + "lens": fact.lens, + "topic": topic, + "age_days": str(age_days), + } + ) if policy == "current" and age_days > 30: - current_stale.append({"ref": fact.fact_id, "lens": fact.lens, "topic": topic, "age_days": str(age_days)}) + current_stale.append( + { + "ref": fact.fact_id, + "lens": fact.lens, + "topic": topic, + "age_days": str(age_days), + } + ) retired_chain_candidates = tuple( { "lens": lens, "topic": topic, "retired_count": str(len(bucket)), - "refs": tuple(fact.fact_id for fact in sorted(bucket, key=lambda item: item.committed_at, reverse=True)[:8]), + "refs": tuple( + fact.fact_id for fact in sorted(bucket, key=lambda item: item.committed_at, reverse=True)[:8] + ), "reason": "topic has a long retired chain; keep for audit, but review whether old links still need dashboard prominence", } - for (lens, topic), bucket in sorted(_facts_by_topic_key(retired).items(), key=lambda item: (-len(item[1]), item[0][0], item[0][1])) + for (lens, topic), bucket in sorted( + _facts_by_topic_key(retired).items(), + key=lambda item: (-len(item[1]), item[0][0], item[0][1]), + ) if len(bucket) >= 5 ) cleanup_suggestions = [ diff --git a/packages/understanding/runtime.py b/packages/understanding/runtime.py index 46c92c6..8ccdde1 100644 --- a/packages/understanding/runtime.py +++ b/packages/understanding/runtime.py @@ -4,6 +4,7 @@ active four-lens claims, evidence summaries, and question rows. Free-form notes are evidence, not a Personal Model write surface. """ + from __future__ import annotations from collections.abc import Mapping from dataclasses import replace @@ -12,10 +13,25 @@ import unicodedata from typing import Any from packages.contracts import ALLOWED_LENSES, Fact -from packages.evidence import UnifiedRecallRequest, conversation_scopes_for_view, infer_recall_lifecycle_metadata, recall_time_range_from_payload, recall_timeline, render_recall_hit, unified_recall +from packages.evidence import ( + UnifiedRecallRequest, + conversation_scopes_for_view, + infer_recall_lifecycle_metadata, + recall_time_range_from_payload, + recall_timeline, + render_recall_hit, + unified_recall, +) from packages.curiosity.question_tool_surface import CuriosityQuestionManagementSurface -from packages.storage.repository_support import DEFAULT_PERSONAL_MODEL_ID, canonical_personal_model_id -from .semantic_search_support import fallback_pm_search, keyword_boost, rank_facts_by_semantic_queries +from packages.storage.repository_support import ( + DEFAULT_PERSONAL_MODEL_ID, + canonical_personal_model_id, +) +from .semantic_search_support import ( + fallback_pm_search, + keyword_boost, + rank_facts_by_semantic_queries, +) from .temporal_policy import freshness_score from .personal_model_governance import ( claim_payload, @@ -32,13 +48,20 @@ topic_rows, topic_tree, ) + _ALLOWED_ACTIONS = frozenset({"remember", "correct", "forget", "dispute", "restore"}) _ALLOWED_SOURCES = frozenset({"user_said", "user_corrected", "learned"}) _ALLOWED_SEARCH_STATUSES = frozenset({"active", "retired", "disputed", "all"}) + + def _utc_now() -> datetime: return datetime.now(timezone.utc) + + def _clean(value: object) -> str: return str(value or "").strip() + + def _rerank_by_freshness(facts: list[Fact], *, now: datetime) -> list[Fact]: """Apply volatility freshness as a small penalty without overriding relevance rank.""" scored: list[tuple[float, int, Fact]] = [] @@ -58,35 +81,47 @@ def _rerank_by_freshness(facts: list[Fact], *, now: datetime) -> list[Fact]: scored.append((float(rank) + freshness_offset, rank, fact)) scored.sort(key=lambda item: (item[0], item[1])) return [fact for _, _, fact in scored] + + def _normalized_lens(value: object) -> str: lens = _clean(value).lower() if lens not in ALLOWED_LENSES: raise ValueError(f"lens must be one of {sorted(ALLOWED_LENSES)}") return lens + + def _normalized_action(value: object) -> str: action = _clean(value).lower() if action not in _ALLOWED_ACTIONS: raise ValueError(f"action must be one of {sorted(_ALLOWED_ACTIONS)}") return action + + def _normalized_source(value: object) -> str: source = _clean(value).lower() or "user_said" if source not in _ALLOWED_SOURCES: raise ValueError(f"source must be one of {sorted(_ALLOWED_SOURCES)}") return source + + def _normalized_search_status(value: object) -> str: status = _clean(value).lower() or "active" if status not in _ALLOWED_SEARCH_STATUSES: raise ValueError(f"status must be one of {sorted(_ALLOWED_SEARCH_STATUSES)}") return status + + def _status_filter(status: str) -> str | tuple[str, ...]: if status == "all": return ("active", "retired", "disputed") return status + + def _fact_ref(personal_model_id: str, lens: str, topic: str, text: str) -> str: - digest = hashlib.sha256( - f"{personal_model_id}|{lens}|{topic}|{text}".encode("utf-8") - ).hexdigest()[:18] + digest = hashlib.sha256(f"{personal_model_id}|{lens}|{topic}|{text}".encode("utf-8")).hexdigest()[:18] return f"claim:{digest}" + + def _topic_matches(fact: Fact, *, topic: str, ref: str = "") -> bool: if ref and fact.fact_id == ref: return True @@ -102,10 +137,14 @@ def _topic_matches_filter(fact: Fact, topic_filter: str) -> bool: _QUERY_ALIASES: Mapping[str, tuple[str, ...]] = {} + + def _normalized_text(value: object) -> str: normalized = unicodedata.normalize("NFKC", str(value or "")).casefold() decomposed = unicodedata.normalize("NFKD", normalized) return "".join(ch for ch in decomposed if not unicodedata.combining(ch)) + + def _search_tokens(value: object) -> tuple[str, ...]: normalized = _normalized_text(value) tokens: list[str] = [] @@ -124,6 +163,8 @@ def _search_tokens(value: object) -> tuple[str, ...]: expanded.append(token) expanded.extend(_normalized_text(alias) for alias in _QUERY_ALIASES.get(token, ())) return tuple(token for token in dict.fromkeys(expanded) if token) + + def _token_variants(token: str) -> tuple[str, ...]: if not token: return () @@ -131,6 +172,8 @@ def _token_variants(token: str) -> tuple[str, ...]: if _has_cjk(token): variants.extend(_char_ngrams(token, widths=(1, 2))) return tuple(variants) + + def _has_cjk(text: str) -> bool: return any( "CJK" in unicodedata.name(ch, "") @@ -138,9 +181,13 @@ def _has_cjk(text: str) -> bool: or "KATAKANA" in unicodedata.name(ch, "") for ch in text ) + + def _compact_search_text(value: object) -> str: normalized = _normalized_text(value) return "".join(ch for ch in normalized if unicodedata.category(ch)[0] in {"L", "N"}) + + def _char_ngrams(value: object, *, widths: tuple[int, ...] = (2, 3)) -> set[str]: text = _compact_search_text(value) if not text: @@ -152,6 +199,8 @@ def _char_ngrams(value: object, *, widths: tuple[int, ...] = (2, 3)) -> set[str] else: grams.update(text[index : index + width] for index in range(0, len(text) - width + 1)) return grams + + def _safe_query_variants(values: object) -> tuple[str, ...]: if values is None: return () @@ -170,6 +219,8 @@ def _safe_query_variants(values: object) -> tuple[str, ...]: if len(cleaned) >= 5: break return tuple(dict.fromkeys(cleaned)) + + def _low_information_query(query: str) -> bool: normalized = _normalized_text(query).strip() if not normalized: @@ -185,6 +236,8 @@ def _low_information_query(query: str) -> bool: if len(ascii_tokens) >= 4 and all(len(token) <= 1 for token in ascii_tokens): return True return False + + def _question_topic(lens: str, sub_lens: str) -> str: def segment(value: str, fallback: str) -> str: normalized = _normalized_text(value).replace(".", "_").replace("-", "_").replace("/", "_").replace(":", "_") @@ -195,14 +248,18 @@ def segment(value: str, fallback: str) -> str: if not resolved[0].isalpha(): resolved = f"{fallback}_{resolved}" return resolved + resolved_lens = _normalized_lens(lens) raw = _clean(sub_lens) or "answer" try: return ensure_valid_topic_key(raw) except ValueError: return f"{resolved_lens}.question.{segment(raw, 'answer')}" + + class PersonalModelUnderstandingSurface: """Small four-lens Personal Model surface used by foreground tools.""" + def __init__( self, *, @@ -220,6 +277,7 @@ def __init__( None, ) self._questions = CuriosityQuestionManagementSurface(repository=repository) + def _personal_model_id(self, session_id: str, explicit: str = "") -> str: pm_id = _clean(explicit) if not pm_id: @@ -231,10 +289,12 @@ def _personal_model_id(self, session_id: str, explicit: str = "") -> str: if callable(ensure): ensure(personal_model_id=pm_id) return pm_id + def _episode_id(self, session_id: str) -> str: load_episode = getattr(self.repository, "load_episode_state", None) episode = load_episode(session_id) if callable(load_episode) else None return _clean(getattr(episode, "episode_id", "")) or session_id + def _index_claim(self, fact: Fact) -> None: index_claim = getattr(self.semantic_summary_indexer, "index_personal_model_claim", None) if callable(index_claim): @@ -242,6 +302,7 @@ def _index_claim(self, fact: Fact) -> None: index_claim(fact) except Exception: return + def _deactivate_claim_index( self, *, @@ -279,6 +340,7 @@ def _deactivate_claim_index( ) except Exception: continue + def _indexed_query_dimensions( self, *, @@ -328,6 +390,7 @@ def _query_vector(self, query: str, *, dimensions: int | None = None) -> tuple[t if not values or dimensions is None: return (), None return values, dimensions + def search_personal_model( self, session_id: str, @@ -383,10 +446,7 @@ def search_personal_model( match_status = "strong_match" if selected else "no_match" elif resolved_topic and not search_queries: # Topic-only filter: return all facts matching this topic (exact or prefix) - selected = tuple( - fact for fact in facts - if _topic_matches_filter(fact, resolved_topic) - )[:capped] + selected = tuple(fact for fact in facts if _topic_matches_filter(fact, resolved_topic))[:capped] match_status = "strong_match" if selected else "no_match" elif not search_queries: selected = () @@ -397,7 +457,11 @@ def search_personal_model( else: # --- Main search path: HybridSemanticSearcher --- selected, match_status = self._hybrid_pm_search( - search_queries, pm_id=pm_id, facts=facts, topic=resolved_topic, limit=capped, + search_queries, + pm_id=pm_id, + facts=facts, + topic=resolved_topic, + limit=capped, ) claims = [] @@ -449,7 +513,8 @@ def _hybrid_pm_search( # Topic pre-filter: if a topic is specified, narrow candidates first if topic: topic_matched = tuple( - fact for fact in facts + fact + for fact in facts if _clean((fact.metadata or {}).get("topic")) == topic or _clean((fact.metadata or {}).get("topic")).startswith(f"{topic}.") ) @@ -561,7 +626,12 @@ def search_conversation( preview=preview, limit=capped, ) - return {"personal_model_id": pm_id, "scope": "conversation", "mode": resolved_mode, **dict(result)} + return { + "personal_model_id": pm_id, + "scope": "conversation", + "mode": resolved_mode, + **dict(result), + } ranked = unified_recall( request, repository=self.repository, @@ -616,6 +686,7 @@ def timeline_personal_model( limit=limit, personal_model_id=personal_model_id, ) + def inspect_personal_model( self, session_id: str, @@ -636,7 +707,8 @@ def inspect_personal_model( ) ) selected = tuple( - fact for fact in facts + fact + for fact in facts if (resolved_ref and fact.fact_id == resolved_ref) or (resolved_topic and _clean((fact.metadata or {}).get("topic")) == resolved_topic) ) @@ -649,17 +721,14 @@ def inspect_personal_model( supersedes_refs.append(fact.supersedes_fact_id) metadata = dict(fact.metadata or {}) supersedes_refs.extend( - item.strip() - for item in str(metadata.get("supersedes_fact_ids") or "").split(",") - if item.strip() + item.strip() for item in str(metadata.get("supersedes_fact_ids") or "").split(",") if item.strip() ) chain = tuple( - claim_payload(fact) - for ref_id in dict.fromkeys(supersedes_refs) - for fact in facts - if fact.fact_id == ref_id + claim_payload(fact) for ref_id in dict.fromkeys(supersedes_refs) for fact in facts if fact.fact_id == ref_id + ) + recall_query = ( + _clean(query) or _clean((claim or {}).get("text") if isinstance(claim, Mapping) else "") or resolved_topic ) - recall_query = _clean(query) or _clean((claim or {}).get("text") if isinstance(claim, Mapping) else "") or resolved_topic history = () if recall_query: history_result = self.recall_personal_model( @@ -678,6 +747,7 @@ def inspect_personal_model( "history": history, "supersedes_chain": chain, } + def audit_personal_model( self, session_id: str, @@ -706,13 +776,17 @@ def audit_personal_model( if resolved_action in {"health", "topics"}: result["topic_tree"] = topic_tree(tuple(fact for fact in facts if fact.status == "active")) if resolved_action == "topics": - result["topics"] = topic_rows(tuple(fact for fact in facts if fact.status == "active"), limit=max(1, min(int(limit or 30), 100))) + result["topics"] = topic_rows( + tuple(fact for fact in facts if fact.status == "active"), + limit=max(1, min(int(limit or 30), 100)), + ) if resolved_action == "conflicts": result["conflicts"] = tuple(health.get("conflicting_claim_candidates") or ()) if resolved_action == "stale": result["review_claims_overdue"] = tuple(health.get("review_claims_overdue") or ()) result["current_claims_stale"] = tuple(health.get("current_claims_stale") or ()) return result + def update_personal_model( self, session_id: str, @@ -736,7 +810,13 @@ def update_personal_model( resolved_topic = ensure_valid_topic_key(resolved_topic) resolved_source = _normalized_source(source) resolved_recall_policy = _clean(recall_policy).lower() - if resolved_recall_policy not in {"", "stable", "current", "temporary", "review"}: + if resolved_recall_policy not in { + "", + "stable", + "current", + "temporary", + "review", + }: raise ValueError("recall_policy must be one of stable, current, temporary, review when provided") pm_id = self._personal_model_id(session_id, personal_model_id) now = _utc_now() @@ -748,9 +828,7 @@ def update_personal_model( ) ) resolved_ref = _clean(ref) - targets = tuple( - fact for fact in active if _topic_matches(fact, topic=resolved_topic, ref=resolved_ref) - ) + targets = tuple(fact for fact in active if _topic_matches(fact, topic=resolved_topic, ref=resolved_ref)) if resolved_action == "remember" and not resolved_ref and is_single_active_topic(resolved_topic): targets = tuple(fact for fact in active if _topic_matches(fact, topic=resolved_topic)) if resolved_action == "restore": @@ -771,7 +849,14 @@ def update_personal_model( status=("active", "retired", "disputed"), ) ) - target = next((fact for fact in all_facts if fact.fact_id == resolved_ref and _topic_matches(fact, topic=resolved_topic, ref=resolved_ref)), None) + target = next( + ( + fact + for fact in all_facts + if fact.fact_id == resolved_ref and _topic_matches(fact, topic=resolved_topic, ref=resolved_ref) + ), + None, + ) if target is None: return { "action": resolved_action, @@ -829,7 +914,9 @@ def update_personal_model( text=_clean(text), exclude_refs=tuple(fact.fact_id for fact in targets), ) - protected_targets = tuple(fact for fact in targets if is_protected_topic(resolved_topic, dict(fact.metadata or {}))) + protected_targets = tuple( + fact for fact in targets if is_protected_topic(resolved_topic, dict(fact.metadata or {})) + ) if resolved_action == "forget" and protected_targets: return { "action": resolved_action, @@ -841,7 +928,11 @@ def update_personal_model( "no_match_hint": "protected core topic cannot be forgotten by agent tools; correct the content or unprotect it in the dashboard first", "protected_refs": tuple(fact.fact_id for fact in protected_targets), } - if resolved_action in {"forget", "dispute"} and not resolved_ref and (len(targets) > 1 or (not targets and related_candidates)): + if ( + resolved_action in {"forget", "dispute"} + and not resolved_ref + and (len(targets) > 1 or (not targets and related_candidates)) + ): return { "action": resolved_action, "personal_model_id": pm_id, @@ -894,9 +985,7 @@ def update_personal_model( raise ValueError("text is required for remember/correct") fact_source = "pm_agent_promote" if resolved_source == "learned" else "user_explicit" inherited_recall_metadata = ( - inheritable_recall_metadata(targets) - if resolved_action == "correct" and not resolved_recall_policy - else {} + inheritable_recall_metadata(targets) if resolved_action == "correct" and not resolved_recall_policy else {} ) caller_metadata = {str(key): str(value) for key, value in dict(metadata or {}).items() if str(value).strip()} protection_metadata = protected_topic_metadata(resolved_topic, caller_metadata) @@ -954,6 +1043,7 @@ def update_personal_model( "status": "active", **({"no_match_hint": no_match_hint} if no_match_hint else {}), } + def manage_personal_model_questions(self, session_id: str, **kwargs: Any) -> Mapping[str, Any]: payload = dict(kwargs) answer_text = _clean(payload.pop("answer", "")) @@ -961,7 +1051,10 @@ def manage_personal_model_questions(self, session_id: str, **kwargs: Any) -> Map if _clean(payload.get("action")).lower() == "answer" and answer_text: question = result.get("question") if isinstance(result, Mapping) else None lens = _clean((question or {}).get("lens")) or _clean(payload.get("lens")) or "knowledge" - topic = _question_topic(lens, _clean((question or {}).get("sub_lens")) or _clean(payload.get("sub_lens")) or "answer") + topic = _question_topic( + lens, + _clean((question or {}).get("sub_lens")) or _clean(payload.get("sub_lens")) or "answer", + ) update = self.update_personal_model( session_id, action="correct", diff --git a/packages/understanding/semantic_search_support.py b/packages/understanding/semantic_search_support.py index 2f9c1cc..0b80591 100644 --- a/packages/understanding/semantic_search_support.py +++ b/packages/understanding/semantic_search_support.py @@ -21,11 +21,7 @@ def _topic_matches_filter(fact: Fact, topic_filter: str) -> bool: def claim_ref_from_match(match: Any) -> str: entry = getattr(match, "semantic_index_entry", None) metadata = dict(getattr(entry, "metadata", {}) or {}) - return str( - metadata.get("claim_ref") - or getattr(entry, "source_id", "") - or "" - ).strip() + return str(metadata.get("claim_ref") or getattr(entry, "source_id", "") or "").strip() def rank_facts_by_semantic_queries( diff --git a/packages/understanding/temporal_policy.py b/packages/understanding/temporal_policy.py index 746bb2a..9a3afb0 100644 --- a/packages/understanding/temporal_policy.py +++ b/packages/understanding/temporal_policy.py @@ -14,7 +14,7 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import datetime VOLATILITY_HALF_LIVES: dict[str, float | None] = { "permanent": None, diff --git a/pyproject.toml b/pyproject.toml index f3147e2..f48fe1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,11 @@ Documentation = "https://elephant.agentic-in.ai/docs/" Repository = "https://github.com/agentic-in/elephant-agent" Issues = "https://github.com/agentic-in/elephant-agent/issues" +[tool.ruff] +target-version = "py312" +line-length = 120 +lint.ignore = ["E402"] + [tool.setuptools.packages.find] where = ["."] include = ["apps", "apps.api*", "apps.cli*", "apps.dashboard*", "apps.gateway*", "packages*"] diff --git a/tests/__init__.py b/tests/__init__.py index 8b13789..e69de29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +0,0 @@ - diff --git a/tests/agent/__init__.py b/tests/agent/__init__.py index 8b13789..e69de29 100644 --- a/tests/agent/__init__.py +++ b/tests/agent/__init__.py @@ -1 +0,0 @@ - diff --git a/tests/agent/test_agent_gate.py b/tests/agent/test_agent_gate.py index 1315352..4196c4d 100644 --- a/tests/agent/test_agent_gate.py +++ b/tests/agent/test_agent_gate.py @@ -31,15 +31,26 @@ def test_match_any(self) -> None: self.assertTrue(MODULE.match_any("docs/README.md", ["docs/*.md"])) def test_parse_repo_name_from_remote_url(self) -> None: - self.assertEqual(MODULE.parse_repo_name_from_remote_url("git@github.com:agentic-in/elephant.git"), "elephant") - self.assertEqual(MODULE.parse_repo_name_from_remote_url("https://github.com/agentic-in/elephant.git"), "elephant") + self.assertEqual( + MODULE.parse_repo_name_from_remote_url("git@github.com:agentic-in/elephant.git"), + "elephant", + ) + self.assertEqual( + MODULE.parse_repo_name_from_remote_url("https://github.com/agentic-in/elephant.git"), + "elephant", + ) def test_resolve_repo_identity_name_uses_git_common_dir_name(self) -> None: completed = mock.Mock(returncode=0, stdout="/tmp/repos/elephant\n") with mock.patch.object(MODULE.subprocess, "run", return_value=completed): - self.assertEqual(MODULE.resolve_repo_identity_name(Path("/tmp/activitytrees/fnd-3")), "elephant") + self.assertEqual( + MODULE.resolve_repo_identity_name(Path("/tmp/activitytrees/fnd-3")), + "elephant", + ) - def test_resolve_repo_identity_name_uses_origin_when_common_dir_is_plain_git_dir(self) -> None: + def test_resolve_repo_identity_name_uses_origin_when_common_dir_is_plain_git_dir( + self, + ) -> None: common_dir = mock.Mock(returncode=0, stdout=".git\n") remote = mock.Mock(returncode=0, stdout="git@github.com:agentic-in/elephant.git\n") with mock.patch.object(MODULE.subprocess, "run", side_effect=[common_dir, remote]): @@ -49,7 +60,10 @@ def test_resolve_repo_identity_name_falls_back_to_root_name(self) -> None: common_dir = mock.Mock(returncode=1, stdout="") remote = mock.Mock(returncode=1, stdout="") with mock.patch.object(MODULE.subprocess, "run", side_effect=[common_dir, remote]): - self.assertEqual(MODULE.resolve_repo_identity_name(Path("/tmp/activitytrees/fnd-3")), "fnd-3") + self.assertEqual( + MODULE.resolve_repo_identity_name(Path("/tmp/activitytrees/fnd-3")), + "fnd-3", + ) def test_validate_contract_accepts_checkout_alias_during_repo_rename(self) -> None: with mock.patch.object(MODULE, "resolve_repo_identity_name", return_value="a" + "egis"): @@ -73,7 +87,12 @@ def test_scan_reset_banned_terms_reports_removed_surface_language(self) -> None: errors = MODULE.scan_reset_banned_terms( root=root, surfaces=("surface.txt",), - banned_terms=((removed_term, "speech-mode contract is removed from reset surfaces"),), + banned_terms=( + ( + removed_term, + "speech-mode contract is removed from reset surfaces", + ), + ), ) self.assertEqual( @@ -94,7 +113,12 @@ def test_scan_reset_banned_terms_accepts_clean_surface(self) -> None: errors = MODULE.scan_reset_banned_terms( root=root, surfaces=("surface.txt",), - banned_terms=((removed_term, "speech-mode contract is removed from reset surfaces"),), + banned_terms=( + ( + removed_term, + "speech-mode contract is removed from reset surfaces", + ), + ), ) self.assertEqual(errors, []) @@ -109,7 +133,9 @@ def test_collect_changed_files_accepts_space_and_comma_lists(self) -> None: ["tools/agent/context-map.yaml", ".github/workflows/agent-lint.yml"], ) - def test_scan_reset_banned_terms_defaults_to_tracked_files_with_allowlist(self) -> None: + def test_scan_reset_banned_terms_defaults_to_tracked_files_with_allowlist( + self, + ) -> None: removed_term = " ".join(("goal", "graph")) with tempfile.TemporaryDirectory() as tmpdir: root = Path(tmpdir) @@ -128,10 +154,7 @@ def test_scan_reset_banned_terms_defaults_to_tracked_files_with_allowlist(self) self.assertEqual( errors, - [ - f"reset banned term in blocked.txt:1: {removed_term} " - "(current-work wording is required)" - ], + [f"reset banned term in blocked.txt:1: {removed_term} (current-work wording is required)"], ) def test_resolve_rules_for_ci_workflow(self) -> None: @@ -184,7 +207,10 @@ def test_audit_ignores_local_agents_docs_as_surface_drift(self) -> None: self.assertEqual(MODULE.audit_surface_coverage(["packages/growth/AGENTS.md"], pack), []) def test_audit_uses_also_matched_skill_surface_coverage(self) -> None: - changed_files = ["packages/semantic_index/AGENTS.md", "tools/agent/context-map.yaml"] + changed_files = [ + "packages/semantic_index/AGENTS.md", + "tools/agent/context-map.yaml", + ] matches = MODULE.resolve_rule_matches(changed_files) pack = MODULE.build_context_pack(changed_files, matches) @@ -275,15 +301,20 @@ def test_python_line_limit_skips_legacy_large_modules(self) -> None: self.assertEqual(files, ("apps/cli/runtime_cognition.py",)) def test_frontend_typecheck_commands_select_dashboard_and_site(self) -> None: - commands = MODULE.frontend_typecheck_commands([ - "apps/dashboard/src/routes/console/ConsolePages.tsx", - "apps/site/src/pages/index.tsx", - "packages/state/config.py", - ]) + commands = MODULE.frontend_typecheck_commands( + [ + "apps/dashboard/src/routes/console/ConsolePages.tsx", + "apps/site/src/pages/index.tsx", + "packages/state/config.py", + ] + ) self.assertEqual( commands, ( - ("dashboard", ("npm", "--prefix", "apps/dashboard", "run", "typecheck")), + ( + "dashboard", + ("npm", "--prefix", "apps/dashboard", "run", "typecheck"), + ), ("site", ("npm", "--prefix", "apps/site", "run", "typecheck")), ), ) @@ -320,7 +351,10 @@ def test_makefile_exposes_phony_lint_alias(self) -> None: def test_ci_lint_uses_commit_range_instead_of_full_repo_scan(self) -> None: workflow_text = (ROOT / ".github" / "workflows" / "ci.yml").read_text(encoding="utf-8") - self.assertIn('make build-and-test AGENT_BASE_REF="origin/${{ github.base_ref }}"', workflow_text) + self.assertIn( + 'make build-and-test AGENT_BASE_REF="origin/${{ github.base_ref }}"', + workflow_text, + ) self.assertIn('BASE_REF="${{ github.event.before }}"', workflow_text) self.assertIn('make build-and-test AGENT_BASE_REF="$BASE_REF"', workflow_text) diff --git a/tests/agent/test_commit_msg_lint.py b/tests/agent/test_commit_msg_lint.py index 9d1e4ff..55000a8 100644 --- a/tests/agent/test_commit_msg_lint.py +++ b/tests/agent/test_commit_msg_lint.py @@ -4,7 +4,6 @@ from pathlib import Path import sys import unittest -from unittest import mock ROOT = Path(__file__).resolve().parents[2] diff --git a/tests/agent/test_ship.py b/tests/agent/test_ship.py index 6f2d30c..5e079c8 100644 --- a/tests/agent/test_ship.py +++ b/tests/agent/test_ship.py @@ -28,7 +28,11 @@ def test_parse_status_paths_handles_rename_and_untracked(self) -> None: ] self.assertEqual( MODULE.parse_status_paths(lines), - ["README.md", "docs/new.md", "docs/system-design/provisional-foundation.md"], + [ + "README.md", + "docs/new.md", + "docs/system-design/provisional-foundation.md", + ], ) def test_ensure_branch_uses_override(self) -> None: diff --git a/tests/agent/test_system_layer_reset_matrix.py b/tests/agent/test_system_layer_reset_matrix.py index 0efb368..2b9b48c 100644 --- a/tests/agent/test_system_layer_reset_matrix.py +++ b/tests/agent/test_system_layer_reset_matrix.py @@ -62,7 +62,9 @@ def test_makefile_targets_reference_reset_lifecycle_surfaces(self) -> None: with self.subTest(target=target): self.assertIn(target, text) - def test_storage_suite_pins_default_model_clean_schema_and_delete_boundaries(self) -> None: + def test_storage_suite_pins_default_model_clean_schema_and_delete_boundaries( + self, + ) -> None: text = _read("tests/integration/storage_system_layers/test_repository.py") for marker in ( @@ -75,7 +77,9 @@ def test_storage_suite_pins_default_model_clean_schema_and_delete_boundaries(sel with self.subTest(marker=marker): self.assertIn(marker, text) - def test_loop_checkpoint_and_personal_model_growth_live_in_repository_methods_without_runtime_shims(self) -> None: + def test_loop_checkpoint_and_personal_model_growth_live_in_repository_methods_without_runtime_shims( + self, + ) -> None: checkpoint_text = _read("packages/kernel/loop_checkpoint_support.py") episode_runtime_text = _read("apps/episode_runtime.py") repository_methods_text = _read("packages/storage/repository_system_methods.py") @@ -103,7 +107,9 @@ def test_loop_checkpoint_and_personal_model_growth_live_in_repository_methods_wi self.assertNotIn("load_latest_open_agent_run", repository_methods_text) self.assertNotIn("profile_growth", repository_methods_text) - def test_kernel_and_context_suites_pin_state_query_and_compaction_coverage(self) -> None: + def test_kernel_and_context_suites_pin_state_query_and_compaction_coverage( + self, + ) -> None: kernel_text = _read("tests/integration/kernel/test_turn_lifecycle.py") context_text = _read("tests/unit/context/test_context_projection.py") @@ -124,7 +130,9 @@ def test_kernel_and_context_suites_pin_state_query_and_compaction_coverage(self) with self.subTest(marker=marker): self.assertIn(marker, context_text) - def test_skill_dashboard_and_continuity_suites_pin_reset_acceptance_surfaces(self) -> None: + def test_skill_dashboard_and_continuity_suites_pin_reset_acceptance_surfaces( + self, + ) -> None: skills_text = _read("tests/integration/tools_skills/test_tools_and_skills_runtime.py") api_text = _read("tests/e2e/api/test_api_surface.py") continuity_text = _read("tests/scenarios/continuity/test_continuity_scenarios.py") @@ -139,7 +147,10 @@ def test_skill_dashboard_and_continuity_suites_pin_reset_acceptance_surfaces(sel self.assertIn(marker, skills_text) self.assertIn("/v1/internal/dashboard", api_text) - self.assertIn("test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence", api_text) + self.assertIn( + "test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence", + api_text, + ) self.assertIn("test_continuity_scenarios_index_is_stable", continuity_text) self.assertIn("test_state_continuity_fixture_declares_text_only_boundary", continuity_text) diff --git a/tests/agent/test_wave_manager.py b/tests/agent/test_wave_manager.py index bf5e111..413058d 100644 --- a/tests/agent/test_wave_manager.py +++ b/tests/agent/test_wave_manager.py @@ -50,7 +50,11 @@ def test_parse_worktree_records(self) -> None: self.assertEqual( MODULE.parse_worktree_records(output), [ - {"worktree": "/tmp/elephant", "HEAD": "abc123", "branch": "refs/heads/main"}, + { + "worktree": "/tmp/elephant", + "HEAD": "abc123", + "branch": "refs/heads/main", + }, { "worktree": "/tmp/elephant/.worktrees/fnd-1", "HEAD": "def456", diff --git a/tests/e2e/api/test_api_surface.py b/tests/e2e/api/test_api_surface.py index b9b5fd5..cf78f15 100644 --- a/tests/e2e/api/test_api_surface.py +++ b/tests/e2e/api/test_api_surface.py @@ -109,15 +109,25 @@ def do_POST(self) -> None: # noqa: N802 } } ], - "usage": {"prompt_tokens": 7, "completion_tokens": 3, "total_tokens": 10}, + "usage": { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + }, } elif self.path == "/v1/responses": content = f"live-response:{Handler._responses_input_text(payload.get('input'))}" if payload.get("stream"): midpoint = max(1, len(content) // 2) events = ( - ("response.output_text.delta", {"delta": content[:midpoint]}), - ("response.output_text.delta", {"delta": content[midpoint:]}), + ( + "response.output_text.delta", + {"delta": content[:midpoint]}, + ), + ( + "response.output_text.delta", + {"delta": content[midpoint:]}, + ), ( "response.completed", { @@ -125,7 +135,11 @@ def do_POST(self) -> None: # noqa: N802 "id": "resp-stub", "model": payload["model"], "output_text": content, - "usage": {"input_tokens": 6, "output_tokens": 3, "total_tokens": 9}, + "usage": { + "input_tokens": 6, + "output_tokens": 3, + "total_tokens": 9, + }, } }, ), @@ -145,7 +159,11 @@ def do_POST(self) -> None: # noqa: N802 "id": "resp-stub", "model": payload["model"], "output_text": content, - "usage": {"input_tokens": 6, "output_tokens": 3, "total_tokens": 9}, + "usage": { + "input_tokens": 6, + "output_tokens": 3, + "total_tokens": 9, + }, } elif self.path == "/v1/messages": response = { @@ -186,7 +204,11 @@ def do_GET(self) -> None: # noqa: N802 if self.path == "/v1/models": response = { "data": [ - {"id": "openai/gpt-4o-mini", "owned_by": "stub", "context_window": 128000}, + { + "id": "openai/gpt-4o-mini", + "owned_by": "stub", + "context_window": 128000, + }, {"id": "openai/gpt-4.1-mini", "owned_by": "stub"}, ] } @@ -352,22 +374,33 @@ def test_kernel_backed_turn_execution_and_controlled_tool_path(self) -> None: self.assertEqual(turn.status_code, 200) self.assertEqual(turn.payload["episode"]["episode_id"], "session-turn") self.assertEqual(turn.payload["outcome"]["event"]["episode_id"], "session-turn") - self.assertEqual(turn.payload["outcome"]["event"]["payload"]["state_query"], "Continue the release plan") + self.assertEqual( + turn.payload["outcome"]["event"]["payload"]["state_query"], + "Continue the release plan", + ) self.assertEqual(turn.payload["outcome"]["state"]["elephant_id"], "elephant-1") self.assertNotIn("active_task", turn.payload["outcome"]["state"]) self.assertGreaterEqual(len(turn.payload["outcome"]["stages"]), 6) self.assertGreaterEqual(len(turn.payload["outcome"]["steps"]), 6) self.assertGreaterEqual(turn.payload["inspection"]["recall_count"], 0) self.assertGreaterEqual(turn.payload["inspection"]["telemetry_count"], 1) - self.assertEqual(turn.payload["inspection"]["progression"]["stage_title"], "learning the path") - self.assertTrue( - turn.payload["outcome"]["execution"]["summary"].startswith( - "live-chat:What should we do next?" - ) + self.assertEqual( + turn.payload["inspection"]["progression"]["stage_title"], + "learning the path", + ) + self.assertTrue(turn.payload["outcome"]["execution"]["summary"].startswith("live-chat:What should we do next?")) + self.assertIn( + "transport=openai_chat_compatible", + turn.payload["outcome"]["execution"]["side_effects"], + ) + self.assertIn( + "credential_keys=api_key", + turn.payload["outcome"]["execution"]["side_effects"], + ) + self.assertEqual( + turn.payload["inspection"]["provider_profile"]["profile_id"], + "provider-openrouter", ) - self.assertIn("transport=openai_chat_compatible", turn.payload["outcome"]["execution"]["side_effects"]) - self.assertIn("credential_keys=api_key", turn.payload["outcome"]["execution"]["side_effects"]) - self.assertEqual(turn.payload["inspection"]["provider_profile"]["profile_id"], "provider-openrouter") tool_turn = self.app.dispatch( "POST", @@ -382,10 +415,19 @@ def test_kernel_backed_turn_execution_and_controlled_tool_path(self) -> None: ) self.assertEqual(tool_turn.status_code, 200) self.assertEqual(tool_turn.payload["outcome"]["execution"]["outcome"], "success") - self.assertEqual(tool_turn.payload["outcome"]["execution"]["side_effects"], ["code", "python", "sandbox"]) + self.assertEqual( + tool_turn.payload["outcome"]["execution"]["side_effects"], + ["code", "python", "sandbox"], + ) self.assertIn("hello api tool", tool_turn.payload["outcome"]["execution"]["summary"]) - self.assertEqual(tool_turn.payload["latest_loop"]["request"]["tool_name"], "tool.code.execute") - self.assertEqual(tool_turn.payload["inspection"]["latest_loop"]["request"]["tool_name"], "tool.code.execute") + self.assertEqual( + tool_turn.payload["latest_loop"]["request"]["tool_name"], + "tool.code.execute", + ) + self.assertEqual( + tool_turn.payload["inspection"]["latest_loop"]["request"]["tool_name"], + "tool.code.execute", + ) clarify_turn = self.app.dispatch( "POST", @@ -404,7 +446,10 @@ def test_kernel_backed_turn_execution_and_controlled_tool_path(self) -> None: ) self.assertEqual(clarify_turn.status_code, 200) self.assertEqual(clarify_turn.payload["outcome"]["execution"]["outcome"], "success") - self.assertIn("user_response: beta", clarify_turn.payload["outcome"]["execution"]["summary"]) + self.assertIn( + "user_response: beta", + clarify_turn.payload["outcome"]["execution"]["summary"], + ) self.assertEqual(clarify_turn.payload["latest_loop"]["request"]["tool_name"], "tool.clarify") inspect = self.app.dispatch("GET", "/v1/episodes/session-turn") @@ -421,7 +466,10 @@ def test_kernel_backed_turn_execution_and_controlled_tool_path(self) -> None: ("PATCH", "/v1/episodes/session-turn/goals/work-launch"), ): with self.subTest(method=method, route=route): - self.assertEqual(self.app.dispatch(method, route, body=self._body({})).status_code, 404) + self.assertEqual( + self.app.dispatch(method, route, body=self._body({})).status_code, + 404, + ) profile_surface = self.app.dispatch("GET", "/v1/episodes/session-turn/profile") self.assertEqual(profile_surface.status_code, 200) @@ -488,7 +536,9 @@ def test_api_chat_runtime_exposes_model_tools_and_skill_context(self) -> None: self.assertIn("skill", result.side_effects) self.assertNotEqual(result.summary.strip(), "") - def test_canonical_state_routes_expose_identity_user_relationship_and_continuity(self) -> None: + def test_canonical_state_routes_expose_identity_user_relationship_and_continuity( + self, + ) -> None: created = self.app.dispatch( "POST", "/v1/episodes", @@ -540,7 +590,10 @@ def test_canonical_state_routes_expose_identity_user_relationship_and_continuity ) self.assertEqual(updated_user.status_code, 200) self.assertEqual(updated_user.payload["user"]["preferred_name"], "Bit") - self.assertIn("current_work:Build Elephant Agent", updated_user.payload["user"]["biography_fragments"]) + self.assertIn( + "current_work:Build Elephant Agent", + updated_user.payload["user"]["biography_fragments"], + ) updated_relationship = self.app.dispatch( "PATCH", @@ -566,7 +619,9 @@ def test_canonical_state_routes_expose_identity_user_relationship_and_continuity self.assertIn("wake_summary", continuity.payload) self.assertIn("continuity", continuity.payload) - def test_elephant_management_routes_create_update_delete_state_file_and_level(self) -> None: + def test_elephant_management_routes_create_update_delete_state_file_and_level( + self, + ) -> None: created = self.app.dispatch( "POST", "/v1/herd", @@ -632,7 +687,9 @@ def test_elephant_management_routes_create_update_delete_state_file_and_level(se self.assertIsNone(self.app.repository.load_state("state:atlas")) self.assertFalse(state_file.exists()) - def test_turn_without_seed_graph_does_not_form_a_goal_from_prompt_alone(self) -> None: + def test_turn_without_seed_graph_does_not_form_a_goal_from_prompt_alone( + self, + ) -> None: self.app.dispatch( "POST", "/v1/episodes", @@ -656,9 +713,14 @@ def test_turn_without_seed_graph_does_not_form_a_goal_from_prompt_alone(self) -> self.assertNotIn("goals", turn.payload["inspection"]) self.assertNotIn("work_items", turn.payload["inspection"]) self.assertNotIn("active_task", turn.payload["outcome"]["state"]) - self.assertIn("current-work lifecycle", turn.payload["outcome"]["event"]["payload"]["message"]) + self.assertIn( + "current-work lifecycle", + turn.payload["outcome"]["event"]["payload"]["message"], + ) - def test_turn_does_not_mutate_profile_without_explicit_profile_surface(self) -> None: + def test_turn_does_not_mutate_profile_without_explicit_profile_surface( + self, + ) -> None: self.app.dispatch( "POST", "/v1/episodes", @@ -717,7 +779,10 @@ def test_turn_without_seed_graph_uses_explicit_state_query(self) -> None: self.assertEqual(turn.status_code, 200) self.assertNotIn("active_task", turn.payload["outcome"]["state"]) - self.assertIn("current-work lifecycle", turn.payload["outcome"]["event"]["payload"]["state_query"].lower()) + self.assertIn( + "current-work lifecycle", + turn.payload["outcome"]["event"]["payload"]["state_query"].lower(), + ) def test_openai_provider_profile_uses_first_party_runtime_resolution(self) -> None: created = self.app.dispatch( @@ -749,12 +814,16 @@ def test_openai_provider_profile_uses_first_party_runtime_resolution(self) -> No ) self.assertEqual(turn.status_code, 200) self.assertTrue( - turn.payload["outcome"]["execution"]["summary"].startswith( - "live-response:Summarize the next release step." - ) + turn.payload["outcome"]["execution"]["summary"].startswith("live-response:Summarize the next release step.") + ) + self.assertIn( + "transport=openai_responses", + turn.payload["outcome"]["execution"]["side_effects"], + ) + self.assertIn( + "credential_keys=api_key", + turn.payload["outcome"]["execution"]["side_effects"], ) - self.assertIn("transport=openai_responses", turn.payload["outcome"]["execution"]["side_effects"]) - self.assertIn("credential_keys=api_key", turn.payload["outcome"]["execution"]["side_effects"]) self.assertEqual(turn.payload["inspection"]["provider_profile"]["provider_id"], "openai") def test_anthropic_provider_profile_uses_native_messages_runtime(self) -> None: @@ -790,9 +859,18 @@ def test_anthropic_provider_profile_uses_native_messages_runtime(self) -> None: turn.payload["outcome"]["execution"]["summary"], "live-anthropic:Explain the provider boundary.", ) - self.assertIn("transport=anthropic_messages", turn.payload["outcome"]["execution"]["side_effects"]) - self.assertIn("credential_keys=api_key", turn.payload["outcome"]["execution"]["side_effects"]) - self.assertEqual(turn.payload["inspection"]["provider_profile"]["transport_id"], "anthropic_messages") + self.assertIn( + "transport=anthropic_messages", + turn.payload["outcome"]["execution"]["side_effects"], + ) + self.assertIn( + "credential_keys=api_key", + turn.payload["outcome"]["execution"]["side_effects"], + ) + self.assertEqual( + turn.payload["inspection"]["provider_profile"]["transport_id"], + "anthropic_messages", + ) def test_provider_onboarding_and_default_provider_flow(self) -> None: provider_profile = self._provider_profile( @@ -815,11 +893,19 @@ def test_provider_onboarding_and_default_provider_flow(self) -> None: models = self.app.dispatch( "POST", "/v1/providers/models", - body=self._body({"providerId": "openai-compatible", "baseUrl": self.stub.openai_base_url}), + body=self._body( + { + "providerId": "openai-compatible", + "baseUrl": self.stub.openai_base_url, + } + ), ) self.assertEqual(models.status_code, 200) self.assertEqual(models.payload["providerId"], "openai-compatible") - self.assertIn("openai/gpt-4o-mini", [model["model_id"] for model in models.payload["models"]]) + self.assertIn( + "openai/gpt-4o-mini", + [model["model_id"] for model in models.payload["models"]], + ) defaulted = self.app.dispatch( "POST", @@ -833,7 +919,10 @@ def test_provider_onboarding_and_default_provider_flow(self) -> None: self.assertEqual(defaulted.payload["active_provider"]["model_id"], "openai/gpt-4o-mini") self.assertEqual(defaulted.payload["active_provider"]["context_window_tokens"], 128000) self.assertEqual(defaulted.payload["active_provider"]["context_window_mode"], "auto") - self.assertIn(defaulted.payload["active_provider"]["embedding_bootstrap_status"], EMBEDDING_BOOTSTRAP_STATUSES) + self.assertIn( + defaulted.payload["active_provider"]["embedding_bootstrap_status"], + EMBEDDING_BOOTSTRAP_STATUSES, + ) config = load_global_config( global_config_path_for_state_dir(self.app.repository.database_path.parent), state_dir=self.app.repository.database_path.parent, @@ -881,7 +970,10 @@ def test_provider_onboarding_and_default_provider_flow(self) -> None: ) self.assertEqual(external_embedding.status_code, 200) self.assertEqual(external_embedding.payload["embedding_provider"]["source"], "configured") - self.assertEqual(external_embedding.payload["embedding_provider"]["model_id"], "text-embedding-3-large") + self.assertEqual( + external_embedding.payload["embedding_provider"]["model_id"], + "text-embedding-3-large", + ) self.assertEqual(external_embedding.payload["embedding_provider"]["secret_status"], "stored") local_embedding = self.app.dispatch( "POST", @@ -900,7 +992,10 @@ def test_provider_onboarding_and_default_provider_flow(self) -> None: self.assertEqual(doctor.payload["status"], "ready") self.assertEqual(doctor.payload["active_provider"]["provider_id"], "openai-compatible") self.assertIn("runtime", [check["check"] for check in doctor.payload["checks"]]) - self.assertIn("embedding_bootstrap", [check["check"] for check in doctor.payload["checks"]]) + self.assertIn( + "embedding_bootstrap", + [check["check"] for check in doctor.payload["checks"]], + ) test = self.app.dispatch( "POST", @@ -936,16 +1031,10 @@ def test_provider_onboarding_and_default_provider_flow(self) -> None: else turn.payload["outcome"]["execution"] ) execution_summary = execution.summary if hasattr(execution, "summary") else execution["summary"] - self.assertTrue( - execution_summary.startswith( - "live-chat:What should we do next?" - ) - ) + self.assertTrue(execution_summary.startswith("live-chat:What should we do next?")) inspection = turn.payload["inspection"] provider_profile = ( - inspection.provider_profile - if hasattr(inspection, "provider_profile") - else inspection["provider_profile"] + inspection.provider_profile if hasattr(inspection, "provider_profile") else inspection["provider_profile"] ) provider_id = ( provider_profile.provider_id @@ -1017,9 +1106,7 @@ def test_default_provider_profile_stays_non_blocking(self) -> None: doctor = self.app.dispatch("GET", "/v1/providers/doctor") self.assertEqual(doctor.status_code, 200) - bootstrap_check = next( - check for check in doctor.payload["checks"] if check["check"] == "embedding_bootstrap" - ) + bootstrap_check = next(check for check in doctor.payload["checks"] if check["check"] == "embedding_bootstrap") self.assertIn(bootstrap_check["status"], EMBEDDING_BOOTSTRAP_STATUSES) self.assertEqual(doctor.payload["status"], "ready") @@ -1100,7 +1187,10 @@ def test_gateway_dashboard_cards_configure_im_accounts(self) -> None: self.assertNotIn("cli-feishu-secret", json.dumps(manifest)) self.assertEqual( [ref["metadata"]["env_var"] for ref in account["secret_references"]], - ["ELEPHANT_FEISHU_OPS_FEISHU_APP_ID", "ELEPHANT_FEISHU_OPS_FEISHU_APP_SECRET"], + [ + "ELEPHANT_FEISHU_OPS_FEISHU_APP_ID", + "ELEPHANT_FEISHU_OPS_FEISHU_APP_SECRET", + ], ) secret_file = Path(self.tempdir.name) / "gateway-local-secrets.json" local_secrets = json.loads(secret_file.read_text(encoding="utf-8")) @@ -1125,7 +1215,14 @@ def test_gateway_dashboard_cards_configure_im_accounts(self) -> None: started = self.app.dispatch( "POST", "/v1/operator/gateway", - body=self._body({"service": "feishu", "action": "start", "accountId": "ops-feishu", "transport": "long-connection"}), + body=self._body( + { + "service": "feishu", + "action": "start", + "accountId": "ops-feishu", + "transport": "long-connection", + } + ), ) self.assertEqual(started.status_code, 200) command = run_mock.call_args.args[0] @@ -1273,9 +1370,15 @@ def test_internal_dashboard_exposes_cli_linked_control_surfaces(self) -> None: self.assertIn("globalConfig", operations["settings"]) self.assertNotIn("eggStateFiles", operations["settings"]) self.assertNotIn("eggStateFilesDir", operations["settings"]) - self.assertNotIn("models.state_focus_mode", json.dumps(operations["settings"], sort_keys=True)) + self.assertNotIn( + "models.state_focus_mode", + json.dumps(operations["settings"], sort_keys=True), + ) elephant = next(row for row in payload["herd"] if row["elephant_id"] == "profile-console") - self.assertEqual(elephant["elephant_identity_file"]["path"], str(elephant_root / "ELEPHANT.md")) + self.assertEqual( + elephant["elephant_identity_file"]["path"], + str(elephant_root / "ELEPHANT.md"), + ) self.assertTrue(elephant["elephant_identity_file"]["exists"]) self.assertIn("- Stay exact.", elephant["elephant_identity_file"]["text"]) self.assertTrue(operations["skills"]) @@ -1304,7 +1407,10 @@ def test_internal_dashboard_exposes_cli_linked_control_surfaces(self) -> None: profile_json = Path(patched.payload["profileManifestPath"]) self.assertTrue(profile_json.exists()) patched_config = load_global_config(profile_json, state_dir=self.app.repository.database_path.parent) - self.assertEqual(patched_config["runtime"]["state_dir"], str(self.app.repository.database_path.parent)) + self.assertEqual( + patched_config["runtime"]["state_dir"], + str(self.app.repository.database_path.parent), + ) global_config = self.app.dispatch( "PATCH", @@ -1355,7 +1461,11 @@ def test_internal_dashboard_exposes_cli_linked_control_surfaces(self) -> None: "serverLabel": "Filesystem", "transport": "stdio", "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp/demo"], + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp/demo", + ], "env": {"ALLOW": "1"}, "displayName": "Read File", "description": "Read a file from the mounted elephant file area.", @@ -1388,7 +1498,9 @@ def test_internal_dashboard_exposes_cli_linked_control_surfaces(self) -> None: refreshed = self.app.dispatch("GET", "/v1/internal/dashboard/tools") custom_mcp_tool = next( - tool for tool in refreshed.payload["dashboard"]["operations"]["mcp"]["tools"] if tool["toolKey"] == "filesystem:read_file" + tool + for tool in refreshed.payload["dashboard"]["operations"]["mcp"]["tools"] + if tool["toolKey"] == "filesystem:read_file" ) self.assertEqual(custom_mcp_tool["displayName"], "Read File") self.assertTrue(custom_mcp_tool["enabled"]) @@ -1415,9 +1527,7 @@ def test_internal_dashboard_exposes_cli_linked_control_surfaces(self) -> None: stored_global_config["mcp_servers"]["filesystem"]["tools"]["read_file"]["display_name"], "Read File (updated)", ) - self.assertTrue( - stored_global_config["mcp_servers"]["filesystem"]["tools"]["read_file"]["touches_secrets"] - ) + self.assertTrue(stored_global_config["mcp_servers"]["filesystem"]["tools"]["read_file"]["touches_secrets"]) runtime_tool = self.app.tool_runtime.describe("mcp.filesystem.read_file") self.assertEqual(runtime_tool.display_name, "Read File (updated)") self.assertTrue(runtime_tool.side_effects.touches_secrets) @@ -1440,7 +1550,9 @@ def test_internal_dashboard_exposes_cli_linked_control_surfaces(self) -> None: self.assertFalse(self.app.tool_runtime.describe("mcp.filesystem.read_file").enabled) refreshed = self.app.dispatch("GET", "/v1/internal/dashboard/tools") custom_mcp_tool = next( - tool for tool in refreshed.payload["dashboard"]["operations"]["mcp"]["tools"] if tool["toolKey"] == "filesystem:read_file" + tool + for tool in refreshed.payload["dashboard"]["operations"]["mcp"]["tools"] + if tool["toolKey"] == "filesystem:read_file" ) self.assertFalse(custom_mcp_tool["enabled"]) @@ -1466,7 +1578,9 @@ def test_internal_dashboard_exposes_cli_linked_control_surfaces(self) -> None: {server["serverId"] for server in refreshed.payload["dashboard"]["operations"]["mcp"]["servers"]}, ) - def test_operator_mcp_server_sync_persists_multiple_tools_and_deletes_server(self) -> None: + def test_operator_mcp_server_sync_persists_multiple_tools_and_deletes_server( + self, + ) -> None: synced_server = self.app.dispatch( "POST", "/v1/operator/mcp/servers", @@ -1569,7 +1683,9 @@ def test_operator_mcp_server_sync_persists_multiple_tools_and_deletes_server(sel self.assertNotIn("km", stored_global_config.get("mcp_servers", {})) self.assertIsNone(self.app.tool_runtime.describe("mcp.km.list_articles")) - def test_internal_dashboard_surfaces_configured_external_skill_shelves(self) -> None: + def test_internal_dashboard_surfaces_configured_external_skill_shelves( + self, + ) -> None: external_root = Path(self.tempdir.name) / ".agents" / "skills" skill_dir = external_root / "personal-journal" skill_dir.mkdir(parents=True) @@ -1698,7 +1814,11 @@ def fake_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]: "serverId": "filesystem", "transport": "stdio", "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp/demo"], + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp/demo", + ], "env": {"ALLOW": "1"}, } ), @@ -1759,11 +1879,16 @@ def test_internal_dashboard_keeps_durable_state_after_episode_delete(self) -> No self.assertEqual(dashboard.status_code, 200) payload = dashboard.payload["dashboard"] self.assertNotIn("sessions", payload) - self.assertIn("profile-orphan", [elephant["personal_model_id"] for elephant in payload["herd"]]) + self.assertIn( + "profile-orphan", + [elephant["personal_model_id"] for elephant in payload["herd"]], + ) self.assertIn("state:profile-orphan", [state["state_id"] for state in payload["states"]]) self.assertEqual(payload["overview"]["counts"]["episodes"], 0) - def test_internal_dashboard_excludes_personal_model_growth_state_lanes(self) -> None: + def test_internal_dashboard_excludes_personal_model_growth_state_lanes( + self, + ) -> None: created = self.app.dispatch( "POST", "/v1/episodes", @@ -1792,7 +1917,8 @@ def test_internal_dashboard_excludes_personal_model_growth_state_lanes(self) -> self.assertEqual(dashboard.status_code, 200) elephant = next( - elephant for elephant in dashboard.payload["dashboard"]["herd"] + elephant + for elephant in dashboard.payload["dashboard"]["herd"] if elephant["elephant_id"] == "profile-stale-growth" ) self.assertNotIn("growth_score", elephant) @@ -1804,7 +1930,9 @@ def test_operator_namespace_no_longer_exposes_public_dashboard_reads(self) -> No self.assertEqual(dashboard.status_code, 404) self.assertEqual(console.status_code, 404) - def test_wsgi_get_request_with_no_content_length_returns_without_blocking(self) -> None: + def test_wsgi_get_request_with_no_content_length_returns_without_blocking( + self, + ) -> None: captured: dict[str, object] = {} def start_response(status: str, headers: list[tuple[str, str]]) -> None: @@ -1824,9 +1952,14 @@ def start_response(status: str, headers: list[tuple[str, str]]) -> None: ) self.assertEqual(captured["status"], "200 OK") - self.assertEqual(json.loads(body.decode("utf-8")), {"status": "ok", "service": "elephant-api"}) + self.assertEqual( + json.loads(body.decode("utf-8")), + {"status": "ok", "service": "elephant-api"}, + ) - def test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence(self) -> None: + def test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence( + self, + ) -> None: provider_profile = self._provider_profile( profile_id="provider-dashboard", base_url=self.stub.openai_base_url, @@ -2004,7 +2137,11 @@ def test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence(s }, ), ), - metadata={"embedding_active": "true", "dimensions": "1536", "configured_from": "test"}, + metadata={ + "embedding_active": "true", + "dimensions": "1536", + "configured_from": "test", + }, ) ) self.app.repository.upsert_semantic_index_entry( @@ -2042,7 +2179,15 @@ def test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence(s ) ) - projection = self._dashboard_sections("overview", "personal-models", "runtime", "reflect", "evidence", "providers", "usage") + projection = self._dashboard_sections( + "overview", + "personal-models", + "runtime", + "reflect", + "evidence", + "providers", + "usage", + ) self.assertEqual(projection["overview"]["counts"]["personal_models"], 1) self.assertEqual(projection["overview"]["counts"]["states"], 1) self.assertEqual(projection["overview"]["counts"]["episodes"], 1) @@ -2072,12 +2217,14 @@ def test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence(s self.assertEqual(personal_model_row["states"][0]["state_id"], state.state_id) self.assertEqual(personal_model_row["user_preferred_name"], "Bit") self.assertEqual(personal_model_row["user_profile"]["preferred_name"], "Bit") - self.assertEqual(personal_model_row["user_profile"]["current_work"], "Building durable agent systems.") + self.assertEqual( + personal_model_row["user_profile"]["current_work"], + "Building durable agent systems.", + ) overview_only = self._dashboard_section("overview") self.assertEqual(overview_only["personal_models"][0]["user_preferred_name"], "Bit") component_rows = { - component["component_key"]: component - for component in personal_model_row["understanding_components"] + component["component_key"]: component for component in personal_model_row["understanding_components"] } self.assertEqual(component_rows["identity"]["status"], "active") self.assertEqual(component_rows["identity"]["claim_count"], 2) @@ -2086,11 +2233,17 @@ def test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence(s self.assertEqual(personal_model_row["personal_model_fact_count"], 3) personal_model_fact_text = {fact["text"] for fact in personal_model_row["personal_model_facts"]} self.assertIn("Prefers concise, grounded replies.", personal_model_fact_text) - self.assertNotIn("State-only tool test evidence", json.dumps(personal_model_row, sort_keys=True)) + self.assertNotIn( + "State-only tool test evidence", + json.dumps(personal_model_row, sort_keys=True), + ) self.assertNotIn("Display name: Miles", json.dumps(personal_model_row, sort_keys=True)) self.assertNotIn("reflection_proposals", personal_model_row) self.assertNotIn("skill_affinities", personal_model_row) - self.assertEqual(personal_model_row["semantic_index_entries"][0]["semantic_index_entry_id"], "semantic-dashboard") + self.assertEqual( + personal_model_row["semantic_index_entries"][0]["semantic_index_entry_id"], + "semantic-dashboard", + ) self.assertEqual(projection["runtime"]["episodes"][0]["episode_id"], episode.episode_id) self.assertEqual(projection["runtime"]["episodes"][0]["loop_count"], 1) self.assertEqual(projection["runtime"]["episodes"][0]["step_count"], 1) @@ -2101,7 +2254,10 @@ def test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence(s for legacy_table in LEGACY_STORAGE_TABLES: self.assertNotIn(f"result_{legacy_table}", projection["learning"]["jobs"][0]) self.assertEqual(projection["learning"]["jobs"][0]["result_status"], "completed") - self.assertEqual(projection["learning"]["jobs"][0]["learning_result"]["summary"], "Dashboard learning result.") + self.assertEqual( + projection["learning"]["jobs"][0]["learning_result"]["summary"], + "Dashboard learning result.", + ) self.assertEqual( projection["runtime"]["episode_traces"][0]["timeline"][0]["detail"]["assistant_reasoning"], "Inspect provider posture before opening the dashboard trace.", @@ -2131,13 +2287,18 @@ def test_internal_dashboard_projection_surfaces_canonical_runtime_and_evidence(s self.assertNotIn("state_focus_mode", projection["providers"]["active_provider"]) self.assertNotIn("strong_model", projection["providers"]["active_provider"]) self.assertNotIn("weak_model", projection["providers"]["active_provider"]) - self.assertNotIn("state_focus_mode", json.dumps(projection["providers"]["doctor"], sort_keys=True)) + self.assertNotIn( + "state_focus_mode", + json.dumps(projection["providers"]["doctor"], sort_keys=True), + ) self.assertNotIn("stateLanes", projection) self.assertNotIn("sessions", projection) serialized = json.dumps(projection, sort_keys=True) self.assertNotIn("sk-live-123", serialized) - def test_internal_dashboard_projection_ignores_legacy_session_graph_rows(self) -> None: + def test_internal_dashboard_projection_ignores_legacy_session_graph_rows( + self, + ) -> None: provider_profile = self._provider_profile( profile_id="provider-dashboard", base_url=self.stub.openai_base_url, @@ -2174,7 +2335,9 @@ def test_internal_dashboard_projection_ignores_legacy_session_graph_rows(self) - self.assertNotIn("sessions", projection) self.assertNotIn("ops", projection) - def test_default_provider_bad_request_hides_legacy_profile_field_names(self) -> None: + def test_default_provider_bad_request_hides_legacy_profile_field_names( + self, + ) -> None: response = self.app.dispatch( "POST", "/v1/providers/default", @@ -2194,7 +2357,14 @@ def _dashboard_section(self, section: str) -> dict[str, Any]: def _dashboard_sections(self, *sections: str) -> dict[str, Any]: top_level_keys = { - "overview": ("overview", "herd", "states", "personal_models", "runtime", "learning"), + "overview": ( + "overview", + "herd", + "states", + "personal_models", + "runtime", + "learning", + ), "personal-models": ("personal_models",), "herd": ("herd", "states"), "runtime": ("herd", "states", "runtime"), diff --git a/tests/e2e/cli/test_cli_surface.py b/tests/e2e/cli/test_cli_surface.py index 0ec209d..a3a52dc 100644 --- a/tests/e2e/cli/test_cli_surface.py +++ b/tests/e2e/cli/test_cli_surface.py @@ -230,8 +230,8 @@ def do_POST(self) -> None: # noqa: N802 elif prompt_head == "search xunzhuo liu": content = ( "\n" - "\n" - "xunzhuo liu\n" + '\n' + 'xunzhuo liu\n' "\n" "" ) @@ -571,8 +571,14 @@ def test_setup_and_grow_cli_flow(self) -> None: self.assertIn("provider_status · ready", health.stdout) self.assertIn("security_status · ready", health.stdout) self.assertIn("active_provider_model · openai/gpt-4o-mini", health.stdout) - self.assertRegex(health.stdout, rf"active_provider_embedding_bootstrap · {EMBEDDING_BOOTSTRAP_STATUS_PATTERN}") - self.assertRegex(health.stdout, rf"active_provider_embedding_ready · {EMBEDDING_BOOTSTRAP_READY_PATTERN}") + self.assertRegex( + health.stdout, + rf"active_provider_embedding_bootstrap · {EMBEDDING_BOOTSTRAP_STATUS_PATTERN}", + ) + self.assertRegex( + health.stdout, + rf"active_provider_embedding_ready · {EMBEDDING_BOOTSTRAP_READY_PATTERN}", + ) self.assertNotIn("state_focus_mode", health.stdout) turn = self._run("wake", "--message", "Who are you?") @@ -616,7 +622,9 @@ def test_born_persists_runtime_secret_file_for_future_surfaces(self) -> None: payload = json.loads(secret_path.read_text(encoding="utf-8")) self.assertEqual(payload["ELEPHANT_OPENROUTER_API_KEY"], "sk-cli-test-123") - def test_init_surfaces_embedding_bootstrap_without_exposing_state_focus_mode(self) -> None: + def test_init_surfaces_embedding_bootstrap_without_exposing_state_focus_mode( + self, + ) -> None: setup = self._run( "init", "--non-interactive", @@ -631,18 +639,29 @@ def test_init_surfaces_embedding_bootstrap_without_exposing_state_focus_mode(sel "--api-key", "sk-cli-test-123", ) - self.assertRegex(setup.stdout, rf"embedding_bootstrap_status · {EMBEDDING_BOOTSTRAP_STATUS_PATTERN}") - self.assertRegex(setup.stdout, rf"embedding_bootstrap_ready · {EMBEDDING_BOOTSTRAP_READY_PATTERN}") + self.assertRegex( + setup.stdout, + rf"embedding_bootstrap_status · {EMBEDDING_BOOTSTRAP_STATUS_PATTERN}", + ) + self.assertRegex( + setup.stdout, + rf"embedding_bootstrap_ready · {EMBEDDING_BOOTSTRAP_READY_PATTERN}", + ) self.assertNotIn("state_focus_mode", setup.stdout) config = load_global_config(global_config_path_for_state_dir(self.state_dir), state_dir=self.state_dir) self.assertEqual(config["models"]["provider"]["default_model"], "openai/gpt-4o-mini") health = self._run("status") - self.assertRegex(health.stdout, rf"active_provider_embedding_bootstrap · {EMBEDDING_BOOTSTRAP_STATUS_PATTERN}") + self.assertRegex( + health.stdout, + rf"active_provider_embedding_bootstrap · {EMBEDDING_BOOTSTRAP_STATUS_PATTERN}", + ) self.assertNotIn("state_focus_mode", health.stdout) - def test_provider_embeddings_switch_between_local_default_and_configured_override(self) -> None: + def test_provider_embeddings_switch_between_local_default_and_configured_override( + self, + ) -> None: self._run( "init", "--non-interactive", @@ -689,13 +708,22 @@ def test_provider_embeddings_switch_between_local_default_and_configured_overrid reverted = self._run("provider", "embeddings", "local") self.assertIn("Embedding provider updated", reverted.stdout) self.assertIn("source · local-default", reverted.stdout) - self.assertRegex(reverted.stdout, rf"embedding_bootstrap_status · {EMBEDDING_BOOTSTRAP_STATUS_PATTERN}") - self.assertRegex(reverted.stdout, rf"embedding_bootstrap_ready · {EMBEDDING_BOOTSTRAP_READY_PATTERN}") + self.assertRegex( + reverted.stdout, + rf"embedding_bootstrap_status · {EMBEDDING_BOOTSTRAP_STATUS_PATTERN}", + ) + self.assertRegex( + reverted.stdout, + rf"embedding_bootstrap_ready · {EMBEDDING_BOOTSTRAP_READY_PATTERN}", + ) refreshed = CliRuntime.create(state_dir=self.state_dir) refreshed_summary = dict(refreshed.embedding_provider_summary()) self.assertEqual(refreshed_summary["source"], "local-default") - self.assertIn(refreshed_summary["embedding_bootstrap_status"], EMBEDDING_BOOTSTRAP_STATUSES) + self.assertIn( + refreshed_summary["embedding_bootstrap_status"], + EMBEDDING_BOOTSTRAP_STATUSES, + ) def test_setup_hands_off_to_wake_surface(self) -> None: setup = self._run( @@ -806,8 +834,14 @@ def test_launcher_help_lists_gateway_skills_and_dashboard(self) -> None: self.assertIn("Elephant Agent launcher", help_output.stdout) self.assertIn("Elephant Agent is personal-model-first AI", help_output.stdout) self.assertEqual(help_output.stdout.count("Elephant Agent is personal-model-first AI"), 1) - self.assertIn("Warm, steady ways back to the elephant that remembers your path.", help_output.stdout) - self.assertIn("🐘 Model what matters · 👂 Ask gently · 🐾 Follow the path", help_output.stdout) + self.assertIn( + "Warm, steady ways back to the elephant that remembers your path.", + help_output.stdout, + ) + self.assertIn( + "🐘 Model what matters · 👂 Ask gently · 🐾 Follow the path", + help_output.stdout, + ) self.assertIn("Commands", help_output.stdout) expected_order = [ "• init", @@ -831,7 +865,10 @@ def test_launcher_no_args_prints_single_root_cli_surface(self) -> None: overview = self._run_launcher() self.assertNotIn("Welcome", overview.stdout) self.assertIn("Elephant Agent CLI", overview.stdout) - self.assertIn("🐘 Model what matters · 👂 Ask gently · 🐾 Follow the path", overview.stdout) + self.assertIn( + "🐘 Model what matters · 👂 Ask gently · 🐾 Follow the path", + overview.stdout, + ) self.assertIn("Elephant Agent is personal-model-first AI", overview.stdout) self.assertEqual(overview.stdout.count("Elephant Agent is personal-model-first AI"), 1) self.assertIn("elephant init", overview.stdout) @@ -1137,7 +1174,9 @@ def test_grow_debug_mode_surfaces_debug_elephant_context(self) -> None: self.assertIn("Bring whatever you want to work on; I will adapt from here.", shell) self.assertIn("Elephant Agent stays by your side.", shell) - def test_non_interactive_elephant_creates_state_without_activity_command(self) -> None: + def test_non_interactive_elephant_creates_state_without_activity_command( + self, + ) -> None: self._run( "init", "--non-interactive", @@ -1162,7 +1201,9 @@ def test_non_interactive_elephant_creates_state_without_activity_command(self) - self.assertIn("personal_model_id · you", created.stdout) self.assertNotIn("active_goal", created.stdout) - def test_elephant_name_is_required_and_elephants_delete_clears_named_or_all_elephants(self) -> None: + def test_elephant_name_is_required_and_elephants_delete_clears_named_or_all_elephants( + self, + ) -> None: missing_name = self._run("herd", "new", check=False) self.assertEqual(missing_name.returncode, 1) self.assertIn("Elephant blocked", missing_name.stdout) @@ -1293,7 +1334,9 @@ def test_elephant_message_provider_failure_renders_recovery_card(self) -> None: self.assertIn("elephant wake --elephant-id provider-fail", failed.stdout) self.assertNotIn("Traceback", failed.stderr) - def test_elephant_create_persists_canonical_state_under_default_personal_model(self) -> None: + def test_elephant_create_persists_canonical_state_under_default_personal_model( + self, + ) -> None: self._run( "init", "--non-interactive", @@ -1349,16 +1392,16 @@ def test_elephant_create_uses_canonical_episode_storage_only(self) -> None: ).fetchone() table_names = { str(table_row[0]) - for table_row in connection.execute( - "SELECT name FROM sqlite_master WHERE type = 'table'" - ).fetchall() + for table_row in connection.execute("SELECT name FROM sqlite_master WHERE type = 'table'").fetchall() } self.assertNotIn("sessions", table_names) self.assertIsNotNone(row) self.assertEqual(tuple(row), ("state:atlas", "you")) - def test_elephant_delete_removes_elephant_state_and_preserves_personal_model(self) -> None: + def test_elephant_delete_removes_elephant_state_and_preserves_personal_model( + self, + ) -> None: self._run( "init", "--non-interactive", @@ -1433,7 +1476,9 @@ def test_facts_cli_lists_and_deletes_personal_model_facts(self) -> None: self.assertIn("cleanup stale preference", deleted.stdout) refreshed = CliRuntime.create(state_dir=self.state_dir) - facts = refreshed.repository.list_personal_model_facts(personal_model_id=session.personal_model_id, status=("deleted",)) + facts = refreshed.repository.list_personal_model_facts( + personal_model_id=session.personal_model_id, status=("deleted",) + ) entry = next((fact for fact in facts if fact.fact_id == fact_id), None) self.assertIsNotNone(entry) assert entry is not None @@ -1443,7 +1488,9 @@ def test_facts_cli_lists_and_deletes_personal_model_facts(self) -> None: self.assertNotIn(fact_id, visible.stdout) self.assertNotIn("status=deleted", visible.stdout) - def test_runtime_skill_install_persists_provenance_and_distinguishes_refresh_from_migration(self) -> None: + def test_runtime_skill_install_persists_provenance_and_distinguishes_refresh_from_migration( + self, + ) -> None: runtime = CliRuntime.create(state_dir=self.state_dir) session = runtime.create_elephant(elephant_id="atlas") github_dir = self.root / "remote-github" diff --git a/tests/e2e/deploy/__init__.py b/tests/e2e/deploy/__init__.py index 87626ce..f8f6d60 100644 --- a/tests/e2e/deploy/__init__.py +++ b/tests/e2e/deploy/__init__.py @@ -1,4 +1,5 @@ """Deploy e2e tests.""" + from __future__ import annotations import unittest diff --git a/tests/e2e/deploy/test_editable_install.py b/tests/e2e/deploy/test_editable_install.py index 67b1b43..0d3ec2d 100644 --- a/tests/e2e/deploy/test_editable_install.py +++ b/tests/e2e/deploy/test_editable_install.py @@ -45,7 +45,12 @@ def _run(self, *args: str, env: dict[str, str]) -> subprocess.CompletedProcess[s ) def test_editable_install_exposes_elephant_command(self) -> None: - subprocess.run([sys.executable, "-m", "venv", str(self.venv_dir)], cwd=ROOT, check=True, text=True) + subprocess.run( + [sys.executable, "-m", "venv", str(self.venv_dir)], + cwd=ROOT, + check=True, + text=True, + ) python_bin = self._python_bin() subprocess.run( [str(python_bin), "-m", "pip", "install", "-e", "."], diff --git a/tests/e2e/deploy/test_installed_command_smoke.py b/tests/e2e/deploy/test_installed_command_smoke.py index e121144..3312812 100644 --- a/tests/e2e/deploy/test_installed_command_smoke.py +++ b/tests/e2e/deploy/test_installed_command_smoke.py @@ -26,9 +26,7 @@ def setUp(self) -> None: self.base_url = (os.environ.get("ELEPHANT_LIVE_PROVIDER_BASE_URL") or "").strip() self.model_id = (os.environ.get("ELEPHANT_LIVE_PROVIDER_MODEL") or "").strip() self.api_key = os.environ.get("ELEPHANT_LIVE_PROVIDER_API_KEY") or "" - self.provider_id = ( - os.environ.get("ELEPHANT_LIVE_PROVIDER_PROVIDER_ID") or "openai-compatible" - ).strip() + self.provider_id = (os.environ.get("ELEPHANT_LIVE_PROVIDER_PROVIDER_ID") or "openai-compatible").strip() if not self.base_url or not self.model_id or not self.api_key: self.skipTest( "installed command smoke requires ELEPHANT_LIVE_PROVIDER_BASE_URL, " @@ -47,15 +45,12 @@ def setUp(self) -> None: self.model_id.startswith("tke/"), "installed command smoke keeps the release workflow model-id prefix contract", ) - self.require_dashboard = os.environ.get( - "ELEPHANT_LIVE_INSTALLED_SMOKE_REQUIRE_DASHBOARD" - ) == "1" + self.require_dashboard = os.environ.get("ELEPHANT_LIVE_INSTALLED_SMOKE_REQUIRE_DASHBOARD") == "1" self.dashboard_index = ROOT / "apps" / "dashboard" / "dist" / "index.html" if self.require_dashboard: self.assertTrue( self.dashboard_index.exists(), - "dashboard assets are required for the installed command smoke; " - "run make dashboard-build first", + "dashboard assets are required for the installed command smoke; run make dashboard-build first", ) self.tempdir = tempfile.TemporaryDirectory() self.root = Path(self.tempdir.name) @@ -172,9 +167,7 @@ def _run_tui_smoke(self) -> None: if not prompt_sent and time.monotonic() - start > 4: os.write( master_fd, - "请只回复 ELEPHANT_SMOKE_OK,用于 installed TUI smoke 测试。\n".encode( - "utf-8" - ), + "请只回复 ELEPHANT_SMOKE_OK,用于 installed TUI smoke 测试。\n".encode("utf-8"), ) prompt_sent = True if b"ELEPHANT_SMOKE_OK" in output and not exit_sent: @@ -202,7 +195,17 @@ def test_editable_install_runs_real_elephant_commands_and_tui(self) -> None: self.assertTrue(self._elephant_bin().exists()) help_output = self._run_elephant("--help") - for command_name in ("init", "status", "provider", "herd", "wake", "skills", "gateway", "cron", "dashboard"): + for command_name in ( + "init", + "status", + "provider", + "herd", + "wake", + "skills", + "gateway", + "cron", + "dashboard", + ): with self.subTest(command_name=command_name): self.assertIn(command_name, help_output.stdout) self.assertNotIn("chat", help_output.stdout) diff --git a/tests/e2e/deploy/test_installed_user_journey.py b/tests/e2e/deploy/test_installed_user_journey.py index 2331380..99c53e3 100644 --- a/tests/e2e/deploy/test_installed_user_journey.py +++ b/tests/e2e/deploy/test_installed_user_journey.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path import subprocess import sys import unittest @@ -61,7 +60,10 @@ def _drive_dashboard_chat(dashboard_url: str) -> None: class InstalledUserJourneyE2ETest(unittest.TestCase): def test_editable_install_cli_daemon_dashboard_and_mock_chat(self) -> None: - self.assertTrue(DASHBOARD_INDEX.exists(), "dashboard assets are required; run make dashboard-build first") + self.assertTrue( + DASHBOARD_INDEX.exists(), + "dashboard assets are required; run make dashboard-build first", + ) provider = MockOpenAICompatibleProvider().start() try: diff --git a/tests/e2e/deploy/test_preview_deploy.py b/tests/e2e/deploy/test_preview_deploy.py index 00135f0..d5b6bcb 100644 --- a/tests/e2e/deploy/test_preview_deploy.py +++ b/tests/e2e/deploy/test_preview_deploy.py @@ -33,19 +33,8 @@ def test_preview_assets_and_docker_smoke(self) -> None: subprocess.run(["make", "site-build"], cwd=ROOT, check=True, text=True) dist_index = ROOT / "apps" / "site" / "dist" / "index.html" dist_docs = ROOT / "apps" / "site" / "dist" / "docs" / "index.html" - dist_docs_system_model = ( - ROOT / "apps" / "site" / "dist" / "docs" / "philosophy" / "system-model" / "index.html" - ) - dist_docs_tools = ( - ROOT - / "apps" - / "site" - / "dist" - / "docs" - / "capacities" - / "tools" - / "index.html" - ) + dist_docs_system_model = ROOT / "apps" / "site" / "dist" / "docs" / "philosophy" / "system-model" / "index.html" + dist_docs_tools = ROOT / "apps" / "site" / "dist" / "docs" / "capacities" / "tools" / "index.html" dist_install_script = ROOT / "apps" / "site" / "dist" / "install.sh" dist_robots = ROOT / "apps" / "site" / "dist" / "robots.txt" dist_sitemap = ROOT / "apps" / "site" / "dist" / "sitemap.xml" diff --git a/tests/e2e/deploy/test_public_install_script.py b/tests/e2e/deploy/test_public_install_script.py index 018b1ab..2fac6e3 100644 --- a/tests/e2e/deploy/test_public_install_script.py +++ b/tests/e2e/deploy/test_public_install_script.py @@ -82,7 +82,9 @@ def test_install_upgrade_and_health(self) -> None: health = self._run_install("health") self.assertIn("Elephant Agent status", health.stdout) - def test_public_install_script_defaults_to_dev_channel_and_supports_stable_override(self) -> None: + def test_public_install_script_defaults_to_dev_channel_and_supports_stable_override( + self, + ) -> None: content = (ROOT / "install.sh").read_text(encoding="utf-8") self.assertIn('channel="${ELEPHANT_INSTALL_CHANNEL:-dev}"', content) self.assertIn("--channel CHANNEL", content) @@ -102,7 +104,7 @@ def test_publish_workflow_builds_dev_versions_from_main(self) -> None: self.assertIn("- main", content) self.assertIn("environment: PYPI_API_TOKEN", content) self.assertIn("f'version = \"{target}\"'", content) - self.assertIn("base = re.sub(r\"\\.dev\\d+$\", \"\", current)", content) + self.assertIn('base = re.sub(r"\\.dev\\d+$", "", current)', content) self.assertIn("make package-build", content) self.assertIn("make package-verify", content) self.assertIn("apps/site/node_modules", makefile_content) @@ -114,7 +116,10 @@ def test_publish_workflow_builds_dev_versions_from_main(self) -> None: def test_pyproject_excludes_site_packages_from_python_distribution(self) -> None: content = (ROOT / "pyproject.toml").read_text(encoding="utf-8") - self.assertIn('include = ["apps", "apps.api*", "apps.cli*", "apps.dashboard*", "apps.gateway*", "packages*"]', content) + self.assertIn( + 'include = ["apps", "apps.api*", "apps.cli*", "apps.dashboard*", "apps.gateway*", "packages*"]', + content, + ) self.assertIn('exclude = ["tests*", ".worktrees*", "apps.site*"]', content) self.assertIn('"apps.dashboard" = ["dist/*", "dist/assets/*"]', content) diff --git a/tests/e2e/deploy/test_wheel_install.py b/tests/e2e/deploy/test_wheel_install.py index b9aeb85..26f804c 100644 --- a/tests/e2e/deploy/test_wheel_install.py +++ b/tests/e2e/deploy/test_wheel_install.py @@ -39,10 +39,24 @@ def _elephant_bin(self) -> Path: return self.install_venv / "bin" / "elephant" def test_built_wheel_installs_cleanly(self) -> None: - subprocess.run([sys.executable, "-m", "venv", str(self.build_venv)], cwd=ROOT, check=True, text=True) + subprocess.run( + [sys.executable, "-m", "venv", str(self.build_venv)], + cwd=ROOT, + check=True, + text=True, + ) build_python = self._python_bin(self.build_venv) subprocess.run( - [str(build_python), "-m", "pip", "wheel", ".", "--no-deps", "-w", str(self.dist_dir)], + [ + str(build_python), + "-m", + "pip", + "wheel", + ".", + "--no-deps", + "-w", + str(self.dist_dir), + ], cwd=ROOT, check=True, text=True, @@ -54,7 +68,12 @@ def test_built_wheel_installs_cleanly(self) -> None: with zipfile.ZipFile(wheels[0]) as wheel: self.assertNotIn("packages/state/ELEPHANT.md", wheel.namelist()) - subprocess.run([sys.executable, "-m", "venv", str(self.install_venv)], cwd=ROOT, check=True, text=True) + subprocess.run( + [sys.executable, "-m", "venv", str(self.install_venv)], + cwd=ROOT, + check=True, + text=True, + ) install_python = self._python_bin(self.install_venv) subprocess.run( [str(install_python), "-m", "pip", "install", str(wheels[0])], diff --git a/tests/e2e/gateway/__init__.py b/tests/e2e/gateway/__init__.py index 8b13789..e69de29 100644 --- a/tests/e2e/gateway/__init__.py +++ b/tests/e2e/gateway/__init__.py @@ -1 +0,0 @@ - diff --git a/tests/e2e/gateway/test_gateway_adapter.py b/tests/e2e/gateway/test_gateway_adapter.py index 49d4e51..04ab187 100644 --- a/tests/e2e/gateway/test_gateway_adapter.py +++ b/tests/e2e/gateway/test_gateway_adapter.py @@ -8,7 +8,6 @@ import json import os from pathlib import Path -import signal import sys import tempfile import threading @@ -51,7 +50,7 @@ import apps.gateway.__main__ as gateway_main from apps.gateway.__main__ import command_main from apps.gateway.gateway_main_parser import _build_app -from apps.provider_runtime import provider_profile_from_payload, runtime_local_secret_env_path +from apps.provider_runtime import provider_profile_from_payload from packages.gateway_core import ( DEFAULT_GATEWAY_ACCOUNT_ID, GatewayAccountRef, @@ -65,7 +64,12 @@ from packages.contracts.layers import Episode from packages.contracts.runtime import EvidenceRetrievalRequest from packages.models import SurfaceModelProviderCapability -from packages.runtime_config import global_config_path_for_state_dir, load_global_config, save_provider_to_config, write_global_config +from packages.runtime_config import ( + global_config_path_for_state_dir, + load_global_config, + save_provider_to_config, + write_global_config, +) from packages.security.runtime import PolicyDecision from packages.storage import RuntimeStorageRepository @@ -142,6 +146,7 @@ def setUp(self) -> None: }, } self._write_profile_manifest(self.profile_manifest) + def tearDown(self) -> None: self.ensure_discord_sdk_patcher.stop() self.ensure_feishu_sdk_patcher.stop() @@ -155,7 +160,11 @@ def test_gateway_recall_capability_accepts_episode_scope(self) -> None: class FakeRecallRuntime: def retrieve_evidence(self, request): calls.append({"evidence_request": request}) - return SimpleNamespace(candidates=(SimpleNamespace(recall="personal-recall"),), scope_episode_ids=request.lineage_episode_ids, scope_reason=request.scope_reason) + return SimpleNamespace( + candidates=(SimpleNamespace(recall="personal-recall"),), + scope_episode_ids=request.lineage_episode_ids, + scope_reason=request.scope_reason, + ) capability = GatewayRecallCapability(FakeRecallRuntime()) @@ -170,7 +179,10 @@ def retrieve_evidence(self, request): ) retrieval = capability.retrieve_evidence(evidence_request) self.assertEqual(retrieval.candidates[0].recall, "personal-recall") - self.assertEqual(calls[0]["evidence_request"].scopes, ("episode", "elephant", "personal_model")) + self.assertEqual( + calls[0]["evidence_request"].scopes, + ("episode", "elephant", "personal_model"), + ) self.assertEqual(calls[0]["evidence_request"].personal_model_id, "personal-model:zoey") def test_gateway_cli_app_reuses_cli_provider_when_im_profile_has_none(self) -> None: @@ -211,9 +223,14 @@ def test_gateway_cli_app_reuses_cli_provider_when_im_profile_has_none(self) -> N self.assertEqual(app.provider_runtime["provider_id"], "openai-compatible") self.assertEqual(app.provider_runtime["default_model"], "openai/gpt-4o-mini") self.assertEqual(app.provider_runtime["source"], "configured") - self.assertEqual(app.model_provider.surface.resolve_credentials(app.provider_profile)["api_key"], "sk-cli-local-vault") + self.assertEqual( + app.model_provider.surface.resolve_credentials(app.provider_profile)["api_key"], + "sk-cli-local-vault", + ) - def test_gateway_cli_app_reuses_default_local_provider_when_dashboard_profile_has_none(self) -> None: + def test_gateway_cli_app_reuses_default_local_provider_when_dashboard_profile_has_none( + self, + ) -> None: gateway_profile_dir = Path(self.tempdir.name) / "dashboard-profile" cli_profile_dir = Path(self.tempdir.name) / "dashboard-cli-profile" default_home = Path(self.tempdir.name) / "default-home" @@ -221,10 +238,17 @@ def test_gateway_cli_app_reuses_default_local_provider_when_dashboard_profile_ha gateway_profile_dir.mkdir() cli_profile_dir.mkdir() default_profile_dir.mkdir(parents=True) - minimal_manifest = {"profile_id": "profile:gateway", "display_name": "Gateway", "mode": "default"} + minimal_manifest = { + "profile_id": "profile:gateway", + "display_name": "Gateway", + "mode": "default", + } (gateway_profile_dir / "profile.json").write_text(json.dumps(minimal_manifest), encoding="utf-8") (cli_profile_dir / "profile.json").write_text(json.dumps(minimal_manifest), encoding="utf-8") - (default_profile_dir / "profile.json").write_text((self.profile_dir / "profile.json").read_text(encoding="utf-8"), encoding="utf-8") + (default_profile_dir / "profile.json").write_text( + (self.profile_dir / "profile.json").read_text(encoding="utf-8"), + encoding="utf-8", + ) default_state_dir = default_home / "herd" default_state_dir.mkdir(parents=True) save_provider_to_config( @@ -310,7 +334,10 @@ def test_gateway_help_omits_hidden_top_level_aliases(self) -> None: self.assertEqual(exit_info.exception.code, 0) rendered = output.getvalue() self.assertNotIn("==SUPPRESS==", rendered) - self.assertIn("{setup,status,doctor,describe,feishu,discord,dingding,weixin,wecom}", rendered) + self.assertIn( + "{setup,status,doctor,describe,feishu,discord,dingding,weixin,wecom}", + rendered, + ) self.assertNotIn("\n serve", rendered) self.assertNotIn("\n add", rendered) @@ -640,7 +667,9 @@ async def send_request(self, request, *, account): self.requests.append((normalized_request, account)) return {"id": "discord-reply-1"} - def test_gateway_add_feishu_command_writes_secret_reference_profile_config(self) -> None: + def test_gateway_add_feishu_command_writes_secret_reference_profile_config( + self, + ) -> None: self._update_manifest(lambda payload: payload.pop("gateway", None)) output = io.StringIO() @@ -778,7 +807,9 @@ def test_ensure_discord_sdk_available_installs_missing_dependency(self) -> None: finally: self.ensure_discord_sdk = self.ensure_discord_sdk_patcher.start() - def test_gateway_add_discord_command_writes_profile_config_and_local_secret(self) -> None: + def test_gateway_add_discord_command_writes_profile_config_and_local_secret( + self, + ) -> None: self._update_manifest(lambda payload: payload["gateway"]["adapters"].pop("discord", None)) output = io.StringIO() @@ -851,13 +882,24 @@ def test_gateway_add_discord_command_writes_profile_config_and_local_secret(self self.assertEqual(account["credentials_status"], "configured") self.assertEqual(account["bot_token_env_var"], "ELEPHANT_TEST_DISCORD_BOT_TOKEN") - def test_gateway_add_discord_command_uses_wizard_by_default_when_shell_is_interactive(self) -> None: + def test_gateway_add_discord_command_uses_wizard_by_default_when_shell_is_interactive( + self, + ) -> None: self._update_manifest(lambda payload: payload["gateway"]["adapters"].pop("discord", None)) output = io.StringIO() with ( - mock.patch("apps.gateway.gateway_main_setup_impl._interactive_shell_supported", return_value=True), - mock.patch("apps.gateway.gateway_main_setup_impl._start_discord_runtime_after_setup", return_value=0) as auto_start, - mock.patch("apps.gateway.gateway_main_setup_impl.getpass.getpass", return_value="wizard-discord-token"), + mock.patch( + "apps.gateway.gateway_main_setup_impl._interactive_shell_supported", + return_value=True, + ), + mock.patch( + "apps.gateway.gateway_main_setup_impl._start_discord_runtime_after_setup", + return_value=0, + ) as auto_start, + mock.patch( + "apps.gateway.gateway_main_setup_impl.getpass.getpass", + return_value="wizard-discord-token", + ), redirect_stdout(output), ): exit_code = command_main( @@ -897,7 +939,9 @@ def test_gateway_add_discord_command_uses_wizard_by_default_when_shell_is_intera local_secrets = json.loads((self.state_dir / "gateway-local-secrets.json").read_text(encoding="utf-8")) self.assertEqual(local_secrets[DEFAULT_DISCORD_BOT_TOKEN_ENV], "wizard-discord-token") - def test_gateway_add_discord_command_replaces_unconfigured_default_placeholder(self) -> None: + def test_gateway_add_discord_command_replaces_unconfigured_default_placeholder( + self, + ) -> None: output = io.StringIO() with redirect_stdout(output): exit_code = command_main( @@ -927,7 +971,9 @@ def test_gateway_add_discord_command_replaces_unconfigured_default_placeholder(s rendered = output.getvalue() self.assertNotIn("Configure the Discord bot token", rendered) - def test_gateway_add_discord_command_can_disable_account_without_disabling_adapter(self) -> None: + def test_gateway_add_discord_command_can_disable_account_without_disabling_adapter( + self, + ) -> None: output = io.StringIO() with redirect_stdout(output): exit_code = command_main( @@ -956,7 +1002,9 @@ def test_gateway_add_discord_command_can_disable_account_without_disabling_adapt self.assertFalse(discord["accounts"][0]["enabled"]) self.assertIn("Discord account enabled for default runtime starts: no", output.getvalue()) - def test_gateway_add_feishu_command_updates_existing_account_without_clobbering_profile(self) -> None: + def test_gateway_add_feishu_command_updates_existing_account_without_clobbering_profile( + self, + ) -> None: output = io.StringIO() with redirect_stdout(output): exit_code = command_main( @@ -1030,7 +1078,9 @@ def test_gateway_add_feishu_command_updates_existing_account_without_clobbering_ ("secret-feishu-ops-feishu-app-id", "secret-feishu-ops-feishu-app-secret"), ) - def test_gateway_add_feishu_command_persists_local_secret_file_for_raw_credentials(self) -> None: + def test_gateway_add_feishu_command_persists_local_secret_file_for_raw_credentials( + self, + ) -> None: self._update_manifest(lambda payload: payload.pop("gateway", None)) output = io.StringIO() @@ -1082,9 +1132,15 @@ def test_im_setup_command_can_capture_raw_credentials(self) -> None: output = io.StringIO() with ( - mock.patch("apps.gateway.gateway_main_setup_impl._start_feishu_runtime_after_setup", return_value=0) as auto_start, + mock.patch( + "apps.gateway.gateway_main_setup_impl._start_feishu_runtime_after_setup", + return_value=0, + ) as auto_start, mock.patch("builtins.input", side_effect=lambda _prompt="": next(scripted_answers)), - mock.patch("apps.gateway.gateway_main_setup_impl.getpass.getpass", return_value="wizard-app-secret-789"), + mock.patch( + "apps.gateway.gateway_main_setup_impl.getpass.getpass", + return_value="wizard-app-secret-789", + ), redirect_stdout(output), ): exit_code = command_main( @@ -1136,9 +1192,15 @@ def test_im_setup_command_does_not_capture_elephant_defaults(self) -> None: scripted_answers = iter(["3", "wizard-app-id-single"]) with ( - mock.patch("apps.gateway.gateway_main_setup_impl._start_feishu_runtime_after_setup", return_value=0), + mock.patch( + "apps.gateway.gateway_main_setup_impl._start_feishu_runtime_after_setup", + return_value=0, + ), mock.patch("builtins.input", side_effect=lambda _prompt="": next(scripted_answers)), - mock.patch("apps.gateway.gateway_main_setup_impl.getpass.getpass", return_value="wizard-app-secret-single"), + mock.patch( + "apps.gateway.gateway_main_setup_impl.getpass.getpass", + return_value="wizard-app-secret-single", + ), ): exit_code = command_main( ["setup"], @@ -1174,7 +1236,10 @@ def test_gateway_feishu_help_lists_runtime_commands(self) -> None: self.assertEqual(exit_info.exception.code, 0) rendered = output.getvalue() - self.assertIn("{setup,remove,start,status,stop,restart,logs,describe,doctor,message}", rendered) + self.assertIn( + "{setup,remove,start,status,stop,restart,logs,describe,doctor,message}", + rendered, + ) self.assertIn("setup Add or update a Feishu account.", rendered) self.assertIn("remove Remove a Feishu account.", rendered) self.assertIn("status Show Feishu status.", rendered) @@ -1202,7 +1267,15 @@ def test_gateway_feishu_logs_reads_tail_and_can_print_path(self) -> None: output = io.StringIO() with redirect_stdout(output): exit_code = command_main( - ["feishu", "logs", "ops-feishu", "--transport", "long-connection", "--tail", "2"], + [ + "feishu", + "logs", + "ops-feishu", + "--transport", + "long-connection", + "--tail", + "2", + ], default_state_dir=self.state_dir, default_control_state_dir=self.state_dir, ) @@ -1213,7 +1286,14 @@ def test_gateway_feishu_logs_reads_tail_and_can_print_path(self) -> None: path_output = io.StringIO() with redirect_stdout(path_output): exit_code = command_main( - ["feishu", "logs", "ops-feishu", "--transport", "long-connection", "--path"], + [ + "feishu", + "logs", + "ops-feishu", + "--transport", + "long-connection", + "--path", + ], default_state_dir=self.state_dir, default_control_state_dir=self.state_dir, ) @@ -1236,7 +1316,13 @@ def test_gateway_feishu_status_reports_running_detached_runtime(self) -> None: "pid_path": str(pid_path), "log_path": str(self.state_dir / "feishu-long-connection.log"), "record_path": str(record_path), - "command": [sys.executable, "-m", "apps.launcher", "gateway", "start"], + "command": [ + sys.executable, + "-m", + "apps.launcher", + "gateway", + "start", + ], "profile_dir": str(self.profile_dir), "state_dir": str(self.state_dir), "cli_profile_dir": str(self.profile_dir), @@ -1270,7 +1356,14 @@ def test_gateway_feishu_stop_updates_runtime_record_and_cleans_pid(self) -> None output = io.StringIO() with redirect_stdout(output): exit_code = command_main( - ["feishu", "stop", "--transport", "long-connection", "--timeout", "0.1"], + [ + "feishu", + "stop", + "--transport", + "long-connection", + "--timeout", + "0.1", + ], default_state_dir=self.state_dir, default_control_state_dir=self.state_dir, ) @@ -1294,14 +1387,25 @@ def poll(self) -> None: side_effect=[ None, None, - {"status": "running", "pid": 54321, "state_dir": str(self.state_dir)}, + { + "status": "running", + "pid": 54321, + "state_dir": str(self.state_dir), + }, ], ), mock.patch("apps.gateway.__main__.time.sleep", return_value=None), redirect_stdout(output), ): exit_code = command_main( - ["feishu", "restart", "--transport", "long-connection", "--timeout", "0.1"], + [ + "feishu", + "restart", + "--transport", + "long-connection", + "--timeout", + "0.1", + ], default_state_dir=self.state_dir, default_control_state_dir=self.state_dir, ) @@ -1339,7 +1443,11 @@ def poll(self) -> None: "apps.daemon_command._daemon_healthz_payload", side_effect=[ None, - {"status": "running", "pid": 43210, "state_dir": str(self.state_dir)}, + { + "status": "running", + "pid": 43210, + "state_dir": str(self.state_dir), + }, ], ), mock.patch("apps.gateway.__main__.time.sleep", return_value=None), @@ -1379,7 +1487,9 @@ def poll(self) -> None: self.assertEqual(command[command.index("--cli-state-dir") + 1], str(self.state_dir)) self.assertTrue(popen.call_args.kwargs["start_new_session"]) - def test_gateway_feishu_start_detach_launches_unified_daemon_with_cli_state(self) -> None: + def test_gateway_feishu_start_detach_launches_unified_daemon_with_cli_state( + self, + ) -> None: class FakeProcess: pid = 43211 @@ -1394,7 +1504,11 @@ def poll(self) -> None: "apps.daemon_command._daemon_healthz_payload", side_effect=[ None, - {"status": "running", "pid": 43211, "state_dir": str(self.state_dir)}, + { + "status": "running", + "pid": 43211, + "state_dir": str(self.state_dir), + }, ], ), mock.patch("apps.gateway.__main__.time.sleep", return_value=None), @@ -1433,7 +1547,10 @@ def test_setup_reuses_profile_bundle_and_provider_profile(self) -> None: self.assertEqual(summary["provider"]["profile_id"], "provider-openrouter") self.assertEqual(summary["provider"]["default_model"], "openai/gpt-4o-mini") self.assertEqual(summary["provider"]["model_id"], "openai/gpt-4o-mini") - self.assertIn(summary["provider"]["embedding_bootstrap_status"], EMBEDDING_BOOTSTRAP_STATUSES) + self.assertIn( + summary["provider"]["embedding_bootstrap_status"], + EMBEDDING_BOOTSTRAP_STATUSES, + ) self.assertEqual( summary["adapter_setup"]["feishu"]["preferred_transport"], "long-connection", @@ -1509,7 +1626,9 @@ def test_gateway_chat_runtime_exposes_model_tools_and_skills(self) -> None: self.assertNotIn("tool.memory.note", model_visible) self.assertNotIn("tool.skill.manage", model_visible) - def test_gateway_chat_context_discloses_skill_index_and_allows_skill_list_tool(self) -> None: + def test_gateway_chat_context_discloses_skill_index_and_allows_skill_list_tool( + self, + ) -> None: app, _, _ = self._build() self._bind_gateway_conversation( app, @@ -1620,7 +1739,9 @@ def test_setup_summary_accepts_custom_plugin_registry_adapter(self) -> None: DEFAULT_GATEWAY_ACCOUNT_ID, ) - def test_load_discord_gateway_accounts_reads_allowlists_and_runtime_metadata(self) -> None: + def test_load_discord_gateway_accounts_reads_allowlists_and_runtime_metadata( + self, + ) -> None: self._update_manifest( lambda payload: payload["gateway"]["adapters"].update( { @@ -1654,7 +1775,9 @@ def test_load_discord_gateway_accounts_reads_allowlists_and_runtime_metadata(sel self.assertEqual(account.runtime_metadata["shard_count"], 2) self.assertEqual(tuple(account.runtime_metadata["shard_ids"]), (0, 1)) - def test_load_discord_gateway_accounts_skips_disabled_accounts_but_describe_reports_them(self) -> None: + def test_load_discord_gateway_accounts_skips_disabled_accounts_but_describe_reports_them( + self, + ) -> None: self._update_manifest( lambda payload: payload["gateway"]["adapters"].update( { @@ -1775,7 +1898,14 @@ def test_gateway_describe_all_includes_discord_service(self) -> None: "pid_path": str(pid_path), "log_path": str(log_path), "record_path": str(record_path), - "command": [sys.executable, "-m", "apps.launcher", "gateway", "discord", "start"], + "command": [ + sys.executable, + "-m", + "apps.launcher", + "gateway", + "discord", + "start", + ], "profile_dir": str(self.profile_dir), "state_dir": str(self.state_dir), "started_at": datetime.now(UTC).isoformat(), @@ -1840,7 +1970,14 @@ def test_gateway_discord_doctor_reports_runtime_state(self) -> None: "pid_path": str(pid_path), "log_path": str(self.state_dir / "discord-gateway.log"), "record_path": str(record_path), - "command": [sys.executable, "-m", "apps.launcher", "gateway", "discord", "start"], + "command": [ + sys.executable, + "-m", + "apps.launcher", + "gateway", + "discord", + "start", + ], "profile_dir": str(self.profile_dir), "state_dir": str(self.state_dir), "started_at": datetime.now(UTC).isoformat(), @@ -1878,7 +2015,10 @@ def test_gateway_discord_help_lists_runtime_commands(self) -> None: self.assertEqual(exit_info.exception.code, 0) rendered = output.getvalue() - self.assertIn("{setup,remove,start,status,stop,restart,logs,describe,doctor,message}", rendered) + self.assertIn( + "{setup,remove,start,status,stop,restart,logs,describe,doctor,message}", + rendered, + ) self.assertIn("setup Add or update a Discord account.", rendered) self.assertIn("remove Remove a Discord account.", rendered) self.assertIn("status Show Discord status.", rendered) @@ -1914,7 +2054,11 @@ def poll(self) -> None: "apps.daemon_command._daemon_healthz_payload", side_effect=[ None, - {"status": "running", "pid": 54322, "state_dir": str(self.state_dir)}, + { + "status": "running", + "pid": 54322, + "state_dir": str(self.state_dir), + }, ], ), mock.patch("apps.gateway.__main__.time.sleep", return_value=None), @@ -1954,7 +2098,9 @@ def poll(self) -> None: self.assertEqual(command[command.index("--cli-state-dir") + 1], str(self.state_dir)) self.assertTrue(launcher_calls[0].kwargs["start_new_session"]) - def test_discord_service_dispatch_event_delivers_dm_reply_with_mentions_suppressed(self) -> None: + def test_discord_service_dispatch_event_delivers_dm_reply_with_mentions_suppressed( + self, + ) -> None: self._update_manifest( lambda payload: payload["gateway"]["adapters"].update( { @@ -2161,7 +2307,16 @@ def wake(self, session_id: str, *, inspect_only: bool = False): self.assertEqual(result.response_body["session_id"], expected_session_id) self.assertEqual(result.response_body["delivery_outcome"], "delivered") self.assertEqual(result.response_body["external_message_id"], "discord-reply-1") - self.assertEqual(shared_runtime_calls, [{"session_id": expected_session_id, "prompt": "hello from discord control", "conversation_id": "dm-control-1"}]) + self.assertEqual( + shared_runtime_calls, + [ + { + "session_id": expected_session_id, + "prompt": "hello from discord control", + "conversation_id": "dm-control-1", + } + ], + ) self.assertEqual(len(delivery_transport.requests), 2) request, account = delivery_transport.requests[-1] self.assertEqual(account.account_id, "ops-discord") @@ -2172,7 +2327,9 @@ def wake(self, session_id: str, *, inspect_only: bool = False): "msg-control-1", ) - def test_weixin_and_wecom_default_control_bridge_handles_elephant_commands(self) -> None: + def test_weixin_and_wecom_default_control_bridge_handles_elephant_commands( + self, + ) -> None: self._update_manifest( lambda payload: payload["gateway"]["adapters"].update( { @@ -2288,7 +2445,10 @@ def schedule_learning_for_session(self, **kwargs) -> None: app=app, cli_runtime_factory=lambda profile_dir, state_dir: FakeCliRuntime(), default_cli_state_dir=str(self.state_dir), - environ={"ELEPHANT_TEST_WECOM_BOT_ID": "bot-id", "ELEPHANT_TEST_WECOM_SECRET": "secret"}, + environ={ + "ELEPHANT_TEST_WECOM_BOT_ID": "bot-id", + "ELEPHANT_TEST_WECOM_SECRET": "secret", + }, ), WECOM_ADAPTER_ID, "ops-wecom", @@ -2309,7 +2469,14 @@ def schedule_learning_for_session(self, **kwargs) -> None: ), ) - for service, adapter_id, account_id, conversation_id, _transport, inbound_factory in cases: + for ( + service, + adapter_id, + account_id, + conversation_id, + _transport, + inbound_factory, + ) in cases: with self.subTest(service=service.service_key): self.assertIsNotNone(service.cli_control) control = service.describe()["control"] @@ -2338,7 +2505,9 @@ def schedule_learning_for_session(self, **kwargs) -> None: self.assertEqual(follow_up.elephant_id, "demo") self.assertEqual(follow_up.session_id, bind_result.session_id) - def test_weixin_ilink_serializes_same_conversation_across_runtime_and_reply_send(self) -> None: + def test_weixin_ilink_serializes_same_conversation_across_runtime_and_reply_send( + self, + ) -> None: app, _, _ = self._build() service = WeixinGatewayService(app=app) service._resolved_account_id = "ops-weixin" @@ -2403,7 +2572,9 @@ def inbound_message(message_id: str, text: str) -> dict[str, object]: asyncio.run(scenario()) - def test_weixin_ilink_serializes_same_conversation_for_cli_control_messages(self) -> None: + def test_weixin_ilink_serializes_same_conversation_for_cli_control_messages( + self, + ) -> None: app, _, _ = self._build() service = WeixinGatewayService(app=app) service._resolved_account_id = "ops-weixin" @@ -2444,7 +2615,11 @@ def inbound_message(message_id: str, text: str) -> dict[str, object]: "item_list": [{"type": 1, "text_item": {"text": text}}], } - with mock.patch.object(type(app), "handle_message", side_effect=AssertionError("shared runtime should not run for handled control messages")): + with mock.patch.object( + type(app), + "handle_message", + side_effect=AssertionError("shared runtime should not run for handled control messages"), + ): with mock.patch.object(type(service), "_send_ilink_message", new=send_stub): first_task = asyncio.create_task( service._process_message_safe(inbound_message("wx-control-1", "first control")) @@ -2513,9 +2688,14 @@ def test_discord_adapter_routes_thread_messages_under_parent_channel(self) -> No self.assertEqual(exchange.route.inbound.conversation.parent_conversation_id, "channel-7") self.assertEqual(exchange.route.inbound.conversation.thread_id, "thread-42") self.assertEqual(exchange.route.inbound.chat_type, "topic") - self.assertEqual(exchange.route.session.session_id, "session:messaging.discord:ops-discord:thread-42") + self.assertEqual( + exchange.route.session.session_id, + "session:messaging.discord:ops-discord:thread-42", + ) - def test_discord_service_should_ignore_bot_self_and_system_sdk_messages(self) -> None: + def test_discord_service_should_ignore_bot_self_and_system_sdk_messages( + self, + ) -> None: app, _, _ = self._build() service = DiscordGatewayService(app=app) @@ -2553,7 +2733,9 @@ def test_discord_service_should_ignore_bot_self_and_system_sdk_messages(self) -> ) ) - def test_discord_gateway_service_starts_sdk_client_and_dispatches_replies(self) -> None: + def test_discord_gateway_service_starts_sdk_client_and_dispatches_replies( + self, + ) -> None: self._update_manifest( lambda payload: payload["gateway"]["adapters"].update( { @@ -2825,7 +3007,9 @@ def test_discord_reply_request_wraps_command_code_and_formula_blocks(self) -> No self.assertIn("```python\ndef add(a, b):\n return a + b\n```", content) self.assertIn("```tex\nx^2 + y^2 = z^2\n```", content) - def test_discord_delivery_transport_keeps_fenced_blocks_balanced_across_chunks(self) -> None: + def test_discord_delivery_transport_keeps_fenced_blocks_balanced_across_chunks( + self, + ) -> None: requests: list[dict[str, object]] = [] class FakeAllowedMentions: @@ -2911,7 +3095,9 @@ class FakeDiscord: self.assertTrue(str(requests[0]["content"]).startswith("```python")) self.assertTrue(str(requests[-1]["content"]).rstrip().endswith("```")) - def test_discord_delivery_transport_uses_attachment_fallback_for_very_long_reply(self) -> None: + def test_discord_delivery_transport_uses_attachment_fallback_for_very_long_reply( + self, + ) -> None: requests: list[dict[str, object]] = [] class FakeAllowedMentions: @@ -2978,7 +3164,7 @@ class FakeDiscord: AllowedMentions = FakeAllowedMentions File = FakeFile - long_content = ("HTTP SERVER\n" * 900) + long_content = "HTTP SERVER\n" * 900 transport = DiscordPyDeliveryTransport(client=FakeClient(), discord_module=FakeDiscord()) response = asyncio.run( @@ -3007,7 +3193,9 @@ class FakeDiscord: self.assertEqual(requests[0]["file"].description, "Full Discord reply body") self.assertEqual(requests[0]["file"].content, long_content) - def test_discord_gateway_service_skips_blocked_enabled_accounts_during_multi_start(self) -> None: + def test_discord_gateway_service_skips_blocked_enabled_accounts_during_multi_start( + self, + ) -> None: self._update_manifest( lambda payload: payload["gateway"]["adapters"].update( { @@ -3078,7 +3266,9 @@ class FakeDiscord: self.assertIn("Skipping Discord account 'shadow-discord'", stderr.getvalue()) self.assertEqual(service.describe()["account_status"]["service_status"], "degraded") - def test_chat_bot_identity_mapping_and_session_reuse_persist_across_restart(self) -> None: + def test_chat_bot_identity_mapping_and_session_reuse_persist_across_restart( + self, + ) -> None: app, chat_adapter, _ = self._build() self._bind_gateway_conversation( app, @@ -3109,10 +3299,17 @@ def test_chat_bot_identity_mapping_and_session_reuse_persist_across_restart(self self.assertNotEqual(first.delivery.outbound.body, "ack: hello") first_records = app.recall_evidence_records(first.route.session.session_id) self.assertEqual( - tuple(record.metadata.get("raw_user_query") for record in first_records if record.kind == "effective_user_query"), + tuple( + record.metadata.get("raw_user_query") + for record in first_records + if record.kind == "effective_user_query" + ), ("hello",), ) - self.assertEqual(len(tuple(record for record in first_records if record.kind == "emit_response")), 1) + self.assertEqual( + len(tuple(record for record in first_records if record.kind == "emit_response")), + 1, + ) restarted_app, restarted_chat, _ = self._build() second = restarted_chat.receive_text( @@ -3138,10 +3335,17 @@ def test_chat_bot_identity_mapping_and_session_reuse_persist_across_restart(self assert second.delivery.outbound is not None second_records = restarted_app.recall_evidence_records(second.route.session.session_id) self.assertEqual( - tuple(record.metadata.get("raw_user_query") for record in second_records if record.kind == "effective_user_query"), + tuple( + record.metadata.get("raw_user_query") + for record in second_records + if record.kind == "effective_user_query" + ), ("hello", "follow-up"), ) - self.assertEqual(len(tuple(record for record in second_records if record.kind == "emit_response")), 2) + self.assertEqual( + len(tuple(record for record in second_records if record.kind == "emit_response")), + 2, + ) self.assertEqual(len(restarted_app.identity_records()), 1) self.assertEqual(len(restarted_app.session_records()), 1) @@ -3179,8 +3383,14 @@ def test_chat_bot_identity_mapping_separates_accounts(self) -> None: self.assertNotEqual(first.route.session.session_id, second.route.session.session_id) self.assertEqual(first.route.identity.key.account_id, "ops-bot") self.assertEqual(second.route.identity.key.account_id, "support-bot") - self.assertEqual(first.route.session.session_id, f"session:{CHAT_BOT_ADAPTER_ID}:ops-bot:chat-1") - self.assertEqual(second.route.session.session_id, f"session:{CHAT_BOT_ADAPTER_ID}:support-bot:chat-1") + self.assertEqual( + first.route.session.session_id, + f"session:{CHAT_BOT_ADAPTER_ID}:ops-bot:chat-1", + ) + self.assertEqual( + second.route.session.session_id, + f"session:{CHAT_BOT_ADAPTER_ID}:support-bot:chat-1", + ) self.assertEqual(len(app.identity_records()), 2) self.assertEqual(len(app.session_records()), 2) @@ -3337,7 +3547,9 @@ def test_feishu_p2p_event_reuses_identity_mapping_across_restart(self) -> None: self.assertEqual(len(restarted_app.identity_records()), 1) self.assertEqual(len(restarted_app.session_records()), 1) - def test_feishu_group_thread_defaults_to_review_and_builds_reply_request(self) -> None: + def test_feishu_group_thread_defaults_to_review_and_builds_reply_request( + self, + ) -> None: app, _, _ = self._build() feishu = FeishuMessagingAdapter(app=app) @@ -3573,7 +3785,9 @@ def test_feishu_attachment_refs_preserve_kind_order_and_dedupe_ids(self) -> None ), ) - def test_feishu_gateway_service_uses_manifest_account_and_dispatches_reply(self) -> None: + def test_feishu_gateway_service_uses_manifest_account_and_dispatches_reply( + self, + ) -> None: app, _, _ = self._build() shared_runtime_calls = self._install_shared_runtime_stub(app) expected_session_id = self._gateway_route_session_id( @@ -3976,10 +4190,10 @@ def wake(self, session_id: str, *, inspect_only: bool = False): ) self.assertEqual(len(requests), 3) - def test_feishu_gateway_service_can_ignore_disabled_flag_when_requested(self) -> None: - self._update_manifest( - lambda payload: payload["gateway"]["adapters"]["feishu"].update({"enabled": False}) - ) + def test_feishu_gateway_service_can_ignore_disabled_flag_when_requested( + self, + ) -> None: + self._update_manifest(lambda payload: payload["gateway"]["adapters"]["feishu"].update({"enabled": False})) app, _, _ = self._build() self.assertEqual(load_feishu_gateway_accounts(app), ()) @@ -4000,7 +4214,9 @@ def test_feishu_gateway_service_can_ignore_disabled_flag_when_requested(self) -> self.assertEqual(description["accounts"][0]["account_id"], "ops-feishu") self.assertEqual(description["accounts"][0]["credentials_status"], "configured") - def test_feishu_gateway_service_routes_replies_back_to_matched_account(self) -> None: + def test_feishu_gateway_service_routes_replies_back_to_matched_account( + self, + ) -> None: self._update_manifest( lambda payload: payload["gateway"]["adapters"]["feishu"].update( { @@ -4099,13 +4315,7 @@ def fake_request( return { "code": 0, "msg": "ok", - "data": { - "message_id": ( - "om_reply_ops" - if auth == "Bearer tenant-token-ops" - else "om_reply_support" - ) - }, + "data": {"message_id": ("om_reply_ops" if auth == "Bearer tenant-token-ops" else "om_reply_support")}, } fake_runtime = FakeCliRuntime() @@ -4277,10 +4487,15 @@ def fake_request( ) self.assertEqual(event_status, "200 OK") self.assertEqual(event_body["delivery_outcome"], "delivered") - self.assertEqual(event_body["delivery_request_path"], "/open-apis/im/v1/messages/om_web_1/reply") + self.assertEqual( + event_body["delivery_request_path"], + "/open-apis/im/v1/messages/om_web_1/reply", + ) self.assertEqual(len(requests), 2) - def test_feishu_gateway_service_dedupes_duplicate_shared_runtime_events(self) -> None: + def test_feishu_gateway_service_dedupes_duplicate_shared_runtime_events( + self, + ) -> None: app, _, _ = self._build() requests: list[tuple[str, str, dict[str, object], dict[str, str]]] = [] @@ -4349,7 +4564,9 @@ def fake_request( self.assertEqual(duplicate.response_body["external_message_id"], "om_reply_dedupe_1") self.assertEqual(len(requests), 2) - def test_telegram_gateway_service_uses_manifest_account_and_dispatches_reply(self) -> None: + def test_telegram_gateway_service_uses_manifest_account_and_dispatches_reply( + self, + ) -> None: self._update_manifest( lambda payload: payload["gateway"]["adapters"].update( { @@ -4771,7 +4988,9 @@ class FakeLark: finally: service.shutdown_async_processing() - def test_feishu_gateway_service_dedupes_duplicate_long_connection_control_events(self) -> None: + def test_feishu_gateway_service_dedupes_duplicate_long_connection_control_events( + self, + ) -> None: app, _, _ = self._build() shared_runtime_calls = self._install_shared_runtime_stub(app) expected_session_id = self._gateway_route_session_id( @@ -5155,8 +5374,9 @@ def test_feishu_long_connection_duplicate_statuses_are_stateful(self) -> None: transport="long-connection", ) - with mock.patch.object(FeishuGatewayService, "_ensure_async_workers"), mock.patch.object( - FeishuGatewayService, "_schedule_async_job", return_value=False + with ( + mock.patch.object(FeishuGatewayService, "_ensure_async_workers"), + mock.patch.object(FeishuGatewayService, "_schedule_async_job", return_value=False), ): job_key, _, created = service.async_job_store.create_or_get( account_id=inbound.account_id, @@ -5352,12 +5572,17 @@ def wake(self, session_id: str, *, inspect_only: bool = False): lambda: len(shared_runtime_calls) == 2 and len(requests) == 5, message="expected serialized same-conversation jobs to finish with two placeholders and two replies", ) - self.assertEqual([call["prompt"] for call in shared_runtime_calls], ["first message", "second message"]) + self.assertEqual( + [call["prompt"] for call in shared_runtime_calls], + ["first message", "second message"], + ) finally: first_release.set() service.shutdown_async_processing() - def test_feishu_async_long_connection_runs_different_conversations_in_parallel(self) -> None: + def test_feishu_async_long_connection_runs_different_conversations_in_parallel( + self, + ) -> None: app, _, _ = self._build() requests: list[tuple[str, str, dict[str, object], dict[str, str]]] = [] first_started = threading.Event() @@ -5520,7 +5745,9 @@ def wake(self, session_id: str, *, inspect_only: bool = False): release_runtime.set() service.shutdown_async_processing() - def test_feishu_async_long_connection_failure_marks_job_and_surfaces_doctor_status(self) -> None: + def test_feishu_async_long_connection_failure_marks_job_and_surfaces_doctor_status( + self, + ) -> None: app, _, _ = self._build() def fail_shared_runtime(_inbound, _session_id: str) -> None: @@ -5655,7 +5882,9 @@ def wake(self, session_id: str, *, inspect_only: bool = False): finally: service.shutdown_async_processing() - def test_feishu_async_long_connection_recovers_incomplete_jobs_on_startup(self) -> None: + def test_feishu_async_long_connection_recovers_incomplete_jobs_on_startup( + self, + ) -> None: app, _, _ = self._build() shared_runtime_calls = self._install_shared_runtime_stub(app) requests: list[tuple[str, str, dict[str, object], dict[str, str]]] = [] @@ -5801,7 +6030,15 @@ def list_herd(self, *, limit: int = 12) -> tuple[object, ...]: def latest_session_for_elephant(self, elephant_id: str): return None - def create_elephant(self, *, elephant_id: str, profile_id=None, display_name=None, mode=None, session_id=None): + def create_elephant( + self, + *, + elephant_id: str, + profile_id=None, + display_name=None, + mode=None, + session_id=None, + ): raise AssertionError("create_elephant should not be called in describe path") def inspect_session(self, session_id: str): @@ -5838,7 +6075,9 @@ def wake(self, session_id: str, *, inspect_only: bool = False): self.assertEqual(control["runtime_status"], "ready") self.assertEqual(control["known_elephants"], ("demo",)) - def test_feishu_control_bridge_binds_conversation_to_selected_elephant(self) -> None: + def test_feishu_control_bridge_binds_conversation_to_selected_elephant( + self, + ) -> None: app, _, _ = self._build() requests: list[tuple[str, str, dict[str, object], dict[str, str]]] = [] expected_session_id = self._gateway_route_session_id( @@ -5894,6 +6133,7 @@ def __init__(self) -> None: started_at=now, updated_at=now, ) + def list_herd(self, *, limit: int = 12) -> tuple[object, ...]: return ( SimpleNamespace( @@ -6010,7 +6250,16 @@ def wake(self, session_id: str, *, inspect_only: bool = False): self.assertEqual(follow_up.response_body["elephant_id"], "demo") self.assertEqual(follow_up.response_body["session_id"], expected_session_id) - self.assertEqual(shared_runtime_calls, [{"session_id": expected_session_id, "prompt": "keep coding", "conversation_id": "oc_control_1"}]) + self.assertEqual( + shared_runtime_calls, + [ + { + "session_id": expected_session_id, + "prompt": "keep coding", + "conversation_id": "oc_control_1", + } + ], + ) self.assertEqual(len(requests), 3) def test_feishu_control_bridge_can_list_and_report_current_elephant(self) -> None: @@ -6070,6 +6319,7 @@ def __init__(self) -> None: updated_at=now, parent_episode_id=self.demo_root_session.episode_id, ) + def list_herd(self, *, limit: int = 12) -> tuple[object, ...]: return ( SimpleNamespace( @@ -6238,10 +6488,21 @@ def wake(self, session_id: str, *, inspect_only: bool = False): self.assertEqual(follow_up.response_body["elephant_id"], "demo") self.assertEqual(follow_up.response_body["session_id"], expected_session_id) - self.assertEqual(shared_runtime_calls, [{"session_id": expected_session_id, "prompt": "stay on the active elephant", "conversation_id": "oc_control_elephant_status"}]) + self.assertEqual( + shared_runtime_calls, + [ + { + "session_id": expected_session_id, + "prompt": "stay on the active elephant", + "conversation_id": "oc_control_elephant_status", + } + ], + ) self.assertGreaterEqual(len(requests), 5) - def test_feishu_control_bridge_accepts_post_command_wrapped_elephant_use(self) -> None: + def test_feishu_control_bridge_accepts_post_command_wrapped_elephant_use( + self, + ) -> None: app, _, _ = self._build() requests: list[tuple[str, str, dict[str, object], dict[str, str]]] = [] @@ -6351,7 +6612,11 @@ def wake(self, session_id: str, *, inspect_only: bool = False): "content": [ [ {"tag": "text", "text": "- "}, - {"tag": "text", "text": "/elephant create leo", "style": ["bold"]}, + { + "tag": "text", + "text": "/elephant create leo", + "style": ["bold"], + }, ] ], } @@ -6373,7 +6638,9 @@ def wake(self, session_id: str, *, inspect_only: bool = False): self.assertEqual(bind_result.response_body["summary"], "elephant shaped") self.assertGreaterEqual(len(requests), 2) - def test_feishu_control_bridge_reuses_parent_binding_inside_topic_replies(self) -> None: + def test_feishu_control_bridge_reuses_parent_binding_inside_topic_replies( + self, + ) -> None: app, _, _ = self._build() requests: list[tuple[str, str, dict[str, object], dict[str, str]]] = [] parent_session_id = self._gateway_route_session_id( @@ -6452,6 +6719,7 @@ def __init__(self) -> None: updated_at=now, parent_episode_id=self.demo_root_session.episode_id, ) + def list_herd(self, *, limit: int = 12) -> tuple[object, ...]: return ( SimpleNamespace( @@ -6557,7 +6825,16 @@ def wake(self, session_id: str, *, inspect_only: bool = False): self.assertEqual(topic_follow_up.response_body["elephant_id"], "demo") self.assertEqual(topic_follow_up.response_body["session_id"], child_session_id) - self.assertEqual(shared_runtime_calls, [{"session_id": child_session_id, "prompt": "继续这个 session", "conversation_id": "oc_topic_chat:om_topic_root"}]) + self.assertEqual( + shared_runtime_calls, + [ + { + "session_id": child_session_id, + "prompt": "继续这个 session", + "conversation_id": "oc_topic_chat:om_topic_root", + } + ], + ) self.assertGreaterEqual(len(requests), 3) thread_identity = app.core.dependencies.identity_store.lookup( @@ -6571,7 +6848,9 @@ def wake(self, session_id: str, *, inspect_only: bool = False): assert thread_identity is not None self.assertEqual(thread_identity.session_id, child_session_id) - def test_feishu_control_bridge_requires_binding_before_plain_text_routes(self) -> None: + def test_feishu_control_bridge_requires_binding_before_plain_text_routes( + self, + ) -> None: app, _, _ = self._build() def fake_request( @@ -6659,9 +6938,7 @@ def prepare_session_surface(self, session_id: str) -> Episode: def explain_next_step(self, **kwargs): self.explain_calls.append(dict(kwargs)) prompt = str(kwargs["prompt"]) - return SimpleNamespace( - execution=SimpleNamespace(summary=f"cli-handled:{prompt}") - ) + return SimpleNamespace(execution=SimpleNamespace(summary=f"cli-handled:{prompt}")) def wake(self, session_id: str, *, inspect_only: bool = False): raise AssertionError("wake should not be used in this test") @@ -6755,7 +7032,9 @@ def test_interruption_state_is_preserved_when_chat_resumes(self) -> None: "awaiting-operator-reply", ) - def test_telegram_private_update_reuses_identity_mapping_across_restart(self) -> None: + def test_telegram_private_update_reuses_identity_mapping_across_restart( + self, + ) -> None: app, _, _ = self._build() telegram = TelegramMessagingAdapter(app=app) diff --git a/tests/e2e/observability/test_observability_e2e.py b/tests/e2e/observability/test_observability_e2e.py index d894d94..99867a3 100644 --- a/tests/e2e/observability/test_observability_e2e.py +++ b/tests/e2e/observability/test_observability_e2e.py @@ -39,9 +39,11 @@ def _stack_available() -> bool: return False -@unittest.skipUnless(_stack_available(), "Monitoring stack not running (need Jaeger + Prometheus on localhost)") +@unittest.skipUnless( + _stack_available(), + "Monitoring stack not running (need Jaeger + Prometheus on localhost)", +) class ObservabilityE2ETest(unittest.TestCase): - def setUp(self) -> None: setup_mod._initialized = False logger_mod._configured = False @@ -49,6 +51,7 @@ def setUp(self) -> None: self.state_dir = self.tmpdir.name from packages.observability import setup_observability + setup_observability( service_name="elephant-agent-e2e-test", log_level="DEBUG", @@ -84,25 +87,40 @@ def test_traces_metrics_and_logs(self) -> None: with trace_kernel_turn(episode_id="ep-e2e", loop_id="lp-e2e", trigger_type="e2e_test"): logger.info("kernel turn started: episode=ep-e2e") - with trace_model_call(provider_id="test-provider", model_id="test-model-e2e", episode_id="ep-e2e") as span: + with trace_model_call( + provider_id="test-provider", + model_id="test-model-e2e", + episode_id="ep-e2e", + ) as span: time.sleep(0.01) record_token_usage(span, input_tokens=200, output_tokens=100, cache_read_tokens=50) with trace_tool_execution(tool_name="e2e_calculator", episode_id="ep-e2e"): time.sleep(0.01) - with trace_model_call(provider_id="test-provider", model_id="test-model-e2e", episode_id="ep-e2e") as span: + with trace_model_call( + provider_id="test-provider", + model_id="test-model-e2e", + episode_id="ep-e2e", + ) as span: time.sleep(0.01) record_token_usage(span, input_tokens=300, output_tokens=150) logger.info("kernel turn completing: tools=1 model_calls=2") - record_model_metrics(provider_id="test-provider", model_id="test-model-e2e", input_tokens=500, output_tokens=250, duration_s=0.5) + record_model_metrics( + provider_id="test-provider", + model_id="test-model-e2e", + input_tokens=500, + output_tokens=250, + duration_s=0.5, + ) record_tool_metrics(tool_name="e2e_calculator", duration_s=0.02, status="success") record_turn_metrics(episode_id="ep-e2e", duration_s=0.6, trigger_type="e2e_test") logger.info("kernel turn completed: episode=ep-e2e duration=0.60s") from opentelemetry import trace, metrics + tp = trace.get_tracer_provider() if hasattr(tp, "force_flush"): tp.force_flush() @@ -125,22 +143,33 @@ def test_traces_metrics_and_logs(self) -> None: print(f" Traces: {len(traces)}, Spans: {len(all_spans)}") print(f" Operations: {operations}") - self.assertTrue(any("invoke_agent" in op for op in operations), f"No invoke_agent span: {operations}") + self.assertTrue( + any("invoke_agent" in op for op in operations), + f"No invoke_agent span: {operations}", + ) self.assertTrue(any("chat" in op for op in operations), f"No chat span: {operations}") - self.assertTrue(any("execute_tool" in op for op in operations), f"No execute_tool span: {operations}") + self.assertTrue( + any("execute_tool" in op for op in operations), + f"No execute_tool span: {operations}", + ) chat_spans = [s for s in all_spans if "chat" in s.get("operationName", "")] tags = {tag["key"]: tag["value"] for tag in chat_spans[0].get("tags", [])} self.assertEqual(tags.get("gen_ai.request.model"), "test-model-e2e") - print(f" chat span: model={tags.get('gen_ai.request.model')}, " - f"input_tokens={tags.get('gen_ai.usage.input_tokens')}, " - f"output_tokens={tags.get('gen_ai.usage.output_tokens')}") + print( + f" chat span: model={tags.get('gen_ai.request.model')}, " + f"input_tokens={tags.get('gen_ai.usage.input_tokens')}, " + f"output_tokens={tags.get('gen_ai.usage.output_tokens')}" + ) # ---- METRICS (Prometheus) ---- print("\n--- Metrics (Prometheus) ---") metric_queries = [ ("gen_ai_client_token_usage_count", "gen_ai.client.token.usage"), - ("gen_ai_client_operation_duration_count", "gen_ai.client.operation.duration"), + ( + "gen_ai_client_operation_duration_count", + "gen_ai.client.operation.duration", + ), ("elephant_tool_duration_count", "elephant.tool.duration"), ("elephant_kernel_turn_duration_count", "elephant.kernel.turn.duration"), ] @@ -167,12 +196,17 @@ def test_traces_metrics_and_logs(self) -> None: episode_logs = [l for l in log_lines if l.get("episode_id") == "ep-e2e"] self.assertGreater(len(episode_logs), 0, "No logs with episode_id=ep-e2e") - self.assertTrue(all(l.get("trace_id") for l in episode_logs), "Some entries missing trace_id") + self.assertTrue( + all(l.get("trace_id") for l in episode_logs), + "Some entries missing trace_id", + ) print(f" Total entries: {len(log_lines)}, with episode_id: {len(episode_logs)}") for entry in log_lines: - print(f" [{entry.get('level')}] trace={entry.get('trace_id','')[:8]} " - f"episode={entry.get('episode_id')} msg={entry.get('msg','')[:80]}") + print( + f" [{entry.get('level')}] trace={entry.get('trace_id', '')[:8]} " + f"episode={entry.get('episode_id')} msg={entry.get('msg', '')[:80]}" + ) print(f"\n Jaeger UI: {JAEGER_QUERY}/search?service=elephant-agent-e2e-test") print(f" Prometheus: {PROMETHEUS_QUERY}/graph") diff --git a/tests/e2e/release/test_design_closure_certification.py b/tests/e2e/release/test_design_closure_certification.py index 745820d..3e51654 100644 --- a/tests/e2e/release/test_design_closure_certification.py +++ b/tests/e2e/release/test_design_closure_certification.py @@ -50,13 +50,13 @@ "tests.e2e.deploy.test_installed_command_smoke.InstalledCommandLiveSmokeTest", ) -INSTALLED_USER_JOURNEY_TARGETS = ( - "tests.e2e.deploy.test_installed_user_journey", -) +INSTALLED_USER_JOURNEY_TARGETS = ("tests.e2e.deploy.test_installed_user_journey",) class DesignClosureContractsTest(unittest.TestCase): - def test_design_closure_matrix_no_longer_tracks_deleted_voice_or_planning_modules(self) -> None: + def test_design_closure_matrix_no_longer_tracks_deleted_voice_or_planning_modules( + self, + ) -> None: makefile_text = MAKEFILE_PATH.read_text(encoding="utf-8") self.assertNotIn("tests.e2e.voice.test_voice_preview", makefile_text) @@ -92,7 +92,9 @@ def test_makefile_pins_design_closure_matrix(self) -> None: with self.subTest(target=target): self.assertIn(target, text) - def test_design_closure_uses_canonical_docs_and_historical_inputs_stay_deleted(self) -> None: + def test_design_closure_uses_canonical_docs_and_historical_inputs_stay_deleted( + self, + ) -> None: for path in CANONICAL_DESIGN_DOCS: with self.subTest(path=path): self.assertTrue(path.exists(), path) @@ -139,5 +141,6 @@ def test_workflow_keeps_live_provider_manual_and_secret_backed(self) -> None: with self.subTest(target=target): self.assertIn(target, makefile_text) + if __name__ == "__main__": unittest.main() diff --git a/tests/e2e/release/test_release_certification.py b/tests/e2e/release/test_release_certification.py index c18c118..e7869a2 100644 --- a/tests/e2e/release/test_release_certification.py +++ b/tests/e2e/release/test_release_certification.py @@ -59,9 +59,7 @@ "apps/dashboard/src/components/primitives/Primitives.module.css", ) -DASHBOARD_BRAND_ASSET_PATHS = ( - "apps/dashboard/src/assets/brand/elephant-logo.png", -) +DASHBOARD_BRAND_ASSET_PATHS = ("apps/dashboard/src/assets/brand/elephant-logo.png",) PUBLIC_SURFACE_PROOF_PATHS = ( "apps/site/docs/reference/cli.md", @@ -102,7 +100,9 @@ def test_release_modules_exist(self) -> None: with self.subTest(target=target): self.assertTrue(_target_path(target).exists(), target) - def test_release_matrix_no_longer_tracks_deleted_voice_or_planning_modules(self) -> None: + def test_release_matrix_no_longer_tracks_deleted_voice_or_planning_modules( + self, + ) -> None: makefile_text = MAKEFILE_PATH.read_text(encoding="utf-8") self.assertNotIn("tests.e2e.voice.test_voice_preview", makefile_text) @@ -115,7 +115,9 @@ def test_standalone_release_runbooks_stay_deleted(self) -> None: with self.subTest(path=path): self.assertFalse(path.exists(), path) - def test_release_contract_rejects_session_era_goal_or_procedure_routes(self) -> None: + def test_release_contract_rejects_session_era_goal_or_procedure_routes( + self, + ) -> None: for path in (ROOT / "apps" / "api").rglob("*.py"): text = path.read_text(encoding="utf-8") with self.subTest(path=path): @@ -159,7 +161,7 @@ def test_workflow_keeps_live_provider_manual_and_secret_backed(self) -> None: self.assertIn("ELEPHANT_LIVE_PROVIDER_API_KEY", text) self.assertIn("Build dashboard assets for installed smoke", text) self.assertIn("make test-live-provider-smoke", text) - self.assertIn("make release AGENT_BASE_REF=\"$BASE_REF\"", text) + self.assertIn('make release AGENT_BASE_REF="$BASE_REF"', text) self.assertIn("Run canonical system-layer reset release contract", text) makefile_text = MAKEFILE_PATH.read_text(encoding="utf-8") @@ -191,12 +193,8 @@ def test_dashboard_inspection_surface_stays_implemented(self) -> None: with self.subTest(path=path): self.assertTrue((ROOT / path).exists(), path) - dashboard_api = (ROOT / "apps" / "dashboard" / "src" / "lib" / "dashboardApi.ts").read_text( - encoding="utf-8" - ) - cli_doc = (ROOT / "apps" / "site" / "docs" / "reference" / "cli.md").read_text( - encoding="utf-8" - ) + dashboard_api = (ROOT / "apps" / "dashboard" / "src" / "lib" / "dashboardApi.ts").read_text(encoding="utf-8") + cli_doc = (ROOT / "apps" / "site" / "docs" / "reference" / "cli.md").read_text(encoding="utf-8") self.assertIn("/v1/internal/dashboard", dashboard_api) self.assertIn("/v1/internal/dashboard", cli_doc) @@ -212,7 +210,9 @@ def test_dashboard_package_keeps_real_data_scripts_only(self) -> None: self.assertNotIn("capture:refactor-screenshots", package["scripts"]) self.assertNotIn("preview", package["scripts"]) - def test_release_plan_no_longer_depends_on_deleted_dashboard_design_doc(self) -> None: + def test_release_plan_no_longer_depends_on_deleted_dashboard_design_doc( + self, + ) -> None: self.assertFalse((ROOT / "docs" / "system-design" / "operator-dashboard-surface.md").exists()) def test_makefile_pins_reset_package_verification_contract(self) -> None: diff --git a/tests/e2e/support/__init__.py b/tests/e2e/support/__init__.py index 99407a7..080b7d4 100644 --- a/tests/e2e/support/__init__.py +++ b/tests/e2e/support/__init__.py @@ -1,2 +1 @@ """Shared helpers for deterministic e2e tests.""" - diff --git a/tests/e2e/support/mock_provider.py b/tests/e2e/support/mock_provider.py index c218a3c..6c11e97 100644 --- a/tests/e2e/support/mock_provider.py +++ b/tests/e2e/support/mock_provider.py @@ -73,7 +73,12 @@ def do_POST(self) -> None: # noqa: N802 return if outer.fail_chat: self._send_json( - {"error": {"message": "stub provider is unavailable", "type": "server_error"}}, + { + "error": { + "message": "stub provider is unavailable", + "type": "server_error", + } + }, status=503, ) return @@ -144,4 +149,3 @@ def log_message(self, _format: str, *_args: object) -> None: return return Handler - diff --git a/tests/e2e/voice/__init__.py b/tests/e2e/voice/__init__.py index 8b13789..e69de29 100644 --- a/tests/e2e/voice/__init__.py +++ b/tests/e2e/voice/__init__.py @@ -1 +0,0 @@ - diff --git a/tests/integration/harness/test_provider_retry.py b/tests/integration/harness/test_provider_retry.py index fd91d8f..45cf5cb 100644 --- a/tests/integration/harness/test_provider_retry.py +++ b/tests/integration/harness/test_provider_retry.py @@ -3,7 +3,6 @@ from __future__ import annotations from io import BytesIO -from typing import Any import unittest from unittest import mock from urllib import error, request as urllib_request @@ -30,7 +29,14 @@ def _http_error( class _FakeResponse: """Minimal fill-in for urllib's response so _post_json_once parses OK.""" - def __init__(self, *, status: int, body: bytes, headers: dict | None = None, lines: list[bytes] | None = None): + def __init__( + self, + *, + status: int, + body: bytes, + headers: dict | None = None, + lines: list[bytes] | None = None, + ): self.status = status self._body = body self._lines = lines or [] @@ -62,10 +68,15 @@ def fake_urlopen(req, *, timeout): timeouts.append(timeout) if len(attempts) < 3: raise _http_error(429, headers={"Retry-After": "2"}) - return _FakeResponse(status=200, body=b'{"ok": true}', headers={"content-type": "application/json"}) + return _FakeResponse( + status=200, + body=b'{"ok": true}', + headers={"content-type": "application/json"}, + ) - with mock.patch("time.sleep", side_effect=sleeps.append), mock.patch.object( - urllib_request, "urlopen", side_effect=fake_urlopen + with ( + mock.patch("time.sleep", side_effect=sleeps.append), + mock.patch.object(urllib_request, "urlopen", side_effect=fake_urlopen), ): response = transport.post_json( url="http://provider/v1/messages", @@ -89,8 +100,9 @@ def fake_urlopen(req, *, timeout): attempts.append(1) raise _http_error(403, headers={}, body=b'{"error": "nope"}') - with mock.patch("time.sleep"), mock.patch.object( - urllib_request, "urlopen", side_effect=fake_urlopen + with ( + mock.patch("time.sleep"), + mock.patch.object(urllib_request, "urlopen", side_effect=fake_urlopen), ): with self.assertRaises(ProviderHTTPError) as ctx: transport.post_json( @@ -125,8 +137,9 @@ def fake_urlopen(req, *, timeout): return _FakeResponse(status=200, body=b"", lines=stream_lines) chunks: list[JSONHTTPStreamChunk] = [] - with mock.patch("time.sleep"), mock.patch.object( - urllib_request, "urlopen", side_effect=fake_urlopen + with ( + mock.patch("time.sleep"), + mock.patch.object(urllib_request, "urlopen", side_effect=fake_urlopen), ): for chunk in transport.post_json_stream( url="http://provider/v1/messages/stream", @@ -157,8 +170,9 @@ def fake_urlopen(req, *, timeout): return flaky chunks: list[JSONHTTPStreamChunk] = [] - with mock.patch("time.sleep"), mock.patch.object( - urllib_request, "urlopen", side_effect=fake_urlopen + with ( + mock.patch("time.sleep"), + mock.patch.object(urllib_request, "urlopen", side_effect=fake_urlopen), ): with self.assertRaises(ProviderSSEIncompleteError) as ctx: for chunk in transport.post_json_stream( diff --git a/tests/integration/harness/test_reflection_offpath.py b/tests/integration/harness/test_reflection_offpath.py index 2efa468..863e65c 100644 --- a/tests/integration/harness/test_reflection_offpath.py +++ b/tests/integration/harness/test_reflection_offpath.py @@ -25,14 +25,16 @@ import tempfile import unittest -from packages.contracts.layers import Episode from packages.contracts.runtime import ( ContextBundle, ExecutionResult, - PersonalModelRuntimeState, PromptEnvelope, ) -from packages.kernel.runtime import KernelDependencies, KernelService, KernelSourceRequest +from packages.kernel.runtime import ( + KernelDependencies, + KernelService, + KernelSourceRequest, +) from packages.storage import RuntimeStorageRepository @@ -104,7 +106,9 @@ def emit(self, _event): class ReflectionOffHotPathTest(unittest.TestCase): - def test_run_enqueues_learning_job_without_calling_reflection_synchronously(self) -> None: + def test_run_enqueues_learning_job_without_calling_reflection_synchronously( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "elephant.sqlite3") repository.bootstrap() @@ -142,7 +146,9 @@ def test_run_enqueues_learning_job_without_calling_reflection_synchronously(self self.assertEqual(job.status, "queued") self.assertIn(job.trigger, {"episode_close", "checkpoint", "episode_failed"}) - def test_internal_learning_agent_turn_does_not_enqueue_recursive_learning_job(self) -> None: + def test_internal_learning_agent_turn_does_not_enqueue_recursive_learning_job( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "elephant.sqlite3") repository.bootstrap() diff --git a/tests/integration/kernel/test_turn_lifecycle.py b/tests/integration/kernel/test_turn_lifecycle.py index 0b1d228..0cde389 100644 --- a/tests/integration/kernel/test_turn_lifecycle.py +++ b/tests/integration/kernel/test_turn_lifecycle.py @@ -3,7 +3,6 @@ from datetime import datetime, timezone from pathlib import Path import tempfile -from types import SimpleNamespace import unittest from apps.episode_runtime import install_app_episode_runtime @@ -99,7 +98,9 @@ def emit(self, event: dict[str, object]) -> None: class KernelTurnLifecycleResetTest(unittest.TestCase): - def test_replaying_same_checkpoint_step_updates_in_place_without_duplicate_sequence_failure(self) -> None: + def test_replaying_same_checkpoint_step_updates_in_place_without_duplicate_sequence_failure( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "elephant.sqlite3") repository.bootstrap() diff --git a/tests/integration/models_auth/test_anthropic_provider.py b/tests/integration/models_auth/test_anthropic_provider.py index a6abc5c..fcba8e6 100644 --- a/tests/integration/models_auth/test_anthropic_provider.py +++ b/tests/integration/models_auth/test_anthropic_provider.py @@ -6,7 +6,6 @@ from pathlib import Path import threading import sys -from types import SimpleNamespace import unittest ROOT = Path(__file__).resolve().parents[3] @@ -233,7 +232,11 @@ def test_native_request_preserves_history_and_tool_result_blocks(self) -> None: role="assistant", content="", tool_calls=( - {"id": "toolu-1", "name": "tool.web.search", "arguments": {"query": "elephant docs"}}, + { + "id": "toolu-1", + "name": "tool.web.search", + "arguments": {"query": "elephant docs"}, + }, ), ), PromptMessage( @@ -249,7 +252,10 @@ def test_native_request_preserves_history_and_tool_result_blocks(self) -> None: "function": { "name": "tool.web.search", "description": "Search the web.", - "parameters": {"type": "object", "properties": {"query": {"type": "string"}}}, + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + }, }, }, ), @@ -258,14 +264,19 @@ def test_native_request_preserves_history_and_tool_result_blocks(self) -> None: planned = self.adapter.build_request(request) payload = planned.as_mapping() - self.assertEqual([message["role"] for message in payload["messages"][-3:]], ["user", "assistant", "user"]) + self.assertEqual( + [message["role"] for message in payload["messages"][-3:]], + ["user", "assistant", "user"], + ) self.assertEqual(payload["messages"][-2]["content"][0]["type"], "tool_use") self.assertEqual(payload["messages"][-2]["content"][0]["name"], "tool_web_search") self.assertEqual(payload["messages"][-1]["content"][0]["type"], "tool_result") self.assertEqual(payload["messages"][-1]["content"][0]["tool_use_id"], "toolu-1") self.assertEqual(payload["messages"][-1]["content"][1]["text"], "Use that result.") - def test_native_request_uses_bearer_headers_for_anthropic_oauth_tokens(self) -> None: + def test_native_request_uses_bearer_headers_for_anthropic_oauth_tokens( + self, + ) -> None: request = self.adapter.build_request(self._request(), {"api_key": "sk-ant-oat-test-token"}) self.assertEqual(request.headers["Authorization"], "Bearer sk-ant-oat-test-token") @@ -329,7 +340,9 @@ def test_copilot_claude_uses_bearer_auth_and_default_headers(self) -> None: adapter.generate(request, {"api_key": "ghu-copilot"}) - request_headers = {str(key).lower(): str(value) for key, value in dict(self.server.requests[-1]["headers"]).items()} + request_headers = { + str(key).lower(): str(value) for key, value in dict(self.server.requests[-1]["headers"]).items() + } self.assertEqual(request_headers["authorization"], "Bearer ghu-copilot") self.assertEqual(request_headers["anthropic-version"], "2023-06-01") self.assertEqual(request_headers["openai-intent"], "conversation-edits") @@ -350,11 +363,16 @@ def test_session_header_does_not_override_explicit_extra_header(self) -> None: self.assertEqual(request.headers["X-Session-Id"], "configured-session") self.assertNotIn("x-session-id", request.headers) - def test_generate_returns_native_result_without_leaking_secret_material(self) -> None: + def test_generate_returns_native_result_without_leaking_secret_material( + self, + ) -> None: result = self.adapter.generate(self._request(), self.auth_capability.resolve("anthropic")) self.assertEqual(result.task, "generate") - self.assertEqual(result.content, "live-anthropic:Explain the provider boundary without leaking secrets.") + self.assertEqual( + result.content, + "live-anthropic:Explain the provider boundary without leaking secrets.", + ) self.assertNotIn("anthropic-secret", result.content) self.assertEqual(result.metadata["transport_id"], "anthropic_messages") self.assertEqual(result.metadata["credential_keys"], "api_key") @@ -362,7 +380,9 @@ def test_generate_returns_native_result_without_leaking_secret_material(self) -> self.assertEqual(result.usage.cache_creation_prompt_tokens, 2) self.assertTrue(result.usage.cache_usage_reported) self.assertEqual(self.server.requests[0]["path"], "/v1/messages") - request_headers = {str(key).lower(): str(value) for key, value in dict(self.server.requests[0]["headers"]).items()} + request_headers = { + str(key).lower(): str(value) for key, value in dict(self.server.requests[0]["headers"]).items() + } self.assertEqual(request_headers["x-api-key"], "anthropic-secret") self.assertEqual(request_headers["x-session-id"], "session-1") @@ -439,7 +459,10 @@ def test_capability_bridge_uses_shared_runtime_contract(self) -> None: self.assertEqual(result.session_id, session.episode_id) self.assertEqual(result.outcome, "ok") - self.assertEqual(result.summary, "live-anthropic:Explain the provider boundary without leaking secrets.") + self.assertEqual( + result.summary, + "live-anthropic:Explain the provider boundary without leaking secrets.", + ) self.assertEqual(result.cached_prompt_tokens, 3) self.assertEqual(result.cache_creation_prompt_tokens, 2) self.assertTrue(result.cache_usage_reported) diff --git a/tests/integration/models_auth/test_models_auth_integration.py b/tests/integration/models_auth/test_models_auth_integration.py index 013ff6d..89ec45a 100644 --- a/tests/integration/models_auth/test_models_auth_integration.py +++ b/tests/integration/models_auth/test_models_auth_integration.py @@ -215,8 +215,14 @@ def test_model_registry_routes_provider_neutral_adapters(self) -> None: ) ) - self.assertEqual(registry.select("preview.echo").descriptor.adapter_id, "adapter.preview.echo") - self.assertEqual(registry.select("preview.static").descriptor.adapter_id, "adapter.preview.static") + self.assertEqual( + registry.select("preview.echo").descriptor.adapter_id, + "adapter.preview.echo", + ) + self.assertEqual( + registry.select("preview.static").descriptor.adapter_id, + "adapter.preview.static", + ) self.assertEqual(len(registry.list()), 2) def test_provider_runtime_lists_catalog_and_guided_setup(self) -> None: @@ -408,7 +414,9 @@ def test_preview_model_capability_uses_resolved_credentials(self) -> None: self.assertIn("creds: api_key", result.summary) self.assertNotIn("sk-test-456", result.summary) - def test_surface_runtime_includes_enabled_custom_mcp_tools_in_model_request(self) -> None: + def test_surface_runtime_includes_enabled_custom_mcp_tools_in_model_request( + self, + ) -> None: database_path = Path(self.tempdir.name) / "surface-runtime-mcp.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -422,7 +430,11 @@ def test_surface_runtime_includes_enabled_custom_mcp_tools_in_model_request(self "label": "Filesystem", "transport": "stdio", "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp/demo"], + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp/demo", + ], "tools": { "read_file": { "display_name": "Read File", @@ -440,7 +452,10 @@ def test_surface_runtime_includes_enabled_custom_mcp_tools_in_model_request(self "writes_state": True, "schema": { "type": "object", - "properties": {"path": {"type": "string"}, "content": {"type": "string"}}, + "properties": { + "path": {"type": "string"}, + "content": {"type": "string"}, + }, "required": ["path", "content"], }, }, @@ -473,7 +488,10 @@ def generate(self, request: ModelRequest, credentials: dict[str, str]) -> ModelT task="generate", content="captured", usage=ModelUsage(), - metadata={"transport_id": "openai_chat_completions", "credential_keys": ",".join(sorted(credentials))}, + metadata={ + "transport_id": "openai_chat_completions", + "credential_keys": ",".join(sorted(credentials)), + }, ) profile = PersonalModelRuntimeState( @@ -510,8 +528,15 @@ def generate(self, request: ModelRequest, credentials: dict[str, str]) -> ModelT with ( mock.patch.object(capability, "_profile_for_role", return_value=active_profile), - mock.patch.object(capability.credential_resolver, "resolve", return_value=credential_bundle), - mock.patch("packages.models.runtime_capability.build_model_adapter", return_value=_CapturingAdapter()), + mock.patch.object( + capability.credential_resolver, + "resolve", + return_value=credential_bundle, + ), + mock.patch( + "packages.models.runtime_capability.build_model_adapter", + return_value=_CapturingAdapter(), + ), ): result = capability.generate( profile=profile, @@ -536,7 +561,9 @@ def generate(self, request: ModelRequest, credentials: dict[str, str]) -> ModelT }, ) - def test_surface_runtime_adds_tool_fallback_prompt_without_native_tool_calling(self) -> None: + def test_surface_runtime_adds_tool_fallback_prompt_without_native_tool_calling( + self, + ) -> None: database_path = Path(self.tempdir.name) / "surface-runtime-fallback-tools.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -550,7 +577,11 @@ def test_surface_runtime_adds_tool_fallback_prompt_without_native_tool_calling(s "label": "Filesystem", "transport": "stdio", "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp/demo"], + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp/demo", + ], "tools": { "read_file": { "display_name": "Read File", @@ -588,7 +619,10 @@ def generate(self, request: ModelRequest, credentials: dict[str, str]) -> ModelT task="generate", content="captured", usage=ModelUsage(), - metadata={"transport_id": "legacy_chat", "credential_keys": ",".join(sorted(credentials))}, + metadata={ + "transport_id": "legacy_chat", + "credential_keys": ",".join(sorted(credentials)), + }, ) profile = PersonalModelRuntimeState( @@ -628,8 +662,15 @@ def generate(self, request: ModelRequest, credentials: dict[str, str]) -> ModelT with ( mock.patch.object(capability, "_profile_for_role", return_value=active_profile), mock.patch.object(capability.runtime_resolver, "resolve", return_value=legacy_resolution), - mock.patch.object(capability.credential_resolver, "resolve", return_value=credential_bundle), - mock.patch("packages.models.runtime_capability.build_model_adapter", return_value=_CapturingAdapter()), + mock.patch.object( + capability.credential_resolver, + "resolve", + return_value=credential_bundle, + ), + mock.patch( + "packages.models.runtime_capability.build_model_adapter", + return_value=_CapturingAdapter(), + ), ): result = capability.generate( profile=profile, @@ -663,7 +704,9 @@ def test_model_request_can_be_constructed_for_preview_runtime(self) -> None: self.assertEqual(request.provider_id, "preview.static") self.assertEqual(request.task, "generate") - def test_auth_profiles_persist_provider_metadata_and_secret_references(self) -> None: + def test_auth_profiles_persist_provider_metadata_and_secret_references( + self, + ) -> None: database_path = Path(self.tempdir.name) / "auth.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -703,7 +746,12 @@ def test_auth_profiles_persist_provider_metadata_and_secret_references(self) -> persisted = json.loads(auth_profiles_path.read_text(encoding="utf-8")) row = persisted["auth-openai-default"] self.assertEqual( - (row["provider_id"], row["transport_id"], row["base_url"], row["default_model"]), + ( + row["provider_id"], + row["transport_id"], + row["base_url"], + row["default_model"], + ), ("openai", "openai_responses", "https://api.openai.com/v1", "gpt-4.1-mini"), ) self.assertNotIn("sk-test-123", database_path.read_text(errors="ignore")) @@ -818,9 +866,12 @@ def test_api_provider_list_surfaces_codex_and_copilot_discovery(self) -> None: encoding="utf-8", ) - with mock.patch.dict(os.environ, {"CODEX_HOME": str(codex_home)}, clear=True), mock.patch( - "packages.auth.discovery.subprocess.run", - return_value=mock.Mock(stdout="gho-copilot-token\n"), + with ( + mock.patch.dict(os.environ, {"CODEX_HOME": str(codex_home)}, clear=True), + mock.patch( + "packages.auth.discovery.subprocess.run", + return_value=mock.Mock(stdout="gho-copilot-token\n"), + ), ): payload = list_providers(SimpleNamespace(model_provider=capability)) @@ -830,7 +881,9 @@ def test_api_provider_list_surfaces_codex_and_copilot_discovery(self) -> None: self.assertEqual(providers["copilot"]["status"], "authenticated") self.assertEqual(providers["copilot"]["source"], "gh auth token") - def test_surface_runtime_discovers_models_with_saved_non_active_provider_key(self) -> None: + def test_surface_runtime_discovers_models_with_saved_non_active_provider_key( + self, + ) -> None: database_path = Path(self.tempdir.name) / "provider-saved-key-discovery.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -862,12 +915,17 @@ def _fake_request_json(*, url: str, headers, timeout_seconds: float = 10.0): self.assertEqual(dict(headers).get("Authorization"), "Bearer sk-saved-provider") return {"data": [{"id": "model-a"}, {"id": "model-b"}]} - with mock.patch("packages.models.runtime_capability.request_json", side_effect=_fake_request_json): + with mock.patch( + "packages.models.runtime_capability.request_json", + side_effect=_fake_request_json, + ): models = capability.discover_models(provider_id="openai-compatible", base_url=None) self.assertEqual([model.model_id for model in models[:2]], ["model-a", "model-b"]) - def test_surface_runtime_discovers_copilot_models_from_provider_specific_catalog_path(self) -> None: + def test_surface_runtime_discovers_copilot_models_from_provider_specific_catalog_path( + self, + ) -> None: database_path = Path(self.tempdir.name) / "provider-copilot-models.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -904,7 +962,9 @@ def test_surface_runtime_discovers_copilot_models_from_provider_specific_catalog self.assertEqual(server.last_headers.get("Openai-Intent"), "conversation-edits") self.assertEqual([model.model_id for model in models[:2]], ["claude-opus-4.6", "gpt-5.4"]) - def test_surface_runtime_detects_copilot_claude_context_with_bearer_auth(self) -> None: + def test_surface_runtime_detects_copilot_claude_context_with_bearer_auth( + self, + ) -> None: database_path = Path(self.tempdir.name) / "provider-copilot-claude-context.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -927,19 +987,23 @@ def _fake_request_json(*, url: str, headers, timeout_seconds: float = 10.0): } raise AssertionError(f"unexpected url {url}") - with mock.patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "ghu-test"}, clear=False), mock.patch.object( - capability, - "discover_models", - return_value=( - DiscoveredProviderModel( - model_id="claude-sonnet-4.6", - label="claude-sonnet-4.6", - context_window_tokens=None, + with ( + mock.patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "ghu-test"}, clear=False), + mock.patch.object( + capability, + "discover_models", + return_value=( + DiscoveredProviderModel( + model_id="claude-sonnet-4.6", + label="claude-sonnet-4.6", + context_window_tokens=None, + ), ), ), - ), mock.patch( - "packages.models.runtime_capability.request_json", - side_effect=_fake_request_json, + mock.patch( + "packages.models.runtime_capability.request_json", + side_effect=_fake_request_json, + ), ): context_window = capability.detect_context_window( provider_id="copilot", @@ -948,15 +1012,20 @@ def _fake_request_json(*, url: str, headers, timeout_seconds: float = 10.0): ) self.assertEqual(context_window, 200000) - self.assertEqual([url for url, _ in requests], [ - "https://api.githubcopilot.com/models/claude-sonnet-4.6", - ]) + self.assertEqual( + [url for url, _ in requests], + [ + "https://api.githubcopilot.com/models/claude-sonnet-4.6", + ], + ) detail_headers = requests[-1][1] self.assertEqual(detail_headers.get("Authorization"), "Bearer ghu-test") self.assertEqual(detail_headers.get("anthropic-version"), "2023-06-01") self.assertEqual(detail_headers.get("Openai-Intent"), "conversation-edits") - def test_surface_runtime_falls_back_to_curated_codex_models_when_live_probe_fails(self) -> None: + def test_surface_runtime_falls_back_to_curated_codex_models_when_live_probe_fails( + self, + ) -> None: database_path = Path(self.tempdir.name) / "provider-codex-models.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -966,14 +1035,20 @@ def test_surface_runtime_falls_back_to_curated_codex_models_when_live_probe_fail secret_key_path=Path(self.tempdir.name) / "provider-secrets.key", ) - with mock.patch("packages.models.runtime_capability.request_json", side_effect=RuntimeError("boom")): + with mock.patch( + "packages.models.runtime_capability.request_json", + side_effect=RuntimeError("boom"), + ): models = capability.discover_models( provider_id="openai-codex", base_url="https://chatgpt.com/backend-api/codex", ) self.assertGreaterEqual(len(models), 4) - self.assertEqual([model.model_id for model in models[:4]], ["gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex", "gpt-5.3-codex-spark"]) + self.assertEqual( + [model.model_id for model in models[:4]], + ["gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex", "gpt-5.3-codex-spark"], + ) self.assertTrue(all(model.source == "catalog-hint" for model in models)) gpt5 = next(model for model in models if model.model_id == "gpt-5.4") gpt5_mini = next(model for model in models if model.model_id == "gpt-5.4-mini") @@ -983,7 +1058,9 @@ def test_surface_runtime_falls_back_to_curated_codex_models_when_live_probe_fail self.assertEqual(spark.context_window_tokens, 128_000) self.assertEqual(gpt5.metadata["reasoning_efforts"], "minimal,low,medium,high") - def test_surface_runtime_uses_model_specific_context_hints_when_live_probe_fails(self) -> None: + def test_surface_runtime_uses_model_specific_context_hints_when_live_probe_fails( + self, + ) -> None: database_path = Path(self.tempdir.name) / "provider-context-hints.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -993,7 +1070,10 @@ def test_surface_runtime_uses_model_specific_context_hints_when_live_probe_fails secret_key_path=Path(self.tempdir.name) / "provider-secrets.key", ) - with mock.patch("packages.models.runtime_capability.request_json", side_effect=RuntimeError("boom")): + with mock.patch( + "packages.models.runtime_capability.request_json", + side_effect=RuntimeError("boom"), + ): minimax_models = capability.discover_models( provider_id="minimax", base_url="https://api.minimax.io/anthropic", @@ -1042,7 +1122,9 @@ def test_surface_runtime_detects_ollama_runtime_context_from_show_api(self) -> N self.assertEqual(context_window, 32_768) self.assertEqual(server.requests, ["GET /v1/models", "POST /api/show"]) - def test_surface_runtime_uses_models_dev_fallback_after_endpoint_metadata_miss(self) -> None: + def test_surface_runtime_uses_models_dev_fallback_after_endpoint_metadata_miss( + self, + ) -> None: database_path = Path(self.tempdir.name) / "provider-models-dev-context.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -1053,7 +1135,10 @@ def test_surface_runtime_uses_models_dev_fallback_after_endpoint_metadata_miss(s ) with ( - mock.patch("packages.models.runtime_capability.request_json", side_effect=RuntimeError("boom")), + mock.patch( + "packages.models.runtime_capability.request_json", + side_effect=RuntimeError("boom"), + ), mock.patch( "packages.models.model_metadata.fetch_models_dev_registry", return_value={ @@ -1078,7 +1163,9 @@ def test_surface_runtime_uses_models_dev_fallback_after_endpoint_metadata_miss(s self.assertEqual(context_window, 1_000_000) - def test_surface_runtime_does_not_invent_placeholder_models_for_openai_compatible(self) -> None: + def test_surface_runtime_does_not_invent_placeholder_models_for_openai_compatible( + self, + ) -> None: database_path = Path(self.tempdir.name) / "provider-openai-compatible-models.sqlite3" repository = RuntimeStorageRepository(database_path) repository.bootstrap() @@ -1088,7 +1175,10 @@ def test_surface_runtime_does_not_invent_placeholder_models_for_openai_compatibl secret_key_path=Path(self.tempdir.name) / "provider-secrets.key", ) - with mock.patch("packages.models.runtime_capability.request_json", side_effect=RuntimeError("boom")): + with mock.patch( + "packages.models.runtime_capability.request_json", + side_effect=RuntimeError("boom"), + ): models = capability.discover_models( provider_id="openai-compatible", base_url="https://api.example.test/v1", diff --git a/tests/integration/models_auth/test_openai_compatible_provider.py b/tests/integration/models_auth/test_openai_compatible_provider.py index 68ab3eb..dd47747 100644 --- a/tests/integration/models_auth/test_openai_compatible_provider.py +++ b/tests/integration/models_auth/test_openai_compatible_provider.py @@ -50,7 +50,10 @@ def post_json_stream(self, *, url: str, headers, payload): self.stream_payloads.append(dict(payload)) yield JSONHTTPStreamChunk( event="response.output_text.delta", - payload={"type": "response.output_text.delta", "delta": "fallback-response-text"}, + payload={ + "type": "response.output_text.delta", + "delta": "fallback-response-text", + }, ) yield JSONHTTPStreamChunk( event="response.output_item.done", @@ -142,7 +145,11 @@ def post_json_stream(self, *, url: str, headers, payload): "model": str(payload["model"]), "output_text": "The release note draft is ready.", "reasoning": "Inspect the latest release state first.", - "usage": {"input_tokens": 8, "output_tokens": 5, "total_tokens": 13}, + "usage": { + "input_tokens": 8, + "output_tokens": 5, + "total_tokens": 13, + }, }, }, ) @@ -153,7 +160,18 @@ def post_json(self, *, url: str, headers, payload): class _ResponsesFragmentedReasoningStreamTransport: def post_json_stream(self, *, url: str, headers, payload): - reasoning_deltas = ("先看", "\n", "release", "\n", "notes", "。", "\n", "Then", "\n", "verify") + reasoning_deltas = ( + "先看", + "\n", + "release", + "\n", + "notes", + "。", + "\n", + "Then", + "\n", + "verify", + ) for delta in reasoning_deltas: yield JSONHTTPStreamChunk( event="response.reasoning.delta", @@ -178,7 +196,11 @@ def post_json_stream(self, *, url: str, headers, payload): "model": str(payload["model"]), "output_text": "结论已经确认。", "reasoning": "先看\nrelease\nnotes。\nThen\nverify", - "usage": {"input_tokens": 10, "output_tokens": 6, "total_tokens": 16}, + "usage": { + "input_tokens": 10, + "output_tokens": 6, + "total_tokens": 16, + }, }, }, ) @@ -189,7 +211,19 @@ def post_json(self, *, url: str, headers, payload): class _ResponsesWordFragmentReasoningStreamTransport: def post_json_stream(self, *, url: str, headers, payload): - reasoning_deltas = ("The", "user", "asked", "about", "X", "un", "zhuo", "in", "Cheng", "du", ".") + reasoning_deltas = ( + "The", + "user", + "asked", + "about", + "X", + "un", + "zhuo", + "in", + "Cheng", + "du", + ".", + ) for delta in reasoning_deltas: yield JSONHTTPStreamChunk( event="response.reasoning.delta", @@ -214,7 +248,11 @@ def post_json_stream(self, *, url: str, headers, payload): "model": str(payload["model"]), "output_text": "I can answer naturally now.", "reasoning": "The user asked about Xunzhuo in Chengdu.", - "usage": {"input_tokens": 14, "output_tokens": 7, "total_tokens": 21}, + "usage": { + "input_tokens": 14, + "output_tokens": 7, + "total_tokens": 21, + }, }, }, ) @@ -238,7 +276,11 @@ def post_json(self, *, url: str, headers, payload): } } ], - "usage": {"prompt_tokens": 7, "completion_tokens": 4, "total_tokens": 11}, + "usage": { + "prompt_tokens": 7, + "completion_tokens": 4, + "total_tokens": 11, + }, }, ) @@ -338,7 +380,10 @@ def do_POST(self) -> None: # noqa: N802 "index": 0, "id": "call-stub", "type": "function", - "function": {"name": tool_name, "arguments": "{\"query\":"}, + "function": { + "name": tool_name, + "arguments": '{"query":', + }, } ], } @@ -354,7 +399,7 @@ def do_POST(self) -> None: # noqa: N802 "tool_calls": [ { "index": 0, - "function": {"arguments": "\"native tools\"}"}, + "function": {"arguments": '"native tools"}'}, } ], } @@ -365,7 +410,11 @@ def do_POST(self) -> None: # noqa: N802 "id": "chatcmpl-stub", "model": payload["model"], "choices": [{"delta": {}, "finish_reason": "tool_calls"}], - "usage": {"prompt_tokens": 7, "completion_tokens": 3, "total_tokens": 10}, + "usage": { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + }, }, ) for event in events: @@ -397,7 +446,11 @@ def do_POST(self) -> None: # noqa: N802 } } ], - "usage": {"prompt_tokens": 7, "completion_tokens": 3, "total_tokens": 10}, + "usage": { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + }, } encoded = json.dumps(response).encode("utf-8") self.send_response(200) @@ -428,7 +481,11 @@ def do_POST(self) -> None: # noqa: N802 "id": "chatcmpl-stub", "model": payload["model"], "choices": [{"delta": {}, "finish_reason": "stop"}], - "usage": {"prompt_tokens": 7, "completion_tokens": 3, "total_tokens": 10}, + "usage": { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + }, } self.wfile.write(f"data: {json.dumps(final_event)}\n\n".encode("utf-8")) self.wfile.write(b"data: [DONE]\n\n") @@ -481,7 +538,11 @@ def do_POST(self) -> None: # noqa: N802 "id": "resp-stub", "model": payload["model"], "output": [function_call], - "usage": {"input_tokens": 6, "output_tokens": 3, "total_tokens": 9}, + "usage": { + "input_tokens": 6, + "output_tokens": 3, + "total_tokens": 9, + }, } }, ), @@ -490,8 +551,14 @@ def do_POST(self) -> None: # noqa: N802 content = f"live-response:{input_text}" midpoint = max(1, len(content) // 2) events = ( - ("response.output_text.delta", {"delta": content[:midpoint]}), - ("response.output_text.delta", {"delta": content[midpoint:]}), + ( + "response.output_text.delta", + {"delta": content[:midpoint]}, + ), + ( + "response.output_text.delta", + {"delta": content[midpoint:]}, + ), ( "response.completed", { @@ -499,7 +566,11 @@ def do_POST(self) -> None: # noqa: N802 "id": "resp-stub", "model": payload["model"], "output_text": content, - "usage": {"input_tokens": 6, "output_tokens": 3, "total_tokens": 9}, + "usage": { + "input_tokens": 6, + "output_tokens": 3, + "total_tokens": 9, + }, } }, ), @@ -523,14 +594,22 @@ def do_POST(self) -> None: # noqa: N802 "arguments": json.dumps({"query": "responses tools"}), } ], - "usage": {"input_tokens": 6, "output_tokens": 3, "total_tokens": 9}, + "usage": { + "input_tokens": 6, + "output_tokens": 3, + "total_tokens": 9, + }, } else: response = { "id": "resp-stub", "model": payload["model"], "output_text": f"live-response:{self._responses_input_text(payload.get('input'))}", - "usage": {"input_tokens": 6, "output_tokens": 3, "total_tokens": 9}, + "usage": { + "input_tokens": 6, + "output_tokens": 3, + "total_tokens": 9, + }, } else: self.send_response(404) @@ -565,9 +644,7 @@ def test_plans_chat_requests_with_custom_base_url_and_headers(self) -> None: extra_headers={"x-tenant": "elephant"}, ), runtime_resolver=ProviderRuntimeResolver.default(), - credential_source=_StaticCredentialSource( - {"openai-compatible": {"api_key": "sk-test-123"}} - ), + credential_source=_StaticCredentialSource({"openai-compatible": {"api_key": "sk-test-123"}}), ) request = ModelRequest( request_id="request-1", @@ -591,7 +668,10 @@ def test_plans_chat_requests_with_custom_base_url_and_headers(self) -> None: self.assertEqual(plan.payload["model"], "openai/gpt-4o-mini") self.assertEqual(plan.payload["messages"][0]["role"], "system") self.assertIn("### System Layer Contract", plan.payload["messages"][0]["content"]) - self.assertIn("You are the active elephant identity", plan.payload["messages"][0]["content"]) + self.assertIn( + "You are the active elephant identity", + plan.payload["messages"][0]["content"], + ) self.assertIn("### Episode Continuity", plan.payload["messages"][0]["content"]) self.assertIn("Stay truthful and bounded", plan.payload["messages"][0]["content"]) self.assertIn("### Loop Execution Board", plan.payload["messages"][0]["content"]) @@ -606,7 +686,9 @@ def test_plans_chat_requests_with_custom_base_url_and_headers(self) -> None: self.assertNotIn("sk-test-123", result.content) self.assertEqual(self.server.requests[0]["path"], "/v1/chat/completions") self.assertEqual(self.server.requests[0]["headers"]["Authorization"], "Bearer sk-test-123") - request_headers = {str(key).lower(): str(value) for key, value in dict(self.server.requests[0]["headers"]).items()} + request_headers = { + str(key).lower(): str(value) for key, value in dict(self.server.requests[0]["headers"]).items() + } self.assertEqual(request_headers["x-session-id"], "session-1") self.assertNotIn("metadata", self.server.requests[0]["payload"]) self.assertFalse(plan.payload["stream"]) @@ -672,9 +754,7 @@ def test_chat_requests_accept_base_url_without_v1_suffix(self) -> None: model_id="openai/gpt-4o-mini", ), runtime_resolver=ProviderRuntimeResolver.default(), - credential_source=_StaticCredentialSource( - {"openai-compatible": {"api_key": "sk-test-123"}} - ), + credential_source=_StaticCredentialSource({"openai-compatible": {"api_key": "sk-test-123"}}), ) request = ModelRequest( request_id="request-root-base", @@ -692,7 +772,9 @@ def test_chat_requests_accept_base_url_without_v1_suffix(self) -> None: self.assertEqual(self.server.requests[-1]["path"], "/v1/chat/completions") self.assertEqual(result.content, "live-chat:Use the root endpoint.") - def test_rendered_prompt_is_forwarded_without_provider_guardrail_prepended(self) -> None: + def test_rendered_prompt_is_forwarded_without_provider_guardrail_prepended( + self, + ) -> None: adapter = OpenAICompatibleProviderAdapter( config=OpenAICompatibleProviderConfig( provider_id="openai-compatible", @@ -700,9 +782,7 @@ def test_rendered_prompt_is_forwarded_without_provider_guardrail_prepended(self) model_id="openai/gpt-4o-mini", ), runtime_resolver=ProviderRuntimeResolver.default(), - credential_source=_StaticCredentialSource( - {"openai-compatible": {"api_key": "sk-test-123"}} - ), + credential_source=_StaticCredentialSource({"openai-compatible": {"api_key": "sk-test-123"}}), ) request = ModelRequest( request_id="request-identity", @@ -713,14 +793,9 @@ def test_rendered_prompt_is_forwarded_without_provider_guardrail_prepended(self) prompt="Who are you?", context={ "frozen_prefix_prompt": ( - "## EpisodeFrozenContext\n" - "### System Layer Contract\n" - "You are Aeon, the active elephant identity." - ), - "session_snapshot_prompt": ( - "## StateSnapshot\n" - "- active current work: keep the State exact" + "## EpisodeFrozenContext\n### System Layer Contract\nYou are Aeon, the active elephant identity." ), + "session_snapshot_prompt": ("## StateSnapshot\n- active current work: keep the State exact"), "rendered_prompt": "legacy rendered prompt should not be used when structured sections exist", }, ) @@ -730,20 +805,24 @@ def test_rendered_prompt_is_forwarded_without_provider_guardrail_prepended(self) self.assertEqual(plan.payload["messages"][0]["role"], "system") self.assertEqual( plan.payload["messages"][0]["content"], - f"{request.context['frozen_prefix_prompt']}\n\n" - f"{request.context['session_snapshot_prompt']}", + f"{request.context['frozen_prefix_prompt']}\n\n{request.context['session_snapshot_prompt']}", ) self.assertIn("You are Aeon", plan.payload["messages"][0]["content"]) self.assertNotIn("## LoopContext", plan.payload["messages"][0]["content"]) self.assertNotIn("OpenAI-compatible provider adapter", plan.payload["messages"][0]["content"]) self.assertNotIn("credential_keys=", plan.payload["messages"][0]["content"]) - self.assertEqual(sum(1 for message in plan.payload["messages"] if message["role"] == "system"), 1) + self.assertEqual( + sum(1 for message in plan.payload["messages"] if message["role"] == "system"), + 1, + ) self.assertEqual( plan.payload["messages"][1]["content"], "Who are you?", ) - def test_chat_request_flattens_all_system_context_into_one_system_message(self) -> None: + def test_chat_request_flattens_all_system_context_into_one_system_message( + self, + ) -> None: adapter = OpenAICompatibleProviderAdapter( config=OpenAICompatibleProviderConfig( provider_id="openai-compatible", @@ -774,7 +853,10 @@ def test_chat_request_flattens_all_system_context_into_one_system_message(self) plan = adapter.plan_request(request) - self.assertEqual([message["role"] for message in plan.payload["messages"]], ["system", "assistant", "user"]) + self.assertEqual( + [message["role"] for message in plan.payload["messages"]], + ["system", "assistant", "user"], + ) self.assertIn("## EpisodeFrozenContext", plan.payload["messages"][0]["content"]) self.assertIn("## StateSnapshot", plan.payload["messages"][0]["content"]) self.assertNotIn("## LoopContext", plan.payload["messages"][0]["content"]) @@ -805,7 +887,11 @@ def test_chat_request_preserves_history_and_tool_result_roles(self) -> None: role="assistant", content="", tool_calls=( - {"id": "call-1", "name": "tool.web.search", "arguments": {"query": "elephant docs"}}, + { + "id": "call-1", + "name": "tool.web.search", + "arguments": {"query": "elephant docs"}, + }, ), ), PromptMessage( @@ -821,7 +907,10 @@ def test_chat_request_preserves_history_and_tool_result_roles(self) -> None: "function": { "name": "tool.web.search", "description": "Search the web.", - "parameters": {"type": "object", "properties": {"query": {"type": "string"}}}, + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + }, }, }, ), @@ -829,8 +918,14 @@ def test_chat_request_preserves_history_and_tool_result_roles(self) -> None: plan = adapter.plan_request(request) - self.assertEqual([message["role"] for message in plan.payload["messages"][-4:]], ["user", "assistant", "tool", "user"]) - self.assertEqual(plan.payload["messages"][-3]["tool_calls"][0]["function"]["name"], "tool_web_search") + self.assertEqual( + [message["role"] for message in plan.payload["messages"][-4:]], + ["user", "assistant", "tool", "user"], + ) + self.assertEqual( + plan.payload["messages"][-3]["tool_calls"][0]["function"]["name"], + "tool_web_search", + ) self.assertEqual(plan.payload["messages"][-2]["tool_call_id"], "call-1") self.assertEqual(plan.payload["messages"][-1]["content"], "Use that result.") @@ -842,9 +937,7 @@ def test_embed_requests_use_the_shared_compatible_transport(self) -> None: model_id="text-embedding-3-small", ), runtime_resolver=ProviderRuntimeResolver.default(), - credential_source=_StaticCredentialSource( - {"openai-compatible": {"api_key": "sk-embed-456"}} - ), + credential_source=_StaticCredentialSource({"openai-compatible": {"api_key": "sk-embed-456"}}), ) request = ModelRequest( request_id="request-embed", @@ -878,9 +971,7 @@ def test_generate_streams_chat_completions_when_observer_is_present(self) -> Non model_id="openai/gpt-4o-mini", ), runtime_resolver=ProviderRuntimeResolver.default(), - credential_source=_StaticCredentialSource( - {"openai-compatible": {"api_key": "sk-stream-789"}} - ), + credential_source=_StaticCredentialSource({"openai-compatible": {"api_key": "sk-stream-789"}}), stream_observer=streamed.append, ) request = ModelRequest( @@ -909,9 +1000,7 @@ def test_generate_streams_and_parses_native_tool_calls(self) -> None: model_id="openai/gpt-4o-mini", ), runtime_resolver=ProviderRuntimeResolver.default(), - credential_source=_StaticCredentialSource( - {"openai-compatible": {"api_key": "sk-tools-123"}} - ), + credential_source=_StaticCredentialSource({"openai-compatible": {"api_key": "sk-tools-123"}}), stream_observer=streamed.append, ) request = ModelRequest( @@ -987,7 +1076,9 @@ def test_responses_stream_reasoning_is_split_from_final_answer(self) -> None: ) self.assertTrue(bool(transport.stream_payloads[0]["stream"])) - def test_responses_stream_reasoning_collapses_fragmented_newlines_without_breaking_mixed_language_text(self) -> None: + def test_responses_stream_reasoning_collapses_fragmented_newlines_without_breaking_mixed_language_text( + self, + ) -> None: streamed: list[str] = [] adapter = OpenAICompatibleProviderAdapter( config=OpenAICompatibleProviderConfig( @@ -1018,7 +1109,9 @@ def test_responses_stream_reasoning_collapses_fragmented_newlines_without_breaki self.assertEqual(streamed_combined.reasoning, "先看release notes。 Then verify") self.assertEqual(streamed_combined.content, "结论已经确认。") - def test_responses_stream_reasoning_prioritizes_spaces_and_uses_completed_reasoning_when_available(self) -> None: + def test_responses_stream_reasoning_prioritizes_spaces_and_uses_completed_reasoning_when_available( + self, + ) -> None: streamed: list[str] = [] adapter = OpenAICompatibleProviderAdapter( config=OpenAICompatibleProviderConfig( @@ -1112,7 +1205,10 @@ def test_responses_transport_parses_native_tool_calls(self) -> None: self.assertEqual(plan.endpoint_path, "/v1/responses") self.assertEqual(plan.payload["input"][0]["role"], "user") - self.assertEqual(plan.payload["input"][0]["content"][0]["text"], "Use tools through responses.") + self.assertEqual( + plan.payload["input"][0]["content"][0]["text"], + "Use tools through responses.", + ) self.assertEqual(plan.payload["tools"][0]["name"], "tool_web_search") self.assertFalse(plan.payload["store"]) self.assertEqual(result.content, "") @@ -1177,7 +1273,9 @@ def test_codex_responses_omits_internal_metadata_from_request_payload(self) -> N self.assertNotIn("metadata", self.server.requests[-1]["payload"]) self.assertEqual(result.content, "live-response:Explain the current runtime status.") - def test_codex_responses_backfills_completed_response_from_stream_items(self) -> None: + def test_codex_responses_backfills_completed_response_from_stream_items( + self, + ) -> None: transport = _ResponsesStreamBackfillTransport() adapter = OpenAICompatibleProviderAdapter( config=OpenAICompatibleProviderConfig( @@ -1367,13 +1465,19 @@ def test_retries_with_curl_on_tls_version_mismatch(self) -> None: "packages.models.providers.http.request.urlopen", side_effect=error.URLError(ssl.SSLError("WRONG_VERSION_NUMBER")), ), - mock.patch("packages.models.providers.http.shutil.which", return_value="/usr/bin/curl"), + mock.patch( + "packages.models.providers.http.shutil.which", + return_value="/usr/bin/curl", + ), mock.patch("packages.models.providers.http.subprocess.run", return_value=completed) as run, ): response = transport.post_json( url="https://example.test/v1/chat/completions", headers={"Authorization": "Bearer sk-test"}, - payload={"model": "demo", "messages": [{"role": "user", "content": "hello"}]}, + payload={ + "model": "demo", + "messages": [{"role": "user", "content": "hello"}], + }, ) self.assertEqual(response.status_code, 200) @@ -1388,11 +1492,7 @@ def test_retries_with_curl_on_tls_version_mismatch(self) -> None: def test_retries_with_curl_on_tls_unexpected_eof(self) -> None: transport = UrllibJSONHTTPTransport() - self.assertTrue( - transport._should_retry_with_curl( - error.URLError(ssl.SSLError("UNEXPECTED_EOF_WHILE_READING")) - ) - ) + self.assertTrue(transport._should_retry_with_curl(error.URLError(ssl.SSLError("UNEXPECTED_EOF_WHILE_READING")))) def test_stream_retries_with_curl_on_tls_unexpected_eof(self) -> None: transport = UrllibJSONHTTPTransport() @@ -1400,12 +1500,12 @@ def test_stream_retries_with_curl_on_tls_unexpected_eof(self) -> None: args=["curl"], returncode=0, stdout=( - b'event: response.output_text.delta\n' + b"event: response.output_text.delta\n" b'data: {"delta":"hello"}\n\n' - b'event: response.completed\n' + b"event: response.completed\n" b'data: {"response":{"id":"resp-fallback","output_text":"hello"}}\n\n' - b'data: [DONE]\n\n' - b'__ELEPHANT_STATUS__:200' + b"data: [DONE]\n\n" + b"__ELEPHANT_STATUS__:200" ), stderr=b"", ) @@ -1414,7 +1514,10 @@ def test_stream_retries_with_curl_on_tls_unexpected_eof(self) -> None: "packages.models.providers.http.request.urlopen", side_effect=error.URLError(ssl.SSLError("UNEXPECTED_EOF_WHILE_READING")), ), - mock.patch("packages.models.providers.http.shutil.which", return_value="/usr/bin/curl"), + mock.patch( + "packages.models.providers.http.shutil.which", + return_value="/usr/bin/curl", + ), mock.patch("packages.models.providers.http.subprocess.run", return_value=completed) as run, ): chunks = tuple( @@ -1425,7 +1528,10 @@ def test_stream_retries_with_curl_on_tls_unexpected_eof(self) -> None: ) ) - self.assertEqual([chunk.event for chunk in chunks], ["response.output_text.delta", "response.completed"]) + self.assertEqual( + [chunk.event for chunk in chunks], + ["response.output_text.delta", "response.completed"], + ) self.assertEqual(chunks[0].payload["delta"], "hello") command = run.call_args.args[0] self.assertIn("--write-out", command) diff --git a/tests/integration/models_auth/test_openai_provider.py b/tests/integration/models_auth/test_openai_provider.py index 50fb843..eca938b 100644 --- a/tests/integration/models_auth/test_openai_provider.py +++ b/tests/integration/models_auth/test_openai_provider.py @@ -17,7 +17,11 @@ sys.modules[SPEC.name] = OPENAI SPEC.loader.exec_module(OPENAI) -from packages.auth import InMemorySecretStore, ProfileCredentialResolver, SecretReference +from packages.auth import ( + InMemorySecretStore, + ProfileCredentialResolver, + SecretReference, +) class OpenAIProviderAdapterTests(unittest.TestCase): diff --git a/tests/integration/security_observability/AGENTS.md b/tests/integration/security_observability/AGENTS.md index e060f9b..9ecd5cb 100644 --- a/tests/integration/security_observability/AGENTS.md +++ b/tests/integration/security_observability/AGENTS.md @@ -13,4 +13,3 @@ together. - business logic for policy decisions - telemetry backend implementation details - diff --git a/tests/integration/security_observability/test_security_observability.py b/tests/integration/security_observability/test_security_observability.py index fc1aa6c..c0d3899 100644 --- a/tests/integration/security_observability/test_security_observability.py +++ b/tests/integration/security_observability/test_security_observability.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import os from pathlib import Path import tempfile import unittest @@ -145,7 +144,9 @@ def test_auth_profile_rejects_inline_secret_material(self) -> None: metadata={"token": "sk-live-123"}, ) - def test_security_records_redact_secret_like_metadata_and_support_details(self) -> None: + def test_security_records_redact_secret_like_metadata_and_support_details( + self, + ) -> None: sink = _CaptureSink() policy = SecurityPolicy(rules={}) request = SecurityRequest( @@ -214,11 +215,7 @@ def test_cli_mutation_paths_emit_security_telemetry(self) -> None: snapshot = json.loads(runtime.snapshot_path.read_text(encoding="utf-8")) - telemetry = [ - record - for record in snapshot.get("telemetry", ()) - if record.get("source") == "cli.operator" - ] + telemetry = [record for record in snapshot.get("telemetry", ()) if record.get("source") == "cli.operator"] self.assertGreaterEqual(len(telemetry), 2) self.assertEqual(telemetry[0]["name"], "approval.requested") self.assertIn(telemetry[-1]["name"], {"approval.granted", "approval.denied"}) @@ -254,16 +251,14 @@ def test_gateway_delivery_emits_policy_telemetry(self) -> None: ) self.assertEqual(exchange.delivery.policy_result.decision, PolicyDecision.ALLOW) - security_events = [ - record - for record in app.telemetry.events - if record.get("source") == "gateway.messaging" - ] + security_events = [record for record in app.telemetry.events if record.get("source") == "gateway.messaging"] self.assertGreaterEqual(len(security_events), 4) self.assertEqual(security_events[0]["name"], "approval.requested") self.assertEqual(security_events[-1]["name"], "approval.granted") - def test_security_doctor_surfaces_redacted_support_bundle_and_policy_bundles(self) -> None: + def test_security_doctor_surfaces_redacted_support_bundle_and_policy_bundles( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: root = Path(tmpdir) state_dir = root / "state" @@ -303,8 +298,14 @@ def test_security_doctor_surfaces_redacted_support_bundle_and_policy_bundles(sel report["support_bundle"]["provider"]["stored_secret_reference_ids"], report["support_bundle"]["provider"]["secret_reference_ids"], ) - self.assertEqual(report["support_bundle"]["provider"]["secret_store"], "encrypted-local-store") - self.assertEqual(report["support_bundle"]["embedding_provider"]["provider_id"], "openai-compatible-embed") + self.assertEqual( + report["support_bundle"]["provider"]["secret_store"], + "encrypted-local-store", + ) + self.assertEqual( + report["support_bundle"]["embedding_provider"]["provider_id"], + "openai-compatible-embed", + ) self.assertTrue(report["support_bundle"]["embedding_provider"]["secret_reference_ids"]) self.assertEqual( report["support_bundle"]["embedding_provider"]["stored_secret_reference_ids"], @@ -366,7 +367,10 @@ def test_security_doctor_warns_when_provider_secrets_are_not_stored(self) -> Non provider_id="openai-compatible-embed", secret_name="api_token", secret_key="api_key", - metadata={"storage": "local-vault", "scope": "embedding-provider"}, + metadata={ + "storage": "local-vault", + "scope": "embedding-provider", + }, ), ), metadata={"embedding_active": "true", "dimensions": "1536"}, @@ -375,12 +379,13 @@ def test_security_doctor_warns_when_provider_secrets_are_not_stored(self) -> Non report = runtime.security_doctor() self.assertEqual(report["status"], "not-ready") - boundary_check = next( - check for check in report["checks"] if check["check"] == "secret_boundary" - ) + boundary_check = next(check for check in report["checks"] if check["check"] == "secret_boundary") self.assertEqual(boundary_check["status"], "warning") self.assertIn("missing stored provider secrets", str(boundary_check["summary"])) - self.assertIn("secret-embedding-provider-openai-compatible-active-api-key", str(boundary_check["summary"])) + self.assertIn( + "secret-embedding-provider-openai-compatible-active-api-key", + str(boundary_check["summary"]), + ) if __name__ == "__main__": diff --git a/tests/integration/semantic_index/test_episode_close_writeback.py b/tests/integration/semantic_index/test_episode_close_writeback.py index 50c89ac..7cad54d 100644 --- a/tests/integration/semantic_index/test_episode_close_writeback.py +++ b/tests/integration/semantic_index/test_episode_close_writeback.py @@ -37,7 +37,12 @@ class _StubEmbeddingService: """Deterministic embedding: letters map to one-hot buckets.""" - def __init__(self, provider_id: str = "stub", model_id: str = "stub-embed", dimensions: int = 64) -> None: + def __init__( + self, + provider_id: str = "stub", + model_id: str = "stub-embed", + dimensions: int = 64, + ) -> None: self._provider_id = provider_id self._model_id = model_id self._dimensions = dimensions diff --git a/tests/integration/semantic_index/test_hybrid_search.py b/tests/integration/semantic_index/test_hybrid_search.py index 0c9f161..4dda2a4 100644 --- a/tests/integration/semantic_index/test_hybrid_search.py +++ b/tests/integration/semantic_index/test_hybrid_search.py @@ -74,7 +74,10 @@ def test_hybrid_search_uses_scope_gates_and_weighted_rrf(self) -> None: ) ) - self.assertEqual(tuple(match.document.source_id for match in matches), ("step:alpha-error", "step:alpha-vector")) + self.assertEqual( + tuple(match.document.source_id for match in matches), + ("step:alpha-error", "step:alpha-vector"), + ) self.assertIn("keyword_exact", matches[0].signal_scores) self.assertIn("vector", matches[0].signal_scores) self.assertIn("vector", matches[1].signal_scores) @@ -124,7 +127,10 @@ def test_degraded_vector_search_falls_back_to_lexical_signals(self) -> None: self.assertEqual(tuple(match.document.source_id for match in matches), ("step:heartbeat",)) self.assertEqual(backend.search_calls, 0) - self.assertEqual(set(matches[0].signal_scores), {"token_coverage", "keyword_exact", "bm25", "ngram"}) + self.assertEqual( + set(matches[0].signal_scores), + {"token_coverage", "keyword_exact", "bm25", "ngram"}, + ) self.assertNotIn("vector", matches[0].signal_scores) def test_unicode_lexical_matches_cjk_split_and_fuzzy_queries(self) -> None: @@ -173,8 +179,14 @@ def test_unicode_lexical_matches_cjk_split_and_fuzzy_queries(self) -> None: ) ) - self.assertEqual(tuple(match.document.source_id for match in split_matches), ("step:fog-crossing",)) - self.assertEqual(tuple(match.document.source_id for match in fuzzy_matches), ("step:quiet-corner",)) + self.assertEqual( + tuple(match.document.source_id for match in split_matches), + ("step:fog-crossing",), + ) + self.assertEqual( + tuple(match.document.source_id for match in fuzzy_matches), + ("step:quiet-corner",), + ) self.assertTrue({"token_coverage", "ngram"} & set(split_matches[0].signal_scores)) self.assertIn("ngram", fuzzy_matches[0].signal_scores) diff --git a/tests/integration/semantic_index/test_metadata.py b/tests/integration/semantic_index/test_metadata.py index a56bf58..b4f47b3 100644 --- a/tests/integration/semantic_index/test_metadata.py +++ b/tests/integration/semantic_index/test_metadata.py @@ -53,7 +53,10 @@ def test_service_persists_vector_metadata_and_indexes_vector(self) -> None: self.assertEqual(loaded.model_id, "elephant-embed") self.assertEqual(loaded.dimensions, 4) self.assertEqual(loaded.owner_scope, "state") - self.assertEqual(loaded.content_hash, semantic_content_hash("release checklist and package verification")) + self.assertEqual( + loaded.content_hash, + semantic_content_hash("release checklist and package verification"), + ) self.assertEqual(loaded.status, "indexed") self.assertTrue(loaded.vector_ref.startswith("sqlite-vec:4:semantic-index:")) self.assertEqual(loaded.metadata["backend_version"], "0.1.9") @@ -119,7 +122,9 @@ def test_service_deletes_metadata_and_vectors_by_scope(self) -> None: self.assertIsNone(loaded) self.assertEqual(matches, ()) - def test_rebuild_plan_tracks_provider_model_dimension_and_content_changes(self) -> None: + def test_rebuild_plan_tracks_provider_model_dimension_and_content_changes( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: root = Path(tmpdir) repository = RuntimeStorageRepository(root / "state" / "elephant.sqlite3") @@ -171,6 +176,7 @@ def test_rebuild_plan_tracks_provider_model_dimension_and_content_changes(self) self.assertEqual(len(rebuild.delete_entry_ids), 1) self.assertEqual(rebuild.rebuild_documents, (changed_dimensions,)) + class _DegradedBackend: backend_id = "sqlite-vec" diff --git a/tests/integration/semantic_index/test_multilingual_hybrid_search.py b/tests/integration/semantic_index/test_multilingual_hybrid_search.py index 149c6e7..2f611c1 100644 --- a/tests/integration/semantic_index/test_multilingual_hybrid_search.py +++ b/tests/integration/semantic_index/test_multilingual_hybrid_search.py @@ -43,7 +43,12 @@ class _StubEmbeddingService: """Deterministic letter-bucket embedder (same as sibling tests).""" - def __init__(self, provider_id: str = "stub", model_id: str = "stub-embed", dimensions: int = 64) -> None: + def __init__( + self, + provider_id: str = "stub", + model_id: str = "stub-embed", + dimensions: int = 64, + ) -> None: self._provider_id = provider_id self._model_id = model_id self._dimensions = dimensions @@ -54,8 +59,8 @@ def embed_text(self, text: str, **kwargs) -> EmbeddingVector: bucket = [0.0] * self._dimensions lowered = text.lower() for ch in lowered: - if ch.isalpha() or '\u4e00' <= ch <= '\u9fff' or '\u3040' <= ch <= '\u30ff' or '\uac00' <= ch <= '\ud7af': - idx = (ord(ch[0]) % self._dimensions) + if ch.isalpha() or "\u4e00" <= ch <= "\u9fff" or "\u3040" <= ch <= "\u30ff" or "\uac00" <= ch <= "\ud7af": + idx = ord(ch[0]) % self._dimensions bucket[idx] += 1.0 total = sum(bucket) or 1.0 return EmbeddingVector( @@ -91,7 +96,17 @@ def _build_surface(tmpdir: str) -> tuple[PersonalModelUnderstandingSurface, str] return surface, state.personal_model_id, indexer, repository -def _index_fact(repository, indexer, *, fact_id, personal_model_id, lens, text, topic, status="active"): +def _index_fact( + repository, + indexer, + *, + fact_id, + personal_model_id, + lens, + text, + topic, + status="active", +): fact = Fact( fact_id=fact_id, personal_model_id=personal_model_id, @@ -109,29 +124,51 @@ def _index_fact(repository, indexer, *, fact_id, personal_model_id, lens, text, class MultilingualHybridSearchTest(unittest.TestCase): - # ── 1. Chinese exact / fuzzy / negation ────────────────────────── def test_chinese_exact_and_fuzzy_and_negation(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: surface, pm_id, indexer, repo = _build_surface(tmpdir) - _index_fact(repo, indexer, - fact_id="c:weather-crossing", personal_model_id=pm_id, lens="identity", - text="我喜欢像站在起雾的路口那样慢慢做决定。", topic="test.weather.crossing") - _index_fact(repo, indexer, - fact_id="c:quiet-corner", personal_model_id=pm_id, lens="pulse", - text="能量低的时候,我需要一个安静角落。", topic="test.recovery.low_energy") - _index_fact(repo, indexer, - fact_id="c:social-positive", personal_model_id=pm_id, lens="identity", - text="我喜欢周末约朋友去热闹的 bar。", topic="test.social.positive") - _index_fact(repo, indexer, - fact_id="c:social-negative", personal_model_id=pm_id, lens="identity", - text="我不喜欢参加需要大声说话才能沟通的聚会。", topic="test.social.negative") + _index_fact( + repo, + indexer, + fact_id="c:weather-crossing", + personal_model_id=pm_id, + lens="identity", + text="我喜欢像站在起雾的路口那样慢慢做决定。", + topic="test.weather.crossing", + ) + _index_fact( + repo, + indexer, + fact_id="c:quiet-corner", + personal_model_id=pm_id, + lens="pulse", + text="能量低的时候,我需要一个安静角落。", + topic="test.recovery.low_energy", + ) + _index_fact( + repo, + indexer, + fact_id="c:social-positive", + personal_model_id=pm_id, + lens="identity", + text="我喜欢周末约朋友去热闹的 bar。", + topic="test.social.positive", + ) + _index_fact( + repo, + indexer, + fact_id="c:social-negative", + personal_model_id=pm_id, + lens="identity", + text="我不喜欢参加需要大声说话才能沟通的聚会。", + topic="test.social.negative", + ) def top_ref(query: str, **kw) -> str: - result = surface.search_personal_model( - "s1", query=query, limit=5, personal_model_id=pm_id, **kw) + result = surface.search_personal_model("s1", query=query, limit=5, personal_model_id=pm_id, **kw) claims = tuple(result.get("claims") or ()) self.assertTrue(claims, f"no match for query={query!r}") return str(claims[0]["ref"]) @@ -141,7 +178,10 @@ def top_ref(query: str, **kw) -> str: # Fuzzy / typo-tolerant self.assertEqual(top_ref("安净角落"), "c:quiet-corner") # Current lexical ranking keeps the overlapping social claims visible. - self.assertIn(top_ref("喜欢热闹 聚会 大声说话"), {"c:social-positive", "c:social-negative"}) + self.assertIn( + top_ref("喜欢热闹 聚会 大声说话"), + {"c:social-positive", "c:social-negative"}, + ) # Negation disambiguation: negative-preferring query should prefer negative variant self.assertEqual(top_ref("不喜欢大声说话的聚会"), "c:social-negative") @@ -151,26 +191,45 @@ def test_english_exact_synonym_metaphor(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: surface, pm_id, indexer, repo = _build_surface(tmpdir) - _index_fact(repo, indexer, - fact_id="e:caching-redis", personal_model_id=pm_id, lens="world", + _index_fact( + repo, + indexer, + fact_id="e:caching-redis", + personal_model_id=pm_id, + lens="world", text="User chose Redis caching over memcached for the aegis project.", - topic="test.caching.strategy") - _index_fact(repo, indexer, - fact_id="e:code-lego", personal_model_id=pm_id, lens="identity", + topic="test.caching.strategy", + ) + _index_fact( + repo, + indexer, + fact_id="e:code-lego", + personal_model_id=pm_id, + lens="identity", text="好的代码像是拼好的乐高,每块都刚好卡在它该在的位置。", - topic="test.code.metaphor") - _index_fact(repo, indexer, - fact_id="e:city-walk", personal_model_id=pm_id, lens="identity", + topic="test.code.metaphor", + ) + _index_fact( + repo, + indexer, + fact_id="e:city-walk", + personal_model_id=pm_id, + lens="identity", text="周末最喜欢做的事是 city walk,在成都的街头漫无目的地走。", - topic="test.weekend.routine") - _index_fact(repo, indexer, - fact_id="e:quiet-corner-en", personal_model_id=pm_id, lens="pulse", + topic="test.weekend.routine", + ) + _index_fact( + repo, + indexer, + fact_id="e:quiet-corner-en", + personal_model_id=pm_id, + lens="pulse", text="When energy is low, give me a quiet corner to sit in.", - topic="test.recovery.style") + topic="test.recovery.style", + ) def top_ref(query: str, **kw) -> str: - result = surface.search_personal_model( - "s2", query=query, limit=5, personal_model_id=pm_id, **kw) + result = surface.search_personal_model("s2", query=query, limit=5, personal_model_id=pm_id, **kw) claims = tuple(result.get("claims") or ()) self.assertTrue(claims, f"no match for query={query!r}") return str(claims[0]["ref"]) @@ -190,27 +249,50 @@ def test_cross_lingual_query_variants(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: surface, pm_id, indexer, repo = _build_surface(tmpdir) - _index_fact(repo, indexer, - fact_id="x:music-cn", personal_model_id=pm_id, lens="identity", - text="个人爱好包含音乐和周末听唱片。", topic="test.music.preference") - _index_fact(repo, indexer, - fact_id="x:solitude-cn", personal_model_id=pm_id, lens="identity", - text="孤独有时候是干净的。", topic="test.trait.solitude") - _index_fact(repo, indexer, - fact_id="x:city-walk-cn", personal_model_id=pm_id, lens="identity", - text="周末 city walk 是我最放松的方式。", topic="test.weekend.routine") + _index_fact( + repo, + indexer, + fact_id="x:music-cn", + personal_model_id=pm_id, + lens="identity", + text="个人爱好包含音乐和周末听唱片。", + topic="test.music.preference", + ) + _index_fact( + repo, + indexer, + fact_id="x:solitude-cn", + personal_model_id=pm_id, + lens="identity", + text="孤独有时候是干净的。", + topic="test.trait.solitude", + ) + _index_fact( + repo, + indexer, + fact_id="x:city-walk-cn", + personal_model_id=pm_id, + lens="identity", + text="周末 city walk 是我最放松的方式。", + topic="test.weekend.routine", + ) def top_ref(query: str, **kw) -> str: - result = surface.search_personal_model( - "s3", query=query, limit=5, personal_model_id=pm_id, **kw) + result = surface.search_personal_model("s3", query=query, limit=5, personal_model_id=pm_id, **kw) claims = tuple(result.get("claims") or ()) self.assertTrue(claims, f"no match for query={query!r}") return str(claims[0]["ref"]) # EN query → query_variants with CN translation should bridge self.assertEqual(top_ref("music", query_variants=("音乐",)), "x:music-cn") - self.assertEqual(top_ref("solitude is pure", query_variants=("孤独是干净的",)), "x:solitude-cn") - self.assertEqual(top_ref("weekend walk", query_variants=("周末 city walk",)), "x:city-walk-cn") + self.assertEqual( + top_ref("solitude is pure", query_variants=("孤独是干净的",)), + "x:solitude-cn", + ) + self.assertEqual( + top_ref("weekend walk", query_variants=("周末 city walk",)), + "x:city-walk-cn", + ) # Reverse: CN query → EN variant self.assertEqual(top_ref("音乐爱好", query_variants=("music hobby",)), "x:music-cn") @@ -221,19 +303,36 @@ def test_negation_polarity_disambiguation(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: surface, pm_id, indexer, repo = _build_surface(tmpdir) - _index_fact(repo, indexer, - fact_id="np:social-positive-2", personal_model_id=pm_id, lens="identity", - text="我喜欢周末和朋友们一起去热闹的 bar 喝酒聊天。", topic="test.social.positive_2") - _index_fact(repo, indexer, - fact_id="np:social-negative-2", personal_model_id=pm_id, lens="identity", - text="我不喜欢去人多吵闹的聚会,说话太费劲。", topic="test.social.negative_2") - _index_fact(repo, indexer, - fact_id="np:quiet-preference", personal_model_id=pm_id, lens="pulse", - text="安静的环境让我觉得安全。", topic="test.quiet.preference") + _index_fact( + repo, + indexer, + fact_id="np:social-positive-2", + personal_model_id=pm_id, + lens="identity", + text="我喜欢周末和朋友们一起去热闹的 bar 喝酒聊天。", + topic="test.social.positive_2", + ) + _index_fact( + repo, + indexer, + fact_id="np:social-negative-2", + personal_model_id=pm_id, + lens="identity", + text="我不喜欢去人多吵闹的聚会,说话太费劲。", + topic="test.social.negative_2", + ) + _index_fact( + repo, + indexer, + fact_id="np:quiet-preference", + personal_model_id=pm_id, + lens="pulse", + text="安静的环境让我觉得安全。", + topic="test.quiet.preference", + ) def top_ref(query: str, **kw) -> str: - result = surface.search_personal_model( - "s4", query=query, limit=5, personal_model_id=pm_id, **kw) + result = surface.search_personal_model("s4", query=query, limit=5, personal_model_id=pm_id, **kw) claims = tuple(result.get("claims") or ()) self.assertTrue(claims, f"no match for query={query!r}") return str(claims[0]["ref"]) @@ -251,19 +350,34 @@ def test_low_information_rejection_and_zero_result(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: surface, pm_id, indexer, repo = _build_surface(tmpdir) - _index_fact(repo, indexer, - fact_id="z:normal-fact", personal_model_id=pm_id, lens="identity", - text="用户喜欢在周末 city walk。", topic="test.routine.weekend") + _index_fact( + repo, + indexer, + fact_id="z:normal-fact", + personal_model_id=pm_id, + lens="identity", + text="用户喜欢在周末 city walk。", + topic="test.routine.weekend", + ) # Low-information query (too short / stop-word-like) low_info = surface.search_personal_model( - "s5", query="a b c d e", limit=5, personal_model_id=pm_id, include_diagnostics=True) + "s5", + query="a b c d e", + limit=5, + personal_model_id=pm_id, + include_diagnostics=True, + ) self.assertEqual(low_info["match_status"], "no_match") # Zero-result topic exact query (nonsense topic) zero = surface.search_personal_model( - "s5", query="量子纠缠 星际旅行 时间机器", limit=5, - personal_model_id=pm_id, include_diagnostics=True) + "s5", + query="量子纠缠 星际旅行 时间机器", + limit=5, + personal_model_id=pm_id, + include_diagnostics=True, + ) self.assertIn(zero["match_status"], {"no_match", "strong_match"}) # ── 6. Multilingual synonym generalization (JP/KR) ────────────── @@ -273,30 +387,59 @@ def test_multilingual_synonym_generalization(self) -> None: surface, pm_id, indexer, repo = _build_surface(tmpdir) # Facts with Japanese and Korean content - _index_fact(repo, indexer, - fact_id="ml:jp-calm", personal_model_id=pm_id, lens="pulse", - text="静かな場所で本を読むのが好きです。", topic="test.jp.reading") - _index_fact(repo, indexer, - fact_id="ml:kr-citywalk", personal_model_id=pm_id, lens="identity", - text="주말에 도시를 걷는 것을 좋아합니다.", topic="test.kr.walking") - _index_fact(repo, indexer, - fact_id="ml:en-jazz", personal_model_id=pm_id, lens="identity", - text="I love listening to jazz on weekend mornings.", topic="test.jazz.preference") - _index_fact(repo, indexer, - fact_id="ml:cn-fog-2", personal_model_id=pm_id, lens="identity", - text="我喜欢在雾气弥漫的早晨散步。", topic="test.foggy.morning") + _index_fact( + repo, + indexer, + fact_id="ml:jp-calm", + personal_model_id=pm_id, + lens="pulse", + text="静かな場所で本を読むのが好きです。", + topic="test.jp.reading", + ) + _index_fact( + repo, + indexer, + fact_id="ml:kr-citywalk", + personal_model_id=pm_id, + lens="identity", + text="주말에 도시를 걷는 것을 좋아합니다.", + topic="test.kr.walking", + ) + _index_fact( + repo, + indexer, + fact_id="ml:en-jazz", + personal_model_id=pm_id, + lens="identity", + text="I love listening to jazz on weekend mornings.", + topic="test.jazz.preference", + ) + _index_fact( + repo, + indexer, + fact_id="ml:cn-fog-2", + personal_model_id=pm_id, + lens="identity", + text="我喜欢在雾气弥漫的早晨散步。", + topic="test.foggy.morning", + ) def top_ref(query: str, **kw) -> str: - result = surface.search_personal_model( - "s6", query=query, limit=5, personal_model_id=pm_id, **kw) + result = surface.search_personal_model("s6", query=query, limit=5, personal_model_id=pm_id, **kw) claims = tuple(result.get("claims") or ()) self.assertTrue(claims, f"no match for query={query!r}") return str(claims[0]["ref"]) # CN query with JP variant → should find JP fact - self.assertEqual(top_ref("安静 读书", query_variants=("静かな場所で本を読む",)), "ml:jp-calm") + self.assertEqual( + top_ref("安静 读书", query_variants=("静かな場所で本を読む",)), + "ml:jp-calm", + ) # The stub embedder can over-weight CJK variants; keep the multilingual path visible. - self.assertIn(top_ref("도시 걷기", query_variants=("城市漫步",)), {"ml:kr-citywalk", "ml:cn-fog-2"}) + self.assertIn( + top_ref("도시 걷기", query_variants=("城市漫步",)), + {"ml:kr-citywalk", "ml:cn-fog-2"}, + ) # EN query → direct match self.assertEqual(top_ref("jazz weekend mornings"), "ml:en-jazz") # CN foggy morning → direct match @@ -308,28 +451,54 @@ def test_topic_exact_and_semantic_merge(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: surface, pm_id, indexer, repo = _build_surface(tmpdir) - _index_fact(repo, indexer, - fact_id="tm:optionality", personal_model_id=pm_id, lens="identity", - text="重要的是保住选择权。", topic="test.choice.optionality") - _index_fact(repo, indexer, - fact_id="tm:choice-preference", personal_model_id=pm_id, lens="identity", - text="做取舍时最不想丢掉的是选择权。", topic="test.choice.preference") - _index_fact(repo, indexer, - fact_id="tm:weekend-routine", personal_model_id=pm_id, lens="identity", - text="周末 city walk 在成都街头漫无目的地走。", topic="test.weekend.routine") + _index_fact( + repo, + indexer, + fact_id="tm:optionality", + personal_model_id=pm_id, + lens="identity", + text="重要的是保住选择权。", + topic="test.choice.optionality", + ) + _index_fact( + repo, + indexer, + fact_id="tm:choice-preference", + personal_model_id=pm_id, + lens="identity", + text="做取舍时最不想丢掉的是选择权。", + topic="test.choice.preference", + ) + _index_fact( + repo, + indexer, + fact_id="tm:weekend-routine", + personal_model_id=pm_id, + lens="identity", + text="周末 city walk 在成都街头漫无目的地走。", + topic="test.weekend.routine", + ) # Topic-only exact mode (no query text) → should match by topic topic_only = surface.search_personal_model( - "s7", topic="test.choice.optionality", limit=5, - personal_model_id=pm_id, include_diagnostics=True) + "s7", + topic="test.choice.optionality", + limit=5, + personal_model_id=pm_id, + include_diagnostics=True, + ) claims = tuple(topic_only.get("claims") or ()) self.assertTrue(claims) self.assertEqual(claims[0]["ref"], "tm:optionality") # Query + topic merge: query should trigger semantic scores too merged = surface.search_personal_model( - "s7", query="保留选择权", limit=5, - personal_model_id=pm_id, include_diagnostics=True) + "s7", + query="保留选择权", + limit=5, + personal_model_id=pm_id, + include_diagnostics=True, + ) m_claims = tuple(merged.get("claims") or ()) self.assertTrue(m_claims) # The more exact text match should rank first @@ -411,33 +580,50 @@ def test_large_batch_ranking_stability(self) -> None: # Insert 15 distractor facts for i in range(15): - _index_fact(repo, indexer, - fact_id=f"dist:{i:03d}", personal_model_id=pm_id, lens="world", + _index_fact( + repo, + indexer, + fact_id=f"dist:{i:03d}", + personal_model_id=pm_id, + lens="world", text=f"Distractor fact number {i} about random topic xyz.", - topic=f"distractor_{i}") + topic=f"distractor_{i}", + ) # Insert 5 target facts about "architecture" for i in range(5): - _index_fact(repo, indexer, - fact_id=f"arch:{i:03d}", personal_model_id=pm_id, lens="world", + _index_fact( + repo, + indexer, + fact_id=f"arch:{i:03d}", + personal_model_id=pm_id, + lens="world", text=f"关于系统架构的讨论,第{i}条:微服务和事件驱动设计的取舍。", - topic=f"architecture_{i}") + topic=f"architecture_{i}", + ) # Insert 5 more distractors after targets (to test recency vs relevance) for i in range(15, 20): - _index_fact(repo, indexer, - fact_id=f"dist:{i:03d}", personal_model_id=pm_id, lens="world", + _index_fact( + repo, + indexer, + fact_id=f"dist:{i:03d}", + personal_model_id=pm_id, + lens="world", text=f"Late distractor fact number {i} about random topic abc.", - topic=f"distractor_{i}") + topic=f"distractor_{i}", + ) - result = surface.search_personal_model( - "s9", query="微服务 架构 事件驱动", limit=5, personal_model_id=pm_id) + result = surface.search_personal_model("s9", query="微服务 架构 事件驱动", limit=5, personal_model_id=pm_id) claims = tuple(result.get("claims") or ()) self.assertTrue(claims) # At least 2 of top 5 should be architecture facts arch_hits = sum(1 for c in claims if c["ref"].startswith("arch:")) - self.assertGreaterEqual(arch_hits, 2, - f"expected >=2 architecture facts in top 5, got {arch_hits}: {[c['ref'] for c in claims]}") + self.assertGreaterEqual( + arch_hits, + 2, + f"expected >=2 architecture facts in top 5, got {arch_hits}: {[c['ref'] for c in claims]}", + ) if __name__ == "__main__": diff --git a/tests/integration/semantic_index/test_multilingual_recall_and_pm_search.py b/tests/integration/semantic_index/test_multilingual_recall_and_pm_search.py index 44266e9..5a88b68 100644 --- a/tests/integration/semantic_index/test_multilingual_recall_and_pm_search.py +++ b/tests/integration/semantic_index/test_multilingual_recall_and_pm_search.py @@ -51,6 +51,7 @@ # ── CJK-capable stub embedder ──────────────────────────────────────────── + class _CJKCapableStubEmbeddingService: """Character-category bucket embedder that handles ASCII, CJK, Hangul, Kana. @@ -149,6 +150,7 @@ def embed_text(self, text: str, **kwargs) -> EmbeddingVector: # ── Test fixture helpers ───────────────────────────────────────────────── + @dataclass class MultilingualFixture: surface: PersonalModelUnderstandingSurface @@ -208,7 +210,14 @@ def _seed_pm_fact(fx: MultilingualFixture, *, fact_id: str, lens: str, text: str fx.indexer.index_personal_model_claim(fact) -def _seed_step(fx: MultilingualFixture, *, step_id: str, summary: str, action: str = "record_input", sequence: int = 1) -> None: +def _seed_step( + fx: MultilingualFixture, + *, + step_id: str, + summary: str, + action: str = "record_input", + sequence: int = 1, +) -> None: """Write one Step (user turn) through the repository + index it for conversation recall.""" episode_id = f"episode-{step_id.split(':')[0].replace('_', '-')}" loop_id = f"loop-{step_id.split(':')[0].replace('_', '-')}" @@ -305,6 +314,7 @@ def _conversation_recall(fx: MultilingualFixture, query: str, scopes: tuple[str, # ── Tests ──────────────────────────────────────────────────────────────── + class MultiLingualRecallAndPMTest(unittest.TestCase): """Comprehensive multilingual test for both conversation recall and PM search.""" @@ -314,10 +324,34 @@ def test_chinese_pm_search_matches_exact_and_fuzzy(self) -> None: """PM search with Chinese query: exact token, fuzzy CJK n-gram.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_pm_fact(fx, fact_id="cn:fog", lens="identity", text="我喜欢像站在起雾的路口那样慢慢做决定。", topic="test.decision.fog") - _seed_pm_fact(fx, fact_id="cn:quiet", lens="pulse", text="能量低的时候,我需要一个安静角落。", topic="test.recovery.corner") - _seed_pm_fact(fx, fact_id="cn:social-pos", lens="world", text="我喜欢周末约朋友去热闹的 bar。", topic="test.social.bar") - _seed_pm_fact(fx, fact_id="cn:social-neg", lens="identity", text="我不喜欢参加需要大声说话才能沟通的聚会。", topic="test.social.party") + _seed_pm_fact( + fx, + fact_id="cn:fog", + lens="identity", + text="我喜欢像站在起雾的路口那样慢慢做决定。", + topic="test.decision.fog", + ) + _seed_pm_fact( + fx, + fact_id="cn:quiet", + lens="pulse", + text="能量低的时候,我需要一个安静角落。", + topic="test.recovery.corner", + ) + _seed_pm_fact( + fx, + fact_id="cn:social-pos", + lens="world", + text="我喜欢周末约朋友去热闹的 bar。", + topic="test.social.bar", + ) + _seed_pm_fact( + fx, + fact_id="cn:social-neg", + lens="identity", + text="我不喜欢参加需要大声说话才能沟通的聚会。", + topic="test.social.party", + ) # Exact token match self.assertEqual(_top_pm_ref(fx, "起雾 路口"), "cn:fog") @@ -335,8 +369,18 @@ def test_chinese_conversation_recall_exact_and_fuzzy(self) -> None: """Conversation recall (tool.conversation.search) with Chinese queries.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_step(fx, step_id="cn:step-fog", summary="用户说:我喜欢像站在起雾的路口那样慢慢做决定。", sequence=1) - _seed_step(fx, step_id="cn:step-quiet", summary="用户说:能量低的时候,我需要一个安静角落。", sequence=1) + _seed_step( + fx, + step_id="cn:step-fog", + summary="用户说:我喜欢像站在起雾的路口那样慢慢做决定。", + sequence=1, + ) + _seed_step( + fx, + step_id="cn:step-quiet", + summary="用户说:能量低的时候,我需要一个安静角落。", + sequence=1, + ) # Token match: "起雾的路口" is an exact substring of the step text hits = _conversation_recall(fx, "起雾的路口") @@ -356,9 +400,27 @@ def test_english_pm_search_exact_and_synonym(self) -> None: """PM search with English query: exact, synonym, conceptual.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_pm_fact(fx, fact_id="en:redis", lens="world", text="User chose Redis caching over memcached for the aegis project.", topic="test.caching") - _seed_pm_fact(fx, fact_id="en:lego", lens="identity", text="好的代码像是拼好的乐高,每块都刚好卡在它该在的位置。", topic="test.code.metaphor") - _seed_pm_fact(fx, fact_id="en:citywalk", lens="pulse", text="周末最喜欢做的事是 city walk,在成都的街头漫无目的地走。", topic="test.weekend.routine") + _seed_pm_fact( + fx, + fact_id="en:redis", + lens="world", + text="User chose Redis caching over memcached for the aegis project.", + topic="test.caching", + ) + _seed_pm_fact( + fx, + fact_id="en:lego", + lens="identity", + text="好的代码像是拼好的乐高,每块都刚好卡在它该在的位置。", + topic="test.code.metaphor", + ) + _seed_pm_fact( + fx, + fact_id="en:citywalk", + lens="pulse", + text="周末最喜欢做的事是 city walk,在成都的街头漫无目的地走。", + topic="test.weekend.routine", + ) # Exact match self.assertEqual(_top_pm_ref(fx, "redis caching"), "en:redis") @@ -369,8 +431,16 @@ def test_english_conversation_recall(self) -> None: """Conversation recall with English queries.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_step(fx, step_id="en:step-redis", summary="User decided to use Redis caching over memcached.") - _seed_step(fx, step_id="en:step-lego", summary="好的代码像是拼好的乐高——每块都刚好卡在它该在的位置。") + _seed_step( + fx, + step_id="en:step-redis", + summary="User decided to use Redis caching over memcached.", + ) + _seed_step( + fx, + step_id="en:step-lego", + summary="好的代码像是拼好的乐高——每块都刚好卡在它该在的位置。", + ) # English token overlap: "redis" is a token match via _TOKEN_RE hits = _conversation_recall(fx, "redis caching") @@ -390,9 +460,27 @@ def test_cross_lingual_pm_search_query_variants(self) -> None: """PM search with cross-lingual query_variants bridges CN↔EN.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_pm_fact(fx, fact_id="xl:music", lens="identity", text="个人爱好包含音乐和周末听唱片。", topic="test.music") - _seed_pm_fact(fx, fact_id="xl:solitude", lens="identity", text="孤独有时候是干净的。", topic="test.trait.solitude") - _seed_pm_fact(fx, fact_id="xl:citywalk", lens="pulse", text="周末 city walk 是我最放松的方式。", topic="test.weekend.citywalk") + _seed_pm_fact( + fx, + fact_id="xl:music", + lens="identity", + text="个人爱好包含音乐和周末听唱片。", + topic="test.music", + ) + _seed_pm_fact( + fx, + fact_id="xl:solitude", + lens="identity", + text="孤独有时候是干净的。", + topic="test.trait.solitude", + ) + _seed_pm_fact( + fx, + fact_id="xl:citywalk", + lens="pulse", + text="周末 city walk 是我最放松的方式。", + topic="test.weekend.citywalk", + ) # EN query → CN variant → should find CN fact self.assertEqual( @@ -417,7 +505,11 @@ def test_cross_lingual_conversation_recall_fallback(self) -> None: """Conversation recall with cross-lingual queries.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_step(fx, step_id="xl:step-music", summary="用户分享说个人爱好包含音乐和周末听唱片。") + _seed_step( + fx, + step_id="xl:step-music", + summary="用户分享说个人爱好包含音乐和周末听唱片。", + ) # EN query — lexical fallback may not bridge CN→EN without variants, # but should still surface the step if tokens overlap @@ -432,9 +524,27 @@ def test_jp_kr_pm_search(self) -> None: """PM search with Japanese and Korean content.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_pm_fact(fx, fact_id="jp:calm", lens="identity", text="静かな場所で本を読むのが好きです。", topic="test.jp.reading") - _seed_pm_fact(fx, fact_id="kr:walking", lens="pulse", text="주말에 도시를 걷는 것을 좋아합니다.", topic="test.kr.walking") - _seed_pm_fact(fx, fact_id="en:jazz", lens="identity", text="I love listening to jazz on weekend mornings.", topic="test.jazz.morning") + _seed_pm_fact( + fx, + fact_id="jp:calm", + lens="identity", + text="静かな場所で本を読むのが好きです。", + topic="test.jp.reading", + ) + _seed_pm_fact( + fx, + fact_id="kr:walking", + lens="pulse", + text="주말에 도시를 걷는 것을 좋아합니다.", + topic="test.kr.walking", + ) + _seed_pm_fact( + fx, + fact_id="en:jazz", + lens="identity", + text="I love listening to jazz on weekend mornings.", + topic="test.jazz.morning", + ) # JP query with JP variant self.assertEqual( @@ -453,8 +563,16 @@ def test_jp_kr_conversation_recall(self) -> None: """Conversation recall with Japanese and Korean content.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_step(fx, step_id="jp:step-calm", summary="ユーザー:静かな場所で本を読むのが好きです。") - _seed_step(fx, step_id="kr:step-walking", summary="사용자:주말에 도시를 걷는 것을 좋아합니다.") + _seed_step( + fx, + step_id="jp:step-calm", + summary="ユーザー:静かな場所で本を読むのが好きです。", + ) + _seed_step( + fx, + step_id="kr:step-walking", + summary="사용자:주말에 도시를 걷는 것을 좋아합니다.", + ) # JP lexical match hits = _conversation_recall(fx, "静かな場所 本を読む") @@ -474,28 +592,45 @@ def test_mixed_language_pm_search(self) -> None: """PM search with mixed CN/EN content.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_pm_fact(fx, fact_id="mx:citywalk", lens="pulse", + _seed_pm_fact( + fx, + fact_id="mx:citywalk", + lens="pulse", text="周末 city walk 是我最放松的方式。在成都的街头漫无目的地走。", - topic="test.weekend.mixed") - _seed_pm_fact(fx, fact_id="mx:code", lens="identity", + topic="test.weekend.mixed", + ) + _seed_pm_fact( + fx, + fact_id="mx:code", + lens="identity", text="喜欢用 Python 写一些小工具来自动化日常工作。", - topic="test.workflow.automation") + topic="test.workflow.automation", + ) # CN tokens should match self.assertEqual(_top_pm_ref(fx, "周末 city walk"), "mx:citywalk") # Mixed query self.assertEqual(_top_pm_ref(fx, "Python 自动化"), "mx:code") # EN-only token inside mixed fact - self.assertEqual(_top_pm_ref(fx, "Python automation tools", query_variants=("Python 小工具",)), "mx:code") + self.assertEqual( + _top_pm_ref(fx, "Python automation tools", query_variants=("Python 小工具",)), + "mx:code", + ) def test_mixed_language_conversation_recall(self) -> None: """Conversation recall with mixed CN/EN content.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_step(fx, step_id="mx:step-citywalk", - summary="用户说周末 city walk 是最放松的方式,在成都的街头漫无目的地走。") - _seed_step(fx, step_id="mx:step-python", - summary="用户说喜欢用 Python 写一些小工具来自动化日常工作。") + _seed_step( + fx, + step_id="mx:step-citywalk", + summary="用户说周末 city walk 是最放松的方式,在成都的街头漫无目的地走。", + ) + _seed_step( + fx, + step_id="mx:step-python", + summary="用户说喜欢用 Python 写一些小工具来自动化日常工作。", + ) # CN token match hits = _conversation_recall(fx, "成都的街头") @@ -516,7 +651,13 @@ def test_same_content_through_both_paths(self) -> None: pm_text = "周末最喜欢在成都 city walk,漫无目的地走。" step_summary = f"用户说:{pm_text}" - _seed_pm_fact(fx, fact_id="cmp:citywalk", lens="pulse", text=pm_text, topic="test.weekend.citywalk") + _seed_pm_fact( + fx, + fact_id="cmp:citywalk", + lens="pulse", + text=pm_text, + topic="test.weekend.citywalk", + ) _seed_step(fx, step_id="cmp:step-citywalk", summary=step_summary) # PM search path @@ -537,13 +678,19 @@ def test_step_recall_filters_tool_execution_noise_multilingual(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) # Tool step — should be filtered out - _seed_step(fx, step_id="noise:tool", + _seed_step( + fx, + step_id="noise:tool", summary="tool result: 家庭权力结构分析完成", - action="call_tool") + action="call_tool", + ) # User step — should be recallable - _seed_step(fx, step_id="noise:user", + _seed_step( + fx, + step_id="noise:user", summary="用户说:我们讨论了家庭权力结构的问题。", - action="record_input") + action="record_input", + ) hits = _conversation_recall(fx, "家庭权力结构") contents = "\n".join(h.content for h in hits) @@ -558,28 +705,41 @@ def test_large_batch_multilingual_ranking_stability(self) -> None: fx = _build_fixture(tmpdir) # Distractors for i in range(15): - _seed_pm_fact(fx, fact_id=f"dist:{i:03d}", lens="world", + _seed_pm_fact( + fx, + fact_id=f"dist:{i:03d}", + lens="world", text=f"Distractor fact number {i} about random topic xyz.", - topic=f"distractor_{i}") + topic=f"distractor_{i}", + ) # Target facts about "architecture" in mixed CN/EN for i in range(5): - _seed_pm_fact(fx, fact_id=f"arch:{i:03d}", lens="world", + _seed_pm_fact( + fx, + fact_id=f"arch:{i:03d}", + lens="world", text=f"关于系统架构的讨论,第{i}条:微服务和事件驱动设计的取舍。", - topic=f"architecture_{i}") + topic=f"architecture_{i}", + ) # More distractors for i in range(15, 20): - _seed_pm_fact(fx, fact_id=f"dist:{i:03d}", lens="world", + _seed_pm_fact( + fx, + fact_id=f"dist:{i:03d}", + lens="world", text=f"Late distractor fact number {i} about random topic abc.", - topic=f"distractor_{i}") + topic=f"distractor_{i}", + ) result = _pm_search(fx, "微服务 架构 事件驱动") claims = tuple(result.get("claims") or ()) self.assertTrue(claims, "should find architecture facts") arch_hits = sum(1 for c in claims if c["ref"].startswith("arch:")) self.assertGreaterEqual( - arch_hits, 2, + arch_hits, + 2, f"expected >=2 architecture facts in top-5, got {arch_hits}: {[c['ref'] for c in claims]}", ) @@ -589,10 +749,20 @@ def test_topic_only_exact_pm_search(self) -> None: """PM search with topic-only (no query text) should match by topic.""" with tempfile.TemporaryDirectory() as tmpdir: fx = _build_fixture(tmpdir) - _seed_pm_fact(fx, fact_id="to:optionality", lens="identity", - text="重要的是保住选择权。", topic="test.choice.optionality") - _seed_pm_fact(fx, fact_id="to:choice", lens="identity", - text="做取舍时最不想丢掉的是选择权。", topic="test.choice.preference") + _seed_pm_fact( + fx, + fact_id="to:optionality", + lens="identity", + text="重要的是保住选择权。", + topic="test.choice.optionality", + ) + _seed_pm_fact( + fx, + fact_id="to:choice", + lens="identity", + text="做取舍时最不想丢掉的是选择权。", + topic="test.choice.preference", + ) topic_only = _pm_search(fx, "", topic="test.choice.optionality", include_diagnostics=True) claims = tuple(topic_only.get("claims") or ()) @@ -631,8 +801,13 @@ def test_degraded_pm_search_falls_back_to_lexical(self) -> None: indexer=indexer, ) - _seed_pm_fact(fx, fact_id="dg:citywalk", lens="pulse", - text="周末最喜欢在成都 city walk。", topic="test.weekend.citywalk") + _seed_pm_fact( + fx, + fact_id="dg:citywalk", + lens="pulse", + text="周末最喜欢在成都 city walk。", + topic="test.weekend.citywalk", + ) # Should still find via lexical signals even without embeddings result = _pm_search(fx, "周末 成都 city walk") diff --git a/tests/integration/semantic_index/test_unified_recall_end_to_end.py b/tests/integration/semantic_index/test_unified_recall_end_to_end.py index e23ac25..fd74153 100644 --- a/tests/integration/semantic_index/test_unified_recall_end_to_end.py +++ b/tests/integration/semantic_index/test_unified_recall_end_to_end.py @@ -150,7 +150,10 @@ def test_personal_model_search_uses_semantic_index(self) -> None: committed_at=_NOW, source="user_explicit", status="active", - metadata={"topic": "assistant.review.style", "reason": "user corrected the assistant"}, + metadata={ + "topic": "assistant.review.style", + "reason": "user corrected the assistant", + }, ) repository.upsert_personal_model_fact(fact) self.assertIsNotNone(indexer.index_personal_model_claim(fact)) @@ -172,7 +175,9 @@ def test_personal_model_search_uses_semantic_index(self) -> None: self.assertTrue(claims) self.assertEqual(claims[0]["ref"], fact.fact_id) - def test_personal_model_search_merges_fielded_unicode_fuzzy_and_alias_signals(self) -> None: + def test_personal_model_search_merges_fielded_unicode_fuzzy_and_alias_signals( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: root = Path(tmpdir) state_dir = root / "state" @@ -304,9 +309,15 @@ def top_ref(query: str, **kwargs) -> str: self.assertEqual(top_ref("安净角落"), "claim:quiet-corner") self.assertEqual(top_ref("", topic="test.topic.t1"), "claim:topic-only") self.assertEqual(top_ref("music", query_variants=("音乐",)), "claim:music-cn") - self.assertEqual(top_ref("solitude is pure", query_variants=("孤独是干净的",)), "claim:solitude-clean") + self.assertEqual( + top_ref("solitude is pure", query_variants=("孤独是干净的",)), + "claim:solitude-clean", + ) self.assertEqual(top_ref("保留选择权"), "claim:choice") - self.assertIn(top_ref("喜欢热闹 聚会 大声说话"), {"claim:social-positive", "claim:social-negative"}) + self.assertIn( + top_ref("喜欢热闹 聚会 大声说话"), + {"claim:social-positive", "claim:social-negative"}, + ) self.assertEqual(top_ref("不喜欢大声说话的聚会"), "claim:social-negative") no_match = surface.search_personal_model( @@ -400,7 +411,10 @@ def test_step_semantic_recall_filters_tool_execution_noise(self) -> None: sequence=1, created_at=_NOW, summary="tool result says family power structure", - metadata={"tool_name": "tool.conversation.search", "tool_result": "family power tool report"}, + metadata={ + "tool_name": "tool.conversation.search", + "tool_result": "family power tool report", + }, ) user_step = Step( step_id="step:user", @@ -414,7 +428,10 @@ def test_step_semantic_recall_filters_tool_execution_noise(self) -> None: sequence=2, created_at=_NOW, summary="source item ingested", - metadata={"event_type": "turn.received", "user_query": "We discussed family power structure."}, + metadata={ + "event_type": "turn.received", + "user_query": "We discussed family power structure.", + }, ) repository.upsert_episode( Episode( diff --git a/tests/integration/storage_system_layers/test_learning_jobs.py b/tests/integration/storage_system_layers/test_learning_jobs.py index b0dc878..d78cafd 100644 --- a/tests/integration/storage_system_layers/test_learning_jobs.py +++ b/tests/integration/storage_system_layers/test_learning_jobs.py @@ -10,7 +10,9 @@ class StorageLearningJobsTest(unittest.TestCase): - def test_learning_job_lifecycle_supports_queue_claim_retry_and_complete(self) -> None: + def test_learning_job_lifecycle_supports_queue_claim_retry_and_complete( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "state" / "elephant.sqlite3") repository.bootstrap() @@ -179,7 +181,12 @@ def test_learning_job_force_new_allows_manual_rerun_history(self) -> None: self.assertNotEqual(first.job_id, manual.job_id) self.assertEqual(len(repository.list_learning_jobs(state_id=state.state_id)), 2) - self.assertEqual(repository.load_learning_job_for_episode(job_type="episode_boundary_learning", episode_id=episode.episode_id).job_id, manual.job_id) + self.assertEqual( + repository.load_learning_job_for_episode( + job_type="episode_boundary_learning", episode_id=episode.episode_id + ).job_id, + manual.job_id, + ) if __name__ == "__main__": diff --git a/tests/integration/storage_system_layers/test_loop_checkpoint_hardening.py b/tests/integration/storage_system_layers/test_loop_checkpoint_hardening.py index 311b7a5..5ced3e2 100644 --- a/tests/integration/storage_system_layers/test_loop_checkpoint_hardening.py +++ b/tests/integration/storage_system_layers/test_loop_checkpoint_hardening.py @@ -155,7 +155,9 @@ def _no_op_upsert(self, loop: Loop) -> None: finally: type(repository).upsert_loop = original_upsert_loop - def test_list_loop_checkpoints_round_trips_wait_condition_and_pending_tool_calls(self) -> None: + def test_list_loop_checkpoints_round_trips_wait_condition_and_pending_tool_calls( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "elephant.sqlite3") repository.bootstrap() diff --git a/tests/integration/storage_system_layers/test_repository.py b/tests/integration/storage_system_layers/test_repository.py index 43dd611..8e040f0 100644 --- a/tests/integration/storage_system_layers/test_repository.py +++ b/tests/integration/storage_system_layers/test_repository.py @@ -57,7 +57,7 @@ def test_legacy_repository_methods_are_removed(self) -> None: "load_" + "agent_run", "upsert_evidence_record_bundle", "load_evidence_record_bundle", - "append_" + "memory_ledger", + "append_" + "memory_ledger", ): self.assertFalse(hasattr(repository, method_name), method_name) @@ -96,7 +96,9 @@ def test_personal_model_round_trips_metadata(self) -> None: self.assertEqual(loaded.metadata, {"source": "test"}) self.assertEqual(tuple(model.personal_model_id for model in listed), ("pm-alpha",)) - def test_elephant_state_create_switch_list_and_delete_preserves_personal_model(self) -> None: + def test_elephant_state_create_switch_list_and_delete_preserves_personal_model( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "state" / "elephant.sqlite3") repository.bootstrap() diff --git a/tests/integration/storage_system_layers/test_schema.py b/tests/integration/storage_system_layers/test_schema.py index e8e6625..6b10f2a 100644 --- a/tests/integration/storage_system_layers/test_schema.py +++ b/tests/integration/storage_system_layers/test_schema.py @@ -30,29 +30,16 @@ def test_bootstrap_installs_clean_terminal_schema_only(self) -> None: with sqlite3.connect(database_path) as connection: table_names = { str(row[0]) - for row in connection.execute( - "SELECT name FROM sqlite_master WHERE type = 'table'" - ).fetchall() - } - state_columns = { - str(row[1]) - for row in connection.execute("PRAGMA table_info(states)").fetchall() - } - episode_columns = { - str(row[1]) - for row in connection.execute("PRAGMA table_info(episodes)").fetchall() - } - job_columns = { - str(row[1]) - for row in connection.execute("PRAGMA table_info(learning_jobs)").fetchall() + for row in connection.execute("SELECT name FROM sqlite_master WHERE type = 'table'").fetchall() } + state_columns = {str(row[1]) for row in connection.execute("PRAGMA table_info(states)").fetchall()} + episode_columns = {str(row[1]) for row in connection.execute("PRAGMA table_info(episodes)").fetchall()} + job_columns = {str(row[1]) for row in connection.execute("PRAGMA table_info(learning_jobs)").fetchall()} fact_columns = { - str(row[1]) - for row in connection.execute("PRAGMA table_info(personal_model_facts)").fetchall() + str(row[1]) for row in connection.execute("PRAGMA table_info(personal_model_facts)").fetchall() } semantic_columns = { - str(row[1]) - for row in connection.execute("PRAGMA table_info(semantic_index_entries)").fetchall() + str(row[1]) for row in connection.execute("PRAGMA table_info(semantic_index_entries)").fetchall() } self.assertTrue( @@ -108,7 +95,9 @@ def test_schema_declares_reset_delete_boundaries(self) -> None: self.assertEqual(semantic_fks["states"], "CASCADE") self.assertNotIn("records", semantic_fks) - def test_bootstrap_rejects_existing_database_without_clean_schema_marker(self) -> None: + def test_bootstrap_rejects_existing_database_without_clean_schema_marker( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: database_path = Path(tmpdir) / "state" / "elephant.sqlite3" database_path.parent.mkdir(parents=True, exist_ok=True) @@ -161,8 +150,7 @@ def test_bootstrap_resets_same_version_schema_drift(self) -> None: with sqlite3.connect(database_path) as connection: semantic_columns = { - str(row[1]) - for row in connection.execute("PRAGMA table_info(semantic_index_entries)").fetchall() + str(row[1]) for row in connection.execute("PRAGMA table_info(semantic_index_entries)").fetchall() } rows = connection.execute("SELECT semantic_index_entry_id FROM semantic_index_entries").fetchall() diff --git a/tests/integration/tools_skills/AGENTS.md b/tests/integration/tools_skills/AGENTS.md index d429480..3448b8a 100644 --- a/tests/integration/tools_skills/AGENTS.md +++ b/tests/integration/tools_skills/AGENTS.md @@ -8,4 +8,3 @@ Rules: - cover registry, loader, scope, dependency, and execution wiring together - keep fixtures JSON-shaped so the loader stays stdlib-only - do not add app-level process assumptions here - diff --git a/tests/integration/tools_skills/test_builtin_web_search.py b/tests/integration/tools_skills/test_builtin_web_search.py index e02dfc3..ad50098 100644 --- a/tests/integration/tools_skills/test_builtin_web_search.py +++ b/tests/integration/tools_skills/test_builtin_web_search.py @@ -102,17 +102,21 @@ def test_web_search_prefers_html_results_and_respects_limit(self) -> None: def test_web_search_keeps_long_result_lists_readable(self) -> None: runtime = self._make_runtime() snippet = "Long paper summary sentence. " * 10 - html = "" + "".join( - f''' + html = ( + "" + + "".join( + f"""

Paper {index}

{snippet}
- ''' - for index in range(1, 7) - ) + "" + """ + for index in range(1, 7) + ) + + "" + ) with mock.patch( "packages.tools.handlers_network.urlopen", @@ -129,7 +133,9 @@ def test_web_search_keeps_long_result_lists_readable(self) -> None: self.assertIn("6. Paper 6", result.summary) self.assertGreater(len(result.summary), 1600) - def test_web_search_falls_back_to_instant_answer_when_html_results_are_empty(self) -> None: + def test_web_search_falls_back_to_instant_answer_when_html_results_are_empty( + self, + ) -> None: runtime = self._make_runtime() html = "

no direct results

" instant_answer = json.dumps( diff --git a/tests/integration/tools_skills/test_tools_and_skills_runtime.py b/tests/integration/tools_skills/test_tools_and_skills_runtime.py index f1c0fc3..f303d93 100644 --- a/tests/integration/tools_skills/test_tools_and_skills_runtime.py +++ b/tests/integration/tools_skills/test_tools_and_skills_runtime.py @@ -1,6 +1,5 @@ from __future__ import annotations -from datetime import UTC, datetime import json from pathlib import Path import subprocess @@ -8,10 +7,8 @@ import unittest from unittest import mock -from packages.contracts import ExperienceRecord, State -from packages.contracts.runtime import PersonalModelRuntimeState +from packages.contracts import State from packages.security import SecurityPolicy -from packages.storage import RuntimeStorageRepository from packages.tools import ( CallableApprovalGateway, InMemoryToolExecutor, @@ -76,7 +73,9 @@ def emit(self, event) -> None: class ToolsAndSkillsIntegrationTest(unittest.TestCase): - def test_tool_runtime_resolves_canonical_runtime_context_before_execution(self) -> None: + def test_tool_runtime_resolves_canonical_runtime_context_before_execution( + self, + ) -> None: registry = InMemoryToolRegistry() executor = InMemoryToolExecutor() runtime = ToolRuntime( @@ -103,22 +102,24 @@ def test_tool_runtime_resolves_canonical_runtime_context_before_execution(self) version="1.0.0", description="Capture resolved tool runtime context.", ), - handler=lambda invocation: captured.update( - { - "cwd": invocation.context.cwd, - "allowed_roots": invocation.context.allowed_roots, - "surface_id": invocation.context.surface_id, - "state_id": invocation.context.state_id, - "personal_model_id": invocation.context.personal_model_id, - "elephant_id": invocation.context.elephant_id, - "requester": invocation.context.requester, + handler=lambda invocation: ( + captured.update( + { + "cwd": invocation.context.cwd, + "allowed_roots": invocation.context.allowed_roots, + "surface_id": invocation.context.surface_id, + "state_id": invocation.context.state_id, + "personal_model_id": invocation.context.personal_model_id, + "elephant_id": invocation.context.elephant_id, + "requester": invocation.context.requester, + } + ) + or { + "execution_id": invocation.invocation_id, + "summary": "captured context", + "outcome": "success", } - ) - or { - "execution_id": invocation.invocation_id, - "summary": "captured context", - "outcome": "success", - }, + ), ) result = runtime.invoke( @@ -137,7 +138,9 @@ def test_tool_runtime_resolves_canonical_runtime_context_before_execution(self) self.assertEqual(captured["elephant_id"], "atlas") self.assertEqual(captured["requester"], "operator") - def test_tool_runtime_emits_lifecycle_events_for_successful_invocation(self) -> None: + def test_tool_runtime_emits_lifecycle_events_for_successful_invocation( + self, + ) -> None: registry = InMemoryToolRegistry() executor = InMemoryToolExecutor() runtime = ToolRuntime( @@ -188,7 +191,9 @@ def test_tool_runtime_emits_lifecycle_events_for_successful_invocation(self) -> ) self.assertEqual(events[-1].execution.summary, "created Design review") - def test_tool_runtime_preserves_original_tool_error_in_failed_execution_path(self) -> None: + def test_tool_runtime_preserves_original_tool_error_in_failed_execution_path( + self, + ) -> None: registry = InMemoryToolRegistry() executor = InMemoryToolExecutor() runtime = ToolRuntime( @@ -239,7 +244,9 @@ def _failing_handler(invocation: ToolInvocation): self.assertEqual(record.approval.decision, "approved") self.assertEqual(record.detail, "fetch failed for https://example.com") - def test_tool_runtime_registers_and_executes_with_side_effect_metadata(self) -> None: + def test_tool_runtime_registers_and_executes_with_side_effect_metadata( + self, + ) -> None: registry = InMemoryToolRegistry() executor = InMemoryToolExecutor() runtime = ToolRuntime( @@ -346,7 +353,9 @@ def test_tool_runtime_blocks_model_invocation_of_operator_only_tools(self) -> No self.assertEqual(result.outcome, "success") self.assertEqual(result.summary, "operator handled tool.skill.manage") - def test_tool_runtime_records_deferred_approval_without_executing_handler(self) -> None: + def test_tool_runtime_records_deferred_approval_without_executing_handler( + self, + ) -> None: registry = InMemoryToolRegistry() executor = InMemoryToolExecutor() runtime = ToolRuntime( @@ -428,7 +437,9 @@ def test_security_approval_gateway_can_auto_grant_deferred_reviews(self) -> None self.assertTrue(str(approval.approval_token).startswith("auto:")) self.assertTrue(any(record["family"] == "approval" for record in sink.records)) - def test_tool_manifest_loader_discovers_external_tools_and_runtime_feedback(self) -> None: + def test_tool_manifest_loader_discovers_external_tools_and_runtime_feedback( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: manifest_path = Path(tmpdir) / "tools.json" manifest_path.write_text( @@ -466,7 +477,10 @@ def test_tool_manifest_loader_discovers_external_tools_and_runtime_feedback(self self.assertEqual(manifest.source_path, str(manifest_path)) self.assertEqual(runtime.describe("tool.notes.capture").provenance, str(manifest_path)) self.assertEqual(runtime.list_manifest_loads()[0].tool_ids, ("tool.notes.capture",)) - self.assertEqual(runtime.list_manifest_loads()[0].executable_tool_ids, ("tool.notes.capture",)) + self.assertEqual( + runtime.list_manifest_loads()[0].executable_tool_ids, + ("tool.notes.capture",), + ) result = runtime.invoke( "tool.notes.capture", @@ -481,7 +495,9 @@ def test_tool_manifest_loader_discovers_external_tools_and_runtime_feedback(self self.assertEqual(runtime.list_executions()[0].invocation.tool_id, "tool.notes.capture") self.assertEqual(runtime.list_executions()[0].detail, "captured Operator review") - def test_tool_manifest_loader_preserves_enable_override_and_records_blocked_invocations(self) -> None: + def test_tool_manifest_loader_preserves_enable_override_and_records_blocked_invocations( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: manifest_path = Path(tmpdir) / "tools.json" manifest_path.write_text( @@ -524,9 +540,14 @@ def test_tool_manifest_loader_preserves_enable_override_and_records_blocked_invo ) self.assertEqual(result.outcome, "blocked") self.assertEqual(runtime.list_executions()[0].approved, False) - self.assertEqual(runtime.list_executions()[0].detail, "blocked by callable approval gateway") + self.assertEqual( + runtime.list_executions()[0].detail, + "blocked by callable approval gateway", + ) - def test_sync_custom_mcp_tools_registers_model_visible_handlers_and_removes_stale_tools(self) -> None: + def test_sync_custom_mcp_tools_registers_model_visible_handlers_and_removes_stale_tools( + self, + ) -> None: runtime = ToolRuntime(approval_gateway=CallableApprovalGateway(lambda *_: True)) config = { "mcp_servers": { @@ -534,7 +555,11 @@ def test_sync_custom_mcp_tools_registers_model_visible_handlers_and_removes_stal "label": "Filesystem", "transport": "stdio", "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp/demo"], + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp/demo", + ], "env": {"ALLOW": "1"}, "tools": { "read_file": { @@ -610,7 +635,10 @@ def fake_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]: cwd=Path("/tmp/tool-root"), ) self.assertFalse(runtime.describe("mcp.filesystem.read_file").enabled) - self.assertEqual(runtime.list_tools(audience="model", enabled_only=True, available_only=True), ()) + self.assertEqual( + runtime.list_tools(audience="model", enabled_only=True, available_only=True), + (), + ) sync_custom_mcp_tools( runtime, @@ -620,7 +648,9 @@ def fake_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]: ) self.assertIsNone(runtime.describe("mcp.filesystem.read_file")) - def test_sync_custom_mcp_tools_remote_runtime_uses_mcporter_config_shape(self) -> None: + def test_sync_custom_mcp_tools_remote_runtime_uses_mcporter_config_shape( + self, + ) -> None: runtime = ToolRuntime(approval_gateway=CallableApprovalGateway(lambda *_: True)) config = { "mcp_servers": { @@ -713,9 +743,7 @@ def test_skill_loader_resolves_scope_and_dependencies(self) -> None: "surface_kinds": ["cli"], "modes": ["companion"], }, - "dependencies": [ - {"skill_id": "skill.write-helper", "required": True} - ], + "dependencies": [{"skill_id": "skill.write-helper", "required": True}], }, ] } @@ -739,7 +767,10 @@ def test_skill_loader_resolves_scope_and_dependencies(self) -> None: self.assertEqual(manifest.source_path, str(manifest_path)) self.assertEqual(len(manifest.skills), 2) self.assertEqual(runtime.list_skills(), manifest.skills) - self.assertEqual(runtime.list_manifest_loads()[0].skill_ids, ("skill.write-helper", "skill.voice-helper")) + self.assertEqual( + runtime.list_manifest_loads()[0].skill_ids, + ("skill.write-helper", "skill.voice-helper"), + ) self.assertEqual( tuple( skill.skill_id @@ -826,7 +857,9 @@ def test_skill_runtime_resolve_for_context_applies_state_boundaries(self) -> Non ) self.assertNotIn("skill.blocked", tuple(skill.skill_id for skill in resolved)) - def test_skill_runtime_activate_rejects_suppressed_retired_and_state_blocked_skills(self) -> None: + def test_skill_runtime_activate_rejects_suppressed_retired_and_state_blocked_skills( + self, + ) -> None: catalog = InMemorySkillCatalog() for definition in ( SkillDefinition( @@ -953,7 +986,9 @@ def test_skill_hub_keeps_disabled_installed_entries_discoverable(self) -> None: self.assertEqual(matches[0].reference, "elephant-installed:installed-skill") self.assertFalse(bool(matches[0].metadata.get("default_enabled"))) - def test_external_skill_source_accepts_parent_with_symlinked_skills_dir(self) -> None: + def test_external_skill_source_accepts_parent_with_symlinked_skills_dir( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: root = Path(tmpdir) external_parent = root / ".external" @@ -997,7 +1032,9 @@ def test_external_skill_source_accepts_parent_with_symlinked_skills_dir(self) -> ) (external_parent / "skills").symlink_to(real_skills_root, target_is_directory=True) (real_skills_root / "linked-helper").symlink_to(linked_source_dir, target_is_directory=True) - hub = SkillHub(sources=default_skill_hub_sources(external_dirs=(external_parent,), install_root=root / "elephant")) + hub = SkillHub( + sources=default_skill_hub_sources(external_dirs=(external_parent,), install_root=root / "elephant") + ) entries = {entry.skill_id: entry for entry in hub.list()} @@ -1006,9 +1043,14 @@ def test_external_skill_source_accepts_parent_with_symlinked_skills_dir(self) -> self.assertEqual(entries["notes-helper"].source_id, "external") self.assertEqual(entries["linked-helper"].source_id, "external") self.assertEqual(Path(entries["notes-helper"].skill_path).resolve(), skill_dir.resolve()) - self.assertEqual(Path(entries["linked-helper"].skill_path).resolve(), linked_source_dir.resolve()) + self.assertEqual( + Path(entries["linked-helper"].skill_path).resolve(), + linked_source_dir.resolve(), + ) - def test_materialized_skill_package_persists_public_source_and_install_provenance(self) -> None: + def test_materialized_skill_package_persists_public_source_and_install_provenance( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: root = Path(tmpdir) source_dir = root / "source" @@ -1120,14 +1162,7 @@ def test_builtin_skill_creator_package_is_loadable(self) -> None: def test_builtin_elephant_agent_package_is_loadable(self) -> None: repo_root = Path(__file__).resolve().parents[3] - skill_dir = ( - repo_root - / "packages" - / "skills" - / "builtin_packages" - / "autonomous-ai-agents" - / "elephant-agent" - ) + skill_dir = repo_root / "packages" / "skills" / "builtin_packages" / "autonomous-ai-agents" / "elephant-agent" definition = load_skill_package_definition(skill_dir) @@ -1138,14 +1173,7 @@ def test_builtin_elephant_agent_package_is_loadable(self) -> None: def test_builtin_ascii_art_package_is_loadable(self) -> None: repo_root = Path(__file__).resolve().parents[3] - skill_dir = ( - repo_root - / "packages" - / "skills" - / "builtin_packages" - / "creative" - / "ascii-art" - ) + skill_dir = repo_root / "packages" / "skills" / "builtin_packages" / "creative" / "ascii-art" definition = load_skill_package_definition(skill_dir) @@ -1155,7 +1183,9 @@ def test_builtin_ascii_art_package_is_loadable(self) -> None: self.assertEqual(definition.metadata.get("source_kind"), "elephant-builtin") self.assertTrue(definition.metadata.get("default_enabled")) - def test_builtin_skill_catalog_unifies_runtime_defaults_and_hub_projection(self) -> None: + def test_builtin_skill_catalog_unifies_runtime_defaults_and_hub_projection( + self, + ) -> None: default_catalog = builtin_skill_catalog() override_catalog = builtin_skill_catalog({"shell-execution": False, "docker-management": True}) section_map = {section.section_id: section for section in default_catalog.sections} @@ -1170,12 +1200,24 @@ def test_builtin_skill_catalog_unifies_runtime_defaults_and_hub_projection(self) self.assertTrue( all(entry.source_kind == "elephant-builtin" for entry in default_catalog.entries), ) - self.assertIn("shell-execution", {entry.skill_id for entry in section_map["runtime"].entries}) - self.assertIn("plan", {entry.skill_id for entry in section_map["software-development"].entries}) + self.assertIn( + "shell-execution", + {entry.skill_id for entry in section_map["runtime"].entries}, + ) + self.assertIn( + "plan", + {entry.skill_id for entry in section_map["software-development"].entries}, + ) self.assertIn("ascii-art", {entry.skill_id for entry in section_map["creative"].entries}) - self.assertIn("docker-management", {entry.skill_id for entry in section_map["devops"].entries}) + self.assertIn( + "docker-management", + {entry.skill_id for entry in section_map["devops"].entries}, + ) self.assertIn("1password", {entry.skill_id for entry in section_map["security"].entries}) - self.assertIn("minecraft-modpack-server", {entry.skill_id for entry in section_map["gaming"].entries}) + self.assertIn( + "minecraft-modpack-server", + {entry.skill_id for entry in section_map["gaming"].entries}, + ) self.assertIn("apple-notes", {entry.skill_id for entry in section_map["apple"].entries}) catalog = {entry.skill_id: entry for entry in builtin_skill_catalog_entries()} @@ -1236,8 +1278,14 @@ def test_builtin_skill_catalog_unifies_runtime_defaults_and_hub_projection(self) len({entry.slug for entry in site_catalog.entries}), len(site_catalog.entries), ) - self.assertEqual(site_catalog.stats["external_source_count"], len(site_catalog.external_sources)) - self.assertLess(site_catalog.stats["default_enabled_count"], site_catalog.stats["entry_count"]) + self.assertEqual( + site_catalog.stats["external_source_count"], + len(site_catalog.external_sources), + ) + self.assertLess( + site_catalog.stats["default_enabled_count"], + site_catalog.stats["entry_count"], + ) apple_notes = next(entry for entry in site_catalog.entries if entry.skill_id == "apple-notes") docker_management = next(entry for entry in site_catalog.entries if entry.skill_id == "docker-management") self.assertEqual(apple_notes.detail_doc_id, "skillhub/library/apple-notes") @@ -1248,8 +1296,14 @@ def test_builtin_skill_catalog_unifies_runtime_defaults_and_hub_projection(self) self.assertIn("Bundled", apple_notes.packaging_posture) self.assertIn("elephant skills install apple-notes", apple_notes.install_posture) self.assertEqual(docker_management.default_enabled_label, "Disabled by default") - self.assertEqual(docker_management.install_command, "elephant skills install docker-management") - self.assertIn("elephant skills install ", site_catalog.operator_install_posture) + self.assertEqual( + docker_management.install_command, + "elephant skills install docker-management", + ) + self.assertIn( + "elephant skills install ", + site_catalog.operator_install_posture, + ) self.assertIn("Creative", {section.display_name for section in site_catalog.sections}) self.assertIn("Data Science", {section.display_name for section in site_catalog.sections}) self.assertIn("DevOps", {section.display_name for section in site_catalog.sections}) @@ -1257,7 +1311,10 @@ def test_builtin_skill_catalog_unifies_runtime_defaults_and_hub_projection(self) self.assertIn("MLOps", {section.display_name for section in site_catalog.sections}) self.assertIn("Security", {section.display_name for section in site_catalog.sections}) self.assertEqual(site_catalog.external_sources[0].source_id, "github") - self.assertIn("elephant skills search --source github", site_catalog.external_sources[0].search_command) + self.assertIn( + "elephant skills search --source github", + site_catalog.external_sources[0].search_command, + ) repo_root = Path(__file__).resolve().parents[3] generated_catalog_module = (repo_root / "apps" / "site" / "src" / "generated" / "skillhubCatalog.ts").read_text( @@ -1265,7 +1322,9 @@ def test_builtin_skill_catalog_unifies_runtime_defaults_and_hub_projection(self) ) self.assertNotIn("slash_command", generated_catalog_module) catalog_prefix = "export const skillHubCatalog: SkillHubCatalogData = " - catalog_suffix = ";\n\nexport const skillHubCatalogById: Record = Object.fromEntries(" + catalog_suffix = ( + ";\n\nexport const skillHubCatalogById: Record = Object.fromEntries(" + ) catalog_start = generated_catalog_module.index(catalog_prefix) + len(catalog_prefix) catalog_end = generated_catalog_module.index(catalog_suffix, catalog_start) generated_catalog_payload = json.loads(generated_catalog_module[catalog_start:catalog_end]) @@ -1279,23 +1338,30 @@ def test_builtin_skill_catalog_unifies_runtime_defaults_and_hub_projection(self) sorted(path.stem for path in (repo_root / "apps" / "site" / "docs" / "skillhub" / "library").glob("*.mdx")) ) site_page_slugs = tuple( - sorted(path.stem for path in (repo_root / "apps" / "site" / "src" / "pages" / "skillhub" / "library").glob("*.tsx")) + sorted( + path.stem + for path in (repo_root / "apps" / "site" / "src" / "pages" / "skillhub" / "library").glob("*.tsx") + ) ) self.assertEqual(site_doc_slugs, expected_slugs) self.assertEqual(site_page_slugs, expected_slugs) - hub = SkillHub( - sources=(SkillHubSource("builtin", "Built In", builtin_elephant_skill_source_root()),) - ) + hub = SkillHub(sources=(SkillHubSource("builtin", "Built In", builtin_elephant_skill_source_root()),)) entries = {entry.skill_id: entry for entry in hub.list()} builtin_hub_entries = { entry.skill_id: entry for entry in builtin_skill_hub_entries({"shell-execution": False, "docker-management": True}) } - self.assertEqual(tuple(entries), tuple(entry.skill_id for entry in builtin_skill_hub_entries())) + self.assertEqual( + tuple(entries), + tuple(entry.skill_id for entry in builtin_skill_hub_entries()), + ) self.assertNotIn("shell-execution", builtin_hub_entries) self.assertIn("docker-management", builtin_hub_entries) - self.assertEqual(tuple(builtin_hub_entries), tuple(entry.skill_id for entry in override_catalog.hub_entries())) + self.assertEqual( + tuple(builtin_hub_entries), + tuple(entry.skill_id for entry in override_catalog.hub_entries()), + ) self.assertTrue(entries["shell-execution"].metadata.get("default_enabled")) self.assertTrue(entries["apple-notes"].metadata.get("default_enabled")) self.assertEqual(entries["apple-notes"].metadata.get("source_kind"), "elephant-builtin") @@ -1326,7 +1392,9 @@ def test_skill_scope_matches_strictly_by_context(self) -> None: ) ) - def test_skill_dependency_validation_reports_missing_required_dependencies(self) -> None: + def test_skill_dependency_validation_reports_missing_required_dependencies( + self, + ) -> None: catalog = InMemorySkillCatalog() catalog.register( SkillDefinition( diff --git a/tests/scenarios/companion/test_companion_scenarios.py b/tests/scenarios/companion/test_companion_scenarios.py index 4c05b01..fb35282 100644 --- a/tests/scenarios/companion/test_companion_scenarios.py +++ b/tests/scenarios/companion/test_companion_scenarios.py @@ -69,7 +69,9 @@ def test_governance_state_exposes_text_first_persona_state(self) -> None: self.assertTrue(relationship_policy.text_first) self.assertIn("companion text-first continuity", relationship_policy.summary()) - def test_companion_governance_path_distinguishes_defaults_from_onboarded_identity(self) -> None: + def test_companion_governance_path_distinguishes_defaults_from_onboarded_identity( + self, + ) -> None: from packages.contracts.runtime import PersonalModelRuntimeState from packages.state import ( CompanionSettings, @@ -199,7 +201,9 @@ def test_profile_writers_can_update_identity_and_elephant_state(self) -> None: self.assertEqual(identity.personality_preset, "operator") self.assertEqual(user.preferred_name, "Bit") - def test_companion_turn_reconciliation_does_not_mutate_profile_without_management_tools(self) -> None: + def test_companion_turn_reconciliation_does_not_mutate_profile_without_management_tools( + self, + ) -> None: from apps.cli.runtime import CliRuntime with tempfile.TemporaryDirectory() as tmpdir: diff --git a/tests/scenarios/context/AGENTS.md b/tests/scenarios/context/AGENTS.md index f231e46..ff27ae0 100644 --- a/tests/scenarios/context/AGENTS.md +++ b/tests/scenarios/context/AGENTS.md @@ -8,4 +8,3 @@ Rules: - one scenario per file - make overflow and recovery assertions explicit - do not encode app-level assumptions - diff --git a/tests/scenarios/continuity/test_continuity_scenarios.py b/tests/scenarios/continuity/test_continuity_scenarios.py index 4c5ee38..0f33b78 100644 --- a/tests/scenarios/continuity/test_continuity_scenarios.py +++ b/tests/scenarios/continuity/test_continuity_scenarios.py @@ -41,10 +41,12 @@ def test_state_continuity_fixture_declares_text_only_boundary(self) -> None: self.assertIn("no voice transport", fixture) self.assertNotIn("later voice support", fixture) - def test_companion_fixture_remains_the_text_first_identity_replacement(self) -> None: + def test_companion_fixture_remains_the_text_first_identity_replacement( + self, + ) -> None: fixture = ( - SCENARIOS_PATH.parents[1] / "companion" / "text-first-continuity.md" - ).read_text(encoding="utf-8").lower() + (SCENARIOS_PATH.parents[1] / "companion" / "text-first-continuity.md").read_text(encoding="utf-8").lower() + ) self.assertIn("text-first", fixture) self.assertIn("without voice transport", fixture) diff --git a/tests/unit/api/test_internal_triggers.py b/tests/unit/api/test_internal_triggers.py index 0520450..f98c67f 100644 --- a/tests/unit/api/test_internal_triggers.py +++ b/tests/unit/api/test_internal_triggers.py @@ -34,7 +34,10 @@ def test_reflect_dream_diary_features_get_separate_target_dates(self) -> None: repository = _RepositoryStub() app = SimpleNamespace(repository=repository) - with patch("apps.learning_worker_runtime.ensure_learning_worker_running", lambda **_: None): + with patch( + "apps.learning_worker_runtime.ensure_learning_worker_running", + lambda **_: None, + ): result = trigger_reflect_job(app, trigger="manual", features="dream,diary") self.assertEqual(result["status"], "queued") @@ -42,13 +45,19 @@ def test_reflect_dream_diary_features_get_separate_target_dates(self) -> None: metadata = repository.enqueued_metadata or {} self.assertEqual(metadata["features"], "dream,diary") self.assertEqual(metadata["target_date"], date_type.today().isoformat()) - self.assertEqual(metadata["diary_target_date"], (date_type.today() - timedelta(days=1)).isoformat()) + self.assertEqual( + metadata["diary_target_date"], + (date_type.today() - timedelta(days=1)).isoformat(), + ) def test_reflect_diary_feature_defaults_to_yesterday(self) -> None: repository = _RepositoryStub() app = SimpleNamespace(repository=repository) - with patch("apps.learning_worker_runtime.ensure_learning_worker_running", lambda **_: None): + with patch( + "apps.learning_worker_runtime.ensure_learning_worker_running", + lambda **_: None, + ): trigger_reflect_job(app, trigger="manual", features="diary") metadata = repository.enqueued_metadata or {} diff --git a/tests/unit/api/test_runtime_http_dispatch_helpers.py b/tests/unit/api/test_runtime_http_dispatch_helpers.py index 29abb54..895b122 100644 --- a/tests/unit/api/test_runtime_http_dispatch_helpers.py +++ b/tests/unit/api/test_runtime_http_dispatch_helpers.py @@ -19,7 +19,10 @@ def test_nightly_dream_jobs_are_marked_as_system_jobs(self) -> None: status="scheduled", profile_id=None, elephant_id=None, - payload={"trigger": "dream", "summary": "nightly Personal Model consolidation"}, + payload={ + "trigger": "dream", + "summary": "nightly Personal Model consolidation", + }, created_at=now, updated_at=now, next_run_at=now, diff --git a/tests/unit/apps/test_learning_diff_emitter.py b/tests/unit/apps/test_learning_diff_emitter.py index c5fe86b..08a4fdb 100644 --- a/tests/unit/apps/test_learning_diff_emitter.py +++ b/tests/unit/apps/test_learning_diff_emitter.py @@ -19,7 +19,9 @@ class LearningResultWriteTests(unittest.TestCase): - def test_learning_context_packet_uses_basic_anchors_and_tool_directives(self) -> None: + def test_learning_context_packet_uses_basic_anchors_and_tool_directives( + self, + ) -> None: with tempfile.TemporaryDirectory() as tempdir: state_dir = Path(tempdir) / "state" state_dir.mkdir(parents=True, exist_ok=True) @@ -62,7 +64,9 @@ def test_learning_context_packet_uses_basic_anchors_and_tool_directives(self) -> metadata={"topic": "world.project.private.detail"}, ) ) - job = runtime.schedule_learning_for_session(session_id=session.episode_id, trigger="manual", start_worker=False) + job = runtime.schedule_learning_for_session( + session_id=session.episode_id, trigger="manual", start_worker=False + ) packet = build_evidence(runtime, job, resolve_features("manual")) @@ -81,7 +85,13 @@ def test_learning_result_tool_writes_result_json(self) -> None: state_dir = root / "state" state_dir.mkdir(parents=True, exist_ok=True) (root / "profile.json").write_text( - json.dumps({"profile_id": "profile-companion", "display_name": "Elephant Agent", "mode": "companion"}), + json.dumps( + { + "profile_id": "profile-companion", + "display_name": "Elephant Agent", + "mode": "companion", + } + ), encoding="utf-8", ) runtime = CliRuntime.create(state_dir=state_dir) diff --git a/tests/unit/cli/test_cron_learning_jobs.py b/tests/unit/cli/test_cron_learning_jobs.py index 11ee559..02344e5 100644 --- a/tests/unit/cli/test_cron_learning_jobs.py +++ b/tests/unit/cli/test_cron_learning_jobs.py @@ -18,7 +18,10 @@ def schedule_learning_for_session(**kwargs: object) -> SimpleNamespace: runtime = SimpleNamespace(schedule_learning_for_session=schedule_learning_for_session) cron_job = SimpleNamespace( name="Nightly dream", - payload={"trigger": "dream", "summary": "nightly Personal Model, question, skill, and diary maintenance"}, + payload={ + "trigger": "dream", + "summary": "nightly Personal Model, question, skill, and diary maintenance", + }, ) outcome, summary = CliRuntimeExtensionsMixin._execute_cron_learning_job( # type: ignore[misc] @@ -31,7 +34,10 @@ def schedule_learning_for_session(**kwargs: object) -> SimpleNamespace: self.assertEqual(outcome, "success") self.assertIn("job:dream", summary) self.assertEqual(captured["trigger"], "dream") - self.assertEqual(captured["metadata"], {"target_date": yesterday, "diary_target_date": yesterday}) + self.assertEqual( + captured["metadata"], + {"target_date": yesterday, "diary_target_date": yesterday}, + ) if __name__ == "__main__": diff --git a/tests/unit/cli/test_dashboard_command.py b/tests/unit/cli/test_dashboard_command.py index 681c935..06711d5 100644 --- a/tests/unit/cli/test_dashboard_command.py +++ b/tests/unit/cli/test_dashboard_command.py @@ -27,7 +27,9 @@ def read(self) -> bytes: class DashboardCommandTest(unittest.TestCase): - def test_try_daemon_dashboard_url_requires_health_and_dashboard_payload(self) -> None: + def test_try_daemon_dashboard_url_requires_health_and_dashboard_payload( + self, + ) -> None: with tempfile.TemporaryDirectory() as tempdir: state_dir = Path(tempdir) / "herd" state_dir.mkdir() @@ -59,7 +61,9 @@ def test_try_daemon_dashboard_url_returns_none_without_record(self) -> None: self.assertIsNone(url) - def test_probe_daemon_dashboard_reports_running_process_when_http_is_unavailable(self) -> None: + def test_probe_daemon_dashboard_reports_running_process_when_http_is_unavailable( + self, + ) -> None: with tempfile.TemporaryDirectory() as tempdir: state_dir = Path(tempdir) / "herd" state_dir.mkdir() @@ -207,7 +211,9 @@ def test_run_dashboard_guides_user_when_daemon_is_not_running(self) -> None: mock.patch.object( dashboard_command, "_probe_daemon_dashboard", - return_value=dashboard_command.DaemonDashboardProbe(dashboard_url=None, reason="missing_runtime_record"), + return_value=dashboard_command.DaemonDashboardProbe( + dashboard_url=None, reason="missing_runtime_record" + ), ), mock.patch.object(dashboard_command, "_print_cli_card") as print_card, ): @@ -245,7 +251,9 @@ def test_run_dashboard_starts_daemon_when_not_running(self) -> None: start_daemon.assert_called_once_with(plan) printed.assert_called_once_with("Elephant Agent dashboard URL: http://127.0.0.1:8900/dashboard/") - def test_run_dashboard_reports_running_daemon_when_dashboard_is_unavailable(self) -> None: + def test_run_dashboard_reports_running_daemon_when_dashboard_is_unavailable( + self, + ) -> None: plan = dashboard_command.DashboardLaunchPlan(state_dir=Path("/tmp/elephant-herd")) with ( @@ -273,7 +281,11 @@ def test_run_dashboard_does_not_probe_daemon_without_frontend_assets(self) -> No plan = dashboard_command.DashboardLaunchPlan(state_dir=Path("/tmp/elephant-herd")) with ( - mock.patch.object(dashboard_command, "DASHBOARD_DIST_INDEX", Path("/tmp/missing-dashboard-index.html")), + mock.patch.object( + dashboard_command, + "DASHBOARD_DIST_INDEX", + Path("/tmp/missing-dashboard-index.html"), + ), mock.patch.object(dashboard_command, "_ensure_frontend_dist", return_value=False), mock.patch.object(dashboard_command, "_probe_daemon_dashboard") as probe, mock.patch.object(dashboard_command, "_print_cli_card") as print_card, diff --git a/tests/unit/cli/test_learning_crons.py b/tests/unit/cli/test_learning_crons.py index f0ff754..be55c15 100644 --- a/tests/unit/cli/test_learning_crons.py +++ b/tests/unit/cli/test_learning_crons.py @@ -35,7 +35,9 @@ def create_job(self, *, name: str, schedule_text: str, payload: Mapping[str, Any class NightlyLearningCronTest(unittest.TestCase): - def test_single_nightly_cron_removes_legacy_diary_and_creates_dream_bundle(self) -> None: + def test_single_nightly_cron_removes_legacy_diary_and_creates_dream_bundle( + self, + ) -> None: diary = SimpleNamespace( job_id="cron:diary", name="Daily diary", diff --git a/tests/unit/cli/test_main.py b/tests/unit/cli/test_main.py index 97dc925..d793fe4 100644 --- a/tests/unit/cli/test_main.py +++ b/tests/unit/cli/test_main.py @@ -79,7 +79,13 @@ def _patch_choice_menu_dependencies(application_cls, *, bindings_cls=None, radio stack.enter_context(mock.patch.object(cli_wizard, "HSplit", lambda children, padding=0: (children, padding))) stack.enter_context(mock.patch.object(cli_wizard, "Window", lambda content, **kwargs: (content, kwargs))) stack.enter_context(mock.patch.object(cli_wizard, "FormattedTextControl", lambda fragments: fragments)) - stack.enter_context(mock.patch.object(cli_wizard, "Layout", lambda dialog, focused_element=None: (dialog, focused_element))) + stack.enter_context( + mock.patch.object( + cli_wizard, + "Layout", + lambda dialog, focused_element=None: (dialog, focused_element), + ) + ) stack.enter_context(mock.patch.object(cli_wizard, "PromptDimension", SimpleNamespace)) stack.enter_context(mock.patch.object(cli_wizard, "Button", _FakeButton)) stack.enter_context(mock.patch.object(cli_wizard, "Dialog", _FakeDialog)) @@ -100,7 +106,9 @@ def _handle_enter(self) -> None: class CliInitIntroTest(unittest.TestCase): - def test_init_welcome_frame_renders_enter_gate_without_removed_intro_animation(self) -> None: + def test_init_welcome_frame_renders_enter_gate_without_removed_intro_animation( + self, + ) -> None: if not cli_main_setup.RICH_AVAILABLE or cli_main_setup.Console is None: self.skipTest("rich is not available") @@ -157,7 +165,9 @@ def test_birth_wizard_intro_uses_short_pm_first_copy(self) -> None: self.assertNotIn("database dump", rendered) self.assertNotIn("the elephant before recall", rendered) - def test_cli_help_intro_renders_only_once_without_separator_duplication(self) -> None: + def test_cli_help_intro_renders_only_once_without_separator_duplication( + self, + ) -> None: if not cli_main_support.RICH_AVAILABLE or cli_main_support.Console is None: self.skipTest("rich is not available") @@ -177,7 +187,11 @@ def test_cli_help_intro_renders_only_once_without_separator_duplication(self) -> self.assertNotIn("• • init", rendered) def test_render_cli_banner_mark_uses_stage_zero_elephant(self) -> None: - with mock.patch.object(cli_main_support, "render_stage_zero_elephant_mark", return_value="elephant-mark") as render_stage_zero_elephant_mark: + with mock.patch.object( + cli_main_support, + "render_stage_zero_elephant_mark", + return_value="elephant-mark", + ) as render_stage_zero_elephant_mark: result = cli_main_support._render_cli_banner_mark() self.assertEqual(result, "elephant-mark") @@ -185,7 +199,9 @@ def test_render_cli_banner_mark_uses_stage_zero_elephant(self) -> None: class InitQuestionDesignTest(unittest.TestCase): - def test_starter_questions_use_human_labels_for_manual_and_blank_options(self) -> None: + def test_starter_questions_use_human_labels_for_manual_and_blank_options( + self, + ) -> None: for spec in cli_main_impl._STARTER_QUESTIONS: choices = tuple(spec["choices_zh"]) by_value = {choice[0]: choice for choice in choices} @@ -255,7 +271,9 @@ def test_attention_choice_persists_hidden_profile_answer_for_pm(self) -> None: self.assertNotIn("像站在一条路将要分开的地方", answer) self.assertNotEqual(answer, selected) - def test_english_attention_choice_persists_hidden_profile_answer_for_pm(self) -> None: + def test_english_attention_choice_persists_hidden_profile_answer_for_pm( + self, + ) -> None: selected = "standing at a fork" choice = next(choice for choice in cli_main_impl._ATTENTION_CHOICES_EN if choice[0] == selected) with mock.patch.object(cli_main_impl, "_wizard_choice_prompt", return_value=selected): @@ -277,7 +295,11 @@ def test_english_attention_choice_persists_hidden_profile_answer_for_pm(self) -> def test_attention_manual_input_persists_user_words(self) -> None: with ( mock.patch.object(cli_main_impl, "_wizard_choice_prompt", return_value="type"), - mock.patch.object(cli_main_impl, "_wizard_text_prompt", return_value="我正在重新整理生活优先级"), + mock.patch.object( + cli_main_impl, + "_wizard_text_prompt", + return_value="我正在重新整理生活优先级", + ), ): answer = cli_main_impl._prompt_choice_with_type( "zh", @@ -427,25 +449,43 @@ def test_guard_radio_list_selection_bounds_clamps_negative_index(self) -> None: self.assertEqual(radio_list._selected_index, 0) self.assertEqual(radio_list.current_value, "companion") + def test_wizard_choice_window_caps_long_lists_to_nine_rows(self) -> None: self.assertEqual(_wizard_choice_window(4, 0), (0, 4)) self.assertEqual(_wizard_choice_window(12, 0), (0, WIZARD_MAX_VISIBLE_CHOICES)) self.assertEqual(_wizard_choice_window(12, 6), (2, 11)) self.assertEqual(_wizard_choice_window(12, 11), (3, 12)) - def test_wizard_choice_fragments_render_without_blank_lines_between_options(self) -> None: + def test_wizard_choice_fragments_render_without_blank_lines_between_options( + self, + ) -> None: choices = ( - WizardChoice(value="companion", label="Companion", detail="Steady and present.", emoji="🤝"), - WizardChoice(value="operator", label="Operator", detail="Direct and durable.", emoji="🛠️"), + WizardChoice( + value="companion", + label="Companion", + detail="Steady and present.", + emoji="🤝", + ), + WizardChoice( + value="operator", + label="Operator", + detail="Direct and durable.", + emoji="🛠️", + ), ) text = "".join(fragment for _, fragment in _wizard_choice_fragments("Choose", "Prompt", choices, selected=0)) - self.assertIn("› 🤝 Companion\n Steady and present.\n 🛠️ Operator\n Direct and durable.\n", text) + self.assertIn( + "› 🤝 Companion\n Steady and present.\n 🛠️ Operator\n Direct and durable.\n", + text, + ) self.assertNotIn("Steady and present.\n\n Operator", text) self.assertIn("Enter confirms", text) - def test_wizard_choice_fragments_show_scroll_hints_for_hidden_provider_rows(self) -> None: + def test_wizard_choice_fragments_show_scroll_hints_for_hidden_provider_rows( + self, + ) -> None: choices = tuple( WizardChoice( value=f"provider-{index}", @@ -467,8 +507,18 @@ def test_wizard_choice_fragments_show_scroll_hints_for_hidden_provider_rows(self def test_wizard_choice_fragments_show_back_hint_when_allowed(self) -> None: choices = ( - WizardChoice(value="companion", label="Companion", detail="Steady and present.", emoji="🤝"), - WizardChoice(value="operator", label="Operator", detail="Direct and durable.", emoji="🛠️"), + WizardChoice( + value="companion", + label="Companion", + detail="Steady and present.", + emoji="🤝", + ), + WizardChoice( + value="operator", + label="Operator", + detail="Direct and durable.", + emoji="🛠️", + ), ) text = "".join( @@ -480,8 +530,18 @@ def test_wizard_choice_fragments_show_back_hint_when_allowed(self) -> None: def test_wizard_choice_menu_uses_centered_dialog_application(self) -> None: choices = ( - WizardChoice(value="companion", label="Companion", detail="Steady and present.", emoji="🤝"), - WizardChoice(value="operator", label="Operator", detail="Direct and durable.", emoji="🛠️"), + WizardChoice( + value="companion", + label="Companion", + detail="Steady and present.", + emoji="🤝", + ), + WizardChoice( + value="operator", + label="Operator", + detail="Direct and durable.", + emoji="🛠️", + ), ) captured: dict[str, object] = {} @@ -501,8 +561,18 @@ def run(self): def test_wizard_choice_menu_runs_dialog_in_thread_when_loop_is_active(self) -> None: choices = ( - WizardChoice(value="companion", label="Companion", detail="Steady and present.", emoji="🤝"), - WizardChoice(value="operator", label="Operator", detail="Direct and durable.", emoji="🛠️"), + WizardChoice( + value="companion", + label="Companion", + detail="Steady and present.", + emoji="🤝", + ), + WizardChoice( + value="operator", + label="Operator", + detail="Direct and durable.", + emoji="🛠️", + ), ) captured: dict[str, object] = {} @@ -523,10 +593,22 @@ def run(self, **kwargs): self.assertEqual(answer, "operator") self.assertTrue(captured["in_thread"]) - def test_wizard_choice_menu_uses_single_line_radio_entries_for_mouse_safety(self) -> None: + def test_wizard_choice_menu_uses_single_line_radio_entries_for_mouse_safety( + self, + ) -> None: choices = ( - WizardChoice(value="companion", label="Companion", detail="Steady and present.", emoji="🤝"), - WizardChoice(value="operator", label="Operator", detail="Direct and durable.", emoji="🛠️"), + WizardChoice( + value="companion", + label="Companion", + detail="Steady and present.", + emoji="🤝", + ), + WizardChoice( + value="operator", + label="Operator", + detail="Direct and durable.", + emoji="🛠️", + ), ) captured: dict[str, object] = {} @@ -554,8 +636,18 @@ def run(self): def test_wizard_choice_menu_can_return_back_signal(self) -> None: choices = ( - WizardChoice(value="companion", label="Companion", detail="Steady and present.", emoji="🤝"), - WizardChoice(value="operator", label="Operator", detail="Direct and durable.", emoji="🛠️"), + WizardChoice( + value="companion", + label="Companion", + detail="Steady and present.", + emoji="🤝", + ), + WizardChoice( + value="operator", + label="Operator", + detail="Direct and durable.", + emoji="🛠️", + ), ) class _FakeApplication: @@ -572,8 +664,18 @@ def run(self): def test_wizard_choice_menu_cancel_never_falls_back_to_default(self) -> None: choices = ( - WizardChoice(value="companion", label="Companion", detail="Steady and present.", emoji="🤝"), - WizardChoice(value="operator", label="Operator", detail="Direct and durable.", emoji="🛠️"), + WizardChoice( + value="companion", + label="Companion", + detail="Steady and present.", + emoji="🤝", + ), + WizardChoice( + value="operator", + label="Operator", + detail="Direct and durable.", + emoji="🛠️", + ), ) class _FakeApplication: @@ -590,8 +692,18 @@ def run(self): def test_wizard_choice_menu_binds_enter_eagerly_for_continue(self) -> None: choices = ( - WizardChoice(value="companion", label="Companion", detail="Steady and present.", emoji="🤝"), - WizardChoice(value="operator", label="Operator", detail="Direct and durable.", emoji="🛠️"), + WizardChoice( + value="companion", + label="Companion", + detail="Steady and present.", + emoji="🤝", + ), + WizardChoice( + value="operator", + label="Operator", + detail="Direct and durable.", + emoji="🛠️", + ), ) binding_calls: list[tuple[tuple[object, ...], dict[str, object]]] = [] @@ -655,7 +767,9 @@ def run(self): button_labels = [button.kwargs["text"] for button in captured["buttons"]] self.assertEqual(button_labels, ["Continue", "Back"]) - def test_wizard_dual_choice_menu_binds_space_and_delete_for_selection_flow(self) -> None: + def test_wizard_dual_choice_menu_binds_space_and_delete_for_selection_flow( + self, + ) -> None: choices = ( WizardChoice(value="gpt-5.4", label="gpt-5.4", detail="Large lane"), WizardChoice(value="gpt-5.4-mini", label="gpt-5.4-mini", detail="Small lane"), @@ -694,7 +808,9 @@ def run(self): self.assertIn((("backspace",), {"eager": True}), binding_calls) self.assertIn((("delete",), {"eager": True}), binding_calls) - def test_wizard_dual_choice_menu_space_accepts_when_complete_off_radio(self) -> None: + def test_wizard_dual_choice_menu_space_accepts_when_complete_off_radio( + self, + ) -> None: choices = ( WizardChoice(value="gpt-5.4", label="gpt-5.4", detail="Large lane"), WizardChoice(value="gpt-5.4-mini", label="gpt-5.4-mini", detail="Small lane"), @@ -781,7 +897,9 @@ def test_wizard_text_prompt_uses_back_button_for_born_flow(self) -> None: self.assertEqual(input_dialog_mock.call_args.kwargs["ok_text"], "Continue") self.assertEqual(input_dialog_mock.call_args.kwargs["cancel_text"], "Back") - def test_wizard_text_prompt_uses_back_button_even_without_previous_step(self) -> None: + def test_wizard_text_prompt_uses_back_button_even_without_previous_step( + self, + ) -> None: dialog = mock.Mock() dialog.run.return_value = None @@ -796,7 +914,9 @@ def test_wizard_text_prompt_uses_back_button_even_without_previous_step(self) -> self.assertEqual(input_dialog_mock.call_args.kwargs["ok_text"], "Continue") self.assertEqual(input_dialog_mock.call_args.kwargs["cancel_text"], "Back") - def test_wizard_text_prompt_uses_required_dialog_when_validation_copy_is_needed(self) -> None: + def test_wizard_text_prompt_uses_required_dialog_when_validation_copy_is_needed( + self, + ) -> None: with ( mock.patch.object(cli_wizard, "_wizard_dialogs_supported", return_value=True), mock.patch.object( @@ -840,7 +960,9 @@ def test_wizard_text_prompt_can_clear_a_prefilled_value(self) -> None: self.assertEqual(answer, "") - def test_wizard_text_prompt_runs_input_dialog_in_thread_when_loop_is_active(self) -> None: + def test_wizard_text_prompt_runs_input_dialog_in_thread_when_loop_is_active( + self, + ) -> None: captured: dict[str, object] = {} class _FakeDialog: @@ -872,15 +994,36 @@ def test_wizard_choice_label_prefixes_emoji_when_present(self) -> None: def test_provider_choices_use_plain_labels_and_brand_accent_detail(self) -> None: runtime = mock.Mock() runtime.provider_inventory.return_value = ( - SimpleNamespace(provider_id="openai-compatible", display_name="OpenAI-compatible", status="requires-setup", source="none", runtime_enabled=True), - SimpleNamespace(provider_id="moonshot", display_name="Moonshot Kimi", status="requires-setup", source="none", runtime_enabled=True), - SimpleNamespace(provider_id="unknown-provider", display_name="Custom", status="requires-setup", source="none", runtime_enabled=True), + SimpleNamespace( + provider_id="openai-compatible", + display_name="OpenAI-compatible", + status="requires-setup", + source="none", + runtime_enabled=True, + ), + SimpleNamespace( + provider_id="moonshot", + display_name="Moonshot Kimi", + status="requires-setup", + source="none", + runtime_enabled=True, + ), + SimpleNamespace( + provider_id="unknown-provider", + display_name="Custom", + status="requires-setup", + source="none", + runtime_enabled=True, + ), ) providers = _provider_choices(runtime) self.assertEqual([choice.emoji for choice in providers], ["", "", ""]) - self.assertEqual([choice.detail_style for choice in providers], ["accent-detail", "accent-detail", "accent-detail"]) + self.assertEqual( + [choice.detail_style for choice in providers], + ["accent-detail", "accent-detail", "accent-detail"], + ) def test_build_parser_registers_brain_surface(self) -> None: parser = cli_main.build_parser() @@ -1123,7 +1266,9 @@ def test_run_brain_configures_openai_compatible_embedding_provider(self) -> None ) print_card.assert_called_once() - def test_run_brain_interactive_provider_state_is_not_compared_as_hashable_signal(self) -> None: + def test_run_brain_interactive_provider_state_is_not_compared_as_hashable_signal( + self, + ) -> None: runtime = mock.Mock() profile_state = SimpleNamespace(profile_id="profile-default", display_name="Atlas", mode="companion") runtime.current_profile.return_value = SimpleNamespace(state=profile_state) @@ -1165,7 +1310,9 @@ def test_run_brain_interactive_provider_state_is_not_compared_as_hashable_signal self.assertEqual(exit_code, 0) runtime.set_default_provider.assert_called_once() - def test_suggest_elephant_name_skips_existing_elephant_ids_when_possible(self) -> None: + def test_suggest_elephant_name_skips_existing_elephant_ids_when_possible( + self, + ) -> None: runtime = mock.Mock() runtime.latest_session_for_elephant.side_effect = lambda elephant_id: object() if elephant_id == "ada" else None captured: dict[str, tuple[str, ...]] = {} @@ -1180,7 +1327,9 @@ def _pick(options): self.assertEqual(suggested, captured["options"][0]) self.assertNotIn("Ada", captured["options"]) - def test_run_setup_uses_random_name_suggestion_when_no_initial_name_is_given(self) -> None: + def test_run_setup_uses_random_name_suggestion_when_no_initial_name_is_given( + self, + ) -> None: runtime = mock.Mock() runtime.current_profile.return_value = SimpleNamespace( state=SimpleNamespace(display_name="Elephant Agent"), @@ -1223,7 +1372,11 @@ def test_run_setup_uses_random_name_suggestion_when_no_initial_name_is_given(sel def test_run_setup_allows_oauth_provider_without_explicit_key(self) -> None: runtime = mock.Mock() runtime.current_profile.return_value = SimpleNamespace( - state=SimpleNamespace(profile_id="profile-companion", display_name="Elephant Agent", mode="companion"), + state=SimpleNamespace( + profile_id="profile-companion", + display_name="Elephant Agent", + mode="companion", + ), companion=SimpleNamespace(personality_preset="companion", initiative="gentle"), ) runtime.provider_setup_guide.return_value = SimpleNamespace( @@ -1232,7 +1385,11 @@ def test_run_setup_allows_oauth_provider_without_explicit_key(self) -> None: ) runtime.detect_provider_context_window.return_value = 128000 updated_profile = SimpleNamespace( - state=SimpleNamespace(profile_id="profile-companion", display_name="Elephant Agent", mode="companion"), + state=SimpleNamespace( + profile_id="profile-companion", + display_name="Elephant Agent", + mode="companion", + ), companion=SimpleNamespace(personality_preset="companion", initiative="gentle"), ) runtime.update_identity.return_value = updated_profile @@ -1284,7 +1441,9 @@ def test_run_setup_allows_oauth_provider_without_explicit_key(self) -> None: runtime.set_default_provider.assert_called_once() self.assertIsNone(runtime.set_default_provider.call_args.kwargs["api_key"]) - def test_interactive_birth_wizard_cancels_when_provider_setup_is_escaped(self) -> None: + def test_interactive_birth_wizard_cancels_when_provider_setup_is_escaped( + self, + ) -> None: runtime = mock.Mock() runtime.personality_presets.return_value = ( SimpleNamespace(preset_id="companion", label="Companion", summary="Steady."), @@ -1323,14 +1482,17 @@ def test_prompt_birth_date_accepts_freeform_input(self) -> None: self.assertEqual(answer, "spring equinox 1991") def test_interactive_elephant_wizard_uses_suggested_name_as_default(self) -> None: - with mock.patch.object( - cli_main, - "_wizard_text_prompt", - return_value="Nova", - ) as text_prompt, mock.patch.object( - cli_main, - "_suggest_elephant_name", - return_value="Rowan", + with ( + mock.patch.object( + cli_main, + "_wizard_text_prompt", + return_value="Nova", + ) as text_prompt, + mock.patch.object( + cli_main, + "_suggest_elephant_name", + return_value="Rowan", + ), ): state = _run_interactive_elephant_wizard(mock.Mock(), elephant_name=None) @@ -1338,7 +1500,9 @@ def test_interactive_elephant_wizard_uses_suggested_name_as_default(self) -> Non self.assertEqual(text_prompt.call_count, 1) self.assertEqual(text_prompt.call_args_list[0].kwargs["default"], "Rowan") - def test_interactive_elephant_wizard_can_cancel_before_creating_elephant(self) -> None: + def test_interactive_elephant_wizard_can_cancel_before_creating_elephant( + self, + ) -> None: with ( mock.patch.object(cli_main, "_wizard_text_prompt", return_value=WIZARD_BACK), mock.patch.object(cli_main, "_suggest_elephant_name", return_value="Theo") as suggest_name, @@ -1356,7 +1520,11 @@ def test_run_setup_creates_first_elephant_when_non_interactive(self) -> None: companion=SimpleNamespace(personality_preset="companion", initiative="gentle"), ) updated_profile = SimpleNamespace( - state=SimpleNamespace(profile_id="profile-companion", display_name="Elephant Agent", mode="companion"), + state=SimpleNamespace( + profile_id="profile-companion", + display_name="Elephant Agent", + mode="companion", + ), companion=SimpleNamespace(personality_preset="companion", initiative="gentle"), ) runtime.provider_setup_guide.return_value = SimpleNamespace(auth_type="api_key", required_secret_keys=()) @@ -1426,7 +1594,11 @@ def test_run_setup_keeps_raw_birth_date_when_non_interactive(self) -> None: companion=SimpleNamespace(personality_preset="companion", initiative="gentle"), ) updated_profile = SimpleNamespace( - state=SimpleNamespace(profile_id="profile-companion", display_name="Elephant Agent", mode="companion"), + state=SimpleNamespace( + profile_id="profile-companion", + display_name="Elephant Agent", + mode="companion", + ), companion=SimpleNamespace(personality_preset="companion", initiative="gentle"), ) runtime.provider_setup_guide.return_value = SimpleNamespace(auth_type="api_key", required_secret_keys=()) @@ -1506,7 +1678,9 @@ def test_run_setup_keeps_raw_birth_date_when_non_interactive(self) -> None: self.assertEqual(exit_code, 0) self.assertEqual(bootstrap_personal_model.call_args.args[2].birth_date, "late summer 1991") - def test_init_question_config_persists_proactive_ask_from_learning_intensity(self) -> None: + def test_init_question_config_persists_proactive_ask_from_learning_intensity( + self, + ) -> None: runtime = SimpleNamespace(paths=SimpleNamespace(state_dir="/tmp/elephant-test/herd")) captured: dict[str, object] = {} @@ -1543,14 +1717,20 @@ def test_init_question_config_persists_proactive_ask_from_learning_intensity(sel ) self.assertEqual(captured["personal_model"]["first_language"], "zh") - def test_interactive_setup_uses_shallow_provider_doctor_before_tui_handoff(self) -> None: + def test_interactive_setup_uses_shallow_provider_doctor_before_tui_handoff( + self, + ) -> None: runtime = mock.Mock() runtime.current_profile.return_value = SimpleNamespace( state=SimpleNamespace(display_name="Elephant Agent"), companion=SimpleNamespace(personality_preset="companion", initiative="gentle"), ) updated_profile = SimpleNamespace( - state=SimpleNamespace(profile_id="profile-companion", display_name="Elephant Agent", mode="companion"), + state=SimpleNamespace( + profile_id="profile-companion", + display_name="Elephant Agent", + mode="companion", + ), companion=SimpleNamespace(personality_preset="companion", initiative="gentle"), ) runtime.provider_setup_guide.return_value = SimpleNamespace(auth_type="api_key", required_secret_keys=()) @@ -1639,15 +1819,19 @@ def test_interactive_setup_uses_shallow_provider_doctor_before_tui_handoff(self) mock.patch.object(cli_main, "_print_birth_wizard_intro"), mock.patch.object(cli_main, "_run_interactive_birth_wizard", return_value=wizard_state), mock.patch.object(cli_main, "_print_init_section"), - mock.patch.object(cli_main, "provider_setup_defaults", return_value=cli_main.ProviderSelectionState( - provider_id="openai-compatible", - base_url="https://api.example.com/v1", - api_key="sk-cli-test-123", - model_id="openai/gpt-4o-mini", - reasoning_effort=None, - context_window_mode="auto", - context_window_tokens=128000, - )), + mock.patch.object( + cli_main, + "provider_setup_defaults", + return_value=cli_main.ProviderSelectionState( + provider_id="openai-compatible", + base_url="https://api.example.com/v1", + api_key="sk-cli-test-123", + model_id="openai/gpt-4o-mini", + reasoning_effort=None, + context_window_mode="auto", + context_window_tokens=128000, + ), + ), mock.patch.object(cli_main, "_persist_init_question_config"), mock.patch.object(cli_main, "_bootstrap_personal_model_from_init"), mock.patch.object(cli_main, "_play_creating_transition"), @@ -1718,7 +1902,9 @@ def test_run_elephant_does_not_open_wizard_when_name_is_preselected(self) -> Non mode="companion", ) - def test_run_grow_defers_surface_prepare_until_after_interactive_shell_boot(self) -> None: + def test_run_grow_defers_surface_prepare_until_after_interactive_shell_boot( + self, + ) -> None: runtime = mock.Mock() runtime.provider_doctor.return_value = {"status": "ready"} shell = mock.Mock() @@ -1730,7 +1916,11 @@ def test_run_grow_defers_surface_prepare_until_after_interactive_shell_boot(self ) with ( - mock.patch.object(cli_main, "_open_growth_episode", return_value=("episode-atlas", "Opened elephant atlas")), + mock.patch.object( + cli_main, + "_open_growth_episode", + return_value=("episode-atlas", "Opened elephant atlas"), + ), mock.patch.object(cli_main, "_interactive_shell_supported", return_value=True), mock.patch.object(cli_main, "ProductizedShell", return_value=shell) as productized_shell, ): @@ -1747,8 +1937,12 @@ def test_run_grow_defers_surface_prepare_until_after_interactive_shell_boot(self def test_open_growth_episode_opens_next_episode_for_open_elephant(self) -> None: runtime = mock.Mock() - runtime.latest_session_for_elephant.return_value = SimpleNamespace(episode_id="episode-parent", status="open", exit_summary="") - runtime.open_next_episode.return_value = SimpleNamespace(episode=SimpleNamespace(episode_id="episode-child", status="open")) + runtime.latest_session_for_elephant.return_value = SimpleNamespace( + episode_id="episode-parent", status="open", exit_summary="" + ) + runtime.open_next_episode.return_value = SimpleNamespace( + episode=SimpleNamespace(episode_id="episode-child", status="open") + ) episode_id, opened = cli_main._open_growth_episode(runtime, elephant_id="atlas") @@ -1758,47 +1952,100 @@ def test_open_growth_episode_opens_next_episode_for_open_elephant(self) -> None: def test_open_growth_episode_opens_next_episode_for_closed_elephant(self) -> None: runtime = mock.Mock() - runtime.latest_session_for_elephant.return_value = SimpleNamespace(episode_id="episode-parent", status="closed", exit_summary="parent handoff") - runtime.open_next_episode.return_value = SimpleNamespace(episode=SimpleNamespace(episode_id="episode-child", status="open")) + runtime.latest_session_for_elephant.return_value = SimpleNamespace( + episode_id="episode-parent", status="closed", exit_summary="parent handoff" + ) + runtime.open_next_episode.return_value = SimpleNamespace( + episode=SimpleNamespace(episode_id="episode-child", status="open") + ) episode_id, opened = cli_main._open_growth_episode(runtime, elephant_id="atlas") self.assertEqual(episode_id, "episode-child") self.assertEqual(opened, "Opened elephant atlas") - runtime.open_next_episode.assert_called_once_with("episode-parent", reason="wake_boundary", summary="parent handoff") + runtime.open_next_episode.assert_called_once_with( + "episode-parent", reason="wake_boundary", summary="parent handoff" + ) - def test_resolve_growth_session_prefers_current_elephant_snapshot_when_multiple_prompting_is_disabled(self) -> None: + def test_resolve_growth_session_prefers_current_elephant_snapshot_when_multiple_prompting_is_disabled( + self, + ) -> None: runtime = mock.Mock() runtime.elephant_id_for_session.return_value = "atlas" runtime.list_herd.return_value = ( - SimpleNamespace(elephant_id="atlas", latest_session_id="episode-atlas", session_count=1, latest_status="open"), - SimpleNamespace(elephant_id="beta", latest_session_id="episode-beta", session_count=1, latest_status="open"), + SimpleNamespace( + elephant_id="atlas", + latest_session_id="episode-atlas", + session_count=1, + latest_status="open", + ), + SimpleNamespace( + elephant_id="beta", + latest_session_id="episode-beta", + session_count=1, + latest_status="open", + ), + ) + current_session = SimpleNamespace( + episode_id="episode-current", + elephant_id="atlas", + status="open", + exit_summary="", + ) + runtime.open_next_episode.return_value = SimpleNamespace( + episode=SimpleNamespace(episode_id="episode-next", status="open") ) - current_session = SimpleNamespace(episode_id="episode-current", elephant_id="atlas", status="open", exit_summary="") - runtime.open_next_episode.return_value = SimpleNamespace(episode=SimpleNamespace(episode_id="episode-next", status="open")) - with mock.patch.object(cli_elephant_support, "_current_elephant_session", return_value=current_session): + with mock.patch.object( + cli_elephant_support, + "_current_elephant_session", + return_value=current_session, + ): episode_id, opened = cli_main._open_growth_episode(runtime, prompt_for_multiple=False) self.assertEqual(episode_id, "episode-next") self.assertEqual(opened, "Opened elephant atlas") runtime.open_next_episode.assert_called_once_with("episode-current", reason="wake_boundary", summary="") - def test_resolve_growth_session_prompts_for_multiple_elephants_in_interactive_mode(self) -> None: + def test_resolve_growth_session_prompts_for_multiple_elephants_in_interactive_mode( + self, + ) -> None: runtime = mock.Mock() runtime.list_herd.return_value = ( - SimpleNamespace(elephant_id="atlas", latest_session_id="episode-atlas", session_count=2, latest_status="open"), - SimpleNamespace(elephant_id="beta", latest_session_id="episode-beta", session_count=3, latest_status="open"), + SimpleNamespace( + elephant_id="atlas", + latest_session_id="episode-atlas", + session_count=2, + latest_status="open", + ), + SimpleNamespace( + elephant_id="beta", + latest_session_id="episode-beta", + session_count=3, + latest_status="open", + ), ) runtime.elephant_id_for_session.return_value = "atlas" - runtime.inspect_session.return_value = SimpleNamespace(episode_id="episode-beta", status="open", exit_summary="") - runtime.open_next_episode.return_value = SimpleNamespace(episode=SimpleNamespace(episode_id="episode-beta-next", status="open")) + runtime.inspect_session.return_value = SimpleNamespace( + episode_id="episode-beta", status="open", exit_summary="" + ) + runtime.open_next_episode.return_value = SimpleNamespace( + episode=SimpleNamespace(episode_id="episode-beta-next", status="open") + ) current_session = SimpleNamespace(episode_id="episode-current", elephant_id="atlas", status="open") selected_elephant = runtime.list_herd.return_value[1] with ( - mock.patch.object(cli_elephant_support, "_current_elephant_session", return_value=current_session), - mock.patch.object(cli_elephant_support, "_prompt_elephant_choice", return_value=selected_elephant) as prompt_elephant_choice, + mock.patch.object( + cli_elephant_support, + "_current_elephant_session", + return_value=current_session, + ), + mock.patch.object( + cli_elephant_support, + "_prompt_elephant_choice", + return_value=selected_elephant, + ) as prompt_elephant_choice, ): episode_id, opened = cli_main._open_growth_episode(runtime, prompt_for_multiple=True) @@ -1812,14 +2059,24 @@ def test_resolve_growth_session_prompts_for_multiple_elephants_in_interactive_mo runtime.inspect_session.assert_called_once_with("episode-beta") runtime.schedule_learning_for_session.assert_not_called() - def test_resolve_growth_session_does_not_queue_boundary_learning_when_opening_different_elephant(self) -> None: + def test_resolve_growth_session_does_not_queue_boundary_learning_when_opening_different_elephant( + self, + ) -> None: runtime = mock.Mock() - runtime.latest_session_for_elephant.return_value = SimpleNamespace(episode_id="episode-beta", status="open", exit_summary="") - runtime.open_next_episode.return_value = SimpleNamespace(episode=SimpleNamespace(episode_id="episode-beta-next", status="open")) + runtime.latest_session_for_elephant.return_value = SimpleNamespace( + episode_id="episode-beta", status="open", exit_summary="" + ) + runtime.open_next_episode.return_value = SimpleNamespace( + episode=SimpleNamespace(episode_id="episode-beta-next", status="open") + ) current_session = SimpleNamespace(episode_id="episode-atlas", elephant_id="atlas", status="open") runtime.elephant_id_for_session.return_value = "atlas" - with mock.patch.object(cli_elephant_support, "_current_elephant_session", return_value=current_session): + with mock.patch.object( + cli_elephant_support, + "_current_elephant_session", + return_value=current_session, + ): episode_id, opened = cli_main._open_growth_episode(runtime, elephant_id="beta") self.assertEqual(episode_id, "episode-beta-next") diff --git a/tests/unit/cli/test_provider_flow.py b/tests/unit/cli/test_provider_flow.py index bb681e3..aca2b63 100644 --- a/tests/unit/cli/test_provider_flow.py +++ b/tests/unit/cli/test_provider_flow.py @@ -3,11 +3,17 @@ import unittest from unittest import mock -from apps.cli.provider_flow import ProviderSelectionState, provider_setup_defaults, run_provider_selection_wizard +from apps.cli.provider_flow import ( + ProviderSelectionState, + provider_setup_defaults, + run_provider_selection_wizard, +) class ProviderFlowWizardTests(unittest.TestCase): - def test_provider_setup_defaults_falls_back_from_preview_to_default_provider(self) -> None: + def test_provider_setup_defaults_falls_back_from_preview_to_default_provider( + self, + ) -> None: runtime = mock.Mock() runtime.provider_setup_guide.side_effect = [ LookupError("preview is not a real provider"), @@ -36,15 +42,25 @@ def test_oauth_provider_skips_session_override_prompt(self) -> None: auth_type="oauth_external", ) runtime.discovered_provider.return_value = mock.Mock(status="authenticated", source="codex-cli") - runtime.provider_summary.return_value = {"provider_id": "openai-codex", "secret_status": "stored"} + runtime.provider_summary.return_value = { + "provider_id": "openai-codex", + "secret_status": "stored", + } runtime.discover_provider_models.return_value = ( - mock.Mock(model_id="gpt-5.4", context_window_tokens=128000, max_output_tokens=16384), + mock.Mock( + model_id="gpt-5.4", + context_window_tokens=128000, + max_output_tokens=16384, + ), ) runtime.provider_reasoning_efforts.return_value = () runtime.detect_provider_context_window.return_value = 128000 with ( - mock.patch("apps.cli.provider_flow._wizard_choice_prompt", side_effect=("gpt-5.4", "auto")), + mock.patch( + "apps.cli.provider_flow._wizard_choice_prompt", + side_effect=("gpt-5.4", "auto"), + ), mock.patch("apps.cli.provider_flow._wizard_text_prompt") as text_prompt, ): result = run_provider_selection_wizard( @@ -74,9 +90,16 @@ def test_discovered_copilot_credentials_skip_key_prompt(self) -> None: auth_type="api_key", ) runtime.discovered_provider.return_value = mock.Mock(status="authenticated", source="gh-cli") - runtime.provider_summary.return_value = {"provider_id": "copilot", "secret_status": "stored"} + runtime.provider_summary.return_value = { + "provider_id": "copilot", + "secret_status": "stored", + } runtime.discover_provider_models.return_value = ( - mock.Mock(model_id="gpt-5.4", context_window_tokens=128000, max_output_tokens=16384), + mock.Mock( + model_id="gpt-5.4", + context_window_tokens=128000, + max_output_tokens=16384, + ), ) runtime.provider_reasoning_efforts.return_value = () runtime.detect_provider_context_window.return_value = 128000 @@ -112,10 +135,21 @@ def test_provider_wizard_uses_one_model_prompt(self) -> None: auth_type="oauth_external", ) runtime.discovered_provider.return_value = mock.Mock(status="authenticated", source="codex-cli") - runtime.provider_summary.return_value = {"provider_id": "openai-codex", "secret_status": "stored"} + runtime.provider_summary.return_value = { + "provider_id": "openai-codex", + "secret_status": "stored", + } runtime.discover_provider_models.return_value = ( - mock.Mock(model_id="gpt-5.4", context_window_tokens=128000, max_output_tokens=16384), - mock.Mock(model_id="gpt-5.4-mini", context_window_tokens=128000, max_output_tokens=16384), + mock.Mock( + model_id="gpt-5.4", + context_window_tokens=128000, + max_output_tokens=16384, + ), + mock.Mock( + model_id="gpt-5.4-mini", + context_window_tokens=128000, + max_output_tokens=16384, + ), ) runtime.provider_reasoning_efforts.return_value = () runtime.detect_provider_context_window.return_value = 128000 @@ -142,7 +176,9 @@ def test_provider_wizard_uses_one_model_prompt(self) -> None: self.assertEqual(result.model_id, "gpt-5.4-mini") self.assertEqual(choice_prompt.call_count, 2) - def test_openai_compatible_prompts_for_key_again_when_base_url_changes(self) -> None: + def test_openai_compatible_prompts_for_key_again_when_base_url_changes( + self, + ) -> None: runtime = mock.Mock() runtime.provider_setup_guide.return_value = mock.Mock( required_config_keys=("base_url", "model_id"), @@ -156,13 +192,20 @@ def test_openai_compatible_prompts_for_key_again_when_base_url_changes(self) -> "base_url": "https://old.example.test/v1", } runtime.discover_provider_models.return_value = ( - mock.Mock(model_id="openai/gpt-4o-mini", context_window_tokens=128000, max_output_tokens=16384), + mock.Mock( + model_id="openai/gpt-4o-mini", + context_window_tokens=128000, + max_output_tokens=16384, + ), ) runtime.provider_reasoning_efforts.return_value = () runtime.detect_provider_context_window.return_value = 128000 with ( - mock.patch("apps.cli.provider_flow._wizard_choice_prompt", side_effect=("openai/gpt-4o-mini", "auto")), + mock.patch( + "apps.cli.provider_flow._wizard_choice_prompt", + side_effect=("openai/gpt-4o-mini", "auto"), + ), mock.patch( "apps.cli.provider_flow._wizard_text_prompt", side_effect=("https://new.example.test/v1", "sk-new-key"), @@ -186,7 +229,9 @@ def test_openai_compatible_prompts_for_key_again_when_base_url_changes(self) -> self.assertEqual(result.api_key, "sk-new-key") self.assertEqual(text_prompt.call_count, 2) - def test_openai_compatible_configured_state_still_prompts_for_key_when_endpoint_changes(self) -> None: + def test_openai_compatible_configured_state_still_prompts_for_key_when_endpoint_changes( + self, + ) -> None: runtime = mock.Mock() runtime.provider_setup_guide.return_value = mock.Mock( required_config_keys=("base_url", "model_id"), @@ -204,13 +249,20 @@ def test_openai_compatible_configured_state_still_prompts_for_key_when_endpoint_ "base_url": "https://api.githubcopilot.com", } runtime.discover_provider_models.return_value = ( - mock.Mock(model_id="openai/gpt-4o-mini", context_window_tokens=128000, max_output_tokens=16384), + mock.Mock( + model_id="openai/gpt-4o-mini", + context_window_tokens=128000, + max_output_tokens=16384, + ), ) runtime.provider_reasoning_efforts.return_value = () runtime.detect_provider_context_window.return_value = 128000 with ( - mock.patch("apps.cli.provider_flow._wizard_choice_prompt", side_effect=("openai/gpt-4o-mini", "auto")), + mock.patch( + "apps.cli.provider_flow._wizard_choice_prompt", + side_effect=("openai/gpt-4o-mini", "auto"), + ), mock.patch( "apps.cli.provider_flow._wizard_text_prompt", side_effect=("https://new.example.test/v1", "sk-new-key"), @@ -234,7 +286,9 @@ def test_openai_compatible_configured_state_still_prompts_for_key_when_endpoint_ self.assertEqual(result.api_key, "sk-new-key") self.assertEqual(text_prompt.call_count, 2) - def test_openai_compatible_reuses_configured_key_when_endpoint_matches(self) -> None: + def test_openai_compatible_reuses_configured_key_when_endpoint_matches( + self, + ) -> None: runtime = mock.Mock() runtime.provider_setup_guide.return_value = mock.Mock( required_config_keys=("base_url", "model_id"), @@ -252,13 +306,19 @@ def test_openai_compatible_reuses_configured_key_when_endpoint_matches(self) -> "base_url": "https://irrelevant.example.test/v1", } runtime.discover_provider_models.return_value = ( - mock.Mock(model_id="openai/gpt-4o-mini", context_window_tokens=128000, max_output_tokens=16384), + mock.Mock( + model_id="openai/gpt-4o-mini", + context_window_tokens=128000, + max_output_tokens=16384, + ), ) runtime.provider_reasoning_efforts.return_value = () runtime.detect_provider_context_window.return_value = 128000 with ( - mock.patch("apps.cli.provider_flow._wizard_choice_prompt", side_effect=("keep", "openai/gpt-4o-mini", "auto")), + mock.patch( + "apps.cli.provider_flow._wizard_choice_prompt", side_effect=("keep", "openai/gpt-4o-mini", "auto") + ), mock.patch( "apps.cli.provider_flow._wizard_text_prompt", side_effect=("https://same.example.test/v1",), @@ -282,7 +342,9 @@ def test_openai_compatible_reuses_configured_key_when_endpoint_matches(self) -> self.assertIsNone(result.api_key) self.assertEqual(text_prompt.call_count, 1) - def test_api_key_provider_retries_key_before_manual_model_when_catalog_is_unavailable(self) -> None: + def test_api_key_provider_retries_key_before_manual_model_when_catalog_is_unavailable( + self, + ) -> None: runtime = mock.Mock() runtime.provider_setup_guide.return_value = mock.Mock( required_config_keys=("model_id",), @@ -301,14 +363,28 @@ def test_api_key_provider_retries_key_before_manual_model_when_catalog_is_unavai } runtime.discover_provider_models.side_effect = [ (), - (mock.Mock(model_id="openai/gpt-4o-mini", context_window_tokens=128000, max_output_tokens=16384),), - (mock.Mock(model_id="openai/gpt-4o-mini", context_window_tokens=128000, max_output_tokens=16384),), + ( + mock.Mock( + model_id="openai/gpt-4o-mini", + context_window_tokens=128000, + max_output_tokens=16384, + ), + ), + ( + mock.Mock( + model_id="openai/gpt-4o-mini", + context_window_tokens=128000, + max_output_tokens=16384, + ), + ), ] runtime.provider_reasoning_efforts.return_value = () runtime.detect_provider_context_window.return_value = 128000 with ( - mock.patch("apps.cli.provider_flow._wizard_choice_prompt", side_effect=("keep", "openai/gpt-4o-mini", "auto")), + mock.patch( + "apps.cli.provider_flow._wizard_choice_prompt", side_effect=("keep", "openai/gpt-4o-mini", "auto") + ), mock.patch( "apps.cli.provider_flow._wizard_text_prompt", side_effect=("sk-refreshed-key",), @@ -341,7 +417,11 @@ def test_copilot_keeps_manual_model_flow_when_catalog_is_unavailable(self) -> No auth_type="api_key", ) runtime.discovered_provider.return_value = mock.Mock(status="authenticated", source="gh-cli", base_url="") - runtime.provider_summary.return_value = {"provider_id": "copilot", "secret_status": "stored", "base_url": ""} + runtime.provider_summary.return_value = { + "provider_id": "copilot", + "secret_status": "stored", + "base_url": "", + } runtime.discover_provider_models.return_value = () runtime.provider_reasoning_efforts.return_value = () runtime.detect_provider_context_window.return_value = 128000 @@ -389,7 +469,9 @@ def test_stored_api_key_prompts_keep_or_replace(self) -> None: runtime.detect_provider_context_window.return_value = 128000 with ( - mock.patch("apps.cli.provider_flow._wizard_choice_prompt", side_effect=("keep", "deepseek-reasoner", "auto")), + mock.patch( + "apps.cli.provider_flow._wizard_choice_prompt", side_effect=("keep", "deepseek-reasoner", "auto") + ), mock.patch("apps.cli.provider_flow._wizard_text_prompt") as text_prompt, ): result = run_provider_selection_wizard( @@ -426,7 +508,9 @@ def test_stored_api_key_replace_prompts_for_new_key(self) -> None: runtime.detect_provider_context_window.return_value = 128000 with ( - mock.patch("apps.cli.provider_flow._wizard_choice_prompt", side_effect=("replace", "deepseek-reasoner", "auto")), + mock.patch( + "apps.cli.provider_flow._wizard_choice_prompt", side_effect=("replace", "deepseek-reasoner", "auto") + ), mock.patch( "apps.cli.provider_flow._wizard_text_prompt", side_effect=("sk-new-deepseek-key",), diff --git a/tests/unit/cli/test_runtime_cognition.py b/tests/unit/cli/test_runtime_cognition.py index f7328ba..32c1eea 100644 --- a/tests/unit/cli/test_runtime_cognition.py +++ b/tests/unit/cli/test_runtime_cognition.py @@ -16,7 +16,10 @@ sys.path.insert(0, str(ROOT)) from apps.cli.runtime import CliRuntime, _CliContextCapability, _DurableRecallCapability -from apps.cli.runtime_snapshot import load_snapshot_session_context_epoch, restore_snapshot_state_focus +from apps.cli.runtime_snapshot import ( + load_snapshot_session_context_epoch, + restore_snapshot_state_focus, +) from packages.contracts import ( ContextBundle, Episode, @@ -44,8 +47,6 @@ from packages.skills import ( FetchedSkillBundle, SkillSearchEntry, - builtin_site_skill_catalog_entries, - operator_prompt_skill_catalog_entries, ) @@ -138,7 +139,9 @@ def _runtime( ) return runtime - def test_cli_context_capability_recovers_recent_loop_context_from_snapshot(self) -> None: + def test_cli_context_capability_recovers_recent_loop_context_from_snapshot( + self, + ) -> None: runtime = self._runtime() session = runtime.start() runtime.snapshot_path.write_text( @@ -188,7 +191,9 @@ def index_episode_exit(self, episode: Episode) -> None: self.assertEqual(indexed[0].status, "closed") self.assertEqual(indexed[0].exit_summary, "/clear requested a fresh Episode") - def test_cli_context_capability_ignores_internal_startup_loops_in_recent_loop_context(self) -> None: + def test_cli_context_capability_ignores_internal_startup_loops_in_recent_loop_context( + self, + ) -> None: runtime = self._runtime() session = runtime.start() runtime.snapshot_path.write_text( @@ -217,7 +222,9 @@ def test_cli_context_capability_ignores_internal_startup_loops_in_recent_loop_co self.assertNotIn("startup opening", bundle.rendered_prompt) self.assertNotIn("steady welcome", bundle.rendered_prompt) - def test_cli_context_does_not_duplicate_active_personal_model_behavior_contract(self) -> None: + def test_cli_context_does_not_duplicate_active_personal_model_behavior_contract( + self, + ) -> None: runtime = self._runtime() session = runtime.start() capability = _CliContextCapability( @@ -251,7 +258,9 @@ def test_multiple_next_episodes_keep_lineage(self) -> None: self.assertEqual(second_child.parent_episode_id, parent.episode_id) self.assertNotEqual(first_child.episode_id, second_child.episode_id) - def test_frozen_session_context_epoch_reuses_stable_sections_without_turn_bodies(self) -> None: + def test_frozen_session_context_epoch_reuses_stable_sections_without_turn_bodies( + self, + ) -> None: runtime = self._runtime() session = runtime.start() profile = runtime._load_profile(session.personal_model_id) @@ -474,7 +483,9 @@ def test_frozen_skill_index_honors_profile_skill_disable_overrides(self) -> None self.assertNotIn("ascii-art", frozen_epoch.frozen_skill_ids) self.assertEqual(frozen_epoch.frozen_skill_ids, ()) - def test_frozen_session_history_compacts_explicitly_without_rewriting_epoch_truth(self) -> None: + def test_frozen_session_history_compacts_explicitly_without_rewriting_epoch_truth( + self, + ) -> None: runtime = self._runtime() session = runtime.start() profile = runtime._load_profile(session.personal_model_id) @@ -525,8 +536,7 @@ def test_frozen_session_history_compacts_explicitly_without_rewriting_epoch_trut ) snapshot = json.loads(runtime.snapshot_path.read_text(encoding="utf-8")) snapshot["session_context_epoch"]["history_messages"] = [ - {"role": message.role, "content": message.content} - for message in long_history + {"role": message.role, "content": message.content} for message in long_history ] runtime.snapshot_path.write_text(json.dumps(snapshot, indent=2, sort_keys=True), encoding="utf-8") @@ -579,9 +589,14 @@ def compact_session_projection(self, **kwargs: object) -> str: self.assertEqual(result, "compacted") self.assertIs(captured["embedding_service"], embedding_service) - self.assertEqual(captured["compact_kwargs"], {"session_id": session.episode_id, "reason": "usage", "force": True}) + self.assertEqual( + captured["compact_kwargs"], + {"session_id": session.episode_id, "reason": "usage", "force": True}, + ) - def test_projection_relevance_scorer_was_removed_from_context_public_contract(self) -> None: + def test_projection_relevance_scorer_was_removed_from_context_public_contract( + self, + ) -> None: runtime = self._runtime() capability = _CliContextCapability( profile_loader=runtime.profile_loader, @@ -593,7 +608,9 @@ def test_projection_relevance_scorer_was_removed_from_context_public_contract(se self.assertFalse(hasattr(capability, "_projection_relevance_scorer")) - def test_snapshot_history_messages_use_actual_turn_transcript_without_legacy_lines(self) -> None: + def test_snapshot_history_messages_use_actual_turn_transcript_without_legacy_lines( + self, + ) -> None: runtime = self._runtime() session = runtime.start() profile = runtime._load_profile(session.personal_model_id) @@ -655,12 +672,17 @@ def test_snapshot_history_messages_use_actual_turn_transcript_without_legacy_lin self.assertNotIn("history_lines", epoch_payload) roles = [message["role"] for message in epoch_payload["history_messages"]] self.assertEqual(roles, ["user", "assistant", "tool", "assistant"]) - self.assertEqual(epoch_payload["history_messages"][1]["tool_calls"][0]["name"], "tool.web.search") + self.assertEqual( + epoch_payload["history_messages"][1]["tool_calls"][0]["name"], + "tool.web.search", + ) self.assertEqual(epoch_payload["history_messages"][2]["tool_name"], "tool.web.search") self.assertEqual(epoch_payload["history_messages"][2]["tool_call_id"], "call-real-1") self.assertIn("summary: search result", epoch_payload["history_messages"][2]["content"]) - def test_high_usage_turn_compacts_snapshot_after_current_transcript_is_appended(self) -> None: + def test_high_usage_turn_compacts_snapshot_after_current_transcript_is_appended( + self, + ) -> None: runtime = self._runtime() session = runtime.start() observed_events: list[dict[str, object]] = [] @@ -715,12 +737,17 @@ def run_reflect_agent(_runtime, job, *, explicit_features, persist_result): self.assertEqual(frozen_epoch.compaction_count, 1) self.assertIn("Reference summary:", frozen_epoch.frozen_prefix) self.assertIn("oversized completed request", frozen_epoch.compacted_history_summary) - self.assertIn("oversized completed request", captured_compress_metadata["compressed_messages"]) + self.assertIn( + "oversized completed request", + captured_compress_metadata["compressed_messages"], + ) history = tuple(message.content for message in frozen_epoch.history_messages) self.assertIn("completed answer", history) self.assertNotIn(huge_prompt, history) - def test_frozen_session_context_epoch_tracks_latest_skill_disclosure_reason(self) -> None: + def test_frozen_session_context_epoch_tracks_latest_skill_disclosure_reason( + self, + ) -> None: runtime = self._runtime() session = runtime.start() profile = runtime._load_profile(session.personal_model_id) @@ -774,7 +801,9 @@ def test_frozen_session_context_epoch_tracks_latest_skill_disclosure_reason(self frozen_epoch.latest_skill_disclosures[0].reason, ) - def test_snapshot_state_focus_restore_rejects_legacy_skill_candidate_scores(self) -> None: + def test_snapshot_state_focus_restore_rejects_legacy_skill_candidate_scores( + self, + ) -> None: snapshot = { "state_focus": { "state_focus": "execution", @@ -793,7 +822,9 @@ def test_snapshot_state_focus_restore_rejects_legacy_skill_candidate_scores(self with self.assertRaises(ValueError): restore_snapshot_state_focus(snapshot) - def test_durable_recall_capability_prefers_work_item_aware_continuity_retrieval(self) -> None: + def test_durable_recall_capability_prefers_work_item_aware_continuity_retrieval( + self, + ) -> None: runtime = self._runtime() session = runtime.start() runtime.repository.upsert_loop( @@ -885,7 +916,9 @@ def test_inspect_continuity_surfaces_reengagement_guidance(self) -> None: self.assertNotIn("Publish the release artifacts.", continuity.reengagement_prompt) self.assertIn("initiative=gentle", continuity.continuity_summary) - def test_planning_recall_evidence_recovery_falls_back_to_episode_scoped_steps(self) -> None: + def test_planning_recall_evidence_recovery_falls_back_to_episode_scoped_steps( + self, + ) -> None: runtime = self._runtime() session = runtime.start() runtime.repository.upsert_loop( @@ -940,7 +973,10 @@ def test_planning_recall_evidence_recovery_falls_back_to_episode_scoped_steps(se with mock.patch.object(runtime.recall_runtime, "retrieve_evidence", return_value=empty_retrieval): recovery = runtime._planning_recall_evidence_recovery(session) - self.assertEqual(tuple(evidence.evidence_id for evidence in recovery.recall_items), ("step:evidence-fallback",)) + self.assertEqual( + tuple(evidence.evidence_id for evidence in recovery.recall_items), + ("step:evidence-fallback",), + ) self.assertEqual(recovery.scope_episode_ids, (session.episode_id,)) def test_prepare_session_surface_kicks_off_embedding_steadyup(self) -> None: @@ -953,7 +989,9 @@ def test_prepare_session_surface_kicks_off_embedding_steadyup(self) -> None: steady_async.assert_called_once_with() - def test_skill_catalog_does_not_kick_off_embedding_steadyup_for_passive_ui_reads(self) -> None: + def test_skill_catalog_does_not_kick_off_embedding_steadyup_for_passive_ui_reads( + self, + ) -> None: runtime = self._runtime() session = runtime.start() embedding_service = runtime.recall_runtime.retriever.evidence_retriever.embedding_service @@ -963,7 +1001,9 @@ def test_skill_catalog_does_not_kick_off_embedding_steadyup_for_passive_ui_reads steady_async.assert_not_called() - def test_cli_context_capability_surfaces_enabled_tools_and_scoped_skills(self) -> None: + def test_cli_context_capability_surfaces_enabled_tools_and_scoped_skills( + self, + ) -> None: runtime = self._runtime() session = runtime.start() now = datetime.now(timezone.utc) @@ -1069,10 +1109,18 @@ def test_cli_context_injects_default_workspace_path(self) -> None: self.assertIn("runtime-paths:", bundle.rendered_prompt) self.assertIn("### Runtime paths", bundle.prompt_envelope.system_prompt()) self.assertNotIn("runtime-paths:", bundle.prompt_envelope.user_prelude()) - self.assertIn(f"elephant_workspace={runtime.paths.elephant_file_path('miles').resolve()}", bundle.rendered_prompt) - self.assertIn(f"elephant_workspace={runtime.paths.elephant_file_path('miles').resolve()}", bundle.prompt_envelope.system_prompt()) + self.assertIn( + f"elephant_workspace={runtime.paths.elephant_file_path('miles').resolve()}", + bundle.rendered_prompt, + ) + self.assertIn( + f"elephant_workspace={runtime.paths.elephant_file_path('miles').resolve()}", + bundle.prompt_envelope.system_prompt(), + ) - def test_cli_context_only_lists_launch_directory_rule_files_for_on_demand_reading(self) -> None: + def test_cli_context_only_lists_launch_directory_rule_files_for_on_demand_reading( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="miles") @@ -1098,20 +1146,46 @@ def test_cli_context_only_lists_launch_directory_rule_files_for_on_demand_readin bundle = capability.assemble(session, (), ()) self.assertNotIn("### Launch Directory Context", bundle.prompt_envelope.frozen_prefix) - self.assertNotIn(f"Current absolute path: `{startup_dir.resolve()}`", bundle.prompt_envelope.frozen_prefix) - self.assertNotIn("Launch-directory rule files are available for on-demand reading:", bundle.prompt_envelope.frozen_prefix) + self.assertNotIn( + f"Current absolute path: `{startup_dir.resolve()}`", + bundle.prompt_envelope.frozen_prefix, + ) + self.assertNotIn( + "Launch-directory rule files are available for on-demand reading:", + bundle.prompt_envelope.frozen_prefix, + ) self.assertNotIn(f"- `{startup_dir / 'AGENTS.md'}`", bundle.prompt_envelope.frozen_prefix) self.assertNotIn(".elephant.md", bundle.prompt_envelope.frozen_prefix) - self.assertNotIn("Loaded launch-directory project context files:", bundle.prompt_envelope.frozen_prefix) - self.assertNotIn("Always treat the current repo as the primary analysis target.", bundle.prompt_envelope.frozen_prefix) - self.assertNotIn("Use launch-directory docs before generic fallbacks.", bundle.prompt_envelope.frozen_prefix) + self.assertNotIn( + "Loaded launch-directory project context files:", + bundle.prompt_envelope.frozen_prefix, + ) + self.assertNotIn( + "Always treat the current repo as the primary analysis target.", + bundle.prompt_envelope.frozen_prefix, + ) + self.assertNotIn( + "Use launch-directory docs before generic fallbacks.", + bundle.prompt_envelope.frozen_prefix, + ) self.assertIn(f"startup_cwd={startup_dir.resolve()}", bundle.rendered_prompt) - self.assertNotIn(f"startup_cwd={startup_dir.resolve()}", bundle.prompt_envelope.system_prompt()) + self.assertNotIn( + f"startup_cwd={startup_dir.resolve()}", + bundle.prompt_envelope.system_prompt(), + ) self.assertIn("startup_cwd=", bundle.prompt_envelope.system_prompt()) - self.assertNotIn(f"startup_cwd={startup_dir.resolve()}", bundle.prompt_envelope.user_prelude()) - self.assertIn(f"elephant_workspace={runtime.paths.elephant_file_path('miles').resolve()}", bundle.rendered_prompt) + self.assertNotIn( + f"startup_cwd={startup_dir.resolve()}", + bundle.prompt_envelope.user_prelude(), + ) + self.assertIn( + f"elephant_workspace={runtime.paths.elephant_file_path('miles').resolve()}", + bundle.rendered_prompt, + ) - def test_installing_skill_package_does_not_eagerly_expand_generation_context(self) -> None: + def test_installing_skill_package_does_not_eagerly_expand_generation_context( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") skill_dir = Path(runtime.paths.state_dir) / "test-skill" @@ -1137,7 +1211,10 @@ def test_installing_skill_package_does_not_eagerly_expand_generation_context(sel runtime.install_skill_source("custom-1:test-skill", session_id=session.episode_id) installed_entry = runtime.inspect_skill("test-skill", session_id=session.episode_id) - self.assertEqual(Path(installed_entry.entry_path), runtime.paths.installed_skills_dir / "custom-1" / "test-skill" / "SKILL.md") + self.assertEqual( + Path(installed_entry.entry_path), + runtime.paths.installed_skills_dir / "custom-1" / "test-skill" / "SKILL.md", + ) self.assertTrue(Path(installed_entry.entry_path).exists()) self.assertEqual(installed_entry.metadata.get("source_reference"), "custom-1:test-skill") self.assertEqual(installed_entry.metadata.get("install_action"), "install") @@ -1157,7 +1234,9 @@ def test_installing_skill_package_does_not_eagerly_expand_generation_context(sel self.assertNotIn("Search Skill", bundle.prompt_envelope.frozen_prefix) self.assertNotIn("Always search before editing", bundle.prompt_envelope.frozen_prefix) - def test_enabled_shelf_skill_enters_prompt_index_without_runtime_install(self) -> None: + def test_enabled_shelf_skill_enters_prompt_index_without_runtime_install( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") skill_dir = runtime.paths.installed_skills_dir / "manual" / "shelf-skill" @@ -1188,11 +1267,17 @@ def test_enabled_shelf_skill_enters_prompt_index_without_runtime_install(self) - install_root=runtime.paths.home_dir, ) - with mock.patch("apps.cli.runtime_cognition.build_launch_directory_context", return_value=(), create=True): + with mock.patch( + "apps.cli.runtime_cognition.build_launch_directory_context", + return_value=(), + create=True, + ): bundle = capability.assemble(session, (), ()) self.assertNotIn("Shelf Skill", bundle.rendered_prompt) - self.assertFalse(any(skill.skill_id == "shelf-skill" for skill in runtime.skill_catalog(session_id=session.episode_id))) + self.assertFalse( + any(skill.skill_id == "shelf-skill" for skill in runtime.skill_catalog(session_id=session.episode_id)) + ) loaded = runtime._load_profile(session.personal_model_id) manifest = dict(loaded.manifest) @@ -1219,7 +1304,9 @@ def test_enabled_shelf_skill_enters_prompt_index_without_runtime_install(self) - self.assertIn("enabled: False", viewed.summary) self.assertIn("installed: True", viewed.summary) - def test_explain_next_step_persists_assistant_outcome_as_decision_memory(self) -> None: + def test_explain_next_step_persists_assistant_outcome_as_decision_memory( + self, + ) -> None: runtime = self._runtime() session = runtime.start() @@ -1234,14 +1321,19 @@ def test_explain_next_step_persists_assistant_outcome_as_decision_memory(self) - # context stays visible without mixing in per-turn State summaries. self.assertNotIn("### Where things stand", outcome.context.rendered_prompt) self.assertNotIn("### Carrying context forward", outcome.context.rendered_prompt) - self.assertNotIn("recovered-evidence-summary: no durable recall_items", outcome.context.rendered_prompt) + self.assertNotIn( + "recovered-evidence-summary: no durable recall_items", + outcome.context.rendered_prompt, + ) self.assertFalse(any(evidence.kind == "decision" for evidence in recall_items)) steps = runtime.repository.list_steps() self.assertTrue(any(step.episode_id == session.episode_id for step in steps)) self.assertTrue(any(outcome.execution.summary in step.summary for step in steps)) self.assertEqual(runtime.inspect_experiences(session_id=session.episode_id), ()) - def test_explain_next_step_updates_personal_model_growth_from_level_zero_to_level_one(self) -> None: + def test_explain_next_step_updates_personal_model_growth_from_level_zero_to_level_one( + self, + ) -> None: runtime = self._runtime() session = runtime.start() @@ -1284,7 +1376,9 @@ def test_generate_opening_reply_returns_none_without_active_provider(self) -> No self.assertIsNone(outcome) - def test_generate_opening_reply_uses_internal_turn_without_growth_side_effects(self) -> None: + def test_generate_opening_reply_uses_internal_turn_without_growth_side_effects( + self, + ) -> None: runtime = self._runtime() session = runtime.start() @@ -1304,7 +1398,10 @@ def test_generate_opening_reply_uses_internal_turn_without_growth_side_effects(s self.assertFalse(kwargs["record_outcome_event"]) self.assertFalse(kwargs["capture_experience"]) self.assertFalse(kwargs["apply_growth"]) - self.assertEqual(kwargs["event_payload"]["summary"], "startup opening (Opened elephant atlas)") + self.assertEqual( + kwargs["event_payload"]["summary"], + "startup opening (Opened elephant atlas)", + ) self.assertEqual(kwargs["event_payload"]["allow_embeddings"], "false") def test_generate_opening_reply_keeps_wake_episode_open(self) -> None: @@ -1341,7 +1438,9 @@ def generate_response(*, profile, session, context, prompt, model_role="strong") self.assertNotEqual(stored.status, "closed") self.assertEqual(runtime.repository.list_learning_jobs(episode_id=session.episode_id), ()) - def test_cli_turn_continues_from_next_episode_without_reopening_closed_parent(self) -> None: + def test_cli_turn_continues_from_next_episode_without_reopening_closed_parent( + self, + ) -> None: runtime = self._runtime() session = runtime.start() episode = runtime.repository.load_episode(session.episode_id) @@ -1358,7 +1457,9 @@ def test_cli_turn_continues_from_next_episode_without_reopening_closed_parent(se ) transition = runtime.open_next_episode(session.episode_id, reason="wake_boundary") - outcome = runtime.explain_next_step(session_id=transition.episode.episode_id, prompt="continue this wake thread") + outcome = runtime.explain_next_step( + session_id=transition.episode.episode_id, prompt="continue this wake thread" + ) stored_parent = runtime.repository.load_episode(session.episode_id) self.assertIsNotNone(stored_parent) @@ -1371,7 +1472,10 @@ def test_cli_turn_continues_from_next_episode_without_reopening_closed_parent(se self.assertEqual(stored_child.parent_episode_id, session.episode_id) self.assertEqual(stored_child.metadata.get("opening_resume_snapshot"), "final parent summary") self.assertEqual(runtime.repository.list_learning_jobs(episode_id=session.episode_id), ()) - self.assertEqual(runtime.repository.list_learning_jobs(episode_id=transition.episode.episode_id), ()) + self.assertEqual( + runtime.repository.list_learning_jobs(episode_id=transition.episode.episode_id), + (), + ) def test_state_focus_runtime_status_surfaces_loaded_runtime_state(self) -> None: runtime = self._runtime() @@ -1391,9 +1495,15 @@ def test_state_focus_runtime_status_surfaces_loaded_runtime_state(self) -> None: self.assertEqual(status["runtime_state"], "loaded") self.assertTrue(status["embedding_ready"]) - def test_shared_elephant_authored_skill_shelf_supports_cross_profile_reuse(self) -> None: + def test_shared_elephant_authored_skill_shelf_supports_cross_profile_reuse( + self, + ) -> None: with tempfile.TemporaryDirectory() as authored_dir: - with mock.patch.dict("os.environ", {"ELEPHANT_AUTHORED_SKILLS_DIR": authored_dir}, clear=False): + with mock.patch.dict( + "os.environ", + {"ELEPHANT_AUTHORED_SKILLS_DIR": authored_dir}, + clear=False, + ): runtime_a = self._runtime() session_a = runtime_a.create_elephant(elephant_id="atlas") runtime_a.create_experience_skill( @@ -1419,7 +1529,11 @@ def test_shared_elephant_authored_skill_shelf_supports_cross_profile_reuse(self) def test_create_experience_skill_surfaces_in_skill_hub_listing(self) -> None: with tempfile.TemporaryDirectory() as authored_dir: - with mock.patch.dict("os.environ", {"ELEPHANT_AUTHORED_SKILLS_DIR": authored_dir}, clear=False): + with mock.patch.dict( + "os.environ", + {"ELEPHANT_AUTHORED_SKILLS_DIR": authored_dir}, + clear=False, + ): runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") runtime.create_experience_skill( @@ -1462,7 +1576,9 @@ def test_search_skill_sources_queries_external_sources(self) -> None: self.assertEqual(searched[0].reference, "github:openai/skills/bounded-retrieval") self.assertEqual(searched[0].trust_level, "trusted") - def test_inspect_skill_source_can_inspect_remote_search_reference_without_installing(self) -> None: + def test_inspect_skill_source_can_inspect_remote_search_reference_without_installing( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") remote_dir = Path(runtime.paths.state_dir) / "remote-skill" @@ -1508,12 +1624,20 @@ def test_inspect_skill_source_can_inspect_remote_search_reference_without_instal self.assertEqual(inspected.display_name, "Remote Notes") self.assertEqual(inspected.metadata.get("hub_reference"), "github:openai/skills/remote-notes") - self.assertEqual(inspected.metadata.get("source_reference"), "github:openai/skills/remote-notes") - self.assertEqual(inspected.metadata.get("install_reference"), "github:openai/skills/remote-notes") + self.assertEqual( + inspected.metadata.get("source_reference"), + "github:openai/skills/remote-notes", + ) + self.assertEqual( + inspected.metadata.get("install_reference"), + "github:openai/skills/remote-notes", + ) self.assertEqual(inspected.metadata.get("trust_level"), "trusted") self.assertIn("AppleScript", inspected.instruction_text) - def test_inspect_skill_can_read_builtin_skill_package_without_installing(self) -> None: + def test_inspect_skill_can_read_builtin_skill_package_without_installing( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") @@ -1529,7 +1653,9 @@ def test_inspect_skill_can_read_builtin_skill_package_without_installing(self) - self.assertIn("memo notes --help", inspected.instruction_text) self.assertIn("open -a Notes", inspected.instruction_text) - def test_operator_profile_surface_can_inspect_and_update_profile_surface(self) -> None: + def test_operator_profile_surface_can_inspect_and_update_profile_surface( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") @@ -1555,8 +1681,14 @@ def test_operator_profile_surface_can_inspect_and_update_profile_surface(self) - self.assertEqual(updated_profile.identity.personality_preset, "operator") self.assertEqual(updated_profile.identity.initiative, "proactive") self.assertEqual(updated_profile.user.preferred_name, "xunzhuo") - self.assertIn("Prefers direct updates and wants long-horizon context preserved.", updated_profile.user.durable_notes) - self.assertIn("Keep responses concise and grounded.", updated_profile.relationship.continuity_notes) + self.assertIn( + "Prefers direct updates and wants long-horizon context preserved.", + updated_profile.user.durable_notes, + ) + self.assertIn( + "Keep responses concise and grounded.", + updated_profile.relationship.continuity_notes, + ) user = runtime.inspect_user(session_id=session.episode_id) self.assertIn("current_work:Software engineer", user.biography_fragments) self.assertEqual( @@ -1586,7 +1718,9 @@ def test_operator_profile_surface_accepts_scoped_user_fields(self) -> None: self.assertIn("current_work:Software engineer", user.biography_fragments) self.assertEqual(user.durable_notes, ("Prefers direct progress updates.",)) - def test_operator_profile_surface_persists_structured_biography_fields_in_profile_summary(self) -> None: + def test_operator_profile_surface_persists_structured_biography_fields_in_profile_summary( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") @@ -1648,7 +1782,9 @@ def test_operator_profile_surface_can_update_identity_posture(self) -> None: self.assertEqual(identity.personality_preset, "operator") self.assertEqual(identity.initiative, "proactive") - def test_personal_model_update_tool_runtime_uses_refreshed_canonical_state_surface(self) -> None: + def test_personal_model_update_tool_runtime_uses_refreshed_canonical_state_surface( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") @@ -1673,7 +1809,9 @@ def test_personal_model_update_tool_runtime_uses_refreshed_canonical_state_surfa ) self.assertTrue(any("Software engineer" in fact.text for fact in facts)) - def test_profile_persistence_syncs_canonical_owner_records_and_ledgers(self) -> None: + def test_profile_persistence_syncs_canonical_owner_records_and_ledgers( + self, + ) -> None: runtime = self._runtime( profile_payload={ "profile_id": "profile-companion", @@ -1704,7 +1842,9 @@ def test_profile_persistence_syncs_canonical_owner_records_and_ledgers(self) -> self.assertEqual(elephant_identity.elephant_identity_text, "Stay calm, durable, and exact.") self.assertIsNotNone(runtime.repository.load_elephant_identity_for_profile(profile_id)) facts = runtime.repository.list_personal_model_facts(personal_model_id=profile_id, status="active") - self.assertFalse(any(fact.metadata.get("canonical_component") in {"user-profile", "relationship"} for fact in facts)) + self.assertFalse( + any(fact.metadata.get("canonical_component") in {"user-profile", "relationship"} for fact in facts) + ) runtime.update_user_state( profile_id=profile_id, @@ -1785,9 +1925,14 @@ def test_cli_context_capability_surfaces_active_loop_checkpoint(self) -> None: self.assertIn("active-loop-checkpoint:", bundle.rendered_prompt) self.assertIn("Audit the long-horizon loop design", bundle.rendered_prompt) - self.assertIn("Collected Elephant Agent and OpenClaw reference points", bundle.rendered_prompt) + self.assertIn( + "Collected Elephant Agent and OpenClaw reference points", + bundle.rendered_prompt, + ) - def test_delete_elephant_clears_sessions_and_memories_for_that_elephant(self) -> None: + def test_delete_elephant_clears_sessions_and_memories_for_that_elephant( + self, + ) -> None: runtime = self._runtime() session = runtime.create_elephant(elephant_id="atlas") @@ -1795,12 +1940,17 @@ def test_delete_elephant_clears_sessions_and_memories_for_that_elephant(self) -> self.assertEqual(deleted_sessions, 1) self.assertIsNone(runtime.repository.load_episode_state(session.episode_id)) - self.assertEqual(runtime.recall_runtime.store.list(session.episode_id, include_inactive=True), ()) + self.assertEqual( + runtime.recall_runtime.store.list(session.episode_id, include_inactive=True), + (), + ) self.assertIsNotNone(runtime.repository.load_personal_model(session.personal_model_id)) self.assertIsNone(runtime.repository.load_state("state:atlas")) self.assertEqual(runtime.list_herd(), ()) - def test_delete_all_elephants_clears_state_rows_and_preserves_personal_model(self) -> None: + def test_delete_all_elephants_clears_state_rows_and_preserves_personal_model( + self, + ) -> None: runtime = self._runtime() alpha = runtime.create_elephant(elephant_id="alpha") beta = runtime.create_elephant(elephant_id="beta") @@ -1823,7 +1973,9 @@ def test_delete_all_elephants_clears_state_rows_and_preserves_personal_model(sel self.assertEqual([tuple(row) for row in profile_rows], [("you",)]) - def test_create_elephant_reuses_personal_model_without_clearing_growth(self) -> None: + def test_create_elephant_reuses_personal_model_without_clearing_growth( + self, + ) -> None: runtime = self._runtime() original = runtime.create_elephant(elephant_id="atlas") runtime.repository.upsert_personal_model_growth( @@ -1856,7 +2008,9 @@ def test_create_elephant_reuses_personal_model_without_clearing_growth(self) -> assert elephant_state is not None self.assertEqual(elephant_state.elephant_name, "Atlas") - def test_elephants_get_isolated_elephant_identity_under_one_personal_model(self) -> None: + def test_elephants_get_isolated_elephant_identity_under_one_personal_model( + self, + ) -> None: runtime = self._runtime() alpha = runtime.create_elephant(elephant_id="alpha") beta = runtime.create_elephant(elephant_id="beta") @@ -1894,7 +2048,9 @@ def test_start_session_keeps_the_requested_profile_binding(self) -> None: self.assertEqual(continuity.profile.state.display_name, "you") self.assertFalse((runtime.paths.home_dir / "profiles" / "elephant%3Anova" / "profile.json").exists()) - def test_explain_next_step_does_not_mutate_profile_without_management_tools(self) -> None: + def test_explain_next_step_does_not_mutate_profile_without_management_tools( + self, + ) -> None: runtime = self._runtime( profile_payload={ "profile_id": "profile-companion", diff --git a/tests/unit/cli/test_runtime_extensions.py b/tests/unit/cli/test_runtime_extensions.py index 0dbcfdb..7a5e433 100644 --- a/tests/unit/cli/test_runtime_extensions.py +++ b/tests/unit/cli/test_runtime_extensions.py @@ -33,7 +33,9 @@ def test_load_extension_manifest_resolves_relative_paths_per_section(self) -> No self.assertFalse(hasattr(manifest, "mcp_overrides")) self.assertFalse(hasattr(manifest, "mcp_definitions")) - def test_serialize_manifest_path_keeps_relative_paths_inside_profile_dir(self) -> None: + def test_serialize_manifest_path_keeps_relative_paths_inside_profile_dir( + self, + ) -> None: profile_dir = Path("/tmp/elephant-profile") self.assertEqual( @@ -45,7 +47,9 @@ def test_serialize_manifest_path_keeps_relative_paths_inside_profile_dir(self) - "/opt/shared/tool.yaml", ) - def test_cli_tool_catalog_includes_global_custom_mcp_tools_after_refresh(self) -> None: + def test_cli_tool_catalog_includes_global_custom_mcp_tools_after_refresh( + self, + ) -> None: with tempfile.TemporaryDirectory() as tempdir: root = Path(tempdir) state_dir = root / "state" @@ -70,7 +74,11 @@ def test_cli_tool_catalog_includes_global_custom_mcp_tools_after_refresh(self) - "label": "Filesystem", "transport": "stdio", "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp/demo"], + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp/demo", + ], "tools": { "read_file": { "display_name": "Read File", diff --git a/tests/unit/cli/test_runtime_filesystem_layout.py b/tests/unit/cli/test_runtime_filesystem_layout.py index d6f52df..6db8990 100644 --- a/tests/unit/cli/test_runtime_filesystem_layout.py +++ b/tests/unit/cli/test_runtime_filesystem_layout.py @@ -38,7 +38,10 @@ def test_create_elephant_creates_elephant_root_and_tools_write_there(self) -> No self.assertEqual(result.outcome, "success") self.assertTrue((elephant_root / "notes" / "plan.txt").exists()) self.assertTrue((runtime.paths.builtin_skills_dir / ".manifest.json").exists()) - self.assertEqual(runtime.paths.cron_jobs_path.resolve(), (root / "cron" / "jobs.json").resolve()) + self.assertEqual( + runtime.paths.cron_jobs_path.resolve(), + (root / "cron" / "jobs.json").resolve(), + ) self.assertEqual(runtime.paths.pairing_dir.resolve(), (root / "pairing").resolve()) diff --git a/tests/unit/cli/test_runtime_learning.py b/tests/unit/cli/test_runtime_learning.py index 10ffa0c..92e4cfb 100644 --- a/tests/unit/cli/test_runtime_learning.py +++ b/tests/unit/cli/test_runtime_learning.py @@ -1,6 +1,5 @@ from __future__ import annotations -from datetime import datetime, timezone import json from pathlib import Path from types import SimpleNamespace @@ -13,7 +12,9 @@ class CliRuntimeLearningTest(unittest.TestCase): - def test_schedule_learning_for_session_enqueues_job_and_surfaces_status(self) -> None: + def test_schedule_learning_for_session_enqueues_job_and_surfaces_status( + self, + ) -> None: with tempfile.TemporaryDirectory() as tempdir: root = Path(tempdir) state_dir = root / "state" @@ -38,7 +39,10 @@ def test_schedule_learning_for_session_enqueues_job_and_surfaces_status(self) -> mode="companion", ) - with mock.patch("apps.learning_worker_runtime.ensure_learning_worker_running", return_value=True) as ensure_worker: + with mock.patch( + "apps.learning_worker_runtime.ensure_learning_worker_running", + return_value=True, + ) as ensure_worker: job = runtime.schedule_learning_for_session( session_id=session.episode_id, trigger="clear", @@ -87,7 +91,10 @@ def test_learn_cli_run_list_and_kill(self) -> None: runtime = CliRuntime.create(state_dir=state_dir) session = runtime.create_elephant(elephant_id="atlas") - with mock.patch("apps.learning_worker_runtime.ensure_learning_worker_running", return_value=True) as ensure_worker: + with mock.patch( + "apps.learning_worker_runtime.ensure_learning_worker_running", + return_value=True, + ) as ensure_worker: run_exit = cli_main_impl._run_learn( runtime, SimpleNamespace(learn_command="run", elephant_id="atlas", limit=12, wait=False), @@ -96,7 +103,14 @@ def test_learn_cli_run_list_and_kill(self) -> None: runtime, SimpleNamespace(learn_command="list", elephant_id="atlas", limit=12), ) - with mock.patch("apps.learning_worker_runtime.stop_learning_worker", return_value={"status": "stopped", "stopped_pid": None, "signal_sent": False}) as stop_worker: + with mock.patch( + "apps.learning_worker_runtime.stop_learning_worker", + return_value={ + "status": "stopped", + "stopped_pid": None, + "signal_sent": False, + }, + ) as stop_worker: kill_exit = cli_main_impl._run_learn( runtime, SimpleNamespace(learn_command="kill", elephant_id=None, limit=12), @@ -106,12 +120,17 @@ def test_learn_cli_run_list_and_kill(self) -> None: self.assertEqual(list_exit, 0) self.assertEqual(kill_exit, 0) ensure_worker.assert_called_once() - stop_worker.assert_called_once_with(state_dir=runtime.paths.state_dir, reason="operator requested learn kill") + stop_worker.assert_called_once_with( + state_dir=runtime.paths.state_dir, + reason="operator requested learn kill", + ) jobs = runtime.repository.list_learning_jobs(episode_id=session.episode_id) self.assertEqual(len(jobs), 1) self.assertEqual(jobs[0].trigger, "manual") - def test_learn_run_wait_uses_subprocess_once_without_starting_background_worker(self) -> None: + def test_learn_run_wait_uses_subprocess_once_without_starting_background_worker( + self, + ) -> None: with tempfile.TemporaryDirectory() as tempdir: state_dir = Path(tempdir) / "state" state_dir.mkdir(parents=True, exist_ok=True) @@ -119,18 +138,29 @@ def test_learn_run_wait_uses_subprocess_once_without_starting_background_worker( runtime.create_elephant(elephant_id="atlas") completed = SimpleNamespace(returncode=0) - with mock.patch("apps.learning_worker_runtime.ensure_learning_worker_running", return_value=True) as ensure_worker: + with mock.patch( + "apps.learning_worker_runtime.ensure_learning_worker_running", + return_value=True, + ) as ensure_worker: with mock.patch.object(cli_main_impl.subprocess, "run", return_value=completed) as run_worker: exit_code = cli_main_impl._run_learn( runtime, - SimpleNamespace(learn_command="run", elephant_id="atlas", limit=12, wait=True), + SimpleNamespace( + learn_command="run", + elephant_id="atlas", + limit=12, + wait=True, + ), ) self.assertEqual(exit_code, 0) ensure_worker.assert_not_called() run_worker.assert_called_once() command = run_worker.call_args.args[0] - self.assertEqual(command[:3], (cli_main_impl.sys.executable, "-m", "apps.learning_worker_command")) + self.assertEqual( + command[:3], + (cli_main_impl.sys.executable, "-m", "apps.learning_worker_command"), + ) self.assertIn("--once", command) self.assertIn(str(runtime.paths.state_dir), command) @@ -146,7 +176,12 @@ def test_learn_run_marks_job_failed_when_subprocess_crashes(self) -> None: with mock.patch("apps.learning_worker_runtime.mark_learning_job_terminal_failure") as mark_failed: exit_code = cli_main_impl._run_learn( runtime, - SimpleNamespace(learn_command="run", elephant_id="atlas", limit=12, wait=True), + SimpleNamespace( + learn_command="run", + elephant_id="atlas", + limit=12, + wait=True, + ), ) self.assertEqual(exit_code, 139) @@ -172,7 +207,9 @@ def test_normal_wake_turn_does_not_start_old_queued_learning_jobs(self) -> None: ensure_worker.assert_not_called() - def test_learning_sub_agent_child_episode_history_is_preserved_after_job_completion(self) -> None: + def test_learning_sub_agent_child_episode_history_is_preserved_after_job_completion( + self, + ) -> None: from apps.learning_worker_runtime import run_learning_job with tempfile.TemporaryDirectory() as tempdir: @@ -185,7 +222,9 @@ def test_learning_sub_agent_child_episode_history_is_preserved_after_job_complet reason="learning_child", summary="learning child Episode opened", ).episode - job = runtime.schedule_learning_for_session(session_id=session.episode_id, trigger="manual", start_worker=False) + job = runtime.schedule_learning_for_session( + session_id=session.episode_id, trigger="manual", start_worker=False + ) agent_result = SimpleNamespace( status="completed", summary="done", @@ -193,7 +232,10 @@ def test_learning_sub_agent_child_episode_history_is_preserved_after_job_complet child_episode_id=child.episode_id, ) - with mock.patch("apps.learning_agents.run_background_learning_agent", return_value=agent_result): + with mock.patch( + "apps.learning_agents.run_background_learning_agent", + return_value=agent_result, + ): run_learning_job(runtime, job, worker_id="worker:test") self.assertIsNotNone(runtime.repository.load_episode(child.episode_id)) @@ -253,7 +295,10 @@ def complete_one(worker_runtime, job, *, worker_id: str) -> None: progress_detail="test completed one job", ) - with mock.patch("apps.learning_worker_runtime.run_learning_job", side_effect=complete_one): + with mock.patch( + "apps.learning_worker_runtime.run_learning_job", + side_effect=complete_one, + ): exit_code = run_learning_worker(state_dir=state_dir, once=True) self.assertEqual(exit_code, 0) diff --git a/tests/unit/cli/test_runtime_provider.py b/tests/unit/cli/test_runtime_provider.py index c761110..9f0646e 100644 --- a/tests/unit/cli/test_runtime_provider.py +++ b/tests/unit/cli/test_runtime_provider.py @@ -27,7 +27,9 @@ def _runtime(self, *, model_provider) -> CliRuntime: security_policy=mock.Mock(), ) - def test_discover_provider_models_uses_discovered_credentials_when_profile_is_inactive(self) -> None: + def test_discover_provider_models_uses_discovered_credentials_when_profile_is_inactive( + self, + ) -> None: model_provider = mock.Mock() model_provider.active_profile.return_value = None model_provider.resolve_discovered_credentials.return_value = {"api_key": "ghu-discovered"} @@ -47,7 +49,9 @@ def test_discover_provider_models_uses_discovered_credentials_when_profile_is_in api_key="ghu-discovered", ) - def test_provider_doctor_rejects_placeholder_model_before_runtime_probe(self) -> None: + def test_provider_doctor_rejects_placeholder_model_before_runtime_probe( + self, + ) -> None: model_provider = mock.Mock() model_provider.describe.return_value = { "provider_id": "openai-compatible", @@ -57,14 +61,19 @@ def test_provider_doctor_rejects_placeholder_model_before_runtime_probe(self) -> "model_id": "model-id", "base_url": "https://api.example.test/v1", } - model_provider.runtime_resolver.build_setup_guide.return_value = mock.Mock(as_mapping=mock.Mock(return_value={})) + model_provider.runtime_resolver.build_setup_guide.return_value = mock.Mock( + as_mapping=mock.Mock(return_value={}) + ) runtime = self._runtime(model_provider=model_provider) - with mock.patch.object(CliRuntime, "discover_provider_models", autospec=True, return_value=()), mock.patch.object( - CliRuntime, - "provider_test", - autospec=True, - ) as provider_test: + with ( + mock.patch.object(CliRuntime, "discover_provider_models", autospec=True, return_value=()), + mock.patch.object( + CliRuntime, + "provider_test", + autospec=True, + ) as provider_test, + ): report = runtime.provider_doctor() provider_test.assert_not_called() @@ -84,14 +93,19 @@ def test_provider_doctor_surfaces_embedding_bootstrap_state(self) -> None: "embedding_bootstrap_status": "pending", "embedding_bootstrap_summary": "local semantic-index bootstrap is preparing minimal sentence-transformers dependencies in the background.", } - model_provider.runtime_resolver.build_setup_guide.return_value = mock.Mock(as_mapping=mock.Mock(return_value={})) + model_provider.runtime_resolver.build_setup_guide.return_value = mock.Mock( + as_mapping=mock.Mock(return_value={}) + ) runtime = self._runtime(model_provider=model_provider) - with mock.patch.object(CliRuntime, "discover_provider_models", autospec=True, return_value=()), mock.patch.object( - CliRuntime, - "provider_test", - autospec=True, - return_value=mock.Mock(summary="Doctor check"), + with ( + mock.patch.object(CliRuntime, "discover_provider_models", autospec=True, return_value=()), + mock.patch.object( + CliRuntime, + "provider_test", + autospec=True, + return_value=mock.Mock(summary="Doctor check"), + ), ): report = runtime.provider_doctor() @@ -111,14 +125,19 @@ def test_provider_doctor_shallow_skips_live_catalog_and_probe(self) -> None: "embedding_bootstrap_status": "ready", "embedding_bootstrap_summary": "local bootstrap is ready", } - model_provider.runtime_resolver.build_setup_guide.return_value = mock.Mock(as_mapping=mock.Mock(return_value={})) + model_provider.runtime_resolver.build_setup_guide.return_value = mock.Mock( + as_mapping=mock.Mock(return_value={}) + ) runtime = self._runtime(model_provider=model_provider) - with mock.patch.object(CliRuntime, "discover_provider_models", autospec=True) as discover, mock.patch.object( - CliRuntime, - "provider_test", - autospec=True, - ) as provider_test: + with ( + mock.patch.object(CliRuntime, "discover_provider_models", autospec=True) as discover, + mock.patch.object( + CliRuntime, + "provider_test", + autospec=True, + ) as provider_test, + ): report = runtime.provider_doctor(deep=False) discover.assert_not_called() @@ -191,7 +210,11 @@ def test_set_local_embedding_provider_marks_override_inactive(self) -> None: provider_id="openai-compatible-embed", secret_name="api_token", secret_key="api_key", - metadata={"storage": "local-vault", "scope": "embedding-provider", "env_var": "OPENAI_API_KEY"}, + metadata={ + "storage": "local-vault", + "scope": "embedding-provider", + "env_var": "OPENAI_API_KEY", + }, ), ), metadata={"embedding_active": "true", "dimensions": "1536"}, @@ -207,7 +230,10 @@ def test_set_local_embedding_provider_marks_override_inactive(self) -> None: secret_references=active_profile.secret_references, metadata={"embedding_active": "false", "dimensions": "1536"}, ) - runtime.repository.load_auth_profile.side_effect = [active_profile, inactive_profile] + runtime.repository.load_auth_profile.side_effect = [ + active_profile, + inactive_profile, + ] summary = dict(runtime.set_local_embedding_provider()) diff --git a/tests/unit/cli/test_runtime_turns.py b/tests/unit/cli/test_runtime_turns.py index b749c44..7923181 100644 --- a/tests/unit/cli/test_runtime_turns.py +++ b/tests/unit/cli/test_runtime_turns.py @@ -10,7 +10,9 @@ class RuntimeTurnsReasoningPayloadTests(unittest.TestCase): - def test_payload_with_turn_reasoning_stays_on_event_payload_not_structured_evidence_copy(self) -> None: + def test_payload_with_turn_reasoning_stays_on_event_payload_not_structured_evidence_copy( + self, + ) -> None: outcome = SimpleNamespace( state=SimpleNamespace(next_step="", summary=""), execution=SimpleNamespace( @@ -25,9 +27,18 @@ def test_payload_with_turn_reasoning_stays_on_event_payload_not_structured_evide decision_summary="Draft the answer.", ) - self.assertEqual(payload["reasoning_trace"], "Inspect tool evidence before drafting the answer.") - self.assertEqual(payload["raw_reasoning_trace"], "Inspect tool evidence before drafting the answer.") - self.assertEqual(payload["reasoning_summary"], "Inspect tool evidence before drafting the answer.") + self.assertEqual( + payload["reasoning_trace"], + "Inspect tool evidence before drafting the answer.", + ) + self.assertEqual( + payload["raw_reasoning_trace"], + "Inspect tool evidence before drafting the answer.", + ) + self.assertEqual( + payload["reasoning_summary"], + "Inspect tool evidence before drafting the answer.", + ) self.assertEqual(payload["reasoning_provenance"], "provider.raw_trace") observation = ReconciliationPipeline().observe_turn( @@ -51,10 +62,15 @@ def test_payload_with_turn_reasoning_stays_on_event_payload_not_structured_evide elephant_id="elephant-1", ) - self.assertEqual(observation.durable_events[0].payload["reasoning_trace"], "Inspect tool evidence before drafting the answer.") + self.assertEqual( + observation.durable_events[0].payload["reasoning_trace"], + "Inspect tool evidence before drafting the answer.", + ) self.assertIn("Step records", observation.summary) - def test_payload_with_turn_reasoning_falls_back_to_decision_summary_when_trace_missing(self) -> None: + def test_payload_with_turn_reasoning_falls_back_to_decision_summary_when_trace_missing( + self, + ) -> None: outcome = SimpleNamespace( state=SimpleNamespace(next_step="Call the provider health check.", summary=""), execution=SimpleNamespace( @@ -75,7 +91,9 @@ def test_payload_with_turn_reasoning_falls_back_to_decision_summary_when_trace_m class RuntimeTurnsCompactionTests(unittest.TestCase): - def test_reflect_compress_summary_allows_reflect_inside_sub_agent_runtime(self) -> None: + def test_reflect_compress_summary_allows_reflect_inside_sub_agent_runtime( + self, + ) -> None: runtime = SimpleNamespace( sub_agent_active=True, _load_session=mock.Mock( diff --git a/tests/unit/cli/test_shell.py b/tests/unit/cli/test_shell.py index d590d63..edbe329 100644 --- a/tests/unit/cli/test_shell.py +++ b/tests/unit/cli/test_shell.py @@ -21,14 +21,31 @@ build_composer_body, prompt_style_map, ) -from apps.cli.shell_progress import _VisibleToolEvent, latest_stream_text, reset_stream_text, stream_text_tracker, turn_tool_progress_lines +from apps.cli.shell_progress import ( + _VisibleToolEvent, + latest_stream_text, + reset_stream_text, + stream_text_tracker, + turn_tool_progress_lines, +) import apps.cli.shell_progress_runtime as shell_progress_runtime from apps.cli.shell_render import _render_tooltrace_body_line import apps.cli.shell_render as shell_render -from apps.cli.shell_banner import _learning_job_execution_summary, _skill_affinity_summary +from apps.cli.shell_banner import ( + _learning_job_execution_summary, + _skill_affinity_summary, +) import apps.cli.shell_progress_trace as shell_progress_trace -from apps.cli.shell_clarify import ShellClarifyState, render_clarify_fragments, route_clarify_answer -from apps.cli.shell_stack import FormattedTextControl as StackFormattedTextControl, ScrollablePane, Window as StackWindow +from apps.cli.shell_clarify import ( + ShellClarifyState, + render_clarify_fragments, + route_clarify_answer, +) +from apps.cli.shell_stack import ( + FormattedTextControl as StackFormattedTextControl, + ScrollablePane, + Window as StackWindow, +) from apps.cli.shell import ( BRAND_ACCENT, BRAND_DARK, @@ -39,7 +56,6 @@ Console, Document, ELEPHANT_STAGE_ROWS, - ELEPHANT_STAGE_ROWS, GROWTH_PROGRESS_EMPTY, GROWTH_PROGRESS_FILLED, GROWTH_PROGRESS_WIDTH, @@ -220,7 +236,9 @@ def test_shell_allows_opt_in_alternate_screen(self) -> None: self.assertTrue(shell._use_alternate_screen) - def test_refresh_shell_frame_does_not_clear_or_replay_same_frame_in_scrollback_mode(self) -> None: + def test_refresh_shell_frame_does_not_clear_or_replay_same_frame_in_scrollback_mode( + self, + ) -> None: shell = self._make_shell_without_identity_update() shell.console = _CaptureConsole(100) shell._last_shell_frame_token = shell._current_shell_frame_token() @@ -232,7 +250,9 @@ def test_refresh_shell_frame_does_not_clear_or_replay_same_frame_in_scrollback_m self.assertEqual(shell.console.printed, []) self.assertEqual(shell._rendered_entries, 2) - def test_refresh_shell_frame_clears_and_replays_in_alternate_screen_mode(self) -> None: + def test_refresh_shell_frame_clears_and_replays_in_alternate_screen_mode( + self, + ) -> None: shell = self._make_shell_without_identity_update() shell.console = _CaptureConsole(100) shell._use_alternate_screen = True @@ -245,7 +265,9 @@ def test_refresh_shell_frame_clears_and_replays_in_alternate_screen_mode(self) - self.assertGreaterEqual(len(shell.console.printed), 1) self.assertEqual(shell._rendered_entries, 0) - def test_prime_transcript_uses_elephant_state_name_for_assistant_title(self) -> None: + def test_prime_transcript_uses_elephant_state_name_for_assistant_title( + self, + ) -> None: shell = self._make_shell(opened="Opened elephant atlas") shell.runtime.update_identity_state( session_id=shell.session_id, @@ -344,9 +366,7 @@ def test_command_palette_stays_minimal_and_identity_focused(self) -> None: self.assertIn("resume", cron_commands) self.assertIn("remove", cron_commands) - removed_whoami_commands = { - item.text for item in completer.get_completions(Document("/whoami "), None) - } + removed_whoami_commands = {item.text for item in completer.get_completions(Document("/whoami "), None)} self.assertEqual(set(), removed_whoami_commands) gateway_commands = {item.text for item in completer.get_completions(Document("/gateway "), None)} @@ -384,7 +404,9 @@ def test_latest_learning_notice_ignores_regular_turn_experience(self) -> None: self.assertFalse(any(entry.title == "Learning" for entry in shell.transcript)) - def test_latest_learning_notice_surfaces_completed_learning_result_once(self) -> None: + def test_latest_learning_notice_surfaces_completed_learning_result_once( + self, + ) -> None: shell = self._make_shell() job = shell.runtime.schedule_learning_for_session( session_id=shell.session_id, @@ -434,7 +456,9 @@ def test_existing_learning_result_is_not_replayed_when_shell_opens(self) -> None self.assertFalse(any(entry.title == "Learning" for entry in shell.transcript)) - def test_conversational_surface_requests_list_tools_on_explicit_show_list_verbs(self) -> None: + def test_conversational_surface_requests_list_tools_on_explicit_show_list_verbs( + self, + ) -> None: shell = self._make_shell() handled_tools = shell._handle_conversational_surface_request("show tools") @@ -469,7 +493,9 @@ def test_conversational_questions_about_skills_no_longer_bypass_shell(self) -> N self.assertFalse(handled_skills) self.assertEqual(len(shell.transcript), original_len) - def test_skills_search_routes_through_skill_search_tool_and_records_tooltrace(self) -> None: + def test_skills_search_routes_through_skill_search_tool_and_records_tooltrace( + self, + ) -> None: shell = self._make_shell() with mock.patch.object( shell.runtime.skill_search_hub, @@ -496,7 +522,9 @@ def test_skills_search_routes_through_skill_search_tool_and_records_tooltrace(se self.assertEqual(shell.transcript[-1].title, "Skill search") self.assertIn("github:openai/skills/apple-notes", shell.transcript[-1].body) - def test_plain_turn_with_explicit_skill_name_no_longer_routes_skill_body(self) -> None: + def test_plain_turn_with_explicit_skill_name_no_longer_routes_skill_body( + self, + ) -> None: shell = self._make_shell() outcome = mock.Mock() with ( @@ -515,9 +543,15 @@ def test_plain_turn_with_explicit_skill_name_no_longer_routes_skill_body(self) - self.assertIsNone(run_turn.call_args.kwargs["event_payload"]) append_outcome.assert_called_once_with(outcome) - def test_dispatch_clears_pending_context_compaction_frame_before_next_turn(self) -> None: + def test_dispatch_clears_pending_context_compaction_frame_before_next_turn( + self, + ) -> None: shell = self._make_shell() - shell._pending_context_compaction_frame = {"prompt": "previous", "tick": 0, "kernel_stage_events": ()} + shell._pending_context_compaction_frame = { + "prompt": "previous", + "tick": 0, + "kernel_stage_events": (), + } shell._pending_context_compaction_frame_rendered = True with ( mock.patch.object(shell, "_handle_slash_command", return_value=False), @@ -641,7 +675,9 @@ def test_dispatch_reads_compacted_context_usage_from_outcome_stage(self) -> None self.assertEqual(shell._last_provider_prompt_tokens, 0) self.assertEqual(shell._last_prompt_tokens, 6_200) - def test_plain_turn_with_contextual_skill_phrase_no_longer_routes_skill_body(self) -> None: + def test_plain_turn_with_contextual_skill_phrase_no_longer_routes_skill_body( + self, + ) -> None: shell = self._make_shell() outcome = mock.Mock() with ( @@ -660,7 +696,9 @@ def test_plain_turn_with_contextual_skill_phrase_no_longer_routes_skill_body(sel self.assertIsNone(run_turn.call_args.kwargs["event_payload"]) append_outcome.assert_called_once_with(outcome) - def test_skill_slash_specs_include_full_local_skill_hub_not_first_page_only(self) -> None: + def test_skill_slash_specs_include_full_local_skill_hub_not_first_page_only( + self, + ) -> None: shell = self._make_shell() spec_ids = {spec.skill_id for spec in shell.skill_slash_specs()} @@ -714,11 +752,16 @@ def test_skills_install_routes_through_runtime_skill_catalog(self) -> None: refresh_specs.assert_called_once_with() self.assertEqual(shell.transcript[-1].title, "Skill installed") self.assertIn("detail: installed via GitHub (trusted)", shell.transcript[-1].body) - self.assertIn("source_reference: github:openai/skills/apple-notes", shell.transcript[-1].body) + self.assertIn( + "source_reference: github:openai/skills/apple-notes", + shell.transcript[-1].body, + ) self.assertIn("install_action: install", shell.transcript[-1].body) self.assertIn("install_requester: operator", shell.transcript[-1].body) - def test_growth_panel_keeps_removed_procedural_memory_out_of_learning_overview(self) -> None: + def test_growth_panel_keeps_removed_procedural_memory_out_of_learning_overview( + self, + ) -> None: shell = self._make_shell() session = shell.runtime.inspect_session(shell.session_id) @@ -729,7 +772,9 @@ def test_growth_panel_keeps_removed_procedural_memory_out_of_learning_overview(s self.assertFalse(any("Release State Recovery" in line for line in lines)) self.assertIn("latest · no captured grounded experience yet", lines) - def test_growth_panel_filters_noisy_failure_experiences_from_learning_overview(self) -> None: + def test_growth_panel_filters_noisy_failure_experiences_from_learning_overview( + self, + ) -> None: shell = self._make_shell() session = shell.runtime.inspect_session(shell.session_id) @@ -740,7 +785,9 @@ def test_growth_panel_filters_noisy_failure_experiences_from_learning_overview(s self.assertFalse(any("skill manager is having some trouble" in line for line in lines)) self.assertIn("latest · no captured grounded experience yet", lines) - def test_conversational_surface_request_reads_specific_web_page_without_hitting_model(self) -> None: + def test_conversational_surface_request_reads_specific_web_page_without_hitting_model( + self, + ) -> None: shell = self._make_shell() server = _WebPageStubServer().start() self.addCleanup(server.close) @@ -758,11 +805,14 @@ def test_conversational_surface_request_reads_specific_web_page_without_hitting_ def test_provider_configure_cancels_when_wizard_is_escaped(self) -> None: shell = self._make_shell() - with mock.patch("apps.cli.shell.run_provider_selection_wizard", return_value=WIZARD_BACK), mock.patch.object( - CliRuntime, - "set_default_provider", - autospec=True, - ) as set_default_provider: + with ( + mock.patch("apps.cli.shell.run_provider_selection_wizard", return_value=WIZARD_BACK), + mock.patch.object( + CliRuntime, + "set_default_provider", + autospec=True, + ) as set_default_provider, + ): shell._append_providers([]) set_default_provider.assert_not_called() @@ -782,17 +832,22 @@ def test_models_configure_cancels_when_wizard_is_escaped(self) -> None: api_key="sk-test", ) - with mock.patch("apps.cli.shell.run_provider_selection_wizard", return_value=WIZARD_BACK), mock.patch.object( - CliRuntime, - "set_default_provider", - autospec=True, - ) as set_default_provider: + with ( + mock.patch("apps.cli.shell.run_provider_selection_wizard", return_value=WIZARD_BACK), + mock.patch.object( + CliRuntime, + "set_default_provider", + autospec=True, + ) as set_default_provider, + ): shell._append_models([]) set_default_provider.assert_not_called() self.assertEqual(shell.transcript[-1].body, "Model setup cancelled.") - def test_work_surface_discloses_resolved_state_focus_scope_and_fallback(self) -> None: + def test_work_surface_discloses_resolved_state_focus_scope_and_fallback( + self, + ) -> None: shell = self._make_shell() session = shell.runtime.inspect_session(shell.session_id) profile = shell.runtime.inspect_profile(session.personal_model_id) @@ -821,7 +876,11 @@ def test_work_surface_discloses_resolved_state_focus_scope_and_fallback(self) -> focus_assist_outcome="suggested", selection_path="embedding-unavailable.weak-assist.suggested.narrow", reasons=( - StateFocusReason("continuation", "The prompt continues the active rollout thread.", 0.9), + StateFocusReason( + "continuation", + "The prompt continues the active rollout thread.", + 0.9, + ), StateFocusReason("focus", "The active work stays ahead of generic recall.", 0.8), ), audit_trace=("stage3: fallback path -> embedding-unavailable.weak-assist.suggested.narrow",), @@ -832,7 +891,9 @@ def test_work_surface_discloses_resolved_state_focus_scope_and_fallback(self) -> self.assertFalse(shell._handle_slash_command("/work")) self.assertEqual(shell.transcript[-1].title, "Unknown command") - def test_conversational_surface_requests_can_schedule_prompt_cron_and_list_jobs(self) -> None: + def test_conversational_surface_requests_can_schedule_prompt_cron_and_list_jobs( + self, + ) -> None: shell = self._make_shell() created = shell._handle_conversational_surface_request("schedule a prompt to tell me a joke every morning") @@ -869,22 +930,30 @@ def test_due_cron_tick_appends_prompt_result_to_open_shell(self) -> None: self.assertIn("cron", shell.transcript[-1].meta) self.assertFalse(shell.runtime.has_due_cron_jobs(session_id=shell.session_id)) - def test_prompt_cron_job_references_requested_skill_without_body_injection(self) -> None: + def test_prompt_cron_job_references_requested_skill_without_body_injection( + self, + ) -> None: shell = self._make_shell() skill = shell.runtime.inspect_skill("arxiv", session_id=shell.session_id) shell.runtime.create_cron_job( session_id=shell.session_id, name="Paper scan", schedule="2000-01-01T00:00:00+00:00", - payload={"prompt": "find papers and write a markdown note", "skills": ["arxiv"]}, + payload={ + "prompt": "find papers and write a markdown note", + "skills": ["arxiv"], + }, ) outcome = SimpleNamespace(execution=SimpleNamespace(summary="wrote paper note")) - with mock.patch.object(type(shell.runtime), "inspect_skill", return_value=skill) as inspect_skill, mock.patch.object( - type(shell.runtime), - "explain_next_step", - return_value=outcome, - ) as explain: + with ( + mock.patch.object(type(shell.runtime), "inspect_skill", return_value=skill) as inspect_skill, + mock.patch.object( + type(shell.runtime), + "explain_next_step", + return_value=outcome, + ) as explain, + ): executions = shell.runtime.run_due_cron_jobs(session_id=shell.session_id) self.assertEqual(executions[0].summary, "wrote paper note") @@ -926,7 +995,11 @@ def explain_next_step(*, session_id: str, prompt: str): ) as create_child_runtime: result = shell.runtime.tool_runtime.invoke( "tool.sub_agents", - {"task": "inspect the cron implementation", "name": "reviewer", "skills": ["subagent-driven-development"]}, + { + "task": "inspect the cron implementation", + "name": "reviewer", + "skills": ["subagent-driven-development"], + }, session_id=shell.session_id, requester="model", ) @@ -943,10 +1016,14 @@ def explain_next_step(*, session_id: str, prompt: str): self.assertIn("Do not call tool.sub_agents", prompt) self.assertIn("Sub-agent name: reviewer", prompt) - def test_learning_sub_agent_uses_dedicated_system_prompt_without_generic_wrapper(self) -> None: + def test_learning_sub_agent_uses_dedicated_system_prompt_without_generic_wrapper( + self, + ) -> None: shell = self._make_shell() captured: dict[str, object] = {} - child_tool_runtime = SimpleNamespace(subscribe=mock.Mock(return_value=mock.Mock()), descriptor=SimpleNamespace()) + child_tool_runtime = SimpleNamespace( + subscribe=mock.Mock(return_value=mock.Mock()), descriptor=SimpleNamespace() + ) child_runtime = SimpleNamespace( tool_runtime=child_tool_runtime, model_provider=SimpleNamespace(tool_runtime=child_tool_runtime), @@ -966,12 +1043,18 @@ def run_turn(**kwargs): ) child_runtime._run_turn = mock.Mock(side_effect=run_turn) - with mock.patch("apps.cli.runtime_cron_sub_agents._create_child_runtime", return_value=child_runtime): + with mock.patch( + "apps.cli.runtime_cron_sub_agents._create_child_runtime", + return_value=child_runtime, + ): result = shell.runtime.run_sub_agent( session_id=shell.session_id, task="Mode: manual\nLearning context packet: compact facts", name="Manual learning", - allowed_tools=("tool.personal_model.search", "tool.personal_model.update"), + allowed_tools=( + "tool.personal_model.search", + "tool.personal_model.update", + ), system_prompt="[SYSTEM: Background Learning Agent]", learning_agent=True, ) @@ -984,7 +1067,9 @@ def run_turn(**kwargs): self.assertEqual(event_payload["context_mode"], "learning_agent") self.assertNotIn("bounded Elephant Agent sub-agent", str(captured["prompt"])) - def test_sub_agents_start_returns_handle_and_emits_child_lifecycle_events(self) -> None: + def test_sub_agents_start_returns_handle_and_emits_child_lifecycle_events( + self, + ) -> None: shell = self._make_shell() child_started = threading.Event() release_child = threading.Event() @@ -1051,8 +1136,7 @@ def explain_next_step(*, session_id: str, prompt: str): child_events = [ event for event in captured_events - if event.invocation.tool_id == "tool.sub_agents" - and event.invocation.arguments.get("sub_agent_child") + if event.invocation.tool_id == "tool.sub_agents" and event.invocation.arguments.get("sub_agent_child") ] self.assertTrue(any(event.phase == "execution.started" for event in child_events)) self.assertTrue(any(event.phase == "execution.completed" for event in child_events)) @@ -1136,7 +1220,9 @@ def test_tool_trace_emoji_covers_builtin_chat_tools(self) -> None: self.assertEqual(shell_progress_trace._tool_trace_emoji(tool_id), emoji) self.assertEqual(shell_progress_trace._tool_trace_emoji("mcp.km.hot-articles"), "🧩") - def test_clarify_blocks_for_shell_input_and_returns_answer_as_tool_result(self) -> None: + def test_clarify_blocks_for_shell_input_and_returns_answer_as_tool_result( + self, + ) -> None: shell = self._make_shell() shell.runtime.set_clarify_surface(shell._interactive_clarify_surface()) holder: dict[str, ExecutionResult] = {} @@ -1260,7 +1346,10 @@ def test_prompt_style_keeps_live_composer_unboxed(self) -> None: self.assertEqual(style_map["progress-tool-label"], f"fg:{BRAND_ACCENT_STRONG} bold") self.assertEqual(style_map["stream-response-body"], f"fg:{BRAND_LIGHT}") self.assertEqual(style_map["status-bar-growth-empty"], f"bg:#1b2029 fg:{BRAND_ACCENT}") - self.assertEqual(style_map["completion-menu.completion.current"], f"bg:#2a3343 fg:{BRAND_ACCENT_STRONG} bold") + self.assertEqual( + style_map["completion-menu.completion.current"], + f"bg:#2a3343 fg:{BRAND_ACCENT_STRONG} bold", + ) self.assertEqual(style_map["scrollbar.button"], f"bg:{BRAND_ACCENT}") self.assertNotIn("bg:", style_map[""]) self.assertNotIn("bg:", style_map["composer-prefix"]) @@ -1313,9 +1402,7 @@ def test_queued_followup_enters_transcript_only_when_dispatched(self) -> None: shell._dispatch(shell._next_command().command) queued_entries = [ - entry - for entry in shell.transcript - if entry.kind == "user" and entry.body == "queued followup" + entry for entry in shell.transcript if entry.kind == "user" and entry.body == "queued followup" ] self.assertEqual(len(queued_entries), 1) @@ -1339,9 +1426,7 @@ def test_queue_preview_rows_are_narrower_than_sent_user_rows(self) -> None: shell.console = _StubConsole(48) shell._enqueue_followup_command("queued followup") - preview_lines = "".join( - text for _style, text in shell._render_queued_followup_fragments() - ).splitlines() + preview_lines = "".join(text for _style, text in shell._render_queued_followup_fragments()).splitlines() sent = shell._render_entry( TranscriptEntry( kind="user", @@ -1370,7 +1455,9 @@ def test_turn_progress_fragments_show_queued_followup_count(self) -> None: self.assertIn("Working", rendered) self.assertIn("queued scrolls · 2 messages", rendered) - def test_turn_progress_fragments_drop_queue_scroll_hint_but_keep_spacing(self) -> None: + def test_turn_progress_fragments_drop_queue_scroll_hint_but_keep_spacing( + self, + ) -> None: shell = self._make_shell() fragments = shell._render_turn_progress_fragments( @@ -1382,7 +1469,9 @@ def test_turn_progress_fragments_drop_queue_scroll_hint_but_keep_spacing(self) - self.assertNotIn("Press Enter to queue another scroll.", rendered) self.assertTrue(rendered.endswith("\n")) - def test_turn_progress_fragments_keep_live_tool_lines_on_separate_rows(self) -> None: + def test_turn_progress_fragments_keep_live_tool_lines_on_separate_rows( + self, + ) -> None: shell = self._make_shell() shell._rendered_entries = len(shell.transcript) shell._append_tooltrace_line("┊ 📚 Calling skill…") @@ -1398,7 +1487,9 @@ def test_turn_progress_fragments_keep_live_tool_lines_on_separate_rows(self) -> self.assertIn("\n┊ 📚 skill apple-notes 0.3s", rendered) self.assertNotIn("skill…┊ 📚 skill", rendered) - def test_turn_progress_fragments_surface_state_focus_resolution_summary(self) -> None: + def test_turn_progress_fragments_surface_state_focus_resolution_summary( + self, + ) -> None: shell = self._make_shell() fragments = shell._render_turn_progress_fragments( @@ -1429,7 +1520,9 @@ def test_turn_progress_fragments_surface_state_focus_resolution_summary(self) -> rendered = "".join(text for _style, text in fragments) self.assertIn("┊ 🧭 focus exploration · 35ms · elephant · conf 0.82", rendered) - def test_turn_progress_fragments_omit_context_and_request_progress_rows(self) -> None: + def test_turn_progress_fragments_omit_context_and_request_progress_rows( + self, + ) -> None: shell = self._make_shell() fragments = shell._render_turn_progress_fragments( @@ -1494,7 +1587,9 @@ def test_record_kernel_event_trace_omits_recall_tooltrace_rows(self) -> None: self.assertFalse([entry for entry in shell.transcript if entry.kind == "tooltrace"]) - def test_record_kernel_event_trace_updates_context_projection_after_compaction(self) -> None: + def test_record_kernel_event_trace_updates_context_projection_after_compaction( + self, + ) -> None: shell = self._make_shell() shell._last_prompt_tokens = 1800 @@ -1515,10 +1610,15 @@ def test_record_kernel_event_trace_updates_context_projection_after_compaction(s self.assertEqual(shell._last_prompt_tokens, 620) tool_entries = [entry for entry in shell.transcript if entry.kind == "tooltrace"] self.assertEqual(len(tool_entries), 1) - self.assertIn("┊ 🧩 context projection compact · est 1800->620 tokens · preflight", tool_entries[0].body) + self.assertIn( + "┊ 🧩 context projection compact · est 1800->620 tokens · preflight", + tool_entries[0].body, + ) self.assertIn("scanner: 2 cached / 5 pending / 1 missed", tool_entries[0].body) - def test_record_kernel_event_trace_uses_provider_prompt_usage_for_status_bar(self) -> None: + def test_record_kernel_event_trace_uses_provider_prompt_usage_for_status_bar( + self, + ) -> None: shell = self._make_shell() shell._last_prompt_tokens = 1800 @@ -1537,7 +1637,9 @@ def test_record_kernel_event_trace_uses_provider_prompt_usage_for_status_bar(sel self.assertEqual(shell._last_provider_prompt_tokens, 720) self.assertFalse([entry for entry in shell.transcript if entry.kind == "tooltrace"]) - def test_record_kernel_event_trace_tracks_latest_context_projection_status(self) -> None: + def test_record_kernel_event_trace_tracks_latest_context_projection_status( + self, + ) -> None: shell = self._make_shell() shell._last_prompt_tokens = 1800 @@ -1582,7 +1684,9 @@ def test_user_history_rows_expand_to_console_width(self) -> None: self.assertEqual(lines[0], "hello from wake shell") self.assertEqual(lines[1], "sent just now") - def test_growth_rows_use_gray_history_background_with_selective_yellow_text(self) -> None: + def test_growth_rows_use_gray_history_background_with_selective_yellow_text( + self, + ) -> None: shell = self._make_shell() shell.console = _StubConsole(48) rendered = shell._render_entry( @@ -1609,7 +1713,10 @@ def test_growth_rows_use_gray_history_background_with_selective_yellow_text(self self.assertIn(f"{BRAND_MUTED} on {USER_HISTORY_BG}", styles) self.assertIn(f"{GROWTH_HIGHLIGHT_FG} on {USER_HISTORY_BG}", styles) else: - self.assertEqual(lines[0], "Something settled into the Personal Model — checkpoint 1 in Evidence I. I'll carry it forward.") + self.assertEqual( + lines[0], + "Something settled into the Personal Model — checkpoint 1 in Evidence I. I'll carry it forward.", + ) self.assertEqual(lines[1], "understanding · checkpoint") def test_composer_divider_tracks_console_width_without_old_cap(self) -> None: @@ -1652,7 +1759,11 @@ def test_elephant_stage_rows_keep_ascii_side_profile_readable(self) -> None: def test_elephant_rows_match_current_ascii_logo(self) -> None: joined = "\n".join(ELEPHANT_STAGE_ROWS) - self.assertIn("/ \\~~~/ \\", joined, msg="ear and head line should survive terminal rendering") + self.assertIn( + "/ \\~~~/ \\", + joined, + msg="ear and head line should survive terminal rendering", + ) self.assertIn("..", joined, msg="eye dots should survive terminal rendering") self.assertIn("`---'", joined, msg="body tail line should survive terminal rendering") centered_rows = _centered_elephant_rows() @@ -1680,12 +1791,7 @@ def test_elephant_mark_renders_full_centered_stage(self) -> None: def test_elephant_rows_keep_sticker_optically_centered(self) -> None: centered = _centered_elephant_rows() self.assertEqual(centered, ELEPHANT_STAGE_ROWS) - visible = [ - index - for row in centered - for index, cell in enumerate(row) - if cell != " " - ] + visible = [index for row in centered for index, cell in enumerate(row) if cell != " "] self.assertTrue(visible) def test_growth_levels_reuse_unified_elephant_mark(self) -> None: @@ -1693,8 +1799,14 @@ def test_growth_levels_reuse_unified_elephant_mark(self) -> None: elephant = shell._render_growth_mark("seed", level=0) seed = shell._render_growth_mark("seed", level=1) if not RICH_AVAILABLE: - self.assertEqual(elephant.plain if hasattr(elephant, "plain") else str(elephant), "[Elephant Agent elephant]") - self.assertEqual(seed.plain if hasattr(seed, "plain") else str(seed), "[Elephant Agent seed]") + self.assertEqual( + elephant.plain if hasattr(elephant, "plain") else str(elephant), + "[Elephant Agent elephant]", + ) + self.assertEqual( + seed.plain if hasattr(seed, "plain") else str(seed), + "[Elephant Agent seed]", + ) return elephant_lines = elephant.plain.splitlines() seed_lines = seed.plain.splitlines() @@ -1722,12 +1834,7 @@ def test_growth_stage_rows_center_visible_pixels(self) -> None: ): centered = visual_centered_rows(rows, width=GROWTH_MARK_CANVAS_WIDTH) self.assertEqual({len(row) for row in centered}, {GROWTH_MARK_CANVAS_WIDTH}) - visible = [ - index - for row in centered - for index, cell in enumerate(row) - if cell != " " - ] + visible = [index for row in centered for index, cell in enumerate(row) if cell != " "] self.assertTrue(visible, msg=label) visible_center = (min(visible) + max(visible)) / 2 canvas_center = (GROWTH_MARK_CANVAS_WIDTH - 1) / 2 @@ -1831,7 +1938,10 @@ def test_shell_frame_banner_summarizes_pm_lenses_and_curiosity(self) -> None: self.assertIn("🐘 What I know", plain) self.assertIn("saved · identity 1 · world 2 · 2 lens empty", plain) - self.assertIn("question (pulse · current_focus) · What should I treat as the current highest-priority thread?", plain) + self.assertIn( + "question (pulse · current_focus) · What should I treat as the current highest-priority thread?", + plain, + ) self.assertIn("🧩 Skills for you", plain) self.assertIn("affinities · 1 learned · 1 active", plain) self.assertNotIn("affinities · Architecture Diagram", plain) @@ -1864,7 +1974,9 @@ def test_skill_affinities_report_metrics_without_skill_names(self) -> None: self.assertEqual(summary, "1 learned · 1 active") - def test_skill_affinities_follow_dashboard_topic_detection_without_projection_filter(self) -> None: + def test_skill_affinities_follow_dashboard_topic_detection_without_projection_filter( + self, + ) -> None: now = datetime.now(timezone.utc) summary = _skill_affinity_summary( facts=( @@ -1898,7 +2010,10 @@ def test_learning_job_execution_summary_counts_executed_jobs(self) -> None: ) ) - self.assertEqual(_learning_job_execution_summary(runtime, "you"), "2 run(s) · 1 completed · 1 failed") + self.assertEqual( + _learning_job_execution_summary(runtime, "you"), + "2 run(s) · 1 completed · 1 failed", + ) def test_shell_frame_surfaces_user_facing_context_summary(self) -> None: shell = self._make_shell() @@ -1948,9 +2063,14 @@ def test_shell_frame_filters_opening_prompt_like_state_text(self) -> None: self.assertIn("now · Ready to pick the thread back up when you are.", rendered) self.assertNotIn("assistant_display_name:", rendered) - self.assertNotIn("Open the wake surface proactively before the user sends a new message.", rendered) + self.assertNotIn( + "Open the wake surface proactively before the user sends a new message.", + rendered, + ) - def test_status_column_renders_carrying_forward_with_bold_label_and_markdown_value(self) -> None: + def test_status_column_renders_carrying_forward_with_bold_label_and_markdown_value( + self, + ) -> None: shell = self._make_shell() session = shell.runtime.inspect_session(shell.session_id) continuity = shell.runtime.inspect_continuity(session_id=shell.session_id) @@ -2159,14 +2279,19 @@ def test_settled_state_focus_meta_stays_muted_in_transcript(self) -> None: ) ) - self.assertIn("routing · resume · 56ms · lineage · 0.94", rendered.plain if hasattr(rendered, "plain") else str(rendered)) + self.assertIn( + "routing · resume · 56ms · lineage · 0.94", + rendered.plain if hasattr(rendered, "plain") else str(rendered), + ) def test_live_state_focus_progress_uses_steady_orange_trace_style(self) -> None: text = shell_progress_trace.render_tool_trace_text("┊ 🧭 routing resume · 56ms · lineage · 0.94") self.assertEqual(text.spans[0].style, shell_render.BRAND_ACCENT_STRONG) - def test_growth_panel_reports_enabled_and_self_learned_skill_counts_without_internal_next_move(self) -> None: + def test_growth_panel_reports_enabled_and_self_learned_skill_counts_without_internal_next_move( + self, + ) -> None: shell = self._make_shell() shell.runtime.create_experience_skill( skill_id="self-learned-shell-fix", @@ -2180,7 +2305,9 @@ def test_growth_panel_reports_enabled_and_self_learned_skill_counts_without_inte continuity = shell.runtime.inspect_continuity(session_id=shell.session_id) provider = dict(shell.runtime.provider_summary()) lines = shell._recent_activity_lines(session, continuity, provider) - enabled_skills = tuple(skill for skill in shell.runtime.skill_catalog(session_id=shell.session_id) if skill.enabled) + enabled_skills = tuple( + skill for skill in shell.runtime.skill_catalog(session_id=shell.session_id) if skill.enabled + ) self.assertIn(f"skills · {len(enabled_skills)} enabled · 1 self-learned", lines) self.assertFalse(any(line.startswith("next move ·") for line in lines)) @@ -2204,14 +2331,19 @@ def test_growth_progress_bar_uses_glyph_bar_and_orange_fill(self) -> None: self.assertIn(BRAND_ACCENT_STRONG, styles) self.assertIn(BRAND_MUTED, styles) - def test_diff_styles_use_brighter_live_palette_and_dimmer_settled_palette(self) -> None: + def test_diff_styles_use_brighter_live_palette_and_dimmer_settled_palette( + self, + ) -> None: style_map = prompt_style_map() self.assertEqual(style_map["progress-output-file"], f"fg:{LIVE_DIFF_FILE_FG} bold") self.assertEqual(style_map["progress-output-hunk"], f"fg:{LIVE_DIFF_HUNK_FG} bold") self.assertEqual(style_map["progress-output-add"], f"fg:{LIVE_DIFF_ADD_FG} bold") self.assertEqual(style_map["progress-output-remove"], f"fg:{LIVE_DIFF_REMOVE_FG} bold") - self.assertEqual(_render_tooltrace_body_line("a/notes.md → b/notes.md").style, SETTLED_DIFF_FILE_FG) + self.assertEqual( + _render_tooltrace_body_line("a/notes.md → b/notes.md").style, + SETTLED_DIFF_FILE_FG, + ) self.assertEqual(_render_tooltrace_body_line("@@ -1 +1 @@").style, SETTLED_DIFF_HUNK_FG) self.assertEqual(_render_tooltrace_body_line("+added").style, SETTLED_DIFF_ADD_FG) self.assertEqual(_render_tooltrace_body_line("-removed").style, SETTLED_DIFF_REMOVE_FG) @@ -2237,7 +2369,10 @@ def test_status_bar_fragments_include_checkpoint_and_growth_progress(self) -> No self.assertIn("12s", rendered) self.assertIn("Evidence I", rendered) self.assertIn(shell._build_context_bar(update.after.progress_percent), rendered) - self.assertIn(f"checkpoint {update.after.level} · {update.after.progress_percent}%", rendered) + self.assertIn( + f"checkpoint {update.after.level} · {update.after.progress_percent}%", + rendered, + ) styles = {style for style, _text in fragments if style} self.assertIn("class:status-bar-level", styles) @@ -2287,7 +2422,9 @@ def test_status_bar_fragments_keep_previous_usage_during_live_turn(self) -> None self.assertNotIn("--/128K", rendered) self.assertNotIn("10%", rendered) - def test_status_bar_fragments_show_committed_provider_usage_after_turn(self) -> None: + def test_status_bar_fragments_show_committed_provider_usage_after_turn( + self, + ) -> None: shell = self._make_shell() shell._last_prompt_tokens = 14_000 shell._last_provider_prompt_tokens = 43_500 @@ -2392,7 +2529,9 @@ def test_turn_progress_frame_keeps_cumulative_tool_rail_visible(self) -> None: self.assertIn("apple-notes", rendered) self.assertIn("memo notes --help", rendered) - def test_turn_progress_frame_renders_streaming_response_in_dedicated_surface(self) -> None: + def test_turn_progress_frame_renders_streaming_response_in_dedicated_surface( + self, + ) -> None: shell = self._make_shell() frame = shell._render_turn_frame( prompt="hello", @@ -2411,7 +2550,9 @@ def test_turn_progress_frame_renders_streaming_response_in_dedicated_surface(sel self.assertIn("First line of the reply.", rendered) self.assertIn("Second line arrives next.", rendered) - def test_turn_progress_frame_formats_reasoning_with_elephant_mind_heading(self) -> None: + def test_turn_progress_frame_formats_reasoning_with_elephant_mind_heading( + self, + ) -> None: shell = self._make_shell() frame = shell._render_turn_frame( prompt="hello", @@ -2466,7 +2607,9 @@ def test_turn_progress_frame_surfaces_context_compaction(self) -> None: self.assertIn("projection compact", rendered) self.assertIn("est 1800->620 tokens", rendered) - def test_turn_progress_frame_surfaces_recall_without_context_ready_or_request_rows(self) -> None: + def test_turn_progress_frame_surfaces_recall_without_context_ready_or_request_rows( + self, + ) -> None: shell = self._make_shell() frame = shell._render_turn_frame( prompt="hello", @@ -2506,14 +2649,16 @@ def test_turn_progress_frame_surfaces_recall_without_context_ready_or_request_ro self.assertNotIn("📈 request", rendered) self.assertNotIn("provider running", rendered) - def test_turn_progress_frame_hides_raw_tool_call_markup_from_stream_response(self) -> None: + def test_turn_progress_frame_hides_raw_tool_call_markup_from_stream_response( + self, + ) -> None: shell = self._make_shell() frame = shell._render_turn_frame( prompt="hello", tick=0, stream_text=( "I'll search for information on Xunzhuo Liu.\n" - "" + '' "xunzhuo liu researcher academic" ), ) @@ -2568,9 +2713,14 @@ def test_append_outcome_surfaces_state_focus_meta_in_transcript(self) -> None: self.assertEqual(shell.transcript[-1].kind, "assistant") self.assertEqual(shell.transcript[-1].body, "The release note draft is ready.") - self.assertEqual(shell.transcript[-1].meta, "routing · execution · 12ms · loop · 0.74 · cache hit · 50.0%") + self.assertEqual( + shell.transcript[-1].meta, + "routing · execution · 12ms · loop · 0.74 · cache hit · 50.0%", + ) - def test_state_focus_notice_fragments_show_almost_there_while_transcript_prime_pending(self) -> None: + def test_state_focus_notice_fragments_show_almost_there_while_transcript_prime_pending( + self, + ) -> None: shell = self._make_shell() with mock.patch.object( @@ -2595,7 +2745,9 @@ def test_state_focus_notice_fragments_show_almost_there_while_transcript_prime_p self.assertNotIn("I'm with you", rendered) self.assertTrue(shell._state_focus_runtime_ready_seen) - def test_state_focus_notice_fragments_hide_after_first_user_turn_is_submitted(self) -> None: + def test_state_focus_notice_fragments_hide_after_first_user_turn_is_submitted( + self, + ) -> None: shell = self._make_shell() shell._startup_user_turn_submitted = True @@ -2617,7 +2769,9 @@ def test_state_focus_notice_fragments_hide_after_first_user_turn_is_submitted(se self.assertNotIn("opening", rendered) self.assertNotIn("ready", rendered) - def test_state_focus_notice_fragments_hide_after_ready_once_first_user_turn_is_submitted(self) -> None: + def test_state_focus_notice_fragments_hide_after_ready_once_first_user_turn_is_submitted( + self, + ) -> None: shell = self._make_shell() shell._startup_surface_prepared = True shell._startup_user_turn_submitted = True @@ -2644,7 +2798,9 @@ def test_state_focus_notice_fragments_hide_after_ready_once_first_user_turn_is_s self.assertNotIn("opening", rendered) self.assertNotIn("path nearly ready", rendered) - def test_state_focus_notice_fragments_surface_state_focus_queue_after_ready_when_first_turn_is_waiting(self) -> None: + def test_state_focus_notice_fragments_surface_state_focus_queue_after_ready_when_first_turn_is_waiting( + self, + ) -> None: shell = self._make_shell() shell._startup_surface_prepared = True shell._startup_user_turn_submitted = True @@ -2668,7 +2824,9 @@ def test_state_focus_notice_fragments_surface_state_focus_queue_after_ready_when self.assertIn("path nearly ready", rendered) self.assertNotIn("🐾 ready", rendered) - def test_startup_transition_result_primes_opening_after_ready_idle_threshold(self) -> None: + def test_startup_transition_result_primes_opening_after_ready_idle_threshold( + self, + ) -> None: shell = self._make_shell() shell._state_focus_runtime_ready_seen_at = time.monotonic() - 10 @@ -2679,7 +2837,9 @@ def test_startup_transition_result_primes_opening_after_ready_idle_threshold(sel self.assertIsNone(immediate) self.assertEqual(result, "__elephant.startup.prime__") - def test_startup_transition_result_primes_before_dispatching_queued_first_turn(self) -> None: + def test_startup_transition_result_primes_before_dispatching_queued_first_turn( + self, + ) -> None: shell = self._make_shell() shell._startup_user_turn_submitted = True shell._pending_commands.append(PendingShellCommand(command="帮我看下这个")) @@ -2699,7 +2859,9 @@ def test_startup_transition_result_waits_briefly_after_ready_notice(self) -> Non self.assertIsNone(result) - def test_startup_transition_result_does_not_restart_prime_while_background_prime_runs(self) -> None: + def test_startup_transition_result_does_not_restart_prime_while_background_prime_runs( + self, + ) -> None: shell = self._make_shell() shell._startup_prime_started = True shell._state_focus_runtime_ready_seen_at = time.monotonic() - 10 @@ -2709,7 +2871,9 @@ def test_startup_transition_result_does_not_restart_prime_while_background_prime self.assertIsNone(result) - def test_startup_transition_result_dispatches_pending_after_proactive_prime(self) -> None: + def test_startup_transition_result_dispatches_pending_after_proactive_prime( + self, + ) -> None: shell = self._make_shell() shell._startup_user_turn_submitted = True shell._startup_transcript_primed = True @@ -2741,7 +2905,9 @@ def test_startup_turn_still_queues_until_proactive_opening_is_primed(self) -> No with mock.patch.object(type(shell), "_startup_state_focus_dispatch_ready", return_value=True): self.assertFalse(shell._startup_should_hold_user_command("帮我看下这个")) - def test_shell_constructor_defers_startup_opening_until_explicit_prime(self) -> None: + def test_shell_constructor_defers_startup_opening_until_explicit_prime( + self, + ) -> None: tmpdir = tempfile.TemporaryDirectory() self.addCleanup(tmpdir.cleanup) root = Path(tmpdir.name) @@ -2783,7 +2949,11 @@ def _inner(*args, **kwargs): return _inner with ( - mock.patch.object(shell, "_render_startup_sequence", side_effect=record("startup-sequence")), + mock.patch.object( + shell, + "_render_startup_sequence", + side_effect=record("startup-sequence"), + ), mock.patch.object(shell, "_refresh_shell_frame", side_effect=record("refresh-frame")), mock.patch.object(shell, "_prepare_startup_surface", side_effect=record("prepare-surface")), mock.patch.object(shell, "_next_command", side_effect=EOFError), @@ -2824,11 +2994,7 @@ def next_command(): def test_run_dispatches_queued_startup_turn_immediately_after_prime(self) -> None: shell = self._make_shell() shell._pending_commands.append(PendingShellCommand(command="帮我看下这个")) - commands = iter( - ( - PendingShellCommand(command="__elephant.startup.prime__"), - ) - ) + commands = iter((PendingShellCommand(command="__elephant.startup.prime__"),)) def next_command(): value = next(commands) @@ -2852,7 +3018,9 @@ def next_command(): prime.assert_called_once_with() dispatch.assert_called_once_with(PendingShellCommand(command="帮我看下这个")) - def test_prepare_startup_surface_runs_in_background_and_refreshes_skills(self) -> None: + def test_prepare_startup_surface_runs_in_background_and_refreshes_skills( + self, + ) -> None: shell = self._make_shell() class _ImmediateThread: @@ -2863,7 +3031,10 @@ def start(self) -> None: self._target() with ( - mock.patch("apps.cli.shell_methods_ui.threading.Thread", side_effect=_ImmediateThread), + mock.patch( + "apps.cli.shell_methods_ui.threading.Thread", + side_effect=_ImmediateThread, + ), mock.patch.object(type(shell.runtime), "prepare_session_surface") as prepare_surface, mock.patch.object(shell, "_refresh_skill_slash_specs") as refresh_skills, ): @@ -2873,7 +3044,9 @@ def start(self) -> None: refresh_skills.assert_called_once_with() self.assertTrue(shell._startup_surface_prepared) - def test_turn_progress_fragments_keep_stream_text_out_of_progress_header(self) -> None: + def test_turn_progress_fragments_keep_stream_text_out_of_progress_header( + self, + ) -> None: shell = self._make_shell() fragments = shell._render_turn_progress_fragments( @@ -2887,11 +3060,13 @@ def test_turn_progress_fragments_keep_stream_text_out_of_progress_header(self) - self.assertIn("streaming chunk", rendered) self.assertNotIn("active request:", rendered) - def test_stream_text_tracker_strips_tool_markup_and_resets_between_tool_rounds(self) -> None: + def test_stream_text_tracker_strips_tool_markup_and_resets_between_tool_rounds( + self, + ) -> None: holder, lock, observer = stream_text_tracker() observer("I'll search for information on Xunzhuo Liu.\n") - observer("xunzhuo") + observer('xunzhuo') self.assertEqual( latest_stream_text(holder, lock).strip(), "I'll search for information on Xunzhuo Liu.", @@ -2910,7 +3085,9 @@ def test_stream_text_tracker_strips_tool_markup_and_resets_between_tool_rounds(s "I found several relevant researcher profiles.", ) - def test_retain_stream_response_only_drops_old_thinking_but_keeps_response(self) -> None: + def test_retain_stream_response_only_drops_old_thinking_but_keeps_response( + self, + ) -> None: holder, lock, observer = stream_text_tracker() observer("Inspect the first result carefully.I'll open the strongest profile next.") @@ -2920,7 +3097,9 @@ def test_retain_stream_response_only_drops_old_thinking_but_keeps_response(self) self.assertEqual(preserved, "I'll open the strongest profile next.") self.assertEqual(latest_stream_text(holder, lock), "I'll open the strongest profile next.") - def test_tool_event_tracker_stream_anchors_exclude_historical_thinking(self) -> None: + def test_tool_event_tracker_stream_anchors_exclude_historical_thinking( + self, + ) -> None: holder, lock, observer = stream_text_tracker() tool_event_holder, tool_event_lock, tool_observer = shell_progress_runtime.tool_event_tracker( stream_holder=holder, @@ -2988,7 +3167,9 @@ def test_unknown_command_uses_brand_accent_panel(self) -> None: if border_style is not None: self.assertEqual(str(border_style), BRAND_ACCENT) - def test_personal_model_update_progress_copy_mentions_understanding_surface(self) -> None: + def test_personal_model_update_progress_copy_mentions_understanding_surface( + self, + ) -> None: shell = self._make_shell() event = ToolLifecycleEvent( event_id="tool-event-2", @@ -2996,7 +3177,12 @@ def test_personal_model_update_progress_copy_mentions_understanding_surface(self invocation_id="session-1:tool.personal_model.update", tool_id="tool.personal_model.update", session_id="session-1", - arguments={"action": "remember", "lens": "trait", "topic": "identity.name.preferred", "text": "The user's preferred name is Bit."}, + arguments={ + "action": "remember", + "lens": "trait", + "topic": "identity.name.preferred", + "text": "The user's preferred name is Bit.", + }, ), phase="execution.started", detail="executing tool.personal_model.update", @@ -3153,7 +3339,9 @@ def test_file_write_completed_event_appends_review_diff_to_tooltrace(self) -> No self.assertIn("@@ -1 +1,2 @@", tool_entries[0].body) self.assertIn("+world", tool_entries[0].body) - def test_turn_tool_progress_lines_keep_write_visible_when_diff_is_pending(self) -> None: + def test_turn_tool_progress_lines_keep_write_visible_when_diff_is_pending( + self, + ) -> None: shell = self._make_shell() shell.transcript = [ TranscriptEntry( @@ -3181,7 +3369,9 @@ def test_turn_tool_progress_lines_keep_write_visible_when_diff_is_pending(self) self.assertFalse(any(line.startswith("@@") for line in lines)) self.assertFalse(any(line.startswith("+") for line in lines)) - def test_render_pending_entries_keeps_context_compaction_frame_until_next_turn(self) -> None: + def test_render_pending_entries_keeps_context_compaction_frame_until_next_turn( + self, + ) -> None: shell = self._make_shell() shell._pending_context_compaction_frame = { "prompt": "hello", @@ -3211,7 +3401,9 @@ def test_render_pending_entries_keeps_context_compaction_frame_until_next_turn(s self.assertIn("projection compact", rendered) self.assertIn("est 1800->620 tokens", rendered) - def test_turn_progress_frame_keeps_later_tool_events_visible_after_diff_body(self) -> None: + def test_turn_progress_frame_keeps_later_tool_events_visible_after_diff_body( + self, + ) -> None: shell = self._make_shell() shell.transcript = [ TranscriptEntry( @@ -3239,7 +3431,9 @@ def test_turn_progress_frame_keeps_later_tool_events_visible_after_diff_body(sel self.assertIn("┊ 💻 computer", rendered) self.assertIn("osascript", rendered) - def test_personal_model_update_completed_event_keeps_generic_tooltrace(self) -> None: + def test_personal_model_update_completed_event_keeps_generic_tooltrace( + self, + ) -> None: shell = self._make_shell() event = ToolLifecycleEvent( event_id="tool-event-state-completed", @@ -3247,7 +3441,12 @@ def test_personal_model_update_completed_event_keeps_generic_tooltrace(self) -> invocation_id="session-1:tool.personal_model.update", tool_id="tool.personal_model.update", session_id=shell.session_id, - arguments={"action": "remember", "lens": "trait", "topic": "identity.name.preferred", "text": "The user's preferred name is Bit."}, + arguments={ + "action": "remember", + "lens": "trait", + "topic": "identity.name.preferred", + "text": "The user's preferred name is Bit.", + }, requested_at=datetime(2026, 4, 13, 8, 0, 0, tzinfo=timezone.utc), ), phase="execution.completed", @@ -3284,7 +3483,9 @@ def test_tool_trace_entries_render_with_layered_styles(self) -> None: self.assertIn(BRAND_MUTED, styles) self.assertIn(f"bold {BRAND_ACCENT_STRONG}", styles) - def test_turn_progress_fragments_reuse_tool_trace_copy_for_live_events(self) -> None: + def test_turn_progress_fragments_reuse_tool_trace_copy_for_live_events( + self, + ) -> None: shell = self._make_shell() invocation = ToolInvocation( invocation_id="session-1:tool.web.search", @@ -3308,8 +3509,12 @@ def test_turn_progress_fragments_reuse_tool_trace_copy_for_live_events(self) -> occurred_at=invocation.requested_at, ) - requested_fragments = shell._render_turn_progress_fragments(prompt="search xunzhuo liu", tick=0, tool_event=requested) - started_fragments = shell._render_turn_progress_fragments(prompt="search xunzhuo liu", tick=0, tool_event=started) + requested_fragments = shell._render_turn_progress_fragments( + prompt="search xunzhuo liu", tick=0, tool_event=requested + ) + started_fragments = shell._render_turn_progress_fragments( + prompt="search xunzhuo liu", tick=0, tool_event=started + ) requested_text = "".join(fragment[1] for fragment in requested_fragments) started_text = "".join(fragment[1] for fragment in started_fragments) @@ -3317,7 +3522,9 @@ def test_turn_progress_fragments_reuse_tool_trace_copy_for_live_events(self) -> self.assertIn("┊ 🌐 search", started_text) self.assertIn("xunzhuo liu", started_text) - def test_turn_progress_fragments_anchor_stream_text_to_matching_tool_event(self) -> None: + def test_turn_progress_fragments_anchor_stream_text_to_matching_tool_event( + self, + ) -> None: shell = self._make_shell() stream_holder, stream_lock, stream_observer = stream_text_tracker() tool_event_holder, tool_event_lock, tool_observer = shell_progress_runtime.tool_event_tracker( @@ -3384,11 +3591,22 @@ def test_turn_progress_fragments_anchor_stream_text_to_matching_tool_event(self) ) rendered = "".join(fragment[1] for fragment in fragments) - self.assertLess(rendered.index("I'll search for the profile first."), rendered.index("┊ 🌐 search")) - self.assertGreater(rendered.index("Then I'll open the best result."), rendered.index("┊ 🌐 search")) - self.assertLess(rendered.index("Then I'll open the best result."), rendered.rindex("┊ 🌐 Calling fetch…")) + self.assertLess( + rendered.index("I'll search for the profile first."), + rendered.index("┊ 🌐 search"), + ) + self.assertGreater( + rendered.index("Then I'll open the best result."), + rendered.index("┊ 🌐 search"), + ) + self.assertLess( + rendered.index("Then I'll open the best result."), + rendered.rindex("┊ 🌐 Calling fetch…"), + ) - def test_turn_progress_fragments_keep_stream_text_with_started_event_after_requested_event_expires(self) -> None: + def test_turn_progress_fragments_keep_stream_text_with_started_event_after_requested_event_expires( + self, + ) -> None: shell = self._make_shell() stream_holder, stream_lock, stream_observer = stream_text_tracker() tool_event_holder, tool_event_lock, tool_observer = shell_progress_runtime.tool_event_tracker( @@ -3445,9 +3663,14 @@ def test_turn_progress_fragments_keep_stream_text_with_started_event_after_reque rendered = "".join(fragment[1] for fragment in fragments) self.assertIn("I'll inspect local files first.", rendered) self.assertEqual(rendered.count("I'll inspect local files first."), 1) - self.assertLess(rendered.index("I'll inspect local files first."), rendered.index("┊ 🌐 search")) + self.assertLess( + rendered.index("I'll inspect local files first."), + rendered.index("┊ 🌐 search"), + ) - def test_turn_progress_fragments_preserve_repeated_tool_rail_with_late_stream_anchor(self) -> None: + def test_turn_progress_fragments_preserve_repeated_tool_rail_with_late_stream_anchor( + self, + ) -> None: shell = self._make_shell() shell._rendered_entries = len(shell.transcript) stream_holder, stream_lock, stream_observer = stream_text_tracker() @@ -3511,10 +3734,18 @@ def test_turn_progress_fragments_preserve_repeated_tool_rail_with_late_stream_an rendered = "".join(fragment[1] for fragment in fragments) self.assertGreaterEqual(rendered.count("┊ 📖 Calling read…"), 2) self.assertIn("┊ 📖 read /tmp/alpha.txt 0.7s", rendered) - self.assertLess(rendered.index("┊ 📖 read /tmp/alpha.txt 0.7s"), rendered.index("Then I'll inspect the second file.")) - self.assertLess(rendered.index("Then I'll inspect the second file."), rendered.rindex("┊ 📖 Calling read…")) + self.assertLess( + rendered.index("┊ 📖 read /tmp/alpha.txt 0.7s"), + rendered.index("Then I'll inspect the second file."), + ) + self.assertLess( + rendered.index("Then I'll inspect the second file."), + rendered.rindex("┊ 📖 Calling read…"), + ) - def test_turn_progress_fragments_keep_middle_stream_text_after_live_events_expire(self) -> None: + def test_turn_progress_fragments_keep_middle_stream_text_after_live_events_expire( + self, + ) -> None: shell = self._make_shell() shell._rendered_entries = len(shell.transcript) stream_holder, stream_lock, stream_observer = stream_text_tracker() @@ -3584,11 +3815,22 @@ def test_turn_progress_fragments_keep_middle_stream_text_after_live_events_expir ) rendered = "".join(fragment[1] for fragment in fragments) - self.assertLess(rendered.index("I'll search for the profile first."), rendered.index("┊ 🌐 Calling search…")) - self.assertGreater(rendered.index("Then I'll open the best result."), rendered.index("┊ 🌐 search")) - self.assertLess(rendered.index("Then I'll open the best result."), rendered.rindex("┊ 🌐 Calling fetch…")) + self.assertLess( + rendered.index("I'll search for the profile first."), + rendered.index("┊ 🌐 Calling search…"), + ) + self.assertGreater( + rendered.index("Then I'll open the best result."), + rendered.index("┊ 🌐 search"), + ) + self.assertLess( + rendered.index("Then I'll open the best result."), + rendered.rindex("┊ 🌐 Calling fetch…"), + ) - def test_turn_progress_fragments_keep_earliest_stream_anchor_after_live_feed_truncates(self) -> None: + def test_turn_progress_fragments_keep_earliest_stream_anchor_after_live_feed_truncates( + self, + ) -> None: shell = self._make_shell() shell._rendered_entries = len(shell.transcript) stream_holder, stream_lock, stream_observer = stream_text_tracker() @@ -3652,7 +3894,13 @@ def test_turn_progress_fragments_keep_earliest_stream_anchor_after_live_feed_tru ), ) - for message, invocation, requested_detail, completed_detail, completed_at in invocations_and_events: + for ( + message, + invocation, + requested_detail, + completed_detail, + completed_at, + ) in invocations_and_events: stream_observer(message) tool_observer( ToolLifecycleEvent( @@ -3677,7 +3925,10 @@ def test_turn_progress_fragments_keep_earliest_stream_anchor_after_live_feed_tru with tool_event_lock: feed = [item for item in tool_event_holder.get("feed", ()) if isinstance(item, _VisibleToolEvent)] self.assertEqual(len(feed), 6) - self.assertNotIn("session-1:tool.web.search", {item.event.invocation.invocation_id for item in feed}) + self.assertNotIn( + "session-1:tool.web.search", + {item.event.invocation.invocation_id for item in feed}, + ) fragments = shell_progress_runtime.render_turn_progress_fragments( shell, @@ -3691,7 +3942,10 @@ def test_turn_progress_fragments_keep_earliest_stream_anchor_after_live_feed_tru rendered = "".join(fragment[1] for fragment in fragments) self.assertLess(rendered.index("I'll search first."), rendered.index("┊ 🌐 Calling search…")) self.assertLess(rendered.index("Then I'll fetch."), rendered.index("┊ 🌐 Calling fetch…")) - self.assertLess(rendered.index("Next I'll read a file."), rendered.index("┊ 📖 Calling read…")) + self.assertLess( + rendered.index("Next I'll read a file."), + rendered.index("┊ 📖 Calling read…"), + ) self.assertLess(rendered.index("Finally I'll grep."), rendered.index("┊ 🔎 Calling grep…")) def test_tool_event_lines_compact_completed_tool_result_details(self) -> None: @@ -3734,12 +3988,18 @@ def test_append_relationship_routes_clear_through_state_surface(self) -> None: self.assertFalse(shell._handle_slash_command("/relationship clear")) self.assertEqual(shell.transcript[-1].title, "Unknown command") - def test_render_pending_entries_inserts_blank_line_between_user_and_assistant(self) -> None: + def test_render_pending_entries_inserts_blank_line_between_user_and_assistant( + self, + ) -> None: shell = self._make_shell() shell.console = _CaptureConsole(80) shell.transcript = [ TranscriptEntry(kind="user", title="You", body="where did we leave off?"), - TranscriptEntry(kind="assistant", title="Elephant Agent", body="We were refining the wake shell."), + TranscriptEntry( + kind="assistant", + title="Elephant Agent", + body="We were refining the wake shell.", + ), ] shell._rendered_entries = 0 @@ -3755,7 +4015,11 @@ def test_render_pending_entries_keeps_tooltrace_rows_tight(self) -> None: shell.console = _CaptureConsole(80) shell.transcript = [ TranscriptEntry(kind="tooltrace", title="Tool trace", body="┊ 🌐 Calling search…"), - TranscriptEntry(kind="tooltrace", title="Tool trace", body="┊ 🌐 search xunzhuo liu 3.2s"), + TranscriptEntry( + kind="tooltrace", + title="Tool trace", + body="┊ 🌐 search xunzhuo liu 3.2s", + ), ] shell._rendered_entries = 0 @@ -3765,7 +4029,9 @@ def test_render_pending_entries_keeps_tooltrace_rows_tight(self) -> None: self.assertIn("Calling search", shell.console.printed[0]) self.assertIn("xunzhuo liu", shell.console.printed[0]) - def test_render_pending_entries_keeps_inline_review_diff_in_same_tooltrace_block(self) -> None: + def test_render_pending_entries_keeps_inline_review_diff_in_same_tooltrace_block( + self, + ) -> None: shell = self._make_shell() shell.console = _CaptureConsole(120) shell.transcript = [ @@ -3773,12 +4039,7 @@ def test_render_pending_entries_keeps_inline_review_diff_in_same_tooltrace_block kind="tooltrace", title="Tool trace", body=( - "┊ 🛠 write notes.md 0.2s\n" - "┊ 🛠 diff\n" - "a/notes.md → b/notes.md\n" - "@@ -1 +1,2 @@\n" - " hello\n" - "+world" + "┊ 🛠 write notes.md 0.2s\n┊ 🛠 diff\na/notes.md → b/notes.md\n@@ -1 +1,2 @@\n hello\n+world" ), ) ] @@ -3791,12 +4052,22 @@ def test_render_pending_entries_keeps_inline_review_diff_in_same_tooltrace_block self.assertIn("a/notes.md → b/notes.md", shell.console.printed[0]) self.assertIn("+world", shell.console.printed[0]) - def test_render_pending_entries_inserts_blank_line_between_tooltrace_and_assistant(self) -> None: + def test_render_pending_entries_inserts_blank_line_between_tooltrace_and_assistant( + self, + ) -> None: shell = self._make_shell() shell.console = _CaptureConsole(80) shell.transcript = [ - TranscriptEntry(kind="tooltrace", title="Tool trace", body="┊ 📚 skill apple-notes 0.3s"), - TranscriptEntry(kind="assistant", title="Elephant Agent", body="I created the note in Apple Notes."), + TranscriptEntry( + kind="tooltrace", + title="Tool trace", + body="┊ 📚 skill apple-notes 0.3s", + ), + TranscriptEntry( + kind="assistant", + title="Elephant Agent", + body="I created the note in Apple Notes.", + ), ] shell._rendered_entries = 0 @@ -3807,7 +4078,9 @@ def test_render_pending_entries_inserts_blank_line_between_tooltrace_and_assista self.assertEqual(shell.console.printed[1], "") self.assertIn("I created the note in Apple Notes.", shell.console.printed[2]) - def test_render_pending_entries_inserts_blank_line_between_reasoning_and_tooltrace(self) -> None: + def test_render_pending_entries_inserts_blank_line_between_reasoning_and_tooltrace( + self, + ) -> None: shell = self._make_shell() shell.console = _CaptureConsole(100) shell.transcript = [ @@ -3816,7 +4089,11 @@ def test_render_pending_entries_inserts_blank_line_between_reasoning_and_tooltra title="Elephant Agent", body="Inspect the tool results first.", ), - TranscriptEntry(kind="tooltrace", title="Tool trace", body="┊ 🌐 fetch https://example.com"), + TranscriptEntry( + kind="tooltrace", + title="Tool trace", + body="┊ 🌐 fetch https://example.com", + ), ] shell._rendered_entries = 0 @@ -3881,7 +4158,9 @@ def test_providers_embeddings_local_switches_back_to_default(self) -> None: self.assertEqual(shell.transcript[-1].title, "Embedding provider updated") self.assertIn("selection: local-default", shell.transcript[-1].body) - def test_refresh_shell_frame_resets_render_cursor_and_clears_console_in_alternate_screen(self) -> None: + def test_refresh_shell_frame_resets_render_cursor_and_clears_console_in_alternate_screen( + self, + ) -> None: shell = self._make_shell_without_identity_update() shell.console = _CaptureConsole(120) shell._use_alternate_screen = True @@ -3893,7 +4172,9 @@ def test_refresh_shell_frame_resets_render_cursor_and_clears_console_in_alternat self.assertEqual(shell.console.clear_calls, [True]) self.assertEqual(len(shell.console.printed), 1) - def test_conversational_dispatch_skips_shell_frame_refresh_when_frame_state_is_unchanged(self) -> None: + def test_conversational_dispatch_skips_shell_frame_refresh_when_frame_state_is_unchanged( + self, + ) -> None: shell = self._make_shell() with mock.patch.object(shell, "_refresh_shell_frame_if_needed") as refresh: @@ -3912,9 +4193,14 @@ def test_clear_resets_transcript_and_replays_model_generated_opening(self) -> No mock.patch.object( CliRuntime, "generate_opening_reply", - return_value=SimpleNamespace(execution=SimpleNamespace(summary="startup-reply:I'm back in the thread.")), + return_value=SimpleNamespace( + execution=SimpleNamespace(summary="startup-reply:I'm back in the thread.") + ), ) as generate_opening_reply, - mock.patch("apps.learning_worker_runtime.ensure_learning_worker_running", return_value=True), + mock.patch( + "apps.learning_worker_runtime.ensure_learning_worker_running", + return_value=True, + ), mock.patch.object(shell, "_refresh_shell_frame") as refresh, ): handled = shell._handle_slash_command("/clear") @@ -3923,7 +4209,10 @@ def test_clear_resets_transcript_and_replays_model_generated_opening(self) -> No generate_opening_reply.assert_called_once() refresh.assert_called_once_with() self.assertNotEqual(shell.session_id, original_session_id) - self.assertEqual(shell.runtime.inspect_session(shell.session_id).parent_episode_id, original_session_id) + self.assertEqual( + shell.runtime.inspect_session(shell.session_id).parent_episode_id, + original_session_id, + ) self.assertEqual(len(shell.transcript), 2) self.assertEqual(shell.transcript[0].kind, "assistant") self.assertEqual(shell.transcript[0].body, "startup-reply:I'm back in the thread.") @@ -3937,7 +4226,10 @@ def test_exit_closes_episode_and_queues_episode_close_learning(self) -> None: shell = self._make_shell(prime_transcript=True) original_session_id = shell.session_id - with mock.patch("apps.learning_worker_runtime.ensure_learning_worker_running", return_value=True): + with mock.patch( + "apps.learning_worker_runtime.ensure_learning_worker_running", + return_value=True, + ): handled = shell._handle_slash_command("/exit") self.assertTrue(handled) @@ -3949,7 +4241,9 @@ def test_exit_closes_episode_and_queues_episode_close_learning(self) -> None: self.assertEqual(len(jobs), 1) self.assertEqual(jobs[0].trigger, "episode_close") - def test_append_growth_update_message_surfaces_visible_understanding_checkpoint_reply(self) -> None: + def test_append_growth_update_message_surfaces_visible_understanding_checkpoint_reply( + self, + ) -> None: shell = self._make_shell() now = datetime.now(timezone.utc) initial = default_growth_state(shell.runtime.current_profile().state.profile_id, now=now) @@ -3996,7 +4290,9 @@ def test_dispatch_schedules_growth_followup_after_turn(self) -> None: schedule.assert_called_once_with() refresh.assert_not_called() - def test_refresh_shell_frame_if_needed_skips_when_frame_token_is_unchanged(self) -> None: + def test_refresh_shell_frame_if_needed_skips_when_frame_token_is_unchanged( + self, + ) -> None: shell = self._make_shell() shell._last_shell_frame_token = shell._current_shell_frame_token() @@ -4006,7 +4302,9 @@ def test_refresh_shell_frame_if_needed_skips_when_frame_token_is_unchanged(self) self.assertFalse(changed) refresh.assert_not_called() - def test_refresh_shell_frame_if_needed_skips_for_pending_context_compaction_frame(self) -> None: + def test_refresh_shell_frame_if_needed_skips_for_pending_context_compaction_frame( + self, + ) -> None: shell = self._make_shell() shell._last_shell_frame_token = shell._current_shell_frame_token() shell._pending_context_compaction_frame = { @@ -4029,7 +4327,9 @@ def test_refresh_shell_frame_if_needed_skips_for_pending_context_compaction_fram self.assertFalse(changed) refresh.assert_not_called() - def test_refresh_shell_frame_if_needed_skips_when_session_context_freezes(self) -> None: + def test_refresh_shell_frame_if_needed_skips_when_session_context_freezes( + self, + ) -> None: shell = self._make_shell() shell._last_shell_frame_token = shell._current_shell_frame_token() session = shell.runtime.inspect_session(shell.session_id) @@ -4097,7 +4397,9 @@ def test_opener_uses_continuity_driven_wake_summary(self) -> None: self.assertNotIn("Resume active", shell.transcript[0].body) self.assertNotIn("internal projection", shell.transcript[0].body) - def test_opener_hides_internal_defer_summary_when_no_actionable_current_work_exists(self) -> None: + def test_opener_hides_internal_defer_summary_when_no_actionable_current_work_exists( + self, + ) -> None: shell = self._make_shell() shell.runtime.update_user_state( profile_id=shell.runtime.inspect_session(shell.session_id).profile_id, @@ -4121,7 +4423,10 @@ def test_opener_keeps_blank_user_profile_flow_light(self) -> None: self.assertIn("I'll start holding this new elephant with you.", shell.transcript[0].body) self.assertNotIn("Welcome back", shell.transcript[0].body) self.assertIn("What should I call you", shell.transcript[0].body) - self.assertNotIn("one durable thing I should keep in mind from the start", shell.transcript[0].body) + self.assertNotIn( + "one durable thing I should keep in mind from the start", + shell.transcript[0].body, + ) def test_prime_transcript_prefers_model_generated_opening_reply(self) -> None: shell = self._make_shell() @@ -4138,7 +4443,9 @@ def test_prime_transcript_prefers_model_generated_opening_reply(self) -> None: self.assertEqual(len(shell.transcript), 1) self.assertEqual(shell.transcript[0].body, "startup-reply:I'm already here.") - def test_prime_transcript_renders_new_elephant_opening_without_runtime_label(self) -> None: + def test_prime_transcript_renders_new_elephant_opening_without_runtime_label( + self, + ) -> None: shell = self._make_shell(opened="Shaped new") shell.transcript = [] shell._rendered_entries = 0 @@ -4146,7 +4453,9 @@ def test_prime_transcript_renders_new_elephant_opening_without_runtime_label(sel with mock.patch.object( CliRuntime, "generate_opening_reply", - return_value=SimpleNamespace(execution=SimpleNamespace(summary="startup-reply:I'm here. What should I call you?")), + return_value=SimpleNamespace( + execution=SimpleNamespace(summary="startup-reply:I'm here. What should I call you?") + ), ) as generate_opening_reply: shell._prime_transcript() @@ -4157,7 +4466,9 @@ def test_prime_transcript_renders_new_elephant_opening_without_runtime_label(sel self.assertNotIn("Shaped new", prompt) self.assertNotIn("welcome back", shell.transcript[0].body.lower()) - def test_prime_transcript_passes_known_name_and_active_state_into_startup_prompt(self) -> None: + def test_prime_transcript_passes_known_name_and_active_state_into_startup_prompt( + self, + ) -> None: shell = self._make_shell( opened="Opened elephant atlas", user_profile_text=render_user_profile_text( @@ -4174,7 +4485,9 @@ def test_prime_transcript_passes_known_name_and_active_state_into_startup_prompt with mock.patch.object( CliRuntime, "generate_opening_reply", - return_value=SimpleNamespace(execution=SimpleNamespace(summary="startup-reply:Bit, I still have the release State in view.")), + return_value=SimpleNamespace( + execution=SimpleNamespace(summary="startup-reply:Bit, I still have the release State in view.") + ), ) as generate_opening_reply: shell._prime_transcript() @@ -4184,10 +4497,13 @@ def test_prime_transcript_passes_known_name_and_active_state_into_startup_prompt self.assertNotIn("their current context is Building durable agent systems.", prompt) self.assertNotIn("returning to an ongoing relationship", prompt) self.assertNotIn("Opened elephant atlas", prompt) - self.assertNotIn('Live thread', prompt) + self.assertNotIn("Live thread", prompt) self.assertNotIn("private posture signals only", prompt) self.assertIn("one natural message", prompt) - self.assertEqual(shell.transcript[0].body, "startup-reply:Bit, I still have the release State in view.") + self.assertEqual( + shell.transcript[0].body, + "startup-reply:Bit, I still have the release State in view.", + ) def test_existing_elephant_open_does_not_render_user_questionnaire(self) -> None: shell = self._make_shell( @@ -4218,7 +4534,9 @@ def test_opener_mentions_durable_thread_when_state_focus_is_missing(self) -> Non self.assertEqual(len(shell.transcript), 1) self.assertIn("If something matters right now", shell.transcript[0].body) - def test_existing_elephant_open_skips_user_onboarding_when_profile_fields_are_complete(self) -> None: + def test_existing_elephant_open_skips_user_onboarding_when_profile_fields_are_complete( + self, + ) -> None: shell = self._make_shell( opened="Opened elephant atlas", user_profile_text=render_user_profile_text( @@ -4259,11 +4577,18 @@ def test_state_focus_onboarding_skips_when_durable_state_focus_exists(self) -> N shell._prime_transcript() self.assertEqual(len(shell.transcript), 1) - self.assertNotIn("If there's something you want me to keep carrying", shell.transcript[-1].body) + self.assertNotIn( + "If there's something you want me to keep carrying", + shell.transcript[-1].body, + ) def test_shell_welcome_copy_and_boot_delays_support_a_visible_entry(self) -> None: self.assertEqual(SHELL_WELCOME_HEADLINE, "Your elephant still knows the path.") - self.assertAlmostEqual((STARTUP_SEQUENCE_STEP_DELAY * 4) + STARTUP_SEQUENCE_FINAL_DELAY, 3.0, delta=0.12) + self.assertAlmostEqual( + (STARTUP_SEQUENCE_STEP_DELAY * 4) + STARTUP_SEQUENCE_FINAL_DELAY, + 3.0, + delta=0.12, + ) self.assertGreaterEqual(STARTUP_SEQUENCE_STEP_DELAY, 0.50) self.assertGreaterEqual(STARTUP_SEQUENCE_FINAL_DELAY, 0.50) diff --git a/tests/unit/cli/test_shell_clipboard.py b/tests/unit/cli/test_shell_clipboard.py index 17a758a..e6cf407 100644 --- a/tests/unit/cli/test_shell_clipboard.py +++ b/tests/unit/cli/test_shell_clipboard.py @@ -29,7 +29,9 @@ def test_build_text_attachment_uses_compact_label_and_full_payload(self) -> None self.assertEqual(attachment.display_label, "[Pasted Content 10 chars]") self.assertEqual(attachment.prompt_fragment, "[Clipboard text]\nalpha\nbeta") - def test_build_path_attachment_uses_filename_but_preserves_absolute_path(self) -> None: + def test_build_path_attachment_uses_filename_but_preserves_absolute_path( + self, + ) -> None: attachment = build_path_attachment("./notes/design.md") self.assertIsNotNone(attachment) @@ -40,7 +42,9 @@ def test_build_path_attachment_uses_filename_but_preserves_absolute_path(self) - self.assertIn("design.md", attachment.prompt_fragment) self.assertTrue(Path(attachment.prompt_fragment.split(":", 1)[1]).is_absolute()) - def test_compile_submission_keeps_visible_summary_separate_from_full_prompt(self) -> None: + def test_compile_submission_keeps_visible_summary_separate_from_full_prompt( + self, + ) -> None: text_attachment = build_text_attachment("full copied text payload") file_attachment = build_path_attachment("./notes/design.md") assert text_attachment is not None @@ -61,7 +65,9 @@ def test_compile_submission_keeps_visible_summary_separate_from_full_prompt(self self.assertEqual(submission.event_payload["message"], submission.display_command) self.assertIn("full copied text payload", submission.event_payload["message"]) - def test_compile_submission_ignores_clipboard_attachments_for_slash_commands(self) -> None: + def test_compile_submission_ignores_clipboard_attachments_for_slash_commands( + self, + ) -> None: attachment = build_text_attachment("keep out of slash command") assert attachment is not None diff --git a/tests/unit/cli/test_shell_composer.py b/tests/unit/cli/test_shell_composer.py index 2a9e47e..6a441e0 100644 --- a/tests/unit/cli/test_shell_composer.py +++ b/tests/unit/cli/test_shell_composer.py @@ -20,7 +20,9 @@ def _make_shell(self) -> SimpleNamespace: _prompt_continuation=lambda: " ", ) - def test_read_command_runs_prompt_session_in_thread_when_loop_is_active(self) -> None: + def test_read_command_runs_prompt_session_in_thread_when_loop_is_active( + self, + ) -> None: shell = self._make_shell() captured: dict[str, object] = {} @@ -39,7 +41,11 @@ def prompt(self, *args, **kwargs): mock.patch.object(shell_composer, "prompt_toolkit_composer_available", return_value=False), mock.patch.object(shell_composer, "shell_history", return_value=mock.sentinel.history), mock.patch.object(shell_composer, "prompt_toolkit_loop_running", return_value=True), - mock.patch.object(shell_composer, "prompt_toolkit_output_without_cpr", return_value=mock.sentinel.output), + mock.patch.object( + shell_composer, + "prompt_toolkit_output_without_cpr", + return_value=mock.sentinel.output, + ), mock.patch.object(cli_shell, "ShellCompleter", return_value=mock.sentinel.completer), ): result = shell_composer.read_command(shell) @@ -73,18 +79,37 @@ def exit(self, result=None): mock.patch.object(shell_composer, "PROMPT_TOOLKIT_AVAILABLE", True), mock.patch.object(shell_composer, "prompt_toolkit_composer_available", return_value=True), mock.patch.object(shell_composer, "build_prompt_buffer", return_value=buffer), - mock.patch.object(shell_composer, "build_input_window", return_value=mock.sentinel.input_window), - mock.patch.object(shell_composer, "build_command_palette", return_value=mock.sentinel.command_palette), + mock.patch.object( + shell_composer, + "build_input_window", + return_value=mock.sentinel.input_window, + ), + mock.patch.object( + shell_composer, + "build_command_palette", + return_value=mock.sentinel.command_palette, + ), mock.patch.object(shell_composer, "build_composer_body", return_value=mock.sentinel.body), - mock.patch.object(shell_composer, "prompt_toolkit_output_without_cpr", return_value=mock.sentinel.output), + mock.patch.object( + shell_composer, + "prompt_toolkit_output_without_cpr", + return_value=mock.sentinel.output, + ), mock.patch.object(shell_composer, "Application", _FakeApplication), - mock.patch.object(shell_composer, "Layout", side_effect=lambda body, focused_element=None: (body, focused_element)), + mock.patch.object( + shell_composer, + "Layout", + side_effect=lambda body, focused_element=None: (body, focused_element), + ), mock.patch.object(shell_composer, "prompt_toolkit_loop_running", return_value=True), ): result = shell_composer.read_command(shell) self.assertEqual(result, "hello from app") - self.assertEqual(captured["application_init"]["layout"], (mock.sentinel.body, mock.sentinel.input_window)) + self.assertEqual( + captured["application_init"]["layout"], + (mock.sentinel.body, mock.sentinel.input_window), + ) self.assertIs(captured["application_init"]["output"], mock.sentinel.output) self.assertTrue(captured["run_kwargs"]["in_thread"]) @@ -181,19 +206,23 @@ def test_empty_transcript_returns_empty(self) -> None: self.assertEqual(shell_composer._last_user_message(shell), "") def test_picks_most_recent_user_entry(self) -> None: - shell = self._shell_with_transcript([ - self._entry("user", "first"), - self._entry("assistant", "reply"), - self._entry("user", "second"), - self._entry("assistant", "second reply"), - ]) + shell = self._shell_with_transcript( + [ + self._entry("user", "first"), + self._entry("assistant", "reply"), + self._entry("user", "second"), + self._entry("assistant", "second reply"), + ] + ) self.assertEqual(shell_composer._last_user_message(shell), "second") def test_skips_blank_user_entries(self) -> None: - shell = self._shell_with_transcript([ - self._entry("user", "real message"), - self._entry("user", " "), - ]) + shell = self._shell_with_transcript( + [ + self._entry("user", "real message"), + self._entry("user", " "), + ] + ) self.assertEqual(shell_composer._last_user_message(shell), "real message") def test_missing_transcript_attr_does_not_raise(self) -> None: diff --git a/tests/unit/cli/test_shell_opening.py b/tests/unit/cli/test_shell_opening.py index d23b88e..47fe570 100644 --- a/tests/unit/cli/test_shell_opening.py +++ b/tests/unit/cli/test_shell_opening.py @@ -10,7 +10,9 @@ class ShellOpeningTest(unittest.TestCase): - def test_compose_shell_opener_requests_name_when_user_profile_is_blank(self) -> None: + def test_compose_shell_opener_requests_name_when_user_profile_is_blank( + self, + ) -> None: opener = compose_shell_opener( ShellOpeningContext( opened="Shaped new", @@ -30,7 +32,9 @@ def test_compose_shell_opener_requests_name_when_user_profile_is_blank(self) -> self.assertIn("What should I call you", opener) self.assertNotIn("one durable thing I should keep in mind from the start", opener) - def test_compose_shell_opener_uses_wake_summary_when_user_profile_exists(self) -> None: + def test_compose_shell_opener_uses_wake_summary_when_user_profile_exists( + self, + ) -> None: opener = compose_shell_opener( ShellOpeningContext( opened="Opened elephant atlas", @@ -45,10 +49,18 @@ def test_compose_shell_opener_uses_wake_summary_when_user_profile_exists(self) - ) self.assertIn("I'm here, Bit. I still have the useful shape of our current work.", opener) - self.assertIn("I'll keep the next useful step visible without turning this into a status report.", opener) - self.assertIn("I still have Ship the release in view; do you want to keep going there?", opener) + self.assertIn( + "I'll keep the next useful step visible without turning this into a status report.", + opener, + ) + self.assertIn( + "I still have Ship the release in view; do you want to keep going there?", + opener, + ) - def test_compose_shell_opener_invites_current_work_when_state_focus_is_missing(self) -> None: + def test_compose_shell_opener_invites_current_work_when_state_focus_is_missing( + self, + ) -> None: opener = compose_shell_opener( ShellOpeningContext( opened="Opened elephant atlas", @@ -112,7 +124,9 @@ def test_compose_shell_opening_instruction_is_one_shot_and_humane(self) -> None: self.assertNotIn("not a greeter or product surface", prompt) self.assertNotIn("optionally include one gentle question", prompt) - def test_compose_shell_opening_instruction_surfaces_known_name_and_current_work(self) -> None: + def test_compose_shell_opening_instruction_surfaces_known_name_and_current_work( + self, + ) -> None: prompt = compose_shell_opening_instruction( ShellOpeningContext( opened="Opened elephant atlas", @@ -144,7 +158,9 @@ def test_compose_shell_opening_instruction_surfaces_known_name_and_current_work( self.assertNotIn("work item ids", prompt) self.assertNotIn("do not mention prompts", prompt.lower()) - def test_compose_shell_opening_instruction_after_init_requests_warm_live_opener(self) -> None: + def test_compose_shell_opening_instruction_after_init_requests_warm_live_opener( + self, + ) -> None: prompt = compose_shell_opening_instruction( ShellOpeningContext( opened="Born new", @@ -189,7 +205,9 @@ def test_compose_shell_opening_instruction_after_init_requests_warm_live_opener( self.assertNotIn("personal hobbies: reading, music", prompt) self.assertNotIn("Output: one natural message", prompt) - def test_compose_shell_opening_instruction_sanitizes_internal_wake_refs(self) -> None: + def test_compose_shell_opening_instruction_sanitizes_internal_wake_refs( + self, + ) -> None: prompt = compose_shell_opening_instruction( ShellOpeningContext( opened="Opened elephant atlas", @@ -201,7 +219,7 @@ def test_compose_shell_opening_instruction_sanitizes_internal_wake_refs(self) -> wake_summary=( "The episode resumed from a prior collaboration and should continue the active elephant. " "Replay evidence event:f526dcf07c2048f0af65226e60807364:structured-turn:memory retains a successful action chain for this work. " - "The internal projection keeps \"i am xunzhuo\" active as the next step." + 'The internal projection keeps "i am xunzhuo" active as the next step.' ), has_state_focus=True, ) @@ -227,11 +245,16 @@ def test_compose_shell_opener_sanitizes_internal_wake_refs(self) -> None: ) ) - self.assertIn("I still have The active elephant is ready to continue in view; do you want to keep going there?", opener) + self.assertIn( + "I still have The active elephant is ready to continue in view; do you want to keep going there?", + opener, + ) self.assertNotIn("work:90920371a588", opener) self.assertNotIn("event:f526", opener) - def test_compose_shell_opening_instruction_omits_deferred_wake_summary(self) -> None: + def test_compose_shell_opening_instruction_omits_deferred_wake_summary( + self, + ) -> None: prompt = compose_shell_opening_instruction( ShellOpeningContext( opened="Opened elephant atlas", @@ -255,7 +278,9 @@ def test_compose_shell_opening_instruction_omits_deferred_wake_summary(self) -> self.assertNotIn("something is already open —", prompt) self.assertNotIn("No actionable current work was available", prompt) - def test_compose_shell_opening_instruction_includes_init_first_language(self) -> None: + def test_compose_shell_opening_instruction_includes_init_first_language( + self, + ) -> None: prompt = compose_shell_opening_instruction( ShellOpeningContext( opened="Born new", diff --git a/tests/unit/cli/test_shell_polish.py b/tests/unit/cli/test_shell_polish.py index e6c5e38..36ccdd2 100644 --- a/tests/unit/cli/test_shell_polish.py +++ b/tests/unit/cli/test_shell_polish.py @@ -11,7 +11,6 @@ import os from types import SimpleNamespace import unittest -from unittest import mock from apps.cli import shell_render, shell_ui, turn_metrics @@ -30,20 +29,24 @@ def test_single_success_returns_verb(self) -> None: ) def test_repeated_tool_collapses_to_count(self) -> None: - result = turn_metrics.condense_tool_summary([ - ("tool.file.read", True, 1), - ("tool.file.read", True, 2), - ("tool.file.read", True, 3), - ]) + result = turn_metrics.condense_tool_summary( + [ + ("tool.file.read", True, 1), + ("tool.file.read", True, 2), + ("tool.file.read", True, 3), + ] + ) self.assertEqual(result, "read × 3") def test_mixed_tools_join_with_middle_dot(self) -> None: - result = turn_metrics.condense_tool_summary([ - ("tool.file.read", True, 1), - ("tool.file.read", True, 2), - ("tool.file.patch", True, 3), - ("tool.file.search", True, 4), - ]) + result = turn_metrics.condense_tool_summary( + [ + ("tool.file.read", True, 1), + ("tool.file.read", True, 2), + ("tool.file.patch", True, 3), + ("tool.file.search", True, 4), + ] + ) # Order follows Counter.most_common — read (2) first, then two ties. self.assertIn("read × 2", result) self.assertIn("edited", result) @@ -54,22 +57,29 @@ def test_mixed_tools_join_with_middle_dot(self) -> None: self.assertIn("searched", parts) def test_single_failure_surfaces_with_failed_suffix(self) -> None: - result = turn_metrics.condense_tool_summary([ - ("tool.file.read", True, 1), - ("tool.terminal.exec", False, 2), - ]) + result = turn_metrics.condense_tool_summary( + [ + ("tool.file.read", True, 1), + ("tool.terminal.exec", False, 2), + ] + ) self.assertEqual(result, "read · ran failed") def test_multiple_failures_collapse_to_count(self) -> None: - result = turn_metrics.condense_tool_summary([ - ("tool.terminal.exec", False, 1), - ("tool.file.read", False, 2), - ("tool.file.patch", False, 3), - ]) + result = turn_metrics.condense_tool_summary( + [ + ("tool.terminal.exec", False, 1), + ("tool.file.read", False, 2), + ("tool.file.patch", False, 3), + ] + ) self.assertIn("3 failures", result) def test_unknown_tool_falls_back_to_last_segment(self) -> None: - self.assertEqual(turn_metrics.condense_tool_summary([("tool.novel.widget", True, 1)]), "widget") + self.assertEqual( + turn_metrics.condense_tool_summary([("tool.novel.widget", True, 1)]), + "widget", + ) def test_tool_id_with_underscores_gets_spaces(self) -> None: # "tool.foo.my_cool_thing" -> "my cool thing" @@ -146,8 +156,7 @@ class WrapFileHyperlinkTests(unittest.TestCase): def setUp(self) -> None: # Snapshot relevant env vars so tests don't leak into each other. self._env_snapshot = { - key: os.environ.get(key) - for key in ("TERM_PROGRAM", "TERM", "NO_COLOR", "ELEPHANT_NO_HYPERLINKS") + key: os.environ.get(key) for key in ("TERM_PROGRAM", "TERM", "NO_COLOR", "ELEPHANT_NO_HYPERLINKS") } for key in self._env_snapshot: os.environ.pop(key, None) diff --git a/tests/unit/cli/test_shell_progress.py b/tests/unit/cli/test_shell_progress.py index 8af46b6..948f65c 100644 --- a/tests/unit/cli/test_shell_progress.py +++ b/tests/unit/cli/test_shell_progress.py @@ -40,15 +40,19 @@ class ShellProgressTest(unittest.TestCase): - def _evidence_event(self, *, tool_id: str, action: str | None = None, evidence_id: str = "evidence-release") -> ToolLifecycleEvent: + def _evidence_event( + self, + *, + tool_id: str, + action: str | None = None, + evidence_id: str = "evidence-release", + ) -> ToolLifecycleEvent: invocation = ToolInvocation( invocation_id="invoke-evidence", tool_id=tool_id, session_id="session-test", arguments=( - {"evidence_id": evidence_id} - if action is None - else {"action": action, "evidence_id": evidence_id} + {"evidence_id": evidence_id} if action is None else {"action": action, "evidence_id": evidence_id} ), requested_at=datetime.now(timezone.utc), requester="test", @@ -92,16 +96,28 @@ def test_personal_model_search_trace_uses_current_label_and_preview(self) -> Non ) self.assertEqual(_tool_trace_label(event), "model") - self.assertEqual(_tool_trace_preview(event.invocation.arguments, tool_id="tool.personal_model.search"), "release notes") + self.assertEqual( + _tool_trace_preview(event.invocation.arguments, tool_id="tool.personal_model.search"), + "release notes", + ) - def test_personal_model_update_requested_trace_uses_current_prepare_label(self) -> None: + def test_personal_model_update_requested_trace_uses_current_prepare_label( + self, + ) -> None: event = self._tool_event( tool_id="tool.personal_model.update", - arguments={"action": "remember", "lens": "rapport", "topic": "assistant.review.style"}, + arguments={ + "action": "remember", + "lens": "rapport", + "topic": "assistant.review.style", + }, phase="requested", ) - self.assertEqual(tool_trace_line(None, event), "┊ 🌱 Calling learn · remember assistant.review.style…") + self.assertEqual( + tool_trace_line(None, event), + "┊ 🌱 Calling learn · remember assistant.review.style…", + ) def test_personal_model_search_requested_trace_uses_current_label(self) -> None: event = self._tool_event( @@ -112,7 +128,17 @@ def test_personal_model_search_requested_trace_uses_current_label(self) -> None: self.assertEqual(_tool_trace_label(event), "model") self.assertEqual(tool_trace_line(None, event), "┊ 🐘 Calling model · notes…") - self.assertEqual(tool_event_progress_line(None, self._tool_event(tool_id="tool.personal_model.search", arguments={"query": "notes"}, phase="execution.started")), "┊ 🐘 model notes") + self.assertEqual( + tool_event_progress_line( + None, + self._tool_event( + tool_id="tool.personal_model.search", + arguments={"query": "notes"}, + phase="execution.started", + ), + ), + "┊ 🐘 model notes", + ) def test_conversation_search_trace_uses_current_label(self) -> None: search = self._tool_event( @@ -149,9 +175,15 @@ def test_custom_mcp_started_trace_uses_extension_emoji(self) -> None: self.assertEqual(tool_event_progress_line(None, event), "┊ 🧩 mcp.km.hot-articles Top KM") - def test_tool_event_tracker_keeps_short_lived_feed_for_fast_personal_model_events(self) -> None: + def test_tool_event_tracker_keeps_short_lived_feed_for_fast_personal_model_events( + self, + ) -> None: holder, lock, observer = tool_event_tracker() - requested = self._tool_event(tool_id="tool.personal_model.search", arguments={"query": "evidence-release"}, phase="requested") + requested = self._tool_event( + tool_id="tool.personal_model.search", + arguments={"query": "evidence-release"}, + phase="requested", + ) completed = ToolLifecycleEvent( event_id="event-evidence-complete", invocation=requested.invocation, @@ -166,7 +198,9 @@ def test_tool_event_tracker_keeps_short_lived_feed_for_fast_personal_model_event self.assertEqual([event.event.phase for event in feed], ["requested", "execution.completed"]) - def test_kernel_event_tracker_forwards_skill_disclosures_without_adding_fake_stage(self) -> None: + def test_kernel_event_tracker_forwards_skill_disclosures_without_adding_fake_stage( + self, + ) -> None: captured: list[dict[str, object]] = [] holder, lock, observer = kernel_event_tracker(captured.append) @@ -194,7 +228,9 @@ def test_kernel_event_tracker_forwards_skill_disclosures_without_adding_fake_sta ], ) - def test_kernel_event_tracker_keeps_context_compaction_visible_after_usage_events(self) -> None: + def test_kernel_event_tracker_keeps_context_compaction_visible_after_usage_events( + self, + ) -> None: holder, lock, observer = kernel_event_tracker() observer( { @@ -221,7 +257,9 @@ def test_kernel_event_tracker_keeps_context_compaction_visible_after_usage_event self.assertTrue(any(stage["payload"]["stage"] == "context-compact" for stage in stages)) self.assertTrue(any(stage["payload"]["stage"] == "context-usage" for stage in stages)) - def test_kernel_event_tracker_keeps_state_focus_visible_after_compaction_and_usage_events(self) -> None: + def test_kernel_event_tracker_keeps_state_focus_visible_after_compaction_and_usage_events( + self, + ) -> None: holder, lock, observer = kernel_event_tracker() observer( { @@ -302,7 +340,9 @@ def test_recall_progress_line_shows_no_match(self) -> None: self.assertEqual(line, "┊ 🗺️ recall no signal") - def test_usage_progress_line_shows_projection_while_provider_usage_is_pending(self) -> None: + def test_usage_progress_line_shows_projection_while_provider_usage_is_pending( + self, + ) -> None: line = turn_usage_progress_line( kernel_stage_events=( { @@ -315,7 +355,10 @@ def test_usage_progress_line_shows_projection_while_provider_usage_is_pending(se ) ) - self.assertEqual(line, "┊ 📈 request provider running · sent est 16000/128000 · 12% · usage pending") + self.assertEqual( + line, + "┊ 📈 request provider running · sent est 16000/128000 · 12% · usage pending", + ) def test_terminal_progress_line_shows_shell_command(self) -> None: event = self._tool_event( @@ -353,12 +396,19 @@ def test_sub_agents_progress_line_shows_name_and_prompt(self) -> None: phase="execution.started", ) - self.assertEqual(tool_event_progress_line(None, event), "┊ 🐘 herd run · reviewer: inspect the cron scheduler") + self.assertEqual( + tool_event_progress_line(None, event), + "┊ 🐘 herd run · reviewer: inspect the cron scheduler", + ) def test_sub_agents_progress_line_shows_start_action(self) -> None: event = self._tool_event( tool_id="tool.sub_agents", - arguments={"action": "start", "name": "reviewer", "task": "inspect the cron scheduler"}, + arguments={ + "action": "start", + "name": "reviewer", + "task": "inspect the cron scheduler", + }, phase="execution.started", ) @@ -374,7 +424,10 @@ def test_sub_agents_progress_line_shows_status_action_and_run_id(self) -> None: phase="execution.started", ) - self.assertEqual(tool_event_progress_line(None, event), "┊ 🐘 herd status · subrun-abc123") + self.assertEqual( + tool_event_progress_line(None, event), + "┊ 🐘 herd status · subrun-abc123", + ) def test_sub_agents_progress_lines_expand_batch_tasks(self) -> None: event = self._tool_event( @@ -442,7 +495,9 @@ def test_context_progress_line_surfaces_projection_compaction(self) -> None: self.assertIn("80->12 messages", line) self.assertIn("scanner: 2 cached / 5 pending / 1 missed", line) - def test_context_progress_line_marks_projection_rewrite_when_tokens_do_not_shrink(self) -> None: + def test_context_progress_line_marks_projection_rewrite_when_tokens_do_not_shrink( + self, + ) -> None: line = loop_context_progress_line( kernel_stage_events=( { @@ -563,7 +618,11 @@ def test_process_trace_uses_proc_label_and_action_preview(self) -> None: def test_tool_trace_keeps_gap_after_wide_variation_emoji(self) -> None: cases = ( ("tool.file.write", {"path": "notes.md"}, "✍️ write"), - ("tool.process.manage", {"action": "poll", "process_id": "proc_123"}, "🖥️ proc"), + ( + "tool.process.manage", + {"action": "poll", "process_id": "proc_123"}, + "🖥️ proc", + ), ("tool.code.execute", {"code": "print('ok')"}, "🛠️ code"), ) for tool_id, arguments, expected in cases: @@ -577,7 +636,10 @@ def test_tool_trace_keeps_gap_after_wide_variation_emoji(self) -> None: fragments = render_tool_trace_fragments(line or "") self.assertIn(expected, line or "") - self.assertIn(expected.split(" ", 1)[0] + " ", "".join(text for _style, text in fragments)) + self.assertIn( + expected.split(" ", 1)[0] + " ", + "".join(text for _style, text in fragments), + ) def test_turn_phase_cycles_marker_frames(self) -> None: self.assertEqual(turn_phase(0)[0], "✧") @@ -595,9 +657,7 @@ class _ShellProbe: def __init__(self) -> None: self._pending_commands = [] self._rendered_entries = 0 - self.transcript = [ - _Entry("\n".join(f"┊ 💻 line {index}" for index in range(20))) - ] + self.transcript = [_Entry("\n".join(f"┊ 💻 line {index}" for index in range(20)))] lines = live_tool_feed_lines(_ShellProbe()) @@ -606,7 +666,9 @@ def __init__(self) -> None: self.assertIn("┊ 💻 line 19", lines) self.assertFalse(any("earlier tool line(s) hidden" in line for line in lines)) - def test_live_tool_feed_lines_collapses_consecutive_duplicate_trace_rows(self) -> None: + def test_live_tool_feed_lines_collapses_consecutive_duplicate_trace_rows( + self, + ) -> None: class _Entry: def __init__(self, body: str) -> None: self.kind = "tooltrace" diff --git a/tests/unit/cli/test_shell_skills.py b/tests/unit/cli/test_shell_skills.py index 826ad30..a4cf1cb 100644 --- a/tests/unit/cli/test_shell_skills.py +++ b/tests/unit/cli/test_shell_skills.py @@ -64,7 +64,11 @@ def test_command_palette_hides_dynamic_skill_slash_commands(self) -> None: def test_skill_slash_command_without_instruction_loads_skill_metadata(self) -> None: shell = self._make_shell() - with mock.patch.object(shell, "_run_tool_with_progress", return_value=SimpleNamespace(summary="loaded")): + with mock.patch.object( + shell, + "_run_tool_with_progress", + return_value=SimpleNamespace(summary="loaded"), + ): handled = shell._handle_slash_command("/apple-notes") self.assertFalse(handled) @@ -72,21 +76,37 @@ def test_skill_slash_command_without_instruction_loads_skill_metadata(self) -> N self.assertIn("display_name: Apple Notes", shell.transcript[-1].body) self.assertIn("run: /apple-notes ", shell.transcript[-1].body) - def test_skill_slash_command_with_instruction_injects_skill_guidance_into_turn(self) -> None: + def test_skill_slash_command_with_instruction_injects_skill_guidance_into_turn( + self, + ) -> None: shell = self._make_shell() - with mock.patch.object(shell, "_run_tool_with_progress", return_value=SimpleNamespace(summary="loaded")): + with mock.patch.object( + shell, + "_run_tool_with_progress", + return_value=SimpleNamespace(summary="loaded"), + ): with mock.patch.object(shell, "_render_pending_entries", return_value=None): with mock.patch.object(shell, "_refresh_shell_frame", return_value=None): - with mock.patch.object(shell, "_run_turn_with_progress", return_value=self._fake_outcome("Notes opened.")) as run_turn: + with mock.patch.object( + shell, + "_run_turn_with_progress", + return_value=self._fake_outcome("Notes opened."), + ) as run_turn: handled = shell._handle_slash_command("/apple-notes open Notes and create a travel checklist") self.assertFalse(handled) prompt = run_turn.call_args.args[0] - self.assertIn('[SYSTEM: This turn references the "Apple Notes" skill from the frozen skill index.]', prompt) + self.assertIn( + '[SYSTEM: This turn references the "Apple Notes" skill from the frozen skill index.]', + prompt, + ) self.assertIn("User request: open Notes and create a travel checklist", prompt) self.assertEqual(shell.transcript[-2].kind, "user") - self.assertEqual(shell.transcript[-2].body, "/apple-notes open Notes and create a travel checklist") + self.assertEqual( + shell.transcript[-2].body, + "/apple-notes open Notes and create a travel checklist", + ) self.assertEqual(shell.transcript[-1].kind, "assistant") self.assertIn("Notes opened.", shell.transcript[-1].body) diff --git a/tests/unit/cli/test_snapshot_io.py b/tests/unit/cli/test_snapshot_io.py index 22138c4..57ec1c2 100644 --- a/tests/unit/cli/test_snapshot_io.py +++ b/tests/unit/cli/test_snapshot_io.py @@ -25,7 +25,10 @@ def test_write_snapshot_payload_replaces_with_valid_json(self) -> None: write_snapshot_payload(path, {"session": {"session_id": "two"}}) - self.assertEqual(json.loads(path.read_text(encoding="utf-8")), {"session": {"session_id": "two"}}) + self.assertEqual( + json.loads(path.read_text(encoding="utf-8")), + {"session": {"session_id": "two"}}, + ) if __name__ == "__main__": diff --git a/tests/unit/context/AGENTS.md b/tests/unit/context/AGENTS.md index ea9cdf4..874f09c 100644 --- a/tests/unit/context/AGENTS.md +++ b/tests/unit/context/AGENTS.md @@ -7,4 +7,3 @@ Rules: - keep tests deterministic and dependency-light - cover budget allocation, retrieval ordering, summary hooks, and rendering - do not assume an app process or external model provider - diff --git a/tests/unit/context/test_context_projection.py b/tests/unit/context/test_context_projection.py index af361e5..bf6446b 100644 --- a/tests/unit/context/test_context_projection.py +++ b/tests/unit/context/test_context_projection.py @@ -24,11 +24,22 @@ from packages.contracts import ContextBundle, ExecutionResult, PromptMessage from packages.contracts.layers import Episode from packages.contracts.runtime import PersonalModelRuntimeState -from packages.embeddings import EmbeddingHealth, EmbeddingPreloadEntry, EmbeddingPreloadState, EmbeddingVector +from packages.embeddings import ( + EmbeddingHealth, + EmbeddingPreloadEntry, + EmbeddingPreloadState, + EmbeddingVector, +) class _FakeProjectionEmbeddingService: - def __init__(self, *, loaded: bool = True, auto_cache: bool = True, promote_pending_on_probe: bool = False) -> None: + def __init__( + self, + *, + loaded: bool = True, + auto_cache: bool = True, + promote_pending_on_probe: bool = False, + ) -> None: self.loaded = loaded self.auto_cache = auto_cache self.promote_pending_on_probe = promote_pending_on_probe @@ -48,7 +59,14 @@ def health(self) -> EmbeddingHealth: metadata={"runtime_state": "loaded" if self.loaded else "cold"}, ) - def queue_backfill(self, *, target: str, entries, latency_mode: str = "balanced", provider_id: str | None = None): + def queue_backfill( + self, + *, + target: str, + entries, + latency_mode: str = "balanced", + provider_id: str | None = None, + ): del provider_id queued_entries = tuple(entries) self.queued_targets.append((target, len(queued_entries), latency_mode)) @@ -64,7 +82,9 @@ def queue_backfill(self, *, target: str, entries, latency_mode: str = "balanced" ) for index, entry in enumerate(queued_entries): vector_index = 0 if entry.metadata.get("kind") == "query" or "database" in entry.text else 1 - self.cache[(target, entry.cache_key, 64)] = _unit_vector(64, index=vector_index, source_text=entry.text, text_index=index) + self.cache[(target, entry.cache_key, 64)] = _unit_vector( + 64, index=vector_index, source_text=entry.text, text_index=index + ) return EmbeddingPreloadState( provider_id="elephant-local-embed", model_id="llm-semantic-router/elephant-embed", @@ -74,11 +94,25 @@ def queue_backfill(self, *, target: str, entries, latency_mode: str = "balanced" pending_targets=(), ) - def cached_vector(self, *, target: str, cache_key: str, dimensions: int, provider_id: str | None = None): + def cached_vector( + self, + *, + target: str, + cache_key: str, + dimensions: int, + provider_id: str | None = None, + ): del provider_id return self.cache.get((target, cache_key, dimensions)) - def pending_vector(self, *, target: str, cache_key: str, dimensions: int, provider_id: str | None = None) -> bool: + def pending_vector( + self, + *, + target: str, + cache_key: str, + dimensions: int, + provider_id: str | None = None, + ) -> bool: del provider_id key = (target, cache_key, dimensions) if key not in self.pending_keys: @@ -169,13 +203,19 @@ def test_compaction_preserves_head_and_tail_while_summarizing_middle(self) -> No self.assertEqual(projection.messages[:2], messages[:2]) self.assertEqual(projection.messages[-4:], messages[-4:]) self.assertLess(len(projection.messages), len(messages)) - self.assertEqual(projection.result.compacted_line_count, len(messages) - len(projection.messages)) + self.assertEqual( + projection.result.compacted_line_count, + len(messages) - len(projection.messages), + ) def test_message_compaction_preserves_roles_and_tool_results_in_tail(self) -> None: messages = tuple( PromptMessage(role="user", content=f"completed request {index} " + ("payload " * 100)) if index % 3 == 0 - else PromptMessage(role="assistant", content=f"completed response {index} " + ("implementation " * 100)) + else PromptMessage( + role="assistant", + content=f"completed response {index} " + ("implementation " * 100), + ) if index % 3 == 1 else PromptMessage( role="tool", @@ -202,7 +242,10 @@ def test_message_compaction_preserves_roles_and_tool_results_in_tail(self) -> No ) self.assertTrue(projection.result.compacted) - self.assertEqual(tuple(message.role for message in projection.messages[:2]), ("user", "assistant")) + self.assertEqual( + tuple(message.role for message in projection.messages[:2]), + ("user", "assistant"), + ) self.assertEqual(projection.messages[-1].role, "tool") self.assertEqual(projection.messages[-1].tool_call_id, "call-29") self.assertIn("CONTEXT COMPACTION - REFERENCE ONLY", projection.summary) @@ -242,13 +285,19 @@ def test_tail_compaction_keeps_tool_call_groups_atomic(self) -> None: ) self.assertTrue(projection.result.compacted) - self.assertEqual(tuple(message.role for message in projection.messages[-2:]), ("assistant", "tool")) + self.assertEqual( + tuple(message.role for message in projection.messages[-2:]), + ("assistant", "tool"), + ) self.assertEqual(projection.messages[-2].tool_calls[0]["id"], "call-live") self.assertEqual(projection.messages[-1].tool_call_id, "call-live") def test_usage_force_can_summarize_a_single_oversized_completed_turn(self) -> None: messages = ( - PromptMessage(role="user", content="oversized completed request " + ("payload " * 5000)), + PromptMessage( + role="user", + content="oversized completed request " + ("payload " * 5000), + ), PromptMessage(role="assistant", content="completed answer"), ) compactor = SessionProjectionCompactor( @@ -277,10 +326,35 @@ def test_usage_force_can_summarize_a_single_oversized_completed_turn(self) -> No def test_im_compaction_uses_burst_tail_without_protecting_old_head(self) -> None: base = datetime(2026, 5, 9, 8, 0, tzinfo=timezone.utc) messages = ( - PromptMessage(role="user", content="morning topic", metadata={"projection_surface": "im", "created_at": base.isoformat()}), - PromptMessage(role="assistant", content="morning reply", metadata={"projection_surface": "im", "created_at": (base + timedelta(minutes=1)).isoformat()}), - PromptMessage(role="user", content="evening topic", metadata={"projection_surface": "im", "created_at": (base + timedelta(hours=10)).isoformat()}), - PromptMessage(role="assistant", content="evening reply", metadata={"projection_surface": "im", "created_at": (base + timedelta(hours=10, minutes=1)).isoformat()}), + PromptMessage( + role="user", + content="morning topic", + metadata={"projection_surface": "im", "created_at": base.isoformat()}, + ), + PromptMessage( + role="assistant", + content="morning reply", + metadata={ + "projection_surface": "im", + "created_at": (base + timedelta(minutes=1)).isoformat(), + }, + ), + PromptMessage( + role="user", + content="evening topic", + metadata={ + "projection_surface": "im", + "created_at": (base + timedelta(hours=10)).isoformat(), + }, + ), + PromptMessage( + role="assistant", + content="evening reply", + metadata={ + "projection_surface": "im", + "created_at": (base + timedelta(hours=10, minutes=1)).isoformat(), + }, + ), ) compactor = SessionProjectionCompactor( policy=ProjectionCompactionPolicy( @@ -301,7 +375,10 @@ def test_im_compaction_uses_burst_tail_without_protecting_old_head(self) -> None self.assertTrue(projection.result.compacted) self.assertEqual(projection.result.protected_head_count, 0) - self.assertEqual(tuple(message.content for message in projection.messages), ("evening topic", "evening reply")) + self.assertEqual( + tuple(message.content for message in projection.messages), + ("evening topic", "evening reply"), + ) self.assertNotIn("morning topic", [message.content for message in projection.messages]) def test_embedding_ranked_middle_anchor_stays_role_preserved(self) -> None: @@ -351,9 +428,14 @@ def rank(self, *, query: str, candidates: tuple[str, ...], limit: int) -> tuple[ self.assertEqual(projection.result.protected_ranges, ("head:0-0", "tail:5-5")) self.assertEqual(len(projection.result.selected_raw_ids), 1) self.assertTrue(projection.result.summary_hash) - self.assertIn("semantic anchor: database migration decision", [message.content for message in projection.messages]) + self.assertIn( + "semantic anchor: database migration decision", + [message.content for message in projection.messages], + ) - def test_projection_embedding_backfill_queues_cached_turn_groups_after_runtime_is_loaded(self) -> None: + def test_projection_embedding_backfill_queues_cached_turn_groups_after_runtime_is_loaded( + self, + ) -> None: service = _FakeProjectionEmbeddingService(loaded=True) messages = ( PromptMessage(role="user", content="continue the database migration"), @@ -374,7 +456,9 @@ def test_projection_embedding_backfill_queues_cached_turn_groups_after_runtime_i self.assertEqual(service.embed_calls, 0) self.assertEqual(service.embed_text_calls, 0) - def test_projection_embedding_backfill_skips_cold_runtime_to_avoid_query_latency(self) -> None: + def test_projection_embedding_backfill_skips_cold_runtime_to_avoid_query_latency( + self, + ) -> None: service = _FakeProjectionEmbeddingService(loaded=False) state = queue_projection_history_embedding_backfill( @@ -388,13 +472,18 @@ def test_projection_embedding_backfill_skips_cold_runtime_to_avoid_query_latency self.assertEqual(service.embed_calls, 0) self.assertEqual(service.embed_text_calls, 0) - def test_embedding_projection_scorer_reads_cache_without_sync_embedding(self) -> None: + def test_embedding_projection_scorer_reads_cache_without_sync_embedding( + self, + ) -> None: service = _FakeProjectionEmbeddingService(loaded=True) queue_projection_history_embedding_backfill( service, messages=( PromptMessage(role="user", content="continue the database migration"), - PromptMessage(role="assistant", content="semantic anchor: database migration decision"), + PromptMessage( + role="assistant", + content="semantic anchor: database migration decision", + ), ), thread_focus="database migration", ) @@ -413,13 +502,18 @@ def test_embedding_projection_scorer_reads_cache_without_sync_embedding(self) -> self.assertEqual(service.embed_calls, 0) self.assertEqual(service.embed_text_calls, 0) - def test_embedding_projection_scorer_records_pending_and_missed_cache_state(self) -> None: + def test_embedding_projection_scorer_records_pending_and_missed_cache_state( + self, + ) -> None: service = _FakeProjectionEmbeddingService(loaded=True, auto_cache=False) queue_projection_history_embedding_backfill( service, messages=( PromptMessage(role="user", content="continue the database migration"), - PromptMessage(role="assistant", content="semantic anchor: database migration decision"), + PromptMessage( + role="assistant", + content="semantic anchor: database migration decision", + ), PromptMessage(role="assistant", content="pending architecture detail"), ), thread_focus="database migration", @@ -463,7 +557,10 @@ def test_embedding_projection_scorer_is_cache_first_and_non_blocking(self) -> No service, messages=( PromptMessage(role="user", content="continue the database migration"), - PromptMessage(role="assistant", content="semantic anchor: database migration decision"), + PromptMessage( + role="assistant", + content="semantic anchor: database migration decision", + ), ), thread_focus="database migration", ) @@ -538,7 +635,9 @@ def test_iterative_compaction_carries_previous_summary_as_reference(self) -> Non self.assertIn("## Handoff notes for recent tail", projection.summary) self.assertIn("follow-up 17", projection.summary) - def test_provider_summary_hook_uses_model_and_preserves_reference_only_header(self) -> None: + def test_provider_summary_hook_uses_model_and_preserves_reference_only_header( + self, + ) -> None: class _Provider: def __init__(self) -> None: self.calls: list[dict[str, object]] = [] @@ -558,7 +657,11 @@ def generate(self, **kwargs): now = datetime.now(timezone.utc) hook = ProviderProjectionSummaryHook( provider=provider, - profile=PersonalModelRuntimeState(profile_id="profile-test", display_name="Elephant Agent", mode="companion"), + profile=PersonalModelRuntimeState( + profile_id="profile-test", + display_name="Elephant Agent", + mode="companion", + ), session=Episode( episode_id="session-test", state_id="state:test", @@ -583,7 +686,9 @@ def generate(self, **kwargs): self.assertTrue(summary.startswith("[CONTEXT COMPACTION - REFERENCE ONLY]")) self.assertIn("Compact context safely", summary) - def test_provider_summary_hook_suppresses_stream_observer_during_internal_summary(self) -> None: + def test_provider_summary_hook_suppresses_stream_observer_during_internal_summary( + self, + ) -> None: streamed: list[str] = [] class _Provider: @@ -608,7 +713,11 @@ def generate(self, **kwargs): now = datetime.now(timezone.utc) hook = ProviderProjectionSummaryHook( provider=provider, - profile=PersonalModelRuntimeState(profile_id="profile-test", display_name="Elephant Agent", mode="companion"), + profile=PersonalModelRuntimeState( + profile_id="profile-test", + display_name="Elephant Agent", + mode="companion", + ), session=Episode( episode_id="session-test", state_id="state:test", diff --git a/tests/unit/context/test_context_runtime.py b/tests/unit/context/test_context_runtime.py index ce72b16..46f0fc7 100644 --- a/tests/unit/context/test_context_runtime.py +++ b/tests/unit/context/test_context_runtime.py @@ -30,7 +30,12 @@ class ContextRuntimeTest(unittest.TestCase): - def _session(self, *, interruption_state: str | None = None, parent_session_id: str | None = None) -> Episode: + def _session( + self, + *, + interruption_state: str | None = None, + parent_session_id: str | None = None, + ) -> Episode: return Episode( episode_id="session-1", state_id="state:test", @@ -156,9 +161,21 @@ def test_budget_manager_allocates_and_reports_overflow(self) -> None: 120, ( ContextBudgetRequest("stable_prefix", 32, minimum_tokens=16, required=True, priority=100), - ContextBudgetRequest("session_snapshot", 72, minimum_tokens=32, required=True, priority=90), + ContextBudgetRequest( + "session_snapshot", + 72, + minimum_tokens=32, + required=True, + priority=90, + ), ContextBudgetRequest("loop_context", 32, minimum_tokens=16, required=True, priority=80), - ContextBudgetRequest("request_attachments", 24, minimum_tokens=0, required=False, priority=10), + ContextBudgetRequest( + "request_attachments", + 24, + minimum_tokens=0, + required=False, + priority=10, + ), ), ) @@ -286,7 +303,9 @@ def test_source_trace_explains_compaction_and_retrieval(self) -> None: # Per R1, Source Trace is telemetry only — not rendered into the prompt. self.assertNotIn("Source Trace", detailed.rendered_prompt) - def test_steady_selection_prefers_work_item_linked_memory_over_newer_filler(self) -> None: + def test_steady_selection_prefers_work_item_linked_memory_over_newer_filler( + self, + ) -> None: runtime = ContextRuntime( planner=LayeredContextPlanner( budget_manager=DeterministicBudgetManager(), @@ -356,7 +375,8 @@ def test_retrieval_layer_renders_selected_memory_and_reason(self) -> None: # tag. It is now a short natural label (`Summary:`, `Decision:`, # `Relationship note:`, ...). evidence_lines = tuple( - line for line in snapshot_layer.content + line + for line in snapshot_layer.content if "Summary: The last turn asked for recovery after a gap." in line and "why:" in line ) self.assertTrue(evidence_lines) @@ -366,7 +386,9 @@ def test_retrieval_layer_renders_selected_memory_and_reason(self) -> None: prompt_snapshot = detailed.bundle.prompt_envelope.session_snapshot self.assertEqual(prompt_snapshot, "") - def test_session_snapshot_summary_keeps_profile_values_legible_when_truncated(self) -> None: + def test_session_snapshot_summary_keeps_profile_values_legible_when_truncated( + self, + ) -> None: runtime = ContextRuntime( planner=LayeredContextPlanner( budget_manager=DeterministicBudgetManager(), @@ -397,7 +419,9 @@ def test_session_snapshot_summary_keeps_profile_values_legible_when_truncated(se self.assertNotIn("Preferred name: Xunzhuo", prompt_snapshot) self.assertNotIn("MBTI: INTJ", prompt_snapshot) - def test_steady_layer_prefers_work_item_linked_and_corrected_memory_over_blind_recency(self) -> None: + def test_steady_layer_prefers_work_item_linked_and_corrected_memory_over_blind_recency( + self, + ) -> None: runtime = ContextRuntime( planner=LayeredContextPlanner( budget_manager=DeterministicBudgetManager(), @@ -524,7 +548,9 @@ def test_steady_summary_surfaces_retained_and_compacted_memory_refs(self) -> Non self.assertIn("steady:", snapshot_summary) self.assertIn("interruption: resume-after-gap", snapshot_summary) - def test_replay_request_does_not_depend_on_removed_structured_evidence_copies(self) -> None: + def test_replay_request_does_not_depend_on_removed_structured_evidence_copies( + self, + ) -> None: runtime = ContextRuntime( planner=LayeredContextPlanner( budget_manager=DeterministicBudgetManager(), @@ -581,10 +607,19 @@ def test_profile_state_focus_suppresses_work_slice_and_replay(self) -> None: self.assertIsNotNone(detailed.frame) assert detailed.frame is not None self.assertEqual(detailed.frame.session_snapshot.work_refs, ()) - self.assertNotIn("work-slice: personal_model scope suppressed active elephant work items", detailed.frame.session_snapshot.content) + self.assertNotIn( + "work-slice: personal_model scope suppressed active elephant work items", + detailed.frame.session_snapshot.content, + ) self.assertIsNone(detailed.frame.replay_packet) - self.assertIn("work slice suppressed by personal_model scope", detailed.summary_by_layer["session_snapshot"]) - self.assertIn("personal-model elephant focus suppresses unrelated work refs", detailed.plan.rationale) + self.assertIn( + "work slice suppressed by personal_model scope", + detailed.summary_by_layer["session_snapshot"], + ) + self.assertIn( + "personal-model elephant focus suppresses unrelated work refs", + detailed.plan.rationale, + ) if __name__ == "__main__": diff --git a/tests/unit/context/test_prompt_purity.py b/tests/unit/context/test_prompt_purity.py index 62fa08d..99ee2d1 100644 --- a/tests/unit/context/test_prompt_purity.py +++ b/tests/unit/context/test_prompt_purity.py @@ -111,7 +111,7 @@ def test_rendered_prompt_has_no_runtime_ids(self) -> None: self.assertIsNone( match, msg=f"runtime id matching {pattern.pattern!r} leaked into rendered prompt: " - f"match={match.group(0) if match else None}", + f"match={match.group(0) if match else None}", ) def test_prompt_contains_human_readable_titles_and_content(self) -> None: diff --git a/tests/unit/continuity/test_runtime.py b/tests/unit/continuity/test_runtime.py index 9b7140c..ee0db87 100644 --- a/tests/unit/continuity/test_runtime.py +++ b/tests/unit/continuity/test_runtime.py @@ -33,7 +33,9 @@ def _episode( class ContinuityRuntimeTests(unittest.TestCase): - def test_build_episode_continuity_state_inherits_ancestor_interruption(self) -> None: + def test_build_episode_continuity_state_inherits_ancestor_interruption( + self, + ) -> None: parent = _episode("root", interruption_state="Need to finish the plan") child = _episode("child", parent_episode_id="root") @@ -48,7 +50,9 @@ def test_build_episode_continuity_state_inherits_ancestor_interruption(self) -> self.assertEqual(continuity.inherited_interruption_state, "Need to finish the plan") self.assertNotIn("current-work item", continuity.summary) - def test_build_episode_continuity_state_normalizes_generated_resume_text(self) -> None: + def test_build_episode_continuity_state_normalizes_generated_resume_text( + self, + ) -> None: episode = _episode( "child", parent_episode_id="root", @@ -63,7 +67,9 @@ def test_build_episode_continuity_state_normalizes_generated_resume_text(self) - self.assertEqual(continuity.mode, "background") self.assertEqual(continuity.inherited_interruption_state, "Recover the thread") - def test_apply_episode_continuity_state_restores_inherited_interruption_when_needed(self) -> None: + def test_apply_episode_continuity_state_restores_inherited_interruption_when_needed( + self, + ) -> None: parent = _episode("root", interruption_state="Return to the design review") child = _episode("child", parent_episode_id="root") continuity = build_episode_continuity_state(child, lineage=(parent, child)) diff --git a/tests/unit/cron/test_runtime.py b/tests/unit/cron/test_runtime.py index f580585..bf1b9b0 100644 --- a/tests/unit/cron/test_runtime.py +++ b/tests/unit/cron/test_runtime.py @@ -56,7 +56,11 @@ def test_due_interval_job_executes_and_reschedules(self) -> None: elephant_id="atlas", ) - due = runtime.due_jobs(now=base + timedelta(hours=2, minutes=1), profile_id="elephant:atlas", elephant_id="atlas") + due = runtime.due_jobs( + now=base + timedelta(hours=2, minutes=1), + profile_id="elephant:atlas", + elephant_id="atlas", + ) self.assertEqual(len(due), 1) self.assertEqual(due[0].job_id, job.job_id) @@ -115,7 +119,11 @@ def test_profile_scoped_due_jobs_include_global_jobs(self) -> None: payload={"prompt": "scan"}, ) - due = runtime.due_jobs(now=base + timedelta(hours=1, minutes=1), profile_id="elephant:atlas", elephant_id="atlas") + due = runtime.due_jobs( + now=base + timedelta(hours=1, minutes=1), + profile_id="elephant:atlas", + elephant_id="atlas", + ) self.assertEqual(tuple(item.job_id for item in due), (job.job_id,)) diff --git a/tests/unit/embeddings/test_runtime.py b/tests/unit/embeddings/test_runtime.py index 04d431b..ade2990 100644 --- a/tests/unit/embeddings/test_runtime.py +++ b/tests/unit/embeddings/test_runtime.py @@ -52,20 +52,38 @@ def test_provider_health_reflects_local_runtime_state(self) -> None: provider = SentenceTransformerEmbeddingProvider(model_root="/tmp/elephant-embed") with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=False), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=False), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=False, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=False, + ), ): self.assertEqual(provider.health().status, "pending") with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=False), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=False, + ), ): self.assertEqual(provider.health().status, "downloading") with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), ): health = provider.health() @@ -79,8 +97,14 @@ def test_provider_health_reports_runtime_steady_state_metadata(self) -> None: provider._steady_thread = steadying_thread with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), ): steadying = provider.health() @@ -89,8 +113,14 @@ def test_provider_health_reports_runtime_steady_state_metadata(self) -> None: provider._model = object() with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), ): loaded = provider.health() @@ -102,8 +132,14 @@ def test_default_service_uses_canonical_provider(self) -> None: service = DefaultEmbeddingService(registry=InMemoryEmbeddingModelRegistry((provider,))) with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), mock.patch.object(provider, "_encode_texts", return_value=(_unit_vector(64),)), ): vector = service.embed_text( @@ -150,8 +186,14 @@ def test_registry_selects_registered_provider(self) -> None: service = DefaultEmbeddingService(registry=registry) with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), mock.patch.object( provider, "_encode_texts", @@ -214,8 +256,14 @@ def test_local_embedding_loader_passes_tokenizer_regex_fix(self) -> None: fake_module = types.SimpleNamespace(SentenceTransformer=sentence_transformer) with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), mock.patch.dict(sys.modules, {"sentence_transformers": fake_module}), ): provider._load_model() @@ -226,7 +274,9 @@ def test_local_embedding_loader_passes_tokenizer_regex_fix(self) -> None: processor_kwargs={"fix_mistral_regex": True}, ) - def test_local_embedding_loader_suppresses_known_tokenizer_regex_warning(self) -> None: + def test_local_embedding_loader_suppresses_known_tokenizer_regex_warning( + self, + ) -> None: class _CaptureHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: records.append(record) @@ -272,16 +322,20 @@ def test_preload_and_background_backfill_fill_the_shared_cache(self) -> None: EmbeddingPreloadEntry(cache_key="memory-1", text="release evidence summary"), EmbeddingPreloadEntry(cache_key="memory-2", text="release checklist"), ) - backfill_entries = ( - EmbeddingPreloadEntry(cache_key="projection-1", text="projection anchor summary"), - ) + backfill_entries = (EmbeddingPreloadEntry(cache_key="projection-1", text="projection anchor summary"),) def _encode(texts: tuple[str, ...], *, dimensions: int) -> tuple[tuple[float, ...], ...]: return tuple(_unit_vector(dimensions, index=index) for index, _text in enumerate(texts)) with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), mock.patch.object(provider, "_encode_texts", side_effect=_encode), ): preload_state = provider.preload( @@ -313,7 +367,9 @@ def _encode(texts: tuple[str, ...], *, dimensions: int) -> tuple[tuple[float, .. time.sleep(0.01) self.assertIsNotNone(projection_cached) - def test_queue_backfill_skips_cached_entries_when_a_higher_dimension_vector_exists(self) -> None: + def test_queue_backfill_skips_cached_entries_when_a_higher_dimension_vector_exists( + self, + ) -> None: provider = SentenceTransformerEmbeddingProvider() entry = EmbeddingPreloadEntry(cache_key="memory-1", text="release evidence summary") @@ -321,8 +377,14 @@ def _encode(texts: tuple[str, ...], *, dimensions: int) -> tuple[tuple[float, .. return tuple(_unit_vector(dimensions, index=index) for index, _text in enumerate(texts)) with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), mock.patch.object(provider, "_encode_texts", side_effect=_encode), ): provider.preload( @@ -346,8 +408,14 @@ def test_pending_vector_reports_queued_backfill_without_embedding(self) -> None: entry = EmbeddingPreloadEntry(cache_key="memory-1", text="release evidence summary") with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), mock.patch.object(provider, "_spawn_backfill_worker") as spawn, ): state = provider.queue_backfill( @@ -366,8 +434,14 @@ def test_queue_backfill_respects_failure_cooldown_after_worker_errors(self) -> N entry = EmbeddingPreloadEntry(cache_key="memory-1", text="release evidence summary") with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), mock.patch.object(provider, "_encode_texts", side_effect=RuntimeError("boom")), ): provider.queue_backfill( @@ -404,8 +478,14 @@ def _load_model() -> object: return object() with ( - mock.patch("packages.embeddings.runtime.sentence_transformers_dependencies_ready", return_value=True), - mock.patch("packages.embeddings.runtime.embedding_root_is_healthy", return_value=True), + mock.patch( + "packages.embeddings.runtime.sentence_transformers_dependencies_ready", + return_value=True, + ), + mock.patch( + "packages.embeddings.runtime.embedding_root_is_healthy", + return_value=True, + ), mock.patch.object(provider, "_load_model", side_effect=_load_model), ): self.assertTrue(provider.steady_async()) diff --git a/tests/unit/evidence/test_crystallization_assets.py b/tests/unit/evidence/test_crystallization_assets.py index 517e5a1..051e8cc 100644 --- a/tests/unit/evidence/test_crystallization_assets.py +++ b/tests/unit/evidence/test_crystallization_assets.py @@ -15,7 +15,6 @@ from packages.evidence.crystallization_runtime_impl import ( _b64_encode, - _extract_asset_from_step, _materialize_assets_from_steps, ) @@ -64,7 +63,13 @@ def test_invalid_json_is_reported(self) -> None: self.assertIn("error", report) def test_missing_source_step_is_reported(self) -> None: - hints = [{"path": "scripts/run.sh", "source_step_id": "nope", "content_kind": "script"}] + hints = [ + { + "path": "scripts/run.sh", + "source_step_id": "nope", + "content_kind": "script", + } + ] materialized, report = _materialize_assets_from_steps( repository=_FakeRepo(), steps=(), @@ -74,14 +79,22 @@ def test_missing_source_step_is_reported(self) -> None: self.assertEqual(report["missing"][0]["reason"], "no_source_step") def test_extracts_from_tool_call_arguments(self) -> None: - record = _FakeRecord(payload={ - "tool_calls": [ - {"arguments": {"content": "#!/bin/bash\necho hi\n"}}, - ], - }) + record = _FakeRecord( + payload={ + "tool_calls": [ + {"arguments": {"content": "#!/bin/bash\necho hi\n"}}, + ], + } + ) step = _FakeStep(step_id="step:1", payload_refs=("ref:1",)) repo = _FakeRepo({"ref:1": record}) - hints = [{"path": "scripts/run.sh", "source_step_id": "step:1", "content_kind": "script"}] + hints = [ + { + "path": "scripts/run.sh", + "source_step_id": "step:1", + "content_kind": "script", + } + ] materialized, report = _materialize_assets_from_steps( repository=repo, steps=(step,), @@ -94,7 +107,13 @@ def test_extracts_from_stdout_fallback(self) -> None: record = _FakeRecord(payload={"stdout": "generated config content\n"}) step = _FakeStep(step_id="step:2", payload_refs=("ref:2",)) repo = _FakeRepo({"ref:2": record}) - hints = [{"path": "config/x.yaml", "source_step_id": "step:2", "content_kind": "config"}] + hints = [ + { + "path": "config/x.yaml", + "source_step_id": "step:2", + "content_kind": "config", + } + ] materialized, report = _materialize_assets_from_steps( repository=repo, steps=(step,), @@ -104,7 +123,13 @@ def test_extracts_from_stdout_fallback(self) -> None: def test_fallback_to_step_outcome(self) -> None: step = _FakeStep(step_id="step:3", outcome="raw outcome text") - hints = [{"path": "notes.txt", "source_step_id": "step:3", "content_kind": "reference"}] + hints = [ + { + "path": "notes.txt", + "source_step_id": "step:3", + "content_kind": "reference", + } + ] materialized, _ = _materialize_assets_from_steps( repository=_FakeRepo(), steps=(step,), @@ -117,7 +142,13 @@ def test_truncates_oversized_content(self) -> None: record = _FakeRecord(payload={"stdout": huge_content}) step = _FakeStep(step_id="step:4", payload_refs=("ref:4",)) repo = _FakeRepo({"ref:4": record}) - hints = [{"path": "huge.bin", "source_step_id": "step:4", "content_kind": "reference"}] + hints = [ + { + "path": "huge.bin", + "source_step_id": "step:4", + "content_kind": "reference", + } + ] materialized, report = _materialize_assets_from_steps( repository=repo, steps=(step,), diff --git a/tests/unit/evidence/test_locator_match.py b/tests/unit/evidence/test_locator_match.py index d95f714..0cbdda5 100644 --- a/tests/unit/evidence/test_locator_match.py +++ b/tests/unit/evidence/test_locator_match.py @@ -153,11 +153,7 @@ def embed_text(self, text: str): # noqa: ARG002 entries = (FakeEntry("alpha"),) # No lexical match, broken embedding → None, not exception. - self.assertIsNone( - find_entry_by_locator( - entries, "completely different", embedding_service=_BrokenEmbedding() - ) - ) + self.assertIsNone(find_entry_by_locator(entries, "completely different", embedding_service=_BrokenEmbedding())) if __name__ == "__main__": diff --git a/tests/unit/evidence/test_recall_planning.py b/tests/unit/evidence/test_recall_planning.py index 5aad611..bc926cb 100644 --- a/tests/unit/evidence/test_recall_planning.py +++ b/tests/unit/evidence/test_recall_planning.py @@ -2,7 +2,7 @@ from __future__ import annotations -from packages.evidence import normalize_recall_query, plan_recall_query +from packages.evidence import plan_recall_query def test_recent_chinese_recap_extracts_topic_core() -> None: diff --git a/tests/unit/evidence/test_recall_rerank.py b/tests/unit/evidence/test_recall_rerank.py index 324faae..7783cb6 100644 --- a/tests/unit/evidence/test_recall_rerank.py +++ b/tests/unit/evidence/test_recall_rerank.py @@ -4,7 +4,12 @@ from datetime import datetime, timedelta, timezone -from packages.evidence import RecallHit, plan_recall_query, rerank_recall_hits, score_recall_hit +from packages.evidence import ( + RecallHit, + plan_recall_query, + rerank_recall_hits, + score_recall_hit, +) _NOW = datetime(2026, 5, 8, tzinfo=timezone.utc) diff --git a/tests/unit/evidence/test_recall_support.py b/tests/unit/evidence/test_recall_support.py index a1f113f..d0db2f0 100644 --- a/tests/unit/evidence/test_recall_support.py +++ b/tests/unit/evidence/test_recall_support.py @@ -4,7 +4,6 @@ from datetime import datetime, timedelta, timezone -import pytest from packages.evidence import ( RecallCandidate, diff --git a/tests/unit/gateway/test_cron_service.py b/tests/unit/gateway/test_cron_service.py index aeb8cfb..ccdde71 100644 --- a/tests/unit/gateway/test_cron_service.py +++ b/tests/unit/gateway/test_cron_service.py @@ -5,7 +5,10 @@ from unittest import mock from apps.gateway import cron_service -from apps.gateway.cron_service import cron_execution_should_deliver, _try_deliver_cron_result +from apps.gateway.cron_service import ( + cron_execution_should_deliver, + _try_deliver_cron_result, +) from packages.cron import CronJob, CronJobExecution @@ -37,7 +40,10 @@ def test_learning_cron_enqueue_ack_is_not_delivered_to_im(self) -> None: recorded_at=datetime.now(timezone.utc), ) - _try_deliver_cron_result(lambda delivered_job, delivered_execution: calls.append((delivered_job, delivered_execution)), execution) + _try_deliver_cron_result( + lambda delivered_job, delivered_execution: calls.append((delivered_job, delivered_execution)), + execution, + ) self.assertEqual(calls, []) self.assertFalse(cron_execution_should_deliver(execution)) @@ -52,7 +58,10 @@ def test_prompt_cron_result_still_delivers(self) -> None: recorded_at=datetime.now(timezone.utc), ) - _try_deliver_cron_result(lambda delivered_job, delivered_execution: calls.append((delivered_job, delivered_execution)), execution) + _try_deliver_cron_result( + lambda delivered_job, delivered_execution: calls.append((delivered_job, delivered_execution)), + execution, + ) self.assertEqual(calls, [(job, execution)]) self.assertTrue(cron_execution_should_deliver(execution)) @@ -67,7 +76,10 @@ def test_silent_prompt_cron_result_is_not_delivered(self) -> None: recorded_at=datetime.now(timezone.utc), ) - _try_deliver_cron_result(lambda delivered_job, delivered_execution: calls.append((delivered_job, delivered_execution)), execution) + _try_deliver_cron_result( + lambda delivered_job, delivered_execution: calls.append((delivered_job, delivered_execution)), + execution, + ) self.assertEqual(calls, []) self.assertFalse(cron_execution_should_deliver(execution)) diff --git a/tests/unit/gateway/test_dingding_inbound_serialization.py b/tests/unit/gateway/test_dingding_inbound_serialization.py index 62adf0f..40d8314 100644 --- a/tests/unit/gateway/test_dingding_inbound_serialization.py +++ b/tests/unit/gateway/test_dingding_inbound_serialization.py @@ -1,4 +1,5 @@ """Unit tests for DingDing same-conversation inbound serialization.""" + from __future__ import annotations import asyncio @@ -7,7 +8,10 @@ from unittest import mock from apps.gateway.dingding_service import DingdingGatewayService -from apps.gateway.dingding_support import DingdingGatewayAccountConfig, DingdingResolvedAccount +from apps.gateway.dingding_support import ( + DingdingGatewayAccountConfig, + DingdingResolvedAccount, +) def _make_account(account_id: str = "ops-dingding") -> DingdingResolvedAccount: @@ -45,7 +49,9 @@ def __init__(self) -> None: self.runtime_calls: list[str] = [] self.loaded_profile = None self.state_dir = None - self.core = SimpleNamespace(route_inbound=lambda *a, **kw: SimpleNamespace(delivery=SimpleNamespace(outbound=None))) + self.core = SimpleNamespace( + route_inbound=lambda *a, **kw: SimpleNamespace(delivery=SimpleNamespace(outbound=None)) + ) self.loaded_profile = None def handle_message(self, inbound, **kwargs): @@ -137,7 +143,11 @@ async def send_stub(_self, delivery_request, **kw) -> None: else: second_started.set() - with mock.patch.object(type(app), "handle_message", side_effect=AssertionError("should not run")): + with mock.patch.object( + type(app), + "handle_message", + side_effect=AssertionError("should not run"), + ): with mock.patch.object(type(service), "_send_dingtalk_reply", new=send_stub): first_task = asyncio.create_task( service._on_dingtalk_message_safe( diff --git a/tests/unit/gateway/test_gateway_context_history.py b/tests/unit/gateway/test_gateway_context_history.py index 45e3ed8..9c2e55a 100644 --- a/tests/unit/gateway/test_gateway_context_history.py +++ b/tests/unit/gateway/test_gateway_context_history.py @@ -13,7 +13,14 @@ ) from packages.context.epoch_store import FileEpochStore, InMemoryEpochStore from packages.contracts.layers import Episode -from packages.contracts.runtime import ContextBundle, EventEnvelope, ExecutionResult, PersonalModelRuntimeState, PromptMessage, PromptEnvelope +from packages.contracts.runtime import ( + ContextBundle, + EventEnvelope, + ExecutionResult, + PersonalModelRuntimeState, + PromptMessage, + PromptEnvelope, +) from packages.state import CompanionSettings, render_user_profile_text from packages.state import write_elephant_identity_file from packages.state.projection import build_loaded_profile_from_state @@ -55,7 +62,10 @@ def test_gateway_context_uses_shared_session_epoch_projection(self) -> None: session_snapshot="SESSION SNAPSHOT", history_messages=( PromptMessage(role="user", content="工作好忙"), - PromptMessage(role="assistant", content="你是想倒倒苦水说一下在忙什么,还是就想有人知道你今天很忙?"), + PromptMessage( + role="assistant", + content="你是想倒倒苦水说一下在忙什么,还是就想有人知道你今天很忙?", + ), ), ), ) @@ -131,7 +141,9 @@ def test_session_epoch_persists_with_epoch_store(self) -> None: assert preflight_epoch is not None self.assertEqual(preflight_epoch.history_messages[0].content, "preflight") - def test_im_idle_gap_resets_projection_tail_before_appending_new_burst(self) -> None: + def test_im_idle_gap_resets_projection_tail_before_appending_new_burst( + self, + ) -> None: base = datetime(2026, 5, 7, 3, 0, tzinfo=timezone.utc) session = Episode( episode_id="episode:wx", @@ -151,7 +163,10 @@ def test_im_idle_gap_resets_projection_tail_before_appending_new_burst(self) -> PromptMessage( role="user", content="早上第一句", - metadata={"projection_surface": "im", "created_at": base.isoformat()}, + metadata={ + "projection_surface": "im", + "created_at": base.isoformat(), + }, ), ), ) @@ -165,7 +180,10 @@ def test_im_idle_gap_resets_projection_tail_before_appending_new_burst(self) -> event_type="turn.received", episode_id=session.episode_id, source="gateway:feishu", - payload={"content": "晚上新话题", "delivery_surface": "feishu-long-connection"}, + payload={ + "content": "晚上新话题", + "delivery_surface": "feishu-long-connection", + }, ), execution=ExecutionResult( execution_id="exec:2", @@ -182,10 +200,15 @@ def test_im_idle_gap_resets_projection_tail_before_appending_new_burst(self) -> ) self.assertEqual(updated.compacted_history_summary, "") - self.assertEqual(tuple(message.content for message in updated.history_messages), ("晚上新话题", "晚上新回复")) + self.assertEqual( + tuple(message.content for message in updated.history_messages), + ("晚上新话题", "晚上新回复"), + ) self.assertEqual(updated.history_messages[0].metadata["projection_surface"], "im") - def test_existing_session_epoch_does_not_refresh_frozen_prefix_on_normal_turn(self) -> None: + def test_existing_session_epoch_does_not_refresh_frozen_prefix_on_normal_turn( + self, + ) -> None: """Frozen prefix only refreshes on episode open, not on normal turns.""" session = Episode( episode_id="episode:wx", @@ -229,7 +252,10 @@ def test_existing_session_epoch_does_not_refresh_frozen_prefix_on_normal_turn(se # Frozen prefix does NOT refresh on normal turns (only on episode open) self.assertEqual(updated.frozen_prefix, "old PM facts") self.assertEqual(updated.compacted_history_summary, "older summary") - self.assertEqual(tuple(message.content for message in updated.history_messages), ("existing tail",)) + self.assertEqual( + tuple(message.content for message in updated.history_messages), + ("existing tail",), + ) def test_internal_proactive_prompt_is_not_appended_to_session_epoch(self) -> None: session = Episode( diff --git a/tests/unit/gateway/test_main.py b/tests/unit/gateway/test_main.py index 6e08e4a..f18b006 100644 --- a/tests/unit/gateway/test_main.py +++ b/tests/unit/gateway/test_main.py @@ -19,8 +19,16 @@ class GatewayWizardIntegrationTest(unittest.TestCase): def test_gateway_text_prompt_uses_shared_wizard_dialogs(self) -> None: with ( - mock.patch.object(gateway_wizard_ui, "_gateway_wizard_dialogs_supported", return_value=True), - mock.patch.object(gateway_wizard_ui, "_shared_wizard_text_prompt", return_value="demo-elephant") as shared_prompt, + mock.patch.object( + gateway_wizard_ui, + "_gateway_wizard_dialogs_supported", + return_value=True, + ), + mock.patch.object( + gateway_wizard_ui, + "_shared_wizard_text_prompt", + return_value="demo-elephant", + ) as shared_prompt, ): answer = gateway_main._gateway_wizard_text_prompt( "Default Elephant", @@ -41,12 +49,25 @@ def test_gateway_text_prompt_uses_shared_wizard_dialogs(self) -> None: def test_gateway_choice_prompt_uses_shared_wizard_dialogs(self) -> None: choices = ( - WizardChoice(value="long-connection", label="Long Connection", detail="Local bridge.", emoji="🛰️"), + WizardChoice( + value="long-connection", + label="Long Connection", + detail="Local bridge.", + emoji="🛰️", + ), WizardChoice(value="skip", label="Skip", detail="Stay local.", emoji="➖"), ) with ( - mock.patch.object(gateway_wizard_ui, "_gateway_wizard_dialogs_supported", return_value=True), - mock.patch.object(gateway_wizard_ui, "_shared_wizard_choice_prompt", return_value="long-connection") as shared_choice, + mock.patch.object( + gateway_wizard_ui, + "_gateway_wizard_dialogs_supported", + return_value=True, + ), + mock.patch.object( + gateway_wizard_ui, + "_shared_wizard_choice_prompt", + return_value="long-connection", + ) as shared_choice, ): answer = gateway_main._gateway_wizard_choice_prompt( "Ingress Transport", @@ -67,8 +88,16 @@ def test_gateway_choice_prompt_uses_shared_wizard_dialogs(self) -> None: def test_gateway_secret_prompt_uses_shared_password_dialog(self) -> None: with ( - mock.patch.object(gateway_wizard_ui, "_gateway_wizard_dialogs_supported", return_value=True), - mock.patch.object(gateway_wizard_ui, "_shared_wizard_text_prompt", return_value="secret-value") as shared_prompt, + mock.patch.object( + gateway_wizard_ui, + "_gateway_wizard_dialogs_supported", + return_value=True, + ), + mock.patch.object( + gateway_wizard_ui, + "_shared_wizard_text_prompt", + return_value="secret-value", + ) as shared_prompt, ): answer = gateway_main._gateway_wizard_secret_prompt( "Paste App Secret", @@ -84,11 +113,21 @@ def test_gateway_secret_prompt_uses_shared_password_dialog(self) -> None: password=True, ) - def test_gateway_choice_prompt_preserves_back_signal_from_shared_wizard(self) -> None: + def test_gateway_choice_prompt_preserves_back_signal_from_shared_wizard( + self, + ) -> None: choices = (WizardChoice(value="feishu", label="Feishu", detail="Wire Feishu.", emoji="🐦"),) with ( - mock.patch.object(gateway_wizard_ui, "_gateway_wizard_dialogs_supported", return_value=True), - mock.patch.object(gateway_wizard_ui, "_shared_wizard_choice_prompt", return_value=WIZARD_BACK), + mock.patch.object( + gateway_wizard_ui, + "_gateway_wizard_dialogs_supported", + return_value=True, + ), + mock.patch.object( + gateway_wizard_ui, + "_shared_wizard_choice_prompt", + return_value=WIZARD_BACK, + ), ): answer = gateway_main._gateway_wizard_choice_prompt( "💬 IM Setup", @@ -117,7 +156,9 @@ def test_run_im_setup_can_dispatch_discord_wizard(self) -> None: default_control_state_dir=Path("/tmp/state"), ) - def test_gateway_discord_wizard_intro_prints_setup_card_without_extra_confirmation(self) -> None: + def test_gateway_discord_wizard_intro_prints_setup_card_without_extra_confirmation( + self, + ) -> None: output = io.StringIO() with ( mock.patch.object(gateway_wizard_ui, "RICH_AVAILABLE", False), @@ -132,7 +173,9 @@ def test_gateway_discord_wizard_intro_prints_setup_card_without_extra_confirmati self.assertIn("Bring Discord into Elephant Agent Gateway.", rendered) self.assertIn("Discord portal checklist", rendered) - def test_gateway_feishu_wizard_intro_prints_setup_card_without_extra_confirmation(self) -> None: + def test_gateway_feishu_wizard_intro_prints_setup_card_without_extra_confirmation( + self, + ) -> None: output = io.StringIO() with ( mock.patch.object(gateway_wizard_ui, "RICH_AVAILABLE", False), @@ -152,7 +195,9 @@ def test_confirm_gateway_wizard_intro_auto_accepts_without_prompt(self) -> None: input_mock.assert_not_called() - def test_prompt_gateway_control_binding_uses_elephant_and_session_menus_when_runtime_is_ready(self) -> None: + def test_prompt_gateway_control_binding_uses_elephant_and_session_menus_when_runtime_is_ready( + self, + ) -> None: now = datetime.now(UTC) latest_session = Episode( episode_id="session-demo-latest", @@ -220,7 +265,9 @@ def inspect_session(self, session_id: str) -> Episode: self.assertEqual(session_choices[0].value, gateway_main._GATEWAY_FOLLOW_LATEST_SESSION) self.assertTrue(any(choice.value == "session-demo-root" for choice in session_choices)) - def test_prompt_gateway_control_binding_skips_session_menu_when_elephant_has_single_session(self) -> None: + def test_prompt_gateway_control_binding_skips_session_menu_when_elephant_has_single_session( + self, + ) -> None: now = datetime.now(UTC) only_session = Episode( episode_id="session-demo-only", @@ -320,7 +367,9 @@ def test_start_feishu_runtime_after_setup_fills_runtime_defaults(self) -> None: self.assertFalse(restart_args.force) self.assertEqual(run_restart.call_args.kwargs["service"], build_service.return_value) - def test_prompt_gateway_control_binding_falls_back_to_text_when_runtime_is_unavailable(self) -> None: + def test_prompt_gateway_control_binding_falls_back_to_text_when_runtime_is_unavailable( + self, + ) -> None: with mock.patch.object( gateway_wizard_binding, "_gateway_wizard_text_prompt", @@ -435,7 +484,10 @@ def describe(self) -> dict[str, object]: "pid": 123, "pid_active": True, "stale_pid": False, - "record": {"status": "running", "command": ("python", "-m", "apps.launcher")}, + "record": { + "status": "running", + "command": ("python", "-m", "apps.launcher"), + }, } output = io.StringIO() with ( @@ -484,13 +536,21 @@ def test_cron_scheduler_service_builds_managed_runtime_command(self) -> None: def test_cron_logs_do_not_require_account_id(self) -> None: service = _ManagedOnlyService() service.service_key = "cron" - args = mock.Mock(runtime_target="configured", account_id=None, account_id_flag=None, path=True) + args = mock.Mock( + runtime_target="configured", + account_id=None, + account_id_flag=None, + path=True, + ) runtime_state = { "status": "running", "pid": 123, "pid_active": True, "stale_pid": False, - "record": {"status": "running", "command": ("python", "-m", "apps.launcher")}, + "record": { + "status": "running", + "command": ("python", "-m", "apps.launcher"), + }, } output = io.StringIO() with ( @@ -569,10 +629,7 @@ def test_discord_doctor_lines_include_account_health_summary(self) -> None: lines, ) self.assertTrue( - any( - line.startswith("discord_account: shadow-discord · enabled=yes · startup=blocked") - for line in lines - ) + any(line.startswith("discord_account: shadow-discord · enabled=yes · startup=blocked") for line in lines) ) def test_doctor_services_lines_tolerate_non_mapping_runtime_payloads(self) -> None: diff --git a/tests/unit/gateway/test_outbound_delivery.py b/tests/unit/gateway/test_outbound_delivery.py index c444fc3..725257c 100644 --- a/tests/unit/gateway/test_outbound_delivery.py +++ b/tests/unit/gateway/test_outbound_delivery.py @@ -6,21 +6,20 @@ import unittest from packages.contracts.runtime import ExecutionResult -from packages.gateway_core.outbound_delivery import GatewayMessageDeliverySurface, _try_parse_session_route +from packages.gateway_core.outbound_delivery import ( + GatewayMessageDeliverySurface, + _try_parse_session_route, +) from packages.gateway_core.outbound_queue import GatewayOutboundQueue class ParseSessionRouteTest(unittest.TestCase): def test_valid_session_id(self): - result = _try_parse_session_route( - "session:messaging.feishu:bot123@im.bot:user456@im.feishu" - ) + result = _try_parse_session_route("session:messaging.feishu:bot123@im.bot:user456@im.feishu") self.assertEqual(result, ("messaging.feishu", "bot123@im.bot", "user456@im.feishu")) def test_conversation_id_with_colons(self): - result = _try_parse_session_route( - "session:messaging.weixin:bot@im.bot:conv:with:colons" - ) + result = _try_parse_session_route("session:messaging.weixin:bot@im.bot:conv:with:colons") self.assertEqual(result, ("messaging.weixin", "bot@im.bot", "conv:with:colons")) def test_invalid_prefix_returns_none(self): @@ -78,12 +77,14 @@ def test_send_via_gateway_session(self): def test_send_via_identity_store_fallback(self): """CLI session_id doesn't parse — falls back to identity store.""" - identity_store = _FakeIdentityStore([ - _FakeIdentityRecord( - key=_FakeIdentityKey("messaging.feishu", "bot@im.bot", "zoey@im.feishu"), - elephant_id="elephant-001", - ), - ]) + identity_store = _FakeIdentityStore( + [ + _FakeIdentityRecord( + key=_FakeIdentityKey("messaging.feishu", "bot@im.bot", "zoey@im.feishu"), + elephant_id="elephant-001", + ), + ] + ) surface = GatewayMessageDeliverySurface( outbound_queue=self.queue, identity_store=identity_store, @@ -101,16 +102,18 @@ def test_send_via_identity_store_fallback(self): def test_send_with_target_hint_filters_adapter(self): """Target hint selects the right adapter.""" - identity_store = _FakeIdentityStore([ - _FakeIdentityRecord( - key=_FakeIdentityKey("messaging.feishu", "bot@feishu", "user@feishu"), - elephant_id="elephant-001", - ), - _FakeIdentityRecord( - key=_FakeIdentityKey("messaging.weixin", "bot@wx", "user@wx"), - elephant_id="elephant-001", - ), - ]) + identity_store = _FakeIdentityStore( + [ + _FakeIdentityRecord( + key=_FakeIdentityKey("messaging.feishu", "bot@feishu", "user@feishu"), + elephant_id="elephant-001", + ), + _FakeIdentityRecord( + key=_FakeIdentityKey("messaging.weixin", "bot@wx", "user@wx"), + elephant_id="elephant-001", + ), + ] + ) surface = GatewayMessageDeliverySurface( outbound_queue=self.queue, identity_store=identity_store, diff --git a/tests/unit/gateway/test_outbound_queue.py b/tests/unit/gateway/test_outbound_queue.py index d43a0c0..1686ebf 100644 --- a/tests/unit/gateway/test_outbound_queue.py +++ b/tests/unit/gateway/test_outbound_queue.py @@ -16,7 +16,6 @@ from datetime import datetime, timedelta, timezone from pathlib import Path import unittest -from unittest.mock import patch from packages.gateway_core import GatewayOutboundQueue @@ -73,8 +72,18 @@ def test_enqueue_then_claim_then_complete_empties_queue(self) -> None: self.assertEqual(self.queue.list_rows(), ()) def test_claim_filters_by_adapter(self) -> None: - self.queue.enqueue(adapter_id="messaging.weixin", account_id="a", conversation_id="c1", body="w") - self.queue.enqueue(adapter_id="messaging.feishu", account_id="a", conversation_id="c2", body="f") + self.queue.enqueue( + adapter_id="messaging.weixin", + account_id="a", + conversation_id="c1", + body="w", + ) + self.queue.enqueue( + adapter_id="messaging.feishu", + account_id="a", + conversation_id="c2", + body="f", + ) weixin = self.queue.claim(adapter_id="messaging.weixin") self.assertEqual(len(weixin), 1) diff --git a/tests/unit/gateway/test_weixin_inbound_serialization.py b/tests/unit/gateway/test_weixin_inbound_serialization.py index d14d063..1deddc0 100644 --- a/tests/unit/gateway/test_weixin_inbound_serialization.py +++ b/tests/unit/gateway/test_weixin_inbound_serialization.py @@ -6,7 +6,11 @@ from unittest import mock from apps.gateway.weixin_service import MessageDeduplicator, WeixinGatewayService -from packages.gateway_core import GatewayAccountRef, GatewayConversationRef, GatewayOutboundMessage +from packages.gateway_core import ( + GatewayAccountRef, + GatewayConversationRef, + GatewayOutboundMessage, +) class _FakeGatewayApp: diff --git a/tests/unit/growth/test_runtime.py b/tests/unit/growth/test_runtime.py index 00019e6..8b55ffc 100644 --- a/tests/unit/growth/test_runtime.py +++ b/tests/unit/growth/test_runtime.py @@ -4,7 +4,11 @@ from types import SimpleNamespace import unittest -from packages.contracts import ExperienceRecord, ProcedureRecord, PersonalModelGrowthState +from packages.contracts import ( + ExperienceRecord, + ProcedureRecord, + PersonalModelGrowthState, +) from packages.growth import ( GrowthTurnSignals, ProgressionProjectionBuilder, @@ -148,7 +152,9 @@ def test_first_turn_boost_and_second_turn_promotion_are_guaranteed(self) -> None self.assertGreaterEqual(second_snapshot.state.growth_score, 100) self.assertEqual(second.reward_reasons[0].reason_id, "second-turn-promotion") - def test_personal_model_understanding_signals_outrank_token_heavy_flat_turns(self) -> None: + def test_personal_model_understanding_signals_outrank_token_heavy_flat_turns( + self, + ) -> None: now = datetime(2026, 4, 18, tzinfo=timezone.utc) current = self._mature_state(now=now) @@ -184,7 +190,12 @@ def test_personal_model_understanding_signals_outrank_token_heavy_flat_turns(sel artifact_ids=("artifact:patch-note",), promoted_procedure_ids=("procedure:resume-checklist",), personal_model_fact_count=7, - personal_model_lens_counts=(("identity", 2), ("world", 2), ("pulse", 2), ("journey", 1)), + personal_model_lens_counts=( + ("identity", 2), + ("world", 2), + ("pulse", 2), + ("journey", 1), + ), personal_model_topic_count=5, personal_model_new_fact_count=2, personal_model_updated_fact_count=1, @@ -219,8 +230,14 @@ def test_personal_model_understanding_signals_outrank_token_heavy_flat_turns(sel self.assertIn("understanding-grounding", meaningful_reasons) self.assertIn("continuity", meaningful_reasons) self.assertIn("tokens-support", flat_reasons) - self.assertLess(flat_reasons["tokens-support"].score, meaningful_reasons["understanding-freshness"].score) - self.assertLess(flat_reasons["tokens-support"].score, meaningful_reasons["understanding-grounding"].score) + self.assertLess( + flat_reasons["tokens-support"].score, + meaningful_reasons["understanding-freshness"].score, + ) + self.assertLess( + flat_reasons["tokens-support"].score, + meaningful_reasons["understanding-grounding"].score, + ) def test_reward_reasons_keep_pm_freshness_and_grounding_traceable(self) -> None: now = datetime(2026, 4, 18, tzinfo=timezone.utc) @@ -250,7 +267,12 @@ def test_reward_reasons_keep_pm_freshness_and_grounding_traceable(self) -> None: promoted_procedure_ids=("procedure:resume-checklist",), work_item_evidence_refs=("artifact:brief",), personal_model_fact_count=5, - personal_model_lens_counts=(("identity", 1), ("world", 2), ("pulse", 1), ("journey", 1)), + personal_model_lens_counts=( + ("identity", 1), + ("world", 2), + ("pulse", 1), + ("journey", 1), + ), personal_model_topic_count=4, personal_model_new_fact_count=1, personal_model_updated_fact_count=1, @@ -305,14 +327,19 @@ def test_default_progression_rollout_scorecard_certifies_shadow_pack(self) -> No self.assertTrue(all(gate.status == "pass" for gate in gates.values())) comparisons = {comparison.case_id: comparison for comparison in scorecard.comparisons} - self.assertGreater(comparisons["meaningful-a"].delta_score, comparisons["trivial-a"].delta_score) + self.assertGreater( + comparisons["meaningful-a"].delta_score, + comparisons["trivial-a"].delta_score, + ) self.assertLess(comparisons["trivial-b"].delta_score, comparisons["trivial-a"].delta_score) self.assertIn("token-heavy", comparisons["trivial-b"].anti_grind_flags) self.assertTrue( any(condition.startswith("fallback to baseline-snapshot mode") for condition in scorecard.stop_conditions) ) - def test_progression_rollout_scorecard_falls_back_when_ui_budget_regresses(self) -> None: + def test_progression_rollout_scorecard_falls_back_when_ui_budget_regresses( + self, + ) -> None: now = datetime(2026, 4, 18, tzinfo=timezone.utc) case = ProgressionReplayCase( case_id="meaningful-ui", diff --git a/tests/unit/harness/test_retry_policy.py b/tests/unit/harness/test_retry_policy.py index 13818ff..75bc8db 100644 --- a/tests/unit/harness/test_retry_policy.py +++ b/tests/unit/harness/test_retry_policy.py @@ -31,14 +31,16 @@ def _advance(seconds: float) -> None: def _http_error(status: int, *, headers: dict | None = None) -> urllib_error.HTTPError: - return urllib_error.HTTPError( - "http://provider", status, "err", headers or {}, BytesIO(b"") - ) + return urllib_error.HTTPError("http://provider", status, "err", headers or {}, BytesIO(b"")) class ClassifyErrorTest(unittest.TestCase): def test_network_errors_classify_as_network(self) -> None: - for exc in (ConnectionError("broken"), TimeoutError("slow"), urllib_error.URLError("dns")): + for exc in ( + ConnectionError("broken"), + TimeoutError("slow"), + urllib_error.URLError("dns"), + ): self.assertEqual(classify_error(exc), "network", msg=str(exc)) def test_http_429_separates_from_5xx(self) -> None: @@ -223,7 +225,12 @@ def always_bad() -> None: with self.assertRaises(ConnectionError): with_retry( always_bad, - policy=RetryPolicy(max_attempts=10, base_backoff_s=5.0, max_backoff_s=60.0, jitter_ratio=0), + policy=RetryPolicy( + max_attempts=10, + base_backoff_s=5.0, + max_backoff_s=60.0, + jitter_ratio=0, + ), deadline=now + timedelta(seconds=3), sleeper=sleeper, clock=clock, diff --git a/tests/unit/harness/test_supervisor.py b/tests/unit/harness/test_supervisor.py index 54916d1..955c00a 100644 --- a/tests/unit/harness/test_supervisor.py +++ b/tests/unit/harness/test_supervisor.py @@ -4,7 +4,6 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from typing import Sequence import unittest from packages.contracts.layers import Step @@ -14,7 +13,6 @@ WaitCondition, ) from packages.harness.supervisor import ( - SupervisorDecision, scan_once, ) @@ -242,7 +240,10 @@ def test_existing_completed_step_skips_tool_replay_on_reclaim(self) -> None: sequence=0, created_at=now, outcome="ok", - metadata={"tool_call_id": "call-A", "tool_name": "tool.shell.run"}, + metadata={ + "tool_call_id": "call-A", + "tool_name": "tool.shell.run", + }, ) ] }, diff --git a/tests/unit/kernel/test_context_compaction.py b/tests/unit/kernel/test_context_compaction.py index 2408f87..5094b17 100644 --- a/tests/unit/kernel/test_context_compaction.py +++ b/tests/unit/kernel/test_context_compaction.py @@ -62,12 +62,17 @@ def test_packet_and_step_metadata_include_compaction_audit_fields(self) -> None: self.assertEqual(metadata["protected_ranges"], "head:0-1, tail:77-79") self.assertEqual(metadata["selected_raw_ids"], "group:abc123, group:def456") self.assertEqual(metadata["summary_hash"], "deadbeefcafefeed") - self.assertIn("latest user query: continue the database migration", metadata["compaction_query"]) + self.assertIn( + "latest user query: continue the database migration", + metadata["compaction_query"], + ) self.assertIn("protected_ranges=head:0-1|tail:77-79", detail) self.assertIn("selected_raw=2", detail) self.assertIn("summary_hash=deadbeefcafefeed", detail) - def test_retry_context_after_provider_overflow_returns_continuity_outcome(self) -> None: + def test_retry_context_after_provider_overflow_returns_continuity_outcome( + self, + ) -> None: result = ContextProjectionCompactionResult( compacted=True, reason="provider-overflow", diff --git a/tests/unit/kernel/test_generation_context_projection.py b/tests/unit/kernel/test_generation_context_projection.py index 9b52eea..d1625b2 100644 --- a/tests/unit/kernel/test_generation_context_projection.py +++ b/tests/unit/kernel/test_generation_context_projection.py @@ -108,7 +108,14 @@ def _bundle() -> ContextBundle: ) -def _fact(*, text: str, field: str, lens: str = "knowledge", confidence: float = 1.0, extra_metadata: dict[str, str] | None = None) -> Any: +def _fact( + *, + text: str, + field: str, + lens: str = "knowledge", + confidence: float = 1.0, + extra_metadata: dict[str, str] | None = None, +) -> Any: topic_by_field = { "identity.name.preferred": "identity.anchor.name.preferred", "city": "world.places.city.current", @@ -124,12 +131,23 @@ def _fact(*, text: str, field: str, lens: str = "knowledge", confidence: float = "text": text, "lens": lens, "confidence": confidence, - "metadata": {"field": field, "topic": topic_by_field.get(field, field), "protected": "system", "projection_policy": "core_prompt", **(extra_metadata or {})}, + "metadata": { + "field": field, + "topic": topic_by_field.get(field, field), + "protected": "system", + "projection_policy": "core_prompt", + **(extra_metadata or {}), + }, }, )() -def _profile(*, preferences: tuple[str, ...] = (), user_profile_text: str = "", style_summary: str = "") -> Any: +def _profile( + *, + preferences: tuple[str, ...] = (), + user_profile_text: str = "", + style_summary: str = "", +) -> Any: return type( "_Profile", (), @@ -141,7 +159,11 @@ def _profile(*, preferences: tuple[str, ...] = (), user_profile_text: str = "", )() -def _question(*, text: str = "When things get messy, what helps?", sub_lens: str = "stress_response") -> Any: +def _question( + *, + text: str = "When things get messy, what helps?", + sub_lens: str = "stress_response", +) -> Any: return type( "_Question", (), @@ -198,10 +220,17 @@ def test_learning_agent_context_mode_gets_minimal_generation_context(self) -> No self.assertEqual(result.evidence_refs, ()) self.assertEqual(result.artifact_ids, ()) - def test_learning_agent_minimal_context_can_carry_dedicated_system_prompt(self) -> None: + def test_learning_agent_minimal_context_can_carry_dedicated_system_prompt( + self, + ) -> None: result = build_context_for_generation( dependencies=_FakeDependencies(_FakeStorage()), - request=_FakeRequest(source_payload={"context_mode": "learning_agent", "system_prompt": "SYSTEM ONLY"}), + request=_FakeRequest( + source_payload={ + "context_mode": "learning_agent", + "system_prompt": "SYSTEM ONLY", + } + ), profile=None, session=None, state_focus=None, @@ -244,15 +273,27 @@ def test_core_prompt_filters_internal_learning_artifact_facts(self) -> None: dependencies=_FakeDependencies( _FakeStorage( facts=( - _fact(text="Question-bank signal for feedback_preference: explicit", field="question.signal", lens="rapport"), - _fact(text="User explicitly shared autonomy_boundary: 的心事更在前面?", field="question.noise", lens="rapport"), + _fact( + text="Question-bank signal for feedback_preference: explicit", + field="question.signal", + lens="rapport", + ), + _fact( + text="User explicitly shared autonomy_boundary: 的心事更在前面?", + field="question.noise", + lens="rapport", + ), _fact( text="Synthetic live acceptance marker for init_bootstrap mode validation. run_tag=20260509135158.", field="validation.marker", lens="knowledge", extra_metadata={"recall_policy": "temporary"}, ), - _fact(text="第一语言:中文。", field="first_language", lens="rapport"), + _fact( + text="第一语言:中文。", + field="first_language", + lens="rapport", + ), ) ) ), @@ -274,7 +315,9 @@ def test_core_prompt_filters_internal_learning_artifact_facts(self) -> None: self.assertNotIn("User explicitly shared autonomy_boundary", prompt) self.assertNotIn("Synthetic live acceptance marker", prompt) - def test_gateway_state_projection_does_not_inject_previous_assistant_reply_as_ongoing_thread(self) -> None: + def test_gateway_state_projection_does_not_inject_previous_assistant_reply_as_ongoing_thread( + self, + ) -> None: state = type( "_State", (), @@ -310,11 +353,7 @@ def test_episode_resume_snapshot_stays_out_of_frozen_prefix(self) -> None: episode = type( "_Episode", (), - { - "metadata": { - "opening_resume_snapshot": "Use the live project handoff as the current Elephant context." - } - }, + {"metadata": {"opening_resume_snapshot": "Use the live project handoff as the current Elephant context."}}, )() result = build_context_for_generation( dependencies=_FakeDependencies(_FakeStorage(episode=episode)), @@ -349,7 +388,9 @@ def test_token_budget_is_not_injected_into_system_prompt(self) -> None: ), ) result = build_context_for_generation( - dependencies=_FakeDependencies(_FakeStorage(facts=(_fact(text="称呼:xunzhuo。", field="identity.name.preferred"),))), + dependencies=_FakeDependencies( + _FakeStorage(facts=(_fact(text="称呼:xunzhuo。", field="identity.name.preferred"),)) + ), request=_FakeRequest(), profile=None, session=None, @@ -369,7 +410,9 @@ def test_token_budget_is_not_injected_into_system_prompt(self) -> None: self.assertNotIn("Prompt budget", rendered) self.assertNotIn("204800", rendered) - def test_pm_facts_replace_raw_user_snapshot_and_skill_index_moves_late(self) -> None: + def test_pm_facts_replace_raw_user_snapshot_and_skill_index_moves_late( + self, + ) -> None: skill_block = "\n".join( ( "### Capability Disclosure", @@ -398,7 +441,11 @@ def test_pm_facts_replace_raw_user_snapshot_and_skill_index_moves_late(self) -> ) storage = _FakeStorage( facts=( - _fact(text="称呼:xunzhuo。", field="identity.name.preferred", lens="knowledge"), + _fact( + text="称呼:xunzhuo。", + field="identity.name.preferred", + lens="knowledge", + ), _fact(text="城市或时区语境:成都。", field="city", lens="knowledge"), ), profile=_profile(preferences=("first_language=zh",)), @@ -425,7 +472,10 @@ def test_pm_facts_replace_raw_user_snapshot_and_skill_index_moves_late(self) -> self.assertIn("### World — what is around them", prompt) self.assertIn("称呼:xunzhuo", prompt) self.assertIn("城市或时区语境:成都", prompt) - self.assertLess(prompt.index("### Identity — who they are"), prompt.index("### Capability Disclosure")) + self.assertLess( + prompt.index("### Identity — who they are"), + prompt.index("### Capability Disclosure"), + ) def test_pm_facts_replace_stale_frozen_personal_projection(self) -> None: context = ContextBundle( @@ -447,7 +497,15 @@ def test_pm_facts_replace_stale_frozen_personal_projection(self) -> None: ) result = build_context_for_generation( dependencies=_FakeDependencies( - _FakeStorage(facts=(_fact(text="称呼:zoey。", field="identity.name.preferred", lens="knowledge"),)) + _FakeStorage( + facts=( + _fact( + text="称呼:zoey。", + field="identity.name.preferred", + lens="knowledge", + ), + ) + ) ), request=_FakeRequest(), profile=None, @@ -468,9 +526,21 @@ def test_pm_facts_replace_stale_frozen_personal_projection(self) -> None: def test_style_guidance_is_behavioral_not_raw_database_summary(self) -> None: storage = _FakeStorage( facts=( - _fact(text="压力升起来时,常见反应是:先安静下来。", field="pressure_pattern", lens="trait"), - _fact(text="恢复精力时,较早有用的是:安静一会儿。", field="recovery_style", lens="trait"), - _fact(text="面对悬而未决的选择时,更靠近答案的方式是:写下取舍。", field="decision_compass", lens="trait"), + _fact( + text="压力升起来时,常见反应是:先安静下来。", + field="pressure_pattern", + lens="trait", + ), + _fact( + text="恢复精力时,较早有用的是:安静一会儿。", + field="recovery_style", + lens="trait", + ), + _fact( + text="面对悬而未决的选择时,更靠近答案的方式是:写下取舍。", + field="decision_compass", + lens="trait", + ), ), profile=_profile( preferences=("relationship_mode=安静、细腻、低压地陪在旁边",), @@ -501,13 +571,24 @@ def test_style_guidance_is_behavioral_not_raw_database_summary(self) -> None: def test_curiosity_hint_routes_question_selection_through_tool(self) -> None: storage = _FakeStorage( facts=( - _fact(text="第一语言:中文;除非用户另行要求,默认使用中文沟通。", field="first_language", lens="rapport"), - _fact(text="称呼:xunzhuo。", field="identity.name.preferred", lens="knowledge"), + _fact( + text="第一语言:中文;除非用户另行要求,默认使用中文沟通。", + field="first_language", + lens="rapport", + ), + _fact( + text="称呼:xunzhuo。", + field="identity.name.preferred", + lens="knowledge", + ), ), profile=_profile(), questions=( _question(text="when things get messy, do you want a checklist first?"), - _question(text="when your energy is low, do you need quiet space first?", sub_lens="energy_management"), + _question( + text="when your energy is low, do you need quiet space first?", + sub_lens="energy_management", + ), ), ) result = build_context_for_generation( @@ -537,6 +618,7 @@ def test_placeholder_user_names_are_suppressed(self) -> None: ever said their name. Render-time filter catches it until a real name arrives via `tool.personal_model.update`. """ + class _PersonalModel: display_name = "Hazel" status = "active" @@ -587,6 +669,7 @@ def test_reflexive_display_name_is_suppressed(self) -> None: """When display_name decayed to a reflexive pronoun like "You", we must not emit "You are You" absurdity. """ + class _PersonalModel: display_name = "You" status = "active" diff --git a/tests/unit/kernel/test_lifecycle_support.py b/tests/unit/kernel/test_lifecycle_support.py index 9651c82..e6d1090 100644 --- a/tests/unit/kernel/test_lifecycle_support.py +++ b/tests/unit/kernel/test_lifecycle_support.py @@ -7,13 +7,19 @@ import unittest from packages.contracts.layers import Episode -from packages.kernel.lifecycle_support import KernelRuntimeIdentity, close_episode_lifecycle, open_episode_lifecycle +from packages.kernel.lifecycle_support import ( + KernelRuntimeIdentity, + close_episode_lifecycle, + open_episode_lifecycle, +) from packages.kernel.runtime_support import KernelSourceRequest from packages.storage.repository_impl import RuntimeStorageRepository class KernelLifecycleSupportTests(unittest.TestCase): - def test_gateway_idle_reuse_closes_stale_episode_and_opens_new_episode(self) -> None: + def test_gateway_idle_reuse_closes_stale_episode_and_opens_new_episode( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "state" / "elephant.sqlite3") repository.bootstrap() @@ -25,7 +31,10 @@ def test_gateway_idle_reuse_closes_stale_episode_and_opens_new_episode(self) -> state_id="state-gateway", surface_bindings=("gateway:discord:room",), ) - state = replace(state, current_context_note="Resume the gateway handoff from the prior episode.") + state = replace( + state, + current_context_note="Resume the gateway handoff from the prior episode.", + ) repository.upsert_state(state) previous_at = datetime(2026, 4, 24, 10, tzinfo=timezone.utc) stale_episode = Episode( @@ -59,7 +68,10 @@ def test_gateway_idle_reuse_closes_stale_episode_and_opens_new_episode(self) -> stored_stale = repository.load_episode(stale_episode.episode_id) self.assertEqual(lifecycle.episode.episode_id, "episode:request-gateway-new") self.assertEqual(lifecycle.close_on_completion, False) - self.assertEqual(tuple(episode.episode_id for episode in lifecycle.idle_closed_episodes), ("episode:gateway-stale",)) + self.assertEqual( + tuple(episode.episode_id for episode in lifecycle.idle_closed_episodes), + ("episode:gateway-stale",), + ) self.assertIsNotNone(stored_stale) assert stored_stale is not None self.assertEqual(stored_stale.status, "closed") @@ -191,7 +203,9 @@ def test_explicit_closed_episode_cannot_be_reopened(self) -> None: assert stored is not None self.assertEqual(stored.status, "closed") - def test_close_episode_does_not_foreground_update_state_continuation_note(self) -> None: + def test_close_episode_does_not_foreground_update_state_continuation_note( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "state" / "elephant.sqlite3") repository.bootstrap() diff --git a/tests/unit/kernel/test_resume_support.py b/tests/unit/kernel/test_resume_support.py index c45d03e..1f82b54 100644 --- a/tests/unit/kernel/test_resume_support.py +++ b/tests/unit/kernel/test_resume_support.py @@ -18,7 +18,6 @@ WaitCondition, ) from packages.kernel.resume_support import ( - PendingToolReplayPlan, apply_resume_snapshot, plan_pending_tool_replay, snapshot_resume, diff --git a/tests/unit/kernel/test_runtime_support_budgets.py b/tests/unit/kernel/test_runtime_support_budgets.py index dd6fef4..4b95974 100644 --- a/tests/unit/kernel/test_runtime_support_budgets.py +++ b/tests/unit/kernel/test_runtime_support_budgets.py @@ -10,13 +10,14 @@ from packages.contracts.runtime import ( ContextBundle, ExecutionResult, - PersonalModelRuntimeState, PromptEnvelope, PromptMessage, ) from packages.kernel import KernelService, KernelSourceRequest -from packages.kernel.execution_support import execute_kernel_turn -from packages.kernel.loop_checkpoint_support import LoopCheckpointBudget, LoopCheckpointService +from packages.kernel.loop_checkpoint_support import ( + LoopCheckpointBudget, + LoopCheckpointService, +) from packages.kernel.runtime_support import ( _TextToolCall, _build_clock, @@ -57,7 +58,9 @@ def test_dashboard_maps_effective_user_query_as_visible_user_query(self) -> None self.assertEqual(_step_event_type(source_step), "source_input") self.assertEqual(_step_event_content(source_step, {}), "raw question") - def test_agent_loop_budget_defaults_extend_model_turns_and_use_large_result_budgets(self) -> None: + def test_agent_loop_budget_defaults_extend_model_turns_and_use_large_result_budgets( + self, + ) -> None: budget = LoopCheckpointBudget() self.assertEqual(budget.max_model_turns, 100) @@ -112,8 +115,7 @@ def test_recorded_provider_system_prompt_excludes_loop_context(self) -> None: ), ), rendered_prompt=( - "system prompt :: frozen_prefix\n\n" - "## Turn attachments\nruntime-paths: startup_cwd=/tmp/start" + "system prompt :: frozen_prefix\n\n## Turn attachments\nruntime-paths: startup_cwd=/tmp/start" ), ) @@ -245,7 +247,9 @@ def test_parallel_tool_batch_allows_background_sub_agent_start_calls(self) -> No self.assertTrue(_should_parallelize_tool_batch(calls)) - def test_parallel_tool_batch_rejects_blocking_sub_loop_checkpoint_calls(self) -> None: + def test_parallel_tool_batch_rejects_blocking_sub_loop_checkpoint_calls( + self, + ) -> None: calls = ( _TextToolCall("tool.sub_agents", {"action": "run", "task": "inspect core"}), _TextToolCall("tool.sub_agents", {"action": "run", "task": "inspect tools"}), @@ -315,7 +319,10 @@ def test_state_projection_does_not_mutate_continuation_note_per_turn(self) -> No current=datetime(2026, 4, 28, tzinfo=timezone.utc), ) - self.assertEqual(updated.current_context_note, "Resume the dashboard redesign from the prior episode.") + self.assertEqual( + updated.current_context_note, + "Resume the dashboard redesign from the prior episode.", + ) self.assertLessEqual(len(updated.summary), 480) self.assertNotIn("\n", updated.summary) diff --git a/tests/unit/personal_state/test_api_state_runtime.py b/tests/unit/personal_state/test_api_state_runtime.py index 18231d2..0c57cdd 100644 --- a/tests/unit/personal_state/test_api_state_runtime.py +++ b/tests/unit/personal_state/test_api_state_runtime.py @@ -84,7 +84,9 @@ def _canonical_facts( if str(fact.metadata.get("sync_source") or "") == sync_source ) - def test_ensure_personal_model_state_bootstrap_captures_governed_updates(self) -> None: + def test_ensure_personal_model_state_bootstrap_captures_governed_updates( + self, + ) -> None: tmpdir, repository, personal_model, state, episode, runtime = self._build_runtime() self.addCleanup(tmpdir.cleanup) @@ -118,7 +120,10 @@ def test_update_identity_state_does_not_capture_personal_model_fact(self) -> Non ) self.assertEqual(updated.display_name, "Elephant Agent Revised") - self.assertEqual(self._personal_model_fact_count(repository, personal_model.profile_id), before) + self.assertEqual( + self._personal_model_fact_count(repository, personal_model.profile_id), + before, + ) canonical_facts = self._canonical_facts( repository, personal_model_id=personal_model.profile_id, @@ -138,7 +143,10 @@ def test_update_user_state_adds_one_governed_memory_capture(self) -> None: ) self.assertEqual(updated.preferred_name, "Bit") - self.assertEqual(self._personal_model_fact_count(repository, personal_model.profile_id), before + 1) + self.assertEqual( + self._personal_model_fact_count(repository, personal_model.profile_id), + before + 1, + ) canonical_facts = self._canonical_facts( repository, personal_model_id=personal_model.profile_id, @@ -162,7 +170,10 @@ def test_update_relationship_state_adds_one_governed_memory_capture(self) -> Non ) self.assertIn("Protect focused work windows.", updated.continuity_notes) - self.assertEqual(self._personal_model_fact_count(repository, personal_model.profile_id), before + 1) + self.assertEqual( + self._personal_model_fact_count(repository, personal_model.profile_id), + before + 1, + ) canonical_facts = self._canonical_facts( repository, personal_model_id=personal_model.profile_id, diff --git a/tests/unit/personal_state/test_projection.py b/tests/unit/personal_state/test_projection.py index 64ac9f5..a4f8728 100644 --- a/tests/unit/personal_state/test_projection.py +++ b/tests/unit/personal_state/test_projection.py @@ -8,7 +8,9 @@ class PersonalStateProjectionTest(unittest.TestCase): - def test_build_loaded_profile_from_state_preserves_custom_personality_traits(self) -> None: + def test_build_loaded_profile_from_state_preserves_custom_personality_traits( + self, + ) -> None: loaded = build_loaded_profile_from_state( PersonalModelRuntimeState( profile_id="profile-companion", diff --git a/tests/unit/profile/test_canonical_state.py b/tests/unit/profile/test_canonical_state.py index a37f237..7c8e0b6 100644 --- a/tests/unit/profile/test_canonical_state.py +++ b/tests/unit/profile/test_canonical_state.py @@ -7,7 +7,10 @@ CompanionSettings, render_user_profile_text, ) -from packages.state.canonical import build_canonical_profile_state, canonical_profile_ids +from packages.state.canonical import ( + build_canonical_profile_state, + canonical_profile_ids, +) from packages.state.persistence import _relationship_capture_content from packages.state.rendered_views import RenderedRelationshipView from packages.state.projection import build_loaded_profile_from_state @@ -42,8 +45,7 @@ def _load_profile(self): profile_dir="", manifest_path=None, elephant_identity_text=( - "Protect continuity, stay exact, and keep the user oriented around " - "the next useful move." + "Protect continuity, stay exact, and keep the user oriented around the next useful move." ), user_profile_text=render_user_profile_text( preferred_name="Bit", @@ -62,7 +64,9 @@ def test_canonical_profile_ids_are_stable(self) -> None: self.assertEqual(ids.user_profile_id, "profile-companion:user-profile") self.assertEqual(ids.relationship_id, "profile-companion:relationship") - def test_build_canonical_profile_state_separates_user_and_relationship_truth(self) -> None: + def test_build_canonical_profile_state_separates_user_and_relationship_truth( + self, + ) -> None: loaded = self._load_profile() bundle = build_canonical_profile_state(loaded) @@ -76,18 +80,33 @@ def test_build_canonical_profile_state_separates_user_and_relationship_truth(sel self.assertEqual(bundle.user_profile.preferred_name, "Bit") self.assertEqual(bundle.user_profile.locale, "zh-CN") self.assertEqual(bundle.user_profile.timezone, "Asia/Shanghai") - self.assertEqual(bundle.user_profile.communication_preferences, ("tone:steady", "verbosity:concise")) - self.assertEqual(bundle.user_profile.shared_preferences, ("local-context:agentic-in/elephant",)) - self.assertIn("current_work:Building durable agent systems.", bundle.user_profile.biography_fragments) + self.assertEqual( + bundle.user_profile.communication_preferences, + ("tone:steady", "verbosity:concise"), + ) + self.assertEqual( + bundle.user_profile.shared_preferences, + ("local-context:agentic-in/elephant",), + ) + self.assertIn( + "current_work:Building durable agent systems.", + bundle.user_profile.biography_fragments, + ) self.assertIn("current_city:Shanghai", bundle.user_profile.biography_fragments) self.assertEqual(bundle.user_profile.boundaries, ("Prefer directness over fluff.",)) - self.assertEqual(bundle.user_profile.durable_notes, ("Carries research context across weeks.",)) + self.assertEqual( + bundle.user_profile.durable_notes, + ("Carries research context across weeks.",), + ) self.assertEqual(bundle.relationship.elephant_id, bundle.elephant_identity.elephant_id) self.assertEqual(bundle.relationship.user_profile_id, bundle.user_profile.user_profile_id) self.assertIn("initiative:proactive", bundle.relationship.expectations) self.assertIn("recover long arcs", bundle.relationship.continuity_notes) - self.assertNotIn("current_work:Building durable agent systems.", bundle.relationship.expectations) + self.assertNotIn( + "current_work:Building durable agent systems.", + bundle.relationship.expectations, + ) self.assertNotIn("Prefer directness over fluff.", bundle.relationship.continuity_notes) def test_relationship_capture_excludes_system_governance_defaults(self) -> None: diff --git a/tests/unit/profile/test_governance.py b/tests/unit/profile/test_governance.py index 2315f76..7a8e521 100644 --- a/tests/unit/profile/test_governance.py +++ b/tests/unit/profile/test_governance.py @@ -21,7 +21,13 @@ class UserProfileGovernanceTest(unittest.TestCase): - def _load_profile(self, root: Path, *, display_name: str = "Aeon", user_profile_text: str | None = None): + def _load_profile( + self, + root: Path, + *, + display_name: str = "Aeon", + user_profile_text: str | None = None, + ): """Build a LoadedProfile with explicit identity (no profile.json). ``root`` is accepted for signature parity with the old fixture; it's @@ -71,8 +77,14 @@ def test_missing_user_profile_fields_split_required_and_optional(self) -> None: required = missing_required_user_fields(loaded) optional = missing_optional_user_fields(loaded) - self.assertEqual(tuple(question.field_id for question in required), ("preferred_name", "current_work")) - self.assertEqual(tuple(question.field_id for question in optional[:2]), ("school", "current_city")) + self.assertEqual( + tuple(question.field_id for question in required), + ("preferred_name", "current_work"), + ) + self.assertEqual( + tuple(question.field_id for question in optional[:2]), + ("school", "current_city"), + ) def test_user_profile_updates_normalize_loose_field_labels(self) -> None: updates = user_profile_updates( @@ -139,7 +151,9 @@ def test_parse_user_profile_content_keeps_low_loss_durable_notes(self) -> None: ), ) - def test_build_companion_identity_state_prefers_elephant_state_display_name(self) -> None: + def test_build_companion_identity_state_prefers_elephant_state_display_name( + self, + ) -> None: """``State.elephant_name`` → ``profile.state.display_name`` is canonical. ELEPHANT.md is authoring-only; the parser is a write-path helper, not a @@ -182,7 +196,10 @@ def test_onboarding_state_is_ready_without_file_first_profile_gate(self) -> None self.assertEqual(onboarding.missing_fields, ()) self.assertEqual(onboarding.next_step, "continue-normal-conversation") self.assertIn("normal turns", onboarding.summary) - self.assertEqual(tuple(checkpoint.status for checkpoint in onboarding.checkpoints), ("ready", "ready", "ready")) + self.assertEqual( + tuple(checkpoint.status for checkpoint in onboarding.checkpoints), + ("ready", "ready", "ready"), + ) if __name__ == "__main__": diff --git a/tests/unit/profile/test_prompt_contract.py b/tests/unit/profile/test_prompt_contract.py index e635ff0..67f89fb 100644 --- a/tests/unit/profile/test_prompt_contract.py +++ b/tests/unit/profile/test_prompt_contract.py @@ -55,7 +55,9 @@ def _build_loaded_profile(self, *, display_name: str = "Aeon", elephant_identity ), ) - def test_full_prompt_contract_includes_canonical_identity_and_user_sections(self) -> None: + def test_full_prompt_contract_includes_canonical_identity_and_user_sections( + self, + ) -> None: loaded = self._build_loaded_profile() contract = build_prompt_contract(loaded, prompt_mode="full") @@ -117,12 +119,17 @@ def test_full_prompt_contract_includes_canonical_identity_and_user_sections(self self.assertNotIn("### What you know about the user", rendered) self.assertNotIn("- Preferred name: Bit", rendered) self.assertNotIn("Continuity reminders for this elephant: recover long arcs.", rendered) - self.assertIn("Use `tool.personal_model.questions` only when one timely question would improve future help.", rendered) + self.assertIn( + "Use `tool.personal_model.questions` only when one timely question would improve future help.", + rendered, + ) self.assertIn("Keep Personal Model writes small", rendered) self.assertNotIn("state-onboarding=", rendered) self.assertNotIn("grounding-policy=", rendered) - def test_minimal_prompt_contract_stays_compact_but_keeps_canonical_user_snapshot(self) -> None: + def test_minimal_prompt_contract_stays_compact_but_keeps_canonical_user_snapshot( + self, + ) -> None: loaded = self._build_loaded_profile() contract = build_prompt_contract(loaded, prompt_mode="minimal") diff --git a/tests/unit/reflect/test_dream_feature.py b/tests/unit/reflect/test_dream_feature.py index ce85fb0..eceb16b 100644 --- a/tests/unit/reflect/test_dream_feature.py +++ b/tests/unit/reflect/test_dream_feature.py @@ -13,33 +13,51 @@ class DreamFeatureTest(unittest.TestCase): def test_dream_trigger_resolves_to_single_nightly_bundle(self) -> None: features = resolve_features("dream") - self.assertEqual(tuple(feature.feature_id for feature in features), ("dream", "questions", "skills", "diary")) + self.assertEqual( + tuple(feature.feature_id for feature in features), + ("dream", "questions", "skills", "diary"), + ) def test_explicit_dream_drops_pm_learning_but_preserves_questions(self) -> None: features = resolve_features("manual", explicit_features=("pm", "questions", "dream", "recall")) self.assertEqual(tuple(feature.feature_id for feature in features), ("dream", "questions")) - def test_dream_trigger_with_legacy_explicit_metadata_adds_nightly_bundle(self) -> None: + def test_dream_trigger_with_legacy_explicit_metadata_adds_nightly_bundle( + self, + ) -> None: features = resolve_features("dream", explicit_features=("dream", "questions")) - self.assertEqual(tuple(feature.feature_id for feature in features), ("dream", "questions", "skills", "diary")) + self.assertEqual( + tuple(feature.feature_id for feature in features), + ("dream", "questions", "skills", "diary"), + ) def test_explicit_dream_alone_stays_dream_only(self) -> None: features = resolve_features("manual", explicit_features=("dream",)) self.assertEqual(tuple(feature.feature_id for feature in features), ("dream",)) - def test_episode_close_resolves_to_pm_questions_and_skills_without_conversation_search(self) -> None: + def test_episode_close_resolves_to_pm_questions_and_skills_without_conversation_search( + self, + ) -> None: features = resolve_features("episode_close") - self.assertEqual(tuple(feature.feature_id for feature in features), ("pm", "questions", "skills")) + self.assertEqual( + tuple(feature.feature_id for feature in features), + ("pm", "questions", "skills"), + ) self.assertNotIn("tool.conversation.search", _compose_tools(features)) - def test_init_profile_resolves_without_diary_and_uses_bootstrap_prompt_rules(self) -> None: + def test_init_profile_resolves_without_diary_and_uses_bootstrap_prompt_rules( + self, + ) -> None: features = resolve_features("init_profile") - self.assertEqual(tuple(feature.feature_id for feature in features), ("pm", "questions", "skills")) + self.assertEqual( + tuple(feature.feature_id for feature in features), + ("pm", "questions", "skills"), + ) prompt = _assemble_system_prompt(features, conservatism="low") @@ -104,7 +122,9 @@ def test_dream_prompt_requires_pm_consolidation_and_concise_claims(self) -> None self.assertIn("CLAIM TEXT RULE", prompt) self.assertIn("short, clear, explicit, and unambiguous", prompt) - def test_dream_evidence_omits_episode_close_packet_when_questions_are_present(self) -> None: + def test_dream_evidence_omits_episode_close_packet_when_questions_are_present( + self, + ) -> None: class Repository: def load_episode(self, episode_id: str) -> SimpleNamespace: return SimpleNamespace(exit_summary="episode close summary should not appear") diff --git a/tests/unit/reflect/test_reflect_runner.py b/tests/unit/reflect/test_reflect_runner.py index 9ad72cb..8e10a4d 100644 --- a/tests/unit/reflect/test_reflect_runner.py +++ b/tests/unit/reflect/test_reflect_runner.py @@ -46,10 +46,16 @@ def write_learning_job_result(self, job_id: str, *_: object, **__: object) -> No class ReflectRunnerTest(unittest.TestCase): - def test_unpersisted_reflect_invocation_can_return_summary_without_learning_job_row(self) -> None: + def test_unpersisted_reflect_invocation_can_return_summary_without_learning_job_row( + self, + ) -> None: runtime = SimpleNamespace( repository=MissingJobRepository(), - run_sub_agent=lambda **_: {"summary": "Compressed summary", "status": "completed", "side_effects": ()}, + run_sub_agent=lambda **_: { + "summary": "Compressed summary", + "status": "completed", + "side_effects": (), + }, ) result = run_reflect_agent( @@ -65,11 +71,19 @@ def test_unpersisted_reflect_invocation_can_return_summary_without_learning_job_ def test_missing_non_sync_job_still_raises(self) -> None: runtime = SimpleNamespace( repository=MissingJobRepository(), - run_sub_agent=lambda **_: {"summary": "Summary", "status": "completed", "side_effects": ()}, + run_sub_agent=lambda **_: { + "summary": "Summary", + "status": "completed", + "side_effects": (), + }, ) with self.assertRaises(KeyError): - run_reflect_agent(runtime, _learning_job("learning-job:missing"), explicit_features=("compress",)) + run_reflect_agent( + runtime, + _learning_job("learning-job:missing"), + explicit_features=("compress",), + ) if __name__ == "__main__": diff --git a/tests/unit/semantic_index/test_backend.py b/tests/unit/semantic_index/test_backend.py index 1960a3a..ebc4238 100644 --- a/tests/unit/semantic_index/test_backend.py +++ b/tests/unit/semantic_index/test_backend.py @@ -42,7 +42,10 @@ def test_backend_indexes_searches_restarts_and_deletes_vectors(self) -> None: deleted = restarted.delete(SemanticIndexDeleteRequest(("entry-alpha",), dimensions=4)) remaining = restarted.search(SemanticIndexVectorQuery(4, (1.0, 0.0, 0.0, 0.0), limit=2)) - self.assertEqual(tuple(match.semantic_index_entry_id for match in matches), ("entry-alpha", "entry-beta")) + self.assertEqual( + tuple(match.semantic_index_entry_id for match in matches), + ("entry-alpha", "entry-beta"), + ) self.assertEqual(matches[0].distance, 0.0) self.assertEqual(deleted.accepted, 1) self.assertEqual(tuple(match.semantic_index_entry_id for match in remaining), ("entry-beta",)) diff --git a/tests/unit/semantic_index/test_sqlite_vec.py b/tests/unit/semantic_index/test_sqlite_vec.py index cd26bf2..4453db5 100644 --- a/tests/unit/semantic_index/test_sqlite_vec.py +++ b/tests/unit/semantic_index/test_sqlite_vec.py @@ -10,7 +10,11 @@ if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) -from packages.semantic_index import SQLITE_VEC_VERSION, load_sqlite_vec_extension, sqlite_vec_dependency_state +from packages.semantic_index import ( + SQLITE_VEC_VERSION, + load_sqlite_vec_extension, + sqlite_vec_dependency_state, +) class _FakeConnection: @@ -53,8 +57,14 @@ def test_load_extension_uses_sqlite_vec_loader_and_smokes_runtime(self) -> None: fake_connection = _FakeConnection() with ( - mock.patch("packages.semantic_index.sqlite_vec.metadata.version", return_value=SQLITE_VEC_VERSION), - mock.patch("packages.semantic_index.sqlite_vec.import_module", return_value=fake_module), + mock.patch( + "packages.semantic_index.sqlite_vec.metadata.version", + return_value=SQLITE_VEC_VERSION, + ), + mock.patch( + "packages.semantic_index.sqlite_vec.import_module", + return_value=fake_module, + ), ): state = load_sqlite_vec_extension(fake_connection) @@ -69,13 +79,20 @@ def test_load_extension_degrades_when_loader_fails(self) -> None: fake_module.load.side_effect = RuntimeError("boom") with ( - mock.patch("packages.semantic_index.sqlite_vec.metadata.version", return_value=SQLITE_VEC_VERSION), - mock.patch("packages.semantic_index.sqlite_vec.import_module", return_value=fake_module), + mock.patch( + "packages.semantic_index.sqlite_vec.metadata.version", + return_value=SQLITE_VEC_VERSION, + ), + mock.patch( + "packages.semantic_index.sqlite_vec.import_module", + return_value=fake_module, + ), ): state = load_sqlite_vec_extension(_FakeConnection()) self.assertEqual(state.status, "degraded") self.assertEqual(state.metadata["reason"], "RuntimeError") + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_browser_backend.py b/tests/unit/test_browser_backend.py index 87608ae..9e3d4f3 100644 --- a/tests/unit/test_browser_backend.py +++ b/tests/unit/test_browser_backend.py @@ -8,7 +8,11 @@ from unittest import mock from packages.tools import browser_backend as browser_backend_module -from packages.tools.browser_backend import BrowserBackendConfig, PlaywrightBrowserBackend, _is_private_url +from packages.tools.browser_backend import ( + BrowserBackendConfig, + PlaywrightBrowserBackend, + _is_private_url, +) from packages.tools.builtins import builtin_tool_definitions from packages.tools.handlers_network import run_browser_action from packages.tools.runtime import ToolInvocation @@ -88,12 +92,30 @@ def evaluate(self, script: str, arg=None): # type: ignore[no-untyped-def] "text": "Welcome Sign in Email", "elementCount": 2, "elements": [ - {"ref": "@e1", "role": "button", "label": "Sign in", "disabled": False}, - {"ref": "@e2", "role": "input", "label": "Email", "disabled": False}, + { + "ref": "@e1", + "role": "button", + "label": "Sign in", + "disabled": False, + }, + { + "ref": "@e2", + "role": "input", + "label": "Email", + "disabled": False, + }, ], } if "document.images" in script: - return [{"index": 1, "src": "https://example.com/a.png", "alt": "A", "width": 10, "height": 20}] + return [ + { + "index": 1, + "src": "https://example.com/a.png", + "alt": "A", + "width": 10, + "height": 20, + } + ] if "data-elephant-browser-annotation" in script: return 2 if script == "document.title": @@ -263,21 +285,36 @@ def test_navigate_returns_ref_snapshot_and_click_uses_ref(self) -> None: def test_type_images_console_and_vision_payloads(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: backend = self._backend(Path(tmpdir)) - backend.invoke("navigate", self._invoke("tool.browser.navigate", {"url": "https://example.com"})) + backend.invoke( + "navigate", + self._invoke("tool.browser.navigate", {"url": "https://example.com"}), + ) page = backend._sessions["session-1"].page page.handlers["console"](_FakeMessage()) - typed = backend.invoke("type", self._invoke("tool.browser.type", {"ref": "e2", "text": "a@example.com"})) + typed = backend.invoke( + "type", + self._invoke("tool.browser.type", {"ref": "e2", "text": "a@example.com"}), + ) images = backend.invoke("images", self._invoke("tool.browser.images", {})) - console = backend.invoke("console", self._invoke("tool.browser.console", {"expression": "document.title"})) + console = backend.invoke( + "console", + self._invoke("tool.browser.console", {"expression": "document.title"}), + ) analyzer = _FakeVisionAnalyzer() vision = backend.invoke( "vision", - self._invoke("tool.browser.vision", {"question": "what is visible?", "annotate": True}), + self._invoke( + "tool.browser.vision", + {"question": "what is visible?", "annotate": True}, + ), vision_analyzer=analyzer, ) - self.assertEqual(json.loads(typed["summary"])["target"], '[data-elephant-browser-ref="@e2"]') + self.assertEqual( + json.loads(typed["summary"])["target"], + '[data-elephant-browser-ref="@e2"]', + ) self.assertEqual(json.loads(images["summary"])["count"], 1) self.assertEqual(json.loads(console["summary"])["result"], "Example") vision_payload = json.loads(vision["summary"]) @@ -292,13 +329,22 @@ def test_type_images_console_and_vision_payloads(self) -> None: def test_browser_vision_without_analyzer_returns_setup_hint(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: backend = self._backend(Path(tmpdir)) - backend.invoke("navigate", self._invoke("tool.browser.navigate", {"url": "https://example.com"})) + backend.invoke( + "navigate", + self._invoke("tool.browser.navigate", {"url": "https://example.com"}), + ) - vision = backend.invoke("vision", self._invoke("tool.browser.vision", {"question": "what is visible?"})) + vision = backend.invoke( + "vision", + self._invoke("tool.browser.vision", {"question": "what is visible?"}), + ) vision_payload = json.loads(vision["summary"]) self.assertFalse(vision_payload["vision_analyzer_configured"]) - self.assertIn("Configure a browser vision analyzer", vision_payload["vision_setup_hint"]) + self.assertIn( + "Configure a browser vision analyzer", + vision_payload["vision_setup_hint"], + ) def test_browser_schema_is_ref_first_with_selector_fallback(self) -> None: definitions = { @@ -329,7 +375,10 @@ def test_browser_vision_is_unavailable_without_analyzer(self) -> None: self.assertTrue(definitions["tool.browser.navigate"].available) self.assertFalse(definitions["tool.browser.vision"].available) - self.assertIn("vision analyzer", definitions["tool.browser.vision"].availability.reason or "") + self.assertIn( + "vision analyzer", + definitions["tool.browser.vision"].availability.reason or "", + ) def test_browser_vision_is_available_with_analyzer(self) -> None: definitions = { @@ -359,7 +408,10 @@ def test_playwright_operations_run_on_dedicated_worker_thread(self) -> None: def _run_navigation() -> None: caller_thread_ids.append(threading.get_ident()) try: - backend.invoke("navigate", self._invoke("tool.browser.navigate", {"url": "example.com"})) + backend.invoke( + "navigate", + self._invoke("tool.browser.navigate", {"url": "example.com"}), + ) except BaseException as exc: errors.append(exc) diff --git a/tests/unit/test_builtin_catalog_docs.py b/tests/unit/test_builtin_catalog_docs.py index fcd4c21..4f052ed 100644 --- a/tests/unit/test_builtin_catalog_docs.py +++ b/tests/unit/test_builtin_catalog_docs.py @@ -8,7 +8,7 @@ if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) -from packages.tools import render_builtin_tool_reference_markdown, render_builtin_tool_summary_markdown +from packages.tools import render_builtin_tool_summary_markdown def _extract_between_markers(text: str, begin: str, end: str) -> str: @@ -29,7 +29,9 @@ def test_site_tools_doc_stays_in_sync_with_runtime_summary(self) -> None: ) self.assertEqual(actual, rendered) - def test_cli_reference_builtin_tool_summary_stays_in_sync_with_runtime_catalog(self) -> None: + def test_cli_reference_builtin_tool_summary_stays_in_sync_with_runtime_catalog( + self, + ) -> None: docs_path = ROOT / "apps" / "site" / "docs" / "reference" / "cli.md" rendered = render_builtin_tool_summary_markdown().strip() actual = _extract_between_markers( diff --git a/tests/unit/test_builtin_tools_v2.py b/tests/unit/test_builtin_tools_v2.py index 155b7da..73a89fd 100644 --- a/tests/unit/test_builtin_tools_v2.py +++ b/tests/unit/test_builtin_tools_v2.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Mapping -from datetime import datetime, timezone from email.message import Message import json import os @@ -19,7 +18,10 @@ from packages.cron import CronRuntime from packages.tools import handlers_code_execution from packages.tools.builtins import builtin_tool_definitions -from packages.tools.adapters import DeliveryMessageSurfaceAdapter, StructuredClarifySurface +from packages.tools.adapters import ( + DeliveryMessageSurfaceAdapter, + StructuredClarifySurface, +) from packages.tools import ( BuiltinToolDependencies, CallableApprovalGateway, @@ -87,7 +89,12 @@ def run_sub_agent( name: str | None = None, skills: tuple[str, ...] = (), ): - self.single = {"session_id": session_id, "task": task, "name": name, "skills": skills} + self.single = { + "session_id": session_id, + "task": task, + "name": name, + "skills": skills, + } if task == "fail": return {"summary": "sub-agent failed", "status": "failed"} return {"summary": "single sub-agent finished"} @@ -99,7 +106,11 @@ def run_sub_agents( tasks, max_concurrency: int = 3, ): - self.batch = {"session_id": session_id, "tasks": tasks, "max_concurrency": max_concurrency} + self.batch = { + "session_id": session_id, + "tasks": tasks, + "max_concurrency": max_concurrency, + } return {"summary": "sub-agent pool finished"} def start_sub_agents( @@ -109,7 +120,11 @@ def start_sub_agents( tasks, max_concurrency: int = 3, ): - self.started = {"session_id": session_id, "tasks": tasks, "max_concurrency": max_concurrency} + self.started = { + "session_id": session_id, + "tasks": tasks, + "max_concurrency": max_concurrency, + } return { "summary": "sub_agent_run_id: subrun-test\nstatus: running", "run_id": "subrun-test", @@ -175,7 +190,10 @@ def write_diary_entry(self, **kwargs): # type: ignore[no-untyped-def] return {"entry_date": kwargs["entry_date"]} def list_diary_entries(self, **kwargs): # type: ignore[no-untyped-def] - return {"entries": ({"entry_date": "2026-05-14", "content": "Today note"},), "count": 1} + return { + "entries": ({"entry_date": "2026-05-14", "content": "Today note"},), + "count": 1, + } class BuiltinToolsV2Test(unittest.TestCase): @@ -239,7 +257,9 @@ def test_runtime_filters_model_visible_available_tools(self) -> None: handler=lambda invocation: {"summary": invocation.tool_id}, ) - model_visible = {tool.tool_id for tool in runtime.list_tools(audience="model", enabled_only=True, available_only=True)} + model_visible = { + tool.tool_id for tool in runtime.list_tools(audience="model", enabled_only=True, available_only=True) + } operator_visible = {tool.tool_id for tool in runtime.list_tools(audience="operator", enabled_only=True)} self.assertIn("tool.file.read", model_visible) @@ -443,7 +463,10 @@ def test_model_skill_list_and_view_include_external_shelves(self) -> None: self.assertIn("personal-journal | Personal Journal | source=agents", listed.summary) self.assertIn("reference=agents:personal-journal", listed.summary) self.assertIn("skill_id: personal-journal", viewed.summary) - self.assertIn("Use this skill when the user asks to review personal journal notes.", viewed.summary) + self.assertIn( + "Use this skill when the user asks to review personal journal notes.", + viewed.summary, + ) def test_model_visible_action_tools_expose_constrained_action_enums(self) -> None: definitions = { @@ -456,7 +479,10 @@ def test_model_visible_action_tools_expose_constrained_action_enums(self) -> Non todo_action = definitions["tool.todo.manage"].schema["properties"]["action"]["enum"] todo_properties = definitions["tool.todo.manage"].schema["properties"] - self.assertEqual(tuple(process_action), ("list", "ls", "poll", "inspect", "wait", "write", "kill")) + self.assertEqual( + tuple(process_action), + ("list", "ls", "poll", "inspect", "wait", "write", "kill"), + ) self.assertEqual( tuple(cron_action), ("list", "ls", "create", "inspect", "pause", "resume", "remove", "delete"), @@ -467,7 +493,9 @@ def test_model_visible_action_tools_expose_constrained_action_enums(self) -> Non self.assertNotIn("noop", tuple(cron_action)) self.assertNotIn("noop", tuple(todo_action)) - def test_builtin_model_schema_carries_cron_description_and_action_guidance(self) -> None: + def test_builtin_model_schema_carries_cron_description_and_action_guidance( + self, + ) -> None: definitions = { definition.tool_id: definition for definition in builtin_tool_definitions({}, dependencies=BuiltinToolDependencies(cwd=Path("/tmp"))) @@ -486,9 +514,18 @@ def test_builtin_model_schema_carries_cron_description_and_action_guidance(self) self.assertIn("inspect|pause|resume|remove|delete", action["description"]) self.assertNotIn("job_kind", parameters["properties"]) self.assertIn("5-field cron", parameters["properties"]["schedule"]["description"]) - self.assertEqual(parameters["properties"]["prompt"]["description"], "Prompt payload for the scheduled prompt job when action=create.") - self.assertEqual(parameters["properties"]["profile_id"]["description"], "Optional profile scope filter for listing or creating jobs.") - self.assertEqual(parameters["properties"]["elephant_id"]["description"], "Optional elephant scope filter for listing or creating jobs.") + self.assertEqual( + parameters["properties"]["prompt"]["description"], + "Prompt payload for the scheduled prompt job when action=create.", + ) + self.assertEqual( + parameters["properties"]["profile_id"]["description"], + "Optional profile scope filter for listing or creating jobs.", + ) + self.assertEqual( + parameters["properties"]["elephant_id"]["description"], + "Optional elephant scope filter for listing or creating jobs.", + ) self.assertNotIn("message", parameters["properties"]) self.assertNotIn("query", parameters["properties"]) @@ -538,15 +575,24 @@ def test_personal_model_tool_schemas_replace_legacy_memory_tools(self) -> None: self.assertNotIn("tool.personal_model.verify", definitions) self.assertNotIn("tool.personal_model.audit", definitions) self.assertNotIn("tool.personal_model.inspect", definitions) - self.assertEqual(search_properties["status"]["enum"], ["active", "retired", "disputed", "all"]) + self.assertEqual( + search_properties["status"]["enum"], + ["active", "retired", "disputed", "all"], + ) self.assertIn("ref", search_properties) - self.assertIn("remember", update_properties["action"]["description"].lower() + " " + update["description"].lower()) + self.assertIn( + "remember", + update_properties["action"]["description"].lower() + " " + update["description"].lower(), + ) self.assertIn("restore", update_properties["action"]["enum"]) self.assertIn("delete", update_properties["action"]["enum"]) self.assertIn("identity={anchor", update_properties["topic"]["description"]) self.assertIn("Required for delete/restore", update_properties["ref"]["description"]) self.assertIn("recall_policy", update_properties) - self.assertEqual(update_properties["recall_policy"]["enum"], ["stable", "current", "temporary", "review"]) + self.assertEqual( + update_properties["recall_policy"]["enum"], + ["stable", "current", "temporary", "review"], + ) self.assertIn("text", question_properties) self.assertNotIn("question", question_properties) self.assertIn("copy", code_properties["code"]["description"]) @@ -561,7 +607,10 @@ def test_personal_model_tool_schemas_replace_legacy_memory_tools(self) -> None: self.assertIn("One concise question", clarify_properties["question"]["description"]) self.assertIn("mode=choice", clarify_properties["choices"]["description"]) self.assertIn("buffered stdout/stderr", process_properties["action"]["description"]) - self.assertIn("background tool.terminal.exec", process_properties["process_id"]["description"]) + self.assertIn( + "background tool.terminal.exec", + process_properties["process_id"]["description"], + ) self.assertIn("public-web information", web_search_properties["query"]["description"]) self.assertIn("query_variants", web_search_properties) self.assertIn("search results to summarize", web_search_properties["limit"]["description"]) @@ -573,7 +622,9 @@ def test_personal_model_tool_schemas_replace_legacy_memory_tools(self) -> None: self.assertNotIn("tool.procedure.inspect", definitions) self.assertNotIn("tool.procedure.manage", definitions) - def test_tool_fallback_prompt_routes_durable_personal_facts_to_personal_model_update(self) -> None: + def test_tool_fallback_prompt_routes_durable_personal_facts_to_personal_model_update( + self, + ) -> None: definitions = tuple( definition for definition in builtin_tool_definitions({}, dependencies=BuiltinToolDependencies(cwd=Path("/tmp"))) @@ -687,7 +738,10 @@ def test_sub_agents_accepts_skills_object_flags(self) -> None: { "name": "core", "task": "inspect core architecture", - "skills": {"codebase-inspection": True, "disabled-skill": False}, + "skills": { + "codebase-inspection": True, + "disabled-skill": False, + }, } ], "max_concurrency": 1, @@ -809,14 +863,20 @@ def test_file_tools_can_write_to_posix_tmp_by_default(self) -> None: self.assertIn("1|tmp ok", read.summary) def test_default_local_allowed_roots_include_posix_tmp(self) -> None: - with mock.patch("packages.tools.local_roots.tempfile.gettempdir", return_value="/var/folders/example/T"): + with mock.patch( + "packages.tools.local_roots.tempfile.gettempdir", + return_value="/var/folders/example/T", + ): roots = default_local_allowed_roots() self.assertIn(Path("/tmp").resolve(), roots) self.assertIn(Path("/var/folders/example/T").resolve(), roots) def test_file_tools_can_access_configured_roots_outside_primary_root(self) -> None: - with tempfile.TemporaryDirectory() as local_tmpdir, tempfile.TemporaryDirectory() as external_tmpdir: + with ( + tempfile.TemporaryDirectory() as local_tmpdir, + tempfile.TemporaryDirectory() as external_tmpdir, + ): local_root = Path(local_tmpdir) external = Path(external_tmpdir) shared = external / "shared.txt" @@ -874,7 +934,10 @@ def test_file_tools_can_access_configured_roots_outside_primary_root(self) -> No self.assertIn(str(external), terminal.summary) def test_file_and_terminal_tools_default_to_session_cwd(self) -> None: - with tempfile.TemporaryDirectory() as root_tmpdir, tempfile.TemporaryDirectory() as fallback_tmpdir: + with ( + tempfile.TemporaryDirectory() as root_tmpdir, + tempfile.TemporaryDirectory() as fallback_tmpdir, + ): root = Path(root_tmpdir) fallback = Path(fallback_tmpdir) roots = { @@ -992,7 +1055,10 @@ def test_file_write_blocks_sensitive_home_paths(self) -> None: with self.assertRaisesRegex(ValueError, "sensitive credential directory"): runtime.invoke( "tool.file.write", - {"path": str(Path.home() / ".ssh" / "config"), "content": "Host *\n"}, + { + "path": str(Path.home() / ".ssh" / "config"), + "content": "Host *\n", + }, session_id="session-sensitive-write", ) with self.assertRaisesRegex(ValueError, "VCS metadata"): @@ -1145,7 +1211,9 @@ def test_file_search_accepts_pattern_alias_for_query(self) -> None: self.assertIn("TestmemoryRecall", result.summary) self.assertNotIn("TestmemoryRecallIgnored", result.summary) - def test_file_search_allows_glob_only_file_listing_and_blocks_vcs_metadata(self) -> None: + def test_file_search_allows_glob_only_file_listing_and_blocks_vcs_metadata( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: cwd = Path(tmpdir) (cwd / "notes.md").write_text("hello\n", encoding="utf-8") @@ -1182,7 +1250,7 @@ def test_terminal_exec_background_processes_can_be_waited_on(self) -> None: started = runtime.invoke( "tool.terminal.exec", { - "command": 'python3 -c "import time; print(\'bg-finished\'); time.sleep(0.1)"', + "command": "python3 -c \"import time; print('bg-finished'); time.sleep(0.1)\"", "background": True, }, session_id="session-process", @@ -1246,7 +1314,7 @@ def test_terminal_exec_merges_env_overrides_with_parent_environment(self) -> Non "tool.terminal.exec", { "command": ( - "python3 -c \"import os; " + 'python3 -c "import os; ' "print(os.environ.get('ELEPHANT_TEST_ENV')); " "print(bool(os.environ.get('PATH')))\"" ), @@ -1321,7 +1389,9 @@ def test_code_execute_allows_safe_stdlib_imports(self) -> None: self.assertEqual(result.outcome, "success") self.assertIn('"a": 3', result.summary) - def test_code_execute_documents_and_allows_copy_pow_and_safe_dunder_name(self) -> None: + def test_code_execute_documents_and_allows_copy_pow_and_safe_dunder_name( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: runtime = self._make_builtin_runtime(cwd=Path(tmpdir)) @@ -1357,7 +1427,9 @@ def test_code_execute_schema_safe_imports_match_enforced_allowlist(self) -> None self.assertIn(blocked, description) self.assertIn("blocked", description) - def test_code_execute_runs_with_project_cwd_and_venv_python_by_default(self) -> None: + def test_code_execute_runs_with_project_cwd_and_venv_python_by_default( + self, + ) -> None: from packages.tools import handlers_code_execution with tempfile.TemporaryDirectory() as tmpdir: @@ -1390,10 +1462,18 @@ def test_code_execute_runs_with_project_cwd_and_venv_python_by_default(self) -> {str(python_path), sys.executable}, ) else: - self.assertEqual(handlers_code_execution._code_child_python(mode="project"), str(python_path)) - self.assertEqual(handlers_code_execution._code_child_python(mode="strict"), sys.executable) + self.assertEqual( + handlers_code_execution._code_child_python(mode="project"), + str(python_path), + ) + self.assertEqual( + handlers_code_execution._code_child_python(mode="strict"), + sys.executable, + ) - def test_code_execute_can_call_terminal_but_rejects_background_arguments(self) -> None: + def test_code_execute_can_call_terminal_but_rejects_background_arguments( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: runtime = self._make_builtin_runtime(cwd=Path(tmpdir)) @@ -1414,8 +1494,7 @@ def test_code_execute_can_call_terminal_but_rejects_background_arguments(self) - "tool.code.execute", { "code": ( - "result = tool('tool.terminal.exec', " - "{'command': 'printf blocked', 'background': True})" + "result = tool('tool.terminal.exec', {'command': 'printf blocked', 'background': True})" ), }, session_id="session-code-terminal-blocked", @@ -1468,10 +1547,15 @@ def test_code_execute_can_write_files_and_extract_web_sources(self) -> None: ) self.assertEqual(result.outcome, "success") - self.assertEqual((cwd / 'notes' / 'out.txt').read_text(encoding='utf-8'), "patched by code\n") + self.assertEqual( + (cwd / "notes" / "out.txt").read_text(encoding="utf-8"), + "patched by code\n", + ) self.assertIn("Alpha Doc", result.summary) - def test_code_execute_rejects_unsafe_imports_and_non_allowlisted_tool_rpc(self) -> None: + def test_code_execute_rejects_unsafe_imports_and_non_allowlisted_tool_rpc( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: runtime = self._make_builtin_runtime(cwd=Path(tmpdir)) diff --git a/tests/unit/test_daemon.py b/tests/unit/test_daemon.py index fee332a..b342c20 100644 --- a/tests/unit/test_daemon.py +++ b/tests/unit/test_daemon.py @@ -6,7 +6,6 @@ import json import os import signal -import sys import time import warnings from pathlib import Path @@ -60,14 +59,20 @@ def test_current_pid(self, tmp_path: Path) -> None: def test_healthz_state_identity_must_match(self, tmp_path: Path) -> None: from apps.daemon_command import _healthz_matches_state - assert _healthz_matches_state( - {"status": "running", "state_dir": str(tmp_path)}, - tmp_path, - ) is True - assert _healthz_matches_state( - {"status": "running", "state_dir": str(tmp_path / "other")}, - tmp_path, - ) is False + assert ( + _healthz_matches_state( + {"status": "running", "state_dir": str(tmp_path)}, + tmp_path, + ) + is True + ) + assert ( + _healthz_matches_state( + {"status": "running", "state_dir": str(tmp_path / "other")}, + tmp_path, + ) + is False + ) class TestStartDaemonDetached: @@ -92,7 +97,11 @@ def test_start_and_cleanup(self, tmp_path: Path) -> None: patch("apps.daemon_command.subprocess.Popen") as mock_popen, patch( "apps.daemon_command._daemon_healthz_payload", - return_value={"status": "running", "pid": 12345, "state_dir": str(tmp_path)}, + return_value={ + "status": "running", + "pid": 12345, + "state_dir": str(tmp_path), + }, ), ): mock_process = mock_popen.return_value @@ -130,10 +139,17 @@ def __del__(self) -> None: ) with ( - patch("apps.daemon_command.subprocess.Popen", side_effect=lambda *_args, **_kwargs: WarningProcess()), + patch( + "apps.daemon_command.subprocess.Popen", + side_effect=lambda *_args, **_kwargs: WarningProcess(), + ), patch( "apps.daemon_command._daemon_healthz_payload", - return_value={"status": "running", "pid": 12346, "state_dir": str(tmp_path)}, + return_value={ + "status": "running", + "pid": 12346, + "state_dir": str(tmp_path), + }, ), ): with warnings.catch_warnings(record=True) as caught: @@ -144,8 +160,7 @@ def __del__(self) -> None: assert not [ warning for warning in caught - if warning.category is ResourceWarning - and "subprocess 12346 is still running" in str(warning.message) + if warning.category is ResourceWarning and "subprocess 12346 is still running" in str(warning.message) ] def test_start_does_not_overwrite_child_ready_record_after_timeout(self, tmp_path: Path) -> None: @@ -168,7 +183,10 @@ def mark_child_ready(_state_dir: Path) -> None: with ( patch("apps.daemon_command.subprocess.Popen", return_value=FakeProcess()), patch("apps.daemon_command._DAEMON_STARTUP_WAIT_SECONDS", 0.0), - patch("apps.daemon_command._daemon_healthz_payload", side_effect=mark_child_ready), + patch( + "apps.daemon_command._daemon_healthz_payload", + side_effect=mark_child_ready, + ), ): result = start_daemon_detached(tmp_path, tmp_path) @@ -208,7 +226,10 @@ def test_stop_uses_healthz_pid_when_pid_file_is_missing(self, tmp_path: Path) -> from apps.daemon_command import stop_daemon record_path = tmp_path / "daemon.runtime.json" - record_path.write_text(json.dumps({"status": "running", "host": "127.0.0.1", "port": 9876}), encoding="utf-8") + record_path.write_text( + json.dumps({"status": "running", "host": "127.0.0.1", "port": 9876}), + encoding="utf-8", + ) running = {"value": True} def fake_is_running(pid: int | None) -> bool: @@ -255,9 +276,7 @@ class TestDaemonTaskGuard: def test_normal_completion(self) -> None: from apps.daemon import DaemonServiceStatus, _daemon_task_guard - statuses: dict[str, DaemonServiceStatus] = { - "test": DaemonServiceStatus(name="test", status="running") - } + statuses: dict[str, DaemonServiceStatus] = {"test": DaemonServiceStatus(name="test", status="running")} async def _inner(): pass # Complete normally @@ -273,9 +292,7 @@ async def _run(): def test_exception_updates_status(self) -> None: from apps.daemon import DaemonServiceStatus, _daemon_task_guard - statuses: dict[str, DaemonServiceStatus] = { - "test": DaemonServiceStatus(name="test", status="running") - } + statuses: dict[str, DaemonServiceStatus] = {"test": DaemonServiceStatus(name="test", status="running")} async def _inner(): raise RuntimeError("boom") @@ -292,9 +309,7 @@ def test_cancellation_cancels_inner(self) -> None: """When the guard is cancelled, the inner task should also be cancelled.""" from apps.daemon import DaemonServiceStatus, _daemon_task_guard - statuses: dict[str, DaemonServiceStatus] = { - "test": DaemonServiceStatus(name="test", status="running") - } + statuses: dict[str, DaemonServiceStatus] = {"test": DaemonServiceStatus(name="test", status="running")} inner_cancelled = False async def _inner(): @@ -393,9 +408,7 @@ def test_datetime_at_top(self) -> None: ] assert len(datetime_imports) >= 1, "datetime import should exist at module level" # Verify none at the bottom (after function defs) - last_func_line = max( - node.lineno for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) - ) + last_func_line = max(node.lineno for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)) for imp in datetime_imports: assert imp.lineno < last_func_line, ( f"datetime import at line {imp.lineno} should be at the top, " @@ -406,7 +419,9 @@ def test_datetime_at_top(self) -> None: class TestLearningWorkerLoop: """Tests for daemon learning worker event-loop behavior.""" - def test_learning_worker_does_not_idle_exit_by_default(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + def test_learning_worker_does_not_idle_exit_by_default( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: from apps import daemon_tasks class FakeRepository: @@ -423,7 +438,10 @@ def fake_write_record(*_args: object, **_kwargs: object) -> dict[str, object]: return {} monkeypatch.setattr(daemon_tasks, "RuntimeStorageRepository", fake_repository_factory) - monkeypatch.setattr("apps.learning_worker_runtime._write_learning_worker_record", fake_write_record) + monkeypatch.setattr( + "apps.learning_worker_runtime._write_learning_worker_record", + fake_write_record, + ) running = True @@ -476,7 +494,10 @@ def fake_run_claimed_job(_state_dir: Path, _job_id: str, _worker_id: str) -> Non running = False monkeypatch.setattr(daemon_tasks, "RuntimeStorageRepository", fake_repository_factory) - monkeypatch.setattr("apps.learning_worker_runtime._write_learning_worker_record", fake_write_record) + monkeypatch.setattr( + "apps.learning_worker_runtime._write_learning_worker_record", + fake_write_record, + ) monkeypatch.setattr(daemon_tasks, "_run_claimed_learning_job", fake_run_claimed_job) tick_at = 0.0 diff --git a/tests/unit/test_daemon_adapter_lifecycle.py b/tests/unit/test_daemon_adapter_lifecycle.py index 98b2d2f..c4efa10 100644 --- a/tests/unit/test_daemon_adapter_lifecycle.py +++ b/tests/unit/test_daemon_adapter_lifecycle.py @@ -7,9 +7,8 @@ from __future__ import annotations import asyncio -import os from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch # ── has_credentials() tests ────────────────────────────────────── @@ -174,9 +173,7 @@ def test_with_token_and_real_account_id_returns_true(self) -> None: from apps.gateway.weixin_support import WeixinGatewayAccountConfig service = WeixinGatewayService.__new__(WeixinGatewayService) - service.account_configs = ( - WeixinGatewayAccountConfig(account_id="wx_real_account", token="some_token"), - ) + service.account_configs = (WeixinGatewayAccountConfig(account_id="wx_real_account", token="some_token"),) service.runtime_state_dir = Path("/tmp") assert service.has_credentials() is True @@ -185,9 +182,7 @@ def test_saved_token_in_local_storage_returns_true(self) -> None: from apps.gateway.weixin_support import WeixinGatewayAccountConfig service = WeixinGatewayService.__new__(WeixinGatewayService) - service.account_configs = ( - WeixinGatewayAccountConfig(account_id="wx_account", token=""), - ) + service.account_configs = (WeixinGatewayAccountConfig(account_id="wx_account", token=""),) service.runtime_state_dir = Path("/tmp") with patch( "apps.gateway.weixin_service.load_weixin_account", @@ -273,9 +268,7 @@ async def _run(): state_dir=Path("/tmp/test-daemon"), cli_state_dir=Path("/tmp/test-cli"), ) - daemon._service_statuses["discord"] = DaemonServiceStatus( - name="discord", status="running" - ) + daemon._service_statuses["discord"] = DaemonServiceStatus(name="discord", status="running") result = await daemon.start_adapter("discord") assert result["status"] == "already_running" @@ -306,9 +299,7 @@ async def _run(): state_dir=Path("/tmp/test-daemon"), cli_state_dir=Path("/tmp/test-cli"), ) - daemon._service_statuses["discord"] = DaemonServiceStatus( - name="discord", status="skipped" - ) + daemon._service_statuses["discord"] = DaemonServiceStatus(name="discord", status="skipped") result = await daemon.stop_adapter("discord") assert result["status"] == "not_running" diff --git a/tests/unit/test_model_metadata.py b/tests/unit/test_model_metadata.py index af9d299..05e773c 100644 --- a/tests/unit/test_model_metadata.py +++ b/tests/unit/test_model_metadata.py @@ -108,7 +108,9 @@ def test_resolves_openrouter_metadata_when_models_dev_misses(self) -> None: self.assertEqual(metadata, openrouter_entry) - def test_resolves_local_endpoint_model_detail_before_remote_registries(self) -> None: + def test_resolves_local_endpoint_model_detail_before_remote_registries( + self, + ) -> None: server = _MetadataStubServer( { "/v1/models/local-model": { @@ -169,7 +171,11 @@ def test_resolves_openai_compatible_provider_from_known_base_url(self) -> None: def test_persistent_context_length_cache_takes_precedence(self) -> None: with TemporaryDirectory() as tempdir: cache_path = Path(tempdir) / "context-length-cache.json" - with mock.patch.dict(os.environ, {"ELEPHANT_CONTEXT_LENGTH_CACHE_PATH": str(cache_path)}, clear=False): + with mock.patch.dict( + os.environ, + {"ELEPHANT_CONTEXT_LENGTH_CACHE_PATH": str(cache_path)}, + clear=False, + ): model_metadata.save_context_length( "custom-model", "https://example.test/v1", @@ -205,7 +211,11 @@ def test_local_probe_persists_detected_context_length(self) -> None: with TemporaryDirectory() as tempdir: cache_path = Path(tempdir) / "context-length-cache.json" - with mock.patch.dict(os.environ, {"ELEPHANT_CONTEXT_LENGTH_CACHE_PATH": str(cache_path)}, clear=False): + with mock.patch.dict( + os.environ, + {"ELEPHANT_CONTEXT_LENGTH_CACHE_PATH": str(cache_path)}, + clear=False, + ): metadata = model_metadata.resolve_provider_model_metadata( provider_id="vllm", model_id="local-model", diff --git a/tests/unit/test_observability.py b/tests/unit/test_observability.py index cb75ba4..b357308 100644 --- a/tests/unit/test_observability.py +++ b/tests/unit/test_observability.py @@ -8,7 +8,6 @@ from unittest import mock import unittest -from opentelemetry import trace, metrics from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.metrics import MeterProvider @@ -56,11 +55,13 @@ def test_update_context(self) -> None: self.assertEqual(ctx.trace_id, "t1") def test_filter_injects_fields(self) -> None: - set_context(TraceContext( - trace_id="trace123", - episode_id="ep456", - loop_id="loop789", - )) + set_context( + TraceContext( + trace_id="trace123", + episode_id="ep456", + loop_id="loop789", + ) + ) record = logging.LogRecord("test", logging.INFO, "", 0, "msg", (), None) f = TraceContextFilter() f.filter(record) @@ -119,6 +120,7 @@ def test_formats_as_json(self) -> None: class ConfigureLoggingTests(unittest.TestCase): def test_log_file_created(self) -> None: import packages.observability.logger as logger_mod + original = logger_mod._configured logger_mod._configured = False try: @@ -147,6 +149,7 @@ def __init__(self) -> None: def export(self, spans): self.spans.extend(spans) from opentelemetry.sdk.trace.export import SpanExportResult + return SpanExportResult.SUCCESS def shutdown(self) -> None: @@ -218,10 +221,26 @@ def setUp(self) -> None: meter = self.provider.get_meter("elephant-agent") self._patches = [ mock.patch.object(metrics_mod, "_meter", meter), - mock.patch.object(metrics_mod, "token_usage", meter.create_histogram("gen_ai.client.token.usage", unit="{token}")), - mock.patch.object(metrics_mod, "operation_duration", meter.create_histogram("gen_ai.client.operation.duration", unit="s")), - mock.patch.object(metrics_mod, "tool_duration", meter.create_histogram("elephant.tool.duration", unit="s")), - mock.patch.object(metrics_mod, "kernel_turn_duration", meter.create_histogram("elephant.kernel.turn.duration", unit="s")), + mock.patch.object( + metrics_mod, + "token_usage", + meter.create_histogram("gen_ai.client.token.usage", unit="{token}"), + ), + mock.patch.object( + metrics_mod, + "operation_duration", + meter.create_histogram("gen_ai.client.operation.duration", unit="s"), + ), + mock.patch.object( + metrics_mod, + "tool_duration", + meter.create_histogram("elephant.tool.duration", unit="s"), + ), + mock.patch.object( + metrics_mod, + "kernel_turn_duration", + meter.create_histogram("elephant.kernel.turn.duration", unit="s"), + ), ] for p in self._patches: p.start() @@ -235,7 +254,13 @@ def _metric_names(self) -> list[str]: return [m.name for rm in data.resource_metrics for sm in rm.scope_metrics for m in sm.metrics] def test_record_model_metrics(self) -> None: - metrics_mod.record_model_metrics(provider_id="openai", model_id="gpt-5", input_tokens=100, output_tokens=50, duration_s=1.5) + metrics_mod.record_model_metrics( + provider_id="openai", + model_id="gpt-5", + input_tokens=100, + output_tokens=50, + duration_s=1.5, + ) names = self._metric_names() self.assertIn("gen_ai.client.token.usage", names) self.assertIn("gen_ai.client.operation.duration", names) @@ -258,10 +283,12 @@ def test_duration_timer(self) -> None: class SetupTests(unittest.TestCase): def test_setup_is_idempotent(self) -> None: import packages.observability.setup as setup_mod + original = setup_mod._initialized setup_mod._initialized = False try: from packages.observability import setup_observability + setup_observability(service_name="test-1") setup_observability(service_name="test-2") finally: @@ -270,13 +297,14 @@ def test_setup_is_idempotent(self) -> None: class InstrumentorTests(unittest.TestCase): def test_instrument_and_uninstrument(self) -> None: - from packages.observability.instrumentor import instrument, uninstrument, _instrumented, _originals + from packages.observability.instrumentor import instrument, uninstrument import packages.observability.instrumentor as inst_mod inst_mod._instrumented = False inst_mod._originals.clear() try: from packages.kernel.runtime_impl import KernelService + original_run = KernelService.run instrument() @@ -291,16 +319,19 @@ def test_instrument_and_uninstrument(self) -> None: def test_instrument_is_idempotent(self) -> None: from packages.observability.instrumentor import instrument, uninstrument import packages.observability.instrumentor as inst_mod + inst_mod._instrumented = False inst_mod._originals.clear() try: from packages.kernel.runtime_impl import KernelService + instrument() run_after_first = KernelService.run instrument() self.assertIs(KernelService.run, run_after_first) finally: from packages.observability.instrumentor import uninstrument + uninstrument() inst_mod._instrumented = False inst_mod._originals.clear() diff --git a/tests/unit/test_personal_model_lifecycle.py b/tests/unit/test_personal_model_lifecycle.py index fa3cd8b..3f7d9ed 100644 --- a/tests/unit/test_personal_model_lifecycle.py +++ b/tests/unit/test_personal_model_lifecycle.py @@ -135,7 +135,15 @@ def ensure_default_personal_model(self, personal_model_id="you"): return None def load_episode_state(self, _session_id): - return type("_Episode", (), {"episode_id": "current-session", "personal_model_id": "you", "state_id": "state-1"})() + return type( + "_Episode", + (), + { + "episode_id": "current-session", + "personal_model_id": "you", + "state_id": "state-1", + }, + )() def current_state(self): return type("_State", (), {"state_id": "state-1"})() @@ -166,7 +174,15 @@ def ensure_default_personal_model(self, personal_model_id="you"): return None def load_episode_state(self, _session_id): - return type("_Episode", (), {"episode_id": "current-session", "personal_model_id": "you", "state_id": "state-1"})() + return type( + "_Episode", + (), + { + "episode_id": "current-session", + "personal_model_id": "you", + "state_id": "state-1", + }, + )() def current_state(self): return type("_State", (), {"state_id": "state-1"})() @@ -198,7 +214,15 @@ def ensure_default_personal_model(self, personal_model_id="you"): return None def load_episode_state(self, _session_id): - return type("_Episode", (), {"episode_id": "current-session", "personal_model_id": "you", "state_id": "state-1"})() + return type( + "_Episode", + (), + { + "episode_id": "current-session", + "personal_model_id": "you", + "state_id": "state-1", + }, + )() def current_state(self): return type("_State", (), {"state_id": "state-1"})() @@ -223,7 +247,10 @@ def list_steps(self, *, loop_id=None): sequence=1, created_at=now, summary="source item ingested", - metadata={"event_type": "turn.received", "user_query": "当前这轮也提到了家庭。"}, + metadata={ + "event_type": "turn.received", + "user_query": "当前这轮也提到了家庭。", + }, ), Step( step_id="step-old", @@ -237,14 +264,20 @@ def list_steps(self, *, loop_id=None): sequence=2, created_at=now, summary="source item ingested", - metadata={"event_type": "turn.received", "user_query": "昨晚我们聊了家庭边界。"}, + metadata={ + "event_type": "turn.received", + "user_query": "昨晚我们聊了家庭边界。", + }, ), ) result = PersonalModelUnderstandingSurface(repository=_Repo()).search_conversation( "current-session", query="家庭", - time_range={"start_at": "2026-05-09T00:00:00+00:00", "end_at": "2026-05-09T02:00:00+00:00"}, + time_range={ + "start_at": "2026-05-09T00:00:00+00:00", + "end_at": "2026-05-09T02:00:00+00:00", + }, mode="recall", ) @@ -252,7 +285,9 @@ def list_steps(self, *, loop_id=None): self.assertIn("昨晚我们聊了家庭边界", contents) self.assertNotIn("当前这轮", contents) - def test_conversation_discover_returns_copyable_range_and_user_anchor_first(self) -> None: + def test_conversation_discover_returns_copyable_range_and_user_anchor_first( + self, + ) -> None: now = datetime(2026, 5, 9, 1, 20, tzinfo=timezone.utc) class _Repo: @@ -260,7 +295,15 @@ def ensure_default_personal_model(self, personal_model_id="you"): return None def load_episode_state(self, _session_id): - return type("_Episode", (), {"episode_id": "current-session", "personal_model_id": "you", "state_id": "state-1"})() + return type( + "_Episode", + (), + { + "episode_id": "current-session", + "personal_model_id": "you", + "state_id": "state-1", + }, + )() def current_state(self): return type("_State", (), {"state_id": "state-1"})() @@ -299,14 +342,21 @@ def list_steps(self, *, loop_id=None): sequence=2, created_at=now, summary="source item ingested", - metadata={"event_type": "turn.received", "user_query": "我说了一个家庭边界的具体场景。"}, + metadata={ + "event_type": "turn.received", + "user_query": "我说了一个家庭边界的具体场景。", + }, ), ) result = PersonalModelUnderstandingSurface(repository=_Repo()).search_conversation( "current-session", query="家庭边界", - time_range={"start_at": "2026-05-09T00:00:00+00:00", "end_at": "2026-05-09T02:00:00+00:00", "timezone": "Asia/Shanghai"}, + time_range={ + "start_at": "2026-05-09T00:00:00+00:00", + "end_at": "2026-05-09T02:00:00+00:00", + "timezone": "Asia/Shanghai", + }, mode="discover", bucket="hour", ) @@ -350,7 +400,10 @@ def list_steps(self, *, loop_id=None): sequence=1, created_at=now, summary="tool result mentions family but is not conversation", - metadata={"tool_name": "tool.conversation.search", "tool_result": "家庭 recall test report"}, + metadata={ + "tool_name": "tool.conversation.search", + "tool_result": "家庭 recall test report", + }, ), Step( step_id="step-internal", @@ -364,7 +417,10 @@ def list_steps(self, *, loop_id=None): sequence=2, created_at=now, summary="source item ingested", - metadata={"event_type": "turn.internal", "user_query": "Write Iris's first message about family."}, + metadata={ + "event_type": "turn.internal", + "user_query": "Write Iris's first message about family.", + }, ), Step( step_id="step-user", @@ -378,14 +434,20 @@ def list_steps(self, *, loop_id=None): sequence=3, created_at=now, summary="source item ingested", - metadata={"event_type": "turn.received", "user_query": "我们昨晚聊了家庭权力结构。"}, + metadata={ + "event_type": "turn.received", + "user_query": "我们昨晚聊了家庭权力结构。", + }, ), ) result = PersonalModelUnderstandingSurface(repository=_Repo()).recall_personal_model( "session-life", query="家庭", - time_range={"start_at": "2026-05-08T18:00:00+00:00", "end_at": "2026-05-09T06:00:00+00:00"}, + time_range={ + "start_at": "2026-05-08T18:00:00+00:00", + "end_at": "2026-05-09T06:00:00+00:00", + }, ) contents = "\n".join(str(hit.get("content") or "") for hit in tuple(result.get("hits") or ())) @@ -393,7 +455,9 @@ def list_steps(self, *, loop_id=None): self.assertNotIn("tool result", contents) self.assertNotIn("Write Iris", contents) - def test_tool_update_adds_review_metadata_for_changeable_account_claim(self) -> None: + def test_tool_update_adds_review_metadata_for_changeable_account_claim( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "elephant.sqlite3") repository.bootstrap() @@ -755,8 +819,16 @@ def test_audit_surface_returns_topics_and_health(self) -> None: personal_model_id=state.personal_model_id, ) - topics = surface.audit_personal_model("session-life", action="topics", personal_model_id=state.personal_model_id) - health = surface.audit_personal_model("session-life", action="health", personal_model_id=state.personal_model_id) + topics = surface.audit_personal_model( + "session-life", + action="topics", + personal_model_id=state.personal_model_id, + ) + health = surface.audit_personal_model( + "session-life", + action="health", + personal_model_id=state.personal_model_id, + ) self.assertTrue(tuple(topics.get("topics") or ())) self.assertEqual(health["health_report"]["total_active_claims"], 1) @@ -790,7 +862,12 @@ def test_single_active_topic_remember_supersedes_existing_claim(self) -> None: ) active = repository.list_personal_model_facts(personal_model_id=state.personal_model_id, status="active") - retired = surface.search_personal_model("session-life", ref=first["ref"], status="retired", personal_model_id=state.personal_model_id) + retired = surface.search_personal_model( + "session-life", + ref=first["ref"], + status="retired", + personal_model_id=state.personal_model_id, + ) self.assertEqual(len(active), 1) self.assertEqual(second["retired"], (first["ref"],)) @@ -867,8 +944,16 @@ def test_health_and_related_reasons_with_clean_topics(self) -> None: personal_model_id=state.personal_model_id, ) - topics = surface.audit_personal_model("session-life", action="topics", personal_model_id=state.personal_model_id) - health = surface.audit_personal_model("session-life", action="health", personal_model_id=state.personal_model_id) + topics = surface.audit_personal_model( + "session-life", + action="topics", + personal_model_id=state.personal_model_id, + ) + health = surface.audit_personal_model( + "session-life", + action="health", + personal_model_id=state.personal_model_id, + ) related = surface.search_personal_model( "session-life", query="小红书账号", @@ -903,7 +988,9 @@ def test_update_rejects_bad_topic_key(self) -> None: personal_model_id=state.personal_model_id, ) - def test_forget_dispute_without_ref_does_not_report_retired_when_no_match(self) -> None: + def test_forget_dispute_without_ref_does_not_report_retired_when_no_match( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: repository = RuntimeStorageRepository(Path(tmpdir) / "elephant.sqlite3") repository.bootstrap() diff --git a/tests/unit/test_prefix_cache.py b/tests/unit/test_prefix_cache.py index e5159b1..2f89ac4 100644 --- a/tests/unit/test_prefix_cache.py +++ b/tests/unit/test_prefix_cache.py @@ -30,18 +30,24 @@ def test_list_stable_after_re_registration(self) -> None: class PrefixCacheHashTest(unittest.TestCase): def test_same_inputs_produce_same_hash(self) -> None: from packages.kernel.generation_context import _prefix_input_hash + h1 = _prefix_input_hash("prefix", ("fact1", "fact2"), ("resume",), "skills") h2 = _prefix_input_hash("prefix", ("fact1", "fact2"), ("resume",), "skills") self.assertEqual(h1, h2) def test_different_facts_produce_different_hash(self) -> None: from packages.kernel.generation_context import _prefix_input_hash + h1 = _prefix_input_hash("prefix", ("fact1",), ("resume",), "skills") h2 = _prefix_input_hash("prefix", ("fact1", "fact2"), ("resume",), "skills") self.assertNotEqual(h1, h2) def test_cache_invalidation(self) -> None: - from packages.kernel.generation_context import _prefix_cache, invalidate_prefix_cache + from packages.kernel.generation_context import ( + _prefix_cache, + invalidate_prefix_cache, + ) + _prefix_cache["test-ep"] = ("hash123", "cached prefix") invalidate_prefix_cache("test-ep") self.assertNotIn("test-ep", _prefix_cache) @@ -50,12 +56,23 @@ def test_cache_invalidation(self) -> None: class AnthropicCacheControlTest(unittest.TestCase): def test_official_anthropic_uses_content_blocks_with_cache_control(self) -> None: from packages.models.providers.anthropic import AnthropicMessagesRequest + req = AnthropicMessagesRequest( - request_id="r1", provider_id="anthropic", transport_id="anthropic_messages", - request_family="anthropic_messages", model_id="claude-4", base_url="https://api.anthropic.com/v1", - endpoint_path="/v1/messages", headers={}, system="test system prompt", - messages=(), max_tokens=1024, - tools=({"name": "tool_a", "input_schema": {}}, {"name": "tool_b", "input_schema": {}}), + request_id="r1", + provider_id="anthropic", + transport_id="anthropic_messages", + request_family="anthropic_messages", + model_id="claude-4", + base_url="https://api.anthropic.com/v1", + endpoint_path="/v1/messages", + headers={}, + system="test system prompt", + messages=(), + max_tokens=1024, + tools=( + {"name": "tool_a", "input_schema": {}}, + {"name": "tool_b", "input_schema": {}}, + ), ) payload = req.as_mapping() self.assertIsInstance(payload["system"], list) @@ -66,11 +83,19 @@ def test_official_anthropic_uses_content_blocks_with_cache_control(self) -> None def test_non_anthropic_provider_uses_plain_string(self) -> None: from packages.models.providers.anthropic import AnthropicMessagesRequest + req = AnthropicMessagesRequest( - request_id="r1", provider_id="minimax-cn", transport_id="anthropic_messages", - request_family="anthropic_messages", model_id="model-x", base_url="https://api.minimaxi.com/anthropic", - endpoint_path="/v1/messages", headers={}, system="test system prompt", - messages=(), max_tokens=1024, + request_id="r1", + provider_id="minimax-cn", + transport_id="anthropic_messages", + request_family="anthropic_messages", + model_id="model-x", + base_url="https://api.minimaxi.com/anthropic", + endpoint_path="/v1/messages", + headers={}, + system="test system prompt", + messages=(), + max_tokens=1024, tools=({"name": "tool_a", "input_schema": {}},), ) payload = req.as_mapping() diff --git a/tests/unit/test_provider_runtime_support.py b/tests/unit/test_provider_runtime_support.py index ed246dc..82951ba 100644 --- a/tests/unit/test_provider_runtime_support.py +++ b/tests/unit/test_provider_runtime_support.py @@ -19,7 +19,9 @@ class ProviderSelectionPayloadTest(unittest.TestCase): - def test_load_provider_profile_reads_provider_profile_from_config_yaml(self) -> None: + def test_load_provider_profile_reads_provider_profile_from_config_yaml( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: state_dir = Path(tmpdir) / "state" state_dir.mkdir(parents=True, exist_ok=True) @@ -78,7 +80,9 @@ def test_sentence_transformer_version_warning_is_filtered(self) -> None: class EmbeddingBootstrapStateTest(unittest.TestCase): - def test_resolve_embedding_bootstrap_state_uses_ready_when_root_is_healthy(self) -> None: + def test_resolve_embedding_bootstrap_state_uses_ready_when_root_is_healthy( + self, + ) -> None: with mock.patch.object(model_bootstrap, "embedding_root_is_healthy", return_value=True): state = provider_runtime_support.resolve_embedding_bootstrap_state( Path("/tmp/elephant-bootstrap-state"), @@ -89,11 +93,16 @@ def test_resolve_embedding_bootstrap_state_uses_ready_when_root_is_healthy(self) self.assertEqual(state.state_focus_mode, "embedded") self.assertIsNone(state.background_pid) - def test_resolve_embedding_bootstrap_state_uses_downloading_when_dependencies_exist(self) -> None: + def test_resolve_embedding_bootstrap_state_uses_downloading_when_dependencies_exist( + self, + ) -> None: with ( - mock.patch.object(model_bootstrap, "embedding_root_is_healthy", return_value=False), - mock.patch.object(model_bootstrap, "sentence_transformers_dependencies_ready", return_value=True), - + mock.patch.object(model_bootstrap, "embedding_root_is_healthy", return_value=False), + mock.patch.object( + model_bootstrap, + "sentence_transformers_dependencies_ready", + return_value=True, + ), ): state = provider_runtime_support.resolve_embedding_bootstrap_state( Path("/tmp/elephant-bootstrap-state"), @@ -103,13 +112,19 @@ def test_resolve_embedding_bootstrap_state_uses_downloading_when_dependencies_ex self.assertEqual(state.status, "downloading") self.assertIn("background model acquisition", state.summary) - def test_trigger_embedding_bootstrap_spawns_background_worker_for_pending_state(self) -> None: + def test_trigger_embedding_bootstrap_spawns_background_worker_for_pending_state( + self, + ) -> None: with tempfile.TemporaryDirectory() as tmpdir: state_dir = Path(tmpdir) fake_process = mock.Mock(pid=43210) with ( mock.patch.object(model_bootstrap, "embedding_root_is_healthy", return_value=False), - mock.patch.object(model_bootstrap, "sentence_transformers_dependencies_ready", return_value=False), + mock.patch.object( + model_bootstrap, + "sentence_transformers_dependencies_ready", + return_value=False, + ), mock.patch.object(model_bootstrap.subprocess, "Popen", return_value=fake_process) as popen, ): state = provider_runtime_support.trigger_embedding_bootstrap( @@ -143,7 +158,11 @@ def test_trigger_embedding_bootstrap_reuses_active_background_worker(self) -> No ) with ( mock.patch.object(model_bootstrap, "embedding_root_is_healthy", return_value=False), - mock.patch.object(model_bootstrap, "sentence_transformers_dependencies_ready", return_value=False), + mock.patch.object( + model_bootstrap, + "sentence_transformers_dependencies_ready", + return_value=False, + ), mock.patch.object(model_bootstrap.subprocess, "Popen") as popen, ): state = provider_runtime_support.trigger_embedding_bootstrap( @@ -160,8 +179,16 @@ def test_trigger_embedding_bootstrap_surfaces_spawn_failures(self) -> None: state_dir = Path(tmpdir) with ( mock.patch.object(model_bootstrap, "embedding_root_is_healthy", return_value=False), - mock.patch.object(model_bootstrap, "sentence_transformers_dependencies_ready", return_value=False), - mock.patch.object(model_bootstrap.subprocess, "Popen", side_effect=OSError("spawn failed")), + mock.patch.object( + model_bootstrap, + "sentence_transformers_dependencies_ready", + return_value=False, + ), + mock.patch.object( + model_bootstrap.subprocess, + "Popen", + side_effect=OSError("spawn failed"), + ), ): state = provider_runtime_support.trigger_embedding_bootstrap( state_dir, @@ -193,7 +220,11 @@ def test_trigger_embedding_bootstrap_retries_after_previous_failure(self) -> Non fake_process = mock.Mock(pid=54321) with ( mock.patch.object(model_bootstrap, "embedding_root_is_healthy", return_value=False), - mock.patch.object(model_bootstrap, "sentence_transformers_dependencies_ready", return_value=False), + mock.patch.object( + model_bootstrap, + "sentence_transformers_dependencies_ready", + return_value=False, + ), mock.patch.object(model_bootstrap.subprocess, "Popen", return_value=fake_process) as popen, ): state = provider_runtime_support.trigger_embedding_bootstrap( diff --git a/tests/unit/test_reasoning_parser.py b/tests/unit/test_reasoning_parser.py index 3be1dbc..d9e82a1 100644 --- a/tests/unit/test_reasoning_parser.py +++ b/tests/unit/test_reasoning_parser.py @@ -2,16 +2,23 @@ import unittest -from packages.models.reasoning_parser import combine_reasoning_text, stitch_text_fragments +from packages.models.reasoning_parser import ( + combine_reasoning_text, + stitch_text_fragments, +) class ReasoningParserTests(unittest.TestCase): - def test_stitch_text_fragments_collapses_whitespace_only_tokens_between_english_words(self) -> None: + def test_stitch_text_fragments_collapses_whitespace_only_tokens_between_english_words( + self, + ) -> None: stitched = stitch_text_fragments("Inspect", "\n", "the", "\n\n", "latest", " ", "release") self.assertEqual(stitched, "Inspect the latest release") - def test_stitch_text_fragments_prioritizes_spaces_over_guessing_subword_joins(self) -> None: + def test_stitch_text_fragments_prioritizes_spaces_over_guessing_subword_joins( + self, + ) -> None: stitched = stitch_text_fragments("3.", "X", "un", "zhuo", " ", "lives", " ", "in", " ", "Cheng", "du") self.assertEqual(stitched, "3. X un zhuo lives in Cheng du") @@ -26,12 +33,16 @@ def test_stitch_text_fragments_does_not_guess_camel_case_boundaries(self) -> Non self.assertEqual(stitched, "LoopStatePr ojection contains Retrieved Memory entries.") - def test_stitch_text_fragments_keeps_mixed_language_reasoning_readable(self) -> None: + def test_stitch_text_fragments_keeps_mixed_language_reasoning_readable( + self, + ) -> None: stitched = stitch_text_fragments("先看", "\n", "release", "\n", "notes", "。", "\n", "Then", "\n", "verify") self.assertEqual(stitched, "先看release notes。 Then verify") - def test_combine_reasoning_text_deduplicates_equivalent_multiline_reasoning(self) -> None: + def test_combine_reasoning_text_deduplicates_equivalent_multiline_reasoning( + self, + ) -> None: combined = combine_reasoning_text( "先看release notes。 Then verify", "先看\nrelease\nnotes。\nThen\nverify", diff --git a/tests/unit/test_runtime_config.py b/tests/unit/test_runtime_config.py index 5d0eb88..8b323ed 100644 --- a/tests/unit/test_runtime_config.py +++ b/tests/unit/test_runtime_config.py @@ -72,7 +72,9 @@ def test_yaml_round_trip_and_default_merge(self) -> None: self.assertNotIn("state_focus_mode", loaded["models"]) self.assertEqual(loaded["skills"]["external_dirs"], ["~/.agents/skills"]) - def test_default_global_config_from_gateway_state_uses_shared_install_root(self) -> None: + def test_default_global_config_from_gateway_state_uses_shared_install_root( + self, + ) -> None: """Gateway and CLI share the same state_dir. Passing ``.../gateway`` as an explicit override is honoured — the @@ -84,7 +86,10 @@ def test_default_global_config_from_gateway_state_uses_shared_install_root(self) self.assertEqual(defaults["gateway"]["state_dir"], "/tmp/elephant/gateway") def test_parse_json_or_simple_yaml_object(self) -> None: - self.assertEqual(parse_global_config_text('{"dashboard": {"port": 9999}}')["dashboard"]["port"], 9999) + self.assertEqual( + parse_global_config_text('{"dashboard": {"port": 9999}}')["dashboard"]["port"], + 9999, + ) parsed = parse_global_config_text("dashboard:\n host: 127.0.0.1\n port: 4174\n") self.assertEqual(parsed["dashboard"]["host"], "127.0.0.1") self.assertEqual(parsed["dashboard"]["port"], 4174) @@ -124,16 +129,18 @@ def test_external_skill_dirs_accept_string_payloads(self) -> None: def test_load_provider_from_config(self) -> None: self.assertIsNone(load_provider_from_config({})) self.assertIsNone(load_provider_from_config({"models": {}})) - provider = load_provider_from_config({ - "models": { - "provider": { - "profile_id": "provider-openai-compatible", - "provider_id": "openai-compatible", - "base_url": "https://api.example.com/v1", - "default_model": "gpt-4", + provider = load_provider_from_config( + { + "models": { + "provider": { + "profile_id": "provider-openai-compatible", + "provider_id": "openai-compatible", + "base_url": "https://api.example.com/v1", + "default_model": "gpt-4", + } } } - }) + ) self.assertIsNotNone(provider) self.assertEqual(provider["profile_id"], "provider-openai-compatible") diff --git a/tests/unit/test_runtime_layout.py b/tests/unit/test_runtime_layout.py index 602b2c2..874d88b 100644 --- a/tests/unit/test_runtime_layout.py +++ b/tests/unit/test_runtime_layout.py @@ -17,9 +17,15 @@ def test_top_level_runtime_dirs_default_under_elephant_home(self) -> None: environ = {"ELEPHANT_HOME": "/tmp/elephant-home"} self.assertEqual(default_cron_dir(environ=environ), Path("/tmp/elephant-home/cron")) - self.assertEqual(default_workspaces_dir(environ=environ), Path("/tmp/elephant-home/workspaces")) + self.assertEqual( + default_workspaces_dir(environ=environ), + Path("/tmp/elephant-home/workspaces"), + ) self.assertEqual(default_pairing_dir(environ=environ), Path("/tmp/elephant-home/pairing")) - self.assertEqual(default_builtin_skills_dir(environ=environ), Path("/tmp/elephant-home/skills/builtin")) + self.assertEqual( + default_builtin_skills_dir(environ=environ), + Path("/tmp/elephant-home/skills/builtin"), + ) def test_workspace_path_escapes_path_characters(self) -> None: path = elephant_file_path("team/atlas", environ={"ELEPHANT_HOME": "/tmp/elephant-home"}) diff --git a/tools/agent/scripts/agent_gate.py b/tools/agent/scripts/agent_gate.py index f96640d..4b7a9d7 100755 --- a/tools/agent/scripts/agent_gate.py +++ b/tools/agent/scripts/agent_gate.py @@ -30,15 +30,26 @@ PYTHON_LINE_LIMIT_PATTERNS = tuple(f"{surface}/**/*.py" for surface in PYTHON_LINE_LIMIT_SURFACES) PYTHON_LINE_LIMIT_ALLOWLIST_PATTERNS: tuple[str, ...] = ( "apps/api/api_runtime_console_ops.py", + "apps/api/api_runtime_http_methods.py", + "apps/api/api_runtime_internal_sections.py", "apps/cli/cli_main_impl.py", "apps/cli/runtime_extensions_surface.py", "apps/cli/shell_composer.py", "apps/cli/shell_methods_commands.py", + "apps/cli/shell_progress_trace.py", + "apps/cli/wizard.py", "apps/gateway/gateway_main_impl.py", + "apps/gateway/gateway_main_setup_impl.py", + "packages/context/projection.py", + "packages/context/runtime_support.py", "packages/evidence/runtime.py", "packages/learning/personal_model_evolution.py", "packages/models/providers/openai_compatible.py", "packages/storage/repository_system_methods.py", + "packages/tools/browser_backend.py", + "packages/tools/builtins.py", + "packages/tools/handlers_filesystem.py", + "packages/understanding/runtime.py", ) FRONTEND_TYPECHECKS: tuple[tuple[str, tuple[str, ...], tuple[str, ...]], ...] = ( ( @@ -73,10 +84,22 @@ "tools/agent/scripts/agent_gate.py", ) RESET_BANNED_TERMS: tuple[tuple[str, str], ...] = ( - (" ".join(("voice", "mode")), "speech-mode contract is removed from reset surfaces"), - (" ".join(("voice", "prompt")), "speech prompt contract is removed from reset surfaces"), - (" ".join(("goal", "graph")), "current-work graph wording is removed from reset surfaces"), - (" ".join(("activity", "graph")), "activity-tree wording is removed from reset surfaces"), + ( + " ".join(("voice", "mode")), + "speech-mode contract is removed from reset surfaces", + ), + ( + " ".join(("voice", "prompt")), + "speech prompt contract is removed from reset surfaces", + ), + ( + " ".join(("goal", "graph")), + "current-work graph wording is removed from reset surfaces", + ), + ( + " ".join(("activity", "graph")), + "activity-tree wording is removed from reset surfaces", + ), ("packages.goals", "goal package is removed from reset surfaces"), ("GoalNode", "goal-node contract is removed from reset surfaces"), ("WorklineSnapshot", "workline snapshot contract is removed from reset surfaces"), @@ -90,21 +113,48 @@ ("goal_snapshot", "legacy goal snapshot event type is removed from reset surfaces"), ("goal_refs", "work_item_refs replaces goal_refs in reset surfaces"), ("goal_ids", "work_item_ids replaces goal_ids in reset surfaces"), - ("focus_activity_ids", "focus_work_item_ids replaces focus_activity_ids in reset surfaces"), - ("activity_candidates", "work_item_candidates replaces activity_candidates in reset surfaces"), - ("build_activity_routing_section", "work routing replaces activity routing in reset surfaces"), + ( + "focus_activity_ids", + "focus_work_item_ids replaces focus_activity_ids in reset surfaces", + ), + ( + "activity_candidates", + "work_item_candidates replaces activity_candidates in reset surfaces", + ), + ( + "build_activity_routing_section", + "work routing replaces activity routing in reset surfaces", + ), ("tool.profile.manage", "memory.curate owns model-visible durable memory writes"), ("tool.memory.upload", "upload cannot represent capture semantics"), ("tool.procedure.inspect", "procedure inspection is not model-visible"), ("tool.procedure.manage", "direct procedure management is not model-visible"), - ("DeterministicEpisodeObserver", "Personal Model learning must not use keyword observer fallback"), - ("PatternClusterer", "skill crystallization must not use ExperienceRecord-first clustering"), - ("DerivedProcedureCandidateStore", "skill crystallization candidates come from trajectory metrics"), - ("list_pattern_clusters", "ExperienceRecord-first learning cluster APIs are removed"), - ("list_procedure_candidates", "procedure candidates are no longer ExperienceRecord-derived"), + ( + "DeterministicEpisodeObserver", + "Personal Model learning must not use keyword observer fallback", + ), + ( + "PatternClusterer", + "skill crystallization must not use ExperienceRecord-first clustering", + ), + ( + "DerivedProcedureCandidateStore", + "skill crystallization candidates come from trajectory metrics", + ), + ( + "list_pattern_clusters", + "ExperienceRecord-first learning cluster APIs are removed", + ), + ( + "list_procedure_candidates", + "procedure candidates are no longer ExperienceRecord-derived", + ), ("/goals", "session-era goal routes are removed from reset surfaces"), ("/procedure", "session-era procedure routes are removed from reset surfaces"), - (" ".join(("intent", "layer")), "intent routing wording is removed from reset surfaces"), + ( + " ".join(("intent", "layer")), + "intent routing wording is removed from reset surfaces", + ), ( "/".join(("strong", "weak")) + " " + "model selection", "strong-or-weak routing wording is removed from reset surfaces", @@ -317,8 +367,7 @@ def build_context_pack(changed_files: list[str], matches: list[RuleMatch]) -> Co defaults = context_map.get("defaults", {}) start_here = [ - SurfaceRef(path=entry["path"], reason=entry.get("reason", "")) - for entry in defaults.get("start_here", []) + SurfaceRef(path=entry["path"], reason=entry.get("reason", "")) for entry in defaults.get("start_here", []) ] primary = matches[0] if matches else None @@ -363,10 +412,7 @@ def build_context_pack(changed_files: list[str], matches: list[RuleMatch]) -> Co resolved_surface_paths: dict[str, tuple[str, ...]] = {} for surface_name in sorted(active_surfaces): refs = surfaces_section.get(surface_name, []) - resolved_surfaces[surface_name] = [ - SurfaceRef(path=ref["path"], reason=ref.get("reason", "")) - for ref in refs - ] + resolved_surfaces[surface_name] = [SurfaceRef(path=ref["path"], reason=ref.get("reason", "")) for ref in refs] resolved_surface_paths[surface_name] = surface_path_map.get(surface_name, ()) # Also add rule-specific context from context-map rules section @@ -438,7 +484,7 @@ def context_repair_prompt() -> str: "tools/agent/context-map.yaml for path/surface gaps; update " "tools/agent/task-matrix.yaml or tools/agent/skill-registry.yaml if " "the primary skill, validation ladder, or required context is wrong; " - "then rerun make agent-context-audit CHANGED_FILES=\"...\"." + 'then rerun make agent-context-audit CHANGED_FILES="...".' ) @@ -569,9 +615,7 @@ def scan_reset_banned_terms( line_lower = line.lower() for term, rationale in banned_terms: if term.lower() in line_lower: - errors.append( - f"reset banned term in {relative_path}:{line_number}: {term} ({rationale})" - ) + errors.append(f"reset banned term in {relative_path}:{line_number}: {term} ({rationale})") return errors @@ -705,7 +749,9 @@ def print_report( print(" - If this diff is one controlled atomic unit and validation is green, ship it with:") print(" - make agent-ship AGENT_COMMIT_MESSAGE='(): '") print(" - agent-ship reruns the PR gate, creates a signed commit, and pushes the current branch to origin.") - print(" - Leave changes unshipped only when publish was explicitly deferred, the diff still needs splitting, or a validation failure remains.") + print( + " - Leave changes unshipped only when publish was explicitly deferred, the diff still needs splitting, or a validation failure remains." + ) if audit and pack.audit_warnings: print() @@ -794,13 +840,13 @@ def lint_python_file_lengths(changed_files: list[str], *, root: Path = ROOT) -> continue line_count = _line_count(path) if line_count > MAX_PYTHON_FILE_LINES: - errors.append( - f"python file exceeds {MAX_PYTHON_FILE_LINES} lines: {relative_path} ({line_count} lines)" - ) + errors.append(f"python file exceeds {MAX_PYTHON_FILE_LINES} lines: {relative_path} ({line_count} lines)") return errors -def frontend_typecheck_commands(changed_files: list[str]) -> tuple[tuple[str, tuple[str, ...]], ...]: +def frontend_typecheck_commands( + changed_files: list[str], +) -> tuple[tuple[str, tuple[str, ...]], ...]: selected: list[tuple[str, tuple[str, ...]]] = [] for name, patterns, command in FRONTEND_TYPECHECKS: if any(match_any(path, patterns) for path in changed_files): @@ -873,7 +919,12 @@ def main() -> int: sub.add_argument("--changed-files-path", default="") if name == "report": sub.add_argument("--context-detail", choices=["compact", "full"], default="compact") - sub.add_argument("--format", choices=["text", "json"], default="text", dest="output_format") + sub.add_argument( + "--format", + choices=["text", "json"], + default="text", + dest="output_format", + ) sub.add_argument("--audit", action="store_true", default=False) args = parser.parse_args() diff --git a/tools/agent/scripts/commit_msg_lint.py b/tools/agent/scripts/commit_msg_lint.py index f785eca..71abc98 100755 --- a/tools/agent/scripts/commit_msg_lint.py +++ b/tools/agent/scripts/commit_msg_lint.py @@ -10,7 +10,18 @@ ROOT = Path(__file__).resolve().parents[3] -ALLOWED_TYPES = ("build", "chore", "ci", "docs", "feat", "fix", "perf", "refactor", "revert", "test") +ALLOWED_TYPES = ( + "build", + "chore", + "ci", + "docs", + "feat", + "fix", + "perf", + "refactor", + "revert", + "test", +) COMMIT_RE = re.compile( r"^(?Pbuild|chore|ci|docs|feat|fix|perf|refactor|revert|test)" r"\((?P[a-z0-9][a-z0-9/-]*)\)" diff --git a/tools/agent/scripts/ship.py b/tools/agent/scripts/ship.py index 9d62e57..1ddf243 100755 --- a/tools/agent/scripts/ship.py +++ b/tools/agent/scripts/ship.py @@ -75,7 +75,13 @@ def resolve_base_ref(explicit: str) -> str: def lint_commit_message(message: str) -> None: result = run( - [sys.executable, "tools/agent/scripts/commit_msg_lint.py", "message", "--subject", message], + [ + sys.executable, + "tools/agent/scripts/commit_msg_lint.py", + "message", + "--subject", + message, + ], check=False, ) if result.returncode != 0: @@ -93,7 +99,16 @@ def run_pr_gate(paths: list[str], base_ref: str) -> None: def run_soft_audit(paths: list[str]) -> None: csv_files = ",".join(paths) result = subprocess.run( - [sys.executable, "tools/agent/scripts/agent_gate.py", "report", "--changed-files", csv_files, "--audit", "--format", "json"], + [ + sys.executable, + "tools/agent/scripts/agent_gate.py", + "report", + "--changed-files", + csv_files, + "--audit", + "--format", + "json", + ], cwd=ROOT, text=True, capture_output=True, @@ -103,6 +118,7 @@ def run_soft_audit(paths: list[str]) -> None: return try: import json + data = json.loads(result.stdout) except (ValueError, ImportError): return diff --git a/tools/agent/scripts/wave_manager.py b/tools/agent/scripts/wave_manager.py index 8aae707..aab5e2f 100644 --- a/tools/agent/scripts/wave_manager.py +++ b/tools/agent/scripts/wave_manager.py @@ -39,23 +39,29 @@ def require_head() -> None: def branch_exists(branch: str) -> bool: - return subprocess.run( - ["git", "show-ref", "--verify", f"refs/heads/{branch}"], - cwd=ROOT, - text=True, - capture_output=True, - check=False, - ).returncode == 0 + return ( + subprocess.run( + ["git", "show-ref", "--verify", f"refs/heads/{branch}"], + cwd=ROOT, + text=True, + capture_output=True, + check=False, + ).returncode + == 0 + ) def remote_branch_exists(branch: str, remote: str) -> bool: - return subprocess.run( - ["git", "ls-remote", "--exit-code", "--heads", remote, branch], - cwd=ROOT, - text=True, - capture_output=True, - check=False, - ).returncode == 0 + return ( + subprocess.run( + ["git", "ls-remote", "--exit-code", "--heads", remote, branch], + cwd=ROOT, + text=True, + capture_output=True, + check=False, + ).returncode + == 0 + ) def parse_worktree_records(output: str) -> list[dict[str, str]]: @@ -133,7 +139,15 @@ def start_wave(wave_id: str, root: Path, base: str) -> int: if branch_exists(track["branch"]): command = ["git", "worktree", "add", str(target), track["branch"]] else: - command = ["git", "worktree", "add", "-b", track["branch"], str(target), base] + command = [ + "git", + "worktree", + "add", + "-b", + track["branch"], + str(target), + base, + ] result = subprocess.run(command, cwd=ROOT, check=False) if result.returncode != 0: return result.returncode diff --git a/tools/agent/scripts/worktree_manager.py b/tools/agent/scripts/worktree_manager.py index 788a4a3..49a2600 100755 --- a/tools/agent/scripts/worktree_manager.py +++ b/tools/agent/scripts/worktree_manager.py @@ -29,13 +29,16 @@ def add_worktree(name: str, branch: str, base: str, root: Path) -> int: if target.exists(): raise SystemExit(f"worktree path already exists: {target}") - branch_exists = subprocess.run( - ["git", "show-ref", "--verify", f"refs/heads/{branch}"], - cwd=ROOT, - text=True, - capture_output=True, - check=False, - ).returncode == 0 + branch_exists = ( + subprocess.run( + ["git", "show-ref", "--verify", f"refs/heads/{branch}"], + cwd=ROOT, + text=True, + capture_output=True, + check=False, + ).returncode + == 0 + ) if branch_exists: command = ["git", "worktree", "add", str(target), branch] diff --git a/tools/agent/task-matrix.yaml b/tools/agent/task-matrix.yaml index 3a2a0f1..f2a0863 100644 --- a/tools/agent/task-matrix.yaml +++ b/tools/agent/task-matrix.yaml @@ -81,6 +81,7 @@ "execution_plan_policy": "long_horizon", "paths": [ ".github/workflows/**", + ".pre-commit-config.yaml", "install.sh", ".github/*.md", ".python-version",