|
| 1 | +"""Remote Device tool. |
| 2 | +
|
| 3 | +Run shell commands on a paired remote device via the DeviceBroker. |
| 4 | +""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import logging |
| 9 | +import time |
| 10 | +import uuid |
| 11 | +from datetime import datetime, timezone |
| 12 | +from typing import Any, Dict, Optional |
| 13 | + |
| 14 | +from application.agents.tools.base import Tool |
| 15 | +from application.devices.broker import get_broker |
| 16 | +from application.devices.denylist import check_denylist |
| 17 | +from application.devices.normalizer import normalize_command |
| 18 | +from application.storage.db.repositories.device_audit_log import ( |
| 19 | + DeviceAuditLogRepository, |
| 20 | +) |
| 21 | +from application.storage.db.repositories.device_auto_approve_patterns import ( |
| 22 | + DeviceAutoApprovePatternsRepository, |
| 23 | +) |
| 24 | +from application.storage.db.repositories.devices import DevicesRepository |
| 25 | +from application.storage.db.session import db_readonly, db_session |
| 26 | + |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +_DEFAULT_TIMEOUT_MS = 30_000 |
| 32 | +_MAX_TIMEOUT_MS = 600_000 |
| 33 | + |
| 34 | + |
| 35 | +class RemoteDeviceTool(Tool): |
| 36 | + """Remote Device |
| 37 | + Run shell commands on a paired remote machine via docsgpt-cli host. |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__(self, config: Optional[dict] = None, user_id: Optional[str] = None): |
| 41 | + self.config = config or {} |
| 42 | + self.user_id = user_id |
| 43 | + self.device_id = self.config.get("device_id") or "" |
| 44 | + self._device: Optional[dict] = None |
| 45 | + if self.device_id and self.user_id: |
| 46 | + self._device = self._load_device() |
| 47 | + |
| 48 | + def _load_device(self) -> Optional[dict]: |
| 49 | + try: |
| 50 | + with db_readonly() as conn: |
| 51 | + return DevicesRepository(conn).get(self.device_id, user_id=self.user_id) |
| 52 | + except Exception: |
| 53 | + logger.exception("failed to load device %s", self.device_id) |
| 54 | + return None |
| 55 | + |
| 56 | + # ------------------------------------------------------------------ |
| 57 | + # Tool ABC |
| 58 | + # ------------------------------------------------------------------ |
| 59 | + def get_actions_metadata(self): |
| 60 | + device = self._device or {} |
| 61 | + device_name = device.get("name") or "remote device" |
| 62 | + description = device.get("description") or "" |
| 63 | + approval_mode = device.get("approval_mode") or "ask" |
| 64 | + return [ |
| 65 | + { |
| 66 | + "name": "run_command", |
| 67 | + "description": ( |
| 68 | + f"Execute a shell command on the remote device " |
| 69 | + f"'{device_name}'. {description}".strip() |
| 70 | + ), |
| 71 | + "active": True, |
| 72 | + "require_approval": approval_mode != "full", |
| 73 | + "parameters": { |
| 74 | + "type": "object", |
| 75 | + "properties": { |
| 76 | + "command": { |
| 77 | + "type": "string", |
| 78 | + "description": "Shell command to run.", |
| 79 | + "filled_by_llm": True, |
| 80 | + "value": "", |
| 81 | + }, |
| 82 | + "working_directory": { |
| 83 | + "type": "string", |
| 84 | + "description": "Working directory on the remote.", |
| 85 | + "filled_by_llm": True, |
| 86 | + "value": "", |
| 87 | + }, |
| 88 | + "timeout_ms": { |
| 89 | + "type": "integer", |
| 90 | + "description": "Timeout in milliseconds (max 600000).", |
| 91 | + "filled_by_llm": True, |
| 92 | + "value": "", |
| 93 | + }, |
| 94 | + }, |
| 95 | + "required": ["command"], |
| 96 | + }, |
| 97 | + } |
| 98 | + ] |
| 99 | + |
| 100 | + def get_config_requirements(self): |
| 101 | + return { |
| 102 | + "device_id": { |
| 103 | + "type": "string", |
| 104 | + "label": "Device", |
| 105 | + "description": "Paired remote device id.", |
| 106 | + "required": True, |
| 107 | + "source": "devices", |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + def preview_requires_approval(self, action_name: str, params: dict) -> bool: |
| 112 | + """Live approval decision for a specific invocation. |
| 113 | +
|
| 114 | + The tool_executor gate calls this for ``remote_device`` so the |
| 115 | + decision considers the device's current ``approval_mode``, sticky |
| 116 | + patterns, and the denylist — rather than trusting the static |
| 117 | + ``user_tools.actions[].require_approval`` snapshot stored at pair |
| 118 | + time. Returns ``True`` when a prompt is required. |
| 119 | + """ |
| 120 | + requires_approval, _denylist_forced = self.preview_decision( |
| 121 | + action_name, params, |
| 122 | + ) |
| 123 | + return requires_approval |
| 124 | + |
| 125 | + def preview_decision( |
| 126 | + self, action_name: str, params: dict, |
| 127 | + ) -> tuple[bool, bool]: |
| 128 | + """Live approval decision plus whether it's a denylist-forced prompt. |
| 129 | +
|
| 130 | + Returns ``(requires_approval, denylist_forced)``. ``denylist_forced`` |
| 131 | + is True only when the prompt is mandated by the hard denylist, which |
| 132 | + a headless allowlist must never bypass. Unknown / inactive devices |
| 133 | + and missing commands require approval but are NOT denylist-forced. |
| 134 | + """ |
| 135 | + if action_name != "run_command": |
| 136 | + return True, False |
| 137 | + if not self.device_id or not self.user_id: |
| 138 | + return True, False |
| 139 | + if self._device is None: |
| 140 | + self._device = self._load_device() |
| 141 | + device = self._device |
| 142 | + if device is None or device.get("status") != "active": |
| 143 | + # Don't bypass the prompt for an unknown / inactive device; |
| 144 | + # execute_action will surface the error. |
| 145 | + return True, False |
| 146 | + command = ((params or {}).get("command") or "").strip() |
| 147 | + if not command: |
| 148 | + return True, False |
| 149 | + reason, effective_mode = self._decide_approval(device, command) |
| 150 | + denylist_forced = reason == "denylist_forced_prompt" |
| 151 | + return effective_mode != "full", denylist_forced |
| 152 | + |
| 153 | + def execute_action(self, action_name: str, **kwargs): |
| 154 | + if action_name != "run_command": |
| 155 | + return {"error": f"unknown action: {action_name}"} |
| 156 | + if not self.device_id or not self.user_id: |
| 157 | + return {"error": "device_id and user_id required"} |
| 158 | + if self._device is None: |
| 159 | + self._device = self._load_device() |
| 160 | + device = self._device |
| 161 | + if device is None: |
| 162 | + return {"error": "device not found"} |
| 163 | + if device.get("status") != "active": |
| 164 | + return {"error": f"device status: {device.get('status')}"} |
| 165 | + |
| 166 | + command = (kwargs.get("command") or "").strip() |
| 167 | + if not command: |
| 168 | + return {"error": "command is required"} |
| 169 | + working_directory = kwargs.get("working_directory") or "" |
| 170 | + timeout_ms = kwargs.get("timeout_ms") |
| 171 | + try: |
| 172 | + timeout_ms = int(timeout_ms) if timeout_ms else _DEFAULT_TIMEOUT_MS |
| 173 | + except (TypeError, ValueError): |
| 174 | + timeout_ms = _DEFAULT_TIMEOUT_MS |
| 175 | + timeout_ms = min(max(timeout_ms, 1), _MAX_TIMEOUT_MS) |
| 176 | + |
| 177 | + decision_reason, effective_mode = self._decide_approval(device, command) |
| 178 | + denied = self._denylist_label(command) |
| 179 | + |
| 180 | + envelope = { |
| 181 | + "invocation_id": "inv_" + uuid.uuid4().hex, |
| 182 | + "action": "run_command", |
| 183 | + "params": { |
| 184 | + "command": command, |
| 185 | + "working_directory": working_directory, |
| 186 | + "timeout_ms": timeout_ms, |
| 187 | + }, |
| 188 | + "approval_mode": effective_mode, |
| 189 | + "issued_at": datetime.now(timezone.utc).isoformat(), |
| 190 | + } |
| 191 | + broker = get_broker() |
| 192 | + inv = broker.dispatch_invocation(self.device_id, self.user_id, envelope) |
| 193 | + |
| 194 | + try: |
| 195 | + with db_session() as conn: |
| 196 | + DeviceAuditLogRepository(conn).record_dispatch( |
| 197 | + device_id=self.device_id, |
| 198 | + user_id=self.user_id, |
| 199 | + invocation_id=inv.invocation_id, |
| 200 | + command=command, |
| 201 | + working_dir=working_directory, |
| 202 | + approval_mode=effective_mode, |
| 203 | + decision="dispatched", |
| 204 | + decision_reason=decision_reason or ("denylist:" + denied if denied else None), |
| 205 | + issued_at=datetime.now(timezone.utc), |
| 206 | + ) |
| 207 | + except Exception: |
| 208 | + logger.exception("audit record_dispatch failed for %s", inv.invocation_id) |
| 209 | + |
| 210 | + return self._collect_result(broker, inv, device, timeout_ms) |
| 211 | + |
| 212 | + # ------------------------------------------------------------------ |
| 213 | + # Internals |
| 214 | + # ------------------------------------------------------------------ |
| 215 | + def _decide_approval(self, device: dict, command: str) -> tuple[Optional[str], str]: |
| 216 | + """Resolve the effective approval mode + a short audit reason. |
| 217 | +
|
| 218 | + Effective mode is ``full`` (auto-run, no prompt) or ``ask`` (prompt). |
| 219 | + """ |
| 220 | + mode = device.get("approval_mode") or "ask" |
| 221 | + # Denylist forces a prompt on every path — full access and the |
| 222 | + # ask-mode sticky auto-approve alike. |
| 223 | + if check_denylist(command): |
| 224 | + return ("denylist_forced_prompt", "ask") |
| 225 | + if mode == "full": |
| 226 | + return ("full_access_passthrough", "full") |
| 227 | + # mode == "ask" |
| 228 | + if self._matches_sticky(command): |
| 229 | + return ("sticky_auto_approve", "full") |
| 230 | + return ("user_approval_required", "ask") |
| 231 | + |
| 232 | + def _denylist_label(self, command: str) -> Optional[str]: |
| 233 | + return check_denylist(command) |
| 234 | + |
| 235 | + def _matches_sticky(self, command: str) -> bool: |
| 236 | + pattern = normalize_command(command) |
| 237 | + if not pattern: |
| 238 | + return False |
| 239 | + try: |
| 240 | + with db_readonly() as conn: |
| 241 | + return DeviceAutoApprovePatternsRepository(conn).has_pattern( |
| 242 | + self.device_id, self.user_id, pattern, |
| 243 | + ) |
| 244 | + except Exception: |
| 245 | + logger.exception("sticky lookup failed") |
| 246 | + return False |
| 247 | + |
| 248 | + def _collect_result(self, broker, inv, device: dict, timeout_ms: int) -> Dict[str, Any]: |
| 249 | + """Drain output from the broker until the control chunk arrives.""" |
| 250 | + deadline = time.time() + (timeout_ms / 1000.0) + 5.0 |
| 251 | + stdout = [] |
| 252 | + stderr = [] |
| 253 | + try: |
| 254 | + for chunk in broker.drain_output( |
| 255 | + inv.invocation_id, timeout=1.0, deadline=deadline |
| 256 | + ): |
| 257 | + if time.time() > deadline: |
| 258 | + break |
| 259 | + stream = chunk.get("stream") |
| 260 | + if stream == "stdout": |
| 261 | + stdout.append(chunk.get("chunk", "")) |
| 262 | + elif stream == "stderr": |
| 263 | + stderr.append(chunk.get("chunk", "")) |
| 264 | + elif stream == "control": |
| 265 | + # control chunks include exit_code; drain loop will stop next iter |
| 266 | + pass |
| 267 | + finally: |
| 268 | + broker.cleanup_invocation(inv.invocation_id) |
| 269 | + |
| 270 | + # Deadline hit with no control chunk: the device never connected or |
| 271 | + # never finished. Surface a clear timeout instead of empty success. |
| 272 | + if not inv.completed.is_set() and inv.exit_code is None and not inv.error: |
| 273 | + inv.error = "device did not respond (timed out)" |
| 274 | + |
| 275 | + return { |
| 276 | + "exit_code": inv.exit_code, |
| 277 | + "stdout": "".join(stdout) if stdout else "".join(inv.stdout_parts), |
| 278 | + "stderr": "".join(stderr) if stderr else "".join(inv.stderr_parts), |
| 279 | + "duration_ms": inv.duration_ms, |
| 280 | + "device_name": device.get("name"), |
| 281 | + "error": inv.error, |
| 282 | + } |
0 commit comments