forked from agent0ai/agent-zero
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtty_session.py
More file actions
429 lines (359 loc) · 14.1 KB
/
Copy pathtty_session.py
File metadata and controls
429 lines (359 loc) · 14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
import asyncio, os, sys, platform, errno
_IS_WIN = platform.system() == "Windows"
if _IS_WIN:
import winpty # pip install pywinpty # type: ignore
import msvcrt
def _reconfigure_stream_errors(stream) -> None:
reconfigure = getattr(stream, "reconfigure", None)
if not callable(reconfigure):
return
reconfigure(errors="replace")
# Make stdin / stdout tolerant to broken UTF-8 so input() never aborts
_reconfigure_stream_errors(sys.stdin)
_reconfigure_stream_errors(sys.stdout)
# ──────────────────────────── PUBLIC CLASS ────────────────────────────
class TTYSession:
def __init__(self, cmd, *, cwd=None, env=None, encoding="utf-8", echo=False):
self.cmd = cmd if isinstance(cmd, str) else " ".join(cmd)
self.cwd = cwd
self.env = env or os.environ.copy()
self.encoding = encoding
self.echo = echo # ← store preference
self._proc = None
self._buf: asyncio.Queue = None # type: ignore
self._pump_task = None
self._pty_master = None
self._pty_master_ref = None
def __del__(self):
# Simple cleanup on object destruction
import nest_asyncio
nest_asyncio.apply()
if hasattr(self, "close"):
try:
asyncio.run(self.close())
except Exception:
pass
# ── user-facing coroutines ────────────────────────────────────────
async def start(self):
self._buf = asyncio.Queue()
if _IS_WIN:
self._proc = await _spawn_winpty(
self.cmd, self.cwd, self.env, self.echo
) # ← pass echo
else:
self._proc = await _spawn_posix_pty(
self.cmd, self.cwd, self.env, self.echo
) # ← pass echo
self._pty_master_ref = getattr(self._proc, "_pty_master_ref", None)
self._pty_master = (
self._pty_master_ref.get("fd")
if self._pty_master_ref is not None
else getattr(self._proc, "_pty_master", None)
)
self._pump_task = asyncio.create_task(self._pump_stdout())
async def close(self):
# Cancel the pump task if it exists
if self._pump_task:
self._pump_task.cancel()
try:
await self._pump_task
except asyncio.CancelledError:
pass
except Exception:
pass
# Terminate the process if it exists
if self._proc:
try:
if getattr(self._proc, "returncode", None) is None:
self._proc.terminate()
except ProcessLookupError:
pass
except Exception:
pass
try:
await self._proc.wait()
except Exception:
pass
self._release_pty_master()
self._proc = None
self._pump_task = None
def _release_pty_master(self):
"""Release the POSIX PTY master exactly once.
The fd number is invalidated before os.close() so that a concurrent or
later cleanup path cannot close the same integer after the OS has reused
it for another file/socket.
"""
ref = self._pty_master_ref
master = ref.get("fd") if ref is not None else self._pty_master
if master is None:
self._pty_master = None
return
if ref is not None:
ref["fd"] = None
self._pty_master = None
try:
loop = asyncio.get_running_loop()
loop.remove_reader(master)
except Exception:
pass
try:
os.close(master)
except OSError:
pass
self._pty_master_ref = None
async def send(self, data: str | bytes):
if self._proc is None:
raise RuntimeError("TTYSpawn is not started")
if not _IS_WIN:
master = (
self._pty_master_ref.get("fd")
if self._pty_master_ref is not None
else self._pty_master
)
if master is None:
raise RuntimeError("TTYSpawn PTY is closed")
if getattr(self._proc, "returncode", None) is not None:
raise RuntimeError("TTYSpawn process has exited")
if isinstance(data, str):
data = data.encode(self.encoding)
try:
self._proc.stdin.write(data) # type: ignore
await self._proc.stdin.drain() # type: ignore
except OSError as e:
if e.errno in (errno.EBADF, errno.EIO, errno.EINVAL):
self._release_pty_master()
raise RuntimeError("TTYSpawn PTY is closed") from e
raise
async def sendline(self, line: str):
await self.send(line + "\n")
async def wait(self):
if self._proc is None:
raise RuntimeError("TTYSpawn is not started")
return await self._proc.wait()
def kill(self):
"""Force-kill the running child process.
This is best-effort: if the process has already terminated (which can
happen if *close()* was called elsewhere or the child exited by
itself) we silently ignore the *ProcessLookupError* raised by
*asyncio.subprocess.Process.kill()*. This prevents race conditions
where multiple coroutines attempt to close the same session.
"""
if self._proc is None:
# Already closed or never started – nothing to do
return
# Only attempt to kill if the process is still running
if getattr(self._proc, "returncode", None) is None:
try:
self._proc.kill()
except ProcessLookupError:
# Child already gone – treat as successfully killed
pass
self._release_pty_master()
async def read(self, timeout=None):
# Return any decoded text the child produced, or None on timeout
try:
return await asyncio.wait_for(self._buf.get(), timeout)
except asyncio.TimeoutError:
return None
# backward-compat alias:
readline = read
async def read_full_until_idle(self, idle_timeout, total_timeout):
# Collect child output using iter_until_idle to avoid duplicate logic
return "".join(
[
chunk
async for chunk in self.read_chunks_until_idle(
idle_timeout, total_timeout
)
]
)
async def read_chunks_until_idle(self, idle_timeout, total_timeout):
# Yield each chunk as soon as it arrives until idle or total timeout
import time
start = time.monotonic()
while True:
if time.monotonic() - start > total_timeout:
break
chunk = await self.read(timeout=idle_timeout)
if chunk is None:
break
yield chunk
# ── internal: stream raw output into the queue ────────────────────
async def _pump_stdout(self):
if self._proc is None:
raise RuntimeError("TTYSpawn is not started")
reader = self._proc.stdout
while True:
chunk = await reader.read(4096) # grab whatever is ready # type: ignore
if not chunk:
break
self._buf.put_nowait(chunk.decode(self.encoding, "replace"))
# ──────────────────────────── POSIX IMPLEMENTATION ────────────────────
async def _spawn_posix_pty(cmd, cwd, env, echo):
import pty, asyncio, os, termios
master, slave = pty.openpty()
# ── Disable ECHO on the slave side if requested ──
if not echo:
attrs = termios.tcgetattr(slave)
attrs[3] &= ~termios.ECHO # lflag
termios.tcsetattr(slave, termios.TCSANOW, attrs)
proc = await asyncio.create_subprocess_shell(
cmd,
stdin=slave,
stdout=slave,
stderr=slave,
cwd=cwd,
env=env,
close_fds=True,
)
os.close(slave)
loop = asyncio.get_running_loop()
reader = asyncio.StreamReader()
master_ref = {"fd": master}
def _release_master_fd():
cur = master_ref.get("fd")
if cur is None:
return
# Invalidate before close so later cleanup cannot close a reused fd.
master_ref["fd"] = None
try:
proc._pty_master = None # type: ignore[attr-defined]
except Exception:
pass
try:
loop.remove_reader(cur)
except Exception:
pass
try:
os.close(cur)
except OSError:
pass
def _on_data():
cur = master_ref.get("fd")
if cur is None:
reader.feed_eof()
return
try:
data = os.read(cur, 1 << 16)
except OSError as e:
if e.errno != errno.EIO: # EIO == EOF on some systems
raise
data = b""
if data:
reader.feed_data(data)
else:
reader.feed_eof()
_release_master_fd()
loop.add_reader(master, _on_data)
class _Stdin:
def write(self, d):
cur = master_ref.get("fd")
if cur is None:
raise OSError(errno.EBADF, "PTY master closed")
os.write(cur, d)
async def drain(self):
await asyncio.sleep(0)
proc.stdin = _Stdin() # type: ignore
proc.stdout = reader
proc._pty_master = master # type: ignore[attr-defined]
proc._pty_master_ref = master_ref # type: ignore[attr-defined]
return proc
# ──────────────────────────── WINDOWS IMPLEMENTATION ──────────────────
async def _spawn_winpty(cmd, cwd, env, echo):
# Clean PowerShell startup: no logo, no profile, bypass execution policy for deterministic behavior
if cmd.strip().lower().startswith("powershell"):
if "-nolog" not in cmd.lower():
cmd = cmd.replace("powershell.exe", "powershell.exe -NoLogo -NoProfile -ExecutionPolicy Bypass", 1)
cols, rows = 80, 25
child = winpty.PtyProcess.spawn(cmd, dimensions=(rows, cols), cwd=cwd or os.getcwd(), env=env) # type: ignore
loop = asyncio.get_running_loop()
reader = asyncio.StreamReader()
async def _on_data():
while child.isalive():
try:
# Run blocking read in executor to not block event loop
data = await loop.run_in_executor(None, child.read, 1 << 16)
if data:
reader.feed_data(data.encode('utf-8') if isinstance(data, str) else data)
except EOFError:
break
except Exception:
await asyncio.sleep(0.01)
reader.feed_eof()
# Start pumping output in background
asyncio.create_task(_on_data())
class _Stdin:
def write(self, d):
# Use winpty's write method, not os.write
if isinstance(d, bytes):
d = d.decode('utf-8', errors='replace')
# Windows needs \r\n for proper line endings
if _IS_WIN:
d = d.replace('\n', '\r\n')
child.write(d)
async def drain(self):
await asyncio.sleep(0.01) # Give write time to complete
class _Proc:
def __init__(self):
self.stdin = _Stdin() # type: ignore
self.stdout = reader
self.pid = child.pid
self.returncode = None
async def wait(self):
while child.isalive():
await asyncio.sleep(0.2)
self.returncode = 0
return 0
def terminate(self):
if child.isalive():
child.terminate()
def kill(self):
if child.isalive():
child.kill()
return _Proc()
# ───────────────────────── INTERACTIVE DRIVER ─────────────────────────
if __name__ == "__main__":
async def interactive_shell():
shell_cmd, prompt_hint = ("powershell.exe", ">") if _IS_WIN else ("/bin/bash", "$")
# echo=False → suppress the shell’s own echo of commands
term = TTYSession(shell_cmd)
await term.start()
timeout = 1.0
print(f"Connected to {shell_cmd}.")
print("Type commands for the shell.")
print("• /t=<seconds> → change idle timeout")
print("• /exit → quit helper\n")
await term.sendline(" ")
print(await term.read_full_until_idle(timeout, timeout), end="", flush=True)
while True:
try:
user = input(f"(timeout={timeout}) {prompt_hint} ")
except (EOFError, KeyboardInterrupt):
print("\nLeaving…")
break
if user.lower() == "/exit":
break
if user.startswith("/t="):
try:
timeout = float(user.split("=", 1)[1])
print(f"[helper] idle timeout set to {timeout}s")
except ValueError:
print("[helper] invalid number")
continue
idle_timeout = timeout
total_timeout = 10 * idle_timeout
if user == "":
# Just read output, do not send empty line
async for chunk in term.read_chunks_until_idle(
idle_timeout, total_timeout
):
print(chunk, end="", flush=True)
else:
await term.sendline(user)
async for chunk in term.read_chunks_until_idle(
idle_timeout, total_timeout
):
print(chunk, end="", flush=True)
await term.sendline("exit")
await term.wait()
asyncio.run(interactive_shell())