|
58 | 58 | _real_stdin = sys.stdin |
59 | 59 | _real_stdout = sys.stdout |
60 | 60 |
|
61 | | -# Bash identifier rules: [A-Za-z_][A-Za-z0-9_]*. We refuse to wrap any tool |
62 | | -# whose name doesn't match — both for shell safety and because the user |
63 | | -# couldn't call it from bash anyway. |
| 61 | +# Bash identifier rules: [A-Za-z_][A-Za-z0-9_]*. Names that don't match |
| 62 | +# get normalized via `_normalize_bash_name` so the user can still call the |
| 63 | +# tool from bash — the SDK applies the same normalization client-side when |
| 64 | +# generating code. |
64 | 65 | _VALID_BASH_NAME = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") |
65 | 66 |
|
| 67 | +_BASH_RESERVED = frozenset( |
| 68 | + { |
| 69 | + "if", |
| 70 | + "then", |
| 71 | + "else", |
| 72 | + "elif", |
| 73 | + "fi", |
| 74 | + "case", |
| 75 | + "esac", |
| 76 | + "for", |
| 77 | + "while", |
| 78 | + "until", |
| 79 | + "do", |
| 80 | + "done", |
| 81 | + "in", |
| 82 | + "function", |
| 83 | + "select", |
| 84 | + "time", |
| 85 | + "coproc", |
| 86 | + "declare", |
| 87 | + "typeset", |
| 88 | + "local", |
| 89 | + "readonly", |
| 90 | + "export", |
| 91 | + "unset", |
| 92 | + } |
| 93 | +) |
| 94 | + |
| 95 | + |
| 96 | +def _normalize_bash_name(name: str) -> str: |
| 97 | + """Match SDK's normalizeToBashIdentifier so generated code can call functions.""" |
| 98 | + result = re.sub(r"[-\s.]", "_", name) |
| 99 | + result = re.sub(r"[^a-zA-Z0-9_]", "", result) |
| 100 | + if result and result[0].isdigit(): |
| 101 | + result = "_" + result |
| 102 | + if result in _BASH_RESERVED: |
| 103 | + result = result + "_tool" |
| 104 | + return result or "_unnamed" |
| 105 | + |
66 | 106 |
|
67 | 107 | def _write_message(msg: dict) -> None: |
68 | 108 | _real_stdout.write(json.dumps(msg) + DELIMITER) |
@@ -91,24 +131,25 @@ def _generate_rcfile(tools: list) -> str: |
91 | 131 | ] |
92 | 132 | for tool in tools: |
93 | 133 | name = tool.get("name", "") |
94 | | - if not _VALID_BASH_NAME.match(name): |
| 134 | + func_name = _normalize_bash_name(name) |
| 135 | + if not func_name or func_name == "_unnamed": |
95 | 136 | continue |
96 | 137 | lines.append( |
97 | | - f"{name}() {{\n" |
| 138 | + f"{func_name}() {{\n" |
98 | 139 | # Use an explicit conditional rather than ${1:-{}} — the brace-default |
99 | 140 | # form parses as ${1:-{} followed by a literal }, which appends a |
100 | 141 | # stray brace whenever $1 is set. |
101 | 142 | f' local input_json="$1"\n' |
102 | 143 | f' if [ -z "$input_json" ]; then input_json="{{}}"; fi\n' |
103 | 144 | f" local payload\n" |
104 | 145 | f" payload=$(jq -c -n --arg name {shlex.quote(name)} " |
105 | | - f'--argjson input "$input_json" \'{{name:$name,input:$input}}\' 2>/dev/null) || \\\n' |
| 146 | + f"--argjson input \"$input_json\" '{{name:$name,input:$input}}' 2>/dev/null) || \\\n" |
106 | 147 | f" payload=$(jq -c -n --arg name {shlex.quote(name)} " |
107 | | - f'--arg input "$input_json" \'{{name:$name,input:$input}}\')\n' |
| 148 | + f"--arg input \"$input_json\" '{{name:$name,input:$input}}')\n" |
108 | 149 | f' printf \'%s\\n\' "$payload" > "$PTC_CALL_FIFO"\n' |
109 | 150 | f" local result\n" |
110 | 151 | f' IFS= read -r result < "$PTC_RESULT_FIFO"\n' |
111 | | - f' printf \'%s\\n\' "$result"\n' |
| 152 | + f" printf '%s\\n' \"$result\"\n" |
112 | 153 | f"}}\n" |
113 | 154 | ) |
114 | 155 | return "\n".join(lines) |
@@ -216,9 +257,7 @@ def on_call_readable() -> None: |
216 | 257 | break |
217 | 258 |
|
218 | 259 | results = response.get("results", []) |
219 | | - target = next( |
220 | | - (r for r in results if r.get("call_id") == call_id), None |
221 | | - ) |
| 260 | + target = next((r for r in results if r.get("call_id") == call_id), None) |
222 | 261 | if target is None and results: |
223 | 262 | target = results[0] |
224 | 263 |
|
|
0 commit comments