Skip to content

Commit f9ae8bc

Browse files
committed
feat: Enhance tool name normalization and file handling in execution
- Added normalization functions for Python and Bash tool names to ensure compatibility with SDK-generated code. - Updated file handling in execution services to support new metadata fields, including `inherited`, `modified_from`, and `entity_id`. - Introduced read-only file handling during uploads, allowing for better management of file permissions in sandbox environments. - Enhanced unit tests to cover new features and ensure robust validation of file and tool name handling.
1 parent 3b5794b commit f9ae8bc

16 files changed

Lines changed: 985 additions & 66 deletions

docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ services:
1313
# nsjail requires these capabilities to create namespaces and cgroups
1414
cap_add:
1515
- SYS_ADMIN
16+
- NET_ADMIN
1617
security_opt:
1718
- apparmor:unconfined
1819
ports:

docker/ptc_bash_server.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,51 @@
5858
_real_stdin = sys.stdin
5959
_real_stdout = sys.stdout
6060

61-
# Bash identifier rules: [A-Za-z_][A-Za-z0-9_]*. We refuse to wrap any tool
62-
# whose name doesn't match — both for shell safety and because the user
63-
# couldn't call it from bash anyway.
61+
# Bash identifier rules: [A-Za-z_][A-Za-z0-9_]*. Names that don't match
62+
# get normalized via `_normalize_bash_name` so the user can still call the
63+
# tool from bash — the SDK applies the same normalization client-side when
64+
# generating code.
6465
_VALID_BASH_NAME = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
6566

67+
_BASH_RESERVED = frozenset(
68+
{
69+
"if",
70+
"then",
71+
"else",
72+
"elif",
73+
"fi",
74+
"case",
75+
"esac",
76+
"for",
77+
"while",
78+
"until",
79+
"do",
80+
"done",
81+
"in",
82+
"function",
83+
"select",
84+
"time",
85+
"coproc",
86+
"declare",
87+
"typeset",
88+
"local",
89+
"readonly",
90+
"export",
91+
"unset",
92+
}
93+
)
94+
95+
96+
def _normalize_bash_name(name: str) -> str:
97+
"""Match SDK's normalizeToBashIdentifier so generated code can call functions."""
98+
result = re.sub(r"[-\s.]", "_", name)
99+
result = re.sub(r"[^a-zA-Z0-9_]", "", result)
100+
if result and result[0].isdigit():
101+
result = "_" + result
102+
if result in _BASH_RESERVED:
103+
result = result + "_tool"
104+
return result or "_unnamed"
105+
66106

67107
def _write_message(msg: dict) -> None:
68108
_real_stdout.write(json.dumps(msg) + DELIMITER)
@@ -91,24 +131,25 @@ def _generate_rcfile(tools: list) -> str:
91131
]
92132
for tool in tools:
93133
name = tool.get("name", "")
94-
if not _VALID_BASH_NAME.match(name):
134+
func_name = _normalize_bash_name(name)
135+
if not func_name or func_name == "_unnamed":
95136
continue
96137
lines.append(
97-
f"{name}() {{\n"
138+
f"{func_name}() {{\n"
98139
# Use an explicit conditional rather than ${1:-{}} — the brace-default
99140
# form parses as ${1:-{} followed by a literal }, which appends a
100141
# stray brace whenever $1 is set.
101142
f' local input_json="$1"\n'
102143
f' if [ -z "$input_json" ]; then input_json="{{}}"; fi\n'
103144
f" local payload\n"
104145
f" payload=$(jq -c -n --arg name {shlex.quote(name)} "
105-
f'--argjson input "$input_json" \'{{name:$name,input:$input}}\' 2>/dev/null) || \\\n'
146+
f"--argjson input \"$input_json\" '{{name:$name,input:$input}}' 2>/dev/null) || \\\n"
106147
f" payload=$(jq -c -n --arg name {shlex.quote(name)} "
107-
f'--arg input "$input_json" \'{{name:$name,input:$input}}\')\n'
148+
f"--arg input \"$input_json\" '{{name:$name,input:$input}}')\n"
108149
f' printf \'%s\\n\' "$payload" > "$PTC_CALL_FIFO"\n'
109150
f" local result\n"
110151
f' IFS= read -r result < "$PTC_RESULT_FIFO"\n'
111-
f' printf \'%s\\n\' "$result"\n'
152+
f" printf '%s\\n' \"$result\"\n"
112153
f"}}\n"
113154
)
114155
return "\n".join(lines)
@@ -216,9 +257,7 @@ def on_call_readable() -> None:
216257
break
217258

218259
results = response.get("results", [])
219-
target = next(
220-
(r for r in results if r.get("call_id") == call_id), None
221-
)
260+
target = next((r for r in results if r.get("call_id") == call_id), None)
222261
if target is None and results:
223262
target = results[0]
224263

docker/ptc_server.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,66 @@
2626
import asyncio
2727
import json
2828
import os
29+
import re
2930
import sys
3031
import traceback
3132
import uuid
3233
from io import StringIO
3334

3435
DELIMITER = "\n---PTC_END---\n"
3536

37+
_PYTHON_KEYWORDS = frozenset(
38+
{
39+
"False",
40+
"None",
41+
"True",
42+
"and",
43+
"as",
44+
"assert",
45+
"async",
46+
"await",
47+
"break",
48+
"class",
49+
"continue",
50+
"def",
51+
"del",
52+
"elif",
53+
"else",
54+
"except",
55+
"finally",
56+
"for",
57+
"from",
58+
"global",
59+
"if",
60+
"import",
61+
"in",
62+
"is",
63+
"lambda",
64+
"nonlocal",
65+
"not",
66+
"or",
67+
"pass",
68+
"raise",
69+
"return",
70+
"try",
71+
"while",
72+
"with",
73+
"yield",
74+
}
75+
)
76+
77+
78+
def _normalize_python_name(name: str) -> str:
79+
"""Match SDK's normalizeToPythonIdentifier so generated code can call stubs."""
80+
result = re.sub(r"[-\s]", "_", name)
81+
result = re.sub(r"[^a-zA-Z0-9_]", "", result)
82+
if result and result[0].isdigit():
83+
result = "_" + result
84+
if result in _PYTHON_KEYWORDS:
85+
result = result + "_tool"
86+
return result or "_unnamed"
87+
88+
3689
# Keep references to the REAL stdin/stdout for protocol communication.
3790
# User code's print() will be redirected to a StringIO capture buffer.
3891
_real_stdin = sys.stdin
@@ -83,9 +136,7 @@ async def tool_stub(**kwargs):
83136

84137
result_info = _tool_results_map.pop(call_id)
85138
if result_info.get("is_error"):
86-
raise RuntimeError(
87-
result_info.get("error_message", "Tool call failed")
88-
)
139+
raise RuntimeError(result_info.get("error_message", "Tool call failed"))
89140
return result_info.get("result")
90141

91142
tool_stub.__name__ = tool_name
@@ -113,7 +164,8 @@ async def _execute_with_tools(
113164
pass
114165

115166
for tool in tools:
116-
namespace[tool["name"]] = _make_tool_stub(tool["name"])
167+
normalized = _normalize_python_name(tool["name"])
168+
namespace[normalized] = _make_tool_stub(tool["name"])
117169

118170
# Wrap user code in async function
119171
indented_code = "\n".join(" " + line for line in code.split("\n"))
@@ -137,10 +189,12 @@ async def _execute_with_tools(
137189
calls_to_send = list(_pending_calls)
138190
_pending_calls.clear()
139191

140-
_write_message({
141-
"type": "tool_calls",
142-
"calls": calls_to_send,
143-
})
192+
_write_message(
193+
{
194+
"type": "tool_calls",
195+
"calls": calls_to_send,
196+
}
197+
)
144198

145199
# Wait for results from host
146200
response = _read_message()
@@ -179,10 +233,12 @@ def main():
179233
try:
180234
request = _read_message()
181235
except Exception as e:
182-
_write_message({
183-
"type": "error",
184-
"error": f"Failed to read initial request: {e}",
185-
})
236+
_write_message(
237+
{
238+
"type": "error",
239+
"error": f"Failed to read initial request: {e}",
240+
}
241+
)
186242
return
187243

188244
code = request.get("code", "")
@@ -200,9 +256,7 @@ def main():
200256
sys.stderr = user_stderr
201257

202258
try:
203-
result = asyncio.run(
204-
_execute_with_tools(code, tools, user_stdout, user_stderr)
205-
)
259+
result = asyncio.run(_execute_with_tools(code, tools, user_stdout, user_stderr))
206260
except Exception as e:
207261
result = {
208262
"type": "error",

src/api/exec.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040
_KEEPALIVE_INTERVAL = 3
4141

4242

43-
@router.post("/exec", responses={200: {"model": ExecResponse}})
43+
@router.post(
44+
"/exec",
45+
responses={200: {"model": ExecResponse}},
46+
response_model_exclude_none=True,
47+
)
4448
async def execute_code(
4549
request: ExecRequest,
4650
http_request: Request,
@@ -175,7 +179,7 @@ async def _stream_response():
175179
request_id=request_id,
176180
session_id=response.session_id,
177181
)
178-
yield response.model_dump_json().encode()
182+
yield response.model_dump_json(exclude_none=True).encode()
179183

180184
return StreamingResponse(
181185
_stream_response(),

src/api/files.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ async def upload_files_batch(
250250
)
251251
is_agent_file = entity_id is not None
252252

253+
read_only_raw = form.get("read_only")
254+
is_read_only = isinstance(read_only_raw, str) and read_only_raw.lower() in (
255+
"1",
256+
"true",
257+
"yes",
258+
)
259+
253260
metadata = {"entity_id": entity_id} if entity_id else {}
254261
session = await session_service.create_session(SessionCreate(metadata=metadata))
255262
session_id = session.session_id
@@ -287,6 +294,7 @@ async def upload_files_batch(
287294
content=content,
288295
content_type=upload.content_type,
289296
is_agent_file=is_agent_file,
297+
is_read_only=is_read_only,
290298
)
291299

292300
results.append(

src/models/exec.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Standard library imports
44
from datetime import datetime
5-
from typing import List, Optional, Any
5+
from typing import Dict, List, Optional, Any
66

77
# Third-party imports
88
from pydantic import BaseModel, Field
@@ -15,6 +15,9 @@ class FileRef(BaseModel):
1515
name: str
1616
path: Optional[str] = None # Make path optional
1717
session_id: Optional[str] = None # Session ID for cross-message file persistence
18+
inherited: Optional[bool] = None
19+
entity_id: Optional[str] = None
20+
modified_from: Optional[Dict[str, str]] = None
1821

1922

2023
class RequestFile(BaseModel):
@@ -23,6 +26,7 @@ class RequestFile(BaseModel):
2326
id: str
2427
session_id: str
2528
name: str
29+
entity_id: Optional[str] = None
2630

2731

2832
class ExecRequest(BaseModel):
@@ -55,6 +59,12 @@ class ExecRequest(BaseModel):
5559
default_factory=list,
5660
description="Array of file references to be used during execution",
5761
)
62+
timeout: Optional[int] = Field(
63+
default=None,
64+
ge=1000,
65+
le=300000,
66+
description="Execution timeout in milliseconds",
67+
)
5868

5969

6070
class ExecResponse(BaseModel):

src/models/execution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Standard library imports
44
from datetime import datetime
55
from enum import Enum
6-
from typing import List, Optional
6+
from typing import Any, Dict, List, Optional
77

88
# Third-party imports
99
from pydantic import BaseModel, Field
@@ -42,6 +42,7 @@ class ExecutionOutput(BaseModel):
4242
default=None, description="Size in bytes for file outputs"
4343
)
4444
timestamp: datetime = Field(default_factory=datetime.utcnow)
45+
metadata: Optional[Dict[str, Any]] = None
4546

4647

4748
class CodeExecution(BaseModel):

0 commit comments

Comments
 (0)