diff --git a/pyproject.toml b/pyproject.toml index bdb710358a..1a162bb977 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,11 +57,11 @@ rocklet = [ # Rocklet execution environment dependencies "fastapi", "uvicorn", - "pexpect", - "bashlex", + "pexpect; sys_platform != 'win32'", + "bashlex; sys_platform != 'win32'", "psutil", - "twisted", - "gem-llm>=0.1.0", + "twisted; sys_platform != 'win32'", + "gem-llm>=0.1.0; sys_platform != 'win32'", "numpy<=2.2.6", ] diff --git a/rock/rocklet/exceptions.py b/rock/rocklet/exceptions.py index af6d7de651..5f6c6b27b6 100644 --- a/rock/rocklet/exceptions.py +++ b/rock/rocklet/exceptions.py @@ -25,6 +25,21 @@ def __init__(self, message: str, *, extra_info: dict[str, Any] = None): self.extra_info = extra_info +class PowerShellError(RockletException, RuntimeError): + """Raised when a PowerShell session encounters an error.""" + + def __init__(self, message: str, *, extra_info: dict[str, Any] = None): + super().__init__(message) + if extra_info is None: + extra_info = {} + self.extra_info = extra_info + + +class PowerShellNotFoundError(RockletException, FileNotFoundError): + """Raised when PowerShell executable is not found on the system.""" + pass + + class CommandTimeoutError(RockletException, RuntimeError, TimeoutError): ... diff --git a/rock/rocklet/local_sandbox.py b/rock/rocklet/local_sandbox.py index 1bc1c11905..15aa76a9ca 100644 --- a/rock/rocklet/local_sandbox.py +++ b/rock/rocklet/local_sandbox.py @@ -1,8 +1,11 @@ import asyncio import os +import queue as queue_module import re import shutil import subprocess +import sys +import threading import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -10,13 +13,16 @@ from pathlib import Path from typing import Any -import bashlex -import bashlex.ast -import gem -import pexpect import psutil from typing_extensions import Self +IS_WINDOWS = sys.platform == "win32" + +if not IS_WINDOWS: + import bashlex + import bashlex.ast + import pexpect + from rock.actions import ( AbstractSandbox, BashObservation, @@ -55,75 +61,77 @@ CommandTimeoutError, NoExitCodeError, NonZeroExitCodeError, + PowerShellNotFoundError, SessionDoesNotExistError, SessionExistsError, SessionNotInitializedError, ) from rock.utils import get_executor -__all__ = ["LocalSandboxRuntime", "BashSession"] +__all__ = ["LocalSandboxRuntime", "BashSession", "PowerShellSession"] logger = init_logger("rock.actions.local") -def _split_bash_command(inpt: str) -> list[str]: - r"""Split a bash command with linebreaks, escaped newlines, and heredocs into a list of - individual commands. +def _strip_control_chars(s: str) -> str: + ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", s) - Args: - inpt: The input string to split into commands. - Returns: - A list of commands as strings. - Examples: +if not IS_WINDOWS: - "cmd1\ncmd2" are two commands - "cmd1\\\n asdf" is one command (because the linebreak is escaped) - "cmd1< tuple[int, int]: - start = cmd.pos[0] # type: ignore - end = cmd.pos[1] # type: ignore - for part in getattr(cmd, "parts", []): - part_start, part_end = find_range(part) - start = min(start, part_start) - end = max(end, part_end) - return start, end - - for cmd in parsed: - start, end = find_range(cmd) - cmd_strings.append(inpt[start:end]) - return cmd_strings + def _split_bash_command(inpt: str) -> list[str]: + r"""Split a bash command with linebreaks, escaped newlines, and heredocs into a list of + individual commands. + Args: + inpt: The input string to split into commands. + Returns: + A list of commands as strings. -def _strip_control_chars(s: str) -> str: - ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") - return ansi_escape.sub("", s) - + Examples: -def _check_bash_command(command: str) -> None: - """Check if a bash command is valid. Raises BashIncorrectSyntaxError if it's not.""" - _unique_string = "SOUNIQUEEOF" - cmd = f"/bin/bash -n << '{_unique_string}'\n{command}\n{_unique_string}" - result = subprocess.run(cmd, shell=True, capture_output=True) - if result.returncode == 0: - return - stdout = result.stdout.decode(errors="backslashreplace") - stderr = result.stderr.decode(errors="backslashreplace") - msg = ( - f"Error (exit code {result.returncode}) while checking bash command \n{command!r}\n" - f"---- Stderr ----\n{stderr}\n---- Stdout ----\n{stdout}" - ) - exc = BashIncorrectSyntaxError(msg, extra_info={"bash_stdout": stdout, "bash_stderr": stderr}) - raise exc + "cmd1\ncmd2" are two commands + "cmd1\\\n asdf" is one command (because the linebreak is escaped) + "cmd1< tuple[int, int]: + start = cmd.pos[0] # type: ignore + end = cmd.pos[1] # type: ignore + for part in getattr(cmd, "parts", []): + part_start, part_end = find_range(part) + start = min(start, part_start) + end = max(end, part_end) + return start, end + + for cmd in parsed: + start, end = find_range(cmd) + cmd_strings.append(inpt[start:end]) + return cmd_strings + + def _check_bash_command(command: str) -> None: + """Check if a bash command is valid. Raises BashIncorrectSyntaxError if it's not.""" + _unique_string = "SOUNIQUEEOF" + cmd = f"/bin/bash -n << '{_unique_string}'\n{command}\n{_unique_string}" + result = subprocess.run(cmd, shell=True, capture_output=True) + if result.returncode == 0: + return + stdout = result.stdout.decode(errors="backslashreplace") + stderr = result.stderr.decode(errors="backslashreplace") + msg = ( + f"Error (exit code {result.returncode}) while checking bash command \n{command!r}\n" + f"---- Stderr ----\n{stderr}\n---- Stdout ----\n{stdout}" + ) + exc = BashIncorrectSyntaxError(msg, extra_info={"bash_stdout": stdout, "bash_stderr": stderr}) + raise exc class Session(ABC): @@ -140,244 +148,458 @@ async def close(self) -> CloseSessionResponse: ... -class BashSession(Session): - _UNIQUE_STRING = "UNIQUESTRING29234" +if not IS_WINDOWS: + + class BashSession(Session): + _UNIQUE_STRING = "UNIQUESTRING29234" + + def __init__(self, request: CreateBashSessionRequest): + """This basically represents one REPL that we control. + + It's pretty similar to a `pexpect.REPLWrapper`. + """ + self.request = request + self._ps1 = "SHELLPS1PREFIX" + self._shell: pexpect.spawn | None = None + self._executor = get_executor() + self._system_user = "root" + + @property + def shell(self) -> pexpect.spawn: + if self._shell is None: + msg = "shell not initialized" + raise RuntimeError(msg) + return self._shell + + def _get_reset_commands(self) -> list[str]: + """Commands to reset the PS1, PS2, and PS0 variables to their default values.""" + return [ + "unset PROMPT_COMMAND", + f"export PS1='{self._ps1}'", + "export PS2=''", + "export PS0=''", + ] + + async def start(self) -> CreateBashSessionResponse: + """Spawn the session, source any startupfiles and set the PS1.""" + if self.request.env_enable: + env = os.environ.copy() + else: + env = {} + env.update({"PS1": self._ps1, "PS2": "", "PS0": ""}) + if self.request.env is not None: + env.update(self.request.env) + logger.info(f"env:{env}") + command = "/bin/bash" + if self.request.remote_user is not None and self.request.remote_user != self._system_user: + command = f"su {self.request.remote_user} -c {command}" + self._shell = pexpect.spawn( + command, + encoding="utf-8", + codec_errors="backslashreplace", + echo=False, + env=env, # type: ignore + maxread=self.request.max_read_size, + ) + time.sleep(0.3) + cmds = [] + if self.request.startup_source: + cmds += [f"source {path}" for path in self.request.startup_source] + ["sleep 0.3"] + cmds += self._get_reset_commands() + cmd = " ; ".join(cmds) + self.shell.sendline(cmd) + self.shell.expect(self._ps1, timeout=self.request.startup_timeout) + output = _strip_control_chars(self.shell.before) # type: ignore + self.refresh_shell() + return CreateBashSessionResponse(output=output) - def __init__(self, request: CreateBashSessionRequest): - """This basically represents one REPL that we control. + def _eat_following_output(self, timeout: float = 0.5) -> str: + """Return all output that happens in the next `timeout` seconds.""" + time.sleep(timeout) + try: + output = self.shell.read_nonblocking(timeout=0.1) + except pexpect.TIMEOUT: + return "" + return _strip_control_chars(output) + + async def run(self, action: BashAction) -> BashObservation: + """Run a bash action. + + Raises: + SessionNotInitializedError: If the shell is not initialized. + CommandTimeoutError: If the command times out. + NonZeroExitCodeError: If the command has a non-zero exit code and `action.check` is True. + NoExitCodeError: If we cannot get the exit code of the command. + + Returns: + BashObservation: The observation of the command. + """ + if self.shell is None: + msg = "shell not initialized" + raise SessionNotInitializedError(msg) + if action.is_interactive_command or action.is_interactive_quit: + return await self._aync_run_interactive(action) + r = await self._async_run_normal(action) + if action.check == "raise" and r.exit_code != 0: + msg = f"Command {action.command!r} failed with exit code {r.exit_code}. Here is the output:\n{r.output!r}" + if action.error_msg: + msg = f"{action.error_msg}: {msg}" + raise NonZeroExitCodeError(msg) + return r + + async def _aync_run_interactive(self, action: BashAction) -> BashObservation: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._executor, self._run_interactive, action) + + def _run_interactive(self, action: BashAction) -> BashObservation: + """Run an interactive action. This is different because we don't seek to + the PS1 and don't attempt to get the exit code. + """ + assert self.shell is not None + self.shell.sendline(action.command) + expect_strings = action.expect + [self._ps1] + try: + self.shell.expect(expect_strings, timeout=action.timeout) # type: ignore + # matched_expect_string = expect_strings[expect_index] + except pexpect.TIMEOUT as e: + msg = f"timeout after {action.timeout} seconds while running command {action.command!r}" + raise CommandTimeoutError(msg) from e + finally: + self.refresh_shell() + output: str = _strip_control_chars(self.shell.before).strip() # type: ignore + if action.is_interactive_quit: + assert not action.is_interactive_command + self.shell.setecho(False) + self.shell.waitnoecho() + self.shell.sendline(f"stty -echo; echo '{self._UNIQUE_STRING}'") + # Might need two expects for some reason + self.shell.expect(self._UNIQUE_STRING, timeout=1) + self.shell.expect(self._ps1, timeout=1) + else: + # Interactive command. + # For some reason, this often times enables echo mode within the shell. + output = output.lstrip().removeprefix(action.command).strip() + return BashObservation(output=output, exit_code=0, expect_string=self.shell.after) + + def refresh_shell(self): + logger.debug(f"before refresh before_content: {self.shell._before.getvalue()}") + logger.debug(f"before refresh buffer_content: {self.shell._buffer.getvalue()}") + self.shell._before.seek(0) + self.shell._before.truncate(0) + self.shell._buffer.seek(0) + self.shell._buffer.truncate(0) + logger.debug(f"after refresh before_content: {self.shell._before.getvalue()}") + logger.debug(f"after refresh buffer_content: {self.shell._buffer.getvalue()}") + + async def _async_run_normal(self, action: BashAction) -> BashObservation: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._executor, self._run_normal, action) + + def _run_normal(self, action: BashAction) -> BashObservation: + """Run a normal action. This is the default mode. + + There are three steps to this: + + 1. Check if the command is valid + 2. Execute the command + 3. Get the exit code + """ + action = deepcopy(action) + + assert self.shell is not None + _check_bash_command(action.command) + + # Part 2: Execute the command + + fallback_terminator = False + # Running multiple interactive commands by sending them with linebreaks would break things + # because we get multiple PS1s back to back. Instead we just join them with ; + # However, sometimes bashlex errors and we can't do this. In this case + # we add a unique string to the end of the command and then seek to that + # (which is also somewhat brittle, so we don't do this by default). + try: + individual_commands = _split_bash_command(action.command) + except Exception as e: + # Bashlex is very buggy and can throw a variety of errors, including + # ParsingErrors, NotImplementedErrors, TypeErrors, possibly more. So we catch them all + logger.error("Bashlex fail: %s", e) + action.command += f"\n TMPEXITCODE=$? ; sleep 0.1; echo '{self._UNIQUE_STRING}' ; (exit $TMPEXITCODE)" + fallback_terminator = True + else: + action.command = " ; ".join(individual_commands) + self.shell.sendline(action.command) + if not fallback_terminator: + expect_strings = action.expect + [self._ps1] + else: + expect_strings = [self._UNIQUE_STRING] + try: + expect_index = self.shell.expect(expect_strings, timeout=action.timeout) # type: ignore + matched_expect_string = expect_strings[expect_index] + except pexpect.TIMEOUT as e: + msg = f"timeout after {action.timeout} seconds while running command {action.command!r}" + raise CommandTimeoutError(msg) from e + finally: + self.refresh_shell() + output: str = _strip_control_chars(self.shell.before).strip() # type: ignore + + # Part 3: Get the exit code + if action.check == "ignore": + return BashObservation(output=output, exit_code=None, expect_string=matched_expect_string) - It's pretty similar to a `pexpect.REPLWrapper`. - """ + try: + _exit_code_prefix = "EXITCODESTART" + _exit_code_suffix = "EXITCODEEND" + self.shell.sendline(f"\necho {_exit_code_prefix}$?{_exit_code_suffix}") + try: + self.shell.expect(_exit_code_suffix, timeout=1) + except pexpect.TIMEOUT: + msg = "timeout while getting exit code" + raise NoExitCodeError(msg) + exit_code_raw: str = _strip_control_chars(self.shell.before).strip() # type: ignore + exit_code = re.findall(f"{_exit_code_prefix}([0-9]+)", exit_code_raw) + if len(exit_code) != 1: + msg = f"failed to parse exit code from output {exit_code_raw!r} (command: {action.command!r}, matches: {exit_code})" + raise NoExitCodeError(msg) + output += exit_code_raw.split(_exit_code_prefix)[0] + exit_code = int(exit_code[0]) + # We get at least one more PS1 here. + try: + self.shell.expect(self._ps1, timeout=0.1) + except pexpect.TIMEOUT: + msg = "Timeout while getting PS1 after exit code extraction" + raise CommandTimeoutError(msg) + output = output.replace(self._UNIQUE_STRING, "").replace(self._ps1, "") + except Exception: + # Ignore all exceptions if check == 'silent' + if action.check == "raise": + raise + exit_code = None + finally: + self.refresh_shell() + return BashObservation(output=output, exit_code=exit_code, expect_string=matched_expect_string) + + async def close(self) -> CloseSessionResponse: + if self._shell is None: + return CloseBashSessionResponse() + self.shell.close() + self._shell = None + return CloseBashSessionResponse() + + def interact(self) -> None: + """Enter interactive mode.""" + self.shell.interact() + + +class PowerShellSession(Session): + """A session that runs PowerShell commands on Windows. + + Uses subprocess.Popen with a background reader thread to interact + with PowerShell interactively via stdin/stdout pipes. + Unique marker strings are used to delimit command output boundaries + and extract exit codes. + """ + + _BEGIN_MARKER = "ROCKLET_PS_BEGIN_29234" + _END_MARKER = "ROCKLET_PS_END_29234" + _EXIT_MARKER = "ROCKLET_PS_EXIT_29234:" + + def __init__(self, request: CreateBashSessionRequest): self.request = request - self._ps1 = "SHELLPS1PREFIX" - self._shell: pexpect.spawn | None = None + self._process: subprocess.Popen | None = None self._executor = get_executor() - self._system_user = "root" + self._output_queue: queue_module.Queue = queue_module.Queue() + self._reader_thread: threading.Thread | None = None + self._lock = threading.Lock() + + @staticmethod + def _find_powershell() -> str: + """Find the PowerShell executable. Prefers pwsh (PowerShell Core) over powershell.""" + for cmd in ["pwsh", "powershell"]: + if shutil.which(cmd): + return cmd + raise PowerShellNotFoundError( + "PowerShell executable not found. " + "Install PowerShell Core ('pwsh') or ensure 'powershell.exe' is in PATH." + ) - @property - def shell(self) -> pexpect.spawn: - if self._shell is None: - msg = "shell not initialized" - raise RuntimeError(msg) - return self._shell - - def _get_reset_commands(self) -> list[str]: - """Commands to reset the PS1, PS2, and PS0 variables to their default values.""" - return [ - "unset PROMPT_COMMAND", - f"export PS1='{self._ps1}'", - "export PS2=''", - "export PS0=''", - ] + def _read_stdout(self) -> None: + """Background thread that continuously reads stdout line by line.""" + try: + assert self._process is not None and self._process.stdout is not None + for line in iter(self._process.stdout.readline, ""): + if not line: + break + self._output_queue.put(line) + except Exception: + pass + + def _drain_queue(self, timeout: float = 0.1) -> str: + """Drain all currently available lines from the output queue.""" + lines = [] + deadline = time.time() + timeout + while time.time() < deadline: + try: + line = self._output_queue.get(timeout=0.05) + lines.append(line.rstrip("\r\n")) + except queue_module.Empty: + if lines: + break + continue + return "\n".join(lines) + + def _send_line(self, text: str) -> None: + """Send a line of text to PowerShell's stdin.""" + assert self._process is not None and self._process.stdin is not None + self._process.stdin.write(text + "\n") + self._process.stdin.flush() async def start(self) -> CreateBashSessionResponse: - """Spawn the session, source any startupfiles and set the PS1.""" - if self.request.env_enable: - env = os.environ.copy() - else: - env = {} - env.update({"PS1": self._ps1, "PS2": "", "PS0": ""}) + """Start the PowerShell session.""" + env = os.environ.copy() if self.request.env is not None: env.update(self.request.env) - logger.info(f"env:{env}") - command = "/bin/bash" - if self.request.remote_user is not None and self.request.remote_user != self._system_user: - command = f"su {self.request.remote_user} -c {command}" - self._shell = pexpect.spawn( - command, + + ps_cmd = self._find_powershell() + logger.info(f"Starting PowerShell session with: {ps_cmd}") + + self._process = subprocess.Popen( + [ps_cmd, "-NoLogo", "-NoProfile", "-NoExit"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, encoding="utf-8", - codec_errors="backslashreplace", - echo=False, - env=env, # type: ignore - maxread=self.request.max_read_size, + bufsize=1, + env=env, ) - time.sleep(0.3) - cmds = [] - if self.request.startup_source: - cmds += [f"source {path}" for path in self.request.startup_source] + ["sleep 0.3"] - cmds += self._get_reset_commands() - cmd = " ; ".join(cmds) - self.shell.sendline(cmd) - self.shell.expect(self._ps1, timeout=self.request.startup_timeout) - output = _strip_control_chars(self.shell.before) # type: ignore - self.refresh_shell() - return CreateBashSessionResponse(output=output) - - def _eat_following_output(self, timeout: float = 0.5) -> str: - """Return all output that happens in the next `timeout` seconds.""" - time.sleep(timeout) - try: - output = self.shell.read_nonblocking(timeout=0.1) - except pexpect.TIMEOUT: - return "" - return _strip_control_chars(output) + + self._reader_thread = threading.Thread(target=self._read_stdout, daemon=True) + self._reader_thread.start() + + # Wait for PowerShell to initialize + time.sleep(0.5) + + # Drain startup output + startup_output = self._drain_queue(timeout=0.5) + logger.info(f"PowerShell session started, startup output: {startup_output[:200]}") + return CreateBashSessionResponse(output=startup_output) async def run(self, action: BashAction) -> BashObservation: - """Run a bash action. + """Run a command in the PowerShell session. Raises: - SessionNotInitializedError: If the shell is not initialized. + SessionNotInitializedError: If the PowerShell process is not running. CommandTimeoutError: If the command times out. - NonZeroExitCodeError: If the command has a non-zero exit code and `action.check` is True. - NoExitCodeError: If we cannot get the exit code of the command. - - Returns: - BashObservation: The observation of the command. + NonZeroExitCodeError: If the command has a non-zero exit code and action.check is 'raise'. """ - if self.shell is None: - msg = "shell not initialized" - raise SessionNotInitializedError(msg) - if action.is_interactive_command or action.is_interactive_quit: - return await self._aync_run_interactive(action) - r = await self._async_run_normal(action) - if action.check == "raise" and r.exit_code != 0: - msg = f"Command {action.command!r} failed with exit code {r.exit_code}. Here is the output:\n{r.output!r}" - if action.error_msg: - msg = f"{action.error_msg}: {msg}" - raise NonZeroExitCodeError(msg) - return r + if self._process is None or self._process.poll() is not None: + raise SessionNotInitializedError("PowerShell session not initialized or has terminated") - async def _aync_run_interactive(self, action: BashAction) -> BashObservation: loop = asyncio.get_running_loop() - return await loop.run_in_executor(self._executor, self._run_interactive, action) + result = await loop.run_in_executor(self._executor, self._run_command, action) - def _run_interactive(self, action: BashAction) -> BashObservation: - """Run an interactive action. This is different because we don't seek to - the PS1 and don't attempt to get the exit code. - """ - assert self.shell is not None - self.shell.sendline(action.command) - expect_strings = action.expect + [self._ps1] - try: - self.shell.expect(expect_strings, timeout=action.timeout) # type: ignore - # matched_expect_string = expect_strings[expect_index] - except pexpect.TIMEOUT as e: - msg = f"timeout after {action.timeout} seconds while running command {action.command!r}" - raise CommandTimeoutError(msg) from e - finally: - self.refresh_shell() - output: str = _strip_control_chars(self.shell.before).strip() # type: ignore - if action.is_interactive_quit: - assert not action.is_interactive_command - self.shell.setecho(False) - self.shell.waitnoecho() - self.shell.sendline(f"stty -echo; echo '{self._UNIQUE_STRING}'") - # Might need two expects for some reason - self.shell.expect(self._UNIQUE_STRING, timeout=1) - self.shell.expect(self._ps1, timeout=1) - else: - # Interactive command. - # For some reason, this often times enables echo mode within the shell. - output = output.lstrip().removeprefix(action.command).strip() - return BashObservation(output=output, exit_code=0, expect_string=self.shell.after) - - def refresh_shell(self): - logger.debug(f"before refresh before_content: {self.shell._before.getvalue()}") - logger.debug(f"before refresh buffer_content: {self.shell._buffer.getvalue()}") - self.shell._before.seek(0) - self.shell._before.truncate(0) - self.shell._buffer.seek(0) - self.shell._buffer.truncate(0) - logger.debug(f"after refresh before_content: {self.shell._before.getvalue()}") - logger.debug(f"after refresh buffer_content: {self.shell._buffer.getvalue()}") - - async def _async_run_normal(self, action: BashAction) -> BashObservation: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self._executor, self._run_normal, action) + if action.check == "raise" and result.exit_code is not None and result.exit_code != 0: + msg = ( + f"Command {action.command!r} failed with exit code {result.exit_code}. " + f"Here is the output:\n{result.output!r}" + ) + if hasattr(action, "error_msg") and action.error_msg: + msg = f"{action.error_msg}: {msg}" + raise NonZeroExitCodeError(msg) - def _run_normal(self, action: BashAction) -> BashObservation: - """Run a normal action. This is the default mode. + return result - There are three steps to this: + def _run_command(self, action: BashAction) -> BashObservation: + """Execute a single command in the PowerShell session (blocking). - 1. Check if the command is valid - 2. Execute the command - 3. Get the exit code + Wraps the command with unique markers to reliably detect output boundaries + and extract the exit code ($LASTEXITCODE). """ - action = deepcopy(action) - - assert self.shell is not None - _check_bash_command(action.command) - - # Part 2: Execute the command - - fallback_terminator = False - # Running multiple interactive commands by sending them with linebreaks would break things - # because we get multiple PS1s back to back. Instead we just join them with ; - # However, sometimes bashlex errors and we can't do this. In this case - # we add a unique string to the end of the command and then seek to that - # (which is also somewhat brittle, so we don't do this by default). - try: - individual_commands = _split_bash_command(action.command) - except Exception as e: - # Bashlex is very buggy and can throw a variety of errors, including - # ParsingErrors, NotImplementedErrors, TypeErrors, possibly more. So we catch them all - logger.error("Bashlex fail: %s", e) - action.command += f"\n TMPEXITCODE=$? ; sleep 0.1; echo '{self._UNIQUE_STRING}' ; (exit $TMPEXITCODE)" - fallback_terminator = True - else: - action.command = " ; ".join(individual_commands) - self.shell.sendline(action.command) - if not fallback_terminator: - expect_strings = action.expect + [self._ps1] - else: - expect_strings = [self._UNIQUE_STRING] - try: - expect_index = self.shell.expect(expect_strings, timeout=action.timeout) # type: ignore - matched_expect_string = expect_strings[expect_index] - except pexpect.TIMEOUT as e: - msg = f"timeout after {action.timeout} seconds while running command {action.command!r}" - raise CommandTimeoutError(msg) from e - finally: - self.refresh_shell() - output: str = _strip_control_chars(self.shell.before).strip() # type: ignore - - # Part 3: Get the exit code - if action.check == "ignore": - return BashObservation(output=output, exit_code=None, expect_string=matched_expect_string) + with self._lock: + # Drain any leftover output from previous commands + self._drain_queue(timeout=0.1) + + # Send command wrapped with markers + self._send_line(f"Write-Host '{self._BEGIN_MARKER}'") + # Send command lines + for line in action.command.splitlines(): + self._send_line(line) + # Capture exit code: use $LASTEXITCODE for native commands, + # but also check $? for PowerShell cmdlet errors + self._send_line( + f"if ($LASTEXITCODE -ne $null -and $LASTEXITCODE -ne 0) " + f"{{ Write-Host '{self._EXIT_MARKER}'$LASTEXITCODE }} " + f"elseif (-not $?) {{ Write-Host '{self._EXIT_MARKER}1' }} " + f"else {{ Write-Host '{self._EXIT_MARKER}0' }}" + ) + self._send_line(f"Write-Host '{self._END_MARKER}'") + + # Collect output until end marker + output_lines: list[str] = [] + started = False + exit_code: int | None = None + timeout = action.timeout if action.timeout is not None else 1200 + deadline = time.time() + timeout + + while True: + remaining = deadline - time.time() + if remaining <= 0: + raise CommandTimeoutError( + f"timeout after {timeout} seconds while running command {action.command!r}" + ) + try: + line = self._output_queue.get(timeout=min(remaining, 1.0)) + line = line.rstrip("\r\n") + except queue_module.Empty: + # Check if process is still alive + if self._process is not None and self._process.poll() is not None: + raise SessionNotInitializedError("PowerShell process terminated unexpectedly") + continue + + if self._BEGIN_MARKER in line: + started = True + continue + if self._EXIT_MARKER in line: + after_marker = line.split(self._EXIT_MARKER)[-1].strip() + try: + exit_code = int(after_marker) if after_marker else 0 + except ValueError: + exit_code = 0 + continue + if self._END_MARKER in line: + break + if started: + output_lines.append(line) + + output = "\n".join(output_lines).strip() + output = _strip_control_chars(output) + + if action.check == "ignore": + return BashObservation(output=output, exit_code=None) + + return BashObservation(output=output, exit_code=exit_code) - try: - _exit_code_prefix = "EXITCODESTART" - _exit_code_suffix = "EXITCODEEND" - self.shell.sendline(f"\necho {_exit_code_prefix}$?{_exit_code_suffix}") + async def close(self) -> CloseSessionResponse: + """Close the PowerShell session.""" + if self._process is not None: try: - self.shell.expect(_exit_code_suffix, timeout=1) - except pexpect.TIMEOUT: - msg = "timeout while getting exit code" - raise NoExitCodeError(msg) - exit_code_raw: str = _strip_control_chars(self.shell.before).strip() # type: ignore - exit_code = re.findall(f"{_exit_code_prefix}([0-9]+)", exit_code_raw) - if len(exit_code) != 1: - msg = f"failed to parse exit code from output {exit_code_raw!r} (command: {action.command!r}, matches: {exit_code})" - raise NoExitCodeError(msg) - output += exit_code_raw.split(_exit_code_prefix)[0] - exit_code = int(exit_code[0]) - # We get at least one more PS1 here. + self._send_line("exit") + except Exception: + pass try: - self.shell.expect(self._ps1, timeout=0.1) - except pexpect.TIMEOUT: - msg = "Timeout while getting PS1 after exit code extraction" - raise CommandTimeoutError(msg) - output = output.replace(self._UNIQUE_STRING, "").replace(self._ps1, "") - except Exception: - # Ignore all exceptions if check == 'silent' - if action.check == "raise": - raise - exit_code = None - finally: - self.refresh_shell() - return BashObservation(output=output, exit_code=exit_code, expect_string=matched_expect_string) - - async def close(self) -> CloseSessionResponse: - if self._shell is None: - return CloseBashSessionResponse() - self.shell.close() - self._shell = None + self._process.terminate() + self._process.wait(timeout=5) + except Exception: + try: + self._process.kill() + except Exception: + pass + self._process = None return CloseBashSessionResponse() def interact(self) -> None: - """Enter interactive mode.""" - self.shell.interact() + """Interactive mode is not supported for PowerShell sessions.""" + raise NotImplementedError("Interactive mode is not supported for PowerShell sessions") class LocalSandboxRuntime(AbstractSandbox): @@ -394,7 +616,7 @@ def __init__(self, *, executor: ThreadPoolExecutor | None = None, **kwargs: Any) # Set up logger self.command_logger = init_logger("command", "command.log") self._executor = executor - self._gem_envs: dict[str, gem.Env] = {} + self._gem_envs: dict[str, Any] = {} @classmethod def from_config(cls, config: LocalSandboxRuntimeConfig) -> Self: @@ -409,17 +631,24 @@ async def is_alive(self, *, timeout: float | None = None) -> IsAliveResponse: return IsAliveResponse(is_alive=True) async def create_session(self, request: CreateSessionRequest) -> CreateSessionResponse: - """Creates a new session.""" + """Creates a new session. + + On Linux, creates a BashSession using pexpect. + On Windows, creates a PowerShellSession using subprocess. + """ if request.session in self.sessions: msg = f"session {request.session} already exists" raise SessionExistsError(msg) if isinstance(request, CreateBashSessionRequest): - session = BashSession(request) + if IS_WINDOWS: + session = PowerShellSession(request) + else: + session = BashSession(request) else: msg = f"unknown session type: {request!r}" raise ValueError(msg) self.sessions[request.session] = session - self.command_logger.info(f"[create_session]:{request.session}") + self.command_logger.info(f"[create_session]:{request.session} (platform={'windows' if IS_WINDOWS else 'linux'})") return await session.start() async def run_in_session(self, action: Action) -> Observation: @@ -526,7 +755,8 @@ async def close(self) -> CloseResponse: async def get_statistics(self): cpu_percent: float = psutil.cpu_percent() mem_percent: float = psutil.virtual_memory().percent - disk_percent: float = psutil.disk_usage("/").percent + disk_path = "C:\\" if IS_WINDOWS else "/" + disk_percent: float = psutil.disk_usage(disk_path).percent net_io: int = psutil.net_io_counters().bytes_recv + psutil.net_io_counters().bytes_sent return { "cpu": cpu_percent, @@ -539,6 +769,8 @@ def env_make(self, env_id: str, sandbox_id: str) -> EnvMakeResponse: """ Make gem env """ + import gem + env = gem.make(env_id) self._gem_envs[sandbox_id] = env return EnvMakeResponse(sandbox_id=sandbox_id)