Skip to content

Commit 6918ebe

Browse files
committed
fix: revert shell execution logic back to master implementation
1 parent 252c623 commit 6918ebe

1 file changed

Lines changed: 19 additions & 247 deletions

File tree

Lines changed: 19 additions & 247 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,17 @@
1-
"""
2-
ExecuteShellTool - subprocess-based shell execution with per-session state.
3-
4-
Replaces previous plumbum-based implementation with a subprocess-based,
5-
per-session state manager that tracks current working directory and
6-
per-session environment variables.
7-
8-
Behavior:
9-
- Each session has its own `cwd` and `env` stored in-memory.
10-
- `cd` commands are interpreted and update the session `cwd`.
11-
Supports constructs like `cd /path && ls` or `cd rel/path; echo hi`.
12-
- Foreground commands run to completion with a configurable timeout.
13-
- Background commands spawn a subprocess and return immediately with the pid.
14-
- Environment variables passed in `env` are merged with the session env.
15-
- Returns JSON string describing result to match existing tool contract.
16-
"""
17-
18-
from __future__ import annotations
19-
20-
import asyncio
211
import json
22-
import os
23-
import shlex
24-
import subprocess
252
from dataclasses import dataclass, field
26-
from typing import Any
273

4+
from astrbot.api import FunctionTool
285
from astrbot.core.agent.run_context import ContextWrapper
29-
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
6+
from astrbot.core.agent.tool import ToolExecResult
307
from astrbot.core.astr_agent_context import AstrAgentContext
318

9+
from ..computer_client import get_booter, get_local_booter
3210
from .permissions import check_admin_permission
3311

3412

3513
@dataclass
3614
class ExecuteShellTool(FunctionTool):
37-
"""
38-
Stateful shell execution tool using subprocess.
39-
40-
Each agent session keeps its own working directory and environment mapping.
41-
"""
42-
4315
name: str = "astrbot_execute_shell"
4416
description: str = "Execute a command in the shell."
4517
parameters: dict = field(
@@ -48,7 +20,7 @@ class ExecuteShellTool(FunctionTool):
4820
"properties": {
4921
"command": {
5022
"type": "string",
51-
"description": "The shell command to execute in the current runtime shell (for example, cmd.exe on Windows). Equivalent to running 'cd {working_dir} && {your_command}'.",
23+
"description": "The shell command to execute in the current runtime shell (for example, cmd.exe on Windows). Equal to 'cd {working_dir} && {your_command}'.",
5224
},
5325
"background": {
5426
"type": "boolean",
@@ -57,7 +29,7 @@ class ExecuteShellTool(FunctionTool):
5729
},
5830
"env": {
5931
"type": "object",
60-
"description": "Optional environment variables to set for the command (merged with session env).",
32+
"description": "Optional environment variables to set for the file creation process.",
6133
"additionalProperties": {"type": "string"},
6234
"default": {},
6335
},
@@ -67,226 +39,26 @@ class ExecuteShellTool(FunctionTool):
6739
)
6840

6941
is_local: bool = False
70-
is_stateful: bool = True
71-
# session_id -> {"cwd": str, "env": dict}
72-
_sessions: dict[str, dict[str, Any]] = field(
73-
default_factory=dict, init=False, repr=False
74-
)
75-
76-
def _get_session_state(self, session_id: str) -> dict[str, Any]:
77-
"""
78-
Initialize or return the per-session state.
79-
State contains:
80-
- cwd: current working directory for session
81-
- env: environment variables dict for session
82-
"""
83-
if session_id not in self._sessions:
84-
# start from current process cwd and a copy of os.environ
85-
self._sessions[session_id] = {
86-
"cwd": os.getcwd(),
87-
"env": dict(os.environ),
88-
}
89-
return self._sessions[session_id]
90-
91-
def _get_framework_session_state(
92-
self, context: ContextWrapper[AstrAgentContext], session_id: str
93-
) -> dict[str, Any]:
94-
"""
95-
Get session state via the framework's ToolSessionManager if available.
96-
Falls back to the tool's own _sessions if session_manager is not set.
97-
"""
98-
session_mgr = getattr(context, "session_manager", None)
99-
if session_mgr is None:
100-
return self._get_session_state(session_id)
101-
return session_mgr.get_state(session_id, self.name)
10242

10343
async def call(
104-
self, context: ContextWrapper[AstrAgentContext], **kwargs: Any
44+
self,
45+
context: ContextWrapper[AstrAgentContext],
46+
command: str,
47+
background: bool = False,
48+
env: dict = {},
10549
) -> ToolExecResult:
106-
"""
107-
Execute a shell command for the session.
108-
109-
Parameters are accepted via kwargs for compatibility with FunctionTool.call:
110-
- command (str): the shell command to execute
111-
- background (bool): whether to run in background
112-
- env (dict): environment variables to merge for this execution
113-
"""
114-
# Use the context directly - already typed as ContextWrapper[AstrAgentContext]
115-
astr_ctx = context
116-
117-
# Permission check (use the cast wrapper)
118-
if permission_error := check_admin_permission(astr_ctx, "Shell execution"):
50+
if permission_error := check_admin_permission(context, "Shell execution"):
11951
return permission_error
12052

121-
# Extract parameters with defaults for backward compatibility
122-
command: str = kwargs.get("command", "")
123-
background: bool = bool(kwargs.get("background", False))
124-
env: dict | None = kwargs.get("env")
125-
126-
# Resolve session id and session state (use the cast wrapper)
127-
session_id = astr_ctx.context.event.unified_msg_origin
128-
# Use framework ToolSessionManager if available, otherwise fall back
129-
if astr_ctx.session_manager is not None:
130-
state = self._get_framework_session_state(astr_ctx, session_id)
53+
if self.is_local:
54+
sb = get_local_booter()
13155
else:
132-
state = self._get_session_state(session_id)
133-
session_cwd = state.get("cwd", os.getcwd())
134-
session_env = state.get("env", dict(os.environ)).copy()
135-
136-
# Merge provided env into execution env (do not mutate saved session env)
137-
if env:
138-
exec_env = session_env.copy()
139-
exec_env.update({k: str(v) for k, v in env.items()})
140-
else:
141-
exec_env = session_env
142-
143-
# Determine timeout from config (fall back to 30)
144-
config = astr_ctx.context.context.get_config(umo=session_id)
145-
provider_settings: dict = {}
146-
if isinstance(config, dict):
147-
provider_settings = config.get("provider_settings") or {}
148-
try:
149-
timeout = int(provider_settings.get("tool_call_timeout", 30))
150-
except (ValueError, TypeError):
151-
timeout = 30
152-
153-
# Single atomic try block for overall execution to satisfy anti-nested-try rule.
154-
try:
155-
# Quick handling for explicit `cd` constructs that should change session cwd.
156-
# We support leading cd followed by && or ;: e.g. "cd dir && ls", "cd dir; ls"
157-
cmd_str = command.strip()
158-
159-
# Helper to split by shell '&&' or ';' while preserving remainder.
160-
remainder_cmd = ""
161-
cd_handled = False
162-
# Handle forms like: cd <path> && rest OR cd <path>; rest
163-
for sep in ("&&", ";"):
164-
if sep in cmd_str:
165-
left, right = cmd_str.split(sep, 1)
166-
left_strip = left.strip()
167-
if left_strip.startswith("cd"):
168-
remainder_cmd = right.strip()
169-
cd_part = left_strip
170-
cd_handled = True
171-
break
172-
else:
173-
# No separator case, but single 'cd' command or just 'cd /path'
174-
if cmd_str.startswith("cd"):
175-
cd_part = cmd_str
176-
remainder_cmd = ""
177-
cd_handled = True
178-
179-
if cd_handled:
180-
# parse cd argument
181-
parts = shlex.split(cd_part)
182-
# cd with no args -> home
183-
if len(parts) == 1:
184-
target = await asyncio.to_thread(os.path.expanduser, "~")
185-
else:
186-
target_raw = parts[1]
187-
# expand ~ and variables
188-
target_raw = await asyncio.to_thread(os.path.expanduser, target_raw)
189-
if await asyncio.to_thread(os.path.isabs, target_raw):
190-
target = target_raw
191-
else:
192-
target = await asyncio.to_thread(
193-
os.path.normpath, os.path.join(session_cwd, target_raw)
194-
)
195-
196-
target_exists = await asyncio.to_thread(os.path.exists, target)
197-
target_isdir = await asyncio.to_thread(os.path.isdir, target)
198-
if not target_exists or not target_isdir:
199-
result = {
200-
"success": False,
201-
"exit_code": -1,
202-
"stdout": "",
203-
"stderr": f"cd: no such directory: {target}",
204-
"cwd": session_cwd,
205-
}
206-
return json.dumps(result)
207-
208-
# Update session cwd permanently
209-
state["cwd"] = target
210-
session_cwd = target
211-
212-
# If there is no remaining command, just return success and new cwd
213-
if not remainder_cmd:
214-
result = {
215-
"success": True,
216-
"exit_code": 0,
217-
"stdout": "",
218-
"stderr": "",
219-
"cwd": session_cwd,
220-
}
221-
return json.dumps(result)
222-
223-
# Otherwise we'll execute the remainder using the updated cwd
224-
# Use the remainder command as the command to run below
225-
command_to_run = remainder_cmd
226-
else:
227-
command_to_run = cmd_str
228-
229-
# Background execution: spawn process and return pid immediately.
230-
if background:
231-
# Start background process; do not wait. Use shell to support pipes/redirects.
232-
popen = await asyncio.to_thread(
233-
subprocess.Popen,
234-
["/bin/sh", "-c", command_to_run],
235-
cwd=session_cwd,
236-
env=exec_env,
237-
stdout=subprocess.PIPE,
238-
stderr=subprocess.PIPE,
239-
)
240-
result = {
241-
"success": True,
242-
"background": True,
243-
"pid": popen.pid,
244-
"cwd": session_cwd,
245-
}
246-
return json.dumps(result)
247-
248-
# Foreground execution: run to completion, capture output.
249-
completed = await asyncio.to_thread(
250-
subprocess.run,
251-
["/bin/sh", "-c", command_to_run],
252-
cwd=session_cwd,
253-
env=exec_env,
254-
timeout=timeout,
255-
capture_output=True,
256-
text=True,
56+
sb = await get_booter(
57+
context.context.context,
58+
context.context.event.unified_msg_origin,
25759
)
258-
259-
exit_code = completed.returncode
260-
stdout = completed.stdout if completed.stdout is not None else ""
261-
stderr = completed.stderr if completed.stderr is not None else ""
262-
263-
result = {
264-
"success": exit_code == 0,
265-
"exit_code": exit_code,
266-
"stdout": stdout,
267-
"stderr": stderr,
268-
"cwd": session_cwd,
269-
}
60+
try:
61+
result = await sb.shell.exec(command, background=background, env=env)
27062
return json.dumps(result)
271-
272-
except subprocess.TimeoutExpired as e:
273-
return json.dumps(
274-
{
275-
"success": False,
276-
"exit_code": -1,
277-
"stdout": e.stdout or "",
278-
"stderr": f"Command timed out after {timeout} seconds",
279-
"cwd": session_cwd,
280-
}
281-
)
28263
except Exception as e:
283-
# Do not silently swallow errors; return an explicit failure payload.
284-
return json.dumps(
285-
{
286-
"success": False,
287-
"exit_code": -1,
288-
"stdout": "",
289-
"stderr": f"Error executing command: {e!s}",
290-
"cwd": session_cwd,
291-
}
292-
)
64+
return f"Error executing command: {str(e)}"

0 commit comments

Comments
 (0)