diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 38773bb9f..ece6300d9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,8 +1,6 @@ name: ci -on: - pull_request: {} - push: { branches: [main] } +on: workflow_dispatch jobs: build-test: diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index 17d54f214..0048fcf8d 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -1,9 +1,5 @@ name: CLA Assistant -on: - issue_comment: - types: [created] - pull_request_target: - types: [opened, closed, synchronize] +on: workflow_dispatch permissions: actions: write diff --git a/.github/workflows/close-stale-contributor-prs.yml b/.github/workflows/close-stale-contributor-prs.yml index e01bc3881..d7fcc1c73 100644 --- a/.github/workflows/close-stale-contributor-prs.yml +++ b/.github/workflows/close-stale-contributor-prs.yml @@ -1,9 +1,6 @@ name: Close stale contributor PRs -on: - workflow_dispatch: - schedule: - - cron: "0 6 * * *" +on: workflow_dispatch permissions: contents: read diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml index c03658132..45c853a6a 100644 --- a/.github/workflows/codespell.yml +++ b/.github/workflows/codespell.yml @@ -2,11 +2,7 @@ --- name: Codespell -on: - push: - branches: [main] - pull_request: - branches: [main] +on: workflow_dispatch permissions: contents: read diff --git a/.github/workflows/issue-deduplicator.yml b/.github/workflows/issue-deduplicator.yml index 579b6a368..be823ff3b 100644 --- a/.github/workflows/issue-deduplicator.yml +++ b/.github/workflows/issue-deduplicator.yml @@ -1,10 +1,6 @@ name: Issue Deduplicator -on: - issues: - types: - - opened - - labeled +on: workflow_dispatch jobs: gather-duplicates: @@ -46,7 +42,6 @@ jobs: with: openai-api-key: ${{ secrets.CODEX_OPENAI_API_KEY }} allow-users: "*" - model: gpt-5.1 prompt: | You are an assistant that triages new GitHub issues by identifying potential duplicates. diff --git a/.github/workflows/issue-labeler.yml b/.github/workflows/issue-labeler.yml index 39f9d47f1..5e1a6de25 100644 --- a/.github/workflows/issue-labeler.yml +++ b/.github/workflows/issue-labeler.yml @@ -1,10 +1,6 @@ name: Issue Labeler -on: - issues: - types: - - opened - - labeled +on: workflow_dispatch jobs: gather-labels: diff --git a/.github/workflows/rust-release.yml b/.github/workflows/rust-release.yml index 6f27fbf54..ce54140c2 100644 --- a/.github/workflows/rust-release.yml +++ b/.github/workflows/rust-release.yml @@ -6,10 +6,7 @@ # ``` name: rust-release -on: - push: - tags: - - "rust-v*.*.*" +on: workflow_dispatch concurrency: group: ${{ github.workflow }} @@ -371,8 +368,20 @@ jobs: path: | codex-rs/dist/${{ matrix.target }}/* + shell-tool-mcp: + name: shell-tool-mcp + needs: tag-check + uses: ./.github/workflows/shell-tool-mcp.yml + with: + release-tag: ${{ github.ref_name }} + # We are not ready to publish yet. + publish: false + secrets: inherit + release: - needs: build + needs: + - build + - shell-tool-mcp name: release runs-on: ubuntu-latest permissions: @@ -395,6 +404,14 @@ jobs: - name: List run: ls -R dist/ + # This is a temporary fix: we should modify shell-tool-mcp.yml so these + # files do not end up in dist/ in the first place. + - name: Delete entries from dist/ that should not go in the release + run: | + rm -rf dist/shell-tool-mcp* + + ls -R dist/ + - name: Define release name id: release_name run: | diff --git a/.github/workflows/sdk.yml b/.github/workflows/sdk.yml index 0f3a7a194..39c0c3c07 100644 --- a/.github/workflows/sdk.yml +++ b/.github/workflows/sdk.yml @@ -1,9 +1,6 @@ name: sdk -on: - push: - branches: [main] - pull_request: {} +on: workflow_dispatch jobs: sdks: diff --git a/.github/workflows/shell-tool-mcp-ci.yml b/.github/workflows/shell-tool-mcp-ci.yml new file mode 100644 index 000000000..e088c7088 --- /dev/null +++ b/.github/workflows/shell-tool-mcp-ci.yml @@ -0,0 +1,36 @@ +name: shell-tool-mcp CI + +on: workflow_dispatch + +env: + NODE_VERSION: 22 + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + run_install: false + + - name: Setup Node.js + uses: actions/setup-node@v5 + with: + node-version: ${{ env.NODE_VERSION }} + cache: "pnpm" + + - name: Install dependencies + run: pnpm install --frozen-lockfile + + - name: Format check + run: pnpm --filter @openai/codex-shell-tool-mcp run format + + - name: Run tests + run: pnpm --filter @openai/codex-shell-tool-mcp test + + - name: Build + run: pnpm --filter @openai/codex-shell-tool-mcp run build diff --git a/.github/workflows/shell-tool-mcp.yml b/.github/workflows/shell-tool-mcp.yml new file mode 100644 index 000000000..633a8f943 --- /dev/null +++ b/.github/workflows/shell-tool-mcp.yml @@ -0,0 +1,397 @@ +name: shell-tool-mcp + +on: workflow_dispatch + +env: + NODE_VERSION: 22 + +jobs: + metadata: + runs-on: ubuntu-latest + outputs: + version: ${{ steps.compute.outputs.version }} + release_tag: ${{ steps.compute.outputs.release_tag }} + should_publish: ${{ steps.compute.outputs.should_publish }} + npm_tag: ${{ steps.compute.outputs.npm_tag }} + steps: + - name: Compute version and tags + id: compute + run: | + set -euo pipefail + + version="${{ inputs.release-version }}" + release_tag="${{ inputs.release-tag }}" + + if [[ -z "$version" ]]; then + if [[ -n "$release_tag" && "$release_tag" =~ ^rust-v.+ ]]; then + version="${release_tag#rust-v}" + elif [[ "${GITHUB_REF_NAME:-}" =~ ^rust-v.+ ]]; then + version="${GITHUB_REF_NAME#rust-v}" + release_tag="${GITHUB_REF_NAME}" + else + echo "release-version is required when GITHUB_REF_NAME is not a rust-v tag." + exit 1 + fi + fi + + if [[ -z "$release_tag" ]]; then + release_tag="rust-v${version}" + fi + + npm_tag="" + should_publish="false" + if [[ "$version" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + should_publish="true" + elif [[ "$version" =~ ^[0-9]+\.[0-9]+\.[0-9]+-alpha\.[0-9]+$ ]]; then + should_publish="true" + npm_tag="alpha" + fi + + echo "version=${version}" >> "$GITHUB_OUTPUT" + echo "release_tag=${release_tag}" >> "$GITHUB_OUTPUT" + echo "npm_tag=${npm_tag}" >> "$GITHUB_OUTPUT" + echo "should_publish=${should_publish}" >> "$GITHUB_OUTPUT" + + rust-binaries: + name: Build Rust - ${{ matrix.target }} + needs: metadata + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + defaults: + run: + working-directory: codex-rs + strategy: + fail-fast: false + matrix: + include: + - runner: macos-15-xlarge + target: aarch64-apple-darwin + - runner: macos-15-xlarge + target: x86_64-apple-darwin + - runner: ubuntu-24.04 + target: x86_64-unknown-linux-musl + install_musl: true + - runner: ubuntu-24.04-arm + target: aarch64-unknown-linux-musl + install_musl: true + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - uses: dtolnay/rust-toolchain@1.90 + with: + targets: ${{ matrix.target }} + + - if: ${{ matrix.install_musl }} + name: Install musl build dependencies + run: | + sudo apt-get update + sudo apt-get install -y musl-tools pkg-config + + - name: Build exec server binaries + run: cargo build --release --target ${{ matrix.target }} --bin codex-exec-mcp-server --bin codex-execve-wrapper + + - name: Stage exec server binaries + run: | + dest="${GITHUB_WORKSPACE}/artifacts/vendor/${{ matrix.target }}" + mkdir -p "$dest" + cp "target/${{ matrix.target }}/release/codex-exec-mcp-server" "$dest/" + cp "target/${{ matrix.target }}/release/codex-execve-wrapper" "$dest/" + + - uses: actions/upload-artifact@v4 + with: + name: shell-tool-mcp-rust-${{ matrix.target }} + path: artifacts/** + if-no-files-found: error + + bash-linux: + name: Build Bash (Linux) - ${{ matrix.variant }} - ${{ matrix.target }} + needs: metadata + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + container: + image: ${{ matrix.image }} + strategy: + fail-fast: false + matrix: + include: + - runner: ubuntu-24.04 + target: x86_64-unknown-linux-musl + variant: ubuntu-24.04 + image: ubuntu:24.04 + - runner: ubuntu-24.04 + target: x86_64-unknown-linux-musl + variant: ubuntu-22.04 + image: ubuntu:22.04 + - runner: ubuntu-24.04 + target: x86_64-unknown-linux-musl + variant: ubuntu-20.04 + image: ubuntu:20.04 + - runner: ubuntu-24.04 + target: x86_64-unknown-linux-musl + variant: debian-12 + image: debian:12 + - runner: ubuntu-24.04 + target: x86_64-unknown-linux-musl + variant: debian-11 + image: debian:11 + - runner: ubuntu-24.04 + target: x86_64-unknown-linux-musl + variant: centos-9 + image: quay.io/centos/centos:stream9 + - runner: ubuntu-24.04-arm + target: aarch64-unknown-linux-musl + variant: ubuntu-24.04 + image: arm64v8/ubuntu:24.04 + - runner: ubuntu-24.04-arm + target: aarch64-unknown-linux-musl + variant: ubuntu-22.04 + image: arm64v8/ubuntu:22.04 + - runner: ubuntu-24.04-arm + target: aarch64-unknown-linux-musl + variant: ubuntu-20.04 + image: arm64v8/ubuntu:20.04 + - runner: ubuntu-24.04-arm + target: aarch64-unknown-linux-musl + variant: debian-12 + image: arm64v8/debian:12 + - runner: ubuntu-24.04-arm + target: aarch64-unknown-linux-musl + variant: debian-11 + image: arm64v8/debian:11 + - runner: ubuntu-24.04-arm + target: aarch64-unknown-linux-musl + variant: centos-9 + image: quay.io/centos/centos:stream9 + steps: + - name: Install build prerequisites + shell: bash + run: | + set -euo pipefail + if command -v apt-get >/dev/null 2>&1; then + apt-get update + DEBIAN_FRONTEND=noninteractive apt-get install -y git build-essential bison autoconf gettext + elif command -v dnf >/dev/null 2>&1; then + dnf install -y git gcc gcc-c++ make bison autoconf gettext + elif command -v yum >/dev/null 2>&1; then + yum install -y git gcc gcc-c++ make bison autoconf gettext + else + echo "Unsupported package manager in container" + exit 1 + fi + + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Build patched Bash + shell: bash + run: | + set -euo pipefail + git clone --depth 1 https://github.com/bminor/bash /tmp/bash + cd /tmp/bash + git fetch --depth 1 origin a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b + git checkout a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b + git apply "${GITHUB_WORKSPACE}/shell-tool-mcp/patches/bash-exec-wrapper.patch" + ./configure --without-bash-malloc + cores="$(command -v nproc >/dev/null 2>&1 && nproc || getconf _NPROCESSORS_ONLN)" + make -j"${cores}" + + dest="${GITHUB_WORKSPACE}/artifacts/vendor/${{ matrix.target }}/bash/${{ matrix.variant }}" + mkdir -p "$dest" + cp bash "$dest/bash" + + - uses: actions/upload-artifact@v4 + with: + name: shell-tool-mcp-bash-${{ matrix.target }}-${{ matrix.variant }} + path: artifacts/** + if-no-files-found: error + + bash-darwin: + name: Build Bash (macOS) - ${{ matrix.variant }} - ${{ matrix.target }} + needs: metadata + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + include: + - runner: macos-15-xlarge + target: aarch64-apple-darwin + variant: macos-15 + - runner: macos-14 + target: aarch64-apple-darwin + variant: macos-14 + - runner: macos-13 + target: x86_64-apple-darwin + variant: macos-13 + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Build patched Bash + shell: bash + run: | + set -euo pipefail + git clone --depth 1 https://github.com/bminor/bash /tmp/bash + cd /tmp/bash + git fetch --depth 1 origin a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b + git checkout a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b + git apply "${GITHUB_WORKSPACE}/shell-tool-mcp/patches/bash-exec-wrapper.patch" + ./configure --without-bash-malloc + cores="$(getconf _NPROCESSORS_ONLN)" + make -j"${cores}" + + dest="${GITHUB_WORKSPACE}/artifacts/vendor/${{ matrix.target }}/bash/${{ matrix.variant }}" + mkdir -p "$dest" + cp bash "$dest/bash" + + - uses: actions/upload-artifact@v4 + with: + name: shell-tool-mcp-bash-${{ matrix.target }}-${{ matrix.variant }} + path: artifacts/** + if-no-files-found: error + + package: + name: Package npm module + needs: + - metadata + - rust-binaries + - bash-linux + - bash-darwin + runs-on: ubuntu-latest + env: + PACKAGE_VERSION: ${{ needs.metadata.outputs.version }} + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 10.8.1 + run_install: false + + - name: Setup Node.js + uses: actions/setup-node@v5 + with: + node-version: ${{ env.NODE_VERSION }} + + - name: Install JavaScript dependencies + run: pnpm install --frozen-lockfile + + - name: Build (shell-tool-mcp) + run: pnpm --filter @openai/codex-shell-tool-mcp run build + + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts + + - name: Assemble staging directory + id: staging + shell: bash + run: | + set -euo pipefail + staging="${STAGING_DIR}" + mkdir -p "$staging" "$staging/vendor" + cp shell-tool-mcp/README.md "$staging/" + cp shell-tool-mcp/package.json "$staging/" + cp -R shell-tool-mcp/bin "$staging/" + + found_vendor="false" + shopt -s nullglob + for vendor_dir in artifacts/*/vendor; do + rsync -av "$vendor_dir/" "$staging/vendor/" + found_vendor="true" + done + if [[ "$found_vendor" == "false" ]]; then + echo "No vendor payloads were downloaded." + exit 1 + fi + + node - <<'NODE' + import fs from "node:fs"; + import path from "node:path"; + + const stagingDir = process.env.STAGING_DIR; + const version = process.env.PACKAGE_VERSION; + const pkgPath = path.join(stagingDir, "package.json"); + const pkg = JSON.parse(fs.readFileSync(pkgPath, "utf8")); + pkg.version = version; + fs.writeFileSync(pkgPath, JSON.stringify(pkg, null, 2) + "\n"); + NODE + + echo "dir=$staging" >> "$GITHUB_OUTPUT" + env: + STAGING_DIR: ${{ runner.temp }}/shell-tool-mcp + + - name: Ensure binaries are executable + run: | + set -euo pipefail + staging="${{ steps.staging.outputs.dir }}" + chmod +x \ + "$staging"/vendor/*/codex-exec-mcp-server \ + "$staging"/vendor/*/codex-execve-wrapper \ + "$staging"/vendor/*/bash/*/bash + + - name: Create npm tarball + shell: bash + run: | + set -euo pipefail + mkdir -p dist/npm + staging="${{ steps.staging.outputs.dir }}" + pack_info=$(cd "$staging" && npm pack --ignore-scripts --json --pack-destination "${GITHUB_WORKSPACE}/dist/npm") + filename=$(PACK_INFO="$pack_info" node -e 'const data = JSON.parse(process.env.PACK_INFO); console.log(data[0].filename);') + mv "dist/npm/${filename}" "dist/npm/codex-shell-tool-mcp-npm-${PACKAGE_VERSION}.tgz" + + - uses: actions/upload-artifact@v4 + with: + name: codex-shell-tool-mcp-npm + path: dist/npm/codex-shell-tool-mcp-npm-${{ env.PACKAGE_VERSION }}.tgz + if-no-files-found: error + + publish: + name: Publish npm package + needs: + - metadata + - package + if: ${{ inputs.publish && needs.metadata.outputs.should_publish == 'true' }} + runs-on: ubuntu-latest + permissions: + id-token: write + contents: read + steps: + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 10.8.1 + run_install: false + + - name: Setup Node.js + uses: actions/setup-node@v5 + with: + node-version: ${{ env.NODE_VERSION }} + registry-url: https://registry.npmjs.org + scope: "@openai" + + - name: Update npm + run: npm install -g npm@latest + + - name: Download npm tarball + uses: actions/download-artifact@v4 + with: + name: codex-shell-tool-mcp-npm + path: dist/npm + + - name: Publish to npm + env: + NPM_TAG: ${{ needs.metadata.outputs.npm_tag }} + VERSION: ${{ needs.metadata.outputs.version }} + shell: bash + run: | + set -euo pipefail + tag_args=() + if [[ -n "${NPM_TAG}" ]]; then + tag_args+=(--tag "${NPM_TAG}") + fi + npm publish "dist/npm/codex-shell-tool-mcp-npm-${VERSION}.tgz" "${tag_args[@]}" diff --git a/README.md b/README.md index 814161003..b90e6d6d7 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,38 @@ Codex can access MCP servers. To configure them, refer to the [config docs](./do Codex CLI supports a rich set of configuration options, with preferences stored in `~/.codex/config.toml`. For full configuration options, see [Configuration](./docs/config.md). ---- +### Execpolicy Quickstart + +Codex can enforce your own rules-based execution policy before it runs shell commands. + +1. Create a policy directory: `mkdir -p ~/.codex/policy`. +2. Create one or more `.codexpolicy` files in that folder. Codex automatically loads every `.codexpolicy` file in there on startup. +3. Write `prefix_rule` entries to describe the commands you want to allow, prompt, or block: + +```starlark +prefix_rule( + pattern = ["git", ["push", "fetch"]], + decision = "prompt", # allow | prompt | forbidden + match = [["git", "push", "origin", "main"]], # examples that must match + not_match = [["git", "status"]], # examples that must not match +) +``` + +- `pattern` is a list of shell tokens, evaluated from left to right; wrap tokens in a nested list to express alternatives (e.g., match both `push` and `fetch`). +- `decision` sets the severity; Codex picks the strictest decision when multiple rules match (forbidden > prompt > allow). +- `match` and `not_match` act as (optional) unit tests. Codex validates them when it loads your policy, so you get feedback if an example has unexpected behavior. + +In this example rule, if Codex wants to run commands with the prefix `git push` or `git fetch`, it will first ask for user approval. + +Use the `codex execpolicy check` subcommand to preview decisions before you save a rule (see the [`codex-execpolicy` README](./codex-rs/execpolicy/README.md) for syntax details): + +```shell +codex execpolicy check --policy ~/.codex/policy/default.codexpolicy git push origin main +``` + +Pass multiple `--policy` flags to test how several files combine, and use `--pretty` for formatted JSON output. See the [`codex-rs/execpolicy` README](./codex-rs/execpolicy/README.md) for a more detailed walkthrough of the available syntax. + +## Note: `execpolicy` commands are still in preview. The API may have breaking changes in the future. ### Docs & FAQ diff --git a/codex-rs/.config/nextest.toml b/codex-rs/.config/nextest.toml index 3ca7cfe50..f432af88e 100644 --- a/codex-rs/.config/nextest.toml +++ b/codex-rs/.config/nextest.toml @@ -7,3 +7,7 @@ slow-timeout = { period = "15s", terminate-after = 2 } # Do not add new tests here filter = 'test(rmcp_client) | test(humanlike_typing_1000_chars_appears_live_no_placeholder)' slow-timeout = { period = "1m", terminate-after = 4 } + +[[profile.default.overrides]] +filter = 'test(approval_matrix_covers_all_modes)' +slow-timeout = { period = "30s", terminate-after = 2 } diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 0ed45ddb2..c891e76fa 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -187,8 +187,10 @@ dependencies = [ "codex-app-server-protocol", "codex-core", "codex-protocol", + "core_test_support", "serde", "serde_json", + "shlex", "tokio", "uuid", "wiremock", @@ -260,7 +262,7 @@ dependencies = [ "memchr", "proc-macro2", "quote", - "rustc-hash 2.1.1", + "rustc-hash", "serde", "serde_derive", "syn 2.0.104", @@ -726,6 +728,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chardetng" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14b8f0b65b7b08ae3c8187e8d77174de20cb6777864c6b832d8ad365999cf1ea" +dependencies = [ + "cfg-if", + "encoding_rs", + "memchr", +] + [[package]] name = "chrono" version = "0.4.42" @@ -857,6 +870,7 @@ dependencies = [ "serde", "serde_json", "serial_test", + "shlex", "tempfile", "tokio", "toml", @@ -879,6 +893,7 @@ dependencies = [ "serde", "serde_json", "strum_macros 0.27.2", + "thiserror 2.0.17", "ts-rs", "uuid", ] @@ -989,6 +1004,7 @@ dependencies = [ "codex-common", "codex-core", "codex-exec", + "codex-execpolicy", "codex-login", "codex-mcp-server", "codex-process-hardening", @@ -1080,11 +1096,13 @@ dependencies = [ "async-trait", "base64", "bytes", + "chardetng", "chrono", "codex-app-server-protocol", "codex-apply-patch", "codex-arg0", "codex-async-utils", + "codex-execpolicy", "codex-file-search", "codex-git", "codex-keyring-store", @@ -1094,13 +1112,13 @@ dependencies = [ "codex-utils-pty", "codex-utils-readiness", "codex-utils-string", - "codex-utils-tokenizer", "codex-windows-sandbox", "core-foundation 0.9.4", "core_test_support", "ctor 0.5.0", "dirs", "dunce", + "encoding_rs", "env-flags", "escargot", "eventsource-stream", @@ -1183,27 +1201,30 @@ dependencies = [ ] [[package]] -name = "codex-execpolicy" +name = "codex-exec-server" version = "0.0.0" dependencies = [ - "allocative", "anyhow", + "async-trait", "clap", - "derive_more 2.0.1", - "env_logger", - "log", - "multimap", + "codex-core", + "libc", "path-absolutize", - "regex-lite", + "pretty_assertions", + "rmcp", "serde", "serde_json", - "serde_with", - "starlark", + "shlex", + "socket2 0.6.0", "tempfile", + "tokio", + "tokio-util", + "tracing", + "tracing-subscriber", ] [[package]] -name = "codex-execpolicy2" +name = "codex-execpolicy" version = "0.0.0" dependencies = [ "anyhow", @@ -1217,6 +1238,26 @@ dependencies = [ "thiserror 2.0.17", ] +[[package]] +name = "codex-execpolicy-legacy" +version = "0.0.0" +dependencies = [ + "allocative", + "anyhow", + "clap", + "derive_more 2.0.1", + "env_logger", + "log", + "multimap", + "path-absolutize", + "regex-lite", + "serde", + "serde_json", + "serde_with", + "starlark", + "tempfile", +] + [[package]] name = "codex-feedback" version = "0.0.0" @@ -1366,6 +1407,7 @@ dependencies = [ "codex-app-server-protocol", "codex-protocol", "eventsource-stream", + "http", "opentelemetry", "opentelemetry-otlp", "opentelemetry-semantic-conventions", @@ -1399,6 +1441,7 @@ dependencies = [ "icu_provider", "mcp-types", "mime_guess", + "pretty_assertions", "schemars 0.8.22", "serde", "serde_json", @@ -1589,23 +1632,12 @@ dependencies = [ name = "codex-utils-string" version = "0.0.0" -[[package]] -name = "codex-utils-tokenizer" -version = "0.0.0" -dependencies = [ - "anyhow", - "codex-utils-cache", - "pretty_assertions", - "thiserror 2.0.17", - "tiktoken-rs", - "tokio", -] - [[package]] name = "codex-windows-sandbox" version = "0.1.0" dependencies = [ "anyhow", + "codex-protocol", "dirs-next", "dunce", "rand 0.8.5", @@ -1747,6 +1779,7 @@ dependencies = [ "notify", "regex-lite", "serde_json", + "shlex", "tempfile", "tokio", "walkdir", @@ -2421,17 +2454,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "fancy-regex" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" -dependencies = [ - "bit-set", - "regex-automata", - "regex-syntax 0.8.5", -] - [[package]] name = "fastrand" version = "2.3.0" @@ -3719,11 +3741,13 @@ dependencies = [ "assert_cmd", "codex-core", "codex-mcp-server", + "core_test_support", "mcp-types", "os_info", "pretty_assertions", "serde", "serde_json", + "shlex", "tokio", "wiremock", ] @@ -4756,7 +4780,7 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", "socket2 0.6.0", "thiserror 2.0.17", @@ -4776,7 +4800,7 @@ dependencies = [ "lru-slab", "rand 0.9.2", "ring", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", "rustls-pki-types", "slab", @@ -5121,12 +5145,6 @@ version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.1" @@ -5174,6 +5192,7 @@ version = "0.23.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -6346,21 +6365,6 @@ dependencies = [ "zune-jpeg", ] -[[package]] -name = "tiktoken-rs" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a19830747d9034cd9da43a60eaa8e552dfda7712424aebf187b7a60126bae0d" -dependencies = [ - "anyhow", - "base64", - "bstr", - "fancy-regex", - "lazy_static", - "regex", - "rustc-hash 1.1.0", -] - [[package]] name = "time" version = "0.3.44" @@ -6601,8 +6605,10 @@ dependencies = [ "percent-encoding", "pin-project", "prost", + "rustls-native-certs", "socket2 0.5.10", "tokio", + "tokio-rustls", "tokio-stream", "tower", "tower-layer", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index b19bf7660..69f7b2d07 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -16,8 +16,9 @@ members = [ "common", "core", "exec", + "exec-server", "execpolicy", - "execpolicy2", + "execpolicy-legacy", "keyring-store", "file-search", "linux-sandbox", @@ -40,12 +41,11 @@ members = [ "utils/pty", "utils/readiness", "utils/string", - "utils/tokenizer", ] resolver = "2" [workspace.package] -version = "0.0.0" +version = "0.63.0" # Track the edition for all workspace crates in one place. Individual # crates can still override this value, but keeping it here means new # crates created with `cargo new -w ...` automatically inherit the 2024 @@ -66,6 +66,7 @@ codex-chatgpt = { path = "chatgpt" } codex-common = { path = "common" } codex-core = { path = "core" } codex-exec = { path = "exec" } +codex-execpolicy = { path = "execpolicy" } codex-feedback = { path = "feedback" } codex-file-search = { path = "file-search" } codex-git = { path = "utils/git" } @@ -88,7 +89,6 @@ codex-utils-json-to-toml = { path = "utils/json-to-toml" } codex-utils-pty = { path = "utils/pty" } codex-utils-readiness = { path = "utils/readiness" } codex-utils-string = { path = "utils/string" } -codex-utils-tokenizer = { path = "utils/tokenizer" } codex-windows-sandbox = { path = "windows-sandbox-rs" } core_test_support = { path = "core/tests/common" } mcp-types = { path = "mcp-types" } @@ -109,6 +109,7 @@ axum = { version = "0.8", default-features = false } base64 = "0.22.1" bytes = "1.10.1" chrono = "0.4.42" +chardetng = "0.1.17" clap = "4" clap_complete = "4" color-eyre = "0.6.3" @@ -121,6 +122,7 @@ dotenvy = "0.15.7" dunce = "1.0.4" env-flags = "0.1.1" env_logger = "0.11.5" +encoding_rs = "0.8.35" escargot = "0.5" eventsource-stream = "0.2.3" futures = { version = "0.3", default-features = false } @@ -167,7 +169,6 @@ reqwest = "0.12" rmcp = { version = "0.8.5", default-features = false } schemars = "0.8.22" seccompiler = "0.5.0" -sentry = "0.34.0" serde = "1" serde_json = "1" serde_with = "3.14" @@ -176,6 +177,7 @@ sha1 = "0.10.6" sha2 = "0.10" shlex = "1.3.0" similar = "2.7.0" +socket2 = "0.6.0" starlark = "0.13.0" strum = "0.27.2" strum_macros = "0.27.2" @@ -185,7 +187,6 @@ tempfile = "3.23.0" test-log = "0.2.18" textwrap = "0.16.2" thiserror = "2.0.17" -tiktoken-rs = "0.9" time = "0.3" tiny_http = "0.12" tokio = "1" @@ -263,7 +264,6 @@ ignored = [ "icu_provider", "openssl-sys", "codex-utils-readiness", - "codex-utils-tokenizer", ] [profile.release] diff --git a/codex-rs/app-server-protocol/Cargo.toml b/codex-rs/app-server-protocol/Cargo.toml index 4d1afadaa..47753b344 100644 --- a/codex-rs/app-server-protocol/Cargo.toml +++ b/codex-rs/app-server-protocol/Cargo.toml @@ -19,6 +19,7 @@ schemars = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } strum_macros = { workspace = true } +thiserror = { workspace = true } ts-rs = { workspace = true } uuid = { workspace = true, features = ["serde", "v7"] } diff --git a/codex-rs/app-server-protocol/src/export.rs b/codex-rs/app-server-protocol/src/export.rs index 11296e8e5..4de66dfb3 100644 --- a/codex-rs/app-server-protocol/src/export.rs +++ b/codex-rs/app-server-protocol/src/export.rs @@ -61,7 +61,32 @@ pub fn generate_types(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { Ok(()) } +#[derive(Clone, Copy, Debug)] +pub struct GenerateTsOptions { + pub generate_indices: bool, + pub ensure_headers: bool, + pub run_prettier: bool, +} + +impl Default for GenerateTsOptions { + fn default() -> Self { + Self { + generate_indices: true, + ensure_headers: true, + run_prettier: true, + } + } +} + pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { + generate_ts_with_options(out_dir, prettier, GenerateTsOptions::default()) +} + +pub fn generate_ts_with_options( + out_dir: &Path, + prettier: Option<&Path>, + options: GenerateTsOptions, +) -> Result<()> { let v2_out_dir = out_dir.join("v2"); ensure_dir(out_dir)?; ensure_dir(&v2_out_dir)?; @@ -74,17 +99,28 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { export_server_responses(out_dir)?; ServerNotification::export_all_to(out_dir)?; - generate_index_ts(out_dir)?; - generate_index_ts(&v2_out_dir)?; + if options.generate_indices { + generate_index_ts(out_dir)?; + generate_index_ts(&v2_out_dir)?; + } // Ensure our header is present on all TS files (root + subdirs like v2/). - let ts_files = ts_files_in_recursive(out_dir)?; - for file in &ts_files { - prepend_header_if_missing(file)?; + let mut ts_files = Vec::new(); + let should_collect_ts_files = + options.ensure_headers || (options.run_prettier && prettier.is_some()); + if should_collect_ts_files { + ts_files = ts_files_in_recursive(out_dir)?; + } + + if options.ensure_headers { + for file in &ts_files { + prepend_header_if_missing(file)?; + } } // Optionally run Prettier on all generated TS files. - if let Some(prettier_bin) = prettier + if options.run_prettier + && let Some(prettier_bin) = prettier && !ts_files.is_empty() { let status = Command::new(prettier_bin) @@ -723,7 +759,13 @@ mod tests { let _guard = TempDirGuard(output_dir.clone()); - generate_ts(&output_dir, None)?; + // Avoid doing more work than necessary to keep the test from timing out. + let options = GenerateTsOptions { + generate_indices: false, + ensure_headers: false, + run_prettier: false, + }; + generate_ts_with_options(&output_dir, None, options)?; let mut undefined_offenders = Vec::new(); let mut optional_nullable_offenders = BTreeSet::new(); diff --git a/codex-rs/app-server-protocol/src/lib.rs b/codex-rs/app-server-protocol/src/lib.rs index 9c02ea924..06102083f 100644 --- a/codex-rs/app-server-protocol/src/lib.rs +++ b/codex-rs/app-server-protocol/src/lib.rs @@ -7,5 +7,6 @@ pub use export::generate_ts; pub use export::generate_types; pub use jsonrpc_lite::*; pub use protocol::common::*; +pub use protocol::thread_history::*; pub use protocol::v1::*; pub use protocol::v2::*; diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index db9bed111..fecdc5b71 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -129,6 +129,10 @@ client_request_definitions! { params: v2::TurnInterruptParams, response: v2::TurnInterruptResponse, }, + ReviewStart => "review/start" { + params: v2::ReviewStartParams, + response: v2::TurnStartResponse, + }, ModelList => "model/list" { params: v2::ModelListParams, @@ -374,7 +378,7 @@ macro_rules! server_notification_definitions { impl TryFrom for ServerNotification { type Error = serde_json::Error; - fn try_from(value: JSONRPCNotification) -> Result { + fn try_from(value: JSONRPCNotification) -> Result { serde_json::from_value(serde_json::to_value(value)?) } } @@ -434,6 +438,13 @@ server_request_definitions! { response: v2::CommandExecutionRequestApprovalResponse, }, + /// Sent when approval is requested for a specific file change. + /// This request is used for Turns started via turn/start. + FileChangeRequestApproval => "item/fileChange/requestApproval" { + params: v2::FileChangeRequestApprovalParams, + response: v2::FileChangeRequestApprovalResponse, + }, + /// DEPRECATED APIs below /// Request to approve a patch. /// This request is used for Turns started via the legacy APIs (i.e. SendUserTurn, SendUserMessage). @@ -476,6 +487,7 @@ pub struct FuzzyFileSearchResponse { server_notification_definitions! { /// NEW NOTIFICATIONS + Error => "error" (v2::ErrorNotification), ThreadStarted => "thread/started" (v2::ThreadStartedNotification), TurnStarted => "turn/started" (v2::TurnStartedNotification), TurnCompleted => "turn/completed" (v2::TurnCompletedNotification), @@ -490,6 +502,9 @@ server_notification_definitions! { ReasoningSummaryPartAdded => "item/reasoning/summaryPartAdded" (v2::ReasoningSummaryPartAddedNotification), ReasoningTextDelta => "item/reasoning/textDelta" (v2::ReasoningTextDeltaNotification), + /// Notifies the user of world-writable directories on Windows, which cannot be protected by the sandbox. + WindowsWorldWritableWarning => "windows/worldWritableWarning" (v2::WindowsWorldWritableWarningNotification), + #[serde(rename = "account/login/completed")] #[ts(rename = "account/login/completed")] #[strum(serialize = "account/login/completed")] @@ -524,7 +539,7 @@ mod tests { let request = ClientRequest::NewConversation { request_id: RequestId::Integer(42), params: v1::NewConversationParams { - model: Some("gpt-5.1-codex".to_string()), + model: Some("gpt-5.1-codex-max".to_string()), model_provider: None, profile: None, cwd: None, @@ -542,7 +557,7 @@ mod tests { "method": "newConversation", "id": 42, "params": { - "model": "gpt-5.1-codex", + "model": "gpt-5.1-codex-max", "modelProvider": null, "profile": null, "cwd": null, diff --git a/codex-rs/app-server-protocol/src/protocol/mod.rs b/codex-rs/app-server-protocol/src/protocol/mod.rs index 11edf04cc..8e2d63e06 100644 --- a/codex-rs/app-server-protocol/src/protocol/mod.rs +++ b/codex-rs/app-server-protocol/src/protocol/mod.rs @@ -2,5 +2,6 @@ // Exposes protocol pieces used by `lib.rs` via `pub use protocol::common::*;`. pub mod common; +pub mod thread_history; pub mod v1; pub mod v2; diff --git a/codex-rs/app-server-protocol/src/protocol/thread_history.rs b/codex-rs/app-server-protocol/src/protocol/thread_history.rs new file mode 100644 index 000000000..04cd1190b --- /dev/null +++ b/codex-rs/app-server-protocol/src/protocol/thread_history.rs @@ -0,0 +1,409 @@ +use crate::protocol::v2::ThreadItem; +use crate::protocol::v2::Turn; +use crate::protocol::v2::TurnStatus; +use crate::protocol::v2::UserInput; +use codex_protocol::protocol::AgentReasoningEvent; +use codex_protocol::protocol::AgentReasoningRawContentEvent; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::TurnAbortedEvent; +use codex_protocol::protocol::UserMessageEvent; + +/// Convert persisted [`EventMsg`] entries into a sequence of [`Turn`] values. +/// +/// The purpose of this is to convert the EventMsgs persisted in a rollout file +/// into a sequence of Turns and ThreadItems, which allows the client to render +/// the historical messages when resuming a thread. +pub fn build_turns_from_event_msgs(events: &[EventMsg]) -> Vec { + let mut builder = ThreadHistoryBuilder::new(); + for event in events { + builder.handle_event(event); + } + builder.finish() +} + +struct ThreadHistoryBuilder { + turns: Vec, + current_turn: Option, + next_turn_index: i64, + next_item_index: i64, +} + +impl ThreadHistoryBuilder { + fn new() -> Self { + Self { + turns: Vec::new(), + current_turn: None, + next_turn_index: 1, + next_item_index: 1, + } + } + + fn finish(mut self) -> Vec { + self.finish_current_turn(); + self.turns + } + + /// This function should handle all EventMsg variants that can be persisted in a rollout file. + /// See `should_persist_event_msg` in `codex-rs/core/rollout/policy.rs`. + fn handle_event(&mut self, event: &EventMsg) { + match event { + EventMsg::UserMessage(payload) => self.handle_user_message(payload), + EventMsg::AgentMessage(payload) => self.handle_agent_message(payload.message.clone()), + EventMsg::AgentReasoning(payload) => self.handle_agent_reasoning(payload), + EventMsg::AgentReasoningRawContent(payload) => { + self.handle_agent_reasoning_raw_content(payload) + } + EventMsg::TokenCount(_) => {} + EventMsg::EnteredReviewMode(_) => {} + EventMsg::ExitedReviewMode(_) => {} + EventMsg::UndoCompleted(_) => {} + EventMsg::TurnAborted(payload) => self.handle_turn_aborted(payload), + _ => {} + } + } + + fn handle_user_message(&mut self, payload: &UserMessageEvent) { + self.finish_current_turn(); + let mut turn = self.new_turn(); + let id = self.next_item_id(); + let content = self.build_user_inputs(payload); + turn.items.push(ThreadItem::UserMessage { id, content }); + self.current_turn = Some(turn); + } + + fn handle_agent_message(&mut self, text: String) { + if text.is_empty() { + return; + } + + let id = self.next_item_id(); + self.ensure_turn() + .items + .push(ThreadItem::AgentMessage { id, text }); + } + + fn handle_agent_reasoning(&mut self, payload: &AgentReasoningEvent) { + if payload.text.is_empty() { + return; + } + + // If the last item is a reasoning item, add the new text to the summary. + if let Some(ThreadItem::Reasoning { summary, .. }) = self.ensure_turn().items.last_mut() { + summary.push(payload.text.clone()); + return; + } + + // Otherwise, create a new reasoning item. + let id = self.next_item_id(); + self.ensure_turn().items.push(ThreadItem::Reasoning { + id, + summary: vec![payload.text.clone()], + content: Vec::new(), + }); + } + + fn handle_agent_reasoning_raw_content(&mut self, payload: &AgentReasoningRawContentEvent) { + if payload.text.is_empty() { + return; + } + + // If the last item is a reasoning item, add the new text to the content. + if let Some(ThreadItem::Reasoning { content, .. }) = self.ensure_turn().items.last_mut() { + content.push(payload.text.clone()); + return; + } + + // Otherwise, create a new reasoning item. + let id = self.next_item_id(); + self.ensure_turn().items.push(ThreadItem::Reasoning { + id, + summary: Vec::new(), + content: vec![payload.text.clone()], + }); + } + + fn handle_turn_aborted(&mut self, _payload: &TurnAbortedEvent) { + let Some(turn) = self.current_turn.as_mut() else { + return; + }; + turn.status = TurnStatus::Interrupted; + } + + fn finish_current_turn(&mut self) { + if let Some(turn) = self.current_turn.take() { + if turn.items.is_empty() { + return; + } + self.turns.push(turn.into()); + } + } + + fn new_turn(&mut self) -> PendingTurn { + PendingTurn { + id: self.next_turn_id(), + items: Vec::new(), + status: TurnStatus::Completed, + } + } + + fn ensure_turn(&mut self) -> &mut PendingTurn { + if self.current_turn.is_none() { + let turn = self.new_turn(); + return self.current_turn.insert(turn); + } + + if let Some(turn) = self.current_turn.as_mut() { + return turn; + } + + unreachable!("current turn must exist after initialization"); + } + + fn next_turn_id(&mut self) -> String { + let id = format!("turn-{}", self.next_turn_index); + self.next_turn_index += 1; + id + } + + fn next_item_id(&mut self) -> String { + let id = format!("item-{}", self.next_item_index); + self.next_item_index += 1; + id + } + + fn build_user_inputs(&self, payload: &UserMessageEvent) -> Vec { + let mut content = Vec::new(); + if !payload.message.trim().is_empty() { + content.push(UserInput::Text { + text: payload.message.clone(), + }); + } + if let Some(images) = &payload.images { + for image in images { + content.push(UserInput::Image { url: image.clone() }); + } + } + content + } +} + +struct PendingTurn { + id: String, + items: Vec, + status: TurnStatus, +} + +impl From for Turn { + fn from(value: PendingTurn) -> Self { + Self { + id: value.id, + items: value.items, + status: value.status, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use codex_protocol::protocol::AgentMessageEvent; + use codex_protocol::protocol::AgentReasoningEvent; + use codex_protocol::protocol::AgentReasoningRawContentEvent; + use codex_protocol::protocol::TurnAbortReason; + use codex_protocol::protocol::TurnAbortedEvent; + use codex_protocol::protocol::UserMessageEvent; + use pretty_assertions::assert_eq; + + #[test] + fn builds_multiple_turns_with_reasoning_items() { + let events = vec![ + EventMsg::UserMessage(UserMessageEvent { + message: "First turn".into(), + images: Some(vec!["https://example.com/one.png".into()]), + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "Hi there".into(), + }), + EventMsg::AgentReasoning(AgentReasoningEvent { + text: "thinking".into(), + }), + EventMsg::AgentReasoningRawContent(AgentReasoningRawContentEvent { + text: "full reasoning".into(), + }), + EventMsg::UserMessage(UserMessageEvent { + message: "Second turn".into(), + images: None, + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "Reply two".into(), + }), + ]; + + let turns = build_turns_from_event_msgs(&events); + assert_eq!(turns.len(), 2); + + let first = &turns[0]; + assert_eq!(first.id, "turn-1"); + assert_eq!(first.status, TurnStatus::Completed); + assert_eq!(first.items.len(), 3); + assert_eq!( + first.items[0], + ThreadItem::UserMessage { + id: "item-1".into(), + content: vec![ + UserInput::Text { + text: "First turn".into(), + }, + UserInput::Image { + url: "https://example.com/one.png".into(), + } + ], + } + ); + assert_eq!( + first.items[1], + ThreadItem::AgentMessage { + id: "item-2".into(), + text: "Hi there".into(), + } + ); + assert_eq!( + first.items[2], + ThreadItem::Reasoning { + id: "item-3".into(), + summary: vec!["thinking".into()], + content: vec!["full reasoning".into()], + } + ); + + let second = &turns[1]; + assert_eq!(second.id, "turn-2"); + assert_eq!(second.items.len(), 2); + assert_eq!( + second.items[0], + ThreadItem::UserMessage { + id: "item-4".into(), + content: vec![UserInput::Text { + text: "Second turn".into() + }], + } + ); + assert_eq!( + second.items[1], + ThreadItem::AgentMessage { + id: "item-5".into(), + text: "Reply two".into(), + } + ); + } + + #[test] + fn splits_reasoning_when_interleaved() { + let events = vec![ + EventMsg::UserMessage(UserMessageEvent { + message: "Turn start".into(), + images: None, + }), + EventMsg::AgentReasoning(AgentReasoningEvent { + text: "first summary".into(), + }), + EventMsg::AgentReasoningRawContent(AgentReasoningRawContentEvent { + text: "first content".into(), + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "interlude".into(), + }), + EventMsg::AgentReasoning(AgentReasoningEvent { + text: "second summary".into(), + }), + ]; + + let turns = build_turns_from_event_msgs(&events); + assert_eq!(turns.len(), 1); + let turn = &turns[0]; + assert_eq!(turn.items.len(), 4); + + assert_eq!( + turn.items[1], + ThreadItem::Reasoning { + id: "item-2".into(), + summary: vec!["first summary".into()], + content: vec!["first content".into()], + } + ); + assert_eq!( + turn.items[3], + ThreadItem::Reasoning { + id: "item-4".into(), + summary: vec!["second summary".into()], + content: Vec::new(), + } + ); + } + + #[test] + fn marks_turn_as_interrupted_when_aborted() { + let events = vec![ + EventMsg::UserMessage(UserMessageEvent { + message: "Please do the thing".into(), + images: None, + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "Working...".into(), + }), + EventMsg::TurnAborted(TurnAbortedEvent { + reason: TurnAbortReason::Replaced, + }), + EventMsg::UserMessage(UserMessageEvent { + message: "Let's try again".into(), + images: None, + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "Second attempt complete.".into(), + }), + ]; + + let turns = build_turns_from_event_msgs(&events); + assert_eq!(turns.len(), 2); + + let first_turn = &turns[0]; + assert_eq!(first_turn.status, TurnStatus::Interrupted); + assert_eq!(first_turn.items.len(), 2); + assert_eq!( + first_turn.items[0], + ThreadItem::UserMessage { + id: "item-1".into(), + content: vec![UserInput::Text { + text: "Please do the thing".into() + }], + } + ); + assert_eq!( + first_turn.items[1], + ThreadItem::AgentMessage { + id: "item-2".into(), + text: "Working...".into(), + } + ); + + let second_turn = &turns[1]; + assert_eq!(second_turn.status, TurnStatus::Completed); + assert_eq!(second_turn.items.len(), 2); + assert_eq!( + second_turn.items[0], + ThreadItem::UserMessage { + id: "item-3".into(), + content: vec![UserInput::Text { + text: "Let's try again".into() + }], + } + ); + assert_eq!( + second_turn.items[1], + ThreadItem::AgentMessage { + id: "item-4".into(), + text: "Second attempt complete.".into(), + } + ); + } +} diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index a2b9cee3f..46d248e22 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -11,6 +11,8 @@ use codex_protocol::items::AgentMessageContent as CoreAgentMessageContent; use codex_protocol::items::TurnItem as CoreTurnItem; use codex_protocol::models::ResponseItem; use codex_protocol::parse_command::ParsedCommand as CoreParsedCommand; +use codex_protocol::protocol::CodexErrorInfo as CoreCodexErrorInfo; +use codex_protocol::protocol::CreditsSnapshot as CoreCreditsSnapshot; use codex_protocol::protocol::RateLimitSnapshot as CoreRateLimitSnapshot; use codex_protocol::protocol::RateLimitWindow as CoreRateLimitWindow; use codex_protocol::user_input::UserInput as CoreUserInput; @@ -19,6 +21,7 @@ use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; use serde_json::Value as JsonValue; +use thiserror::Error; use ts_rs::TS; // Macro to declare a camelCased API v2 enum mirroring a core enum which @@ -46,6 +49,72 @@ macro_rules! v2_enum_from_core { }; } +/// This translation layer make sure that we expose codex error code in camel case. +/// +/// When an upstream HTTP status is available (for example, from the Responses API or a provider), +/// it is forwarded in `httpStatusCode` on the relevant `codexErrorInfo` variant. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub enum CodexErrorInfo { + ContextWindowExceeded, + UsageLimitExceeded, + HttpConnectionFailed { + #[serde(rename = "httpStatusCode")] + #[ts(rename = "httpStatusCode")] + http_status_code: Option, + }, + /// Failed to connect to the response SSE stream. + ResponseStreamConnectionFailed { + #[serde(rename = "httpStatusCode")] + #[ts(rename = "httpStatusCode")] + http_status_code: Option, + }, + InternalServerError, + Unauthorized, + BadRequest, + SandboxError, + /// The response SSE stream disconnected in the middle of a turn before completion. + ResponseStreamDisconnected { + #[serde(rename = "httpStatusCode")] + #[ts(rename = "httpStatusCode")] + http_status_code: Option, + }, + /// Reached the retry limit for responses. + ResponseTooManyFailedAttempts { + #[serde(rename = "httpStatusCode")] + #[ts(rename = "httpStatusCode")] + http_status_code: Option, + }, + Other, +} + +impl From for CodexErrorInfo { + fn from(value: CoreCodexErrorInfo) -> Self { + match value { + CoreCodexErrorInfo::ContextWindowExceeded => CodexErrorInfo::ContextWindowExceeded, + CoreCodexErrorInfo::UsageLimitExceeded => CodexErrorInfo::UsageLimitExceeded, + CoreCodexErrorInfo::HttpConnectionFailed { http_status_code } => { + CodexErrorInfo::HttpConnectionFailed { http_status_code } + } + CoreCodexErrorInfo::ResponseStreamConnectionFailed { http_status_code } => { + CodexErrorInfo::ResponseStreamConnectionFailed { http_status_code } + } + CoreCodexErrorInfo::InternalServerError => CodexErrorInfo::InternalServerError, + CoreCodexErrorInfo::Unauthorized => CodexErrorInfo::Unauthorized, + CoreCodexErrorInfo::BadRequest => CodexErrorInfo::BadRequest, + CoreCodexErrorInfo::SandboxError => CodexErrorInfo::SandboxError, + CoreCodexErrorInfo::ResponseStreamDisconnected { http_status_code } => { + CodexErrorInfo::ResponseStreamDisconnected { http_status_code } + } + CoreCodexErrorInfo::ResponseTooManyFailedAttempts { http_status_code } => { + CodexErrorInfo::ResponseTooManyFailedAttempts { http_status_code } + } + CoreCodexErrorInfo::Other => CodexErrorInfo::Other, + } + } +} + v2_enum_from_core!( pub enum AskForApproval from codex_protocol::protocol::AskForApproval { UnlessTrusted, OnFailure, OnRequest, Never @@ -402,6 +471,12 @@ pub struct ThreadStartParams { #[ts(export_to = "v2/")] pub struct ThreadStartResponse { pub thread: Thread, + pub model: String, + pub model_provider: String, + pub cwd: PathBuf, + pub approval_policy: AskForApproval, + pub sandbox: SandboxPolicy, + pub reasoning_effort: Option, } #[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, JsonSchema, TS)] @@ -444,6 +519,12 @@ pub struct ThreadResumeParams { #[ts(export_to = "v2/")] pub struct ThreadResumeResponse { pub thread: Thread, + pub model: String, + pub model_provider: String, + pub cwd: PathBuf, + pub approval_policy: AskForApproval, + pub sandbox: SandboxPolicy, + pub reasoning_effort: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] @@ -505,6 +586,10 @@ pub struct Thread { pub created_at: i64, /// [UNSTABLE] Path to the thread on disk. pub path: PathBuf, + /// Only populated on a `thread/resume` response. + /// For all other responses and notifications returning a Thread, + /// the turns field will be an empty list. + pub turns: Vec, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] @@ -519,25 +604,37 @@ pub struct AccountUpdatedNotification { #[ts(export_to = "v2/")] pub struct Turn { pub id: String, + /// Only populated on a `thread/resume` response. + /// For all other responses and notifications returning a Turn, + /// the items field will be an empty list. pub items: Vec, + #[serde(flatten)] pub status: TurnStatus, - pub error: Option, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS, Error)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] +#[error("{message}")] pub struct TurnError { pub message: String, + pub codex_error_info: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] +pub struct ErrorNotification { + pub error: TurnError, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(tag = "status", rename_all = "camelCase")] +#[ts(tag = "status", export_to = "v2/")] pub enum TurnStatus { Completed, Interrupted, - Failed, + Failed { error: TurnError }, InProgress, } @@ -562,6 +659,45 @@ pub struct TurnStartParams { pub summary: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct ReviewStartParams { + pub thread_id: String, + pub target: ReviewTarget, + + /// When true, also append the final review message to the original thread. + #[serde(default)] + pub append_to_original_thread: bool, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(tag = "type", rename_all = "camelCase")] +#[ts(tag = "type", export_to = "v2/")] +pub enum ReviewTarget { + /// Review the working tree: staged, unstaged, and untracked files. + UncommittedChanges, + + /// Review changes between the current branch and the given base branch. + #[serde(rename_all = "camelCase")] + #[ts(rename_all = "camelCase")] + BaseBranch { branch: String }, + + /// Review the changes introduced by a specific commit. + #[serde(rename_all = "camelCase")] + #[ts(rename_all = "camelCase")] + Commit { + sha: String, + /// Optional human-readable label (e.g., commit subject) for UIs. + title: Option, + }, + + /// Arbitrary instructions, equivalent to the old free-form prompt. + #[serde(rename_all = "camelCase")] + #[ts(rename_all = "camelCase")] + Custom { instructions: String }, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -723,6 +859,7 @@ pub enum CommandExecutionStatus { InProgress, Completed, Failed, + Declined, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] @@ -735,20 +872,23 @@ pub struct FileUpdateChange { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] -#[serde(rename_all = "camelCase")] +#[serde(tag = "type", rename_all = "camelCase")] +#[ts(tag = "type")] #[ts(export_to = "v2/")] pub enum PatchChangeKind { Add, Delete, - Update, + Update { move_path: Option }, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] pub enum PatchApplyStatus { + InProgress, Completed, Failed, + Declined, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] @@ -814,8 +954,6 @@ pub struct Usage { #[ts(export_to = "v2/")] pub struct TurnCompletedNotification { pub turn: Turn, - // TODO: should usage be stored on the Turn object, and we return that instead? - pub usage: Usage, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] @@ -883,6 +1021,15 @@ pub struct McpToolCallProgressNotification { pub message: String, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct WindowsWorldWritableWarningNotification { + pub sample_paths: Vec, + pub extra_count: usize, + pub failed_scan: bool, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -916,6 +1063,26 @@ pub struct CommandExecutionRequestApprovalResponse { pub accept_settings: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct FileChangeRequestApprovalParams { + pub thread_id: String, + pub turn_id: String, + pub item_id: String, + /// Optional explanatory reason (e.g. request for extra write access). + pub reason: Option, + /// [UNSTABLE] When set, the agent is asking the user to allow writes under this root + /// for the remainder of the session (unclear if this is honored today). + pub grant_root: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[ts(export_to = "v2/")] +pub struct FileChangeRequestApprovalResponse { + pub decision: ApprovalDecision, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -929,6 +1096,7 @@ pub struct AccountRateLimitsUpdatedNotification { pub struct RateLimitSnapshot { pub primary: Option, pub secondary: Option, + pub credits: Option, } impl From for RateLimitSnapshot { @@ -936,6 +1104,7 @@ impl From for RateLimitSnapshot { Self { primary: value.primary.map(RateLimitWindow::from), secondary: value.secondary.map(RateLimitWindow::from), + credits: value.credits.map(CreditsSnapshot::from), } } } @@ -959,6 +1128,25 @@ impl From for RateLimitWindow { } } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct CreditsSnapshot { + pub has_credits: bool, + pub unlimited: bool, + pub balance: Option, +} + +impl From for CreditsSnapshot { + fn from(value: CoreCreditsSnapshot) -> Self { + Self { + has_credits: value.has_credits, + unlimited: value.unlimited, + balance: value.balance, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -981,6 +1169,7 @@ mod tests { use codex_protocol::items::WebSearchItem; use codex_protocol::user_input::UserInput as CoreUserInput; use pretty_assertions::assert_eq; + use serde_json::json; use std::path::PathBuf; #[test] @@ -1066,4 +1255,20 @@ mod tests { } ); } + + #[test] + fn codex_error_info_serializes_http_status_code_in_camel_case() { + let value = CodexErrorInfo::ResponseTooManyFailedAttempts { + http_status_code: Some(401), + }; + + assert_eq!( + serde_json::to_value(value).unwrap(), + json!({ + "responseTooManyFailedAttempts": { + "httpStatusCode": 401 + } + }) + ); + } } diff --git a/codex-rs/app-server-test-client/src/main.rs b/codex-rs/app-server-test-client/src/main.rs index a243937b2..7f2b75bd2 100644 --- a/codex-rs/app-server-test-client/src/main.rs +++ b/codex-rs/app-server-test-client/src/main.rs @@ -17,15 +17,22 @@ use clap::Parser; use clap::Subcommand; use codex_app_server_protocol::AddConversationListenerParams; use codex_app_server_protocol::AddConversationSubscriptionResponse; +use codex_app_server_protocol::ApprovalDecision; use codex_app_server_protocol::AskForApproval; use codex_app_server_protocol::ClientInfo; use codex_app_server_protocol::ClientRequest; +use codex_app_server_protocol::CommandExecutionRequestAcceptSettings; +use codex_app_server_protocol::CommandExecutionRequestApprovalParams; +use codex_app_server_protocol::CommandExecutionRequestApprovalResponse; +use codex_app_server_protocol::FileChangeRequestApprovalParams; +use codex_app_server_protocol::FileChangeRequestApprovalResponse; use codex_app_server_protocol::GetAccountRateLimitsResponse; use codex_app_server_protocol::InitializeParams; use codex_app_server_protocol::InitializeResponse; use codex_app_server_protocol::InputItem; use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::LoginChatGptCompleteNotification; use codex_app_server_protocol::LoginChatGptResponse; @@ -36,14 +43,17 @@ use codex_app_server_protocol::SandboxPolicy; use codex_app_server_protocol::SendUserMessageParams; use codex_app_server_protocol::SendUserMessageResponse; use codex_app_server_protocol::ServerNotification; +use codex_app_server_protocol::ServerRequest; use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::ThreadStartResponse; use codex_app_server_protocol::TurnStartParams; use codex_app_server_protocol::TurnStartResponse; +use codex_app_server_protocol::TurnStatus; use codex_app_server_protocol::UserInput as V2UserInput; use codex_protocol::ConversationId; use codex_protocol::protocol::Event; use codex_protocol::protocol::EventMsg; +use serde::Serialize; use serde::de::DeserializeOwned; use serde_json::Value; use uuid::Uuid; @@ -502,10 +512,9 @@ impl CodexClient { ServerNotification::TurnCompleted(payload) => { if payload.turn.id == turn_id { println!("\n< turn/completed notification: {:?}", payload.turn.status); - if let Some(error) = payload.turn.error { + if let TurnStatus::Failed { error } = &payload.turn.status { println!("[turn error] {}", error.message); } - println!("< usage: {:?}", payload.usage); break; } } @@ -603,8 +612,8 @@ impl CodexClient { JSONRPCMessage::Notification(notification) => { self.pending_notifications.push_back(notification); } - JSONRPCMessage::Request(_) => { - bail!("unexpected request from codex app-server"); + JSONRPCMessage::Request(request) => { + self.handle_server_request(request)?; } } } @@ -624,8 +633,8 @@ impl CodexClient { // No outstanding requests, so ignore stray responses/errors for now. continue; } - JSONRPCMessage::Request(_) => { - bail!("unexpected request from codex app-server"); + JSONRPCMessage::Request(request) => { + self.handle_server_request(request)?; } } } @@ -661,6 +670,115 @@ impl CodexClient { fn request_id(&self) -> RequestId { RequestId::String(Uuid::new_v4().to_string()) } + + fn handle_server_request(&mut self, request: JSONRPCRequest) -> Result<()> { + let server_request = ServerRequest::try_from(request) + .context("failed to deserialize ServerRequest from JSONRPCRequest")?; + + match server_request { + ServerRequest::CommandExecutionRequestApproval { request_id, params } => { + self.handle_command_execution_request_approval(request_id, params)?; + } + ServerRequest::FileChangeRequestApproval { request_id, params } => { + self.approve_file_change_request(request_id, params)?; + } + other => { + bail!("received unsupported server request: {other:?}"); + } + } + + Ok(()) + } + + fn handle_command_execution_request_approval( + &mut self, + request_id: RequestId, + params: CommandExecutionRequestApprovalParams, + ) -> Result<()> { + let CommandExecutionRequestApprovalParams { + thread_id, + turn_id, + item_id, + reason, + risk, + } = params; + + println!( + "\n< commandExecution approval requested for thread {thread_id}, turn {turn_id}, item {item_id}" + ); + if let Some(reason) = reason.as_deref() { + println!("< reason: {reason}"); + } + if let Some(risk) = risk.as_ref() { + println!("< risk assessment: {risk:?}"); + } + + let response = CommandExecutionRequestApprovalResponse { + decision: ApprovalDecision::Accept, + accept_settings: Some(CommandExecutionRequestAcceptSettings { for_session: false }), + }; + self.send_server_request_response(request_id, &response)?; + println!("< approved commandExecution request for item {item_id}"); + Ok(()) + } + + fn approve_file_change_request( + &mut self, + request_id: RequestId, + params: FileChangeRequestApprovalParams, + ) -> Result<()> { + let FileChangeRequestApprovalParams { + thread_id, + turn_id, + item_id, + reason, + grant_root, + } = params; + + println!( + "\n< fileChange approval requested for thread {thread_id}, turn {turn_id}, item {item_id}" + ); + if let Some(reason) = reason.as_deref() { + println!("< reason: {reason}"); + } + if let Some(grant_root) = grant_root.as_deref() { + println!("< grant root: {}", grant_root.display()); + } + + let response = FileChangeRequestApprovalResponse { + decision: ApprovalDecision::Accept, + }; + self.send_server_request_response(request_id, &response)?; + println!("< approved fileChange request for item {item_id}"); + Ok(()) + } + + fn send_server_request_response(&mut self, request_id: RequestId, response: &T) -> Result<()> + where + T: Serialize, + { + let message = JSONRPCMessage::Response(JSONRPCResponse { + id: request_id, + result: serde_json::to_value(response)?, + }); + self.write_jsonrpc_message(message) + } + + fn write_jsonrpc_message(&mut self, message: JSONRPCMessage) -> Result<()> { + let payload = serde_json::to_string(&message)?; + let pretty = serde_json::to_string_pretty(&message)?; + print_multiline_with_prefix("> ", &pretty); + + if let Some(stdin) = self.stdin.as_mut() { + writeln!(stdin, "{payload}")?; + stdin + .flush() + .context("failed to flush response to codex app-server")?; + return Ok(()); + } + + bail!("codex app-server stdin closed") + } } fn print_multiline_with_prefix(prefix: &str, payload: &str) { diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index 96f64afdf..0e7849b16 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -53,3 +53,4 @@ serial_test = { workspace = true } tempfile = { workspace = true } toml = { workspace = true } wiremock = { workspace = true } +shlex = { workspace = true } diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 5f9b87458..d5ac8a57d 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -65,6 +65,7 @@ The JSON-RPC API exposes dedicated methods for managing Codex conversations. Thr - `thread/archive` — move a thread’s rollout file into the archived directory; returns `{}` on success. - `turn/start` — add user input to a thread and begin Codex generation; responds with the initial `turn` object and streams `turn/started`, `item/*`, and `turn/completed` notifications. - `turn/interrupt` — request cancellation of an in-flight turn by `(thread_id, turn_id)`; success is an empty `{}` response and the turn finishes with `status: "interrupted"`. +- `review/start` — kick off Codex’s automated reviewer for a thread; responds like `turn/start` and emits a `item/completed` notification with a `codeReview` item when results are ready. ### 1) Start or resume a thread @@ -181,6 +182,58 @@ You can cancel a running Turn with `turn/interrupt`. The server requests cancellations for running subprocesses, then emits a `turn/completed` event with `status: "interrupted"`. Rely on the `turn/completed` to know when Codex-side cleanup is done. +### 6) Request a code review + +Use `review/start` to run Codex’s reviewer on the currently checked-out project. The request takes the thread id plus a `target` describing what should be reviewed: + +- `{"type":"uncommittedChanges"}` — staged, unstaged, and untracked files. +- `{"type":"baseBranch","branch":"main"}` — diff against the provided branch’s upstream (see prompt for the exact `git merge-base`/`git diff` instructions Codex will run). +- `{"type":"commit","sha":"abc1234","title":"Optional subject"}` — review a specific commit. +- `{"type":"custom","instructions":"Free-form reviewer instructions"}` — fallback prompt equivalent to the legacy manual review request. +- `appendToOriginalThread` (bool, default `false`) — when `true`, Codex also records a final assistant-style message with the review summary in the original thread. When `false`, only the `codeReview` item is emitted for the review run and no extra message is added to the original thread. + +Example request/response: + +```json +{ "method": "review/start", "id": 40, "params": { + "threadId": "thr_123", + "appendToOriginalThread": true, + "target": { "type": "commit", "sha": "1234567deadbeef", "title": "Polish tui colors" } +} } +{ "id": 40, "result": { "turn": { + "id": "turn_900", + "status": "inProgress", + "items": [ + { "type": "userMessage", "id": "turn_900", "content": [ { "type": "text", "text": "Review commit 1234567: Polish tui colors" } ] } + ], + "error": null +} } } +``` + +Codex streams the usual `turn/started` notification followed by an `item/started` +with the same `codeReview` item id so clients can show progress: + +```json +{ "method": "item/started", "params": { "item": { + "type": "codeReview", + "id": "turn_900", + "review": "current changes" +} } } +``` + +When the reviewer finishes, the server emits `item/completed` containing the same +`codeReview` item with the final review text: + +```json +{ "method": "item/completed", "params": { "item": { + "type": "codeReview", + "id": "turn_900", + "review": "Looks solid overall...\n\n- Prefer Stylize helpers — app.rs:10-20\n ..." +} } } +``` + +The `review` string is plain text that already bundles the overall explanation plus a bullet list for each structured finding (matching `ThreadItem::CodeReview` in the generated schema). Use this notification to render the reviewer output in your client. + ## Auth endpoints The JSON-RPC auth/account surface exposes request/response methods plus server-initiated notifications (no `id`). Use these to determine auth state, start or cancel logins, logout, and inspect ChatGPT rate limits. @@ -286,6 +339,29 @@ Event notifications are the server-initiated event stream for thread lifecycles, The app-server streams JSON-RPC notifications while a turn is running. Each turn starts with `turn/started` (initial `turn`) and ends with `turn/completed` (final `turn` plus token `usage`), and clients subscribe to the events they care about, rendering each item incrementally as updates arrive. The per-item lifecycle is always: `item/started` → zero or more item-specific deltas → `item/completed`. +- `turn/started` — `{ turn }` with the turn id, empty `items`, and `status: "inProgress"`. +- `turn/completed` — `{ turn }` where `turn.status` is `completed`, `interrupted`, or `failed`; failures carry `{ error: { message, codexErrorInfo? } }`. + +Today both notifications carry an empty `items` array even when item events were streamed; rely on `item/*` notifications for the canonical item list until this is fixed. + +#### Errors +`error` event is emitted whenever the server hits an error mid-turn (for example, upstream model errors or quota limits). Carries the same `{ error: { message, codexErrorInfo? } }` payload as `turn.status: "failed"` and may precede that terminal notification. + + `codexErrorInfo` maps to the `CodexErrorInfo` enum. Common values: + - `ContextWindowExceeded` + - `UsageLimitExceeded` + - `HttpConnectionFailed { httpStatusCode? }`: upstream HTTP failures including 4xx/5xx + - `ResponseStreamConnectionFailed { httpStatusCode? }`: failure to connect to the response SSE stream + - `ResponseStreamDisconnected { httpStatusCode? }`: disconnect of the response SSE stream in the middle of a turn before completion + - `ResponseTooManyFailedAttempts { httpStatusCode? }` + - `BadRequest` + - `Unauthorized` + - `SandboxError` + - `InternalServerError` + - `Other`: all unclassified errors + +When an upstream HTTP status is available (for example, from the Responses API or a provider), it is forwarded in `httpStatusCode` on the relevant `codexErrorInfo` variant. + #### Thread items `ThreadItem` is the tagged union carried in turn responses and `item/*` notifications. Currently we support events for the following items: diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index 8ed343f03..0c2445d85 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -1,24 +1,33 @@ use crate::codex_message_processor::ApiVersion; use crate::codex_message_processor::PendingInterrupts; +use crate::codex_message_processor::TurnSummary; +use crate::codex_message_processor::TurnSummaryStore; use crate::outgoing_message::OutgoingMessageSender; use codex_app_server_protocol::AccountRateLimitsUpdatedNotification; use codex_app_server_protocol::AgentMessageDeltaNotification; use codex_app_server_protocol::ApplyPatchApprovalParams; use codex_app_server_protocol::ApplyPatchApprovalResponse; use codex_app_server_protocol::ApprovalDecision; +use codex_app_server_protocol::CodexErrorInfo as V2CodexErrorInfo; use codex_app_server_protocol::CommandAction as V2ParsedCommand; use codex_app_server_protocol::CommandExecutionOutputDeltaNotification; use codex_app_server_protocol::CommandExecutionRequestApprovalParams; use codex_app_server_protocol::CommandExecutionRequestApprovalResponse; use codex_app_server_protocol::CommandExecutionStatus; +use codex_app_server_protocol::ErrorNotification; use codex_app_server_protocol::ExecCommandApprovalParams; use codex_app_server_protocol::ExecCommandApprovalResponse; +use codex_app_server_protocol::FileChangeRequestApprovalParams; +use codex_app_server_protocol::FileChangeRequestApprovalResponse; +use codex_app_server_protocol::FileUpdateChange; use codex_app_server_protocol::InterruptConversationResponse; use codex_app_server_protocol::ItemCompletedNotification; use codex_app_server_protocol::ItemStartedNotification; use codex_app_server_protocol::McpToolCallError; use codex_app_server_protocol::McpToolCallResult; use codex_app_server_protocol::McpToolCallStatus; +use codex_app_server_protocol::PatchApplyStatus; +use codex_app_server_protocol::PatchChangeKind as V2PatchChangeKind; use codex_app_server_protocol::ReasoningSummaryPartAddedNotification; use codex_app_server_protocol::ReasoningSummaryTextDeltaNotification; use codex_app_server_protocol::ReasoningTextDeltaNotification; @@ -26,7 +35,11 @@ use codex_app_server_protocol::SandboxCommandAssessment as V2SandboxCommandAsses use codex_app_server_protocol::ServerNotification; use codex_app_server_protocol::ServerRequestPayload; use codex_app_server_protocol::ThreadItem; +use codex_app_server_protocol::Turn; +use codex_app_server_protocol::TurnCompletedNotification; +use codex_app_server_protocol::TurnError; use codex_app_server_protocol::TurnInterruptResponse; +use codex_app_server_protocol::TurnStatus; use codex_core::CodexConversation; use codex_core::parse_command::shlex_join; use codex_core::protocol::ApplyPatchApprovalRequestEvent; @@ -34,12 +47,17 @@ use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::ExecCommandEndEvent; +use codex_core::protocol::FileChange as CoreFileChange; use codex_core::protocol::McpToolCallBeginEvent; use codex_core::protocol::McpToolCallEndEvent; use codex_core::protocol::Op; use codex_core::protocol::ReviewDecision; +use codex_core::review_format::format_review_findings_block; use codex_protocol::ConversationId; +use codex_protocol::protocol::ReviewOutputEvent; +use std::collections::HashMap; use std::convert::TryFrom; +use std::path::PathBuf; use std::sync::Arc; use tokio::sync::oneshot; use tracing::error; @@ -52,30 +70,84 @@ pub(crate) async fn apply_bespoke_event_handling( conversation: Arc, outgoing: Arc, pending_interrupts: PendingInterrupts, + turn_summary_store: TurnSummaryStore, api_version: ApiVersion, ) { let Event { id: event_id, msg } = event; match msg { + EventMsg::TaskComplete(_ev) => { + handle_turn_complete(conversation_id, event_id, &outgoing, &turn_summary_store).await; + } EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id, + turn_id, changes, reason, grant_root, - }) => { - let params = ApplyPatchApprovalParams { - conversation_id, - call_id, - file_changes: changes, - reason, - grant_root, - }; - let rx = outgoing - .send_request(ServerRequestPayload::ApplyPatchApproval(params)) - .await; - tokio::spawn(async move { - on_patch_approval_response(event_id, rx, conversation).await; - }); - } + }) => match api_version { + ApiVersion::V1 => { + let params = ApplyPatchApprovalParams { + conversation_id, + call_id, + file_changes: changes.clone(), + reason, + grant_root, + }; + let rx = outgoing + .send_request(ServerRequestPayload::ApplyPatchApproval(params)) + .await; + tokio::spawn(async move { + on_patch_approval_response(event_id, rx, conversation).await; + }); + } + ApiVersion::V2 => { + // Until we migrate the core to be aware of a first class FileChangeItem + // and emit the corresponding EventMsg, we repurpose the call_id as the item_id. + let item_id = call_id.clone(); + let patch_changes = convert_patch_changes(&changes); + + let first_start = { + let mut map = turn_summary_store.lock().await; + let summary = map.entry(conversation_id).or_default(); + summary.file_change_started.insert(item_id.clone()) + }; + if first_start { + let item = ThreadItem::FileChange { + id: item_id.clone(), + changes: patch_changes.clone(), + status: PatchApplyStatus::InProgress, + }; + let notification = ItemStartedNotification { item }; + outgoing + .send_server_notification(ServerNotification::ItemStarted(notification)) + .await; + } + + let params = FileChangeRequestApprovalParams { + thread_id: conversation_id.to_string(), + turn_id: turn_id.clone(), + item_id: item_id.clone(), + reason, + grant_root, + }; + let rx = outgoing + .send_request(ServerRequestPayload::FileChangeRequestApproval(params)) + .await; + tokio::spawn(async move { + on_file_change_request_approval_response( + event_id, + conversation_id, + item_id, + patch_changes, + rx, + conversation, + outgoing, + turn_summary_store, + ) + .await; + }); + } + }, EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { call_id, turn_id, @@ -103,12 +175,20 @@ pub(crate) async fn apply_bespoke_event_handling( }); } ApiVersion::V2 => { + let item_id = call_id.clone(); + let command_actions = parsed_cmd + .iter() + .cloned() + .map(V2ParsedCommand::from) + .collect::>(); + let command_string = shlex_join(&command); + let params = CommandExecutionRequestApprovalParams { thread_id: conversation_id.to_string(), turn_id: turn_id.clone(), // Until we migrate the core to be aware of a first class CommandExecutionItem // and emit the corresponding EventMsg, we repurpose the call_id as the item_id. - item_id: call_id.clone(), + item_id: item_id.clone(), reason, risk: risk.map(V2SandboxCommandAssessment::from), }; @@ -118,8 +198,17 @@ pub(crate) async fn apply_bespoke_event_handling( )) .await; tokio::spawn(async move { - on_command_execution_request_approval_response(event_id, rx, conversation) - .await; + on_command_execution_request_approval_response( + event_id, + item_id, + command_string, + cwd, + command_actions, + rx, + conversation, + outgoing, + ) + .await; }); } }, @@ -189,6 +278,42 @@ pub(crate) async fn apply_bespoke_event_handling( .await; } } + EventMsg::Error(ev) => { + let turn_error = TurnError { + message: ev.message, + codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from), + }; + handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await; + outgoing + .send_server_notification(ServerNotification::Error(ErrorNotification { + error: turn_error, + })) + .await; + } + EventMsg::StreamError(ev) => { + // We don't need to update the turn summary store for stream errors as they are intermediate error states for retries, + // but we notify the client. + let turn_error = TurnError { + message: ev.message, + codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from), + }; + outgoing + .send_server_notification(ServerNotification::Error(ErrorNotification { + error: turn_error, + })) + .await; + } + EventMsg::EnteredReviewMode(review_request) => { + let notification = ItemStartedNotification { + item: ThreadItem::CodeReview { + id: event_id.clone(), + review: review_request.user_facing_hint, + }, + }; + outgoing + .send_server_notification(ServerNotification::ItemStarted(notification)) + .await; + } EventMsg::ItemStarted(item_started_event) => { let item: ThreadItem = item_started_event.item.clone().into(); let notification = ItemStartedNotification { item }; @@ -203,17 +328,80 @@ pub(crate) async fn apply_bespoke_event_handling( .send_server_notification(ServerNotification::ItemCompleted(notification)) .await; } + EventMsg::ExitedReviewMode(review_event) => { + let review_text = match review_event.review_output { + Some(output) => render_review_output_text(&output), + None => REVIEW_FALLBACK_MESSAGE.to_string(), + }; + let notification = ItemCompletedNotification { + item: ThreadItem::CodeReview { + id: event_id, + review: review_text, + }, + }; + outgoing + .send_server_notification(ServerNotification::ItemCompleted(notification)) + .await; + } + EventMsg::PatchApplyBegin(patch_begin_event) => { + // Until we migrate the core to be aware of a first class FileChangeItem + // and emit the corresponding EventMsg, we repurpose the call_id as the item_id. + let item_id = patch_begin_event.call_id.clone(); + + let first_start = { + let mut map = turn_summary_store.lock().await; + let summary = map.entry(conversation_id).or_default(); + summary.file_change_started.insert(item_id.clone()) + }; + if first_start { + let item = ThreadItem::FileChange { + id: item_id.clone(), + changes: convert_patch_changes(&patch_begin_event.changes), + status: PatchApplyStatus::InProgress, + }; + let notification = ItemStartedNotification { item }; + outgoing + .send_server_notification(ServerNotification::ItemStarted(notification)) + .await; + } + } + EventMsg::PatchApplyEnd(patch_end_event) => { + // Until we migrate the core to be aware of a first class FileChangeItem + // and emit the corresponding EventMsg, we repurpose the call_id as the item_id. + let item_id = patch_end_event.call_id.clone(); + + let status = if patch_end_event.success { + PatchApplyStatus::Completed + } else { + PatchApplyStatus::Failed + }; + let changes = convert_patch_changes(&patch_end_event.changes); + complete_file_change_item( + conversation_id, + item_id, + changes, + status, + outgoing.as_ref(), + &turn_summary_store, + ) + .await; + } EventMsg::ExecCommandBegin(exec_command_begin_event) => { + let item_id = exec_command_begin_event.call_id.clone(); + let command_actions = exec_command_begin_event + .parsed_cmd + .into_iter() + .map(V2ParsedCommand::from) + .collect::>(); + let command = shlex_join(&exec_command_begin_event.command); + let cwd = exec_command_begin_event.cwd; + let item = ThreadItem::CommandExecution { - id: exec_command_begin_event.call_id.clone(), - command: shlex_join(&exec_command_begin_event.command), - cwd: exec_command_begin_event.cwd, + id: item_id, + command, + cwd, status: CommandExecutionStatus::InProgress, - command_actions: exec_command_begin_event - .parsed_cmd - .into_iter() - .map(V2ParsedCommand::from) - .collect(), + command_actions, aggregated_output: None, exit_code: None, duration_ms: None, @@ -251,6 +439,10 @@ pub(crate) async fn apply_bespoke_event_handling( } else { CommandExecutionStatus::Failed }; + let command_actions = parsed_cmd + .into_iter() + .map(V2ParsedCommand::from) + .collect::>(); let aggregated_output = if aggregated_output.is_empty() { None @@ -265,7 +457,7 @@ pub(crate) async fn apply_bespoke_event_handling( command: shlex_join(&command), cwd, status, - command_actions: parsed_cmd.into_iter().map(V2ParsedCommand::from).collect(), + command_actions, aggregated_output, exit_code: Some(exit_code), duration_ms: Some(duration_ms), @@ -298,12 +490,127 @@ pub(crate) async fn apply_bespoke_event_handling( } } } + + handle_turn_interrupted(conversation_id, event_id, &outgoing, &turn_summary_store) + .await; } _ => {} } } +async fn emit_turn_completed_with_status( + event_id: String, + status: TurnStatus, + outgoing: &OutgoingMessageSender, +) { + let notification = TurnCompletedNotification { + turn: Turn { + id: event_id, + items: vec![], + status, + }, + }; + outgoing + .send_server_notification(ServerNotification::TurnCompleted(notification)) + .await; +} + +async fn complete_file_change_item( + conversation_id: ConversationId, + item_id: String, + changes: Vec, + status: PatchApplyStatus, + outgoing: &OutgoingMessageSender, + turn_summary_store: &TurnSummaryStore, +) { + { + let mut map = turn_summary_store.lock().await; + if let Some(summary) = map.get_mut(&conversation_id) { + summary.file_change_started.remove(&item_id); + } + } + + let item = ThreadItem::FileChange { + id: item_id, + changes, + status, + }; + let notification = ItemCompletedNotification { item }; + outgoing + .send_server_notification(ServerNotification::ItemCompleted(notification)) + .await; +} + +async fn complete_command_execution_item( + item_id: String, + command: String, + cwd: PathBuf, + command_actions: Vec, + status: CommandExecutionStatus, + outgoing: &OutgoingMessageSender, +) { + let item = ThreadItem::CommandExecution { + id: item_id, + command, + cwd, + status, + command_actions, + aggregated_output: None, + exit_code: None, + duration_ms: None, + }; + let notification = ItemCompletedNotification { item }; + outgoing + .send_server_notification(ServerNotification::ItemCompleted(notification)) + .await; +} + +async fn find_and_remove_turn_summary( + conversation_id: ConversationId, + turn_summary_store: &TurnSummaryStore, +) -> TurnSummary { + let mut map = turn_summary_store.lock().await; + map.remove(&conversation_id).unwrap_or_default() +} + +async fn handle_turn_complete( + conversation_id: ConversationId, + event_id: String, + outgoing: &OutgoingMessageSender, + turn_summary_store: &TurnSummaryStore, +) { + let turn_summary = find_and_remove_turn_summary(conversation_id, turn_summary_store).await; + + let status = if let Some(error) = turn_summary.last_error { + TurnStatus::Failed { error } + } else { + TurnStatus::Completed + }; + + emit_turn_completed_with_status(event_id, status, outgoing).await; +} + +async fn handle_turn_interrupted( + conversation_id: ConversationId, + event_id: String, + outgoing: &OutgoingMessageSender, + turn_summary_store: &TurnSummaryStore, +) { + find_and_remove_turn_summary(conversation_id, turn_summary_store).await; + + emit_turn_completed_with_status(event_id, TurnStatus::Interrupted, outgoing).await; +} + +async fn handle_error( + conversation_id: ConversationId, + error: TurnError, + turn_summary_store: &TurnSummaryStore, +) { + let mut map = turn_summary_store.lock().await; + map.entry(conversation_id).or_default().last_error = Some(error); +} + async fn on_patch_approval_response( event_id: String, receiver: oneshot::Receiver, @@ -382,42 +689,194 @@ async fn on_exec_approval_response( } } -async fn on_command_execution_request_approval_response( +const REVIEW_FALLBACK_MESSAGE: &str = "Reviewer failed to output a response."; + +fn render_review_output_text(output: &ReviewOutputEvent) -> String { + let mut sections = Vec::new(); + let explanation = output.overall_explanation.trim(); + if !explanation.is_empty() { + sections.push(explanation.to_string()); + } + if !output.findings.is_empty() { + let findings = format_review_findings_block(&output.findings, None); + let trimmed = findings.trim(); + if !trimmed.is_empty() { + sections.push(trimmed.to_string()); + } + } + if sections.is_empty() { + REVIEW_FALLBACK_MESSAGE.to_string() + } else { + sections.join("\n\n") + } +} + +fn convert_patch_changes(changes: &HashMap) -> Vec { + let mut converted: Vec = changes + .iter() + .map(|(path, change)| FileUpdateChange { + path: path.to_string_lossy().into_owned(), + kind: map_patch_change_kind(change), + diff: format_file_change_diff(change), + }) + .collect(); + converted.sort_by(|a, b| a.path.cmp(&b.path)); + converted +} + +fn map_patch_change_kind(change: &CoreFileChange) -> V2PatchChangeKind { + match change { + CoreFileChange::Add { .. } => V2PatchChangeKind::Add, + CoreFileChange::Delete { .. } => V2PatchChangeKind::Delete, + CoreFileChange::Update { move_path, .. } => V2PatchChangeKind::Update { + move_path: move_path.clone(), + }, + } +} + +fn format_file_change_diff(change: &CoreFileChange) -> String { + match change { + CoreFileChange::Add { content } => content.clone(), + CoreFileChange::Delete { content } => content.clone(), + CoreFileChange::Update { + unified_diff, + move_path, + } => { + if let Some(path) = move_path { + format!("{unified_diff}\n\nMoved to: {}", path.display()) + } else { + unified_diff.clone() + } + } + } +} + +#[allow(clippy::too_many_arguments)] +async fn on_file_change_request_approval_response( event_id: String, + conversation_id: ConversationId, + item_id: String, + changes: Vec, receiver: oneshot::Receiver, - conversation: Arc, + codex: Arc, + outgoing: Arc, + turn_summary_store: TurnSummaryStore, ) { let response = receiver.await; - let value = match response { - Ok(value) => value, + let (decision, completion_status) = match response { + Ok(value) => { + let response = serde_json::from_value::(value) + .unwrap_or_else(|err| { + error!("failed to deserialize FileChangeRequestApprovalResponse: {err}"); + FileChangeRequestApprovalResponse { + decision: ApprovalDecision::Decline, + } + }); + + let (decision, completion_status) = match response.decision { + ApprovalDecision::Accept => (ReviewDecision::Approved, None), + ApprovalDecision::Decline => { + (ReviewDecision::Denied, Some(PatchApplyStatus::Declined)) + } + ApprovalDecision::Cancel => { + (ReviewDecision::Abort, Some(PatchApplyStatus::Declined)) + } + }; + // Allow EventMsg::PatchApplyEnd to emit ItemCompleted for accepted patches. + // Only short-circuit on declines/cancels/failures. + (decision, completion_status) + } Err(err) => { error!("request failed: {err:?}"); - return; + (ReviewDecision::Denied, Some(PatchApplyStatus::Failed)) } }; - let response = serde_json::from_value::(value) - .unwrap_or_else(|err| { - error!("failed to deserialize CommandExecutionRequestApprovalResponse: {err}"); - CommandExecutionRequestApprovalResponse { - decision: ApprovalDecision::Decline, - accept_settings: None, - } - }); + if let Some(status) = completion_status { + complete_file_change_item( + conversation_id, + item_id, + changes, + status, + outgoing.as_ref(), + &turn_summary_store, + ) + .await; + } - let CommandExecutionRequestApprovalResponse { - decision, - accept_settings, - } = response; + if let Err(err) = codex + .submit(Op::PatchApproval { + id: event_id, + decision, + }) + .await + { + error!("failed to submit PatchApproval: {err}"); + } +} - let decision = match (decision, accept_settings) { - (ApprovalDecision::Accept, Some(settings)) if settings.for_session => { - ReviewDecision::ApprovedForSession +#[allow(clippy::too_many_arguments)] +async fn on_command_execution_request_approval_response( + event_id: String, + item_id: String, + command: String, + cwd: PathBuf, + command_actions: Vec, + receiver: oneshot::Receiver, + conversation: Arc, + outgoing: Arc, +) { + let response = receiver.await; + let (decision, completion_status) = match response { + Ok(value) => { + let response = serde_json::from_value::(value) + .unwrap_or_else(|err| { + error!("failed to deserialize CommandExecutionRequestApprovalResponse: {err}"); + CommandExecutionRequestApprovalResponse { + decision: ApprovalDecision::Decline, + accept_settings: None, + } + }); + + let CommandExecutionRequestApprovalResponse { + decision, + accept_settings, + } = response; + + let (decision, completion_status) = match (decision, accept_settings) { + (ApprovalDecision::Accept, Some(settings)) if settings.for_session => { + (ReviewDecision::ApprovedForSession, None) + } + (ApprovalDecision::Accept, _) => (ReviewDecision::Approved, None), + (ApprovalDecision::Decline, _) => ( + ReviewDecision::Denied, + Some(CommandExecutionStatus::Declined), + ), + (ApprovalDecision::Cancel, _) => ( + ReviewDecision::Abort, + Some(CommandExecutionStatus::Declined), + ), + }; + (decision, completion_status) + } + Err(err) => { + error!("request failed: {err:?}"); + (ReviewDecision::Denied, Some(CommandExecutionStatus::Failed)) } - (ApprovalDecision::Accept, _) => ReviewDecision::Approved, - (ApprovalDecision::Decline, _) => ReviewDecision::Denied, - (ApprovalDecision::Cancel, _) => ReviewDecision::Abort, }; + + if let Some(status) = completion_status { + complete_command_execution_item( + item_id.clone(), + command.clone(), + cwd.clone(), + command_actions.clone(), + status, + outgoing.as_ref(), + ) + .await; + } + if let Err(err) = conversation .submit(Op::ExecApproval { id: event_id, @@ -486,13 +945,171 @@ async fn construct_mcp_tool_call_end_notification( #[cfg(test)] mod tests { use super::*; + use crate::CHANNEL_CAPACITY; + use crate::outgoing_message::OutgoingMessage; + use crate::outgoing_message::OutgoingMessageSender; + use anyhow::Result; + use anyhow::anyhow; + use anyhow::bail; use codex_core::protocol::McpInvocation; use mcp_types::CallToolResult; use mcp_types::ContentBlock; use mcp_types::TextContent; use pretty_assertions::assert_eq; use serde_json::Value as JsonValue; + use std::collections::HashMap; use std::time::Duration; + use tokio::sync::Mutex; + use tokio::sync::mpsc; + + fn new_turn_summary_store() -> TurnSummaryStore { + Arc::new(Mutex::new(HashMap::new())) + } + + #[tokio::test] + async fn test_handle_error_records_message() -> Result<()> { + let conversation_id = ConversationId::new(); + let turn_summary_store = new_turn_summary_store(); + + handle_error( + conversation_id, + TurnError { + message: "boom".to_string(), + codex_error_info: Some(V2CodexErrorInfo::InternalServerError), + }, + &turn_summary_store, + ) + .await; + + let turn_summary = find_and_remove_turn_summary(conversation_id, &turn_summary_store).await; + assert_eq!( + turn_summary.last_error, + Some(TurnError { + message: "boom".to_string(), + codex_error_info: Some(V2CodexErrorInfo::InternalServerError), + }) + ); + Ok(()) + } + + #[tokio::test] + async fn test_handle_turn_complete_emits_completed_without_error() -> Result<()> { + let conversation_id = ConversationId::new(); + let event_id = "complete1".to_string(); + let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); + let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let turn_summary_store = new_turn_summary_store(); + + handle_turn_complete( + conversation_id, + event_id.clone(), + &outgoing, + &turn_summary_store, + ) + .await; + + let msg = rx + .recv() + .await + .ok_or_else(|| anyhow!("should send one notification"))?; + match msg { + OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { + assert_eq!(n.turn.id, event_id); + assert_eq!(n.turn.status, TurnStatus::Completed); + } + other => bail!("unexpected message: {other:?}"), + } + assert!(rx.try_recv().is_err(), "no extra messages expected"); + Ok(()) + } + + #[tokio::test] + async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> { + let conversation_id = ConversationId::new(); + let event_id = "interrupt1".to_string(); + let turn_summary_store = new_turn_summary_store(); + handle_error( + conversation_id, + TurnError { + message: "oops".to_string(), + codex_error_info: None, + }, + &turn_summary_store, + ) + .await; + let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); + let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + + handle_turn_interrupted( + conversation_id, + event_id.clone(), + &outgoing, + &turn_summary_store, + ) + .await; + + let msg = rx + .recv() + .await + .ok_or_else(|| anyhow!("should send one notification"))?; + match msg { + OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { + assert_eq!(n.turn.id, event_id); + assert_eq!(n.turn.status, TurnStatus::Interrupted); + } + other => bail!("unexpected message: {other:?}"), + } + assert!(rx.try_recv().is_err(), "no extra messages expected"); + Ok(()) + } + + #[tokio::test] + async fn test_handle_turn_complete_emits_failed_with_error() -> Result<()> { + let conversation_id = ConversationId::new(); + let event_id = "complete_err1".to_string(); + let turn_summary_store = new_turn_summary_store(); + handle_error( + conversation_id, + TurnError { + message: "bad".to_string(), + codex_error_info: Some(V2CodexErrorInfo::Other), + }, + &turn_summary_store, + ) + .await; + let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); + let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + + handle_turn_complete( + conversation_id, + event_id.clone(), + &outgoing, + &turn_summary_store, + ) + .await; + + let msg = rx + .recv() + .await + .ok_or_else(|| anyhow!("should send one notification"))?; + match msg { + OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { + assert_eq!(n.turn.id, event_id); + assert_eq!( + n.turn.status, + TurnStatus::Failed { + error: TurnError { + message: "bad".to_string(), + codex_error_info: Some(V2CodexErrorInfo::Other), + } + } + ); + } + other => bail!("unexpected message: {other:?}"), + } + assert!(rx.try_recv().is_err(), "no extra messages expected"); + Ok(()) + } #[tokio::test] async fn test_construct_mcp_tool_call_begin_notification_with_args() { @@ -522,6 +1139,123 @@ mod tests { assert_eq!(notification, expected); } + #[tokio::test] + async fn test_handle_turn_complete_emits_error_multiple_turns() -> Result<()> { + // Conversation A will have two turns; Conversation B will have one turn. + let conversation_a = ConversationId::new(); + let conversation_b = ConversationId::new(); + let turn_summary_store = new_turn_summary_store(); + + let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); + let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + + // Turn 1 on conversation A + let a_turn1 = "a_turn1".to_string(); + handle_error( + conversation_a, + TurnError { + message: "a1".to_string(), + codex_error_info: Some(V2CodexErrorInfo::BadRequest), + }, + &turn_summary_store, + ) + .await; + handle_turn_complete( + conversation_a, + a_turn1.clone(), + &outgoing, + &turn_summary_store, + ) + .await; + + // Turn 1 on conversation B + let b_turn1 = "b_turn1".to_string(); + handle_error( + conversation_b, + TurnError { + message: "b1".to_string(), + codex_error_info: None, + }, + &turn_summary_store, + ) + .await; + handle_turn_complete( + conversation_b, + b_turn1.clone(), + &outgoing, + &turn_summary_store, + ) + .await; + + // Turn 2 on conversation A + let a_turn2 = "a_turn2".to_string(); + handle_turn_complete( + conversation_a, + a_turn2.clone(), + &outgoing, + &turn_summary_store, + ) + .await; + + // Verify: A turn 1 + let msg = rx + .recv() + .await + .ok_or_else(|| anyhow!("should send first notification"))?; + match msg { + OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { + assert_eq!(n.turn.id, a_turn1); + assert_eq!( + n.turn.status, + TurnStatus::Failed { + error: TurnError { + message: "a1".to_string(), + codex_error_info: Some(V2CodexErrorInfo::BadRequest), + } + } + ); + } + other => bail!("unexpected message: {other:?}"), + } + + // Verify: B turn 1 + let msg = rx + .recv() + .await + .ok_or_else(|| anyhow!("should send second notification"))?; + match msg { + OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { + assert_eq!(n.turn.id, b_turn1); + assert_eq!( + n.turn.status, + TurnStatus::Failed { + error: TurnError { + message: "b1".to_string(), + codex_error_info: None, + } + } + ); + } + other => bail!("unexpected message: {other:?}"), + } + + // Verify: A turn 2 + let msg = rx + .recv() + .await + .ok_or_else(|| anyhow!("should send third notification"))?; + match msg { + OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { + assert_eq!(n.turn.id, a_turn2); + assert_eq!(n.turn.status, TurnStatus::Completed); + } + other => bail!("unexpected message: {other:?}"), + } + + assert!(rx.try_recv().is_err(), "no extra messages expected"); + Ok(()) + } + #[tokio::test] async fn test_construct_mcp_tool_call_begin_notification_without_args() { let begin_event = McpToolCallBeginEvent { diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index c5fa2a7fa..ae1bed31c 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -60,6 +60,8 @@ use codex_app_server_protocol::RemoveConversationSubscriptionResponse; use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ResumeConversationParams; use codex_app_server_protocol::ResumeConversationResponse; +use codex_app_server_protocol::ReviewStartParams; +use codex_app_server_protocol::ReviewTarget; use codex_app_server_protocol::SandboxMode; use codex_app_server_protocol::SendUserMessageParams; use codex_app_server_protocol::SendUserMessageResponse; @@ -81,6 +83,7 @@ use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::ThreadStartResponse; use codex_app_server_protocol::ThreadStartedNotification; use codex_app_server_protocol::Turn; +use codex_app_server_protocol::TurnError; use codex_app_server_protocol::TurnInterruptParams; use codex_app_server_protocol::TurnStartParams; use codex_app_server_protocol::TurnStartResponse; @@ -89,6 +92,7 @@ use codex_app_server_protocol::TurnStatus; use codex_app_server_protocol::UserInfoResponse; use codex_app_server_protocol::UserInput as V2UserInput; use codex_app_server_protocol::UserSavedConfig; +use codex_app_server_protocol::build_turns_from_event_msgs; use codex_backend_client::Client as BackendClient; use codex_core::AuthManager; use codex_core::CodexConversation; @@ -109,12 +113,15 @@ use codex_core::config_loader::load_config_as_toml; use codex_core::default_client::get_codex_user_agent; use codex_core::exec::ExecParams; use codex_core::exec_env::create_env; +use codex_core::features::Feature; use codex_core::find_conversation_path_by_id_str; use codex_core::get_platform_sandbox; use codex_core::git_info::git_diff_to_remote; use codex_core::parse_cursor; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; +use codex_core::protocol::ReviewRequest; +use codex_core::protocol::SessionConfiguredEvent; use codex_core::read_head_for_summary; use codex_feedback::CodexFeedback; use codex_login::ServerOptions as LoginServerOptions; @@ -132,6 +139,7 @@ use codex_protocol::protocol::USER_MESSAGE_BEGIN; use codex_protocol::user_input::UserInput as CoreInputItem; use codex_utils_json_to_toml::json_to_toml; use std::collections::HashMap; +use std::collections::HashSet; use std::ffi::OsStr; use std::io::Error as IoError; use std::path::Path; @@ -151,6 +159,15 @@ use uuid::Uuid; type PendingInterruptQueue = Vec<(RequestId, ApiVersion)>; pub(crate) type PendingInterrupts = Arc>>; +/// Per-conversation accumulation of the latest states e.g. error message while a turn runs. +#[derive(Default, Clone)] +pub(crate) struct TurnSummary { + pub(crate) file_change_started: HashSet, + pub(crate) last_error: Option, +} + +pub(crate) type TurnSummaryStore = Arc>>; + // Duration before a ChatGPT login attempt is abandoned. const LOGIN_CHATGPT_TIMEOUT: Duration = Duration::from_secs(10 * 60); struct ActiveLogin { @@ -175,6 +192,7 @@ pub(crate) struct CodexMessageProcessor { active_login: Arc>>, // Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives. pending_interrupts: PendingInterrupts, + turn_summary_store: TurnSummaryStore, pending_fuzzy_searches: Arc>>>, feedback: CodexFeedback, } @@ -227,11 +245,97 @@ impl CodexMessageProcessor { conversation_listeners: HashMap::new(), active_login: Arc::new(Mutex::new(None)), pending_interrupts: Arc::new(Mutex::new(HashMap::new())), + turn_summary_store: Arc::new(Mutex::new(HashMap::new())), pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())), feedback, } } + fn review_request_from_target( + target: ReviewTarget, + append_to_original_thread: bool, + ) -> Result<(ReviewRequest, String), JSONRPCErrorError> { + fn invalid_request(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message, + data: None, + } + } + + match target { + // TODO(jif) those messages will be extracted in a follow-up PR. + ReviewTarget::UncommittedChanges => Ok(( + ReviewRequest { + prompt: "Review the current code changes (staged, unstaged, and untracked files) and provide prioritized findings.".to_string(), + user_facing_hint: "current changes".to_string(), + append_to_original_thread, + }, + "Review uncommitted changes".to_string(), + )), + ReviewTarget::BaseBranch { branch } => { + let branch = branch.trim().to_string(); + if branch.is_empty() { + return Err(invalid_request("branch must not be empty".to_string())); + } + let prompt = format!("Review the code changes against the base branch '{branch}'. Start by finding the merge diff between the current branch and {branch}'s upstream e.g. (`git merge-base HEAD \"$(git rev-parse --abbrev-ref \"{branch}@{{upstream}}\")\"`), then run `git diff` against that SHA to see what changes we would merge into the {branch} branch. Provide prioritized, actionable findings."); + let hint = format!("changes against '{branch}'"); + let display = format!("Review changes against base branch '{branch}'"); + Ok(( + ReviewRequest { + prompt, + user_facing_hint: hint, + append_to_original_thread, + }, + display, + )) + } + ReviewTarget::Commit { sha, title } => { + let sha = sha.trim().to_string(); + if sha.is_empty() { + return Err(invalid_request("sha must not be empty".to_string())); + } + let brief_title = title + .map(|t| t.trim().to_string()) + .filter(|t| !t.is_empty()); + let prompt = if let Some(title) = brief_title.clone() { + format!("Review the code changes introduced by commit {sha} (\"{title}\"). Provide prioritized, actionable findings.") + } else { + format!("Review the code changes introduced by commit {sha}. Provide prioritized, actionable findings.") + }; + let short_sha = sha.chars().take(7).collect::(); + let hint = format!("commit {short_sha}"); + let display = if let Some(title) = brief_title { + format!("Review commit {short_sha}: {title}") + } else { + format!("Review commit {short_sha}") + }; + Ok(( + ReviewRequest { + prompt, + user_facing_hint: hint, + append_to_original_thread, + }, + display, + )) + } + ReviewTarget::Custom { instructions } => { + let trimmed = instructions.trim().to_string(); + if trimmed.is_empty() { + return Err(invalid_request("instructions must not be empty".to_string())); + } + Ok(( + ReviewRequest { + prompt: trimmed.clone(), + user_facing_hint: trimmed.clone(), + append_to_original_thread, + }, + trimmed, + )) + } + } + } + pub async fn process_request(&mut self, request: ClientRequest) { match request { ClientRequest::Initialize { .. } => { @@ -263,6 +367,9 @@ impl CodexMessageProcessor { ClientRequest::TurnInterrupt { request_id, params } => { self.turn_interrupt(request_id, params).await; } + ClientRequest::ReviewStart { request_id, params } => { + self.review_start(request_id, params).await; + } ClientRequest::NewConversation { request_id, params } => { // Do not tokio::spawn() to process new_conversation() // asynchronously because we need to ensure the conversation is @@ -1063,7 +1170,7 @@ impl CodexMessageProcessor { let exec_params = ExecParams { command: params.command, cwd, - timeout_ms, + expiration: timeout_ms.into(), env, with_escalated_permissions: None, justification: None, @@ -1135,7 +1242,7 @@ impl CodexMessageProcessor { let overrides = ConfigOverrides { model, config_profile: profile, - cwd: cwd.map(PathBuf::from), + cwd: cwd.clone().map(PathBuf::from), approval_policy, sandbox_mode, model_provider, @@ -1147,7 +1254,17 @@ impl CodexMessageProcessor { ..Default::default() }; - let config = match derive_config_from_params(overrides, cli_overrides).await { + // Persist windows sandbox feature. + // TODO: persist default config in general. + let mut cli_overrides = cli_overrides.unwrap_or_default(); + if cfg!(windows) && self.config.features.enabled(Feature::WindowsSandbox) { + cli_overrides.insert( + "features.enable_experimental_windows_sandbox".to_string(), + serde_json::json!(true), + ); + } + + let config = match derive_config_from_params(overrides, Some(cli_overrides)).await { Ok(config) => config, Err(err) => { let error = JSONRPCErrorError { @@ -1212,8 +1329,12 @@ impl CodexMessageProcessor { match self.conversation_manager.new_conversation(config).await { Ok(new_conv) => { - let conversation_id = new_conv.conversation_id; - let rollout_path = new_conv.session_configured.rollout_path.clone(); + let NewConversation { + conversation_id, + session_configured, + .. + } = new_conv; + let rollout_path = session_configured.rollout_path.clone(); let fallback_provider = self.config.model_provider_id.as_str(); // A bit hacky, but the summary contains a lot of useful information for the thread @@ -1238,8 +1359,22 @@ impl CodexMessageProcessor { } }; + let SessionConfiguredEvent { + model, + model_provider_id, + cwd, + approval_policy, + sandbox_policy, + .. + } = session_configured; let response = ThreadStartResponse { thread: thread.clone(), + model, + model_provider: model_provider_id, + cwd, + approval_policy: approval_policy.into(), + sandbox: sandbox_policy.into(), + reasoning_effort: session_configured.reasoning_effort, }; // Auto-attach a conversation listener when starting a thread. @@ -1521,6 +1656,11 @@ impl CodexMessageProcessor { session_configured, .. }) => { + let SessionConfiguredEvent { + rollout_path, + initial_messages, + .. + } = session_configured; // Auto-attach a conversation listener when resuming a thread. if let Err(err) = self .attach_conversation_listener(conversation_id, false, ApiVersion::V2) @@ -1533,8 +1673,8 @@ impl CodexMessageProcessor { ); } - let thread = match read_summary_from_rollout( - session_configured.rollout_path.as_path(), + let mut thread = match read_summary_from_rollout( + rollout_path.as_path(), fallback_model_provider.as_str(), ) .await @@ -1545,14 +1685,27 @@ impl CodexMessageProcessor { request_id, format!( "failed to load rollout `{}` for conversation {conversation_id}: {err}", - session_configured.rollout_path.display() + rollout_path.display() ), ) .await; return; } }; - let response = ThreadResumeResponse { thread }; + thread.turns = initial_messages + .as_deref() + .map_or_else(Vec::new, build_turns_from_event_msgs); + + let response = ThreadResumeResponse { + thread, + model: session_configured.model, + model_provider: session_configured.model_provider_id, + cwd: session_configured.cwd, + approval_policy: session_configured.approval_policy.into(), + sandbox: session_configured.sandbox_policy.into(), + reasoning_effort: session_configured.reasoning_effort, + }; + self.outgoing.send_response(request_id, response).await; } Err(err) => { @@ -1803,6 +1956,15 @@ impl CodexMessageProcessor { include_apply_patch_tool, } = overrides; + // Persist windows sandbox feature. + let mut cli_overrides = cli_overrides.unwrap_or_default(); + if cfg!(windows) && self.config.features.enabled(Feature::WindowsSandbox) { + cli_overrides.insert( + "features.enable_experimental_windows_sandbox".to_string(), + serde_json::json!(true), + ); + } + let overrides = ConfigOverrides { model, config_profile: profile, @@ -1818,7 +1980,7 @@ impl CodexMessageProcessor { ..Default::default() }; - derive_config_from_params(overrides, cli_overrides).await + derive_config_from_params(overrides, Some(cli_overrides)).await } None => Ok(self.config.as_ref().clone()), }; @@ -2272,9 +2434,6 @@ impl CodexMessageProcessor { } }; - // Keep a copy of v2 inputs for the notification payload. - let v2_inputs_for_notif = params.input.clone(); - // Map v2 input items to core input items. let mapped_items: Vec = params .input @@ -2314,12 +2473,8 @@ impl CodexMessageProcessor { Ok(turn_id) => { let turn = Turn { id: turn_id.clone(), - items: vec![ThreadItem::UserMessage { - id: turn_id, - content: v2_inputs_for_notif, - }], + items: vec![], status: TurnStatus::InProgress, - error: None, }; let response = TurnStartResponse { turn: turn.clone() }; @@ -2342,6 +2497,64 @@ impl CodexMessageProcessor { } } + async fn review_start(&self, request_id: RequestId, params: ReviewStartParams) { + let ReviewStartParams { + thread_id, + target, + append_to_original_thread, + } = params; + let (_, conversation) = match self.conversation_from_thread_id(&thread_id).await { + Ok(v) => v, + Err(error) => { + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + let (review_request, display_text) = + match Self::review_request_from_target(target, append_to_original_thread) { + Ok(value) => value, + Err(err) => { + self.outgoing.send_error(request_id, err).await; + return; + } + }; + + let turn_id = conversation.submit(Op::Review { review_request }).await; + + match turn_id { + Ok(turn_id) => { + let mut items = Vec::new(); + if !display_text.is_empty() { + items.push(ThreadItem::UserMessage { + id: turn_id.clone(), + content: vec![V2UserInput::Text { text: display_text }], + }); + } + let turn = Turn { + id: turn_id.clone(), + items, + status: TurnStatus::InProgress, + }; + let response = TurnStartResponse { turn: turn.clone() }; + self.outgoing.send_response(request_id, response).await; + + let notif = TurnStartedNotification { turn }; + self.outgoing + .send_server_notification(ServerNotification::TurnStarted(notif)) + .await; + } + Err(err) => { + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("failed to start review: {err}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + } + } + } + async fn turn_interrupt(&mut self, request_id: RequestId, params: TurnInterruptParams) { let TurnInterruptParams { thread_id, .. } = params; @@ -2441,6 +2654,7 @@ impl CodexMessageProcessor { let outgoing_for_task = self.outgoing.clone(); let pending_interrupts = self.pending_interrupts.clone(); + let turn_summary_store = self.turn_summary_store.clone(); let api_version_for_task = api_version; tokio::spawn(async move { loop { @@ -2497,6 +2711,7 @@ impl CodexMessageProcessor { conversation.clone(), outgoing_for_task.clone(), pending_interrupts.clone(), + turn_summary_store.clone(), api_version_for_task, ) .await; @@ -2791,6 +3006,7 @@ fn summary_to_thread(summary: ConversationSummary) -> Thread { model_provider, created_at: created_at.map(|dt| dt.timestamp()).unwrap_or(0), path, + turns: Vec::new(), } } diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index a97b037be..403263b89 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -6,7 +6,6 @@ use crate::outgoing_message::OutgoingMessageSender; use codex_app_server_protocol::ClientInfo; use codex_app_server_protocol::ClientRequest; use codex_app_server_protocol::InitializeResponse; - use codex_app_server_protocol::JSONRPCError; use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCNotification; @@ -118,6 +117,7 @@ impl MessageProcessor { self.outgoing.send_response(request_id, response).await; self.initialized = true; + return; } } diff --git a/codex-rs/app-server/src/outgoing_message.rs b/codex-rs/app-server/src/outgoing_message.rs index 40260c8b9..b7f331c9d 100644 --- a/codex-rs/app-server/src/outgoing_message.rs +++ b/codex-rs/app-server/src/outgoing_message.rs @@ -229,6 +229,7 @@ mod tests { resets_at: Some(123), }), secondary: None, + credits: None, }, }); @@ -243,7 +244,8 @@ mod tests { "windowDurationMins": 15, "resetsAt": 123 }, - "secondary": null + "secondary": null, + "credits": null } }, }), diff --git a/codex-rs/app-server/tests/common/Cargo.toml b/codex-rs/app-server/tests/common/Cargo.toml index 6240f755e..f8e6529d4 100644 --- a/codex-rs/app-server/tests/common/Cargo.toml +++ b/codex-rs/app-server/tests/common/Cargo.toml @@ -24,3 +24,5 @@ tokio = { workspace = true, features = [ ] } uuid = { workspace = true } wiremock = { workspace = true } +core_test_support = { path = "../../../core/tests/common" } +shlex = { workspace = true } diff --git a/codex-rs/app-server/tests/common/lib.rs b/codex-rs/app-server/tests/common/lib.rs index dc3d24cca..7d3eb5388 100644 --- a/codex-rs/app-server/tests/common/lib.rs +++ b/codex-rs/app-server/tests/common/lib.rs @@ -9,12 +9,14 @@ pub use auth_fixtures::ChatGptIdTokenClaims; pub use auth_fixtures::encode_id_token; pub use auth_fixtures::write_chatgpt_auth; use codex_app_server_protocol::JSONRPCResponse; +pub use core_test_support::format_with_current_shell; +pub use core_test_support::format_with_current_shell_display; pub use mcp_process::McpProcess; pub use mock_model_server::create_mock_chat_completions_server; pub use mock_model_server::create_mock_chat_completions_server_unchecked; pub use responses::create_apply_patch_sse_response; pub use responses::create_final_assistant_message_sse_response; -pub use responses::create_shell_sse_response; +pub use responses::create_shell_command_sse_response; pub use rollout::create_fake_rollout; use serde::de::DeserializeOwned; diff --git a/codex-rs/app-server/tests/common/mcp_process.rs b/codex-rs/app-server/tests/common/mcp_process.rs index 75851eda2..920a6fa01 100644 --- a/codex-rs/app-server/tests/common/mcp_process.rs +++ b/codex-rs/app-server/tests/common/mcp_process.rs @@ -35,6 +35,7 @@ use codex_app_server_protocol::NewConversationParams; use codex_app_server_protocol::RemoveConversationListenerParams; use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ResumeConversationParams; +use codex_app_server_protocol::ReviewStartParams; use codex_app_server_protocol::SendUserMessageParams; use codex_app_server_protocol::SendUserTurnParams; use codex_app_server_protocol::ServerRequest; @@ -377,6 +378,15 @@ impl McpProcess { self.send_request("turn/interrupt", params).await } + /// Send a `review/start` JSON-RPC request (v2). + pub async fn send_review_start_request( + &mut self, + params: ReviewStartParams, + ) -> anyhow::Result { + let params = Some(serde_json::to_value(params)?); + self.send_request("review/start", params).await + } + /// Send a `cancelLoginChatGpt` JSON-RPC request. pub async fn send_cancel_login_chat_gpt_request( &mut self, diff --git a/codex-rs/app-server/tests/common/responses.rs b/codex-rs/app-server/tests/common/responses.rs index 9a827fb98..0a9183c04 100644 --- a/codex-rs/app-server/tests/common/responses.rs +++ b/codex-rs/app-server/tests/common/responses.rs @@ -1,17 +1,18 @@ use serde_json::json; use std::path::Path; -pub fn create_shell_sse_response( +pub fn create_shell_command_sse_response( command: Vec, workdir: Option<&Path>, timeout_ms: Option, call_id: &str, ) -> anyhow::Result { - // The `arguments`` for the `shell` tool is a serialized JSON object. + // The `arguments` for the `shell_command` tool is a serialized JSON object. + let command_str = shlex::try_join(command.iter().map(String::as_str))?; let tool_call_arguments = serde_json::to_string(&json!({ - "command": command, + "command": command_str, "workdir": workdir.map(|w| w.to_string_lossy()), - "timeout": timeout_ms + "timeout_ms": timeout_ms }))?; let tool_call = json!({ "choices": [ @@ -21,7 +22,7 @@ pub fn create_shell_sse_response( { "id": call_id, "function": { - "name": "shell", + "name": "shell_command", "arguments": tool_call_arguments } } @@ -62,10 +63,10 @@ pub fn create_apply_patch_sse_response( patch_content: &str, call_id: &str, ) -> anyhow::Result { - // Use shell command to call apply_patch with heredoc format - let shell_command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF"); + // Use shell_command to call apply_patch with heredoc format + let command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF"); let tool_call_arguments = serde_json::to_string(&json!({ - "command": ["bash", "-lc", shell_command] + "command": command }))?; let tool_call = json!({ @@ -76,7 +77,7 @@ pub fn create_apply_patch_sse_response( { "id": call_id, "function": { - "name": "shell", + "name": "shell_command", "arguments": tool_call_arguments } } diff --git a/codex-rs/app-server/tests/suite/codex_message_processor_flow.rs b/codex-rs/app-server/tests/suite/codex_message_processor_flow.rs index 1feda4284..a64aca805 100644 --- a/codex-rs/app-server/tests/suite/codex_message_processor_flow.rs +++ b/codex-rs/app-server/tests/suite/codex_message_processor_flow.rs @@ -2,7 +2,8 @@ use anyhow::Result; use app_test_support::McpProcess; use app_test_support::create_final_assistant_message_sse_response; use app_test_support::create_mock_chat_completions_server; -use app_test_support::create_shell_sse_response; +use app_test_support::create_shell_command_sse_response; +use app_test_support::format_with_current_shell; use app_test_support::to_response; use codex_app_server_protocol::AddConversationListenerParams; use codex_app_server_protocol::AddConversationSubscriptionResponse; @@ -56,7 +57,7 @@ async fn test_codex_jsonrpc_conversation_flow() -> Result<()> { // Create a mock model server that immediately ends each turn. // Two turns are expected: initial session configure + one user message. let responses = vec![ - create_shell_sse_response( + create_shell_command_sse_response( vec!["ls".to_string()], Some(&working_directory), Some(5000), @@ -175,7 +176,7 @@ async fn test_send_user_turn_changes_approval_policy_behavior() -> Result<()> { // Mock server will request a python shell call for the first and second turn, then finish. let responses = vec![ - create_shell_sse_response( + create_shell_command_sse_response( vec![ "python3".to_string(), "-c".to_string(), @@ -186,7 +187,7 @@ async fn test_send_user_turn_changes_approval_policy_behavior() -> Result<()> { "call1", )?, create_final_assistant_message_sse_response("done 1")?, - create_shell_sse_response( + create_shell_command_sse_response( vec![ "python3".to_string(), "-c".to_string(), @@ -267,11 +268,7 @@ async fn test_send_user_turn_changes_approval_policy_behavior() -> Result<()> { ExecCommandApprovalParams { conversation_id, call_id: "call1".to_string(), - command: vec![ - "python3".to_string(), - "-c".to_string(), - "print(42)".to_string(), - ], + command: format_with_current_shell("python3 -c 'print(42)'"), cwd: working_directory.clone(), reason: None, risk: None, @@ -353,23 +350,15 @@ async fn test_send_user_turn_updates_sandbox_and_cwd_between_turns() -> Result<( std::fs::create_dir(&second_cwd)?; let responses = vec![ - create_shell_sse_response( - vec![ - "bash".to_string(), - "-lc".to_string(), - "echo first turn".to_string(), - ], + create_shell_command_sse_response( + vec!["echo".to_string(), "first".to_string(), "turn".to_string()], None, Some(5000), "call-first", )?, create_final_assistant_message_sse_response("done first")?, - create_shell_sse_response( - vec![ - "bash".to_string(), - "-lc".to_string(), - "echo second turn".to_string(), - ], + create_shell_command_sse_response( + vec!["echo".to_string(), "second".to_string(), "turn".to_string()], None, Some(5000), "call-second", @@ -481,13 +470,9 @@ async fn test_send_user_turn_updates_sandbox_and_cwd_between_turns() -> Result<( exec_begin.cwd, second_cwd, "exec turn should run from updated cwd" ); + let expected_command = format_with_current_shell("echo second turn"); assert_eq!( - exec_begin.command, - vec![ - "bash".to_string(), - "-lc".to_string(), - "echo second turn".to_string() - ], + exec_begin.command, expected_command, "exec turn should run expected command" ); diff --git a/codex-rs/app-server/tests/suite/config.rs b/codex-rs/app-server/tests/suite/config.rs index 281d54927..75dba5722 100644 --- a/codex-rs/app-server/tests/suite/config.rs +++ b/codex-rs/app-server/tests/suite/config.rs @@ -27,7 +27,7 @@ fn create_config_toml(codex_home: &Path) -> std::io::Result<()> { std::fs::write( config_toml, r#" -model = "gpt-5.1-codex" +model = "gpt-5.1-codex-max" approval_policy = "on-request" sandbox_mode = "workspace-write" model_reasoning_summary = "detailed" @@ -87,7 +87,7 @@ async fn get_config_toml_parses_all_fields() -> Result<()> { }), forced_chatgpt_workspace_id: Some("12345678-0000-0000-0000-000000000000".into()), forced_login_method: Some(ForcedLoginMethod::Chatgpt), - model: Some("gpt-5.1-codex".into()), + model: Some("gpt-5.1-codex-max".into()), model_reasoning_effort: Some(ReasoningEffort::High), model_reasoning_summary: Some(ReasoningSummary::Detailed), model_verbosity: Some(Verbosity::Medium), diff --git a/codex-rs/app-server/tests/suite/interrupt.rs b/codex-rs/app-server/tests/suite/interrupt.rs index 86b0a3f3f..d8e6182be 100644 --- a/codex-rs/app-server/tests/suite/interrupt.rs +++ b/codex-rs/app-server/tests/suite/interrupt.rs @@ -19,7 +19,7 @@ use tokio::time::timeout; use app_test_support::McpProcess; use app_test_support::create_mock_chat_completions_server; -use app_test_support::create_shell_sse_response; +use app_test_support::create_shell_command_sse_response; use app_test_support::to_response; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); @@ -56,7 +56,7 @@ async fn shell_command_interruption() -> anyhow::Result<()> { std::fs::create_dir(&working_directory)?; // Create mock server with a single SSE response: the long sleep command - let server = create_mock_chat_completions_server(vec![create_shell_sse_response( + let server = create_mock_chat_completions_server(vec![create_shell_command_sse_response( shell_command.clone(), Some(&working_directory), Some(10_000), // 10 seconds timeout in ms diff --git a/codex-rs/app-server/tests/suite/set_default_model.rs b/codex-rs/app-server/tests/suite/set_default_model.rs index f3af141c0..b56c54dbd 100644 --- a/codex-rs/app-server/tests/suite/set_default_model.rs +++ b/codex-rs/app-server/tests/suite/set_default_model.rs @@ -57,7 +57,7 @@ fn create_config_toml(codex_home: &Path) -> std::io::Result<()> { std::fs::write( config_toml, r#" -model = "gpt-5.1-codex" +model = "gpt-5.1-codex-max" model_reasoning_effort = "medium" "#, ) diff --git a/codex-rs/app-server/tests/suite/v2/mod.rs b/codex-rs/app-server/tests/suite/v2/mod.rs index 587afef10..a8594e7ca 100644 --- a/codex-rs/app-server/tests/suite/v2/mod.rs +++ b/codex-rs/app-server/tests/suite/v2/mod.rs @@ -1,6 +1,7 @@ mod account; mod model_list; mod rate_limits; +mod review; mod thread_archive; mod thread_list; mod thread_resume; diff --git a/codex-rs/app-server/tests/suite/v2/model_list.rs b/codex-rs/app-server/tests/suite/v2/model_list.rs index 8b17185f4..3c4844fed 100644 --- a/codex-rs/app-server/tests/suite/v2/model_list.rs +++ b/codex-rs/app-server/tests/suite/v2/model_list.rs @@ -45,6 +45,33 @@ async fn list_models_returns_all_models_with_large_limit() -> Result<()> { } = to_response::(response)?; let expected_models = vec![ + Model { + id: "gpt-5.1-codex-max".to_string(), + model: "gpt-5.1-codex-max".to_string(), + display_name: "gpt-5.1-codex-max".to_string(), + description: "Latest Codex-optimized flagship for deep and fast reasoning.".to_string(), + supported_reasoning_efforts: vec![ + ReasoningEffortOption { + reasoning_effort: ReasoningEffort::Low, + description: "Fast responses with lighter reasoning".to_string(), + }, + ReasoningEffortOption { + reasoning_effort: ReasoningEffort::Medium, + description: "Balances speed and reasoning depth for everyday tasks" + .to_string(), + }, + ReasoningEffortOption { + reasoning_effort: ReasoningEffort::High, + description: "Maximizes reasoning depth for complex problems".to_string(), + }, + ReasoningEffortOption { + reasoning_effort: ReasoningEffort::XHigh, + description: "Extra high reasoning depth for complex problems".to_string(), + }, + ], + default_reasoning_effort: ReasoningEffort::Medium, + is_default: true, + }, Model { id: "gpt-5.1-codex".to_string(), model: "gpt-5.1-codex".to_string(), @@ -66,7 +93,7 @@ async fn list_models_returns_all_models_with_large_limit() -> Result<()> { }, ], default_reasoning_effort: ReasoningEffort::Medium, - is_default: true, + is_default: false, }, Model { id: "gpt-5.1-codex-mini".to_string(), @@ -147,7 +174,7 @@ async fn list_models_pagination_works() -> Result<()> { } = to_response::(first_response)?; assert_eq!(first_items.len(), 1); - assert_eq!(first_items[0].id, "gpt-5.1-codex"); + assert_eq!(first_items[0].id, "gpt-5.1-codex-max"); let next_cursor = first_cursor.ok_or_else(|| anyhow!("cursor for second page"))?; let second_request = mcp @@ -169,7 +196,7 @@ async fn list_models_pagination_works() -> Result<()> { } = to_response::(second_response)?; assert_eq!(second_items.len(), 1); - assert_eq!(second_items[0].id, "gpt-5.1-codex-mini"); + assert_eq!(second_items[0].id, "gpt-5.1-codex"); let third_cursor = second_cursor.ok_or_else(|| anyhow!("cursor for third page"))?; let third_request = mcp @@ -191,8 +218,30 @@ async fn list_models_pagination_works() -> Result<()> { } = to_response::(third_response)?; assert_eq!(third_items.len(), 1); - assert_eq!(third_items[0].id, "gpt-5.1"); - assert!(third_cursor.is_none()); + assert_eq!(third_items[0].id, "gpt-5.1-codex-mini"); + let fourth_cursor = third_cursor.ok_or_else(|| anyhow!("cursor for fourth page"))?; + + let fourth_request = mcp + .send_list_models_request(ModelListParams { + limit: Some(1), + cursor: Some(fourth_cursor.clone()), + }) + .await?; + + let fourth_response: JSONRPCResponse = timeout( + DEFAULT_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(fourth_request)), + ) + .await??; + + let ModelListResponse { + data: fourth_items, + next_cursor: fourth_cursor, + } = to_response::(fourth_response)?; + + assert_eq!(fourth_items.len(), 1); + assert_eq!(fourth_items[0].id, "gpt-5.1"); + assert!(fourth_cursor.is_none()); Ok(()) } diff --git a/codex-rs/app-server/tests/suite/v2/rate_limits.rs b/codex-rs/app-server/tests/suite/v2/rate_limits.rs index d0cba8366..7ddccf7a7 100644 --- a/codex-rs/app-server/tests/suite/v2/rate_limits.rs +++ b/codex-rs/app-server/tests/suite/v2/rate_limits.rs @@ -152,6 +152,7 @@ async fn get_account_rate_limits_returns_snapshot() -> Result<()> { window_duration_mins: Some(1440), resets_at: Some(secondary_reset_timestamp), }), + credits: None, }, }; assert_eq!(received, expected); diff --git a/codex-rs/app-server/tests/suite/v2/review.rs b/codex-rs/app-server/tests/suite/v2/review.rs new file mode 100644 index 000000000..cdb3acd08 --- /dev/null +++ b/codex-rs/app-server/tests/suite/v2/review.rs @@ -0,0 +1,279 @@ +use anyhow::Result; +use app_test_support::McpProcess; +use app_test_support::create_final_assistant_message_sse_response; +use app_test_support::create_mock_chat_completions_server_unchecked; +use app_test_support::to_response; +use codex_app_server_protocol::ItemCompletedNotification; +use codex_app_server_protocol::ItemStartedNotification; +use codex_app_server_protocol::JSONRPCError; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ReviewStartParams; +use codex_app_server_protocol::ReviewTarget; +use codex_app_server_protocol::ThreadItem; +use codex_app_server_protocol::ThreadStartParams; +use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnStartResponse; +use codex_app_server_protocol::TurnStatus; +use serde_json::json; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); +const INVALID_REQUEST_ERROR_CODE: i64 = -32600; + +#[tokio::test] +async fn review_start_runs_review_turn_and_emits_code_review_item() -> Result<()> { + let review_payload = json!({ + "findings": [ + { + "title": "Prefer Stylize helpers", + "body": "Use .dim()/.bold() chaining instead of manual Style.", + "confidence_score": 0.9, + "priority": 1, + "code_location": { + "absolute_file_path": "/tmp/file.rs", + "line_range": {"start": 10, "end": 20} + } + } + ], + "overall_correctness": "good", + "overall_explanation": "Looks solid overall with minor polish suggested.", + "overall_confidence_score": 0.75 + }) + .to_string(); + let responses = vec![create_final_assistant_message_sse_response( + &review_payload, + )?]; + let server = create_mock_chat_completions_server_unchecked(responses).await; + + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let thread_id = start_default_thread(&mut mcp).await?; + + let review_req = mcp + .send_review_start_request(ReviewStartParams { + thread_id: thread_id.clone(), + append_to_original_thread: true, + target: ReviewTarget::Commit { + sha: "1234567deadbeef".to_string(), + title: Some("Tidy UI colors".to_string()), + }, + }) + .await?; + let review_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(review_req)), + ) + .await??; + let TurnStartResponse { turn } = to_response::(review_resp)?; + let turn_id = turn.id.clone(); + assert_eq!(turn.status, TurnStatus::InProgress); + assert_eq!(turn.items.len(), 1); + match &turn.items[0] { + ThreadItem::UserMessage { content, .. } => { + assert_eq!(content.len(), 1); + assert!(matches!( + &content[0], + codex_app_server_protocol::UserInput::Text { .. } + )); + } + other => panic!("expected user message, got {other:?}"), + } + + let _started: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/started"), + ) + .await??; + let item_started: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("item/started"), + ) + .await??; + let started: ItemStartedNotification = + serde_json::from_value(item_started.params.expect("params must be present"))?; + match started.item { + ThreadItem::CodeReview { id, review } => { + assert_eq!(id, turn_id); + assert_eq!(review, "commit 1234567"); + } + other => panic!("expected code review item, got {other:?}"), + } + + let mut review_body: Option = None; + for _ in 0..5 { + let review_notif: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("item/completed"), + ) + .await??; + let completed: ItemCompletedNotification = + serde_json::from_value(review_notif.params.expect("params must be present"))?; + match completed.item { + ThreadItem::CodeReview { id, review } => { + assert_eq!(id, turn_id); + review_body = Some(review); + break; + } + ThreadItem::UserMessage { .. } => continue, + other => panic!("unexpected item/completed payload: {other:?}"), + } + } + + let review = review_body.expect("did not observe a code review item"); + assert!(review.contains("Prefer Stylize helpers")); + assert!(review.contains("/tmp/file.rs:10-20")); + + Ok(()) +} + +#[tokio::test] +async fn review_start_rejects_empty_base_branch() -> Result<()> { + let server = create_mock_chat_completions_server_unchecked(vec![]).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + let thread_id = start_default_thread(&mut mcp).await?; + + let request_id = mcp + .send_review_start_request(ReviewStartParams { + thread_id, + append_to_original_thread: true, + target: ReviewTarget::BaseBranch { + branch: " ".to_string(), + }, + }) + .await?; + let error: JSONRPCError = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_error_message(RequestId::Integer(request_id)), + ) + .await??; + assert_eq!(error.error.code, INVALID_REQUEST_ERROR_CODE); + assert!( + error.error.message.contains("branch must not be empty"), + "unexpected message: {}", + error.error.message + ); + + Ok(()) +} + +#[tokio::test] +async fn review_start_rejects_empty_commit_sha() -> Result<()> { + let server = create_mock_chat_completions_server_unchecked(vec![]).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + let thread_id = start_default_thread(&mut mcp).await?; + + let request_id = mcp + .send_review_start_request(ReviewStartParams { + thread_id, + append_to_original_thread: true, + target: ReviewTarget::Commit { + sha: "\t".to_string(), + title: None, + }, + }) + .await?; + let error: JSONRPCError = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_error_message(RequestId::Integer(request_id)), + ) + .await??; + assert_eq!(error.error.code, INVALID_REQUEST_ERROR_CODE); + assert!( + error.error.message.contains("sha must not be empty"), + "unexpected message: {}", + error.error.message + ); + + Ok(()) +} + +#[tokio::test] +async fn review_start_rejects_empty_custom_instructions() -> Result<()> { + let server = create_mock_chat_completions_server_unchecked(vec![]).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + let thread_id = start_default_thread(&mut mcp).await?; + + let request_id = mcp + .send_review_start_request(ReviewStartParams { + thread_id, + append_to_original_thread: true, + target: ReviewTarget::Custom { + instructions: "\n\n".to_string(), + }, + }) + .await?; + let error: JSONRPCError = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_error_message(RequestId::Integer(request_id)), + ) + .await??; + assert_eq!(error.error.code, INVALID_REQUEST_ERROR_CODE); + assert!( + error + .error + .message + .contains("instructions must not be empty"), + "unexpected message: {}", + error.error.message + ); + + Ok(()) +} + +async fn start_default_thread(mcp: &mut McpProcess) -> Result { + let thread_req = mcp + .send_thread_start_request(ThreadStartParams { + model: Some("mock-model".to_string()), + ..Default::default() + }) + .await?; + let thread_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(thread_resp)?; + Ok(thread.id) +} + +fn create_config_toml(codex_home: &std::path::Path, server_uri: &str) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/app-server/tests/suite/v2/thread_archive.rs b/codex-rs/app-server/tests/suite/v2/thread_archive.rs index 083f3da90..88891af77 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_archive.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_archive.rs @@ -35,7 +35,7 @@ async fn thread_archive_moves_rollout_into_archived_directory() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(start_id)), ) .await??; - let ThreadStartResponse { thread } = to_response::(start_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; assert!(!thread.id.is_empty()); // Locate the rollout path recorded for this thread id. diff --git a/codex-rs/app-server/tests/suite/v2/thread_resume.rs b/codex-rs/app-server/tests/suite/v2/thread_resume.rs index bda2d1417..e22b711ea 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_resume.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_resume.rs @@ -1,13 +1,17 @@ use anyhow::Result; use app_test_support::McpProcess; +use app_test_support::create_fake_rollout; use app_test_support::create_mock_chat_completions_server; use app_test_support::to_response; use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ThreadItem; use codex_app_server_protocol::ThreadResumeParams; use codex_app_server_protocol::ThreadResumeResponse; use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnStatus; +use codex_app_server_protocol::UserInput; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use tempfile::TempDir; @@ -27,7 +31,7 @@ async fn thread_resume_returns_original_thread() -> Result<()> { // Start a thread. let start_id = mcp .send_thread_start_request(ThreadStartParams { - model: Some("gpt-5.1-codex".to_string()), + model: Some("gpt-5.1-codex-max".to_string()), ..Default::default() }) .await?; @@ -36,7 +40,7 @@ async fn thread_resume_returns_original_thread() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(start_id)), ) .await??; - let ThreadStartResponse { thread } = to_response::(start_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; // Resume it via v2 API. let resume_id = mcp @@ -50,13 +54,73 @@ async fn thread_resume_returns_original_thread() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(resume_id)), ) .await??; - let ThreadResumeResponse { thread: resumed } = - to_response::(resume_resp)?; + let ThreadResumeResponse { + thread: resumed, .. + } = to_response::(resume_resp)?; assert_eq!(resumed, thread); Ok(()) } +#[tokio::test] +async fn thread_resume_returns_rollout_history() -> Result<()> { + let server = create_mock_chat_completions_server(vec![]).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let preview = "Saved user message"; + let conversation_id = create_fake_rollout( + codex_home.path(), + "2025-01-05T12-00-00", + "2025-01-05T12:00:00Z", + preview, + Some("mock_provider"), + )?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let resume_id = mcp + .send_thread_resume_request(ThreadResumeParams { + thread_id: conversation_id.clone(), + ..Default::default() + }) + .await?; + let resume_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(resume_id)), + ) + .await??; + let ThreadResumeResponse { thread, .. } = to_response::(resume_resp)?; + + assert_eq!(thread.id, conversation_id); + assert_eq!(thread.preview, preview); + assert_eq!(thread.model_provider, "mock_provider"); + assert!(thread.path.is_absolute()); + + assert_eq!( + thread.turns.len(), + 1, + "expected rollouts to include one turn" + ); + let turn = &thread.turns[0]; + assert_eq!(turn.status, TurnStatus::Completed); + assert_eq!(turn.items.len(), 1, "expected user message item"); + match &turn.items[0] { + ThreadItem::UserMessage { content, .. } => { + assert_eq!( + content, + &vec![UserInput::Text { + text: preview.to_string() + }] + ); + } + other => panic!("expected user message item, got {other:?}"), + } + + Ok(()) +} + #[tokio::test] async fn thread_resume_prefers_path_over_thread_id() -> Result<()> { let server = create_mock_chat_completions_server(vec![]).await; @@ -68,7 +132,7 @@ async fn thread_resume_prefers_path_over_thread_id() -> Result<()> { let start_id = mcp .send_thread_start_request(ThreadStartParams { - model: Some("gpt-5.1-codex".to_string()), + model: Some("gpt-5.1-codex-max".to_string()), ..Default::default() }) .await?; @@ -77,7 +141,7 @@ async fn thread_resume_prefers_path_over_thread_id() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(start_id)), ) .await??; - let ThreadStartResponse { thread } = to_response::(start_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; let thread_path = thread.path.clone(); let resume_id = mcp @@ -93,8 +157,9 @@ async fn thread_resume_prefers_path_over_thread_id() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(resume_id)), ) .await??; - let ThreadResumeResponse { thread: resumed } = - to_response::(resume_resp)?; + let ThreadResumeResponse { + thread: resumed, .. + } = to_response::(resume_resp)?; assert_eq!(resumed, thread); Ok(()) @@ -112,7 +177,7 @@ async fn thread_resume_supports_history_and_overrides() -> Result<()> { // Start a thread. let start_id = mcp .send_thread_start_request(ThreadStartParams { - model: Some("gpt-5.1-codex".to_string()), + model: Some("gpt-5.1-codex-max".to_string()), ..Default::default() }) .await?; @@ -121,7 +186,7 @@ async fn thread_resume_supports_history_and_overrides() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(start_id)), ) .await??; - let ThreadStartResponse { thread } = to_response::(start_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; let history_text = "Hello from history"; let history = vec![ResponseItem::Message { @@ -147,10 +212,13 @@ async fn thread_resume_supports_history_and_overrides() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(resume_id)), ) .await??; - let ThreadResumeResponse { thread: resumed } = - to_response::(resume_resp)?; + let ThreadResumeResponse { + thread: resumed, + model_provider, + .. + } = to_response::(resume_resp)?; assert!(!resumed.id.is_empty()); - assert_eq!(resumed.model_provider, "mock_provider"); + assert_eq!(model_provider, "mock_provider"); assert_eq!(resumed.preview, history_text); Ok(()) diff --git a/codex-rs/app-server/tests/suite/v2/thread_start.rs b/codex-rs/app-server/tests/suite/v2/thread_start.rs index a5e4c0d48..ad0949ba2 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_start.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_start.rs @@ -40,13 +40,17 @@ async fn thread_start_creates_thread_and_emits_started() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(req_id)), ) .await??; - let ThreadStartResponse { thread } = to_response::(resp)?; + let ThreadStartResponse { + thread, + model_provider, + .. + } = to_response::(resp)?; assert!(!thread.id.is_empty(), "thread id should not be empty"); assert!( thread.preview.is_empty(), "new threads should start with an empty preview" ); - assert_eq!(thread.model_provider, "mock_provider"); + assert_eq!(model_provider, "mock_provider"); assert!( thread.created_at > 0, "created_at should be a positive UNIX timestamp" diff --git a/codex-rs/app-server/tests/suite/v2/turn_interrupt.rs b/codex-rs/app-server/tests/suite/v2/turn_interrupt.rs index d1deb6080..83389ed1d 100644 --- a/codex-rs/app-server/tests/suite/v2/turn_interrupt.rs +++ b/codex-rs/app-server/tests/suite/v2/turn_interrupt.rs @@ -3,16 +3,19 @@ use anyhow::Result; use app_test_support::McpProcess; use app_test_support::create_mock_chat_completions_server; -use app_test_support::create_shell_sse_response; +use app_test_support::create_shell_command_sse_response; use app_test_support::to_response; +use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnCompletedNotification; use codex_app_server_protocol::TurnInterruptParams; use codex_app_server_protocol::TurnInterruptResponse; use codex_app_server_protocol::TurnStartParams; use codex_app_server_protocol::TurnStartResponse; +use codex_app_server_protocol::TurnStatus; use codex_app_server_protocol::UserInput as V2UserInput; use tempfile::TempDir; use tokio::time::timeout; @@ -38,7 +41,7 @@ async fn turn_interrupt_aborts_running_turn() -> Result<()> { std::fs::create_dir(&working_directory)?; // Mock server: long-running shell command then (after abort) nothing else needed. - let server = create_mock_chat_completions_server(vec![create_shell_sse_response( + let server = create_mock_chat_completions_server(vec![create_shell_command_sse_response( shell_command.clone(), Some(&working_directory), Some(10_000), @@ -62,7 +65,7 @@ async fn turn_interrupt_aborts_running_turn() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), ) .await??; - let ThreadStartResponse { thread } = to_response::(thread_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(thread_resp)?; // Start a turn that triggers a long-running command. let turn_req = mcp @@ -99,7 +102,18 @@ async fn turn_interrupt_aborts_running_turn() -> Result<()> { .await??; let _resp: TurnInterruptResponse = to_response::(interrupt_resp)?; - // No fields to assert on; successful deserialization confirms proper response shape. + let completed_notif: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + let completed: TurnCompletedNotification = serde_json::from_value( + completed_notif + .params + .expect("turn/completed params must be present"), + )?; + assert_eq!(completed.turn.status, TurnStatus::Interrupted); + Ok(()) } diff --git a/codex-rs/app-server/tests/suite/v2/turn_start.rs b/codex-rs/app-server/tests/suite/v2/turn_start.rs index 433c7b448..de4d1cd2c 100644 --- a/codex-rs/app-server/tests/suite/v2/turn_start.rs +++ b/codex-rs/app-server/tests/suite/v2/turn_start.rs @@ -1,22 +1,32 @@ use anyhow::Result; use app_test_support::McpProcess; +use app_test_support::create_apply_patch_sse_response; use app_test_support::create_final_assistant_message_sse_response; use app_test_support::create_mock_chat_completions_server; use app_test_support::create_mock_chat_completions_server_unchecked; -use app_test_support::create_shell_sse_response; +use app_test_support::create_shell_command_sse_response; +use app_test_support::format_with_current_shell_display; use app_test_support::to_response; +use codex_app_server_protocol::ApprovalDecision; +use codex_app_server_protocol::CommandExecutionRequestApprovalResponse; use codex_app_server_protocol::CommandExecutionStatus; +use codex_app_server_protocol::FileChangeRequestApprovalResponse; +use codex_app_server_protocol::ItemCompletedNotification; use codex_app_server_protocol::ItemStartedNotification; use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::PatchApplyStatus; +use codex_app_server_protocol::PatchChangeKind; use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ServerRequest; use codex_app_server_protocol::ThreadItem; use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnCompletedNotification; use codex_app_server_protocol::TurnStartParams; use codex_app_server_protocol::TurnStartResponse; use codex_app_server_protocol::TurnStartedNotification; +use codex_app_server_protocol::TurnStatus; use codex_app_server_protocol::UserInput as V2UserInput; use codex_core::protocol_config_types::ReasoningEffort; use codex_core::protocol_config_types::ReasoningSummary; @@ -57,7 +67,7 @@ async fn turn_start_emits_notifications_and_accepts_model_override() -> Result<( mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), ) .await??; - let ThreadStartResponse { thread } = to_response::(thread_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(thread_resp)?; // Start a turn with only input and thread_id set (no overrides). let turn_req = mcp @@ -118,13 +128,17 @@ async fn turn_start_emits_notifications_and_accepts_model_override() -> Result<( ) .await??; - // And we should ultimately get a task_complete without having to add a - // legacy conversation listener explicitly (auto-attached by thread/start). - let _task_complete: JSONRPCNotification = timeout( + let completed_notif: JSONRPCNotification = timeout( DEFAULT_READ_TIMEOUT, - mcp.read_stream_until_notification_message("codex/event/task_complete"), + mcp.read_stream_until_notification_message("turn/completed"), ) .await??; + let completed: TurnCompletedNotification = serde_json::from_value( + completed_notif + .params + .expect("turn/completed params must be present"), + )?; + assert_eq!(completed.turn.status, TurnStatus::Completed); Ok(()) } @@ -157,7 +171,7 @@ async fn turn_start_accepts_local_image_input() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), ) .await??; - let ThreadStartResponse { thread } = to_response::(thread_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(thread_resp)?; let image_path = codex_home.path().join("image.png"); // No need to actually write the file; we just exercise the input path. @@ -191,7 +205,7 @@ async fn turn_start_exec_approval_toggle_v2() -> Result<()> { // Mock server: first turn requests a shell call (elicitation), then completes. // Second turn same, but we'll set approval_policy=never to avoid elicitation. let responses = vec![ - create_shell_sse_response( + create_shell_command_sse_response( vec![ "python3".to_string(), "-c".to_string(), @@ -202,7 +216,7 @@ async fn turn_start_exec_approval_toggle_v2() -> Result<()> { "call1", )?, create_final_assistant_message_sse_response("done 1")?, - create_shell_sse_response( + create_shell_command_sse_response( vec![ "python3".to_string(), "-c".to_string(), @@ -233,7 +247,7 @@ async fn turn_start_exec_approval_toggle_v2() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(start_id)), ) .await??; - let ThreadStartResponse { thread } = to_response::(start_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; // turn/start — expect CommandExecutionRequestApproval request from server let first_turn_id = mcp @@ -274,6 +288,11 @@ async fn turn_start_exec_approval_toggle_v2() -> Result<()> { mcp.read_stream_until_notification_message("codex/event/task_complete"), ) .await??; + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; // Second turn with approval_policy=never should not elicit approval let second_turn_id = mcp @@ -297,6 +316,150 @@ async fn turn_start_exec_approval_toggle_v2() -> Result<()> { .await??; // Ensure we do NOT receive a CommandExecutionRequestApproval request before task completes + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("codex/event/task_complete"), + ) + .await??; + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + + Ok(()) +} + +#[tokio::test] +async fn turn_start_exec_approval_decline_v2() -> Result<()> { + skip_if_no_network!(Ok(())); + + let tmp = TempDir::new()?; + let codex_home = tmp.path().to_path_buf(); + let workspace = tmp.path().join("workspace"); + std::fs::create_dir(&workspace)?; + + let responses = vec![ + create_shell_command_sse_response( + vec![ + "python3".to_string(), + "-c".to_string(), + "print(42)".to_string(), + ], + None, + Some(5000), + "call-decline", + )?, + create_final_assistant_message_sse_response("done")?, + ]; + let server = create_mock_chat_completions_server(responses).await; + create_config_toml(codex_home.as_path(), &server.uri(), "untrusted")?; + + let mut mcp = McpProcess::new(codex_home.as_path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let start_id = mcp + .send_thread_start_request(ThreadStartParams { + model: Some("mock-model".to_string()), + ..Default::default() + }) + .await?; + let start_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(start_id)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; + + let turn_id = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "run python".to_string(), + }], + cwd: Some(workspace.clone()), + ..Default::default() + }) + .await?; + let turn_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_id)), + ) + .await??; + let TurnStartResponse { turn } = to_response::(turn_resp)?; + + let started_command_execution = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let started_notif = mcp + .read_stream_until_notification_message("item/started") + .await?; + let started: ItemStartedNotification = + serde_json::from_value(started_notif.params.clone().expect("item/started params"))?; + if let ThreadItem::CommandExecution { .. } = started.item { + return Ok::(started.item); + } + } + }) + .await??; + let ThreadItem::CommandExecution { id, status, .. } = started_command_execution else { + unreachable!("loop ensures we break on command execution items"); + }; + assert_eq!(id, "call-decline"); + assert_eq!(status, CommandExecutionStatus::InProgress); + + let server_req = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_request_message(), + ) + .await??; + let ServerRequest::CommandExecutionRequestApproval { request_id, params } = server_req else { + panic!("expected CommandExecutionRequestApproval request") + }; + assert_eq!(params.item_id, "call-decline"); + assert_eq!(params.thread_id, thread.id); + assert_eq!(params.turn_id, turn.id); + + mcp.send_response( + request_id, + serde_json::to_value(CommandExecutionRequestApprovalResponse { + decision: ApprovalDecision::Decline, + accept_settings: None, + })?, + ) + .await?; + + let completed_command_execution = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let completed_notif = mcp + .read_stream_until_notification_message("item/completed") + .await?; + let completed: ItemCompletedNotification = serde_json::from_value( + completed_notif + .params + .clone() + .expect("item/completed params"), + )?; + if let ThreadItem::CommandExecution { .. } = completed.item { + return Ok::(completed.item); + } + } + }) + .await??; + let ThreadItem::CommandExecution { + id, + status, + exit_code, + aggregated_output, + .. + } = completed_command_execution + else { + unreachable!("loop ensures we break on command execution items"); + }; + assert_eq!(id, "call-decline"); + assert_eq!(status, CommandExecutionStatus::Declined); + assert!(exit_code.is_none()); + assert!(aggregated_output.is_none()); + timeout( DEFAULT_READ_TIMEOUT, mcp.read_stream_until_notification_message("codex/event/task_complete"), @@ -321,23 +484,15 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> { std::fs::create_dir(&second_cwd)?; let responses = vec![ - create_shell_sse_response( - vec![ - "bash".to_string(), - "-lc".to_string(), - "echo first turn".to_string(), - ], + create_shell_command_sse_response( + vec!["echo".to_string(), "first".to_string(), "turn".to_string()], None, Some(5000), "call-first", )?, create_final_assistant_message_sse_response("done first")?, - create_shell_sse_response( - vec![ - "bash".to_string(), - "-lc".to_string(), - "echo second turn".to_string(), - ], + create_shell_command_sse_response( + vec!["echo".to_string(), "second".to_string(), "turn".to_string()], None, Some(5000), "call-second", @@ -362,7 +517,7 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> { mcp.read_stream_until_response_message(RequestId::Integer(start_id)), ) .await??; - let ThreadStartResponse { thread } = to_response::(start_resp)?; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; // first turn with workspace-write sandbox and first_cwd let first_turn = mcp @@ -443,7 +598,8 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> { unreachable!("loop ensures we break on command execution items"); }; assert_eq!(cwd, second_cwd); - assert_eq!(command, "bash -lc 'echo second turn'"); + let expected_command = format_with_current_shell_display("echo second turn"); + assert_eq!(command, expected_command); assert_eq!(status, CommandExecutionStatus::InProgress); timeout( @@ -455,6 +611,308 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> { Ok(()) } +#[tokio::test] +async fn turn_start_file_change_approval_v2() -> Result<()> { + skip_if_no_network!(Ok(())); + if cfg!(windows) { + // TODO apply_patch approvals are not parsed from powershell commands yet + return Ok(()); + } + + let tmp = TempDir::new()?; + let codex_home = tmp.path().join("codex_home"); + std::fs::create_dir(&codex_home)?; + let workspace = tmp.path().join("workspace"); + std::fs::create_dir(&workspace)?; + + let patch = r#"*** Begin Patch +*** Add File: README.md ++new line +*** End Patch +"#; + let responses = vec![ + create_apply_patch_sse_response(patch, "patch-call")?, + create_final_assistant_message_sse_response("patch applied")?, + ]; + let server = create_mock_chat_completions_server(responses).await; + create_config_toml(&codex_home, &server.uri(), "untrusted")?; + + let mut mcp = McpProcess::new(&codex_home).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let start_req = mcp + .send_thread_start_request(ThreadStartParams { + model: Some("mock-model".to_string()), + cwd: Some(workspace.to_string_lossy().into_owned()), + ..Default::default() + }) + .await?; + let start_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(start_req)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; + + let turn_req = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "apply patch".into(), + }], + cwd: Some(workspace.clone()), + ..Default::default() + }) + .await?; + let turn_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_req)), + ) + .await??; + let TurnStartResponse { turn } = to_response::(turn_resp)?; + + let started_file_change = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let started_notif = mcp + .read_stream_until_notification_message("item/started") + .await?; + let started: ItemStartedNotification = + serde_json::from_value(started_notif.params.clone().expect("item/started params"))?; + if let ThreadItem::FileChange { .. } = started.item { + return Ok::(started.item); + } + } + }) + .await??; + let ThreadItem::FileChange { + ref id, + status, + ref changes, + } = started_file_change + else { + unreachable!("loop ensures we break on file change items"); + }; + assert_eq!(id, "patch-call"); + assert_eq!(status, PatchApplyStatus::InProgress); + let started_changes = changes.clone(); + + let server_req = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_request_message(), + ) + .await??; + let ServerRequest::FileChangeRequestApproval { request_id, params } = server_req else { + panic!("expected FileChangeRequestApproval request") + }; + assert_eq!(params.item_id, "patch-call"); + assert_eq!(params.thread_id, thread.id); + assert_eq!(params.turn_id, turn.id); + let expected_readme_path = workspace.join("README.md"); + let expected_readme_path = expected_readme_path.to_string_lossy().into_owned(); + pretty_assertions::assert_eq!( + started_changes, + vec![codex_app_server_protocol::FileUpdateChange { + path: expected_readme_path.clone(), + kind: PatchChangeKind::Add, + diff: "new line\n".to_string(), + }] + ); + + mcp.send_response( + request_id, + serde_json::to_value(FileChangeRequestApprovalResponse { + decision: ApprovalDecision::Accept, + })?, + ) + .await?; + + let completed_file_change = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let completed_notif = mcp + .read_stream_until_notification_message("item/completed") + .await?; + let completed: ItemCompletedNotification = serde_json::from_value( + completed_notif + .params + .clone() + .expect("item/completed params"), + )?; + if let ThreadItem::FileChange { .. } = completed.item { + return Ok::(completed.item); + } + } + }) + .await??; + let ThreadItem::FileChange { ref id, status, .. } = completed_file_change else { + unreachable!("loop ensures we break on file change items"); + }; + assert_eq!(id, "patch-call"); + assert_eq!(status, PatchApplyStatus::Completed); + + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("codex/event/task_complete"), + ) + .await??; + + let readme_contents = std::fs::read_to_string(expected_readme_path)?; + assert_eq!(readme_contents, "new line\n"); + + Ok(()) +} + +#[tokio::test] +async fn turn_start_file_change_approval_decline_v2() -> Result<()> { + skip_if_no_network!(Ok(())); + if cfg!(windows) { + // TODO apply_patch approvals are not parsed from powershell commands yet + return Ok(()); + } + + let tmp = TempDir::new()?; + let codex_home = tmp.path().join("codex_home"); + std::fs::create_dir(&codex_home)?; + let workspace = tmp.path().join("workspace"); + std::fs::create_dir(&workspace)?; + + let patch = r#"*** Begin Patch +*** Add File: README.md ++new line +*** End Patch +"#; + let responses = vec![ + create_apply_patch_sse_response(patch, "patch-call")?, + create_final_assistant_message_sse_response("patch declined")?, + ]; + let server = create_mock_chat_completions_server(responses).await; + create_config_toml(&codex_home, &server.uri(), "untrusted")?; + + let mut mcp = McpProcess::new(&codex_home).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let start_req = mcp + .send_thread_start_request(ThreadStartParams { + model: Some("mock-model".to_string()), + cwd: Some(workspace.to_string_lossy().into_owned()), + ..Default::default() + }) + .await?; + let start_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(start_req)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; + + let turn_req = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "apply patch".into(), + }], + cwd: Some(workspace.clone()), + ..Default::default() + }) + .await?; + let turn_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_req)), + ) + .await??; + let TurnStartResponse { turn } = to_response::(turn_resp)?; + + let started_file_change = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let started_notif = mcp + .read_stream_until_notification_message("item/started") + .await?; + let started: ItemStartedNotification = + serde_json::from_value(started_notif.params.clone().expect("item/started params"))?; + if let ThreadItem::FileChange { .. } = started.item { + return Ok::(started.item); + } + } + }) + .await??; + let ThreadItem::FileChange { + ref id, + status, + ref changes, + } = started_file_change + else { + unreachable!("loop ensures we break on file change items"); + }; + assert_eq!(id, "patch-call"); + assert_eq!(status, PatchApplyStatus::InProgress); + let started_changes = changes.clone(); + + let server_req = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_request_message(), + ) + .await??; + let ServerRequest::FileChangeRequestApproval { request_id, params } = server_req else { + panic!("expected FileChangeRequestApproval request") + }; + assert_eq!(params.item_id, "patch-call"); + assert_eq!(params.thread_id, thread.id); + assert_eq!(params.turn_id, turn.id); + let expected_readme_path = workspace.join("README.md"); + let expected_readme_path_str = expected_readme_path.to_string_lossy().into_owned(); + pretty_assertions::assert_eq!( + started_changes, + vec![codex_app_server_protocol::FileUpdateChange { + path: expected_readme_path_str.clone(), + kind: PatchChangeKind::Add, + diff: "new line\n".to_string(), + }] + ); + + mcp.send_response( + request_id, + serde_json::to_value(FileChangeRequestApprovalResponse { + decision: ApprovalDecision::Decline, + })?, + ) + .await?; + + let completed_file_change = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let completed_notif = mcp + .read_stream_until_notification_message("item/completed") + .await?; + let completed: ItemCompletedNotification = serde_json::from_value( + completed_notif + .params + .clone() + .expect("item/completed params"), + )?; + if let ThreadItem::FileChange { .. } = completed.item { + return Ok::(completed.item); + } + } + }) + .await??; + let ThreadItem::FileChange { ref id, status, .. } = completed_file_change else { + unreachable!("loop ensures we break on file change items"); + }; + assert_eq!(id, "patch-call"); + assert_eq!(status, PatchApplyStatus::Declined); + + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("codex/event/task_complete"), + ) + .await??; + + assert!( + !expected_readme_path.exists(), + "declined patch should not be applied" + ); + + Ok(()) +} + // Helper to create a config.toml pointing at the mock model server. fn create_config_toml( codex_home: &Path, diff --git a/codex-rs/apply-patch/src/lib.rs b/codex-rs/apply-patch/src/lib.rs index ac2f40979..a2bea0514 100644 --- a/codex-rs/apply-patch/src/lib.rs +++ b/codex-rs/apply-patch/src/lib.rs @@ -30,6 +30,7 @@ pub use standalone_executable::main; pub const APPLY_PATCH_TOOL_INSTRUCTIONS: &str = include_str!("../apply_patch_tool_instructions.md"); const APPLY_PATCH_COMMANDS: [&str; 2] = ["apply_patch", "applypatch"]; +const APPLY_PATCH_SHELLS: [&str; 3] = ["bash", "zsh", "sh"]; #[derive(Debug, Error, PartialEq)] pub enum ApplyPatchError { @@ -96,6 +97,13 @@ pub struct ApplyPatchArgs { pub workdir: Option, } +fn shell_supports_apply_patch(shell: &str) -> bool { + std::path::Path::new(shell) + .file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| APPLY_PATCH_SHELLS.contains(&name)) +} + pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch { match argv { // Direct invocation: apply_patch @@ -104,7 +112,7 @@ pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch { Err(e) => MaybeApplyPatch::PatchParseError(e), }, // Bash heredoc form: (optional `cd &&`) apply_patch <<'EOF' ... - [bash, flag, script] if bash == "bash" && flag == "-lc" => { + [shell, flag, script] if shell_supports_apply_patch(shell) && flag == "-lc" => { match extract_apply_patch_from_bash(script) { Ok((body, workdir)) => match parse_patch(&body) { Ok(mut source) => { @@ -224,12 +232,12 @@ pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApp ); } } - [bash, flag, script] if bash == "bash" && flag == "-lc" => { - if parse_patch(script).is_ok() { - return MaybeApplyPatchVerified::CorrectnessError( - ApplyPatchError::ImplicitInvocation, - ); - } + [shell, flag, script] + if shell_supports_apply_patch(shell) + && flag == "-lc" + && parse_patch(script).is_ok() => + { + return MaybeApplyPatchVerified::CorrectnessError(ApplyPatchError::ImplicitInvocation); } _ => {} } diff --git a/codex-rs/backend-client/src/client.rs b/codex-rs/backend-client/src/client.rs index 28a51598e..0fb627ef0 100644 --- a/codex-rs/backend-client/src/client.rs +++ b/codex-rs/backend-client/src/client.rs @@ -1,4 +1,5 @@ use crate::types::CodeTaskDetailsResponse; +use crate::types::CreditStatusDetails; use crate::types::PaginatedListTaskListItem; use crate::types::RateLimitStatusPayload; use crate::types::RateLimitWindowSnapshot; @@ -6,6 +7,7 @@ use crate::types::TurnAttemptsSiblingTurnsResponse; use anyhow::Result; use codex_core::auth::CodexAuth; use codex_core::default_client::get_codex_user_agent; +use codex_protocol::protocol::CreditsSnapshot; use codex_protocol::protocol::RateLimitSnapshot; use codex_protocol::protocol::RateLimitWindow; use reqwest::header::AUTHORIZATION; @@ -272,19 +274,23 @@ impl Client { // rate limit helpers fn rate_limit_snapshot_from_payload(payload: RateLimitStatusPayload) -> RateLimitSnapshot { - let Some(details) = payload + let rate_limit_details = payload .rate_limit - .and_then(|inner| inner.map(|boxed| *boxed)) - else { - return RateLimitSnapshot { - primary: None, - secondary: None, - }; + .and_then(|inner| inner.map(|boxed| *boxed)); + + let (primary, secondary) = if let Some(details) = rate_limit_details { + ( + Self::map_rate_limit_window(details.primary_window), + Self::map_rate_limit_window(details.secondary_window), + ) + } else { + (None, None) }; RateLimitSnapshot { - primary: Self::map_rate_limit_window(details.primary_window), - secondary: Self::map_rate_limit_window(details.secondary_window), + primary, + secondary, + credits: Self::map_credits(payload.credits), } } @@ -306,6 +312,19 @@ impl Client { }) } + fn map_credits(credits: Option>>) -> Option { + let details = match credits { + Some(Some(details)) => *details, + _ => return None, + }; + + Some(CreditsSnapshot { + has_credits: details.has_credits, + unlimited: details.unlimited, + balance: details.balance.and_then(|inner| inner), + }) + } + fn window_minutes_from_seconds(seconds: i32) -> Option { if seconds <= 0 { return None; diff --git a/codex-rs/backend-client/src/types.rs b/codex-rs/backend-client/src/types.rs index 9f196f9c2..afeb231a1 100644 --- a/codex-rs/backend-client/src/types.rs +++ b/codex-rs/backend-client/src/types.rs @@ -1,3 +1,4 @@ +pub use codex_backend_openapi_models::models::CreditStatusDetails; pub use codex_backend_openapi_models::models::PaginatedListTaskListItem; pub use codex_backend_openapi_models::models::PlanType; pub use codex_backend_openapi_models::models::RateLimitStatusDetails; diff --git a/codex-rs/cli/Cargo.toml b/codex-rs/cli/Cargo.toml index deddc068c..e7999b5ce 100644 --- a/codex-rs/cli/Cargo.toml +++ b/codex-rs/cli/Cargo.toml @@ -26,6 +26,7 @@ codex-cloud-tasks = { path = "../cloud-tasks" } codex-common = { workspace = true, features = ["cli"] } codex-core = { workspace = true } codex-exec = { workspace = true } +codex-execpolicy = { workspace = true } codex-login = { workspace = true } codex-mcp-server = { workspace = true } codex-process-hardening = { workspace = true } diff --git a/codex-rs/cli/src/debug_sandbox.rs b/codex-rs/cli/src/debug_sandbox.rs index df4c2e97c..26fecd55c 100644 --- a/codex-rs/cli/src/debug_sandbox.rs +++ b/codex-rs/cli/src/debug_sandbox.rs @@ -138,11 +138,7 @@ async fn run_command_under_sandbox( { use codex_windows_sandbox::run_windows_sandbox_capture; - let policy_str = match &config.sandbox_policy { - codex_core::protocol::SandboxPolicy::DangerFullAccess => "workspace-write", - codex_core::protocol::SandboxPolicy::ReadOnly => "read-only", - codex_core::protocol::SandboxPolicy::WorkspaceWrite { .. } => "workspace-write", - }; + let policy_str = serde_json::to_string(&config.sandbox_policy)?; let sandbox_cwd = sandbox_policy_cwd.clone(); let cwd_clone = cwd.clone(); @@ -153,7 +149,7 @@ async fn run_command_under_sandbox( // Preflight audit is invoked elsewhere at the appropriate times. let res = tokio::task::spawn_blocking(move || { run_windows_sandbox_capture( - policy_str, + policy_str.as_str(), &sandbox_cwd, base_dir.as_path(), command_vec, diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 6a3b24aa9..2b066197a 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -18,6 +18,7 @@ use codex_cli::login::run_logout; use codex_cloud_tasks::Cli as CloudTasksCli; use codex_common::CliConfigOverrides; use codex_exec::Cli as ExecCli; +use codex_execpolicy::ExecPolicyCheckCommand; use codex_responses_api_proxy::Args as ResponsesApiProxyArgs; use codex_tui::AppExitInfo; use codex_tui::Cli as TuiCli; @@ -93,6 +94,10 @@ enum Subcommand { #[clap(visible_alias = "debug")] Sandbox(SandboxArgs), + /// Execpolicy tooling. + #[clap(hide = true)] + Execpolicy(ExecpolicyCommand), + /// Apply the latest diff produced by Codex agent as a `git apply` to your local working tree. #[clap(visible_alias = "a")] Apply(ApplyCommand), @@ -134,6 +139,10 @@ struct ResumeCommand { #[arg(long = "last", default_value_t = false, conflicts_with = "session_id")] last: bool, + /// Show all sessions (disables cwd filtering and shows CWD column). + #[arg(long = "all", default_value_t = false)] + all: bool, + #[clap(flatten)] config_overrides: TuiCli, } @@ -158,6 +167,19 @@ enum SandboxCommand { Windows(WindowsCommand), } +#[derive(Debug, Parser)] +struct ExecpolicyCommand { + #[command(subcommand)] + sub: ExecpolicySubcommand, +} + +#[derive(Debug, clap::Subcommand)] +enum ExecpolicySubcommand { + /// Check execpolicy files against a command. + #[clap(name = "check")] + Check(ExecPolicyCheckCommand), +} + #[derive(Debug, Parser)] struct LoginCommand { #[clap(skip)] @@ -323,6 +345,10 @@ fn run_update_action(action: UpdateAction) -> anyhow::Result<()> { Ok(()) } +fn run_execpolicycheck(cmd: ExecPolicyCheckCommand) -> anyhow::Result<()> { + cmd.run() +} + #[derive(Debug, Default, Parser, Clone)] struct FeatureToggles { /// Enable a feature (repeatable). Equivalent to `-c features.=true`. @@ -448,6 +474,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() Some(Subcommand::Resume(ResumeCommand { session_id, last, + all, config_overrides, })) => { interactive = finalize_resume_interactive( @@ -455,6 +482,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() root_config_overrides.clone(), session_id, last, + all, config_overrides, ); let exit_info = codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?; @@ -543,6 +571,9 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() .await?; } }, + Some(Subcommand::Execpolicy(ExecpolicyCommand { sub })) => match sub { + ExecpolicySubcommand::Check(cmd) => run_execpolicycheck(cmd)?, + }, Some(Subcommand::Apply(mut apply_cli)) => { prepend_config_flags( &mut apply_cli.config_overrides, @@ -611,6 +642,7 @@ fn finalize_resume_interactive( root_config_overrides: CliConfigOverrides, session_id: Option, last: bool, + show_all: bool, resume_cli: TuiCli, ) -> TuiCli { // Start with the parsed interactive CLI so resume shares the same @@ -619,6 +651,7 @@ fn finalize_resume_interactive( interactive.resume_picker = resume_session_id.is_none() && !last; interactive.resume_last = last; interactive.resume_session_id = resume_session_id; + interactive.resume_show_all = show_all; // Merge resume-scoped flags and overrides with highest precedence. merge_resume_cli_flags(&mut interactive, resume_cli); @@ -702,13 +735,21 @@ mod tests { let Subcommand::Resume(ResumeCommand { session_id, last, + all, config_overrides: resume_cli, }) = subcommand.expect("resume present") else { unreachable!() }; - finalize_resume_interactive(interactive, root_overrides, session_id, last, resume_cli) + finalize_resume_interactive( + interactive, + root_overrides, + session_id, + last, + all, + resume_cli, + ) } fn sample_exit_info(conversation: Option<&str>) -> AppExitInfo { @@ -775,6 +816,7 @@ mod tests { assert!(interactive.resume_picker); assert!(!interactive.resume_last); assert_eq!(interactive.resume_session_id, None); + assert!(!interactive.resume_show_all); } #[test] @@ -783,6 +825,7 @@ mod tests { assert!(!interactive.resume_picker); assert!(interactive.resume_last); assert_eq!(interactive.resume_session_id, None); + assert!(!interactive.resume_show_all); } #[test] @@ -791,6 +834,14 @@ mod tests { assert!(!interactive.resume_picker); assert!(!interactive.resume_last); assert_eq!(interactive.resume_session_id.as_deref(), Some("1234")); + assert!(!interactive.resume_show_all); + } + + #[test] + fn resume_all_flag_sets_show_all() { + let interactive = finalize_from_args(["codex", "resume", "--all"].as_ref()); + assert!(interactive.resume_picker); + assert!(interactive.resume_show_all); } #[test] diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index ec37c3a6b..93f22e705 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -79,6 +79,7 @@ pub struct GetArgs { } #[derive(Debug, clap::Parser)] +#[command(override_usage = "codex mcp add [OPTIONS] (--url | -- ...)")] pub struct AddArgs { /// Name for the MCP server configuration. pub name: String, diff --git a/codex-rs/cli/tests/execpolicy.rs b/codex-rs/cli/tests/execpolicy.rs new file mode 100644 index 000000000..c6bca85bc --- /dev/null +++ b/codex-rs/cli/tests/execpolicy.rs @@ -0,0 +1,58 @@ +use std::fs; + +use assert_cmd::Command; +use pretty_assertions::assert_eq; +use serde_json::json; +use tempfile::TempDir; + +#[test] +fn execpolicy_check_matches_expected_json() -> Result<(), Box> { + let codex_home = TempDir::new()?; + let policy_path = codex_home.path().join("policy.codexpolicy"); + fs::write( + &policy_path, + r#" +prefix_rule( + pattern = ["git", "push"], + decision = "forbidden", +) +"#, + )?; + + let output = Command::cargo_bin("codex")? + .env("CODEX_HOME", codex_home.path()) + .args([ + "execpolicy", + "check", + "--policy", + policy_path + .to_str() + .expect("policy path should be valid UTF-8"), + "git", + "push", + "origin", + "main", + ]) + .output()?; + + assert!(output.status.success()); + let result: serde_json::Value = serde_json::from_slice(&output.stdout)?; + assert_eq!( + result, + json!({ + "match": { + "decision": "forbidden", + "matchedRules": [ + { + "prefixRuleMatch": { + "matchedPrefix": ["git", "push"], + "decision": "forbidden" + } + } + ] + } + }) + ); + + Ok(()) +} diff --git a/codex-rs/codex-backend-openapi-models/src/models/credit_status_details.rs b/codex-rs/codex-backend-openapi-models/src/models/credit_status_details.rs new file mode 100644 index 000000000..b62b88d71 --- /dev/null +++ b/codex-rs/codex-backend-openapi-models/src/models/credit_status_details.rs @@ -0,0 +1,52 @@ +/* + * codex-backend + * + * codex-backend + * + * The version of the OpenAPI document: 0.0.1 + * + * Generated by: https://openapi-generator.tech + */ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] +pub struct CreditStatusDetails { + #[serde(rename = "has_credits")] + pub has_credits: bool, + #[serde(rename = "unlimited")] + pub unlimited: bool, + #[serde( + rename = "balance", + default, + with = "::serde_with::rust::double_option", + skip_serializing_if = "Option::is_none" + )] + pub balance: Option>, + #[serde( + rename = "approx_local_messages", + default, + with = "::serde_with::rust::double_option", + skip_serializing_if = "Option::is_none" + )] + pub approx_local_messages: Option>>, + #[serde( + rename = "approx_cloud_messages", + default, + with = "::serde_with::rust::double_option", + skip_serializing_if = "Option::is_none" + )] + pub approx_cloud_messages: Option>>, +} + +impl CreditStatusDetails { + pub fn new(has_credits: bool, unlimited: bool) -> CreditStatusDetails { + CreditStatusDetails { + has_credits, + unlimited, + balance: None, + approx_local_messages: None, + approx_cloud_messages: None, + } + } +} diff --git a/codex-rs/codex-backend-openapi-models/src/models/mod.rs b/codex-rs/codex-backend-openapi-models/src/models/mod.rs index 96348d72c..d76715492 100644 --- a/codex-rs/codex-backend-openapi-models/src/models/mod.rs +++ b/codex-rs/codex-backend-openapi-models/src/models/mod.rs @@ -32,3 +32,6 @@ pub use self::rate_limit_status_details::RateLimitStatusDetails; pub mod rate_limit_window_snapshot; pub use self::rate_limit_window_snapshot::RateLimitWindowSnapshot; + +pub mod credit_status_details; +pub use self::credit_status_details::CreditStatusDetails; diff --git a/codex-rs/codex-backend-openapi-models/src/models/rate_limit_status_payload.rs b/codex-rs/codex-backend-openapi-models/src/models/rate_limit_status_payload.rs index d2af76f4d..0f5caf52f 100644 --- a/codex-rs/codex-backend-openapi-models/src/models/rate_limit_status_payload.rs +++ b/codex-rs/codex-backend-openapi-models/src/models/rate_limit_status_payload.rs @@ -23,6 +23,13 @@ pub struct RateLimitStatusPayload { skip_serializing_if = "Option::is_none" )] pub rate_limit: Option>>, + #[serde( + rename = "credits", + default, + with = "::serde_with::rust::double_option", + skip_serializing_if = "Option::is_none" + )] + pub credits: Option>>, } impl RateLimitStatusPayload { @@ -30,12 +37,15 @@ impl RateLimitStatusPayload { RateLimitStatusPayload { plan_type, rate_limit: None, + credits: None, } } } #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] pub enum PlanType { + #[serde(rename = "guest")] + Guest, #[serde(rename = "free")] Free, #[serde(rename = "go")] @@ -44,6 +54,8 @@ pub enum PlanType { Plus, #[serde(rename = "pro")] Pro, + #[serde(rename = "free_workspace")] + FreeWorkspace, #[serde(rename = "team")] Team, #[serde(rename = "business")] @@ -52,6 +64,8 @@ pub enum PlanType { Education, #[serde(rename = "quorum")] Quorum, + #[serde(rename = "k12")] + K12, #[serde(rename = "enterprise")] Enterprise, #[serde(rename = "edu")] @@ -60,6 +74,6 @@ pub enum PlanType { impl Default for PlanType { fn default() -> PlanType { - Self::Free + Self::Guest } } diff --git a/codex-rs/codex-backend-openapi-models/src/models/rate_limit_window_snapshot.rs b/codex-rs/codex-backend-openapi-models/src/models/rate_limit_window_snapshot.rs index 4fc04f4be..b2a6c0c22 100644 --- a/codex-rs/codex-backend-openapi-models/src/models/rate_limit_window_snapshot.rs +++ b/codex-rs/codex-backend-openapi-models/src/models/rate_limit_window_snapshot.rs @@ -7,7 +7,6 @@ * * Generated by: https://openapi-generator.tech */ - use serde::Deserialize; use serde::Serialize; diff --git a/codex-rs/common/src/approval_presets.rs b/codex-rs/common/src/approval_presets.rs index 6c3bf395a..1b673d1d9 100644 --- a/codex-rs/common/src/approval_presets.rs +++ b/codex-rs/common/src/approval_presets.rs @@ -24,21 +24,21 @@ pub fn builtin_approval_presets() -> Vec { ApprovalPreset { id: "read-only", label: "Read Only", - description: "Codex can read files and answer questions. Codex requires approval to make edits, run commands, or access network.", + description: "Requires approval to edit files and run commands.", approval: AskForApproval::OnRequest, sandbox: SandboxPolicy::ReadOnly, }, ApprovalPreset { id: "auto", - label: "Auto", - description: "Codex can read files, make edits, and run commands in the workspace. Codex requires approval to work outside the workspace or access network.", + label: "Agent", + description: "Read and edit files, and run commands.", approval: AskForApproval::OnRequest, sandbox: SandboxPolicy::new_workspace_write_policy(), }, ApprovalPreset { id: "full-access", - label: "Full Access", - description: "Codex can read files, make edits, and run commands with network access, without approval. Exercise caution.", + label: "Agent (full access)", + description: "Codex can edit files outside this workspace and run commands with network access. Exercise caution when using.", approval: AskForApproval::Never, sandbox: SandboxPolicy::DangerFullAccess, }, diff --git a/codex-rs/common/src/config_summary.rs b/codex-rs/common/src/config_summary.rs index dabc606ce..8fc1bb26f 100644 --- a/codex-rs/common/src/config_summary.rs +++ b/codex-rs/common/src/config_summary.rs @@ -15,13 +15,12 @@ pub fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, Stri if config.model_provider.wire_api == WireApi::Responses && config.model_family.supports_reasoning_summaries { - entries.push(( - "reasoning effort", - config - .model_reasoning_effort - .map(|effort| effort.to_string()) - .unwrap_or_else(|| "none".to_string()), - )); + let reasoning_effort = config + .model_reasoning_effort + .or(config.model_family.default_reasoning_effort) + .map(|effort| effort.to_string()) + .unwrap_or_else(|| "none".to_string()); + entries.push(("reasoning effort", reasoning_effort)); entries.push(( "reasoning summaries", config.model_reasoning_summary.to_string(), diff --git a/codex-rs/common/src/model_presets.rs b/codex-rs/common/src/model_presets.rs index 9921f969a..a031f23b1 100644 --- a/codex-rs/common/src/model_presets.rs +++ b/codex-rs/common/src/model_presets.rs @@ -4,6 +4,10 @@ use codex_app_server_protocol::AuthMode; use codex_core::protocol_config_types::ReasoningEffort; use once_cell::sync::Lazy; +pub const HIDE_GPT5_1_MIGRATION_PROMPT_CONFIG: &str = "hide_gpt5_1_migration_prompt"; +pub const HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG: &str = + "hide_gpt-5.1-codex-max_migration_prompt"; + /// A reasoning effort option that can be surfaced for a model. #[derive(Debug, Clone, Copy)] pub struct ReasoningEffortPreset { @@ -17,6 +21,7 @@ pub struct ReasoningEffortPreset { pub struct ModelUpgrade { pub id: &'static str, pub reasoning_effort_mapping: Option>, + pub migration_config_key: &'static str, } /// Metadata describing a Codex-supported model. @@ -38,10 +43,40 @@ pub struct ModelPreset { pub is_default: bool, /// recommended upgrade model pub upgrade: Option, + /// Whether this preset should appear in the picker UI. + pub show_in_picker: bool, } static PRESETS: Lazy> = Lazy::new(|| { vec![ + ModelPreset { + id: "gpt-5.1-codex-max", + model: "gpt-5.1-codex-max", + display_name: "gpt-5.1-codex-max", + description: "Latest Codex-optimized flagship for deep and fast reasoning.", + default_reasoning_effort: ReasoningEffort::Medium, + supported_reasoning_efforts: &[ + ReasoningEffortPreset { + effort: ReasoningEffort::Low, + description: "Fast responses with lighter reasoning", + }, + ReasoningEffortPreset { + effort: ReasoningEffort::Medium, + description: "Balances speed and reasoning depth for everyday tasks", + }, + ReasoningEffortPreset { + effort: ReasoningEffort::High, + description: "Maximizes reasoning depth for complex problems", + }, + ReasoningEffortPreset { + effort: ReasoningEffort::XHigh, + description: "Extra high reasoning depth for complex problems", + }, + ], + is_default: true, + upgrade: None, + show_in_picker: true, + }, ModelPreset { id: "gpt-5.1-codex", model: "gpt-5.1-codex", @@ -62,8 +97,13 @@ static PRESETS: Lazy> = Lazy::new(|| { description: "Maximizes reasoning depth for complex or ambiguous problems", }, ], - is_default: true, - upgrade: None, + is_default: false, + upgrade: Some(ModelUpgrade { + id: "gpt-5.1-codex-max", + reasoning_effort_mapping: None, + migration_config_key: HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG, + }), + show_in_picker: true, }, ModelPreset { id: "gpt-5.1-codex-mini", @@ -82,7 +122,12 @@ static PRESETS: Lazy> = Lazy::new(|| { }, ], is_default: false, - upgrade: None, + upgrade: Some(ModelUpgrade { + id: "gpt-5.1-codex-max", + reasoning_effort_mapping: None, + migration_config_key: HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG, + }), + show_in_picker: true, }, ModelPreset { id: "gpt-5.1", @@ -105,7 +150,12 @@ static PRESETS: Lazy> = Lazy::new(|| { }, ], is_default: false, - upgrade: None, + upgrade: Some(ModelUpgrade { + id: "gpt-5.1-codex-max", + reasoning_effort_mapping: None, + migration_config_key: HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG, + }), + show_in_picker: true, }, // Deprecated models. ModelPreset { @@ -130,9 +180,11 @@ static PRESETS: Lazy> = Lazy::new(|| { ], is_default: false, upgrade: Some(ModelUpgrade { - id: "gpt-5.1-codex", + id: "gpt-5.1-codex-max", reasoning_effort_mapping: None, + migration_config_key: HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG, }), + show_in_picker: false, }, ModelPreset { id: "gpt-5-codex-mini", @@ -154,7 +206,9 @@ static PRESETS: Lazy> = Lazy::new(|| { upgrade: Some(ModelUpgrade { id: "gpt-5.1-codex-mini", reasoning_effort_mapping: None, + migration_config_key: HIDE_GPT5_1_MIGRATION_PROMPT_CONFIG, }), + show_in_picker: false, }, ModelPreset { id: "gpt-5", @@ -182,21 +236,22 @@ static PRESETS: Lazy> = Lazy::new(|| { ], is_default: false, upgrade: Some(ModelUpgrade { - id: "gpt-5.1", - reasoning_effort_mapping: Some(HashMap::from([( - ReasoningEffort::Minimal, - ReasoningEffort::Low, - )])), + id: "gpt-5.1-codex-max", + reasoning_effort_mapping: None, + migration_config_key: HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG, }), + show_in_picker: false, }, ] }); -pub fn builtin_model_presets(_auth_mode: Option) -> Vec { - // leave auth mode for later use +pub fn builtin_model_presets(auth_mode: Option) -> Vec { PRESETS .iter() - .filter(|preset| preset.upgrade.is_none()) + .filter(|preset| match auth_mode { + Some(AuthMode::ApiKey) => preset.show_in_picker && preset.id != "gpt-5.1-codex-max", + _ => preset.show_in_picker, + }) .cloned() .collect() } @@ -208,10 +263,21 @@ pub fn all_model_presets() -> &'static Vec { #[cfg(test)] mod tests { use super::*; + use codex_app_server_protocol::AuthMode; #[test] fn only_one_default_model_is_configured() { let default_models = PRESETS.iter().filter(|preset| preset.is_default).count(); assert!(default_models == 1); } + + #[test] + fn gpt_5_1_codex_max_hidden_for_api_key_auth() { + let presets = builtin_model_presets(Some(AuthMode::ApiKey)); + assert!( + presets + .iter() + .all(|preset| preset.id != "gpt-5.1-codex-max") + ); + } } diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 4d8f43778..669a9a63f 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -19,9 +19,11 @@ async-trait = { workspace = true } base64 = { workspace = true } bytes = { workspace = true } chrono = { workspace = true, features = ["serde"] } +chardetng = { workspace = true } codex-app-server-protocol = { workspace = true } codex-apply-patch = { workspace = true } codex-async-utils = { workspace = true } +codex-execpolicy = { workspace = true } codex-file-search = { workspace = true } codex-git = { workspace = true } codex-keyring-store = { workspace = true } @@ -31,11 +33,11 @@ codex-rmcp-client = { workspace = true } codex-utils-pty = { workspace = true } codex-utils-readiness = { workspace = true } codex-utils-string = { workspace = true } -codex-utils-tokenizer = { workspace = true } codex-windows-sandbox = { package = "codex-windows-sandbox", path = "../windows-sandbox-rs" } dirs = { workspace = true } dunce = { workspace = true } env-flags = { workspace = true } +encoding_rs = { workspace = true } eventsource-stream = { workspace = true } futures = { workspace = true } http = { workspace = true } diff --git a/codex-rs/core/gpt-5.1-codex-max_prompt.md b/codex-rs/core/gpt-5.1-codex-max_prompt.md new file mode 100644 index 000000000..292e5d7d0 --- /dev/null +++ b/codex-rs/core/gpt-5.1-codex-max_prompt.md @@ -0,0 +1,117 @@ +You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer. + +## General + +- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.) + +## Editing constraints + +- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them. +- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare. +- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase). +- You may be in a dirty git worktree. + * NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user. + * If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes. + * If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them. + * If the changes are in unrelated files, just ignore them and don't revert them. +- Do not amend a commit unless explicitly requested to do so. +- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed. +- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user. + +## Plan tool + +When using the planning tool: +- Skip using the planning tool for straightforward tasks (roughly the easiest 25%). +- Do not make single-step plans. +- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan. + +## Codex CLI harness, sandboxing, and approvals + +The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from. + +Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are: +- **read-only**: The sandbox only permits reading files. +- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval. +- **danger-full-access**: No filesystem sandboxing - all commands are permitted. + +Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are: +- **restricted**: Requires approval +- **enabled**: No approval needed + +Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are +- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands. +- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox. +- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.) +- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding. + +When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval: +- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var) +- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files. +- You are running sandboxed and need to run a command that requires network access (e.g. installing packages) +- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `with_escalated_permissions` and `justification` parameters - do not message the user before requesting approval for the command. +- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for +- (for all of these, you should weigh alternative paths that do not require approval) + +When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read. + +You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure. + +Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to "never", in which case never ask for approvals. + +When requesting approval to execute a command that will require escalated privileges: + - Provide the `with_escalated_permissions` parameter with the boolean value true + - Include a short, 1 sentence explanation for why you need to enable `with_escalated_permissions` in the justification parameter + +## Special user requests + +- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so. +- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps. + +## Frontend tasks +When doing frontend design tasks, avoid collapsing into "AI slop" or safe, average-looking layouts. +Aim for interfaces that feel intentional, bold, and a bit surprising. +- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system). +- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias. +- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions. +- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere. +- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs. +- Ensure the page loads properly on both desktop and mobile + +Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language. + +## Presenting your work and final message + +You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value. + +- Default: be very concise; friendly coding teammate tone. +- Ask only when needed; suggest ideas; mirror the user's style. +- For substantial work, summarize clearly; follow final‑answer formatting. +- Skip heavy formatting for simple confirmations. +- Don't dump large files you've written; reference paths only. +- No "save/copy this file" - User is on the same machine. +- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something. +- For code changes: + * Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in. + * If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps. + * When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number. +- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result. + +### Final answer structure and style guidelines + +- Plain text; CLI handles styling. Use structure only when it helps scanability. +- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help. +- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent. +- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **. +- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible. +- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task. +- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording. +- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers. +- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets. +- File References: When referencing files in your response follow the below rules: + * Use inline code to make file paths clickable. + * Each reference should have a stand alone path. Even if it's the same file. + * Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix. + * Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1). + * Do not use URIs like file://, vscode://, or https://. + * Do not provide range of lines + * Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5 diff --git a/codex-rs/core/src/bash.rs b/codex-rs/core/src/bash.rs index 0ffb8e785..cb8248ec1 100644 --- a/codex-rs/core/src/bash.rs +++ b/codex-rs/core/src/bash.rs @@ -100,7 +100,7 @@ pub fn extract_bash_command(command: &[String]) -> Option<(&str, &str)> { if !matches!(flag.as_str(), "-lc" | "-c") || !matches!( detect_shell_type(&PathBuf::from(shell)), - Some(ShellType::Zsh) | Some(ShellType::Bash) + Some(ShellType::Zsh) | Some(ShellType::Bash) | Some(ShellType::Sh) ) { return None; diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 13c277a77..68bd30f4e 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -56,6 +56,7 @@ use crate::model_family::ModelFamily; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; use crate::openai_model_info::get_model_info; +use crate::protocol::CreditsSnapshot; use crate::protocol::RateLimitSnapshot; use crate::protocol::RateLimitWindow; use crate::protocol::TokenUsage; @@ -726,7 +727,13 @@ fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option { "x-codex-secondary-reset-at", ); - Some(RateLimitSnapshot { primary, secondary }) + let credits = parse_credits_snapshot(headers); + + Some(RateLimitSnapshot { + primary, + secondary, + credits, + }) } fn parse_rate_limit_window( @@ -753,6 +760,20 @@ fn parse_rate_limit_window( }) } +fn parse_credits_snapshot(headers: &HeaderMap) -> Option { + let has_credits = parse_header_bool(headers, "x-codex-credits-has-credits")?; + let unlimited = parse_header_bool(headers, "x-codex-credits-unlimited")?; + let balance = parse_header_str(headers, "x-codex-credits-balance") + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(std::string::ToString::to_string); + Some(CreditsSnapshot { + has_credits, + unlimited, + balance, + }) +} + fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option { parse_header_str(headers, name)? .parse::() @@ -764,6 +785,17 @@ fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option { parse_header_str(headers, name)?.parse::().ok() } +fn parse_header_bool(headers: &HeaderMap, name: &str) -> Option { + let raw = parse_header_str(headers, name)?; + if raw.eq_ignore_ascii_case("true") || raw == "1" { + Some(true) + } else if raw.eq_ignore_ascii_case("false") || raw == "0" { + Some(false) + } else { + None + } +} + fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { headers.get(name)?.to_str().ok() } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 09cc76922..21f6fc657 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -136,7 +136,7 @@ fn reserialize_shell_outputs(items: &mut [ResponseItem]) { } fn is_shell_tool_name(name: &str) -> bool { - matches!(name, "shell" | "container.exec" | "shell_command") + matches!(name, "shell" | "container.exec") } #[derive(Deserialize)] @@ -165,11 +165,9 @@ fn build_structured_output(parsed: &ExecOutputJson) -> String { )); let mut output = parsed.output.clone(); - if let Some(total_lines) = extract_total_output_lines(&parsed.output) { + if let Some((stripped, total_lines)) = strip_total_output_header(&parsed.output) { sections.push(format!("Total output lines: {total_lines}")); - if let Some(stripped) = strip_total_output_header(&output) { - output = stripped.to_string(); - } + output = stripped.to_string(); } sections.push("Output:".to_string()); @@ -178,19 +176,12 @@ fn build_structured_output(parsed: &ExecOutputJson) -> String { sections.join("\n") } -fn extract_total_output_lines(output: &str) -> Option { - let marker_start = output.find("[... omitted ")?; - let marker = &output[marker_start..]; - let (_, after_of) = marker.split_once(" of ")?; - let (total_segment, _) = after_of.split_once(' ')?; - total_segment.parse::().ok() -} - -fn strip_total_output_header(output: &str) -> Option<&str> { +fn strip_total_output_header(output: &str) -> Option<(&str, u32)> { let after_prefix = output.strip_prefix("Total output lines: ")?; - let (_, remainder) = after_prefix.split_once('\n')?; + let (total_segment, remainder) = after_prefix.split_once('\n')?; + let total_lines = total_segment.parse::().ok()?; let remainder = remainder.strip_prefix('\n').unwrap_or(remainder); - Some(remainder) + Some((remainder, total_lines)) } #[derive(Debug)] @@ -431,7 +422,7 @@ mod tests { expects_apply_patch_instructions: false, }, InstructionsTestCase { - slug: "gpt-5.1-codex", + slug: "gpt-5.1-codex-max", expects_apply_patch_instructions: false, }, ]; diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 64d06d057..45e3d87ac 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -7,12 +7,16 @@ use std::sync::atomic::AtomicU64; use crate::AuthManager; use crate::client_common::REVIEW_PROMPT; use crate::compact; +use crate::compact::run_inline_auto_compact_task; +use crate::compact::should_use_remote_compact_task; +use crate::compact_remote::run_inline_remote_auto_compact_task; use crate::features::Feature; use crate::function_tool::FunctionCallError; use crate::parse_command::parse_command; use crate::parse_turn_item; use crate::response_processing::process_items; use crate::terminal; +use crate::truncate::TruncationPolicy; use crate::user_notification::UserNotifier; use crate::util::error_or_panic; use async_channel::Receiver; @@ -75,7 +79,6 @@ use crate::protocol::ApplyPatchApprovalRequestEvent; use crate::protocol::AskForApproval; use crate::protocol::BackgroundEventEvent; use crate::protocol::DeprecationNoticeEvent; -use crate::protocol::ErrorEvent; use crate::protocol::Event; use crate::protocol::EventMsg; use crate::protocol::ExecApprovalRequestEvent; @@ -117,6 +120,7 @@ use crate::user_instructions::UserInstructions; use crate::user_notification::UserNotification; use crate::util::backoff; use codex_async_utils::OrCancelExt; +use codex_execpolicy::Policy as ExecPolicy; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; @@ -124,11 +128,11 @@ use codex_protocol::models::ContentItem; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::CodexErrorInfo; use codex_protocol::protocol::InitialHistory; use codex_protocol::user_input::UserInput; use codex_utils_readiness::Readiness; use codex_utils_readiness::ReadinessFlag; -use codex_utils_tokenizer::warm_model_cache; /// The high-level interface to the Codex system. /// It operates as a queue pair where you send submissions and receive events. @@ -162,6 +166,10 @@ impl Codex { let user_instructions = get_user_instructions(&config).await; + let exec_policy = crate::exec_policy::exec_policy_for(&config.features, &config.codex_home) + .await + .map_err(|err| CodexErr::Fatal(format!("failed to load execpolicy: {err}")))?; + let config = Arc::new(config); let session_configuration = SessionConfiguration { @@ -178,6 +186,7 @@ impl Codex { cwd: config.cwd.clone(), original_config_do_not_use: Arc::clone(&config), features: config.features.clone(), + exec_policy, session_source, }; @@ -275,6 +284,8 @@ pub(crate) struct TurnContext { pub(crate) final_output_json_schema: Option, pub(crate) codex_linux_sandbox_exe: Option, pub(crate) tool_call_gate: Arc, + pub(crate) exec_policy: Arc, + pub(crate) truncation_policy: TruncationPolicy, } impl TurnContext { @@ -291,7 +302,6 @@ impl TurnContext { } } -#[allow(dead_code)] #[derive(Clone)] pub(crate) struct SessionConfiguration { /// Provider identifier ("openai", "openrouter", ...). @@ -331,6 +341,8 @@ pub(crate) struct SessionConfiguration { /// Set of feature flags for this session features: Features, + /// Execpolicy policy, applied only when enabled by feature flag. + exec_policy: Arc, // TODO(pakrym): Remove config from here original_config_do_not_use: Arc, @@ -401,7 +413,7 @@ impl Session { ); let client = ModelClient::new( - Arc::new(per_turn_config), + Arc::new(per_turn_config.clone()), auth_manager, otel_event_manager, provider, @@ -431,6 +443,8 @@ impl Session { final_output_json_schema: None, codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(), tool_call_gate: Arc::new(ReadinessFlag::new()), + exec_policy: session_configuration.exec_policy.clone(), + truncation_policy: TruncationPolicy::new(&per_turn_config), } } @@ -478,7 +492,7 @@ impl Session { // - load history metadata let rollout_fut = RolloutRecorder::new(&config, rollout_params); - let default_shell_fut = shell::default_user_shell(); + let default_shell = shell::default_user_shell(); let history_meta_fut = crate::message_history::history_metadata(&config); let auth_statuses_fut = compute_auth_statuses( config.mcp_servers.iter(), @@ -486,12 +500,8 @@ impl Session { ); // Join all independent futures. - let (rollout_recorder, default_shell, (history_log_id, history_entry_count), auth_statuses) = tokio::join!( - rollout_fut, - default_shell_fut, - history_meta_fut, - auth_statuses_fut - ); + let (rollout_recorder, (history_log_id, history_entry_count), auth_statuses) = + tokio::join!(rollout_fut, history_meta_fut, auth_statuses_fut); let rollout_recorder = rollout_recorder.map_err(|e| { error!("failed to initialize rollout recorder: {e:#}"); @@ -533,7 +543,6 @@ impl Session { config.model_reasoning_effort, config.model_reasoning_summary, config.model_context_window, - config.model_max_output_tokens, config.model_auto_compact_token_limit, config.approval_policy, config.sandbox_policy.clone(), @@ -544,9 +553,6 @@ impl Session { // Create the mutable state for the Session. let state = SessionState::new(session_configuration.clone()); - // Warm the tokenizer cache for the session model without blocking startup. - warm_model_cache(&session_configuration.model); - let services = SessionServices { mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())), mcp_startup_cancellation_token: CancellationToken::new(), @@ -578,6 +584,10 @@ impl Session { msg: EventMsg::SessionConfigured(SessionConfiguredEvent { session_id: conversation_id, model: session_configuration.model.clone(), + model_provider_id: config.model_provider_id.clone(), + approval_policy: session_configuration.approval_policy, + sandbox_policy: session_configuration.sandbox_policy.clone(), + cwd: session_configuration.cwd.clone(), reasoning_effort: session_configuration.model_reasoning_effort, history_log_id, history_entry_count, @@ -678,7 +688,8 @@ impl Session { let reconstructed_history = self.reconstruct_history_from_rollout(&turn_context, &rollout_items); if !reconstructed_history.is_empty() { - self.record_into_history(&reconstructed_history).await; + self.record_into_history(&reconstructed_history, &turn_context) + .await; } // If persisting, persist all rollout items as-is (recorder filters) @@ -899,6 +910,7 @@ impl Session { let event = EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id, + turn_id: turn_context.sub_id.clone(), changes, reason, grant_root, @@ -935,7 +947,7 @@ impl Session { turn_context: &TurnContext, items: &[ResponseItem], ) { - self.record_into_history(items).await; + self.record_into_history(items, turn_context).await; self.persist_rollout_response_items(items).await; self.send_raw_response_items(turn_context, items).await; } @@ -949,7 +961,10 @@ impl Session { for item in rollout_items { match item { RolloutItem::ResponseItem(response_item) => { - history.record_items(std::iter::once(response_item)); + history.record_items( + std::iter::once(response_item), + turn_context.truncation_policy, + ); } RolloutItem::Compacted(compacted) => { let snapshot = history.get_history(); @@ -973,9 +988,13 @@ impl Session { } /// Append ResponseItems to the in-memory conversation history only. - pub(crate) async fn record_into_history(&self, items: &[ResponseItem]) { + pub(crate) async fn record_into_history( + &self, + items: &[ResponseItem], + turn_context: &TurnContext, + ) { let mut state = self.state.lock().await; - state.record_items(items.iter()); + state.record_items(items.iter(), turn_context.truncation_policy); } pub(crate) async fn replace_history(&self, items: Vec) { @@ -1029,7 +1048,7 @@ impl Session { Some(turn_context.cwd.clone()), Some(turn_context.approval_policy), Some(turn_context.sandbox_policy.clone()), - Some(self.user_shell().clone()), + self.user_shell().clone(), ))); items } @@ -1068,11 +1087,14 @@ impl Session { self.send_token_count_event(turn_context).await; } - pub(crate) async fn override_last_token_usage_estimate( - &self, - turn_context: &TurnContext, - estimated_total_tokens: i64, - ) { + pub(crate) async fn recompute_token_usage(&self, turn_context: &TurnContext) { + let Some(estimated_total_tokens) = self + .clone_history() + .await + .estimate_token_count(turn_context) + else { + return; + }; { let mut state = self.state.lock().await; let mut info = state.token_info().unwrap_or(TokenUsageInfo { @@ -1166,9 +1188,14 @@ impl Session { &self, turn_context: &TurnContext, message: impl Into, + codex_error: CodexErr, ) { + let codex_error_info = CodexErrorInfo::ResponseStreamDisconnected { + http_status_code: codex_error.http_status_code_value(), + }; let event = EventMsg::StreamError(StreamErrorEvent { message: message.into(), + codex_error_info: Some(codex_error_info), }); self.send_event(turn_context, event).await; } @@ -1315,7 +1342,10 @@ impl Session { } async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiver) { - let mut previous_context: Option> = None; + // Seed with context in case there is an OverrideTurnContext first. + let mut previous_context: Option> = + Some(sess.new_turn(SessionSettingsUpdate::default()).await); + // To break out of this loop, send Op::Shutdown. while let Ok(sub) = rx_sub.recv().await { debug!(?sub, "Submission"); @@ -1411,6 +1441,7 @@ mod handlers { use crate::tasks::UndoTask; use crate::tasks::UserShellCommandTask; use codex_protocol::custom_prompts::CustomPrompt; + use codex_protocol::protocol::CodexErrorInfo; use codex_protocol::protocol::ErrorEvent; use codex_protocol::protocol::Event; use codex_protocol::protocol::EventMsg; @@ -1420,6 +1451,7 @@ mod handlers { use codex_protocol::protocol::ReviewRequest; use codex_protocol::protocol::TurnAbortReason; + use codex_protocol::user_input::UserInput; use std::sync::Arc; use tracing::info; use tracing::warn; @@ -1628,8 +1660,14 @@ mod handlers { .new_turn_with_sub_id(sub_id, SessionSettingsUpdate::default()) .await; - sess.spawn_task(Arc::clone(&turn_context), vec![], CompactTask) - .await; + sess.spawn_task( + Arc::clone(&turn_context), + vec![UserInput::Text { + text: turn_context.compact_prompt().to_string(), + }], + CompactTask, + ) + .await; } pub async fn shutdown(sess: &Arc, sub_id: String) -> bool { @@ -1650,6 +1688,7 @@ mod handlers { id: sub_id.clone(), msg: EventMsg::Error(ErrorEvent { message: "Failed to shutdown rollout recorder".to_string(), + codex_error_info: Some(CodexErrorInfo::Other), }), }; sess.send_event_raw(event).await; @@ -1755,6 +1794,8 @@ async fn spawn_review_thread( final_output_json_schema: None, codex_linux_sandbox_exe: parent_turn_context.codex_linux_sandbox_exe.clone(), tool_call_gate: Arc::new(ReadinessFlag::new()), + exec_policy: parent_turn_context.exec_policy.clone(), + truncation_policy: TruncationPolicy::new(&per_turn_config), }; // Seed the child task with the review prompt as the initial user message. @@ -1762,7 +1803,12 @@ async fn spawn_review_thread( text: review_prompt, }]; let tc = Arc::new(review_turn_context); - sess.spawn_task(tc.clone(), input, ReviewTask).await; + sess.spawn_task( + tc.clone(), + input, + ReviewTask::new(review_request.append_to_original_thread), + ) + .await; // Announce entering review mode so UIs can switch modes. sess.send_event(&tc, EventMsg::EnteredReviewMode(review_request)) @@ -1863,7 +1909,12 @@ pub(crate) async fn run_task( // as long as compaction works well in getting us way below the token limit, we shouldn't worry about being in an infinite loop. if token_limit_reached { - compact::run_inline_auto_compact_task(sess.clone(), turn_context.clone()).await; + if should_use_remote_compact_task(&sess).await { + run_inline_remote_auto_compact_task(sess.clone(), turn_context.clone()) + .await; + } else { + run_inline_auto_compact_task(sess.clone(), turn_context.clone()).await; + } continue; } @@ -1892,9 +1943,7 @@ pub(crate) async fn run_task( } Err(e) => { info!("Turn error: {e:#}"); - let event = EventMsg::Error(ErrorEvent { - message: e.to_string(), - }); + let event = EventMsg::Error(e.to_error_event(None)); sess.send_event(&turn_context, event).await; // let the user continue the conversation break; @@ -1947,12 +1996,14 @@ async fn run_turn( let mut base_instructions = turn_context.base_instructions.clone(); if parallel_tool_calls { static INSTRUCTIONS: &str = include_str!("../templates/parallel/instructions.md"); - static INSERTION_SPOT: &str = "## Editing constraints"; - base_instructions - .as_mut() - .map(|base| base.replace(INSERTION_SPOT, INSTRUCTIONS)); + if let Some(family) = + find_family_for_model(&sess.state.lock().await.session_configuration.model) + { + let mut new_instructions = base_instructions.unwrap_or(family.base_instructions); + new_instructions.push_str(INSTRUCTIONS); + base_instructions = Some(new_instructions); + } } - let prompt = Prompt { input, tools: router.specs(), @@ -2017,6 +2068,7 @@ async fn run_turn( sess.notify_stream_error( &turn_context, format!("Reconnecting... {retries}/{max_retries}"), + e, ) .await; @@ -2335,6 +2387,7 @@ mod tests { use crate::config::ConfigOverrides; use crate::config::ConfigToml; use crate::exec::ExecToolCallOutput; + use crate::shell::default_user_shell; use crate::tools::format_exec_output_str; use crate::protocol::CompactedItem; @@ -2444,8 +2497,9 @@ mod tests { duration: StdDuration::from_secs(1), timed_out: true, }; + let (_, turn_context) = make_session_and_context(); - let out = format_exec_output_str(&exec); + let out = format_exec_output_str(&exec, turn_context.truncation_policy); assert_eq!( out, @@ -2561,6 +2615,7 @@ mod tests { cwd: config.cwd.clone(), original_config_do_not_use: Arc::clone(&config), features: Features::default(), + exec_policy: Arc::new(ExecPolicy::empty()), session_source: SessionSource::Exec, }; @@ -2572,7 +2627,7 @@ mod tests { unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(None), rollout: Mutex::new(None), - user_shell: shell::Shell::Unknown, + user_shell: default_user_shell(), show_raw_agent_reasoning: config.show_raw_agent_reasoning, auth_manager: Arc::clone(&auth_manager), otel_event_manager: otel_event_manager.clone(), @@ -2638,6 +2693,7 @@ mod tests { cwd: config.cwd.clone(), original_config_do_not_use: Arc::clone(&config), features: Features::default(), + exec_policy: Arc::new(ExecPolicy::empty()), session_source: SessionSource::Exec, }; @@ -2649,7 +2705,7 @@ mod tests { unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(None), rollout: Mutex::new(None), - user_shell: shell::Shell::Unknown, + user_shell: default_user_shell(), show_raw_agent_reasoning: config.show_raw_agent_reasoning, auth_manager: Arc::clone(&auth_manager), otel_event_manager: otel_event_manager.clone(), @@ -2768,7 +2824,8 @@ mod tests { let input = vec![UserInput::Text { text: "start review".to_string(), }]; - sess.spawn_task(Arc::clone(&tc), input, ReviewTask).await; + sess.spawn_task(Arc::clone(&tc), input, ReviewTask::new(true)) + .await; sess.abort_all_tasks(TurnAbortReason::Interrupted).await; @@ -2886,7 +2943,7 @@ mod tests { for item in &initial_context { rollout_items.push(RolloutItem::ResponseItem(item.clone())); } - live_history.record_items(initial_context.iter()); + live_history.record_items(initial_context.iter(), turn_context.truncation_policy); let user1 = ResponseItem::Message { id: None, @@ -2895,7 +2952,7 @@ mod tests { text: "first user".to_string(), }], }; - live_history.record_items(std::iter::once(&user1)); + live_history.record_items(std::iter::once(&user1), turn_context.truncation_policy); rollout_items.push(RolloutItem::ResponseItem(user1.clone())); let assistant1 = ResponseItem::Message { @@ -2905,7 +2962,7 @@ mod tests { text: "assistant reply one".to_string(), }], }; - live_history.record_items(std::iter::once(&assistant1)); + live_history.record_items(std::iter::once(&assistant1), turn_context.truncation_policy); rollout_items.push(RolloutItem::ResponseItem(assistant1.clone())); let summary1 = "summary one"; @@ -2929,7 +2986,7 @@ mod tests { text: "second user".to_string(), }], }; - live_history.record_items(std::iter::once(&user2)); + live_history.record_items(std::iter::once(&user2), turn_context.truncation_policy); rollout_items.push(RolloutItem::ResponseItem(user2.clone())); let assistant2 = ResponseItem::Message { @@ -2939,7 +2996,7 @@ mod tests { text: "assistant reply two".to_string(), }], }; - live_history.record_items(std::iter::once(&assistant2)); + live_history.record_items(std::iter::once(&assistant2), turn_context.truncation_policy); rollout_items.push(RolloutItem::ResponseItem(assistant2.clone())); let summary2 = "summary two"; @@ -2963,7 +3020,7 @@ mod tests { text: "third user".to_string(), }], }; - live_history.record_items(std::iter::once(&user3)); + live_history.record_items(std::iter::once(&user3), turn_context.truncation_policy); rollout_items.push(RolloutItem::ResponseItem(user3.clone())); let assistant3 = ResponseItem::Message { @@ -2973,7 +3030,7 @@ mod tests { text: "assistant reply three".to_string(), }], }; - live_history.record_items(std::iter::once(&assistant3)); + live_history.record_items(std::iter::once(&assistant3), turn_context.truncation_policy); rollout_items.push(RolloutItem::ResponseItem(assistant3.clone())); (rollout_items, live_history.get_history()) @@ -2993,6 +3050,7 @@ mod tests { let session = Arc::new(session); let mut turn_context = Arc::new(turn_context_raw); + let timeout_ms = 1000; let params = ExecParams { command: if cfg!(windows) { vec![ @@ -3008,7 +3066,7 @@ mod tests { ] }, cwd: turn_context.cwd.clone(), - timeout_ms: Some(1000), + expiration: timeout_ms.into(), env: HashMap::new(), with_escalated_permissions: Some(true), justification: Some("test".to_string()), @@ -3017,7 +3075,12 @@ mod tests { let params2 = ExecParams { with_escalated_permissions: Some(false), - ..params.clone() + command: params.command.clone(), + cwd: params.cwd.clone(), + expiration: timeout_ms.into(), + env: HashMap::new(), + justification: params.justification.clone(), + arg0: None, }; let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); @@ -3037,7 +3100,7 @@ mod tests { arguments: serde_json::json!({ "command": params.command.clone(), "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), - "timeout_ms": params.timeout_ms, + "timeout_ms": params.expiration.timeout_ms(), "with_escalated_permissions": params.with_escalated_permissions, "justification": params.justification.clone(), }) @@ -3074,7 +3137,7 @@ mod tests { arguments: serde_json::json!({ "command": params2.command.clone(), "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), - "timeout_ms": params2.timeout_ms, + "timeout_ms": params2.expiration.timeout_ms(), "with_escalated_permissions": params2.with_escalated_permissions, "justification": params2.justification.clone(), }) diff --git a/codex-rs/core/src/command_safety/is_dangerous_command.rs b/codex-rs/core/src/command_safety/is_dangerous_command.rs index 09594bb1c..5df2023f0 100644 --- a/codex-rs/core/src/command_safety/is_dangerous_command.rs +++ b/codex-rs/core/src/command_safety/is_dangerous_command.rs @@ -1,6 +1,8 @@ use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::SandboxPolicy; +use crate::sandboxing::SandboxPermissions; + use crate::bash::parse_shell_lc_plain_commands; use crate::is_safe_command::is_known_safe_command; @@ -8,7 +10,7 @@ pub fn requires_initial_appoval( policy: AskForApproval, sandbox_policy: &SandboxPolicy, command: &[String], - with_escalated_permissions: bool, + sandbox_permissions: SandboxPermissions, ) -> bool { if is_known_safe_command(command) { return false; @@ -24,8 +26,7 @@ pub fn requires_initial_appoval( // In restricted sandboxes (ReadOnly/WorkspaceWrite), do not prompt for // non‑escalated, non‑dangerous commands — let the sandbox enforce // restrictions (e.g., block network/write) without a user prompt. - let wants_escalation: bool = with_escalated_permissions; - if wants_escalation { + if sandbox_permissions.requires_escalated_permissions() { return true; } command_might_be_dangerous(command) diff --git a/codex-rs/core/src/command_safety/is_safe_command.rs b/codex-rs/core/src/command_safety/is_safe_command.rs index ab084c191..5e7a7034d 100644 --- a/codex-rs/core/src/command_safety/is_safe_command.rs +++ b/codex-rs/core/src/command_safety/is_safe_command.rs @@ -267,6 +267,20 @@ mod tests { } } + #[test] + fn windows_powershell_full_path_is_safe() { + if !cfg!(windows) { + // Windows only because on Linux path splitting doesn't handle `/` separators properly + return; + } + + assert!(is_known_safe_command(&vec_str(&[ + r"C:\Program Files\PowerShell\7\pwsh.exe", + "-Command", + "Get-Location", + ]))); + } + #[test] fn bash_lc_safe_examples() { assert!(is_known_safe_command(&vec_str(&["bash", "-lc", "ls"]))); diff --git a/codex-rs/core/src/command_safety/windows_safe_commands.rs b/codex-rs/core/src/command_safety/windows_safe_commands.rs index ff0a3d2e7..a1d3b297f 100644 --- a/codex-rs/core/src/command_safety/windows_safe_commands.rs +++ b/codex-rs/core/src/command_safety/windows_safe_commands.rs @@ -1,4 +1,5 @@ use shlex::split as shlex_split; +use std::path::Path; /// On Windows, we conservatively allow only clearly read-only PowerShell invocations /// that match a small safelist. Anything else (including direct CMD commands) is unsafe. @@ -131,8 +132,14 @@ fn split_into_commands(tokens: Vec) -> Option>> { /// Returns true when the executable name is one of the supported PowerShell binaries. fn is_powershell_executable(exe: &str) -> bool { + let executable_name = Path::new(exe) + .file_name() + .and_then(|osstr| osstr.to_str()) + .unwrap_or(exe) + .to_ascii_lowercase(); + matches!( - exe.to_ascii_lowercase().as_str(), + executable_name.as_str(), "powershell" | "powershell.exe" | "pwsh" | "pwsh.exe" ) } @@ -313,6 +320,27 @@ mod tests { ]))); } + #[test] + fn accepts_full_path_powershell_invocations() { + if !cfg!(windows) { + // Windows only because on Linux path splitting doesn't handle `/` separators properly + return; + } + + assert!(is_safe_command_windows(&vec_str(&[ + r"C:\Program Files\PowerShell\7\pwsh.exe", + "-NoProfile", + "-Command", + "Get-ChildItem -Path .", + ]))); + + assert!(is_safe_command_windows(&vec_str(&[ + r"C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe", + "-Command", + "Get-Content Cargo.toml", + ]))); + } + #[test] fn allows_read_only_pipelines_and_git_usage() { assert!(is_safe_command_windows(&vec_str(&[ diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index 0495c161d..1b3937b9f 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -7,15 +7,18 @@ use crate::codex::TurnContext; use crate::codex::get_last_assistant_message_from_turn; use crate::error::CodexErr; use crate::error::Result as CodexResult; +use crate::features::Feature; use crate::protocol::AgentMessageEvent; use crate::protocol::CompactedItem; -use crate::protocol::ErrorEvent; use crate::protocol::EventMsg; use crate::protocol::TaskStartedEvent; use crate::protocol::TurnContextItem; use crate::protocol::WarningEvent; -use crate::truncate::truncate_middle; +use crate::truncate::TruncationPolicy; +use crate::truncate::approx_token_count; +use crate::truncate::truncate_text; use crate::util::backoff; +use codex_app_server_protocol::AuthMode; use codex_protocol::items::TurnItem; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseInputItem; @@ -29,12 +32,22 @@ pub const SUMMARIZATION_PROMPT: &str = include_str!("../templates/compact/prompt pub const SUMMARY_PREFIX: &str = include_str!("../templates/compact/summary_prefix.md"); const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000; +pub(crate) async fn should_use_remote_compact_task(session: &Session) -> bool { + session + .services + .auth_manager + .auth() + .is_some_and(|auth| auth.mode == AuthMode::ChatGPT) + && session.enabled(Feature::RemoteCompaction).await +} + pub(crate) async fn run_inline_auto_compact_task( sess: Arc, turn_context: Arc, ) { let prompt = turn_context.compact_prompt().to_string(); let input = vec![UserInput::Text { text: prompt }]; + run_compact_task_inner(sess, turn_context, input).await; } @@ -42,13 +55,12 @@ pub(crate) async fn run_compact_task( sess: Arc, turn_context: Arc, input: Vec, -) -> Option { +) { let start_event = EventMsg::TaskStarted(TaskStartedEvent { model_context_window: turn_context.client.get_model_context_window(), }); sess.send_event(&turn_context, start_event).await; run_compact_task_inner(sess.clone(), turn_context, input).await; - None } async fn run_compact_task_inner( @@ -59,7 +71,10 @@ async fn run_compact_task_inner( let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input); let mut history = sess.clone_history().await; - history.record_items(&[initial_input_for_turn.into()]); + history.record_items( + &[initial_input_for_turn.into()], + turn_context.truncation_policy, + ); let mut truncated_count = 0usize; @@ -112,9 +127,7 @@ async fn run_compact_task_inner( continue; } sess.set_total_tokens_full(turn_context.as_ref()).await; - let event = EventMsg::Error(ErrorEvent { - message: e.to_string(), - }); + let event = EventMsg::Error(e.to_error_event(None)); sess.send_event(&turn_context, event).await; return; } @@ -125,14 +138,13 @@ async fn run_compact_task_inner( sess.notify_stream_error( turn_context.as_ref(), format!("Reconnecting... {retries}/{max_retries}"), + e, ) .await; tokio::time::sleep(delay).await; continue; } else { - let event = EventMsg::Error(ErrorEvent { - message: e.to_string(), - }); + let event = EventMsg::Error(e.to_error_event(None)); sess.send_event(&turn_context, event).await; return; } @@ -155,15 +167,7 @@ async fn run_compact_task_inner( .collect(); new_history.extend(ghost_snapshots); sess.replace_history(new_history).await; - - if let Some(estimated_tokens) = sess - .clone_history() - .await - .estimate_token_count(&turn_context) - { - sess.override_last_token_usage_estimate(&turn_context, estimated_tokens) - .await; - } + sess.recompute_token_usage(&turn_context).await; let rollout_item = RolloutItem::Compacted(CompactedItem { message: summary_text.clone(), @@ -230,7 +234,7 @@ pub(crate) fn build_compacted_history( initial_context, user_messages, summary_text, - COMPACT_USER_MESSAGE_MAX_TOKENS * 4, + COMPACT_USER_MESSAGE_MAX_TOKENS, ) } @@ -238,20 +242,21 @@ fn build_compacted_history_with_limit( mut history: Vec, user_messages: &[String], summary_text: &str, - max_bytes: usize, + max_tokens: usize, ) -> Vec { let mut selected_messages: Vec = Vec::new(); - if max_bytes > 0 { - let mut remaining = max_bytes; + if max_tokens > 0 { + let mut remaining = max_tokens; for message in user_messages.iter().rev() { if remaining == 0 { break; } - if message.len() <= remaining { + let tokens = approx_token_count(message); + if tokens <= remaining { selected_messages.push(message.clone()); - remaining = remaining.saturating_sub(message.len()); + remaining = remaining.saturating_sub(tokens); } else { - let (truncated, _) = truncate_middle(message, remaining); + let truncated = truncate_text(message, TruncationPolicy::Tokens(remaining)); selected_messages.push(truncated); break; } @@ -300,7 +305,8 @@ async fn drain_to_completed( }; match event { Ok(ResponseEvent::OutputItemDone(item)) => { - sess.record_into_history(std::slice::from_ref(&item)).await; + sess.record_into_history(std::slice::from_ref(&item), turn_context) + .await; } Ok(ResponseEvent::RateLimits(snapshot)) => { sess.update_rate_limits(turn_context, snapshot).await; @@ -318,6 +324,7 @@ async fn drain_to_completed( #[cfg(test)] mod tests { + use super::*; use pretty_assertions::assert_eq; @@ -409,16 +416,16 @@ mod tests { } #[test] - fn build_compacted_history_truncates_overlong_user_messages() { + fn build_token_limited_compacted_history_truncates_overlong_user_messages() { // Use a small truncation limit so the test remains fast while still validating // that oversized user content is truncated. - let max_bytes = 128; - let big = "X".repeat(max_bytes + 50); + let max_tokens = 16; + let big = "word ".repeat(200); let history = super::build_compacted_history_with_limit( Vec::new(), std::slice::from_ref(&big), "SUMMARY", - max_bytes, + max_tokens, ); assert_eq!(history.len(), 2); @@ -451,7 +458,7 @@ mod tests { } #[test] - fn build_compacted_history_appends_summary_message() { + fn build_token_limited_compacted_history_appends_summary_message() { let initial_context: Vec = Vec::new(); let user_messages = vec!["first user message".to_string()]; let summary_text = "summary text"; diff --git a/codex-rs/core/src/compact_remote.rs b/codex-rs/core/src/compact_remote.rs index 2c7d57eff..534d794f0 100644 --- a/codex-rs/core/src/compact_remote.rs +++ b/codex-rs/core/src/compact_remote.rs @@ -6,53 +6,41 @@ use crate::codex::TurnContext; use crate::error::Result as CodexResult; use crate::protocol::AgentMessageEvent; use crate::protocol::CompactedItem; -use crate::protocol::ErrorEvent; use crate::protocol::EventMsg; use crate::protocol::RolloutItem; use crate::protocol::TaskStartedEvent; -use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; -use codex_protocol::user_input::UserInput; -pub(crate) async fn run_remote_compact_task( +pub(crate) async fn run_inline_remote_auto_compact_task( sess: Arc, turn_context: Arc, - input: Vec, -) -> Option { +) { + run_remote_compact_task_inner(&sess, &turn_context).await; +} + +pub(crate) async fn run_remote_compact_task(sess: Arc, turn_context: Arc) { let start_event = EventMsg::TaskStarted(TaskStartedEvent { model_context_window: turn_context.client.get_model_context_window(), }); sess.send_event(&turn_context, start_event).await; - match run_remote_compact_task_inner(&sess, &turn_context, input).await { - Ok(()) => { - let event = EventMsg::AgentMessage(AgentMessageEvent { - message: "Compact task completed".to_string(), - }); - sess.send_event(&turn_context, event).await; - } - Err(err) => { - let event = EventMsg::Error(ErrorEvent { - message: err.to_string(), - }); - sess.send_event(&turn_context, event).await; - } - } + run_remote_compact_task_inner(&sess, &turn_context).await; +} - None +async fn run_remote_compact_task_inner(sess: &Arc, turn_context: &Arc) { + if let Err(err) = run_remote_compact_task_inner_impl(sess, turn_context).await { + let event = EventMsg::Error( + err.to_error_event(Some("Error running remote compact task".to_string())), + ); + sess.send_event(turn_context, event).await; + } } -async fn run_remote_compact_task_inner( +async fn run_remote_compact_task_inner_impl( sess: &Arc, turn_context: &Arc, - input: Vec, ) -> CodexResult<()> { let mut history = sess.clone_history().await; - if !input.is_empty() { - let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input); - history.record_items(&[initial_input_for_turn.into()]); - } - let prompt = Prompt { input: history.get_history_for_prompt(), tools: vec![], @@ -77,15 +65,7 @@ async fn run_remote_compact_task_inner( new_history.extend(ghost_snapshots); } sess.replace_history(new_history.clone()).await; - - if let Some(estimated_tokens) = sess - .clone_history() - .await - .estimate_token_count(turn_context.as_ref()) - { - sess.override_last_token_usage_estimate(turn_context.as_ref(), estimated_tokens) - .await; - } + sess.recompute_token_usage(turn_context).await; let compacted_item = CompactedItem { message: String::new(), @@ -93,5 +73,11 @@ async fn run_remote_compact_task_inner( }; sess.persist_rollout_items(&[RolloutItem::Compacted(compacted_item)]) .await; + + let event = EventMsg::AgentMessage(AgentMessageEvent { + message: "Compact task completed".to_string(), + }); + sess.send_event(turn_context, event).await; + Ok(()) } diff --git a/codex-rs/core/src/config/edit.rs b/codex-rs/core/src/config/edit.rs index 2d58ffc60..b8862fa5c 100644 --- a/codex-rs/core/src/config/edit.rs +++ b/codex-rs/core/src/config/edit.rs @@ -4,7 +4,6 @@ use crate::config::types::Notice; use anyhow::Context; use codex_protocol::config_types::ReasoningEffort; use codex_protocol::config_types::TrustLevel; -use codex_utils_tokenizer::warm_model_cache; use std::collections::BTreeMap; use std::path::Path; use std::path::PathBuf; @@ -231,9 +230,6 @@ impl ConfigDocument { fn apply(&mut self, edit: &ConfigEdit) -> anyhow::Result { match edit { ConfigEdit::SetModel { model, effort } => Ok({ - if let Some(model) = &model { - warm_model_cache(model) - } let mut mutated = false; mutated |= self.write_profile_value( &["model"], @@ -550,6 +546,15 @@ impl ConfigEditsBuilder { self } + /// Enable or disable a feature flag by key under the `[features]` table. + pub fn set_feature_enabled(mut self, key: &str, enabled: bool) -> Self { + self.edits.push(ConfigEdit::SetPath { + segments: vec!["features".to_string(), key.to_string()], + value: value(enabled), + }); + self + } + /// Apply edits on a blocking thread. pub fn apply_blocking(self) -> anyhow::Result<()> { apply_blocking(&self.codex_home, self.profile.as_deref(), &self.edits) @@ -836,6 +841,36 @@ hide_gpt5_1_migration_prompt = true assert_eq!(contents, expected); } + #[test] + fn blocking_set_hide_gpt_5_1_codex_max_migration_prompt_preserves_table() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[notice] +existing = "value" +"#, + ) + .expect("seed"); + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetNoticeHideModelMigrationPrompt( + "hide_gpt-5.1-codex-max_migration_prompt".to_string(), + true, + )], + ) + .expect("persist"); + + let contents = + std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[notice] +existing = "value" +"hide_gpt-5.1-codex-max_migration_prompt" = true +"#; + assert_eq!(contents, expected); + } + #[test] fn blocking_replace_mcp_servers_round_trips() { let tmp = tempdir().expect("tmpdir"); diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 5b57d4dc0..3d5f25763 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -61,9 +61,6 @@ pub mod edit; pub mod profile; pub mod types; -#[cfg(target_os = "windows")] -pub const OPENAI_DEFAULT_MODEL: &str = "gpt-5.1"; -#[cfg(not(target_os = "windows"))] pub const OPENAI_DEFAULT_MODEL: &str = "gpt-5.1-codex"; const OPENAI_DEFAULT_REVIEW_MODEL: &str = "gpt-5.1-codex"; pub const GPT_5_CODEX_MEDIUM_MODEL: &str = "gpt-5.1-codex"; @@ -81,7 +78,7 @@ pub struct Config { /// Optional override of model selection. pub model: String, - /// Model used specifically for review sessions. Defaults to "gpt-5.1-codex". + /// Model used specifically for review sessions. Defaults to "gpt-5.1-codex-max". pub review_model: String, pub model_family: ModelFamily, @@ -89,9 +86,6 @@ pub struct Config { /// Size of the context window for the model, in tokens. pub model_context_window: Option, - /// Maximum number of output tokens. - pub model_max_output_tokens: Option, - /// Token usage threshold triggering auto-compaction of conversation history. pub model_auto_compact_token_limit: Option, @@ -163,6 +157,9 @@ pub struct Config { /// and turn completions when not focused. pub tui_notifications: Notifications, + /// Enable ASCII animations and shimmer effects in the TUI. + pub animations: bool, + /// The directory that should be treated as the current working directory /// for the session. All relative paths inside the business-logic layer are /// resolved against this path. @@ -195,6 +192,9 @@ pub struct Config { /// Additional filenames to try when looking for project-level docs. pub project_doc_fallback_filenames: Vec, + /// Token budget applied when storing tool/function outputs in the context manager. + pub tool_output_token_limit: Option, + /// Directory containing all Codex state (defaults to `~/.codex` but can be /// overridden by the `CODEX_HOME` environment variable). pub codex_home: PathBuf, @@ -567,9 +567,6 @@ pub struct ConfigToml { /// Size of the context window for the model, in tokens. pub model_context_window: Option, - /// Maximum number of output tokens. - pub model_max_output_tokens: Option, - /// Token usage threshold triggering auto-compaction of conversation history. pub model_auto_compact_token_limit: Option, @@ -636,6 +633,9 @@ pub struct ConfigToml { /// Ordered list of fallback filenames to look for when AGENTS.md is missing. pub project_doc_fallback_filenames: Option>, + /// Token budget applied when storing tool/function outputs in the context manager. + pub tool_output_token_limit: Option, + /// Profile to use from the `profiles` map. pub profile: Option, @@ -1116,11 +1116,6 @@ impl Config { let model_context_window = cfg .model_context_window .or_else(|| openai_model_info.as_ref().map(|info| info.context_window)); - let model_max_output_tokens = cfg.model_max_output_tokens.or_else(|| { - openai_model_info - .as_ref() - .map(|info| info.max_output_tokens) - }); let model_auto_compact_token_limit = cfg.model_auto_compact_token_limit.or_else(|| { openai_model_info .as_ref() @@ -1172,7 +1167,6 @@ impl Config { review_model, model_family, model_context_window, - model_max_output_tokens, model_auto_compact_token_limit, model_provider_id, model_provider, @@ -1209,6 +1203,7 @@ impl Config { } }) .collect(), + tool_output_token_limit: cfg.tool_output_token_limit, codex_home, history, file_opener: cfg.file_opener.unwrap_or(UriBasedFileOpener::VsCode), @@ -1249,6 +1244,7 @@ impl Config { .as_ref() .map(|t| t.notifications.clone()) .unwrap_or_default(), + animations: cfg.tui.as_ref().map(|t| t.animations).unwrap_or(true), otel: { let t: OtelConfigToml = cfg.otel.unwrap_or_default(); let log_user_prompt = t.log_user_prompt.unwrap_or(false); @@ -1313,6 +1309,16 @@ impl Config { Ok(Some(s)) } } + + pub fn set_windows_sandbox_globally(&mut self, value: bool) { + crate::safety::set_windows_sandbox_enabled(value); + if value { + self.features.enable(Feature::WindowsSandbox); + } else { + self.features.disable(Feature::WindowsSandbox); + } + self.forced_auto_mode_downgraded_on_windows = !value; + } } fn default_model() -> String { @@ -2943,7 +2949,6 @@ model_verbosity = "high" review_model: OPENAI_DEFAULT_REVIEW_MODEL.to_string(), model_family: find_family_for_model("o3").expect("known model slug"), model_context_window: Some(200_000), - model_max_output_tokens: Some(100_000), model_auto_compact_token_limit: Some(180_000), model_provider_id: "openai".to_string(), model_provider: fixture.openai_provider.clone(), @@ -2961,6 +2966,7 @@ model_verbosity = "high" model_providers: fixture.model_provider_map.clone(), project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), + tool_output_token_limit: None, codex_home: fixture.codex_home(), history: History::default(), file_opener: UriBasedFileOpener::VsCode, @@ -2988,6 +2994,7 @@ model_verbosity = "high" notices: Default::default(), disable_paste_burst: false, tui_notifications: Default::default(), + animations: true, otel: OtelConfig::default(), }, o3_profile_config @@ -3014,7 +3021,6 @@ model_verbosity = "high" review_model: OPENAI_DEFAULT_REVIEW_MODEL.to_string(), model_family: find_family_for_model("gpt-3.5-turbo").expect("known model slug"), model_context_window: Some(16_385), - model_max_output_tokens: Some(4_096), model_auto_compact_token_limit: Some(14_746), model_provider_id: "openai-chat-completions".to_string(), model_provider: fixture.openai_chat_completions_provider.clone(), @@ -3032,6 +3038,7 @@ model_verbosity = "high" model_providers: fixture.model_provider_map.clone(), project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), + tool_output_token_limit: None, codex_home: fixture.codex_home(), history: History::default(), file_opener: UriBasedFileOpener::VsCode, @@ -3059,6 +3066,7 @@ model_verbosity = "high" notices: Default::default(), disable_paste_burst: false, tui_notifications: Default::default(), + animations: true, otel: OtelConfig::default(), }; @@ -3100,7 +3108,6 @@ model_verbosity = "high" review_model: OPENAI_DEFAULT_REVIEW_MODEL.to_string(), model_family: find_family_for_model("o3").expect("known model slug"), model_context_window: Some(200_000), - model_max_output_tokens: Some(100_000), model_auto_compact_token_limit: Some(180_000), model_provider_id: "openai".to_string(), model_provider: fixture.openai_provider.clone(), @@ -3118,6 +3125,7 @@ model_verbosity = "high" model_providers: fixture.model_provider_map.clone(), project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), + tool_output_token_limit: None, codex_home: fixture.codex_home(), history: History::default(), file_opener: UriBasedFileOpener::VsCode, @@ -3145,6 +3153,7 @@ model_verbosity = "high" notices: Default::default(), disable_paste_burst: false, tui_notifications: Default::default(), + animations: true, otel: OtelConfig::default(), }; @@ -3172,7 +3181,6 @@ model_verbosity = "high" review_model: OPENAI_DEFAULT_REVIEW_MODEL.to_string(), model_family: find_family_for_model("gpt-5.1").expect("known model slug"), model_context_window: Some(272_000), - model_max_output_tokens: Some(128_000), model_auto_compact_token_limit: Some(244_800), model_provider_id: "openai".to_string(), model_provider: fixture.openai_provider.clone(), @@ -3190,6 +3198,7 @@ model_verbosity = "high" model_providers: fixture.model_provider_map.clone(), project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), + tool_output_token_limit: None, codex_home: fixture.codex_home(), history: History::default(), file_opener: UriBasedFileOpener::VsCode, @@ -3217,6 +3226,7 @@ model_verbosity = "high" notices: Default::default(), disable_paste_burst: false, tui_notifications: Default::default(), + animations: true, otel: OtelConfig::default(), }; diff --git a/codex-rs/core/src/config/types.rs b/codex-rs/core/src/config/types.rs index 9c2cd03d1..869ec8297 100644 --- a/codex-rs/core/src/config/types.rs +++ b/codex-rs/core/src/config/types.rs @@ -282,6 +282,14 @@ pub enum OtelHttpProtocol { Json, } +#[derive(Deserialize, Debug, Clone, PartialEq, Default)] +#[serde(rename_all = "kebab-case")] +pub struct OtelTlsConfig { + pub ca_certificate: Option, + pub client_certificate: Option, + pub client_private_key: Option, +} + /// Which OTEL exporter to use. #[derive(Deserialize, Debug, Clone, PartialEq)] #[serde(rename_all = "kebab-case")] @@ -289,12 +297,18 @@ pub enum OtelExporterKind { None, OtlpHttp { endpoint: String, + #[serde(default)] headers: HashMap, protocol: OtelHttpProtocol, + #[serde(default)] + tls: Option, }, OtlpGrpc { endpoint: String, + #[serde(default)] headers: HashMap, + #[serde(default)] + tls: Option, }, } @@ -349,6 +363,15 @@ pub struct Tui { /// Defaults to `true`. #[serde(default)] pub notifications: Notifications, + + /// Enable animations (welcome screen, shimmer effects, spinners). + /// Defaults to `true`. + #[serde(default = "default_true")] + pub animations: bool, +} + +const fn default_true() -> bool { + true } /// Settings for notices we display to users via the tui and app-server clients @@ -364,6 +387,9 @@ pub struct Notice { pub hide_rate_limit_model_nudge: Option, /// Tracks whether the user has seen the model migration prompt pub hide_gpt5_1_migration_prompt: Option, + /// Tracks whether the user has seen the gpt-5.1-codex-max migration prompt + #[serde(rename = "hide_gpt-5.1-codex-max_migration_prompt")] + pub hide_gpt_5_1_codex_max_migration_prompt: Option, } impl Notice { diff --git a/codex-rs/core/src/context_manager/history.rs b/codex-rs/core/src/context_manager/history.rs index 50e3a8bc9..8eefcbf85 100644 --- a/codex-rs/core/src/context_manager/history.rs +++ b/codex-rs/core/src/context_manager/history.rs @@ -1,21 +1,15 @@ use crate::codex::TurnContext; use crate::context_manager::normalize; -use crate::truncate; -use crate::truncate::format_output_for_model_body; -use crate::truncate::globally_truncate_function_output_items; +use crate::truncate::TruncationPolicy; +use crate::truncate::approx_token_count; +use crate::truncate::truncate_function_output_items_with_policy; +use crate::truncate::truncate_text; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::TokenUsage; use codex_protocol::protocol::TokenUsageInfo; -use codex_utils_tokenizer::Tokenizer; use std::ops::Deref; -const CONTEXT_WINDOW_HARD_LIMIT_FACTOR: f64 = 1.1; -const CONTEXT_WINDOW_HARD_LIMIT_BYTES: usize = - (truncate::MODEL_FORMAT_MAX_BYTES as f64 * CONTEXT_WINDOW_HARD_LIMIT_FACTOR) as usize; -const CONTEXT_WINDOW_HARD_LIMIT_LINES: usize = - (truncate::MODEL_FORMAT_MAX_LINES as f64 * CONTEXT_WINDOW_HARD_LIMIT_FACTOR) as usize; - /// Transcript of conversation history #[derive(Debug, Clone, Default)] pub(crate) struct ContextManager { @@ -50,7 +44,7 @@ impl ContextManager { } /// `items` is ordered from oldest to newest. - pub(crate) fn record_items(&mut self, items: I) + pub(crate) fn record_items(&mut self, items: I, policy: TruncationPolicy) where I: IntoIterator, I::Item: std::ops::Deref, @@ -62,7 +56,7 @@ impl ContextManager { continue; } - let processed = Self::process_item(&item); + let processed = self.process_item(item_ref, policy); self.items.push(processed); } } @@ -80,26 +74,21 @@ impl ContextManager { history } - // Estimate the number of tokens in the history. Return None if no tokenizer - // is available. This does not consider the reasoning traces. - // /!\ The value is a lower bound estimate and does not represent the exact - // context length. + // Estimate token usage using byte-based heuristics from the truncation helpers. + // This is a coarse lower bound, not a tokenizer-accurate count. pub(crate) fn estimate_token_count(&self, turn_context: &TurnContext) -> Option { - let model = turn_context.client.get_model(); - let tokenizer = Tokenizer::for_model(model.as_str()).ok()?; let model_family = turn_context.client.get_model_family(); + let base_tokens = + i64::try_from(approx_token_count(model_family.base_instructions.as_str())) + .unwrap_or(i64::MAX); + + let items_tokens = self.items.iter().fold(0i64, |acc, item| { + let serialized = serde_json::to_string(item).unwrap_or_default(); + let item_tokens = i64::try_from(approx_token_count(&serialized)).unwrap_or(i64::MAX); + acc.saturating_add(item_tokens) + }); - Some( - self.items - .iter() - .map(|item| { - serde_json::to_string(&item) - .map(|item| tokenizer.count(&item)) - .unwrap_or_default() - }) - .sum::() - + tokenizer.count(model_family.base_instructions.as_str()), - ) + Some(base_tokens.saturating_add(items_tokens)) } pub(crate) fn remove_first_item(&mut self) { @@ -150,18 +139,18 @@ impl ContextManager { items.retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. })); } - fn process_item(item: &ResponseItem) -> ResponseItem { + fn process_item(&self, item: &ResponseItem, policy: TruncationPolicy) -> ResponseItem { + let policy_with_serialization_budget = policy.mul(1.2); match item { ResponseItem::FunctionCallOutput { call_id, output } => { - let truncated = format_output_for_model_body( - output.content.as_str(), - CONTEXT_WINDOW_HARD_LIMIT_BYTES, - CONTEXT_WINDOW_HARD_LIMIT_LINES, - ); - let truncated_items = output - .content_items - .as_ref() - .map(|items| globally_truncate_function_output_items(items)); + let truncated = + truncate_text(output.content.as_str(), policy_with_serialization_budget); + let truncated_items = output.content_items.as_ref().map(|items| { + truncate_function_output_items_with_policy( + items, + policy_with_serialization_budget, + ) + }); ResponseItem::FunctionCallOutput { call_id: call_id.clone(), output: FunctionCallOutputPayload { @@ -172,11 +161,7 @@ impl ContextManager { } } ResponseItem::CustomToolCallOutput { call_id, output } => { - let truncated = format_output_for_model_body( - output, - CONTEXT_WINDOW_HARD_LIMIT_BYTES, - CONTEXT_WINDOW_HARD_LIMIT_LINES, - ); + let truncated = truncate_text(output, policy_with_serialization_budget); ResponseItem::CustomToolCallOutput { call_id: call_id.clone(), output: truncated, diff --git a/codex-rs/core/src/context_manager/history_tests.rs b/codex-rs/core/src/context_manager/history_tests.rs index c81749c2c..1a01604a7 100644 --- a/codex-rs/core/src/context_manager/history_tests.rs +++ b/codex-rs/core/src/context_manager/history_tests.rs @@ -1,9 +1,8 @@ use super::*; -use crate::context_manager::MODEL_FORMAT_MAX_LINES; use crate::truncate; +use crate::truncate::TruncationPolicy; use codex_git::GhostCommit; use codex_protocol::models::ContentItem; -use codex_protocol::models::FunctionCallOutputContentItem; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::LocalShellAction; use codex_protocol::models::LocalShellExecAction; @@ -13,6 +12,9 @@ use codex_protocol::models::ReasoningItemReasoningSummary; use pretty_assertions::assert_eq; use regex_lite::Regex; +const EXEC_FORMAT_MAX_BYTES: usize = 10_000; +const EXEC_FORMAT_MAX_TOKENS: usize = 2_500; + fn assistant_msg(text: &str) -> ResponseItem { ResponseItem::Message { id: None, @@ -25,7 +27,9 @@ fn assistant_msg(text: &str) -> ResponseItem { fn create_history_with_items(items: Vec) -> ContextManager { let mut h = ContextManager::new(); - h.record_items(items.iter()); + // Use a generous but fixed token budget; tests only rely on truncation + // behavior, not on a specific model's token limit. + h.record_items(items.iter(), TruncationPolicy::Tokens(10_000)); h } @@ -52,9 +56,14 @@ fn reasoning_msg(text: &str) -> ResponseItem { } } +fn truncate_exec_output(content: &str) -> String { + truncate::truncate_text(content, TruncationPolicy::Tokens(EXEC_FORMAT_MAX_TOKENS)) +} + #[test] fn filters_non_api_messages() { let mut h = ContextManager::default(); + let policy = TruncationPolicy::Tokens(10_000); // System message is not API messages; Other is ignored. let system = ResponseItem::Message { id: None, @@ -64,12 +73,12 @@ fn filters_non_api_messages() { }], }; let reasoning = reasoning_msg("thinking..."); - h.record_items([&system, &reasoning, &ResponseItem::Other]); + h.record_items([&system, &reasoning, &ResponseItem::Other], policy); // User and assistant should be retained. let u = user_msg("hi"); let a = assistant_msg("hello"); - h.record_items([&u, &a]); + h.record_items([&u, &a], policy); let items = h.contents(); assert_eq!( @@ -223,7 +232,7 @@ fn normalization_retains_local_shell_outputs() { ResponseItem::FunctionCallOutput { call_id: "shell-1".to_string(), output: FunctionCallOutputPayload { - content: "ok".to_string(), + content: "Total output lines: 1\n\nok".to_string(), ..Default::default() }, }, @@ -237,6 +246,9 @@ fn normalization_retains_local_shell_outputs() { #[test] fn record_items_truncates_function_call_output_content() { let mut history = ContextManager::new(); + // Any reasonably small token budget works; the test only cares that + // truncation happens and the marker is present. + let policy = TruncationPolicy::Tokens(1_000); let long_line = "a very long line to trigger truncation\n"; let long_output = long_line.repeat(2_500); let item = ResponseItem::FunctionCallOutput { @@ -248,15 +260,20 @@ fn record_items_truncates_function_call_output_content() { }, }; - history.record_items([&item]); + history.record_items([&item], policy); assert_eq!(history.items.len(), 1); match &history.items[0] { ResponseItem::FunctionCallOutput { output, .. } => { assert_ne!(output.content, long_output); assert!( - output.content.starts_with("Total output lines:"), - "expected truncated summary, got {}", + output.content.contains("tokens truncated"), + "expected token-based truncation marker, got {}", + output.content + ); + assert!( + output.content.contains("tokens truncated"), + "expected truncation marker, got {}", output.content ); } @@ -267,6 +284,7 @@ fn record_items_truncates_function_call_output_content() { #[test] fn record_items_truncates_custom_tool_call_output_content() { let mut history = ContextManager::new(); + let policy = TruncationPolicy::Tokens(1_000); let line = "custom output that is very long\n"; let long_output = line.repeat(2_500); let item = ResponseItem::CustomToolCallOutput { @@ -274,23 +292,50 @@ fn record_items_truncates_custom_tool_call_output_content() { output: long_output.clone(), }; - history.record_items([&item]); + history.record_items([&item], policy); assert_eq!(history.items.len(), 1); match &history.items[0] { ResponseItem::CustomToolCallOutput { output, .. } => { assert_ne!(output, &long_output); assert!( - output.starts_with("Total output lines:"), - "expected truncated summary, got {output}" + output.contains("tokens truncated"), + "expected token-based truncation marker, got {output}" + ); + assert!( + output.contains("tokens truncated") || output.contains("bytes truncated"), + "expected truncation marker, got {output}" ); } other => panic!("unexpected history item: {other:?}"), } } -fn assert_truncated_message_matches(message: &str, line: &str, total_lines: usize) { - let pattern = truncated_message_pattern(line, total_lines); +#[test] +fn record_items_respects_custom_token_limit() { + let mut history = ContextManager::new(); + let policy = TruncationPolicy::Tokens(10); + let long_output = "tokenized content repeated many times ".repeat(200); + let item = ResponseItem::FunctionCallOutput { + call_id: "call-custom-limit".to_string(), + output: FunctionCallOutputPayload { + content: long_output, + success: Some(true), + ..Default::default() + }, + }; + + history.record_items([&item], policy); + + let stored = match &history.items[0] { + ResponseItem::FunctionCallOutput { output, .. } => output, + other => panic!("unexpected history item: {other:?}"), + }; + assert!(stored.content.contains("tokens truncated")); +} + +fn assert_truncated_message_matches(message: &str, line: &str, expected_removed: usize) { + let pattern = truncated_message_pattern(line); let regex = Regex::new(&pattern).unwrap_or_else(|err| { panic!("failed to compile regex {pattern}: {err}"); }); @@ -302,28 +347,22 @@ fn assert_truncated_message_matches(message: &str, line: &str, total_lines: usiz .expect("missing body capture") .as_str(); assert!( - body.len() <= truncate::MODEL_FORMAT_MAX_BYTES, + body.len() <= EXEC_FORMAT_MAX_BYTES, "body exceeds byte limit: {} bytes", body.len() ); + let removed: usize = captures + .name("removed") + .expect("missing removed capture") + .as_str() + .parse() + .unwrap_or_else(|err| panic!("invalid removed tokens: {err}")); + assert_eq!(removed, expected_removed, "mismatched removed token count"); } -fn truncated_message_pattern(line: &str, total_lines: usize) -> String { - let head_lines = MODEL_FORMAT_MAX_LINES / 2; - let tail_lines = MODEL_FORMAT_MAX_LINES - head_lines; - let head_take = head_lines.min(total_lines); - let tail_take = tail_lines.min(total_lines.saturating_sub(head_take)); - let omitted = total_lines.saturating_sub(head_take + tail_take); +fn truncated_message_pattern(line: &str) -> String { let escaped_line = regex_lite::escape(line); - if omitted == 0 { - return format!( - r"(?s)^Total output lines: {total_lines}\n\n(?P{escaped_line}.*\n\[\.{{3}} output truncated to fit {max_bytes} bytes \.{{3}}]\n\n.*)$", - max_bytes = truncate::MODEL_FORMAT_MAX_BYTES, - ); - } - format!( - r"(?s)^Total output lines: {total_lines}\n\n(?P{escaped_line}.*\n\[\.{{3}} omitted {omitted} of {total_lines} lines \.{{3}}]\n\n.*)$", - ) + format!(r"(?s)^(?P{escaped_line}.*?)(?:\r?)?…(?P\d+) tokens truncated…(?:.*)?$") } #[test] @@ -331,35 +370,18 @@ fn format_exec_output_truncates_large_error() { let line = "very long execution error line that should trigger truncation\n"; let large_error = line.repeat(2_500); // way beyond both byte and line limits - let truncated = truncate::format_output_for_model_body( - &large_error, - truncate::MODEL_FORMAT_MAX_BYTES, - truncate::MODEL_FORMAT_MAX_LINES, - ); + let truncated = truncate_exec_output(&large_error); - let total_lines = large_error.lines().count(); - assert_truncated_message_matches(&truncated, line, total_lines); + assert_truncated_message_matches(&truncated, line, 36250); assert_ne!(truncated, large_error); } #[test] fn format_exec_output_marks_byte_truncation_without_omitted_lines() { - let long_line = "a".repeat(truncate::MODEL_FORMAT_MAX_BYTES + 50); - let truncated = truncate::format_output_for_model_body( - &long_line, - truncate::MODEL_FORMAT_MAX_BYTES, - truncate::MODEL_FORMAT_MAX_LINES, - ); - + let long_line = "a".repeat(EXEC_FORMAT_MAX_BYTES + 10000); + let truncated = truncate_exec_output(&long_line); assert_ne!(truncated, long_line); - let marker_line = format!( - "[... output truncated to fit {} bytes ...]", - truncate::MODEL_FORMAT_MAX_BYTES - ); - assert!( - truncated.contains(&marker_line), - "missing byte truncation marker: {truncated}" - ); + assert_truncated_message_matches(&truncated, "a", 2500); assert!( !truncated.contains("omitted"), "line omission marker should not appear when no lines were dropped: {truncated}" @@ -369,42 +391,25 @@ fn format_exec_output_marks_byte_truncation_without_omitted_lines() { #[test] fn format_exec_output_returns_original_when_within_limits() { let content = "example output\n".repeat(10); - - assert_eq!( - truncate::format_output_for_model_body( - &content, - truncate::MODEL_FORMAT_MAX_BYTES, - truncate::MODEL_FORMAT_MAX_LINES - ), - content - ); + assert_eq!(truncate_exec_output(&content), content); } #[test] fn format_exec_output_reports_omitted_lines_and_keeps_head_and_tail() { - let total_lines = truncate::MODEL_FORMAT_MAX_LINES + 100; + let total_lines = 2_000; + let filler = "x".repeat(64); let content: String = (0..total_lines) - .map(|idx| format!("line-{idx}\n")) + .map(|idx| format!("line-{idx}-{filler}\n")) .collect(); - let truncated = truncate::format_output_for_model_body( - &content, - truncate::MODEL_FORMAT_MAX_BYTES, - truncate::MODEL_FORMAT_MAX_LINES, - ); - let omitted = total_lines - truncate::MODEL_FORMAT_MAX_LINES; - let expected_marker = format!("[... omitted {omitted} of {total_lines} lines ...]"); - - assert!( - truncated.contains(&expected_marker), - "missing omitted marker: {truncated}" - ); + let truncated = truncate_exec_output(&content); + assert_truncated_message_matches(&truncated, "line-0-", 34_723); assert!( - truncated.contains("line-0\n"), + truncated.contains("line-0-"), "expected head line to remain: {truncated}" ); - let last_line = format!("line-{}\n", total_lines - 1); + let last_line = format!("line-{}-", total_lines - 1); assert!( truncated.contains(&last_line), "expected tail line to remain: {truncated}" @@ -413,101 +418,15 @@ fn format_exec_output_reports_omitted_lines_and_keeps_head_and_tail() { #[test] fn format_exec_output_prefers_line_marker_when_both_limits_exceeded() { - let total_lines = truncate::MODEL_FORMAT_MAX_LINES + 42; + let total_lines = 300; let long_line = "x".repeat(256); let content: String = (0..total_lines) .map(|idx| format!("line-{idx}-{long_line}\n")) .collect(); - let truncated = truncate::format_output_for_model_body( - &content, - truncate::MODEL_FORMAT_MAX_BYTES, - truncate::MODEL_FORMAT_MAX_LINES, - ); - - assert!( - truncated.contains("[... omitted 42 of 298 lines ...]"), - "expected omitted marker when line count exceeds limit: {truncated}" - ); - assert!( - !truncated.contains("output truncated to fit"), - "line omission marker should take precedence over byte marker: {truncated}" - ); -} - -#[test] -fn truncates_across_multiple_under_limit_texts_and_reports_omitted() { - // Arrange: several text items, none exceeding per-item limit, but total exceeds budget. - let budget = truncate::MODEL_FORMAT_MAX_BYTES; - let t1_len = (budget / 2).saturating_sub(10); - let t2_len = (budget / 2).saturating_sub(10); - let remaining_after_t1_t2 = budget.saturating_sub(t1_len + t2_len); - let t3_len = 50; // gets truncated to remaining_after_t1_t2 - let t4_len = 5; // omitted - let t5_len = 7; // omitted - - let t1 = "a".repeat(t1_len); - let t2 = "b".repeat(t2_len); - let t3 = "c".repeat(t3_len); - let t4 = "d".repeat(t4_len); - let t5 = "e".repeat(t5_len); - - let item = ResponseItem::FunctionCallOutput { - call_id: "call-omit".to_string(), - output: FunctionCallOutputPayload { - content: "irrelevant".to_string(), - content_items: Some(vec![ - FunctionCallOutputContentItem::InputText { text: t1 }, - FunctionCallOutputContentItem::InputText { text: t2 }, - FunctionCallOutputContentItem::InputImage { - image_url: "img:mid".to_string(), - }, - FunctionCallOutputContentItem::InputText { text: t3 }, - FunctionCallOutputContentItem::InputText { text: t4 }, - FunctionCallOutputContentItem::InputText { text: t5 }, - ]), - success: Some(true), - }, - }; - - let mut history = ContextManager::new(); - history.record_items([&item]); - assert_eq!(history.items.len(), 1); - let json = serde_json::to_value(&history.items[0]).expect("serialize to json"); - - let output = json - .get("output") - .expect("output field") - .as_array() - .expect("array output"); - - // Expect: t1 (full), t2 (full), image, t3 (truncated), summary mentioning 2 omitted. - assert_eq!(output.len(), 5); - - let first = output[0].as_object().expect("first obj"); - assert_eq!(first.get("type").unwrap(), "input_text"); - let first_text = first.get("text").unwrap().as_str().unwrap(); - assert_eq!(first_text.len(), t1_len); - - let second = output[1].as_object().expect("second obj"); - assert_eq!(second.get("type").unwrap(), "input_text"); - let second_text = second.get("text").unwrap().as_str().unwrap(); - assert_eq!(second_text.len(), t2_len); - - assert_eq!( - output[2], - serde_json::json!({"type": "input_image", "image_url": "img:mid"}) - ); - - let fourth = output[3].as_object().expect("fourth obj"); - assert_eq!(fourth.get("type").unwrap(), "input_text"); - let fourth_text = fourth.get("text").unwrap().as_str().unwrap(); - assert_eq!(fourth_text.len(), remaining_after_t1_t2); + let truncated = truncate_exec_output(&content); - let summary = output[4].as_object().expect("summary obj"); - assert_eq!(summary.get("type").unwrap(), "input_text"); - let summary_text = summary.get("text").unwrap().as_str().unwrap(); - assert!(summary_text.contains("omitted 2 text items")); + assert_truncated_message_matches(&truncated, "line-0-", 17_423); } //TODO(aibrahim): run CI in release mode. diff --git a/codex-rs/core/src/context_manager/mod.rs b/codex-rs/core/src/context_manager/mod.rs index ab0d2e816..5089b5e8b 100644 --- a/codex-rs/core/src/context_manager/mod.rs +++ b/codex-rs/core/src/context_manager/mod.rs @@ -1,7 +1,4 @@ mod history; mod normalize; -pub(crate) use crate::truncate::MODEL_FORMAT_MAX_BYTES; -pub(crate) use crate::truncate::MODEL_FORMAT_MAX_LINES; -pub(crate) use crate::truncate::format_output_for_model_body; pub(crate) use history::ContextManager; diff --git a/codex-rs/core/src/environment_context.rs b/codex-rs/core/src/environment_context.rs index 6db2bb8f5..56e7f6cad 100644 --- a/codex-rs/core/src/environment_context.rs +++ b/codex-rs/core/src/environment_context.rs @@ -6,6 +6,7 @@ use crate::codex::TurnContext; use crate::protocol::AskForApproval; use crate::protocol::SandboxPolicy; use crate::shell::Shell; +use crate::shell::default_user_shell; use codex_protocol::config_types::SandboxMode; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; @@ -28,7 +29,7 @@ pub(crate) struct EnvironmentContext { pub sandbox_mode: Option, pub network_access: Option, pub writable_roots: Option>, - pub shell: Option, + pub shell: Shell, } impl EnvironmentContext { @@ -36,7 +37,7 @@ impl EnvironmentContext { cwd: Option, approval_policy: Option, sandbox_policy: Option, - shell: Option, + shell: Shell, ) -> Self { Self { cwd, @@ -110,7 +111,7 @@ impl EnvironmentContext { } else { None }; - EnvironmentContext::new(cwd, approval_policy, sandbox_policy, None) + EnvironmentContext::new(cwd, approval_policy, sandbox_policy, default_user_shell()) } } @@ -121,7 +122,7 @@ impl From<&TurnContext> for EnvironmentContext { Some(turn_context.approval_policy), Some(turn_context.sandbox_policy.clone()), // Shell is not configurable from turn to turn - None, + default_user_shell(), ) } } @@ -169,11 +170,9 @@ impl EnvironmentContext { } lines.push(" ".to_string()); } - if let Some(shell) = self.shell - && let Some(shell_name) = shell.name() - { - lines.push(format!(" {shell_name}")); - } + + let shell_name = self.shell.name(); + lines.push(format!(" {shell_name}")); lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string()); lines.join("\n") } @@ -193,12 +192,18 @@ impl From for ResponseItem { #[cfg(test)] mod tests { - use crate::shell::BashShell; - use crate::shell::ZshShell; + use crate::shell::ShellType; use super::*; use pretty_assertions::assert_eq; + fn fake_shell() -> Shell { + Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + } + } + fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy { SandboxPolicy::WorkspaceWrite { writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(), @@ -214,7 +219,7 @@ mod tests { Some(PathBuf::from("/repo")), Some(AskForApproval::OnRequest), Some(workspace_write_policy(vec!["/repo", "/tmp"], false)), - None, + fake_shell(), ); let expected = r#" @@ -226,6 +231,7 @@ mod tests { /repo /tmp + bash "#; assert_eq!(context.serialize_to_xml(), expected); @@ -237,13 +243,14 @@ mod tests { None, Some(AskForApproval::Never), Some(SandboxPolicy::ReadOnly), - None, + fake_shell(), ); let expected = r#" never read-only restricted + bash "#; assert_eq!(context.serialize_to_xml(), expected); @@ -255,13 +262,14 @@ mod tests { None, Some(AskForApproval::OnFailure), Some(SandboxPolicy::DangerFullAccess), - None, + fake_shell(), ); let expected = r#" on-failure danger-full-access enabled + bash "#; assert_eq!(context.serialize_to_xml(), expected); @@ -274,13 +282,13 @@ mod tests { Some(PathBuf::from("/repo")), Some(AskForApproval::OnRequest), Some(workspace_write_policy(vec!["/repo"], false)), - None, + fake_shell(), ); let context2 = EnvironmentContext::new( Some(PathBuf::from("/repo")), Some(AskForApproval::Never), Some(workspace_write_policy(vec!["/repo"], true)), - None, + fake_shell(), ); assert!(!context1.equals_except_shell(&context2)); } @@ -291,13 +299,13 @@ mod tests { Some(PathBuf::from("/repo")), Some(AskForApproval::OnRequest), Some(SandboxPolicy::new_read_only_policy()), - None, + fake_shell(), ); let context2 = EnvironmentContext::new( Some(PathBuf::from("/repo")), Some(AskForApproval::OnRequest), Some(SandboxPolicy::new_workspace_write_policy()), - None, + fake_shell(), ); assert!(!context1.equals_except_shell(&context2)); @@ -309,13 +317,13 @@ mod tests { Some(PathBuf::from("/repo")), Some(AskForApproval::OnRequest), Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)), - None, + fake_shell(), ); let context2 = EnvironmentContext::new( Some(PathBuf::from("/repo")), Some(AskForApproval::OnRequest), Some(workspace_write_policy(vec!["/repo", "/tmp"], true)), - None, + fake_shell(), ); assert!(!context1.equals_except_shell(&context2)); @@ -327,17 +335,19 @@ mod tests { Some(PathBuf::from("/repo")), Some(AskForApproval::OnRequest), Some(workspace_write_policy(vec!["/repo"], false)), - Some(Shell::Bash(BashShell { + Shell { + shell_type: ShellType::Bash, shell_path: "/bin/bash".into(), - })), + }, ); let context2 = EnvironmentContext::new( Some(PathBuf::from("/repo")), Some(AskForApproval::OnRequest), Some(workspace_write_policy(vec!["/repo"], false)), - Some(Shell::Zsh(ZshShell { + Shell { + shell_type: ShellType::Zsh, shell_path: "/bin/zsh".into(), - })), + }, ); assert!(context1.equals_except_shell(&context2)); diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index 64ba8df84..9130b40e1 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -2,13 +2,16 @@ use crate::codex::ProcessedResponseItem; use crate::exec::ExecToolCallOutput; use crate::token_data::KnownPlan; use crate::token_data::PlanType; -use crate::truncate::truncate_middle; +use crate::truncate::TruncationPolicy; +use crate::truncate::truncate_text; use chrono::DateTime; use chrono::Datelike; use chrono::Local; use chrono::Utc; use codex_async_utils::CancelErr; use codex_protocol::ConversationId; +use codex_protocol::protocol::CodexErrorInfo; +use codex_protocol::protocol::ErrorEvent; use codex_protocol::protocol::RateLimitSnapshot; use reqwest::StatusCode; use serde_json; @@ -429,6 +432,57 @@ impl CodexErr { pub fn downcast_ref(&self) -> Option<&T> { (self as &dyn std::any::Any).downcast_ref::() } + + /// Translate core error to client-facing protocol error. + pub fn to_codex_protocol_error(&self) -> CodexErrorInfo { + match self { + CodexErr::ContextWindowExceeded => CodexErrorInfo::ContextWindowExceeded, + CodexErr::UsageLimitReached(_) + | CodexErr::QuotaExceeded + | CodexErr::UsageNotIncluded => CodexErrorInfo::UsageLimitExceeded, + CodexErr::RetryLimit(_) => CodexErrorInfo::ResponseTooManyFailedAttempts { + http_status_code: self.http_status_code_value(), + }, + CodexErr::ConnectionFailed(_) => CodexErrorInfo::HttpConnectionFailed { + http_status_code: self.http_status_code_value(), + }, + CodexErr::ResponseStreamFailed(_) => CodexErrorInfo::ResponseStreamConnectionFailed { + http_status_code: self.http_status_code_value(), + }, + CodexErr::RefreshTokenFailed(_) => CodexErrorInfo::Unauthorized, + CodexErr::SessionConfiguredNotFirstEvent + | CodexErr::InternalServerError + | CodexErr::InternalAgentDied => CodexErrorInfo::InternalServerError, + CodexErr::UnsupportedOperation(_) | CodexErr::ConversationNotFound(_) => { + CodexErrorInfo::BadRequest + } + CodexErr::Sandbox(_) => CodexErrorInfo::SandboxError, + _ => CodexErrorInfo::Other, + } + } + + pub fn to_error_event(&self, message_prefix: Option) -> ErrorEvent { + let error_message = self.to_string(); + let message: String = match message_prefix { + Some(prefix) => format!("{prefix}: {error_message}"), + None => error_message, + }; + ErrorEvent { + message, + codex_error_info: Some(self.to_codex_protocol_error()), + } + } + + pub fn http_status_code_value(&self) -> Option { + let http_status_code = match self { + CodexErr::RetryLimit(err) => Some(err.status), + CodexErr::UnexpectedStatus(err) => Some(err.status), + CodexErr::ConnectionFailed(err) => err.source.status(), + CodexErr::ResponseStreamFailed(err) => err.source.status(), + _ => None, + }; + http_status_code.as_ref().map(StatusCode::as_u16) + } } pub fn get_error_message_ui(e: &CodexErr) -> String { @@ -461,7 +515,10 @@ pub fn get_error_message_ui(e: &CodexErr) -> String { _ => e.to_string(), }; - truncate_middle(&message, ERROR_MESSAGE_UI_MAX_BYTES).0 + truncate_text( + &message, + TruncationPolicy::Bytes(ERROR_MESSAGE_UI_MAX_BYTES), + ) } #[cfg(test)] @@ -474,6 +531,10 @@ mod tests { use chrono::Utc; use codex_protocol::protocol::RateLimitWindow; use pretty_assertions::assert_eq; + use reqwest::Response; + use reqwest::ResponseBuilderExt; + use reqwest::StatusCode; + use reqwest::Url; fn rate_limit_snapshot() -> RateLimitSnapshot { let primary_reset_at = Utc @@ -495,6 +556,7 @@ mod tests { window_minutes: Some(120), resets_at: Some(secondary_reset_at), }), + credits: None, } } @@ -568,6 +630,33 @@ mod tests { assert_eq!(get_error_message_ui(&err), "stdout only"); } + #[test] + fn to_error_event_handles_response_stream_failed() { + let response = http::Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .url(Url::parse("http://example.com").unwrap()) + .body("") + .unwrap(); + let source = Response::from(response).error_for_status_ref().unwrap_err(); + let err = CodexErr::ResponseStreamFailed(ResponseStreamFailed { + source, + request_id: Some("req-123".to_string()), + }); + + let event = err.to_error_event(Some("prefix".to_string())); + + assert_eq!( + event.message, + "prefix: Error while reading the server response: HTTP status client error (429 Too Many Requests) for url (http://example.com/), request id: req-123" + ); + assert_eq!( + event.codex_error_info, + Some(CodexErrorInfo::ResponseStreamConnectionFailed { + http_status_code: Some(429) + }) + ); + } + #[test] fn sandbox_denied_reports_exit_code_when_no_output_available() { let output = ExecToolCallOutput { diff --git a/codex-rs/core/src/event_mapping.rs b/codex-rs/core/src/event_mapping.rs index d0aa1d818..6b4bed4db 100644 --- a/codex-rs/core/src/event_mapping.rs +++ b/codex-rs/core/src/event_mapping.rs @@ -117,7 +117,7 @@ pub fn parse_turn_item(item: &ResponseItem) -> Option { .. } => Some(TurnItem::WebSearch(WebSearchItem { id: id.clone().unwrap_or_default(), - query: query.clone(), + query: query.clone().unwrap_or_default(), })), _ => None, } @@ -306,7 +306,7 @@ mod tests { id: Some("ws_1".to_string()), status: Some("completed".to_string()), action: WebSearchAction::Search { - query: "weather".to_string(), + query: Some("weather".to_string()), }, }; diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index b44a40bdb..f45ecdce7 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -14,6 +14,7 @@ use tokio::io::AsyncRead; use tokio::io::AsyncReadExt; use tokio::io::BufReader; use tokio::process::Child; +use tokio_util::sync::CancellationToken; use crate::error::CodexErr; use crate::error::Result; @@ -28,8 +29,9 @@ use crate::sandboxing::ExecEnv; use crate::sandboxing::SandboxManager; use crate::spawn::StdioPolicy; use crate::spawn::spawn_child_async; +use crate::text_encoding::bytes_to_string_smart; -const DEFAULT_TIMEOUT_MS: u64 = 10_000; +pub const DEFAULT_EXEC_COMMAND_TIMEOUT_MS: u64 = 10_000; // Hardcode these since it does not seem worth including the libc crate just // for these. @@ -46,20 +48,59 @@ const AGGREGATE_BUFFER_INITIAL_CAPACITY: usize = 8 * 1024; // 8 KiB /// Aggregation still collects full output; only the live event stream is capped. pub(crate) const MAX_EXEC_OUTPUT_DELTAS_PER_CALL: usize = 10_000; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct ExecParams { pub command: Vec, pub cwd: PathBuf, - pub timeout_ms: Option, + pub expiration: ExecExpiration, pub env: HashMap, pub with_escalated_permissions: Option, pub justification: Option, pub arg0: Option, } -impl ExecParams { - pub fn timeout_duration(&self) -> Duration { - Duration::from_millis(self.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS)) +/// Mechanism to terminate an exec invocation before it finishes naturally. +#[derive(Debug)] +pub enum ExecExpiration { + Timeout(Duration), + DefaultTimeout, + Cancellation(CancellationToken), +} + +impl From> for ExecExpiration { + fn from(timeout_ms: Option) -> Self { + timeout_ms.map_or(ExecExpiration::DefaultTimeout, |timeout_ms| { + ExecExpiration::Timeout(Duration::from_millis(timeout_ms)) + }) + } +} + +impl From for ExecExpiration { + fn from(timeout_ms: u64) -> Self { + ExecExpiration::Timeout(Duration::from_millis(timeout_ms)) + } +} + +impl ExecExpiration { + async fn wait(self) { + match self { + ExecExpiration::Timeout(duration) => tokio::time::sleep(duration).await, + ExecExpiration::DefaultTimeout => { + tokio::time::sleep(Duration::from_millis(DEFAULT_EXEC_COMMAND_TIMEOUT_MS)).await + } + ExecExpiration::Cancellation(cancel) => { + cancel.cancelled().await; + } + } + } + + /// If ExecExpiration is a timeout, returns the timeout in milliseconds. + pub(crate) fn timeout_ms(&self) -> Option { + match self { + ExecExpiration::Timeout(duration) => Some(duration.as_millis() as u64), + ExecExpiration::DefaultTimeout => Some(DEFAULT_EXEC_COMMAND_TIMEOUT_MS), + ExecExpiration::Cancellation(_) => None, + } } } @@ -95,7 +136,7 @@ pub async fn process_exec_tool_call( let ExecParams { command, cwd, - timeout_ms, + expiration, env, with_escalated_permissions, justification, @@ -114,7 +155,7 @@ pub async fn process_exec_tool_call( args: args.to_vec(), cwd, env, - timeout_ms, + expiration, with_escalated_permissions, justification, }; @@ -122,7 +163,7 @@ pub async fn process_exec_tool_call( let manager = SandboxManager::new(); let exec_env = manager .transform( - &spec, + spec, sandbox_policy, sandbox_type, sandbox_cwd, @@ -131,7 +172,7 @@ pub async fn process_exec_tool_call( .map_err(CodexErr::from)?; // Route through the sandboxing module for a single, unified execution path. - crate::sandboxing::execute_env(&exec_env, sandbox_policy, stdout_stream).await + crate::sandboxing::execute_env(exec_env, sandbox_policy, stdout_stream).await } pub(crate) async fn execute_exec_env( @@ -143,7 +184,7 @@ pub(crate) async fn execute_exec_env( command, cwd, env, - timeout_ms, + expiration, sandbox, with_escalated_permissions, justification, @@ -153,7 +194,7 @@ pub(crate) async fn execute_exec_env( let params = ExecParams { command, cwd, - timeout_ms, + expiration, env, with_escalated_permissions, justification, @@ -178,16 +219,18 @@ async fn exec_windows_sandbox( command, cwd, env, - timeout_ms, + expiration, .. } = params; + // TODO(iceweasel-oai): run_windows_sandbox_capture should support all + // variants of ExecExpiration, not just timeout. + let timeout_ms = expiration.timeout_ms(); - let policy_str = match sandbox_policy { - SandboxPolicy::DangerFullAccess => "workspace-write", - SandboxPolicy::ReadOnly => "read-only", - SandboxPolicy::WorkspaceWrite { .. } => "workspace-write", - }; - + let policy_str = serde_json::to_string(sandbox_policy).map_err(|err| { + CodexErr::Io(io::Error::other(format!( + "failed to serialize Windows sandbox policy: {err}" + ))) + })?; let sandbox_cwd = cwd.clone(); let codex_home = find_codex_home().map_err(|err| { CodexErr::Io(io::Error::other(format!( @@ -196,7 +239,7 @@ async fn exec_windows_sandbox( })?; let spawn_res = tokio::task::spawn_blocking(move || { run_windows_sandbox_capture( - policy_str, + policy_str.as_str(), &sandbox_cwd, codex_home.as_ref(), command, @@ -415,7 +458,7 @@ impl StreamOutput { impl StreamOutput> { pub fn from_utf8_lossy(&self) -> StreamOutput { StreamOutput { - text: String::from_utf8_lossy(&self.text).to_string(), + text: bytes_to_string_smart(&self.text), truncated_after_lines: self.truncated_after_lines, } } @@ -444,15 +487,17 @@ async fn exec( stdout_stream: Option, ) -> Result { #[cfg(target_os = "windows")] - if sandbox == SandboxType::WindowsRestrictedToken { + if sandbox == SandboxType::WindowsRestrictedToken + && !matches!(sandbox_policy, SandboxPolicy::DangerFullAccess) + { return exec_windows_sandbox(params, sandbox_policy).await; } - let timeout = params.timeout_duration(); let ExecParams { command, cwd, env, arg0, + expiration, .. } = params; @@ -473,14 +518,14 @@ async fn exec( env, ) .await?; - consume_truncated_output(child, timeout, stdout_stream).await + consume_truncated_output(child, expiration, stdout_stream).await } /// Consumes the output of a child process, truncating it so it is suitable for /// use as the output of a `shell` tool call. Also enforces specified timeout. async fn consume_truncated_output( mut child: Child, - timeout: Duration, + expiration: ExecExpiration, stdout_stream: Option, ) -> Result { // Both stdout and stderr were configured with `Stdio::piped()` @@ -514,20 +559,14 @@ async fn consume_truncated_output( )); let (exit_status, timed_out) = tokio::select! { - result = tokio::time::timeout(timeout, child.wait()) => { - match result { - Ok(status_result) => { - let exit_status = status_result?; - (exit_status, false) - } - Err(_) => { - // timeout - kill_child_process_group(&mut child)?; - child.start_kill()?; - // Debatable whether `child.wait().await` should be called here. - (synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE), true) - } - } + status_result = child.wait() => { + let exit_status = status_result?; + (exit_status, false) + } + _ = expiration.wait() => { + kill_child_process_group(&mut child)?; + child.start_kill()?; + (synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE), true) } _ = tokio::signal::ctrl_c() => { kill_child_process_group(&mut child)?; @@ -779,6 +818,15 @@ mod tests { #[cfg(unix)] #[tokio::test] async fn kill_child_process_group_kills_grandchildren_on_timeout() -> Result<()> { + // On Linux/macOS, /bin/bash is typically present; on FreeBSD/OpenBSD, + // prefer /bin/sh to avoid NotFound errors. + #[cfg(any(target_os = "freebsd", target_os = "openbsd"))] + let command = vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "sleep 60 & echo $!; sleep 60".to_string(), + ]; + #[cfg(all(unix, not(any(target_os = "freebsd", target_os = "openbsd"))))] let command = vec![ "/bin/bash".to_string(), "-c".to_string(), @@ -788,7 +836,7 @@ mod tests { let params = ExecParams { command, cwd: std::env::current_dir()?, - timeout_ms: Some(500), + expiration: 500.into(), env, with_escalated_permissions: None, justification: None, @@ -822,4 +870,62 @@ mod tests { assert!(killed, "grandchild process with pid {pid} is still alive"); Ok(()) } + + #[tokio::test] + async fn process_exec_tool_call_respects_cancellation_token() -> Result<()> { + let command = long_running_command(); + let cwd = std::env::current_dir()?; + let env: HashMap = std::env::vars().collect(); + let cancel_token = CancellationToken::new(); + let cancel_tx = cancel_token.clone(); + let params = ExecParams { + command, + cwd: cwd.clone(), + expiration: ExecExpiration::Cancellation(cancel_token), + env, + with_escalated_permissions: None, + justification: None, + arg0: None, + }; + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(1_000)).await; + cancel_tx.cancel(); + }); + let result = process_exec_tool_call( + params, + SandboxType::None, + &SandboxPolicy::DangerFullAccess, + cwd.as_path(), + &None, + None, + ) + .await; + let output = match result { + Err(CodexErr::Sandbox(SandboxErr::Timeout { output })) => output, + other => panic!("expected timeout error, got {other:?}"), + }; + assert!(output.timed_out); + assert_eq!(output.exit_code, EXEC_TIMEOUT_EXIT_CODE); + Ok(()) + } + + #[cfg(unix)] + fn long_running_command() -> Vec { + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "sleep 30".to_string(), + ] + } + + #[cfg(windows)] + fn long_running_command() -> Vec { + vec![ + "powershell.exe".to_string(), + "-NonInteractive".to_string(), + "-NoLogo".to_string(), + "-Command".to_string(), + "Start-Sleep -Seconds 30".to_string(), + ] + } } diff --git a/codex-rs/core/src/exec_policy.rs b/codex-rs/core/src/exec_policy.rs new file mode 100644 index 000000000..2a5d3904e --- /dev/null +++ b/codex-rs/core/src/exec_policy.rs @@ -0,0 +1,365 @@ +use std::io::ErrorKind; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; + +use crate::command_safety::is_dangerous_command::requires_initial_appoval; +use codex_execpolicy::Decision; +use codex_execpolicy::Evaluation; +use codex_execpolicy::Policy; +use codex_execpolicy::PolicyParser; +use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::SandboxPolicy; +use thiserror::Error; +use tokio::fs; + +use crate::bash::parse_shell_lc_plain_commands; +use crate::features::Feature; +use crate::features::Features; +use crate::sandboxing::SandboxPermissions; +use crate::tools::sandboxing::ApprovalRequirement; + +const FORBIDDEN_REASON: &str = "execpolicy forbids this command"; +const PROMPT_REASON: &str = "execpolicy requires approval for this command"; +const POLICY_DIR_NAME: &str = "policy"; +const POLICY_EXTENSION: &str = "codexpolicy"; + +#[derive(Debug, Error)] +pub enum ExecPolicyError { + #[error("failed to read execpolicy files from {dir}: {source}")] + ReadDir { + dir: PathBuf, + source: std::io::Error, + }, + + #[error("failed to read execpolicy file {path}: {source}")] + ReadFile { + path: PathBuf, + source: std::io::Error, + }, + + #[error("failed to parse execpolicy file {path}: {source}")] + ParsePolicy { + path: String, + source: codex_execpolicy::Error, + }, +} + +pub(crate) async fn exec_policy_for( + features: &Features, + codex_home: &Path, +) -> Result, ExecPolicyError> { + if !features.enabled(Feature::ExecPolicy) { + return Ok(Arc::new(Policy::empty())); + } + + let policy_dir = codex_home.join(POLICY_DIR_NAME); + let policy_paths = collect_policy_files(&policy_dir).await?; + + let mut parser = PolicyParser::new(); + for policy_path in &policy_paths { + let contents = + fs::read_to_string(policy_path) + .await + .map_err(|source| ExecPolicyError::ReadFile { + path: policy_path.clone(), + source, + })?; + let identifier = policy_path.to_string_lossy().to_string(); + parser + .parse(&identifier, &contents) + .map_err(|source| ExecPolicyError::ParsePolicy { + path: identifier, + source, + })?; + } + + let policy = Arc::new(parser.build()); + tracing::debug!( + "loaded execpolicy from {} files in {}", + policy_paths.len(), + policy_dir.display() + ); + + Ok(policy) +} + +fn evaluate_with_policy( + policy: &Policy, + command: &[String], + approval_policy: AskForApproval, +) -> Option { + let commands = parse_shell_lc_plain_commands(command).unwrap_or_else(|| vec![command.to_vec()]); + let evaluation = policy.check_multiple(commands.iter()); + + match evaluation { + Evaluation::Match { decision, .. } => match decision { + Decision::Forbidden => Some(ApprovalRequirement::Forbidden { + reason: FORBIDDEN_REASON.to_string(), + }), + Decision::Prompt => { + let reason = PROMPT_REASON.to_string(); + if matches!(approval_policy, AskForApproval::Never) { + Some(ApprovalRequirement::Forbidden { reason }) + } else { + Some(ApprovalRequirement::NeedsApproval { + reason: Some(reason), + }) + } + } + Decision::Allow => Some(ApprovalRequirement::Skip), + }, + Evaluation::NoMatch { .. } => None, + } +} + +pub(crate) fn create_approval_requirement_for_command( + policy: &Policy, + command: &[String], + approval_policy: AskForApproval, + sandbox_policy: &SandboxPolicy, + sandbox_permissions: SandboxPermissions, +) -> ApprovalRequirement { + if let Some(requirement) = evaluate_with_policy(policy, command, approval_policy) { + return requirement; + } + + if requires_initial_appoval( + approval_policy, + sandbox_policy, + command, + sandbox_permissions, + ) { + ApprovalRequirement::NeedsApproval { reason: None } + } else { + ApprovalRequirement::Skip + } +} + +async fn collect_policy_files(dir: &Path) -> Result, ExecPolicyError> { + let mut read_dir = match fs::read_dir(dir).await { + Ok(read_dir) => read_dir, + Err(err) if err.kind() == ErrorKind::NotFound => return Ok(Vec::new()), + Err(source) => { + return Err(ExecPolicyError::ReadDir { + dir: dir.to_path_buf(), + source, + }); + } + }; + + let mut policy_paths = Vec::new(); + while let Some(entry) = + read_dir + .next_entry() + .await + .map_err(|source| ExecPolicyError::ReadDir { + dir: dir.to_path_buf(), + source, + })? + { + let path = entry.path(); + let file_type = entry + .file_type() + .await + .map_err(|source| ExecPolicyError::ReadDir { + dir: dir.to_path_buf(), + source, + })?; + + if path + .extension() + .and_then(|ext| ext.to_str()) + .is_some_and(|ext| ext == POLICY_EXTENSION) + && file_type.is_file() + { + policy_paths.push(path); + } + } + + policy_paths.sort(); + + Ok(policy_paths) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::features::Feature; + use crate::features::Features; + use codex_protocol::protocol::AskForApproval; + use codex_protocol::protocol::SandboxPolicy; + use pretty_assertions::assert_eq; + use std::fs; + use tempfile::tempdir; + + #[tokio::test] + async fn returns_empty_policy_when_feature_disabled() { + let mut features = Features::with_defaults(); + features.disable(Feature::ExecPolicy); + let temp_dir = tempdir().expect("create temp dir"); + + let policy = exec_policy_for(&features, temp_dir.path()) + .await + .expect("policy result"); + + let commands = [vec!["rm".to_string()]]; + assert!(matches!( + policy.check_multiple(commands.iter()), + Evaluation::NoMatch { .. } + )); + assert!(!temp_dir.path().join(POLICY_DIR_NAME).exists()); + } + + #[tokio::test] + async fn collect_policy_files_returns_empty_when_dir_missing() { + let temp_dir = tempdir().expect("create temp dir"); + + let policy_dir = temp_dir.path().join(POLICY_DIR_NAME); + let files = collect_policy_files(&policy_dir) + .await + .expect("collect policy files"); + + assert!(files.is_empty()); + } + + #[tokio::test] + async fn loads_policies_from_policy_subdirectory() { + let temp_dir = tempdir().expect("create temp dir"); + let policy_dir = temp_dir.path().join(POLICY_DIR_NAME); + fs::create_dir_all(&policy_dir).expect("create policy dir"); + fs::write( + policy_dir.join("deny.codexpolicy"), + r#"prefix_rule(pattern=["rm"], decision="forbidden")"#, + ) + .expect("write policy file"); + + let policy = exec_policy_for(&Features::with_defaults(), temp_dir.path()) + .await + .expect("policy result"); + let command = [vec!["rm".to_string()]]; + assert!(matches!( + policy.check_multiple(command.iter()), + Evaluation::Match { .. } + )); + } + + #[tokio::test] + async fn ignores_policies_outside_policy_dir() { + let temp_dir = tempdir().expect("create temp dir"); + fs::write( + temp_dir.path().join("root.codexpolicy"), + r#"prefix_rule(pattern=["ls"], decision="prompt")"#, + ) + .expect("write policy file"); + + let policy = exec_policy_for(&Features::with_defaults(), temp_dir.path()) + .await + .expect("policy result"); + let command = [vec!["ls".to_string()]]; + assert!(matches!( + policy.check_multiple(command.iter()), + Evaluation::NoMatch { .. } + )); + } + + #[test] + fn evaluates_bash_lc_inner_commands() { + let policy_src = r#" +prefix_rule(pattern=["rm"], decision="forbidden") +"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.codexpolicy", policy_src) + .expect("parse policy"); + let policy = parser.build(); + + let forbidden_script = vec![ + "bash".to_string(), + "-lc".to_string(), + "rm -rf /tmp".to_string(), + ]; + + let requirement = + evaluate_with_policy(&policy, &forbidden_script, AskForApproval::OnRequest) + .expect("expected match for forbidden command"); + + assert_eq!( + requirement, + ApprovalRequirement::Forbidden { + reason: FORBIDDEN_REASON.to_string() + } + ); + } + + #[test] + fn approval_requirement_prefers_execpolicy_match() { + let policy_src = r#"prefix_rule(pattern=["rm"], decision="prompt")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.codexpolicy", policy_src) + .expect("parse policy"); + let policy = parser.build(); + let command = vec!["rm".to_string()]; + + let requirement = create_approval_requirement_for_command( + &policy, + &command, + AskForApproval::OnRequest, + &SandboxPolicy::DangerFullAccess, + SandboxPermissions::UseDefault, + ); + + assert_eq!( + requirement, + ApprovalRequirement::NeedsApproval { + reason: Some(PROMPT_REASON.to_string()) + } + ); + } + + #[test] + fn approval_requirement_respects_approval_policy() { + let policy_src = r#"prefix_rule(pattern=["rm"], decision="prompt")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.codexpolicy", policy_src) + .expect("parse policy"); + let policy = parser.build(); + let command = vec!["rm".to_string()]; + + let requirement = create_approval_requirement_for_command( + &policy, + &command, + AskForApproval::Never, + &SandboxPolicy::DangerFullAccess, + SandboxPermissions::UseDefault, + ); + + assert_eq!( + requirement, + ApprovalRequirement::Forbidden { + reason: PROMPT_REASON.to_string() + } + ); + } + + #[test] + fn approval_requirement_falls_back_to_heuristics() { + let command = vec!["python".to_string()]; + + let empty_policy = Policy::empty(); + let requirement = create_approval_requirement_for_command( + &empty_policy, + &command, + AskForApproval::UnlessTrusted, + &SandboxPolicy::ReadOnly, + SandboxPermissions::UseDefault, + ); + + assert_eq!( + requirement, + ApprovalRequirement::NeedsApproval { reason: None } + ); + } +} diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index 5f579defd..0c67c9ff5 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -31,9 +31,6 @@ pub enum Feature { GhostCommit, /// Use the single unified PTY-backed exec tool. UnifiedExec, - /// Use the shell command tool that takes `command` as a single string of - /// shell instead of an array of args passed to `execvp(3)`. - ShellCommandTool, /// Enable experimental RMCP features such as OAuth login. RmcpClient, /// Include the freeform apply_patch tool. @@ -42,6 +39,8 @@ pub enum Feature { ViewImageTool, /// Allow the model to request web searches. WebSearchRequest, + /// Gate the execpolicy enforcement for shell/unified exec. + ExecPolicy, /// Enable the model-based risk assessments for sandboxed commands. SandboxCommandAssessment, /// Enable Windows sandbox (restricted token) on Windows. @@ -260,6 +259,12 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::Stable, default_enabled: true, }, + FeatureSpec { + id: Feature::ViewImageTool, + key: "view_image_tool", + stage: Stage::Stable, + default_enabled: true, + }, // Unstable features. FeatureSpec { id: Feature::UnifiedExec, @@ -267,12 +272,6 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::Experimental, default_enabled: false, }, - FeatureSpec { - id: Feature::ShellCommandTool, - key: "shell_command_tool", - stage: Stage::Experimental, - default_enabled: false, - }, FeatureSpec { id: Feature::RmcpClient, key: "rmcp_client", @@ -285,18 +284,18 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::Beta, default_enabled: false, }, - FeatureSpec { - id: Feature::ViewImageTool, - key: "view_image_tool", - stage: Stage::Stable, - default_enabled: true, - }, FeatureSpec { id: Feature::WebSearchRequest, key: "web_search_request", stage: Stage::Stable, default_enabled: false, }, + FeatureSpec { + id: Feature::ExecPolicy, + key: "exec_policy", + stage: Stage::Experimental, + default_enabled: true, + }, FeatureSpec { id: Feature::SandboxCommandAssessment, key: "experimental_sandbox_command_assessment", @@ -313,7 +312,7 @@ pub const FEATURES: &[FeatureSpec] = &[ id: Feature::RemoteCompaction, key: "remote_compaction", stage: Stage::Experimental, - default_enabled: false, + default_enabled: true, }, FeatureSpec { id: Feature::ParallelToolCalls, diff --git a/codex-rs/core/src/git_info.rs b/codex-rs/core/src/git_info.rs index 387e9a682..34e0afc72 100644 --- a/codex-rs/core/src/git_info.rs +++ b/codex-rs/core/src/git_info.rs @@ -825,11 +825,21 @@ mod tests { .await .expect("Should collect git info from repo"); + let remote_url_output = Command::new("git") + .args(["remote", "get-url", "origin"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to read remote url"); + // Some dev environments rewrite remotes (e.g., force SSH), so compare against + // whatever URL Git reports instead of a fixed placeholder. + let expected_remote = String::from_utf8(remote_url_output.stdout) + .unwrap() + .trim() + .to_string(); + // Should have repository URL - assert_eq!( - git_info.repository_url, - Some("https://github.com/example/repo.git".to_string()) - ); + assert_eq!(git_info.repository_url, Some(expected_remote)); } #[tokio::test] diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 3e7463345..6906489e7 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -25,6 +25,7 @@ mod environment_context; pub mod error; pub mod exec; pub mod exec_env; +mod exec_policy; pub mod features; mod flags; pub mod git_info; @@ -38,6 +39,7 @@ pub mod parse_command; pub mod powershell; mod response_processing; pub mod sandboxing; +mod text_encoding; pub mod token_data; mod truncate; mod unified_exec; diff --git a/codex-rs/core/src/model_family.rs b/codex-rs/core/src/model_family.rs index 150420fec..ef54a9584 100644 --- a/codex-rs/core/src/model_family.rs +++ b/codex-rs/core/src/model_family.rs @@ -4,6 +4,7 @@ use codex_protocol::config_types::Verbosity; use crate::config::types::ReasoningSummaryFormat; use crate::tools::handlers::apply_patch::ApplyPatchToolType; use crate::tools::spec::ConfigShellToolType; +use crate::truncate::TruncationPolicy; /// The `instructions` field in the payload sent to a model should always start /// with this content. @@ -11,6 +12,7 @@ const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md"); const GPT_5_CODEX_INSTRUCTIONS: &str = include_str!("../gpt_5_codex_prompt.md"); const GPT_5_1_INSTRUCTIONS: &str = include_str!("../gpt_5_1_prompt.md"); +const GPT_5_1_CODEX_MAX_INSTRUCTIONS: &str = include_str!("../gpt-5.1-codex-max_prompt.md"); /// A model family is a group of models that share certain characteristics. #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -66,6 +68,8 @@ pub struct ModelFamily { /// Preferred shell tool type for this model family when features do not override it. pub shell_type: ConfigShellToolType, + + pub truncation_policy: TruncationPolicy, } macro_rules! model_family { @@ -89,6 +93,7 @@ macro_rules! model_family { shell_type: ConfigShellToolType::Default, default_verbosity: None, default_reasoning_effort: None, + truncation_policy: TruncationPolicy::Bytes(10_000), }; // apply overrides @@ -145,7 +150,9 @@ pub fn find_family_for_model(slug: &str) -> Option { "test_sync_tool".to_string(), ], supports_parallel_tool_calls: true, + shell_type: ConfigShellToolType::ShellCommand, support_verbosity: true, + truncation_policy: TruncationPolicy::Tokens(10_000), ) // Internal models. @@ -161,12 +168,25 @@ pub fn find_family_for_model(slug: &str) -> Option { "list_dir".to_string(), "read_file".to_string(), ], - shell_type: if cfg!(windows) { ConfigShellToolType::ShellCommand } else { ConfigShellToolType::Default }, + shell_type: ConfigShellToolType::ShellCommand, supports_parallel_tool_calls: true, support_verbosity: true, + truncation_policy: TruncationPolicy::Tokens(10_000), ) // Production models. + } else if slug.starts_with("gpt-5.1-codex-max") { + model_family!( + slug, slug, + supports_reasoning_summaries: true, + reasoning_summary_format: ReasoningSummaryFormat::Experimental, + base_instructions: GPT_5_1_CODEX_MAX_INSTRUCTIONS.to_string(), + apply_patch_tool_type: Some(ApplyPatchToolType::Freeform), + shell_type: ConfigShellToolType::ShellCommand, + supports_parallel_tool_calls: true, + support_verbosity: false, + truncation_policy: TruncationPolicy::Tokens(10_000), + ) } else if slug.starts_with("gpt-5-codex") || slug.starts_with("gpt-5.1-codex") || slug.starts_with("codex-") @@ -177,9 +197,10 @@ pub fn find_family_for_model(slug: &str) -> Option { reasoning_summary_format: ReasoningSummaryFormat::Experimental, base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(), apply_patch_tool_type: Some(ApplyPatchToolType::Freeform), - shell_type: if cfg!(windows) { ConfigShellToolType::ShellCommand } else { ConfigShellToolType::Default }, + shell_type: ConfigShellToolType::ShellCommand, supports_parallel_tool_calls: true, support_verbosity: false, + truncation_policy: TruncationPolicy::Tokens(10_000), ) } else if slug.starts_with("gpt-5.1") { model_family!( @@ -190,6 +211,8 @@ pub fn find_family_for_model(slug: &str) -> Option { default_verbosity: Some(Verbosity::Low), base_instructions: GPT_5_1_INSTRUCTIONS.to_string(), default_reasoning_effort: Some(ReasoningEffort::Medium), + truncation_policy: TruncationPolicy::Bytes(10_000), + shell_type: ConfigShellToolType::ShellCommand, supports_parallel_tool_calls: true, ) } else if slug.starts_with("gpt-5") { @@ -197,7 +220,9 @@ pub fn find_family_for_model(slug: &str) -> Option { slug, "gpt-5", supports_reasoning_summaries: true, needs_special_apply_patch_instructions: true, + shell_type: ConfigShellToolType::Default, support_verbosity: true, + truncation_policy: TruncationPolicy::Bytes(10_000), ) } else { None @@ -220,5 +245,6 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily { shell_type: ConfigShellToolType::Default, default_verbosity: None, default_reasoning_effort: None, + truncation_policy: TruncationPolicy::Bytes(10_000), } } diff --git a/codex-rs/core/src/openai_model_info.rs b/codex-rs/core/src/openai_model_info.rs index 9d3703281..0ae3f03eb 100644 --- a/codex-rs/core/src/openai_model_info.rs +++ b/codex-rs/core/src/openai_model_info.rs @@ -2,7 +2,6 @@ use crate::model_family::ModelFamily; // Shared constants for commonly used window/token sizes. pub(crate) const CONTEXT_WINDOW_272K: i64 = 272_000; -pub(crate) const MAX_OUTPUT_TOKENS_128K: i64 = 128_000; /// Metadata about a model, particularly OpenAI models. /// We may want to consider including details like the pricing for @@ -14,19 +13,15 @@ pub(crate) struct ModelInfo { /// Size of the context window in tokens. This is the maximum size of the input context. pub(crate) context_window: i64, - /// Maximum number of output tokens that can be generated for the model. - pub(crate) max_output_tokens: i64, - /// Token threshold where we should automatically compact conversation history. This considers /// input tokens + output tokens of this turn. pub(crate) auto_compact_token_limit: Option, } impl ModelInfo { - const fn new(context_window: i64, max_output_tokens: i64) -> Self { + const fn new(context_window: i64) -> Self { Self { context_window, - max_output_tokens, auto_compact_token_limit: Some(Self::default_auto_compact_limit(context_window)), } } @@ -42,45 +37,44 @@ pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option { // OSS models have a 128k shared token pool. // Arbitrarily splitting it: 3/4 input context, 1/4 output. // https://openai.com/index/gpt-oss-model-card/ - "gpt-oss-20b" => Some(ModelInfo::new(96_000, 32_000)), - "gpt-oss-120b" => Some(ModelInfo::new(96_000, 32_000)), + "gpt-oss-20b" => Some(ModelInfo::new(96_000)), + "gpt-oss-120b" => Some(ModelInfo::new(96_000)), // https://platform.openai.com/docs/models/o3 - "o3" => Some(ModelInfo::new(200_000, 100_000)), + "o3" => Some(ModelInfo::new(200_000)), // https://platform.openai.com/docs/models/o4-mini - "o4-mini" => Some(ModelInfo::new(200_000, 100_000)), + "o4-mini" => Some(ModelInfo::new(200_000)), // https://platform.openai.com/docs/models/codex-mini-latest - "codex-mini-latest" => Some(ModelInfo::new(200_000, 100_000)), + "codex-mini-latest" => Some(ModelInfo::new(200_000)), // As of Jun 25, 2025, gpt-4.1 defaults to gpt-4.1-2025-04-14. // https://platform.openai.com/docs/models/gpt-4.1 - "gpt-4.1" | "gpt-4.1-2025-04-14" => Some(ModelInfo::new(1_047_576, 32_768)), + "gpt-4.1" | "gpt-4.1-2025-04-14" => Some(ModelInfo::new(1_047_576)), // As of Jun 25, 2025, gpt-4o defaults to gpt-4o-2024-08-06. // https://platform.openai.com/docs/models/gpt-4o - "gpt-4o" | "gpt-4o-2024-08-06" => Some(ModelInfo::new(128_000, 16_384)), + "gpt-4o" | "gpt-4o-2024-08-06" => Some(ModelInfo::new(128_000)), // https://platform.openai.com/docs/models/gpt-4o?snapshot=gpt-4o-2024-05-13 - "gpt-4o-2024-05-13" => Some(ModelInfo::new(128_000, 4_096)), + "gpt-4o-2024-05-13" => Some(ModelInfo::new(128_000)), // https://platform.openai.com/docs/models/gpt-4o?snapshot=gpt-4o-2024-11-20 - "gpt-4o-2024-11-20" => Some(ModelInfo::new(128_000, 16_384)), + "gpt-4o-2024-11-20" => Some(ModelInfo::new(128_000)), // https://platform.openai.com/docs/models/gpt-3.5-turbo - "gpt-3.5-turbo" => Some(ModelInfo::new(16_385, 4_096)), + "gpt-3.5-turbo" => Some(ModelInfo::new(16_385)), - _ if slug.starts_with("gpt-5-codex") || slug.starts_with("gpt-5.1-codex") => { - Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K)) + _ if slug.starts_with("gpt-5-codex") + || slug.starts_with("gpt-5.1-codex") + || slug.starts_with("gpt-5.1-codex-max") => + { + Some(ModelInfo::new(CONTEXT_WINDOW_272K)) } - _ if slug.starts_with("gpt-5") => { - Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K)) - } + _ if slug.starts_with("gpt-5") => Some(ModelInfo::new(CONTEXT_WINDOW_272K)), - _ if slug.starts_with("codex-") => { - Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K)) - } + _ if slug.starts_with("codex-") => Some(ModelInfo::new(CONTEXT_WINDOW_272K)), _ => None, } diff --git a/codex-rs/core/src/otel_init.rs b/codex-rs/core/src/otel_init.rs index 5931d7caf..5900c9b4a 100644 --- a/codex-rs/core/src/otel_init.rs +++ b/codex-rs/core/src/otel_init.rs @@ -5,6 +5,7 @@ use crate::default_client::originator; use codex_otel::config::OtelExporter; use codex_otel::config::OtelHttpProtocol; use codex_otel::config::OtelSettings; +use codex_otel::config::OtelTlsConfig as OtelTlsSettings; use codex_otel::otel_provider::OtelProvider; use std::error::Error; @@ -21,6 +22,7 @@ pub fn build_provider( endpoint, headers, protocol, + tls, } => { let protocol = match protocol { Protocol::Json => OtelHttpProtocol::Json, @@ -34,14 +36,28 @@ pub fn build_provider( .map(|(k, v)| (k.clone(), v.clone())) .collect(), protocol, + tls: tls.as_ref().map(|config| OtelTlsSettings { + ca_certificate: config.ca_certificate.clone(), + client_certificate: config.client_certificate.clone(), + client_private_key: config.client_private_key.clone(), + }), } } - Kind::OtlpGrpc { endpoint, headers } => OtelExporter::OtlpGrpc { + Kind::OtlpGrpc { + endpoint, + headers, + tls, + } => OtelExporter::OtlpGrpc { endpoint: endpoint.clone(), headers: headers .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect(), + tls: tls.as_ref().map(|config| OtelTlsSettings { + ca_certificate: config.ca_certificate.clone(), + client_certificate: config.client_certificate.clone(), + client_private_key: config.client_private_key.clone(), + }), }, }; diff --git a/codex-rs/core/src/sandboxing/mod.rs b/codex-rs/core/src/sandboxing/mod.rs index 5e564f510..d43646021 100644 --- a/codex-rs/core/src/sandboxing/mod.rs +++ b/codex-rs/core/src/sandboxing/mod.rs @@ -8,6 +8,7 @@ ready‑to‑spawn environment. pub mod assessment; +use crate::exec::ExecExpiration; use crate::exec::ExecToolCallOutput; use crate::exec::SandboxType; use crate::exec::StdoutStream; @@ -26,23 +27,45 @@ use std::collections::HashMap; use std::path::Path; use std::path::PathBuf; -#[derive(Clone, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum SandboxPermissions { + UseDefault, + RequireEscalated, +} + +impl SandboxPermissions { + pub fn requires_escalated_permissions(self) -> bool { + matches!(self, SandboxPermissions::RequireEscalated) + } +} + +impl From for SandboxPermissions { + fn from(with_escalated_permissions: bool) -> Self { + if with_escalated_permissions { + SandboxPermissions::RequireEscalated + } else { + SandboxPermissions::UseDefault + } + } +} + +#[derive(Debug)] pub struct CommandSpec { pub program: String, pub args: Vec, pub cwd: PathBuf, pub env: HashMap, - pub timeout_ms: Option, + pub expiration: ExecExpiration, pub with_escalated_permissions: Option, pub justification: Option, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct ExecEnv { pub command: Vec, pub cwd: PathBuf, pub env: HashMap, - pub timeout_ms: Option, + pub expiration: ExecExpiration, pub sandbox: SandboxType, pub with_escalated_permissions: Option, pub justification: Option, @@ -93,13 +116,13 @@ impl SandboxManager { pub(crate) fn transform( &self, - spec: &CommandSpec, + mut spec: CommandSpec, policy: &SandboxPolicy, sandbox: SandboxType, sandbox_policy_cwd: &Path, codex_linux_sandbox_exe: Option<&PathBuf>, ) -> Result { - let mut env = spec.env.clone(); + let mut env = spec.env; if !policy.has_full_network_access() { env.insert( CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR.to_string(), @@ -108,8 +131,8 @@ impl SandboxManager { } let mut command = Vec::with_capacity(1 + spec.args.len()); - command.push(spec.program.clone()); - command.extend(spec.args.iter().cloned()); + command.push(spec.program); + command.append(&mut spec.args); let (command, sandbox_env, arg0_override) = match sandbox { SandboxType::None => (command, HashMap::new(), None), @@ -154,12 +177,12 @@ impl SandboxManager { Ok(ExecEnv { command, - cwd: spec.cwd.clone(), + cwd: spec.cwd, env, - timeout_ms: spec.timeout_ms, + expiration: spec.expiration, sandbox, with_escalated_permissions: spec.with_escalated_permissions, - justification: spec.justification.clone(), + justification: spec.justification, arg0: arg0_override, }) } @@ -170,9 +193,9 @@ impl SandboxManager { } pub async fn execute_env( - env: &ExecEnv, + env: ExecEnv, policy: &SandboxPolicy, stdout_stream: Option, ) -> crate::error::Result { - execute_exec_env(env.clone(), policy, stdout_stream).await + execute_exec_env(env, policy, stdout_stream).await } diff --git a/codex-rs/core/src/shell.rs b/codex-rs/core/src/shell.rs index 7bfec089c..ac115facb 100644 --- a/codex-rs/core/src/shell.rs +++ b/codex-rs/core/src/shell.rs @@ -7,61 +7,41 @@ pub enum ShellType { Zsh, Bash, PowerShell, + Sh, + Cmd, } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -pub struct ZshShell { +pub struct Shell { + pub(crate) shell_type: ShellType, pub(crate) shell_path: PathBuf, } -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -pub struct BashShell { - pub(crate) shell_path: PathBuf, -} - -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -pub struct PowerShellConfig { - pub(crate) shell_path: PathBuf, // Executable name or path, e.g. "pwsh" or "powershell.exe". -} - -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -pub enum Shell { - Zsh(ZshShell), - Bash(BashShell), - PowerShell(PowerShellConfig), - Unknown, -} - impl Shell { - pub fn name(&self) -> Option { - match self { - Shell::Zsh(ZshShell { shell_path, .. }) | Shell::Bash(BashShell { shell_path, .. }) => { - std::path::Path::new(shell_path) - .file_name() - .map(|s| s.to_string_lossy().to_string()) - } - Shell::PowerShell(ps) => ps - .shell_path - .file_stem() - .map(|s| s.to_string_lossy().to_string()), - Shell::Unknown => None, + pub fn name(&self) -> &'static str { + match self.shell_type { + ShellType::Zsh => "zsh", + ShellType::Bash => "bash", + ShellType::PowerShell => "powershell", + ShellType::Sh => "sh", + ShellType::Cmd => "cmd", } } /// Takes a string of shell and returns the full list of command args to /// use with `exec()` to run the shell command. pub fn derive_exec_args(&self, command: &str, use_login_shell: bool) -> Vec { - match self { - Shell::Zsh(ZshShell { shell_path, .. }) | Shell::Bash(BashShell { shell_path, .. }) => { + match self.shell_type { + ShellType::Zsh | ShellType::Bash | ShellType::Sh => { let arg = if use_login_shell { "-lc" } else { "-c" }; vec![ - shell_path.to_string_lossy().to_string(), + self.shell_path.to_string_lossy().to_string(), arg.to_string(), command.to_string(), ] } - Shell::PowerShell(ps) => { - let mut args = vec![ps.shell_path.to_string_lossy().to_string()]; + ShellType::PowerShell => { + let mut args = vec![self.shell_path.to_string_lossy().to_string()]; if !use_login_shell { args.push("-NoProfile".to_string()); } @@ -70,7 +50,12 @@ impl Shell { args.push(command.to_string()); args } - Shell::Unknown => shlex::split(command).unwrap_or_else(|| vec![command.to_string()]), + ShellType::Cmd => { + let mut args = vec![self.shell_path.to_string_lossy().to_string()]; + args.push("/c".to_string()); + args.push(command.to_string()); + args + } } } } @@ -143,19 +128,34 @@ fn get_shell_path( None } -fn get_zsh_shell(path: Option<&PathBuf>) -> Option { +fn get_zsh_shell(path: Option<&PathBuf>) -> Option { let shell_path = get_shell_path(ShellType::Zsh, path, "zsh", vec!["/bin/zsh"]); - shell_path.map(|shell_path| ZshShell { shell_path }) + shell_path.map(|shell_path| Shell { + shell_type: ShellType::Zsh, + shell_path, + }) } -fn get_bash_shell(path: Option<&PathBuf>) -> Option { +fn get_bash_shell(path: Option<&PathBuf>) -> Option { let shell_path = get_shell_path(ShellType::Bash, path, "bash", vec!["/bin/bash"]); - shell_path.map(|shell_path| BashShell { shell_path }) + shell_path.map(|shell_path| Shell { + shell_type: ShellType::Bash, + shell_path, + }) +} + +fn get_sh_shell(path: Option<&PathBuf>) -> Option { + let shell_path = get_shell_path(ShellType::Sh, path, "sh", vec!["/bin/sh"]); + + shell_path.map(|shell_path| Shell { + shell_type: ShellType::Sh, + shell_path, + }) } -fn get_powershell_shell(path: Option<&PathBuf>) -> Option { +fn get_powershell_shell(path: Option<&PathBuf>) -> Option { let shell_path = get_shell_path( ShellType::PowerShell, path, @@ -164,26 +164,56 @@ fn get_powershell_shell(path: Option<&PathBuf>) -> Option { ) .or_else(|| get_shell_path(ShellType::PowerShell, path, "powershell", vec![])); - shell_path.map(|shell_path| PowerShellConfig { shell_path }) + shell_path.map(|shell_path| Shell { + shell_type: ShellType::PowerShell, + shell_path, + }) +} + +fn get_cmd_shell(path: Option<&PathBuf>) -> Option { + let shell_path = get_shell_path(ShellType::Cmd, path, "cmd", vec![]); + + shell_path.map(|shell_path| Shell { + shell_type: ShellType::Cmd, + shell_path, + }) +} + +fn ultimate_fallback_shell() -> Shell { + if cfg!(windows) { + Shell { + shell_type: ShellType::Cmd, + shell_path: PathBuf::from("cmd.exe"), + } + } else { + Shell { + shell_type: ShellType::Sh, + shell_path: PathBuf::from("/bin/sh"), + } + } } pub fn get_shell_by_model_provided_path(shell_path: &PathBuf) -> Shell { detect_shell_type(shell_path) .and_then(|shell_type| get_shell(shell_type, Some(shell_path))) - .unwrap_or(Shell::Unknown) + .unwrap_or(ultimate_fallback_shell()) } pub fn get_shell(shell_type: ShellType, path: Option<&PathBuf>) -> Option { match shell_type { - ShellType::Zsh => get_zsh_shell(path).map(Shell::Zsh), - ShellType::Bash => get_bash_shell(path).map(Shell::Bash), - ShellType::PowerShell => get_powershell_shell(path).map(Shell::PowerShell), + ShellType::Zsh => get_zsh_shell(path), + ShellType::Bash => get_bash_shell(path), + ShellType::PowerShell => get_powershell_shell(path), + ShellType::Sh => get_sh_shell(path), + ShellType::Cmd => get_cmd_shell(path), } } pub fn detect_shell_type(shell_path: &PathBuf) -> Option { match shell_path.as_os_str().to_str() { Some("zsh") => Some(ShellType::Zsh), + Some("sh") => Some(ShellType::Sh), + Some("cmd") => Some(ShellType::Cmd), Some("bash") => Some(ShellType::Bash), Some("pwsh") => Some(ShellType::PowerShell), Some("powershell") => Some(ShellType::PowerShell), @@ -200,14 +230,29 @@ pub fn detect_shell_type(shell_path: &PathBuf) -> Option { } } -pub async fn default_user_shell() -> Shell { +pub fn default_user_shell() -> Shell { + default_user_shell_from_path(get_user_shell_path()) +} + +fn default_user_shell_from_path(user_shell_path: Option) -> Shell { if cfg!(windows) { - get_shell(ShellType::PowerShell, None).unwrap_or(Shell::Unknown) + get_shell(ShellType::PowerShell, None).unwrap_or(ultimate_fallback_shell()) } else { - get_user_shell_path() + let user_default_shell = user_shell_path .and_then(|shell| detect_shell_type(&shell)) - .and_then(|shell_type| get_shell(shell_type, None)) - .unwrap_or(Shell::Unknown) + .and_then(|shell_type| get_shell(shell_type, None)); + + let shell_with_fallback = if cfg!(target_os = "macos") { + user_default_shell + .or_else(|| get_shell(ShellType::Zsh, None)) + .or_else(|| get_shell(ShellType::Bash, None)) + } else { + user_default_shell + .or_else(|| get_shell(ShellType::Bash, None)) + .or_else(|| get_shell(ShellType::Zsh, None)) + }; + + shell_with_fallback.unwrap_or(ultimate_fallback_shell()) } } @@ -263,6 +308,19 @@ mod detect_shell_type_tests { detect_shell_type(&PathBuf::from("/usr/local/bin/pwsh")), Some(ShellType::PowerShell) ); + assert_eq!( + detect_shell_type(&PathBuf::from("/bin/sh")), + Some(ShellType::Sh) + ); + assert_eq!(detect_shell_type(&PathBuf::from("sh")), Some(ShellType::Sh)); + assert_eq!( + detect_shell_type(&PathBuf::from("cmd")), + Some(ShellType::Cmd) + ); + assert_eq!( + detect_shell_type(&PathBuf::from("cmd.exe")), + Some(ShellType::Cmd) + ); } } @@ -278,10 +336,17 @@ mod tests { fn detects_zsh() { let zsh_shell = get_shell(ShellType::Zsh, None).unwrap(); - let ZshShell { shell_path } = match zsh_shell { - Shell::Zsh(zsh_shell) => zsh_shell, - _ => panic!("expected zsh shell"), - }; + let shell_path = zsh_shell.shell_path; + + assert_eq!(shell_path, PathBuf::from("/bin/zsh")); + } + + #[test] + #[cfg(target_os = "macos")] + fn fish_fallback_to_zsh() { + let zsh_shell = default_user_shell_from_path(Some(PathBuf::from("/bin/fish"))); + + let shell_path = zsh_shell.shell_path; assert_eq!(shell_path, PathBuf::from("/bin/zsh")); } @@ -289,18 +354,60 @@ mod tests { #[test] fn detects_bash() { let bash_shell = get_shell(ShellType::Bash, None).unwrap(); - let BashShell { shell_path } = match bash_shell { - Shell::Bash(bash_shell) => bash_shell, - _ => panic!("expected bash shell"), - }; + let shell_path = bash_shell.shell_path; assert!( shell_path == PathBuf::from("/bin/bash") - || shell_path == PathBuf::from("/usr/bin/bash"), + || shell_path == PathBuf::from("/usr/bin/bash") + || shell_path == PathBuf::from("/usr/local/bin/bash"), "shell path: {shell_path:?}", ); } + #[test] + fn detects_sh() { + let sh_shell = get_shell(ShellType::Sh, None).unwrap(); + let shell_path = sh_shell.shell_path; + assert!( + shell_path == PathBuf::from("/bin/sh") || shell_path == PathBuf::from("/usr/bin/sh"), + "shell path: {shell_path:?}", + ); + } + + #[test] + fn can_run_on_shell_test() { + let cmd = "echo \"Works\""; + if cfg!(windows) { + assert!(shell_works( + get_shell(ShellType::PowerShell, None), + "Out-String 'Works'", + true, + )); + assert!(shell_works(get_shell(ShellType::Cmd, None), cmd, true,)); + assert!(shell_works(Some(ultimate_fallback_shell()), cmd, true)); + } else { + assert!(shell_works(Some(ultimate_fallback_shell()), cmd, true)); + assert!(shell_works(get_shell(ShellType::Zsh, None), cmd, false)); + assert!(shell_works(get_shell(ShellType::Bash, None), cmd, true)); + assert!(shell_works(get_shell(ShellType::Sh, None), cmd, true)); + } + } + + fn shell_works(shell: Option, command: &str, required: bool) -> bool { + if let Some(shell) = shell { + let args = shell.derive_exec_args(command, false); + let output = Command::new(args[0].clone()) + .args(&args[1..]) + .output() + .unwrap(); + assert!(output.status.success()); + assert!(String::from_utf8_lossy(&output.stdout).contains("Works")); + true + } else { + !required + } + } + #[tokio::test] async fn test_current_shell_detects_zsh() { let shell = Command::new("sh") @@ -312,10 +419,11 @@ mod tests { let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string(); if shell_path.ends_with("/zsh") { assert_eq!( - default_user_shell().await, - Shell::Zsh(ZshShell { + default_user_shell(), + Shell { + shell_type: ShellType::Zsh, shell_path: PathBuf::from(shell_path), - }) + } ); } } @@ -326,11 +434,8 @@ mod tests { return; } - let powershell_shell = default_user_shell().await; - let PowerShellConfig { shell_path } = match powershell_shell { - Shell::PowerShell(powershell_shell) => powershell_shell, - _ => panic!("expected powershell shell"), - }; + let powershell_shell = default_user_shell(); + let shell_path = powershell_shell.shell_path; assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe")); } @@ -342,10 +447,7 @@ mod tests { } let powershell_shell = get_shell(ShellType::PowerShell, None).unwrap(); - let PowerShellConfig { shell_path } = match powershell_shell { - Shell::PowerShell(powershell_shell) => powershell_shell, - _ => panic!("expected powershell shell"), - }; + let shell_path = powershell_shell.shell_path; assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe")); } diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index 5b630d5ce..2dfa5199f 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -7,6 +7,7 @@ use crate::context_manager::ContextManager; use crate::protocol::RateLimitSnapshot; use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; +use crate::truncate::TruncationPolicy; /// Persistent, session-scoped state previously stored directly on `Session`. pub(crate) struct SessionState { @@ -18,20 +19,21 @@ pub(crate) struct SessionState { impl SessionState { /// Create a new session state mirroring previous `State::default()` semantics. pub(crate) fn new(session_configuration: SessionConfiguration) -> Self { + let history = ContextManager::new(); Self { session_configuration, - history: ContextManager::new(), + history, latest_rate_limits: None, } } // History helpers - pub(crate) fn record_items(&mut self, items: I) + pub(crate) fn record_items(&mut self, items: I, policy: TruncationPolicy) where I: IntoIterator, I::Item: std::ops::Deref, { - self.history.record_items(items) + self.history.record_items(items, policy); } pub(crate) fn clone_history(&self) -> ContextManager { diff --git a/codex-rs/core/src/tasks/compact.rs b/codex-rs/core/src/tasks/compact.rs index e2e5625b5..893c0c476 100644 --- a/codex-rs/core/src/tasks/compact.rs +++ b/codex-rs/core/src/tasks/compact.rs @@ -3,10 +3,8 @@ use std::sync::Arc; use super::SessionTask; use super::SessionTaskContext; use crate::codex::TurnContext; -use crate::features::Feature; use crate::state::TaskKind; use async_trait::async_trait; -use codex_app_server_protocol::AuthMode; use codex_protocol::user_input::UserInput; use tokio_util::sync::CancellationToken; @@ -27,16 +25,12 @@ impl SessionTask for CompactTask { _cancellation_token: CancellationToken, ) -> Option { let session = session.clone_session(); - if session - .services - .auth_manager - .auth() - .is_some_and(|auth| auth.mode == AuthMode::ChatGPT) - && session.enabled(Feature::RemoteCompaction).await - { - crate::compact_remote::run_remote_compact_task(session, ctx, input).await + if crate::compact::should_use_remote_compact_task(&session).await { + crate::compact_remote::run_remote_compact_task(session, ctx).await } else { crate::compact::run_compact_task(session, ctx, input).await } + + None } } diff --git a/codex-rs/core/src/tasks/ghost_snapshot.rs b/codex-rs/core/src/tasks/ghost_snapshot.rs index 93830a305..ef5d42a28 100644 --- a/codex-rs/core/src/tasks/ghost_snapshot.rs +++ b/codex-rs/core/src/tasks/ghost_snapshot.rs @@ -1,10 +1,14 @@ use crate::codex::TurnContext; +use crate::protocol::EventMsg; +use crate::protocol::WarningEvent; use crate::state::TaskKind; use crate::tasks::SessionTask; use crate::tasks::SessionTaskContext; use async_trait::async_trait; use codex_git::CreateGhostCommitOptions; +use codex_git::GhostSnapshotReport; use codex_git::GitToolingError; +use codex_git::capture_ghost_snapshot_report; use codex_git::create_ghost_commit; use codex_protocol::models::ResponseItem; use codex_protocol::user_input::UserInput; @@ -39,6 +43,27 @@ impl SessionTask for GhostSnapshotTask { _ = cancellation_token.cancelled() => true, _ = async { let repo_path = ctx_for_task.cwd.clone(); + // First, compute a snapshot report so we can warn about + // large untracked directories before running the heavier + // snapshot logic. + if let Ok(Ok(report)) = tokio::task::spawn_blocking({ + let repo_path = repo_path.clone(); + move || { + let options = CreateGhostCommitOptions::new(&repo_path); + capture_ghost_snapshot_report(&options) + } + }) + .await + && let Some(message) = format_large_untracked_warning(&report) { + session + .session + .send_event( + &ctx_for_task, + EventMsg::Warning(WarningEvent { message }), + ) + .await; + } + // Required to run in a dedicated blocking pool. match tokio::task::spawn_blocking(move || { let options = CreateGhostCommitOptions::new(&repo_path); @@ -56,23 +81,18 @@ impl SessionTask for GhostSnapshotTask { .await; info!("ghost commit captured: {}", ghost_commit.id()); } - Ok(Err(err)) => { - warn!( + Ok(Err(err)) => match err { + GitToolingError::NotAGitRepository { .. } => info!( sub_id = ctx_for_task.sub_id.as_str(), - "failed to capture ghost snapshot: {err}" - ); - let message = match err { - GitToolingError::NotAGitRepository { .. } => { - "Snapshots disabled: current directory is not a Git repository." - .to_string() - } - _ => format!("Snapshots disabled after ghost snapshot error: {err}."), - }; - session - .session - .notify_background_event(&ctx_for_task, message) - .await; - } + "skipping ghost snapshot because current directory is not a Git repository" + ), + _ => { + warn!( + sub_id = ctx_for_task.sub_id.as_str(), + "failed to capture ghost snapshot: {err}" + ); + } + }, Err(err) => { warn!( sub_id = ctx_for_task.sub_id.as_str(), @@ -108,3 +128,22 @@ impl GhostSnapshotTask { Self { token } } } + +fn format_large_untracked_warning(report: &GhostSnapshotReport) -> Option { + if report.large_untracked_dirs.is_empty() { + return None; + } + const MAX_DIRS: usize = 3; + let mut parts: Vec = Vec::new(); + for dir in report.large_untracked_dirs.iter().take(MAX_DIRS) { + parts.push(format!("{} ({} files)", dir.path.display(), dir.file_count)); + } + if report.large_untracked_dirs.len() > MAX_DIRS { + let remaining = report.large_untracked_dirs.len() - MAX_DIRS; + parts.push(format!("{remaining} more")); + } + Some(format!( + "Repository snapshot encountered large untracked directories: {}. This can slow Codex; consider adding these paths to .gitignore or disabling undo in your config.", + parts.join(", ") + )) +} diff --git a/codex-rs/core/src/tasks/review.rs b/codex-rs/core/src/tasks/review.rs index e0bb7d4e9..14a95dba5 100644 --- a/codex-rs/core/src/tasks/review.rs +++ b/codex-rs/core/src/tasks/review.rs @@ -23,8 +23,18 @@ use codex_protocol::user_input::UserInput; use super::SessionTask; use super::SessionTaskContext; -#[derive(Clone, Copy, Default)] -pub(crate) struct ReviewTask; +#[derive(Clone, Copy)] +pub(crate) struct ReviewTask { + append_to_original_thread: bool, +} + +impl ReviewTask { + pub(crate) fn new(append_to_original_thread: bool) -> Self { + Self { + append_to_original_thread, + } + } +} #[async_trait] impl SessionTask for ReviewTask { @@ -52,13 +62,25 @@ impl SessionTask for ReviewTask { None => None, }; if !cancellation_token.is_cancelled() { - exit_review_mode(session.clone_session(), output.clone(), ctx.clone()).await; + exit_review_mode( + session.clone_session(), + output.clone(), + ctx.clone(), + self.append_to_original_thread, + ) + .await; } None } async fn abort(&self, session: Arc, ctx: Arc) { - exit_review_mode(session.clone_session(), None, ctx).await; + exit_review_mode( + session.clone_session(), + None, + ctx, + self.append_to_original_thread, + ) + .await; } } @@ -175,32 +197,35 @@ pub(crate) async fn exit_review_mode( session: Arc, review_output: Option, ctx: Arc, + append_to_original_thread: bool, ) { - let user_message = if let Some(out) = review_output.clone() { - let mut findings_str = String::new(); - let text = out.overall_explanation.trim(); - if !text.is_empty() { - findings_str.push_str(text); - } - if !out.findings.is_empty() { - let block = format_review_findings_block(&out.findings, None); - findings_str.push_str(&format!("\n{block}")); - } - crate::client_common::REVIEW_EXIT_SUCCESS_TMPL.replace("{results}", &findings_str) - } else { - crate::client_common::REVIEW_EXIT_INTERRUPTED_TMPL.to_string() - }; + if append_to_original_thread { + let user_message = if let Some(out) = review_output.clone() { + let mut findings_str = String::new(); + let text = out.overall_explanation.trim(); + if !text.is_empty() { + findings_str.push_str(text); + } + if !out.findings.is_empty() { + let block = format_review_findings_block(&out.findings, None); + findings_str.push_str(&format!("\n{block}")); + } + crate::client_common::REVIEW_EXIT_SUCCESS_TMPL.replace("{results}", &findings_str) + } else { + crate::client_common::REVIEW_EXIT_INTERRUPTED_TMPL.to_string() + }; - session - .record_conversation_items( - &ctx, - &[ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { text: user_message }], - }], - ) - .await; + session + .record_conversation_items( + &ctx, + &[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { text: user_message }], + }], + ) + .await; + } session .send_event( ctx.as_ref(), diff --git a/codex-rs/core/src/tasks/user_shell.rs b/codex-rs/core/src/tasks/user_shell.rs index 465517adb..32e8a2596 100644 --- a/codex-rs/core/src/tasks/user_shell.rs +++ b/codex-rs/core/src/tasks/user_shell.rs @@ -31,6 +31,8 @@ use crate::user_shell_command::user_shell_command_record_item; use super::SessionTask; use super::SessionTaskContext; +const USER_SHELL_TIMEOUT_MS: u64 = 60 * 60 * 1000; // 1 hour + #[derive(Clone)] pub(crate) struct UserShellCommandTask { command: String, @@ -93,7 +95,9 @@ impl SessionTask for UserShellCommandTask { command: command.clone(), cwd: cwd.clone(), env: create_env(&turn_context.shell_environment_policy), - timeout_ms: None, + // TODO(zhao-oai): Now that we have ExecExpiration::Cancellation, we + // should use that instead of an "arbitrarily large" timeout here. + expiration: USER_SHELL_TIMEOUT_MS.into(), sandbox: SandboxType::None, with_escalated_permissions: None, justification: None, @@ -122,7 +126,11 @@ impl SessionTask for UserShellCommandTask { duration: Duration::ZERO, timed_out: false, }; - let output_items = [user_shell_command_record_item(&raw_command, &exec_output)]; + let output_items = [user_shell_command_record_item( + &raw_command, + &exec_output, + &turn_context, + )]; session .record_conversation_items(turn_context.as_ref(), &output_items) .await; @@ -164,12 +172,19 @@ impl SessionTask for UserShellCommandTask { aggregated_output: output.aggregated_output.text.clone(), exit_code: output.exit_code, duration: output.duration, - formatted_output: format_exec_output_str(&output), + formatted_output: format_exec_output_str( + &output, + turn_context.truncation_policy, + ), }), ) .await; - let output_items = [user_shell_command_record_item(&raw_command, &output)]; + let output_items = [user_shell_command_record_item( + &raw_command, + &output, + &turn_context, + )]; session .record_conversation_items(turn_context.as_ref(), &output_items) .await; @@ -201,11 +216,18 @@ impl SessionTask for UserShellCommandTask { aggregated_output: exec_output.aggregated_output.text.clone(), exit_code: exec_output.exit_code, duration: exec_output.duration, - formatted_output: format_exec_output_str(&exec_output), + formatted_output: format_exec_output_str( + &exec_output, + turn_context.truncation_policy, + ), }), ) .await; - let output_items = [user_shell_command_record_item(&raw_command, &exec_output)]; + let output_items = [user_shell_command_record_item( + &raw_command, + &exec_output, + &turn_context, + )]; session .record_conversation_items(turn_context.as_ref(), &output_items) .await; diff --git a/codex-rs/core/src/text_encoding.rs b/codex-rs/core/src/text_encoding.rs new file mode 100644 index 000000000..fde44c419 --- /dev/null +++ b/codex-rs/core/src/text_encoding.rs @@ -0,0 +1,461 @@ +//! Text encoding detection and conversion utilities for shell output. +//! +//! Windows users frequently run into code pages such as CP1251 or CP866 when invoking commands +//! through VS Code. Those bytes show up as invalid UTF-8 and used to be replaced with the standard +//! Unicode replacement character. We now lean on `chardetng` and `encoding_rs` so we can +//! automatically detect and decode the vast majority of legacy encodings before falling back to +//! lossy UTF-8 decoding. + +use chardetng::EncodingDetector; +use encoding_rs::Encoding; +use encoding_rs::IBM866; +use encoding_rs::WINDOWS_1252; + +/// Attempts to convert arbitrary bytes to UTF-8 with best-effort encoding detection. +pub fn bytes_to_string_smart(bytes: &[u8]) -> String { + if bytes.is_empty() { + return String::new(); + } + + if let Ok(utf8_str) = std::str::from_utf8(bytes) { + return utf8_str.to_owned(); + } + + let encoding = detect_encoding(bytes); + decode_bytes(bytes, encoding) +} + +// Windows-1252 reassigns a handful of 0x80-0x9F slots to smart punctuation (curly quotes, dashes, +// ™). CP866 uses those *same byte values* for uppercase Cyrillic letters. When chardetng sees shell +// snippets that mix these bytes with ASCII it sometimes guesses IBM866, so “smart quotes” render as +// Cyrillic garbage (“УФЦ”) in VS Code. However, CP866 uppercase tokens are perfectly valid output +// (e.g., `ПРИ test`) so we cannot flip every 0x80-0x9F byte to Windows-1252 either. The compromise +// is to only coerce IBM866 to Windows-1252 when (a) the high bytes are exclusively the punctuation +// values listed below and (b) we spot adjacent ASCII. This targets the real failure case without +// clobbering legitimate Cyrillic text. If another code page has a similar collision, introduce a +// dedicated allowlist (like this one) plus unit tests that capture the actual shell output we want +// to preserve. Windows-1252 byte values for smart punctuation. +const WINDOWS_1252_PUNCT_BYTES: [u8; 8] = [ + 0x91, // ‘ (left single quotation mark) + 0x92, // ’ (right single quotation mark) + 0x93, // “ (left double quotation mark) + 0x94, // ” (right double quotation mark) + 0x95, // • (bullet) + 0x96, // – (en dash) + 0x97, // — (em dash) + 0x99, // ™ (trade mark sign) +]; + +fn detect_encoding(bytes: &[u8]) -> &'static Encoding { + let mut detector = EncodingDetector::new(); + detector.feed(bytes, true); + let (encoding, _is_confident) = detector.guess_assess(None, true); + + // chardetng occasionally reports IBM866 for short strings that only contain Windows-1252 “smart + // punctuation” bytes (0x80-0x9F) because that range maps to Cyrillic letters in IBM866. When + // those bytes show up alongside an ASCII word (typical shell output: `"“`test), we know the + // intent was likely CP1252 quotes/dashes. Prefer WINDOWS_1252 in that specific situation so we + // render the characters users expect instead of Cyrillic junk. References: + // - Windows-1252 reserving 0x80-0x9F for curly quotes/dashes: + // https://en.wikipedia.org/wiki/Windows-1252 + // - CP866 mapping 0x93/0x94/0x96 to Cyrillic letters, so the same bytes show up as “УФЦ” when + // mis-decoded: https://www.unicode.org/Public/MAPPINGS/VENDORS/MICSFT/PC/CP866.TXT + if encoding == IBM866 && looks_like_windows_1252_punctuation(bytes) { + return WINDOWS_1252; + } + + encoding +} + +fn decode_bytes(bytes: &[u8], encoding: &'static Encoding) -> String { + let (decoded, _, had_errors) = encoding.decode(bytes); + + if had_errors { + return String::from_utf8_lossy(bytes).into_owned(); + } + + decoded.into_owned() +} + +/// Detect whether the byte stream looks like Windows-1252 “smart punctuation” wrapped around +/// otherwise-ASCII text. +/// +/// Context: IBM866 and Windows-1252 share the 0x80-0x9F slot range. In IBM866 these bytes decode to +/// Cyrillic letters, whereas Windows-1252 maps them to curly quotes and dashes. chardetng can guess +/// IBM866 for short snippets that only contain those bytes, which turns shell output such as +/// `“test”` into unreadable Cyrillic. To avoid that, we treat inputs comprising a handful of bytes +/// from the problematic range plus ASCII letters as CP1252 punctuation. We deliberately do *not* +/// cap how many of those punctuation bytes we accept: VS Code frequently prints several quoted +/// phrases (e.g., `"foo" – "bar"`), and truncating the count would once again mis-decode those as +/// Cyrillic. If we discover additional encodings with overlapping byte ranges, prefer adding +/// encoding-specific byte allowlists like `WINDOWS_1252_PUNCT` and tests that exercise real-world +/// shell snippets. +fn looks_like_windows_1252_punctuation(bytes: &[u8]) -> bool { + let mut saw_extended_punctuation = false; + let mut saw_ascii_word = false; + + for &byte in bytes { + if byte >= 0xA0 { + return false; + } + if (0x80..=0x9F).contains(&byte) { + if !is_windows_1252_punct(byte) { + return false; + } + saw_extended_punctuation = true; + } + if byte.is_ascii_alphabetic() { + saw_ascii_word = true; + } + } + + saw_extended_punctuation && saw_ascii_word +} + +fn is_windows_1252_punct(byte: u8) -> bool { + WINDOWS_1252_PUNCT_BYTES.contains(&byte) +} + +#[cfg(test)] +mod tests { + use super::*; + use encoding_rs::BIG5; + use encoding_rs::EUC_KR; + use encoding_rs::GBK; + use encoding_rs::ISO_8859_2; + use encoding_rs::ISO_8859_3; + use encoding_rs::ISO_8859_4; + use encoding_rs::ISO_8859_5; + use encoding_rs::ISO_8859_6; + use encoding_rs::ISO_8859_7; + use encoding_rs::ISO_8859_8; + use encoding_rs::ISO_8859_10; + use encoding_rs::ISO_8859_13; + use encoding_rs::SHIFT_JIS; + use encoding_rs::WINDOWS_874; + use encoding_rs::WINDOWS_1250; + use encoding_rs::WINDOWS_1251; + use encoding_rs::WINDOWS_1253; + use encoding_rs::WINDOWS_1254; + use encoding_rs::WINDOWS_1255; + use encoding_rs::WINDOWS_1256; + use encoding_rs::WINDOWS_1257; + use encoding_rs::WINDOWS_1258; + use pretty_assertions::assert_eq; + + #[test] + fn test_utf8_passthrough() { + // Fast path: when UTF-8 is valid we should avoid copies and return as-is. + let utf8_text = "Hello, мир! 世界"; + let bytes = utf8_text.as_bytes(); + assert_eq!(bytes_to_string_smart(bytes), utf8_text); + } + + #[test] + fn test_cp1251_russian_text() { + // Cyrillic text emitted by PowerShell/WSL in CP1251 should decode cleanly. + let bytes = b"\xEF\xF0\xE8\xEC\xE5\xF0"; // "пример" encoded with Windows-1251 + assert_eq!(bytes_to_string_smart(bytes), "пример"); + } + + #[test] + fn test_cp1251_privet_word() { + // Regression: CP1251 words like "Привет" must not be mis-identified as Windows-1252. + let bytes = b"\xCF\xF0\xE8\xE2\xE5\xF2"; // "Привет" encoded with Windows-1251 + assert_eq!(bytes_to_string_smart(bytes), "Привет"); + } + + #[test] + fn test_koi8_r_privet_word() { + // KOI8-R output should decode to the original Cyrillic as well. + let bytes = b"\xF0\xD2\xC9\xD7\xC5\xD4"; // "Привет" encoded with KOI8-R + assert_eq!(bytes_to_string_smart(bytes), "Привет"); + } + + #[test] + fn test_cp866_russian_text() { + // Legacy consoles (cmd.exe) commonly emit CP866 bytes for Cyrillic content. + let bytes = b"\xAF\xE0\xA8\xAC\xA5\xE0"; // "пример" encoded with CP866 + assert_eq!(bytes_to_string_smart(bytes), "пример"); + } + + #[test] + fn test_cp866_uppercase_text() { + // Ensure the IBM866 heuristic still returns IBM866 for uppercase-only words. + let bytes = b"\x8F\x90\x88"; // "ПРИ" encoded with CP866 uppercase letters + assert_eq!(bytes_to_string_smart(bytes), "ПРИ"); + } + + #[test] + fn test_cp866_uppercase_followed_by_ascii() { + // Regression test: uppercase CP866 tokens next to ASCII text should not be treated as + // CP1252. + let bytes = b"\x8F\x90\x88 test"; // "ПРИ test" encoded with CP866 uppercase letters followed by ASCII + assert_eq!(bytes_to_string_smart(bytes), "ПРИ test"); + } + + #[test] + fn test_windows_1252_quotes() { + // Smart detection should map Windows-1252 punctuation into proper Unicode. + let bytes = b"\x93\x94test"; + assert_eq!(bytes_to_string_smart(bytes), "\u{201C}\u{201D}test"); + } + + #[test] + fn test_windows_1252_multiple_quotes() { + // Longer snippets of punctuation (e.g., “foo” – “bar”) should still flip to CP1252. + let bytes = b"\x93foo\x94 \x96 \x93bar\x94"; + assert_eq!( + bytes_to_string_smart(bytes), + "\u{201C}foo\u{201D} \u{2013} \u{201C}bar\u{201D}" + ); + } + + #[test] + fn test_windows_1252_privet_gibberish_is_preserved() { + // Windows-1252 cannot encode Cyrillic; if the input literally contains "ПÑ..." we should not "fix" it. + let bytes = "Привет".as_bytes(); + assert_eq!(bytes_to_string_smart(bytes), "Привет"); + } + + #[test] + fn test_iso8859_1_latin_text() { + // ISO-8859-1 (code page 28591) is the Latin segment used by LatArCyrHeb. + // encoding_rs unifies ISO-8859-1 with Windows-1252, so reuse that constant here. + let (encoded, _, had_errors) = WINDOWS_1252.encode("Hello"); + assert!(!had_errors, "failed to encode Latin sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Hello"); + } + + #[test] + fn test_iso8859_2_central_european_text() { + // ISO-8859-2 (code page 28592) covers additional Central European glyphs. + let (encoded, _, had_errors) = ISO_8859_2.encode("Příliš žluťoučký kůň"); + assert!(!had_errors, "failed to encode ISO-8859-2 sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "Příliš žluťoučký kůň" + ); + } + + #[test] + fn test_iso8859_3_south_europe_text() { + // ISO-8859-3 (code page 28593) adds support for Maltese/Esperanto letters. + // chardetng rarely distinguishes ISO-8859-3 from neighboring Latin code pages, so we rely on + // an ASCII-only sample to ensure round-tripping still succeeds. + let (encoded, _, had_errors) = ISO_8859_3.encode("Esperanto and Maltese"); + assert!(!had_errors, "failed to encode ISO-8859-3 sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "Esperanto and Maltese" + ); + } + + #[test] + fn test_iso8859_4_baltic_text() { + // ISO-8859-4 (code page 28594) targets the Baltic/Nordic repertoire. + let sample = "Šis ir rakstzīmju kodēšanas tests. Dažās valodās, kurās tiek \ + izmantotas latīņu valodas burti, lēmuma pieņemšanai mums ir nepieciešams \ + vairāk ieguldījuma."; + let (encoded, _, had_errors) = ISO_8859_4.encode(sample); + assert!(!had_errors, "failed to encode ISO-8859-4 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); + } + + #[test] + fn test_iso8859_5_cyrillic_text() { + // ISO-8859-5 (code page 28595) covers the Cyrillic portion. + let (encoded, _, had_errors) = ISO_8859_5.encode("Привет"); + assert!(!had_errors, "failed to encode Cyrillic sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Привет"); + } + + #[test] + fn test_iso8859_6_arabic_text() { + // ISO-8859-6 (code page 28596) covers the Arabic glyphs. + let (encoded, _, had_errors) = ISO_8859_6.encode("مرحبا"); + assert!(!had_errors, "failed to encode Arabic sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "مرحبا"); + } + + #[test] + fn test_iso8859_7_greek_text() { + // ISO-8859-7 (code page 28597) is used for Greek locales. + let (encoded, _, had_errors) = ISO_8859_7.encode("Καλημέρα"); + assert!(!had_errors, "failed to encode ISO-8859-7 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Καλημέρα"); + } + + #[test] + fn test_iso8859_8_hebrew_text() { + // ISO-8859-8 (code page 28598) covers the Hebrew glyphs. + let (encoded, _, had_errors) = ISO_8859_8.encode("שלום"); + assert!(!had_errors, "failed to encode Hebrew sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "שלום"); + } + + #[test] + fn test_iso8859_9_turkish_text() { + // ISO-8859-9 (code page 28599) mirrors Latin-1 but inserts Turkish letters. + // encoding_rs exposes the equivalent Windows-1254 mapping. + let (encoded, _, had_errors) = WINDOWS_1254.encode("İstanbul"); + assert!(!had_errors, "failed to encode ISO-8859-9 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "İstanbul"); + } + + #[test] + fn test_iso8859_10_nordic_text() { + // ISO-8859-10 (code page 28600) adds additional Nordic letters. + let sample = "Þetta er prófun fyrir Ægir og Øystein."; + let (encoded, _, had_errors) = ISO_8859_10.encode(sample); + assert!(!had_errors, "failed to encode ISO-8859-10 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); + } + + #[test] + fn test_iso8859_11_thai_text() { + // ISO-8859-11 (code page 28601) mirrors TIS-620 / Windows-874 for Thai. + let sample = "ภาษาไทยสำหรับการทดสอบ ISO-8859-11"; + // encoding_rs exposes the equivalent Windows-874 encoding, so use that constant. + let (encoded, _, had_errors) = WINDOWS_874.encode(sample); + assert!(!had_errors, "failed to encode ISO-8859-11 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); + } + + // ISO-8859-12 was never standardized, and encodings 14–16 cannot be distinguished reliably + // without the heuristics we removed (chardetng generally reports neighboring Latin pages), so + // we intentionally omit coverage for those slots until the detector can identify them. + + #[test] + fn test_iso8859_13_baltic_text() { + // ISO-8859-13 (code page 28603) is common across Baltic languages. + let (encoded, _, had_errors) = ISO_8859_13.encode("Sveiki"); + assert!(!had_errors, "failed to encode ISO-8859-13 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Sveiki"); + } + + #[test] + fn test_windows_1250_central_european_text() { + let (encoded, _, had_errors) = WINDOWS_1250.encode("Příliš žluťoučký kůň"); + assert!(!had_errors, "failed to encode Central European sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "Příliš žluťoučký kůň" + ); + } + + #[test] + fn test_windows_1251_encoded_text() { + let (encoded, _, had_errors) = WINDOWS_1251.encode("Привет из Windows-1251"); + assert!(!had_errors, "failed to encode Windows-1251 sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "Привет из Windows-1251" + ); + } + + #[test] + fn test_windows_1253_greek_text() { + let (encoded, _, had_errors) = WINDOWS_1253.encode("Γειά σου"); + assert!(!had_errors, "failed to encode Greek sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Γειά σου"); + } + + #[test] + fn test_windows_1254_turkish_text() { + let (encoded, _, had_errors) = WINDOWS_1254.encode("İstanbul"); + assert!(!had_errors, "failed to encode Turkish sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "İstanbul"); + } + + #[test] + fn test_windows_1255_hebrew_text() { + let (encoded, _, had_errors) = WINDOWS_1255.encode("שלום"); + assert!(!had_errors, "failed to encode Windows-1255 Hebrew sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "שלום"); + } + + #[test] + fn test_windows_1256_arabic_text() { + let (encoded, _, had_errors) = WINDOWS_1256.encode("مرحبا"); + assert!(!had_errors, "failed to encode Windows-1256 Arabic sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "مرحبا"); + } + + #[test] + fn test_windows_1257_baltic_text() { + let (encoded, _, had_errors) = WINDOWS_1257.encode("Pērkons"); + assert!(!had_errors, "failed to encode Baltic sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Pērkons"); + } + + #[test] + fn test_windows_1258_vietnamese_text() { + let (encoded, _, had_errors) = WINDOWS_1258.encode("Xin chào"); + assert!(!had_errors, "failed to encode Vietnamese sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Xin chào"); + } + + #[test] + fn test_windows_874_thai_text() { + let (encoded, _, had_errors) = WINDOWS_874.encode("สวัสดีครับ นี่คือการทดสอบภาษาไทย"); + assert!(!had_errors, "failed to encode Thai sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "สวัสดีครับ นี่คือการทดสอบภาษาไทย" + ); + } + + #[test] + fn test_windows_932_shift_jis_text() { + let (encoded, _, had_errors) = SHIFT_JIS.encode("こんにちは"); + assert!(!had_errors, "failed to encode Shift-JIS sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "こんにちは"); + } + + #[test] + fn test_windows_936_gbk_text() { + let (encoded, _, had_errors) = GBK.encode("你好,世界,这是一个测试"); + assert!(!had_errors, "failed to encode GBK sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "你好,世界,这是一个测试" + ); + } + + #[test] + fn test_windows_949_korean_text() { + let (encoded, _, had_errors) = EUC_KR.encode("안녕하세요"); + assert!(!had_errors, "failed to encode Korean sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "안녕하세요"); + } + + #[test] + fn test_windows_950_big5_text() { + let (encoded, _, had_errors) = BIG5.encode("繁體"); + assert!(!had_errors, "failed to encode Big5 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "繁體"); + } + + #[test] + fn test_latin1_cafe() { + // Latin-1 bytes remain common in Western-European locales; decode them directly. + let bytes = b"caf\xE9"; // codespell:ignore caf + assert_eq!(bytes_to_string_smart(bytes), "café"); + } + + #[test] + fn test_preserves_ansi_sequences() { + // ANSI escape sequences should survive regardless of the detected encoding. + let bytes = b"\x1b[31mred\x1b[0m"; + assert_eq!(bytes_to_string_smart(bytes), "\x1b[31mred\x1b[0m"); + } + + #[test] + fn test_fallback_to_lossy() { + // Completely invalid sequences fall back to the old lossy behavior. + let invalid_bytes = [0xFF, 0xFE, 0xFD]; + let result = bytes_to_string_smart(&invalid_bytes); + assert_eq!(result, String::from_utf8_lossy(&invalid_bytes)); + } +} diff --git a/codex-rs/core/src/tools/events.rs b/codex-rs/core/src/tools/events.rs index 26dc397dc..37df12d4b 100644 --- a/codex-rs/core/src/tools/events.rs +++ b/codex-rs/core/src/tools/events.rs @@ -88,6 +88,7 @@ pub(crate) enum ToolEmitter { cwd: PathBuf, source: ExecCommandSource, parsed_cmd: Vec, + freeform: bool, }, ApplyPatch { changes: HashMap, @@ -103,13 +104,19 @@ pub(crate) enum ToolEmitter { } impl ToolEmitter { - pub fn shell(command: Vec, cwd: PathBuf, source: ExecCommandSource) -> Self { + pub fn shell( + command: Vec, + cwd: PathBuf, + source: ExecCommandSource, + freeform: bool, + ) -> Self { let parsed_cmd = parse_command(&command); Self::Shell { command, cwd, source, parsed_cmd, + freeform, } } @@ -144,6 +151,7 @@ impl ToolEmitter { cwd, source, parsed_cmd, + .. }, stage, ) => { @@ -171,15 +179,17 @@ impl ToolEmitter { ctx.turn, EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id: ctx.call_id.to_string(), + turn_id: ctx.turn.sub_id.clone(), auto_approved: *auto_approved, changes: changes.clone(), }), ) .await; } - (Self::ApplyPatch { .. }, ToolEventStage::Success(output)) => { + (Self::ApplyPatch { changes, .. }, ToolEventStage::Success(output)) => { emit_patch_end( ctx, + changes.clone(), output.stdout.text.clone(), output.stderr.text.clone(), output.exit_code == 0, @@ -187,11 +197,12 @@ impl ToolEmitter { .await; } ( - Self::ApplyPatch { .. }, + Self::ApplyPatch { changes, .. }, ToolEventStage::Failure(ToolEventFailure::Output(output)), ) => { emit_patch_end( ctx, + changes.clone(), output.stdout.text.clone(), output.stderr.text.clone(), output.exit_code == 0, @@ -199,10 +210,17 @@ impl ToolEmitter { .await; } ( - Self::ApplyPatch { .. }, + Self::ApplyPatch { changes, .. }, ToolEventStage::Failure(ToolEventFailure::Message(message)), ) => { - emit_patch_end(ctx, String::new(), (*message).to_string(), false).await; + emit_patch_end( + ctx, + changes.clone(), + String::new(), + (*message).to_string(), + false, + ) + .await; } ( Self::UnifiedExec { @@ -234,6 +252,19 @@ impl ToolEmitter { self.emit(ctx, ToolEventStage::Begin).await; } + fn format_exec_output_for_model( + &self, + output: &ExecToolCallOutput, + ctx: ToolEventCtx<'_>, + ) -> String { + match self { + Self::Shell { freeform: true, .. } => { + super::format_exec_output_for_model_freeform(output, ctx.turn.truncation_policy) + } + _ => super::format_exec_output_for_model_structured(output, ctx.turn.truncation_policy), + } + } + pub async fn finish( &self, ctx: ToolEventCtx<'_>, @@ -241,7 +272,7 @@ impl ToolEmitter { ) -> Result { let (event, result) = match out { Ok(output) => { - let content = super::format_exec_output_for_model(&output); + let content = self.format_exec_output_for_model(&output, ctx); let exit_code = output.exit_code; let event = ToolEventStage::Success(output); let result = if exit_code == 0 { @@ -253,7 +284,7 @@ impl ToolEmitter { } Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) | Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied { output }))) => { - let response = super::format_exec_output_for_model(&output); + let response = self.format_exec_output_for_model(&output, ctx); let event = ToolEventStage::Failure(ToolEventFailure::Output(*output)); let result = Err(FunctionCallError::RespondToModel(response)); (event, result) @@ -342,7 +373,7 @@ async fn emit_exec_stage( aggregated_output: output.aggregated_output.text.clone(), exit_code: output.exit_code, duration: output.duration, - formatted_output: format_exec_output_str(&output), + formatted_output: format_exec_output_str(&output, ctx.turn.truncation_policy), }; emit_exec_end(ctx, exec_input, exec_result).await; } @@ -388,15 +419,23 @@ async fn emit_exec_end( .await; } -async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, success: bool) { +async fn emit_patch_end( + ctx: ToolEventCtx<'_>, + changes: HashMap, + stdout: String, + stderr: String, + success: bool, +) { ctx.session .send_event( ctx.turn, EventMsg::PatchApplyEnd(PatchApplyEndEvent { call_id: ctx.call_id.to_string(), + turn_id: ctx.turn.sub_id.clone(), stdout, stderr, success, + changes, }), ) .await; diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index 43b2bb129..99c822fa5 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -9,9 +9,11 @@ use crate::apply_patch::convert_apply_patch_to_protocol; use crate::codex::TurnContext; use crate::exec::ExecParams; use crate::exec_env::create_env; +use crate::exec_policy::create_approval_requirement_for_command; use crate::function_tool::FunctionCallError; use crate::is_safe_command::is_known_safe_command; use crate::protocol::ExecCommandSource; +use crate::sandboxing::SandboxPermissions; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; @@ -35,7 +37,7 @@ impl ShellHandler { ExecParams { command: params.command, cwd: turn_context.resolve_path(params.workdir.clone()), - timeout_ms: params.timeout_ms, + expiration: params.timeout_ms.into(), env: create_env(&turn_context.shell_environment_policy), with_escalated_permissions: params.with_escalated_permissions, justification: params.justification, @@ -57,7 +59,7 @@ impl ShellCommandHandler { ExecParams { command, cwd: turn_context.resolve_path(params.workdir.clone()), - timeout_ms: params.timeout_ms, + expiration: params.timeout_ms.into(), env: create_env(&turn_context.shell_environment_policy), with_escalated_permissions: params.with_escalated_permissions, justification: params.justification, @@ -117,6 +119,7 @@ impl ToolHandler for ShellHandler { turn, tracker, call_id, + false, ) .await } @@ -129,6 +132,7 @@ impl ToolHandler for ShellHandler { turn, tracker, call_id, + false, ) .await } @@ -176,6 +180,7 @@ impl ToolHandler for ShellCommandHandler { turn, tracker, call_id, + true, ) .await } @@ -189,6 +194,7 @@ impl ShellHandler { turn: Arc, tracker: crate::tools::context::SharedTurnDiffTracker, call_id: String, + freeform: bool, ) -> Result { // Approval policy guard for explicit escalation in non-OnRequest modes. if exec_params.with_escalated_permissions.unwrap_or(false) @@ -237,7 +243,7 @@ impl ShellHandler { let req = ApplyPatchRequest { patch: apply.action.patch.clone(), cwd: apply.action.cwd.clone(), - timeout_ms: exec_params.timeout_ms, + timeout_ms: exec_params.expiration.timeout_ms(), user_explicitly_approved: apply.user_explicitly_approved_this_action, codex_exe: turn.codex_linux_sandbox_exe.clone(), }; @@ -282,18 +288,29 @@ impl ShellHandler { } let source = ExecCommandSource::Agent; - let emitter = - ToolEmitter::shell(exec_params.command.clone(), exec_params.cwd.clone(), source); + let emitter = ToolEmitter::shell( + exec_params.command.clone(), + exec_params.cwd.clone(), + source, + freeform, + ); let event_ctx = ToolEventCtx::new(session.as_ref(), turn.as_ref(), &call_id, None); emitter.begin(event_ctx).await; let req = ShellRequest { command: exec_params.command.clone(), cwd: exec_params.cwd.clone(), - timeout_ms: exec_params.timeout_ms, + timeout_ms: exec_params.expiration.timeout_ms(), env: exec_params.env.clone(), with_escalated_permissions: exec_params.with_escalated_permissions, justification: exec_params.justification.clone(), + approval_requirement: create_approval_requirement_for_command( + &turn.exec_policy, + &exec_params.command, + turn.approval_policy, + &turn.sandbox_policy, + SandboxPermissions::from(exec_params.with_escalated_permissions.unwrap_or(false)), + ), }; let mut orchestrator = ToolOrchestrator::new(); let mut runtime = ShellRuntime::new(); @@ -321,29 +338,30 @@ mod tests { use std::path::PathBuf; use crate::is_safe_command::is_known_safe_command; - use crate::shell::BashShell; - use crate::shell::PowerShellConfig; use crate::shell::Shell; - use crate::shell::ZshShell; + use crate::shell::ShellType; /// The logic for is_known_safe_command() has heuristics for known shells, /// so we must ensure the commands generated by [ShellCommandHandler] can be /// recognized as safe if the `command` is safe. #[test] fn commands_generated_by_shell_command_handler_can_be_matched_by_is_known_safe_command() { - let bash_shell = Shell::Bash(BashShell { + let bash_shell = Shell { + shell_type: ShellType::Bash, shell_path: PathBuf::from("/bin/bash"), - }); + }; assert_safe(&bash_shell, "ls -la"); - let zsh_shell = Shell::Zsh(ZshShell { + let zsh_shell = Shell { + shell_type: ShellType::Zsh, shell_path: PathBuf::from("/bin/zsh"), - }); + }; assert_safe(&zsh_shell, "ls -la"); - let powershell = Shell::PowerShell(PowerShellConfig { + let powershell = Shell { + shell_type: ShellType::PowerShell, shell_path: PathBuf::from("pwsh.exe"), - }); + }; assert_safe(&powershell, "ls -Name"); } diff --git a/codex-rs/core/src/tools/mod.rs b/codex-rs/core/src/tools/mod.rs index c94a7c28d..c1ef916d7 100644 --- a/codex-rs/core/src/tools/mod.rs +++ b/codex-rs/core/src/tools/mod.rs @@ -9,10 +9,10 @@ pub mod runtimes; pub mod sandboxing; pub mod spec; -use crate::context_manager::MODEL_FORMAT_MAX_BYTES; -use crate::context_manager::MODEL_FORMAT_MAX_LINES; -use crate::context_manager::format_output_for_model_body; use crate::exec::ExecToolCallOutput; +use crate::truncate::TruncationPolicy; +use crate::truncate::formatted_truncate_text; +use crate::truncate::truncate_text; pub use router::ToolRouter; use serde::Serialize; @@ -24,7 +24,10 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str = /// Format the combined exec output for sending back to the model. /// Includes exit code and duration metadata; truncates large bodies safely. -pub fn format_exec_output_for_model(exec_output: &ExecToolCallOutput) -> String { +pub fn format_exec_output_for_model_structured( + exec_output: &ExecToolCallOutput, + truncation_policy: TruncationPolicy, +) -> String { let ExecToolCallOutput { exit_code, duration, @@ -46,7 +49,7 @@ pub fn format_exec_output_for_model(exec_output: &ExecToolCallOutput) -> String // round to 1 decimal place let duration_seconds = ((duration.as_secs_f32()) * 10.0).round() / 10.0; - let formatted_output = format_exec_output_str(exec_output); + let formatted_output = format_exec_output_str(exec_output, truncation_policy); let payload = ExecOutput { output: &formatted_output, @@ -60,7 +63,35 @@ pub fn format_exec_output_for_model(exec_output: &ExecToolCallOutput) -> String serde_json::to_string(&payload).expect("serialize ExecOutput") } -pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String { +pub fn format_exec_output_for_model_freeform( + exec_output: &ExecToolCallOutput, + truncation_policy: TruncationPolicy, +) -> String { + // round to 1 decimal place + let duration_seconds = ((exec_output.duration.as_secs_f32()) * 10.0).round() / 10.0; + + let total_lines = exec_output.aggregated_output.text.lines().count(); + + let formatted_output = truncate_text(&exec_output.aggregated_output.text, truncation_policy); + + let mut sections = Vec::new(); + + sections.push(format!("Exit code: {}", exec_output.exit_code)); + sections.push(format!("Wall time: {duration_seconds} seconds")); + if total_lines != formatted_output.lines().count() { + sections.push(format!("Total output lines: {total_lines}")); + } + + sections.push("Output:".to_string()); + sections.push(formatted_output); + + sections.join("\n") +} + +pub fn format_exec_output_str( + exec_output: &ExecToolCallOutput, + truncation_policy: TruncationPolicy, +) -> String { let ExecToolCallOutput { aggregated_output, .. } = exec_output; @@ -77,5 +108,5 @@ pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String { }; // Truncate for model consumption before serialization. - format_output_for_model_body(&body, MODEL_FORMAT_MAX_BYTES, MODEL_FORMAT_MAX_LINES) + formatted_truncate_text(&body, truncation_policy) } diff --git a/codex-rs/core/src/tools/orchestrator.rs b/codex-rs/core/src/tools/orchestrator.rs index 878e48e8b..7e8e152f6 100644 --- a/codex-rs/core/src/tools/orchestrator.rs +++ b/codex-rs/core/src/tools/orchestrator.rs @@ -11,11 +11,13 @@ use crate::error::get_error_message_ui; use crate::exec::ExecToolCallOutput; use crate::sandboxing::SandboxManager; use crate::tools::sandboxing::ApprovalCtx; +use crate::tools::sandboxing::ApprovalRequirement; use crate::tools::sandboxing::ProvidesSandboxRetryData; use crate::tools::sandboxing::SandboxAttempt; use crate::tools::sandboxing::ToolCtx; use crate::tools::sandboxing::ToolError; use crate::tools::sandboxing::ToolRuntime; +use crate::tools::sandboxing::default_approval_requirement; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::ReviewDecision; @@ -49,40 +51,52 @@ impl ToolOrchestrator { let otel_cfg = codex_otel::otel_event_manager::ToolDecisionSource::Config; // 1) Approval - let needs_initial_approval = - tool.wants_initial_approval(req, approval_policy, &turn_ctx.sandbox_policy); let mut already_approved = false; - if needs_initial_approval { - let mut risk = None; - - if let Some(metadata) = req.sandbox_retry_data() { - risk = tool_ctx - .session - .assess_sandbox_command(turn_ctx, &tool_ctx.call_id, &metadata.command, None) - .await; + let requirement = tool.approval_requirement(req).unwrap_or_else(|| { + default_approval_requirement(approval_policy, &turn_ctx.sandbox_policy) + }); + match requirement { + ApprovalRequirement::Skip => { + otel.tool_decision(otel_tn, otel_ci, ReviewDecision::Approved, otel_cfg); + } + ApprovalRequirement::Forbidden { reason } => { + return Err(ToolError::Rejected(reason)); } + ApprovalRequirement::NeedsApproval { reason } => { + let mut risk = None; + + if let Some(metadata) = req.sandbox_retry_data() { + risk = tool_ctx + .session + .assess_sandbox_command( + turn_ctx, + &tool_ctx.call_id, + &metadata.command, + None, + ) + .await; + } - let approval_ctx = ApprovalCtx { - session: tool_ctx.session, - turn: turn_ctx, - call_id: &tool_ctx.call_id, - retry_reason: None, - risk, - }; - let decision = tool.start_approval_async(req, approval_ctx).await; + let approval_ctx = ApprovalCtx { + session: tool_ctx.session, + turn: turn_ctx, + call_id: &tool_ctx.call_id, + retry_reason: reason, + risk, + }; + let decision = tool.start_approval_async(req, approval_ctx).await; - otel.tool_decision(otel_tn, otel_ci, decision, otel_user.clone()); + otel.tool_decision(otel_tn, otel_ci, decision, otel_user.clone()); - match decision { - ReviewDecision::Denied | ReviewDecision::Abort => { - return Err(ToolError::Rejected("rejected by user".to_string())); + match decision { + ReviewDecision::Denied | ReviewDecision::Abort => { + return Err(ToolError::Rejected("rejected by user".to_string())); + } + ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {} } - ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {} + already_approved = true; } - already_approved = true; - } else { - otel.tool_decision(otel_tn, otel_ci, ReviewDecision::Approved, otel_cfg); } // 2) First attempt under the selected sandbox. diff --git a/codex-rs/core/src/tools/runtimes/apply_patch.rs b/codex-rs/core/src/tools/runtimes/apply_patch.rs index 0cdddd508..2334f1e71 100644 --- a/codex-rs/core/src/tools/runtimes/apply_patch.rs +++ b/codex-rs/core/src/tools/runtimes/apply_patch.rs @@ -67,7 +67,7 @@ impl ApplyPatchRuntime { program, args: vec![CODEX_APPLY_PATCH_ARG1.to_string(), req.patch.clone()], cwd: req.cwd.clone(), - timeout_ms: req.timeout_ms, + expiration: req.timeout_ms.into(), // Run apply_patch with a minimal environment for determinism and to avoid leaks. env: HashMap::new(), with_escalated_permissions: None, @@ -153,9 +153,9 @@ impl ToolRuntime for ApplyPatchRuntime { ) -> Result { let spec = Self::build_command_spec(req)?; let env = attempt - .env_for(&spec) + .env_for(spec) .map_err(|err| ToolError::Codex(err.into()))?; - let out = execute_env(&env, attempt.policy, Self::stdout_stream(ctx)) + let out = execute_env(env, attempt.policy, Self::stdout_stream(ctx)) .await .map_err(ToolError::Codex)?; Ok(out) diff --git a/codex-rs/core/src/tools/runtimes/mod.rs b/codex-rs/core/src/tools/runtimes/mod.rs index 212163d72..437f4af42 100644 --- a/codex-rs/core/src/tools/runtimes/mod.rs +++ b/codex-rs/core/src/tools/runtimes/mod.rs @@ -4,6 +4,7 @@ Module: runtimes Concrete ToolRuntime implementations for specific tools. Each runtime stays small and focused and reuses the orchestrator for approvals + sandbox + retry. */ +use crate::exec::ExecExpiration; use crate::sandboxing::CommandSpec; use crate::tools::sandboxing::ToolError; use std::collections::HashMap; @@ -19,7 +20,7 @@ pub(crate) fn build_command_spec( command: &[String], cwd: &Path, env: &HashMap, - timeout_ms: Option, + expiration: ExecExpiration, with_escalated_permissions: Option, justification: Option, ) -> Result { @@ -31,7 +32,7 @@ pub(crate) fn build_command_spec( args: args.to_vec(), cwd: cwd.to_path_buf(), env: env.clone(), - timeout_ms, + expiration, with_escalated_permissions, justification, }) diff --git a/codex-rs/core/src/tools/runtimes/shell.rs b/codex-rs/core/src/tools/runtimes/shell.rs index bf7ae7fa3..b46f72b48 100644 --- a/codex-rs/core/src/tools/runtimes/shell.rs +++ b/codex-rs/core/src/tools/runtimes/shell.rs @@ -4,13 +4,12 @@ Runtime: shell Executes shell requests under the orchestrator: asks for approval when needed, builds a CommandSpec, and runs it under the current SandboxAttempt. */ -use crate::command_safety::is_dangerous_command::requires_initial_appoval; use crate::exec::ExecToolCallOutput; -use crate::protocol::SandboxPolicy; use crate::sandboxing::execute_env; use crate::tools::runtimes::build_command_spec; use crate::tools::sandboxing::Approvable; use crate::tools::sandboxing::ApprovalCtx; +use crate::tools::sandboxing::ApprovalRequirement; use crate::tools::sandboxing::ProvidesSandboxRetryData; use crate::tools::sandboxing::SandboxAttempt; use crate::tools::sandboxing::SandboxRetryData; @@ -20,7 +19,6 @@ use crate::tools::sandboxing::ToolCtx; use crate::tools::sandboxing::ToolError; use crate::tools::sandboxing::ToolRuntime; use crate::tools::sandboxing::with_cached_approval; -use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::ReviewDecision; use futures::future::BoxFuture; use std::path::PathBuf; @@ -33,6 +31,7 @@ pub struct ShellRequest { pub env: std::collections::HashMap, pub with_escalated_permissions: Option, pub justification: Option, + pub approval_requirement: ApprovalRequirement, } impl ProvidesSandboxRetryData for ShellRequest { @@ -114,18 +113,8 @@ impl Approvable for ShellRuntime { }) } - fn wants_initial_approval( - &self, - req: &ShellRequest, - policy: AskForApproval, - sandbox_policy: &SandboxPolicy, - ) -> bool { - requires_initial_appoval( - policy, - sandbox_policy, - &req.command, - req.with_escalated_permissions.unwrap_or(false), - ) + fn approval_requirement(&self, req: &ShellRequest) -> Option { + Some(req.approval_requirement.clone()) } fn wants_escalated_first_attempt(&self, req: &ShellRequest) -> bool { @@ -144,14 +133,14 @@ impl ToolRuntime for ShellRuntime { &req.command, &req.cwd, &req.env, - req.timeout_ms, + req.timeout_ms.into(), req.with_escalated_permissions, req.justification.clone(), )?; let env = attempt - .env_for(&spec) + .env_for(spec) .map_err(|err| ToolError::Codex(err.into()))?; - let out = execute_env(&env, attempt.policy, Self::stdout_stream(ctx)) + let out = execute_env(env, attempt.policy, Self::stdout_stream(ctx)) .await .map_err(ToolError::Codex)?; Ok(out) diff --git a/codex-rs/core/src/tools/runtimes/unified_exec.rs b/codex-rs/core/src/tools/runtimes/unified_exec.rs index cddac1924..3f0362259 100644 --- a/codex-rs/core/src/tools/runtimes/unified_exec.rs +++ b/codex-rs/core/src/tools/runtimes/unified_exec.rs @@ -1,4 +1,3 @@ -use crate::command_safety::is_dangerous_command::requires_initial_appoval; /* Runtime: unified exec @@ -7,9 +6,11 @@ the session manager to spawn PTYs once an ExecEnv is prepared. */ use crate::error::CodexErr; use crate::error::SandboxErr; +use crate::exec::ExecExpiration; use crate::tools::runtimes::build_command_spec; use crate::tools::sandboxing::Approvable; use crate::tools::sandboxing::ApprovalCtx; +use crate::tools::sandboxing::ApprovalRequirement; use crate::tools::sandboxing::ProvidesSandboxRetryData; use crate::tools::sandboxing::SandboxAttempt; use crate::tools::sandboxing::SandboxRetryData; @@ -22,9 +23,7 @@ use crate::tools::sandboxing::with_cached_approval; use crate::unified_exec::UnifiedExecError; use crate::unified_exec::UnifiedExecSession; use crate::unified_exec::UnifiedExecSessionManager; -use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::ReviewDecision; -use codex_protocol::protocol::SandboxPolicy; use futures::future::BoxFuture; use std::collections::HashMap; use std::path::PathBuf; @@ -36,6 +35,7 @@ pub struct UnifiedExecRequest { pub env: HashMap, pub with_escalated_permissions: Option, pub justification: Option, + pub approval_requirement: ApprovalRequirement, } impl ProvidesSandboxRetryData for UnifiedExecRequest { @@ -65,6 +65,7 @@ impl UnifiedExecRequest { env: HashMap, with_escalated_permissions: Option, justification: Option, + approval_requirement: ApprovalRequirement, ) -> Self { Self { command, @@ -72,6 +73,7 @@ impl UnifiedExecRequest { env, with_escalated_permissions, justification, + approval_requirement, } } } @@ -129,18 +131,8 @@ impl Approvable for UnifiedExecRuntime<'_> { }) } - fn wants_initial_approval( - &self, - req: &UnifiedExecRequest, - policy: AskForApproval, - sandbox_policy: &SandboxPolicy, - ) -> bool { - requires_initial_appoval( - policy, - sandbox_policy, - &req.command, - req.with_escalated_permissions.unwrap_or(false), - ) + fn approval_requirement(&self, req: &UnifiedExecRequest) -> Option { + Some(req.approval_requirement.clone()) } fn wants_escalated_first_attempt(&self, req: &UnifiedExecRequest) -> bool { @@ -159,13 +151,13 @@ impl<'a> ToolRuntime for UnifiedExecRunt &req.command, &req.cwd, &req.env, - None, + ExecExpiration::DefaultTimeout, req.with_escalated_permissions, req.justification.clone(), ) .map_err(|_| ToolError::Rejected("missing command line for PTY".to_string()))?; let exec_env = attempt - .env_for(&spec) + .env_for(spec) .map_err(|err| ToolError::Codex(err.into()))?; self.manager .open_session_with_exec_env(&exec_env) diff --git a/codex-rs/core/src/tools/sandboxing.rs b/codex-rs/core/src/tools/sandboxing.rs index da1c22b54..f9e3e20ea 100644 --- a/codex-rs/core/src/tools/sandboxing.rs +++ b/codex-rs/core/src/tools/sandboxing.rs @@ -86,6 +86,37 @@ pub(crate) struct ApprovalCtx<'a> { pub risk: Option, } +// Specifies what tool orchestrator should do with a given tool call. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) enum ApprovalRequirement { + /// No approval required for this tool call + Skip, + /// Approval required for this tool call + NeedsApproval { reason: Option }, + /// Execution forbidden for this tool call + Forbidden { reason: String }, +} + +/// - Never, OnFailure: do not ask +/// - OnRequest: ask unless sandbox policy is DangerFullAccess +/// - UnlessTrusted: always ask +pub(crate) fn default_approval_requirement( + policy: AskForApproval, + sandbox_policy: &SandboxPolicy, +) -> ApprovalRequirement { + let needs_approval = match policy { + AskForApproval::Never | AskForApproval::OnFailure => false, + AskForApproval::OnRequest => !matches!(sandbox_policy, SandboxPolicy::DangerFullAccess), + AskForApproval::UnlessTrusted => true, + }; + + if needs_approval { + ApprovalRequirement::NeedsApproval { reason: None } + } else { + ApprovalRequirement::Skip + } +} + pub(crate) trait Approvable { type ApprovalKey: Hash + Eq + Clone + Debug + Serialize; @@ -106,22 +137,11 @@ pub(crate) trait Approvable { matches!(policy, AskForApproval::Never) } - /// Decide whether an initial user approval should be requested before the - /// first attempt. Defaults to the orchestrator's behavior (pre‑refactor): - /// - Never, OnFailure: do not ask - /// - OnRequest: ask unless sandbox policy is DangerFullAccess - /// - UnlessTrusted: always ask - fn wants_initial_approval( - &self, - _req: &Req, - policy: AskForApproval, - sandbox_policy: &SandboxPolicy, - ) -> bool { - match policy { - AskForApproval::Never | AskForApproval::OnFailure => false, - AskForApproval::OnRequest => !matches!(sandbox_policy, SandboxPolicy::DangerFullAccess), - AskForApproval::UnlessTrusted => true, - } + /// Override the default approval requirement. Return `Some(_)` to specify + /// a custom requirement, or `None` to fall back to + /// policy-based default. + fn approval_requirement(&self, _req: &Req) -> Option { + None } /// Decide we can request an approval for no-sandbox execution. @@ -196,7 +216,7 @@ pub(crate) struct SandboxAttempt<'a> { impl<'a> SandboxAttempt<'a> { pub fn env_for( &self, - spec: &CommandSpec, + spec: CommandSpec, ) -> Result { self.manager.transform( spec, diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index 88ddbf551..b3d0330f8 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -57,8 +57,6 @@ impl ToolsConfig { ConfigShellToolType::Disabled } else if features.enabled(Feature::UnifiedExec) { ConfigShellToolType::UnifiedExec - } else if features.enabled(Feature::ShellCommandTool) { - ConfigShellToolType::ShellCommand } else { model_family.shell_type.clone() }; @@ -1292,11 +1290,7 @@ mod tests { "gpt-5-codex", &Features::with_defaults(), &[ - if cfg!(windows) { - "shell_command" - } else { - "shell" - }, + "shell_command", "list_mcp_resources", "list_mcp_resource_templates", "read_mcp_resource", @@ -1313,11 +1307,7 @@ mod tests { "gpt-5.1-codex", &Features::with_defaults(), &[ - if cfg!(windows) { - "shell_command" - } else { - "shell" - }, + "shell_command", "list_mcp_resources", "list_mcp_resource_templates", "read_mcp_resource", @@ -1392,11 +1382,7 @@ mod tests { "gpt-5.1-codex-mini", &Features::with_defaults(), &[ - if cfg!(windows) { - "shell_command" - } else { - "shell" - }, + "shell_command", "list_mcp_resources", "list_mcp_resource_templates", "read_mcp_resource", @@ -1407,13 +1393,29 @@ mod tests { ); } + #[test] + fn test_gpt_5_defaults() { + assert_model_tools( + "gpt-5", + &Features::with_defaults(), + &[ + "shell", + "list_mcp_resources", + "list_mcp_resource_templates", + "read_mcp_resource", + "update_plan", + "view_image", + ], + ); + } + #[test] fn test_gpt_5_1_defaults() { assert_model_tools( "gpt-5.1", &Features::with_defaults(), &[ - "shell", + "shell_command", "list_mcp_resources", "list_mcp_resource_templates", "read_mcp_resource", @@ -1464,22 +1466,6 @@ mod tests { assert_contains_tool_names(&tools, &subset); } - #[test] - fn test_build_specs_shell_command_present() { - assert_model_tools( - "codex-mini-latest", - Features::with_defaults().enable(Feature::ShellCommandTool), - &[ - "shell_command", - "list_mcp_resources", - "list_mcp_resource_templates", - "read_mcp_resource", - "update_plan", - "view_image", - ], - ); - } - #[test] #[ignore] fn test_parallel_support_flags() { diff --git a/codex-rs/core/src/truncate.rs b/codex-rs/core/src/truncate.rs index 42d6a967d..bf883c061 100644 --- a/codex-rs/core/src/truncate.rs +++ b/codex-rs/core/src/truncate.rs @@ -2,48 +2,140 @@ //! and suffix on UTF-8 boundaries, and helpers for line/token‑based truncation //! used across the core crate. +use crate::config::Config; use codex_protocol::models::FunctionCallOutputContentItem; -use codex_utils_string::take_bytes_at_char_boundary; -use codex_utils_string::take_last_bytes_at_char_boundary; -use codex_utils_tokenizer::Tokenizer; - -/// Model-formatting limits: clients get full streams; only content sent to the model is truncated. -pub const MODEL_FORMAT_MAX_BYTES: usize = 10 * 1024; // 10 KiB -pub const MODEL_FORMAT_MAX_LINES: usize = 256; // lines - -/// Globally truncate function output items to fit within `MODEL_FORMAT_MAX_BYTES` -/// by preserving as many text/image items as possible and appending a summary -/// for any omitted text items. -pub(crate) fn globally_truncate_function_output_items( + +const APPROX_BYTES_PER_TOKEN: usize = 4; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TruncationPolicy { + Bytes(usize), + Tokens(usize), +} + +impl TruncationPolicy { + /// Scale the underlying budget by `multiplier`, rounding up to avoid under-budgeting. + pub fn mul(self, multiplier: f64) -> Self { + match self { + TruncationPolicy::Bytes(bytes) => { + TruncationPolicy::Bytes((bytes as f64 * multiplier).ceil() as usize) + } + TruncationPolicy::Tokens(tokens) => { + TruncationPolicy::Tokens((tokens as f64 * multiplier).ceil() as usize) + } + } + } + + pub fn new(config: &Config) -> Self { + let config_token_limit = config.tool_output_token_limit; + + match config.model_family.truncation_policy { + TruncationPolicy::Bytes(family_bytes) => { + if let Some(token_limit) = config_token_limit { + Self::Bytes(approx_bytes_for_tokens(token_limit)) + } else { + Self::Bytes(family_bytes) + } + } + TruncationPolicy::Tokens(family_tokens) => { + if let Some(token_limit) = config_token_limit { + Self::Tokens(token_limit) + } else { + Self::Tokens(family_tokens) + } + } + } + } + + /// Returns a token budget derived from this policy. + /// + /// - For `Tokens`, this is the explicit token limit. + /// - For `Bytes`, this is an approximate token budget using the global + /// bytes-per-token heuristic. + pub fn token_budget(&self) -> usize { + match self { + TruncationPolicy::Bytes(bytes) => { + usize::try_from(approx_tokens_from_byte_count(*bytes)).unwrap_or(usize::MAX) + } + TruncationPolicy::Tokens(tokens) => *tokens, + } + } + + /// Returns a byte budget derived from this policy. + /// + /// - For `Bytes`, this is the explicit byte limit. + /// - For `Tokens`, this is an approximate byte budget using the global + /// bytes-per-token heuristic. + pub fn byte_budget(&self) -> usize { + match self { + TruncationPolicy::Bytes(bytes) => *bytes, + TruncationPolicy::Tokens(tokens) => approx_bytes_for_tokens(*tokens), + } + } +} + +pub(crate) fn formatted_truncate_text(content: &str, policy: TruncationPolicy) -> String { + if content.len() <= policy.byte_budget() { + return content.to_string(); + } + let total_lines = content.lines().count(); + let result = truncate_text(content, policy); + format!("Total output lines: {total_lines}\n\n{result}") +} + +pub(crate) fn truncate_text(content: &str, policy: TruncationPolicy) -> String { + match policy { + TruncationPolicy::Bytes(_) => truncate_with_byte_estimate(content, policy), + TruncationPolicy::Tokens(_) => { + let (truncated, _) = truncate_with_token_budget(content, policy); + truncated + } + } +} +/// Globally truncate function output items to fit within the given +/// truncation policy's budget, preserving as many text/image items as +/// possible and appending a summary for any omitted text items. +pub(crate) fn truncate_function_output_items_with_policy( items: &[FunctionCallOutputContentItem], + policy: TruncationPolicy, ) -> Vec { let mut out: Vec = Vec::with_capacity(items.len()); - let mut remaining = MODEL_FORMAT_MAX_BYTES; + let mut remaining_budget = match policy { + TruncationPolicy::Bytes(_) => policy.byte_budget(), + TruncationPolicy::Tokens(_) => policy.token_budget(), + }; let mut omitted_text_items = 0usize; for it in items { match it { FunctionCallOutputContentItem::InputText { text } => { - if remaining == 0 { + if remaining_budget == 0 { omitted_text_items += 1; continue; } - let len = text.len(); - if len <= remaining { + let cost = match policy { + TruncationPolicy::Bytes(_) => text.len(), + TruncationPolicy::Tokens(_) => approx_token_count(text), + }; + + if cost <= remaining_budget { out.push(FunctionCallOutputContentItem::InputText { text: text.clone() }); - remaining -= len; + remaining_budget = remaining_budget.saturating_sub(cost); } else { - let slice = take_bytes_at_char_boundary(text, remaining); - if !slice.is_empty() { - out.push(FunctionCallOutputContentItem::InputText { - text: slice.to_string(), - }); + let snippet_policy = match policy { + TruncationPolicy::Bytes(_) => TruncationPolicy::Bytes(remaining_budget), + TruncationPolicy::Tokens(_) => TruncationPolicy::Tokens(remaining_budget), + }; + let snippet = truncate_text(text, snippet_policy); + if snippet.is_empty() { + omitted_text_items += 1; + } else { + out.push(FunctionCallOutputContentItem::InputText { text: snippet }); } - remaining = 0; + remaining_budget = 0; } } - // todo(aibrahim): handle input images; resize FunctionCallOutputContentItem::InputImage { image_url } => { out.push(FunctionCallOutputContentItem::InputImage { image_url: image_url.clone(), @@ -61,513 +153,332 @@ pub(crate) fn globally_truncate_function_output_items( out } -/// Format a block of exec/tool output for model consumption, truncating by -/// lines and bytes while preserving head and tail segments. -pub(crate) fn format_output_for_model_body( - content: &str, - limit_bytes: usize, - limit_lines: usize, -) -> String { - // Head+tail truncation for the model: show the beginning and end with an elision. - // Clients still receive full streams; only this formatted summary is capped. - let total_lines = content.lines().count(); - if content.len() <= limit_bytes && total_lines <= limit_lines { - return content.to_string(); +/// Truncate the middle of a UTF-8 string to at most `max_tokens` tokens, +/// preserving the beginning and the end. Returns the possibly truncated string +/// and `Some(original_token_count)` if truncation occurred; otherwise returns +/// the original string and `None`. +fn truncate_with_token_budget(s: &str, policy: TruncationPolicy) -> (String, Option) { + if s.is_empty() { + return (String::new(), None); } - let output = truncate_formatted_exec_output(content, total_lines, limit_bytes, limit_lines); - format!("Total output lines: {total_lines}\n\n{output}") -} - -fn truncate_formatted_exec_output( - content: &str, - total_lines: usize, - limit_bytes: usize, - limit_lines: usize, -) -> String { - error_on_double_truncation(content); - let head_lines: usize = limit_lines / 2; - let tail_lines: usize = limit_lines - head_lines; // 128 - let head_bytes: usize = limit_bytes / 2; - let segments: Vec<&str> = content.split_inclusive('\n').collect(); - let head_take = head_lines.min(segments.len()); - let tail_take = tail_lines.min(segments.len().saturating_sub(head_take)); - let omitted = segments.len().saturating_sub(head_take + tail_take); - - let head_slice_end: usize = segments - .iter() - .take(head_take) - .map(|segment| segment.len()) - .sum(); - let tail_slice_start: usize = if tail_take == 0 { - content.len() - } else { - content.len() - - segments - .iter() - .rev() - .take(tail_take) - .map(|segment| segment.len()) - .sum::() - }; - let head_slice = &content[..head_slice_end]; - let tail_slice = &content[tail_slice_start..]; - let truncated_by_bytes = content.len() > limit_bytes; - // this is a bit wrong. We are counting metadata lines and not just shell output lines. - let marker = if omitted > 0 { - Some(format!( - "\n[... omitted {omitted} of {total_lines} lines ...]\n\n" - )) - } else if truncated_by_bytes { - Some(format!( - "\n[... output truncated to fit {limit_bytes} bytes ...]\n\n" - )) - } else { - None - }; - - let marker_len = marker.as_ref().map_or(0, String::len); - let base_head_budget = head_bytes.min(limit_bytes); - let head_budget = base_head_budget.min(limit_bytes.saturating_sub(marker_len)); - let head_part = take_bytes_at_char_boundary(head_slice, head_budget); - let mut result = String::with_capacity(limit_bytes.min(content.len())); + let max_tokens = policy.token_budget(); - result.push_str(head_part); - if let Some(marker_text) = marker.as_ref() { - result.push_str(marker_text); + let byte_len = s.len(); + if max_tokens > 0 && byte_len <= approx_bytes_for_tokens(max_tokens) { + return (s.to_string(), None); } - let remaining = limit_bytes.saturating_sub(result.len()); - if remaining == 0 { - return result; + let truncated = truncate_with_byte_estimate(s, policy); + let approx_total_usize = approx_token_count(s); + let approx_total = u64::try_from(approx_total_usize).unwrap_or(u64::MAX); + if truncated == s { + (truncated, None) + } else { + (truncated, Some(approx_total)) } +} - let tail_part = take_last_bytes_at_char_boundary(tail_slice, remaining); - result.push_str(tail_part); +/// Truncate a string using a byte budget derived from the token budget, without +/// performing any real tokenization. This keeps the logic purely byte-based and +/// uses a bytes placeholder in the truncated output. +fn truncate_with_byte_estimate(s: &str, policy: TruncationPolicy) -> String { + if s.is_empty() { + return String::new(); + } - result -} + let total_chars = s.chars().count(); + let max_bytes = policy.byte_budget(); -fn error_on_double_truncation(content: &str) { - if content.contains("Total output lines:") && content.contains("omitted") { - tracing::error!( - "FunctionCallOutput content was already truncated before ContextManager::record_items; this would cause double truncation {content}" + if max_bytes == 0 { + // No budget to show content; just report that everything was truncated. + let marker = format_truncation_marker( + policy, + removed_units_for_source(policy, s.len(), total_chars), ); + return marker; } -} - -/// Truncate an output string to a maximum number of “tokens”, where tokens are -/// approximated as individual `char`s. Preserves a prefix and suffix with an -/// elision marker describing how many tokens were omitted. -pub(crate) fn truncate_output_to_tokens( - output: &str, - max_tokens: usize, -) -> (String, Option) { - if max_tokens == 0 { - let total_tokens = output.chars().count(); - let message = format!("…{total_tokens} tokens truncated…"); - return (message, Some(total_tokens)); - } - - let tokens: Vec = output.chars().collect(); - let total_tokens = tokens.len(); - if total_tokens <= max_tokens { - return (output.to_string(), None); - } - - let half = max_tokens / 2; - if half == 0 { - let truncated = total_tokens.saturating_sub(max_tokens); - let message = format!("…{truncated} tokens truncated…"); - return (message, Some(total_tokens)); - } - - let truncated = total_tokens.saturating_sub(half * 2); - let mut truncated_output = String::new(); - truncated_output.extend(&tokens[..half]); - truncated_output.push_str(&format!("…{truncated} tokens truncated…")); - truncated_output.extend(&tokens[total_tokens - half..]); - (truncated_output, Some(total_tokens)) -} -/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes, -/// preserving the beginning and the end. Returns the possibly truncated -/// string and `Some(original_token_count)` (counted with the local tokenizer; -/// falls back to a 4-bytes-per-token estimate if the tokenizer cannot load) -/// if truncation occurred; otherwise returns the original string and `None`. -pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option) { if s.len() <= max_bytes { - return (s.to_string(), None); + return s.to_string(); } - // Build a tokenizer for counting (default to o200k_base; fall back to cl100k_base). - // If both fail, fall back to a 4-bytes-per-token estimate. - let tok = Tokenizer::try_default().ok(); - let token_count = |text: &str| -> u64 { - if let Some(ref t) = tok { - t.count(text) as u64 - } else { - (text.len() as u64).div_ceil(4) - } - }; + let total_bytes = s.len(); - let total_tokens = token_count(s); - if max_bytes == 0 { - return ( - format!("…{total_tokens} tokens truncated…"), - Some(total_tokens), - ); + let (left_budget, right_budget) = split_budget(max_bytes); + + let (removed_chars, left, right) = split_string(s, left_budget, right_budget); + + let marker = format_truncation_marker( + policy, + removed_units_for_source(policy, total_bytes.saturating_sub(max_bytes), removed_chars), + ); + + assemble_truncated_output(left, right, &marker) +} + +fn split_string(s: &str, beginning_bytes: usize, end_bytes: usize) -> (usize, &str, &str) { + if s.is_empty() { + return (0, "", ""); } - fn truncate_on_boundary(input: &str, max_len: usize) -> &str { - if input.len() <= max_len { - return input; + let len = s.len(); + let tail_start_target = len.saturating_sub(end_bytes); + let mut prefix_end = 0usize; + let mut suffix_start = len; + let mut removed_chars = 0usize; + let mut suffix_started = false; + + for (idx, ch) in s.char_indices() { + let char_end = idx + ch.len_utf8(); + if char_end <= beginning_bytes { + prefix_end = char_end; + continue; } - let mut end = max_len; - while end > 0 && !input.is_char_boundary(end) { - end -= 1; + + if idx >= tail_start_target { + if !suffix_started { + suffix_start = idx; + suffix_started = true; + } + continue; } - &input[..end] + + removed_chars = removed_chars.saturating_add(1); } - fn pick_prefix_end(s: &str, left_budget: usize) -> usize { - if let Some(head) = s.get(..left_budget) - && let Some(i) = head.rfind('\n') - { - return i + 1; - } - truncate_on_boundary(s, left_budget).len() + if suffix_start < prefix_end { + suffix_start = prefix_end; } - fn pick_suffix_start(s: &str, right_budget: usize) -> usize { - let start_tail = s.len().saturating_sub(right_budget); - if let Some(tail) = s.get(start_tail..) - && let Some(i) = tail.find('\n') - { - return start_tail + i + 1; - } + let before = &s[..prefix_end]; + let after = &s[suffix_start..]; - let mut idx = start_tail.min(s.len()); - while idx < s.len() && !s.is_char_boundary(idx) { - idx += 1; - } - idx - } - - // Iterate to stabilize marker length → keep budget → boundaries. - let mut guess_tokens: u64 = 1; - for _ in 0..4 { - let marker = format!("…{guess_tokens} tokens truncated…"); - let marker_len = marker.len(); - let keep_budget = max_bytes.saturating_sub(marker_len); - if keep_budget == 0 { - return ( - format!("…{total_tokens} tokens truncated…"), - Some(total_tokens), - ); - } + (removed_chars, before, after) +} - let left_budget = keep_budget / 2; - let right_budget = keep_budget - left_budget; - let prefix_end = pick_prefix_end(s, left_budget); - let mut suffix_start = pick_suffix_start(s, right_budget); - if suffix_start < prefix_end { - suffix_start = prefix_end; - } +fn format_truncation_marker(policy: TruncationPolicy, removed_count: u64) -> String { + match policy { + TruncationPolicy::Tokens(_) => format!("…{removed_count} tokens truncated…"), + TruncationPolicy::Bytes(_) => format!("…{removed_count} chars truncated…"), + } +} - // Tokens actually removed (middle slice) using the real tokenizer. - let removed_tokens = token_count(&s[prefix_end..suffix_start]); - - // If the number of digits in the token count does not change the marker length, - // we can finalize output. - let final_marker = format!("…{removed_tokens} tokens truncated…"); - if final_marker.len() == marker_len { - let kept_content_bytes = prefix_end + (s.len() - suffix_start); - let mut out = String::with_capacity(final_marker.len() + kept_content_bytes + 1); - out.push_str(&s[..prefix_end]); - out.push_str(&final_marker); - out.push('\n'); - out.push_str(&s[suffix_start..]); - return (out, Some(total_tokens)); - } +fn split_budget(budget: usize) -> (usize, usize) { + let left = budget / 2; + (left, budget - left) +} - guess_tokens = removed_tokens; +fn removed_units_for_source( + policy: TruncationPolicy, + removed_bytes: usize, + removed_chars: usize, +) -> u64 { + match policy { + TruncationPolicy::Tokens(_) => approx_tokens_from_byte_count(removed_bytes), + TruncationPolicy::Bytes(_) => u64::try_from(removed_chars).unwrap_or(u64::MAX), } +} - // Fallback build after iterations: compute with the last guess. - let marker = format!("…{guess_tokens} tokens truncated…"); - let marker_len = marker.len(); - let keep_budget = max_bytes.saturating_sub(marker_len); - if keep_budget == 0 { - return ( - format!("…{total_tokens} tokens truncated…"), - Some(total_tokens), - ); - } +fn assemble_truncated_output(prefix: &str, suffix: &str, marker: &str) -> String { + let mut out = String::with_capacity(prefix.len() + marker.len() + suffix.len() + 1); + out.push_str(prefix); + out.push_str(marker); + out.push_str(suffix); + out +} - let left_budget = keep_budget / 2; - let right_budget = keep_budget - left_budget; - let prefix_end = pick_prefix_end(s, left_budget); - let mut suffix_start = pick_suffix_start(s, right_budget); - if suffix_start < prefix_end { - suffix_start = prefix_end; - } +pub(crate) fn approx_token_count(text: &str) -> usize { + let len = text.len(); + len.saturating_add(APPROX_BYTES_PER_TOKEN.saturating_sub(1)) / APPROX_BYTES_PER_TOKEN +} + +fn approx_bytes_for_tokens(tokens: usize) -> usize { + tokens.saturating_mul(APPROX_BYTES_PER_TOKEN) +} - let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1); - out.push_str(&s[..prefix_end]); - out.push_str(&marker); - out.push('\n'); - out.push_str(&s[suffix_start..]); - (out, Some(total_tokens)) +fn approx_tokens_from_byte_count(bytes: usize) -> u64 { + let bytes_u64 = bytes as u64; + bytes_u64.saturating_add((APPROX_BYTES_PER_TOKEN as u64).saturating_sub(1)) + / (APPROX_BYTES_PER_TOKEN as u64) } #[cfg(test)] mod tests { - use super::MODEL_FORMAT_MAX_BYTES; - use super::MODEL_FORMAT_MAX_LINES; - use super::format_output_for_model_body; - use super::globally_truncate_function_output_items; - use super::truncate_middle; - use super::truncate_output_to_tokens; + + use super::TruncationPolicy; + use super::approx_token_count; + use super::formatted_truncate_text; + use super::split_string; + use super::truncate_function_output_items_with_policy; + use super::truncate_text; + use super::truncate_with_token_budget; use codex_protocol::models::FunctionCallOutputContentItem; - use codex_utils_tokenizer::Tokenizer; use pretty_assertions::assert_eq; - use regex_lite::Regex; - - fn truncated_message_pattern(line: &str, total_lines: usize) -> String { - let head_lines = MODEL_FORMAT_MAX_LINES / 2; - let tail_lines = MODEL_FORMAT_MAX_LINES - head_lines; - let head_take = head_lines.min(total_lines); - let tail_take = tail_lines.min(total_lines.saturating_sub(head_take)); - let omitted = total_lines.saturating_sub(head_take + tail_take); - let escaped_line = regex_lite::escape(line); - if omitted == 0 { - return format!( - r"(?s)^Total output lines: {total_lines}\n\n(?P{escaped_line}.*\n\[\.{{3}} output truncated to fit {MODEL_FORMAT_MAX_BYTES} bytes \.{{3}}]\n\n.*)$", - ); - } - format!( - r"(?s)^Total output lines: {total_lines}\n\n(?P{escaped_line}.*\n\[\.{{3}} omitted {omitted} of {total_lines} lines \.{{3}}]\n\n.*)$", - ) - } #[test] - fn truncate_middle_no_newlines_fallback() { - let tok = Tokenizer::try_default().expect("load tokenizer"); - let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*"; - let max_bytes = 32; - let (out, original) = truncate_middle(s, max_bytes); - assert!(out.starts_with("abc")); - assert!(out.contains("tokens truncated")); - assert!(out.ends_with("XYZ*")); - assert_eq!(original, Some(tok.count(s) as u64)); + fn split_string_works() { + assert_eq!(split_string("hello world", 5, 5), (1, "hello", "world")); + assert_eq!(split_string("abc", 0, 0), (3, "", "")); } #[test] - fn truncate_middle_prefers_newline_boundaries() { - let tok = Tokenizer::try_default().expect("load tokenizer"); - let mut s = String::new(); - for i in 1..=20 { - s.push_str(&format!("{i:03}\n")); - } - assert_eq!(s.len(), 80); + fn split_string_handles_empty_string() { + assert_eq!(split_string("", 4, 4), (0, "", "")); + } - let max_bytes = 64; - let (out, tokens) = truncate_middle(&s, max_bytes); - assert!(out.starts_with("001\n002\n003\n004\n")); - assert!(out.contains("tokens truncated")); - assert!(out.ends_with("017\n018\n019\n020\n")); - assert_eq!(tokens, Some(tok.count(&s) as u64)); + #[test] + fn split_string_only_keeps_prefix_when_tail_budget_is_zero() { + assert_eq!(split_string("abcdef", 3, 0), (3, "abc", "")); } #[test] - fn truncate_middle_handles_utf8_content() { - let tok = Tokenizer::try_default().expect("load tokenizer"); - let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n"; - let max_bytes = 32; - let (out, tokens) = truncate_middle(s, max_bytes); + fn split_string_only_keeps_suffix_when_prefix_budget_is_zero() { + assert_eq!(split_string("abcdef", 0, 3), (3, "", "def")); + } - assert!(out.contains("tokens truncated")); - assert!(!out.contains('\u{fffd}')); - assert_eq!(tokens, Some(tok.count(s) as u64)); + #[test] + fn split_string_handles_overlapping_budgets_without_removal() { + assert_eq!(split_string("abcdef", 4, 4), (0, "abcd", "ef")); } #[test] - fn truncate_middle_prefers_newline_boundaries_2() { - let tok = Tokenizer::try_default().expect("load tokenizer"); - // Build a multi-line string of 20 numbered lines (each "NNN\n"). - let mut s = String::new(); - for i in 1..=20 { - s.push_str(&format!("{i:03}\n")); - } - assert_eq!(s.len(), 80); + fn split_string_respects_utf8_boundaries() { + assert_eq!(split_string("😀abc😀", 5, 5), (1, "😀a", "c😀")); - let max_bytes = 64; - let (out, total) = truncate_middle(&s, max_bytes); - assert!(out.starts_with("001\n002\n003\n004\n")); - assert!(out.contains("tokens truncated")); - assert!(out.ends_with("017\n018\n019\n020\n")); - assert_eq!(total, Some(tok.count(&s) as u64)); + assert_eq!(split_string("😀😀😀😀😀", 1, 1), (5, "", "")); + assert_eq!(split_string("😀😀😀😀😀", 7, 7), (3, "😀", "😀")); + assert_eq!(split_string("😀😀😀😀😀", 8, 8), (1, "😀😀", "😀😀")); } #[test] - fn truncate_output_to_tokens_returns_original_when_under_limit() { - let s = "short output"; - let (truncated, original) = truncate_output_to_tokens(s, 100); - assert_eq!(truncated, s); - assert_eq!(original, None); + fn truncate_bytes_less_than_placeholder_returns_placeholder() { + let content = "example output"; + + assert_eq!( + "Total output lines: 1\n\n…13 chars truncated…t", + formatted_truncate_text(content, TruncationPolicy::Bytes(1)), + ); } #[test] - fn truncate_output_to_tokens_reports_truncation_at_zero_limit() { - let s = "abcdef"; - let (truncated, original) = truncate_output_to_tokens(s, 0); - assert!(truncated.contains("tokens truncated")); - assert_eq!(original, Some(s.chars().count())); + fn truncate_tokens_less_than_placeholder_returns_placeholder() { + let content = "example output"; + + assert_eq!( + "Total output lines: 1\n\nex…3 tokens truncated…ut", + formatted_truncate_text(content, TruncationPolicy::Tokens(1)), + ); } #[test] - fn truncate_output_to_tokens_preserves_prefix_and_suffix() { - let s = "abcdefghijklmnopqrstuvwxyz"; - let max_tokens = 10; - let (truncated, original) = truncate_output_to_tokens(s, max_tokens); - assert!(truncated.starts_with("abcde")); - assert!(truncated.ends_with("vwxyz")); - assert_eq!(original, Some(s.chars().count())); + fn truncate_tokens_under_limit_returns_original() { + let content = "example output"; + + assert_eq!( + content, + formatted_truncate_text(content, TruncationPolicy::Tokens(10)), + ); } #[test] - fn format_exec_output_truncates_large_error() { - let line = "very long execution error line that should trigger truncation\n"; - let large_error = line.repeat(2_500); // way beyond both byte and line limits - - let truncated = format_output_for_model_body( - &large_error, - MODEL_FORMAT_MAX_BYTES, - MODEL_FORMAT_MAX_LINES, - ); + fn truncate_bytes_under_limit_returns_original() { + let content = "example output"; - let total_lines = large_error.lines().count(); - let pattern = truncated_message_pattern(line, total_lines); - let regex = Regex::new(&pattern).unwrap_or_else(|err| { - panic!("failed to compile regex {pattern}: {err}"); - }); - let captures = regex - .captures(&truncated) - .unwrap_or_else(|| panic!("message failed to match pattern {pattern}: {truncated}")); - let body = captures - .name("body") - .expect("missing body capture") - .as_str(); - assert!( - body.len() <= MODEL_FORMAT_MAX_BYTES, - "body exceeds byte limit: {} bytes", - body.len() + assert_eq!( + content, + formatted_truncate_text(content, TruncationPolicy::Bytes(20)), ); - assert_ne!(truncated, large_error); } #[test] - fn format_exec_output_marks_byte_truncation_without_omitted_lines() { - let long_line = "a".repeat(MODEL_FORMAT_MAX_BYTES + 50); - let truncated = format_output_for_model_body( - &long_line, - MODEL_FORMAT_MAX_BYTES, - MODEL_FORMAT_MAX_LINES, - ); + fn truncate_tokens_over_limit_returns_truncated() { + let content = "this is an example of a long output that should be truncated"; - assert_ne!(truncated, long_line); - let marker_line = - format!("[... output truncated to fit {MODEL_FORMAT_MAX_BYTES} bytes ...]"); - assert!( - truncated.contains(&marker_line), - "missing byte truncation marker: {truncated}" - ); - assert!( - !truncated.contains("omitted"), - "line omission marker should not appear when no lines were dropped: {truncated}" + assert_eq!( + "Total output lines: 1\n\nthis is an…10 tokens truncated… truncated", + formatted_truncate_text(content, TruncationPolicy::Tokens(5)), ); } #[test] - fn format_exec_output_returns_original_when_within_limits() { - let content = "example output\n".repeat(10); + fn truncate_bytes_over_limit_returns_truncated() { + let content = "this is an example of a long output that should be truncated"; assert_eq!( - format_output_for_model_body(&content, MODEL_FORMAT_MAX_BYTES, MODEL_FORMAT_MAX_LINES), - content + "Total output lines: 1\n\nthis is an exam…30 chars truncated…ld be truncated", + formatted_truncate_text(content, TruncationPolicy::Bytes(30)), ); } #[test] - fn format_exec_output_reports_omitted_lines_and_keeps_head_and_tail() { - let total_lines = MODEL_FORMAT_MAX_LINES + 100; - let content: String = (0..total_lines) - .map(|idx| format!("line-{idx}\n")) - .collect(); + fn truncate_bytes_reports_original_line_count_when_truncated() { + let content = + "this is an example of a long output that should be truncated\nalso some other line"; - let truncated = - format_output_for_model_body(&content, MODEL_FORMAT_MAX_BYTES, MODEL_FORMAT_MAX_LINES); + assert_eq!( + "Total output lines: 2\n\nthis is an exam…51 chars truncated…some other line", + formatted_truncate_text(content, TruncationPolicy::Bytes(30)), + ); + } - let omitted = total_lines - MODEL_FORMAT_MAX_LINES; - let expected_marker = format!("[... omitted {omitted} of {total_lines} lines ...]"); + #[test] + fn truncate_tokens_reports_original_line_count_when_truncated() { + let content = + "this is an example of a long output that should be truncated\nalso some other line"; - assert!( - truncated.contains(&expected_marker), - "missing omitted marker: {truncated}" - ); - assert!( - truncated.contains("line-0\n"), - "expected head line to remain: {truncated}" + assert_eq!( + "Total output lines: 2\n\nthis is an example o…11 tokens truncated…also some other line", + formatted_truncate_text(content, TruncationPolicy::Tokens(10)), ); + } - let last_line = format!("line-{}\n", total_lines - 1); - assert!( - truncated.contains(&last_line), - "expected tail line to remain: {truncated}" - ); + #[test] + fn truncate_with_token_budget_returns_original_when_under_limit() { + let s = "short output"; + let limit = 100; + let (out, original) = truncate_with_token_budget(s, TruncationPolicy::Tokens(limit)); + assert_eq!(out, s); + assert_eq!(original, None); } #[test] - fn format_exec_output_prefers_line_marker_when_both_limits_exceeded() { - let total_lines = MODEL_FORMAT_MAX_LINES + 42; - let long_line = "x".repeat(256); - let content: String = (0..total_lines) - .map(|idx| format!("line-{idx}-{long_line}\n")) - .collect(); + fn truncate_with_token_budget_reports_truncation_at_zero_limit() { + let s = "abcdef"; + let (out, original) = truncate_with_token_budget(s, TruncationPolicy::Tokens(0)); + assert_eq!(out, "…2 tokens truncated…"); + assert_eq!(original, Some(2)); + } - let truncated = - format_output_for_model_body(&content, MODEL_FORMAT_MAX_BYTES, MODEL_FORMAT_MAX_LINES); + #[test] + fn truncate_middle_tokens_handles_utf8_content() { + let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with text\n"; + let (out, tokens) = truncate_with_token_budget(s, TruncationPolicy::Tokens(8)); + assert_eq!(out, "😀😀😀😀…8 tokens truncated… line with text\n"); + assert_eq!(tokens, Some(16)); + } - assert!( - truncated.contains("[... omitted 42 of 298 lines ...]"), - "expected omitted marker when line count exceeds limit: {truncated}" - ); - assert!( - !truncated.contains("output truncated to fit"), - "line omission marker should take precedence over byte marker: {truncated}" - ); + #[test] + fn truncate_middle_bytes_handles_utf8_content() { + let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with text\n"; + let out = truncate_text(s, TruncationPolicy::Bytes(20)); + assert_eq!(out, "😀😀…21 chars truncated…with text\n"); } #[test] fn truncates_across_multiple_under_limit_texts_and_reports_omitted() { - // Arrange: several text items, none exceeding per-item limit, but total exceeds budget. - let budget = MODEL_FORMAT_MAX_BYTES; - let t1_len = (budget / 2).saturating_sub(10); - let t2_len = (budget / 2).saturating_sub(10); - let remaining_after_t1_t2 = budget.saturating_sub(t1_len + t2_len); - let t3_len = 50; // gets truncated to remaining_after_t1_t2 - let t4_len = 5; // omitted - let t5_len = 7; // omitted - - let t1 = "a".repeat(t1_len); - let t2 = "b".repeat(t2_len); - let t3 = "c".repeat(t3_len); - let t4 = "d".repeat(t4_len); - let t5 = "e".repeat(t5_len); + let chunk = "alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega.\n"; + let chunk_tokens = approx_token_count(chunk); + assert!(chunk_tokens > 0, "chunk must consume tokens"); + let limit = chunk_tokens * 3; + let t1 = chunk.to_string(); + let t2 = chunk.to_string(); + let t3 = chunk.repeat(10); + let t4 = chunk.to_string(); + let t5 = chunk.to_string(); let items = vec![ - FunctionCallOutputContentItem::InputText { text: t1 }, - FunctionCallOutputContentItem::InputText { text: t2 }, + FunctionCallOutputContentItem::InputText { text: t1.clone() }, + FunctionCallOutputContentItem::InputText { text: t2.clone() }, FunctionCallOutputContentItem::InputImage { image_url: "img:mid".to_string(), }, @@ -576,7 +487,8 @@ mod tests { FunctionCallOutputContentItem::InputText { text: t5 }, ]; - let output = globally_truncate_function_output_items(&items); + let output = + truncate_function_output_items_with_policy(&items, TruncationPolicy::Tokens(limit)); // Expect: t1 (full), t2 (full), image, t3 (truncated), summary mentioning 2 omitted. assert_eq!(output.len(), 5); @@ -585,13 +497,13 @@ mod tests { FunctionCallOutputContentItem::InputText { text } => text, other => panic!("unexpected first item: {other:?}"), }; - assert_eq!(first_text.len(), t1_len); + assert_eq!(first_text, &t1); let second_text = match &output[1] { FunctionCallOutputContentItem::InputText { text } => text, other => panic!("unexpected second item: {other:?}"), }; - assert_eq!(second_text.len(), t2_len); + assert_eq!(second_text, &t2); assert_eq!( output[2], @@ -604,7 +516,10 @@ mod tests { FunctionCallOutputContentItem::InputText { text } => text, other => panic!("unexpected fourth item: {other:?}"), }; - assert_eq!(fourth_text.len(), remaining_after_t1_t2); + assert!( + fourth_text.contains("tokens truncated"), + "expected marker in truncated snippet: {fourth_text}" + ); let summary_text = match &output[4] { FunctionCallOutputContentItem::InputText { text } => text, diff --git a/codex-rs/core/src/unified_exec/mod.rs b/codex-rs/core/src/unified_exec/mod.rs index f77b74449..390401d78 100644 --- a/codex-rs/core/src/unified_exec/mod.rs +++ b/codex-rs/core/src/unified_exec/mod.rs @@ -45,6 +45,7 @@ pub(crate) const MIN_YIELD_TIME_MS: u64 = 250; pub(crate) const MAX_YIELD_TIME_MS: u64 = 30_000; pub(crate) const DEFAULT_MAX_OUTPUT_TOKENS: usize = 10_000; pub(crate) const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 1024 * 1024; // 1 MiB +pub(crate) const UNIFIED_EXEC_OUTPUT_MAX_TOKENS: usize = UNIFIED_EXEC_OUTPUT_MAX_BYTES / 4; pub(crate) struct UnifiedExecContext { pub session: Arc, diff --git a/codex-rs/core/src/unified_exec/session.rs b/codex-rs/core/src/unified_exec/session.rs index bdb935f17..710334c80 100644 --- a/codex-rs/core/src/unified_exec/session.rs +++ b/codex-rs/core/src/unified_exec/session.rs @@ -2,23 +2,25 @@ use std::collections::VecDeque; use std::sync::Arc; - use tokio::sync::Mutex; use tokio::sync::Notify; use tokio::sync::mpsc; use tokio::sync::oneshot::error::TryRecvError; use tokio::task::JoinHandle; use tokio::time::Duration; +use tokio_util::sync::CancellationToken; use crate::exec::ExecToolCallOutput; use crate::exec::SandboxType; use crate::exec::StreamOutput; use crate::exec::is_likely_sandbox_denied; -use crate::truncate::truncate_middle; +use crate::truncate::TruncationPolicy; +use crate::truncate::formatted_truncate_text; use codex_utils_pty::ExecCommandSession; use codex_utils_pty::SpawnedPty; use super::UNIFIED_EXEC_OUTPUT_MAX_BYTES; +use super::UNIFIED_EXEC_OUTPUT_MAX_TOKENS; use super::UnifiedExecError; #[derive(Debug, Default)] @@ -65,13 +67,18 @@ impl OutputBufferState { } pub(crate) type OutputBuffer = Arc>; -pub(crate) type OutputHandles = (OutputBuffer, Arc); +pub(crate) struct OutputHandles { + pub(crate) output_buffer: OutputBuffer, + pub(crate) output_notify: Arc, + pub(crate) cancellation_token: CancellationToken, +} #[derive(Debug)] pub(crate) struct UnifiedExecSession { session: ExecCommandSession, output_buffer: OutputBuffer, output_notify: Arc, + cancellation_token: CancellationToken, output_task: JoinHandle<()>, sandbox_type: SandboxType, } @@ -84,9 +91,11 @@ impl UnifiedExecSession { ) -> Self { let output_buffer = Arc::new(Mutex::new(OutputBufferState::default())); let output_notify = Arc::new(Notify::new()); + let cancellation_token = CancellationToken::new(); let mut receiver = initial_output_rx; let buffer_clone = Arc::clone(&output_buffer); let notify_clone = Arc::clone(&output_notify); + let cancellation_token_clone = cancellation_token.clone(); let output_task = tokio::spawn(async move { loop { match receiver.recv().await { @@ -97,7 +106,10 @@ impl UnifiedExecSession { notify_clone.notify_waiters(); } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + cancellation_token_clone.cancel(); + break; + } } } }); @@ -106,6 +118,7 @@ impl UnifiedExecSession { session, output_buffer, output_notify, + cancellation_token, output_task, sandbox_type, } @@ -116,10 +129,11 @@ impl UnifiedExecSession { } pub(super) fn output_handles(&self) -> OutputHandles { - ( - Arc::clone(&self.output_buffer), - Arc::clone(&self.output_notify), - ) + OutputHandles { + output_buffer: Arc::clone(&self.output_buffer), + output_notify: Arc::clone(&self.output_notify), + cancellation_token: self.cancellation_token.clone(), + } } pub(super) fn has_exited(&self) -> bool { @@ -165,7 +179,10 @@ impl UnifiedExecSession { }; if is_likely_sandbox_denied(self.sandbox_type(), &exec_output) { - let (snippet, _) = truncate_middle(&aggregated_text, UNIFIED_EXEC_OUTPUT_MAX_BYTES); + let snippet = formatted_truncate_text( + &aggregated_text, + TruncationPolicy::Tokens(UNIFIED_EXEC_OUTPUT_MAX_TOKENS), + ); let message = if snippet.is_empty() { format!("exit code {exit_code}") } else { @@ -194,20 +211,34 @@ impl UnifiedExecSession { }; if exit_ready { + managed.signal_exit(); managed.check_for_sandbox_denial().await?; return Ok(managed); } - tokio::pin!(exit_rx); if tokio::time::timeout(Duration::from_millis(50), &mut exit_rx) .await .is_ok() { + managed.signal_exit(); managed.check_for_sandbox_denial().await?; + return Ok(managed); } + tokio::spawn({ + let cancellation_token = managed.cancellation_token.clone(); + async move { + let _ = exit_rx.await; + cancellation_token.cancel(); + } + }); + Ok(managed) } + + fn signal_exit(&self) { + self.cancellation_token.cancel(); + } } impl Drop for UnifiedExecSession { diff --git a/codex-rs/core/src/unified_exec/session_manager.rs b/codex-rs/core/src/unified_exec/session_manager.rs index fee46df8b..d9f99b9ea 100644 --- a/codex-rs/core/src/unified_exec/session_manager.rs +++ b/codex-rs/core/src/unified_exec/session_manager.rs @@ -5,16 +5,19 @@ use tokio::sync::Notify; use tokio::sync::mpsc; use tokio::time::Duration; use tokio::time::Instant; +use tokio_util::sync::CancellationToken; use crate::codex::Session; use crate::codex::TurnContext; use crate::exec::ExecToolCallOutput; use crate::exec::StreamOutput; use crate::exec_env::create_env; +use crate::exec_policy::create_approval_requirement_for_command; use crate::protocol::BackgroundEventEvent; use crate::protocol::EventMsg; use crate::protocol::ExecCommandSource; use crate::sandboxing::ExecEnv; +use crate::sandboxing::SandboxPermissions; use crate::tools::events::ToolEmitter; use crate::tools::events::ToolEventCtx; use crate::tools::events::ToolEventFailure; @@ -23,6 +26,9 @@ use crate::tools::orchestrator::ToolOrchestrator; use crate::tools::runtimes::unified_exec::UnifiedExecRequest as UnifiedExecToolRequest; use crate::tools::runtimes::unified_exec::UnifiedExecRuntime; use crate::tools::sandboxing::ToolCtx; +use crate::truncate::TruncationPolicy; +use crate::truncate::approx_token_count; +use crate::truncate::formatted_truncate_text; use super::ExecCommandRequest; use super::SessionEntry; @@ -35,8 +41,19 @@ use super::clamp_yield_time; use super::generate_chunk_id; use super::resolve_max_tokens; use super::session::OutputBuffer; +use super::session::OutputHandles; use super::session::UnifiedExecSession; -use crate::truncate::truncate_output_to_tokens; + +struct PreparedSessionHandles { + writer_tx: mpsc::Sender>, + output_buffer: OutputBuffer, + output_notify: Arc, + cancellation_token: CancellationToken, + session_ref: Arc, + turn_ref: Arc, + command: Vec, + cwd: PathBuf, +} impl UnifiedExecSessionManager { pub(crate) async fn exec_command( @@ -63,14 +80,23 @@ impl UnifiedExecSessionManager { let yield_time_ms = clamp_yield_time(request.yield_time_ms); let start = Instant::now(); - let (output_buffer, output_notify) = session.output_handles(); + let OutputHandles { + output_buffer, + output_notify, + cancellation_token, + } = session.output_handles(); let deadline = start + Duration::from_millis(yield_time_ms); - let collected = - Self::collect_output_until_deadline(&output_buffer, &output_notify, deadline).await; + let collected = Self::collect_output_until_deadline( + &output_buffer, + &output_notify, + &cancellation_token, + deadline, + ) + .await; let wall_time = Instant::now().saturating_duration_since(start); let text = String::from_utf8_lossy(&collected).to_string(); - let (output, original_token_count) = truncate_output_to_tokens(&text, max_tokens); + let output = formatted_truncate_text(&text, TruncationPolicy::Tokens(max_tokens)); let chunk_id = generate_chunk_id(); let has_exited = session.has_exited(); let stored_id = self @@ -85,6 +111,8 @@ impl UnifiedExecSessionManager { // Only include a session_id in the response if the process is still alive. let session_id = if has_exited { None } else { Some(stored_id) }; + let original_token_count = approx_token_count(&text); + let response = UnifiedExecResponse { event_call_id: context.call_id.clone(), chunk_id, @@ -92,7 +120,7 @@ impl UnifiedExecSessionManager { output, session_id, exit_code: exit_code.flatten(), - original_token_count, + original_token_count: Some(original_token_count), session_command: Some(request.command.clone()), }; @@ -123,15 +151,16 @@ impl UnifiedExecSessionManager { ) -> Result { let session_id = request.session_id; - let ( + let PreparedSessionHandles { writer_tx, output_buffer, output_notify, + cancellation_token, session_ref, turn_ref, - session_command, - session_cwd, - ) = self.prepare_session_handles(session_id).await?; + command: session_command, + cwd: session_cwd, + } = self.prepare_session_handles(session_id).await?; let interaction_emitter = ToolEmitter::unified_exec( &session_command, @@ -170,12 +199,18 @@ impl UnifiedExecSessionManager { let yield_time_ms = clamp_yield_time(request.yield_time_ms); let start = Instant::now(); let deadline = start + Duration::from_millis(yield_time_ms); - let collected = - Self::collect_output_until_deadline(&output_buffer, &output_notify, deadline).await; + let collected = Self::collect_output_until_deadline( + &output_buffer, + &output_notify, + &cancellation_token, + deadline, + ) + .await; let wall_time = Instant::now().saturating_duration_since(start); let text = String::from_utf8_lossy(&collected).to_string(); - let (output, original_token_count) = truncate_output_to_tokens(&text, max_tokens); + let output = formatted_truncate_text(&text, TruncationPolicy::Tokens(max_tokens)); + let original_token_count = approx_token_count(&text); let chunk_id = generate_chunk_id(); let status = self.refresh_session_state(session_id).await; @@ -199,7 +234,7 @@ impl UnifiedExecSessionManager { output, session_id, exit_code, - original_token_count, + original_token_count: Some(original_token_count), session_command: Some(session_command.clone()), }; @@ -258,44 +293,27 @@ impl UnifiedExecSessionManager { async fn prepare_session_handles( &self, session_id: i32, - ) -> Result< - ( - mpsc::Sender>, - OutputBuffer, - Arc, - Arc, - Arc, - Vec, - PathBuf, - ), - UnifiedExecError, - > { + ) -> Result { let sessions = self.sessions.lock().await; - let (output_buffer, output_notify, writer_tx, session, turn, command, cwd) = - if let Some(entry) = sessions.get(&session_id) { - let (buffer, notify) = entry.session.output_handles(); - ( - buffer, - notify, - entry.session.writer_sender(), - Arc::clone(&entry.session_ref), - Arc::clone(&entry.turn_ref), - entry.command.clone(), - entry.cwd.clone(), - ) - } else { - return Err(UnifiedExecError::UnknownSessionId { session_id }); - }; + let entry = sessions + .get(&session_id) + .ok_or(UnifiedExecError::UnknownSessionId { session_id })?; + let OutputHandles { + output_buffer, + output_notify, + cancellation_token, + } = entry.session.output_handles(); - Ok(( - writer_tx, + Ok(PreparedSessionHandles { + writer_tx: entry.session.writer_sender(), output_buffer, output_notify, - session, - turn, - command, - cwd, - )) + cancellation_token, + session_ref: Arc::clone(&entry.session_ref), + turn_ref: Arc::clone(&entry.turn_ref), + command: entry.command.clone(), + cwd: entry.cwd.clone(), + }) } async fn send_input( @@ -444,6 +462,13 @@ impl UnifiedExecSessionManager { create_env(&context.turn.shell_environment_policy), with_escalated_permissions, justification, + create_approval_requirement_for_command( + &context.turn.exec_policy, + command, + context.turn.approval_policy, + &context.turn.sandbox_policy, + SandboxPermissions::from(with_escalated_permissions.unwrap_or(false)), + ), ); let tool_ctx = ToolCtx { session: context.session.as_ref(), @@ -466,9 +491,13 @@ impl UnifiedExecSessionManager { pub(super) async fn collect_output_until_deadline( output_buffer: &OutputBuffer, output_notify: &Arc, + cancellation_token: &CancellationToken, deadline: Instant, ) -> Vec { + const POST_EXIT_OUTPUT_GRACE: Duration = Duration::from_millis(25); + let mut collected: Vec = Vec::with_capacity(4096); + let mut exit_signal_received = cancellation_token.is_cancelled(); loop { let drained_chunks; let mut wait_for_output = None; @@ -481,15 +510,27 @@ impl UnifiedExecSessionManager { } if drained_chunks.is_empty() { + exit_signal_received |= cancellation_token.is_cancelled(); let remaining = deadline.saturating_duration_since(Instant::now()); if remaining == Duration::ZERO { break; } let notified = wait_for_output.unwrap_or_else(|| output_notify.notified()); + if exit_signal_received { + let grace = remaining.min(POST_EXIT_OUTPUT_GRACE); + if tokio::time::timeout(grace, notified).await.is_err() { + break; + } + continue; + } + tokio::pin!(notified); + let exit_notified = cancellation_token.cancelled(); + tokio::pin!(exit_notified); tokio::select! { _ = &mut notified => {} + _ = &mut exit_notified => exit_signal_received = true, _ = tokio::time::sleep(remaining) => break, } continue; @@ -499,6 +540,7 @@ impl UnifiedExecSessionManager { collected.extend_from_slice(&chunk); } + exit_signal_received |= cancellation_token.is_cancelled(); if Instant::now() >= deadline { break; } diff --git a/codex-rs/core/src/user_shell_command.rs b/codex-rs/core/src/user_shell_command.rs index 7f0731c96..857e01c06 100644 --- a/codex-rs/core/src/user_shell_command.rs +++ b/codex-rs/core/src/user_shell_command.rs @@ -3,6 +3,7 @@ use std::time::Duration; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; +use crate::codex::TurnContext; use crate::exec::ExecToolCallOutput; use crate::tools::format_exec_output_str; @@ -20,7 +21,11 @@ fn format_duration_line(duration: Duration) -> String { format!("Duration: {duration_seconds:.4} seconds") } -fn format_user_shell_command_body(command: &str, exec_output: &ExecToolCallOutput) -> String { +fn format_user_shell_command_body( + command: &str, + exec_output: &ExecToolCallOutput, + turn_context: &TurnContext, +) -> String { let mut sections = Vec::new(); sections.push("".to_string()); sections.push(command.to_string()); @@ -29,25 +34,33 @@ fn format_user_shell_command_body(command: &str, exec_output: &ExecToolCallOutpu sections.push(format!("Exit code: {}", exec_output.exit_code)); sections.push(format_duration_line(exec_output.duration)); sections.push("Output:".to_string()); - sections.push(format_exec_output_str(exec_output)); + sections.push(format_exec_output_str( + exec_output, + turn_context.truncation_policy, + )); sections.push("".to_string()); sections.join("\n") } -pub fn format_user_shell_command_record(command: &str, exec_output: &ExecToolCallOutput) -> String { - let body = format_user_shell_command_body(command, exec_output); +pub fn format_user_shell_command_record( + command: &str, + exec_output: &ExecToolCallOutput, + turn_context: &TurnContext, +) -> String { + let body = format_user_shell_command_body(command, exec_output, turn_context); format!("{USER_SHELL_COMMAND_OPEN}\n{body}\n{USER_SHELL_COMMAND_CLOSE}") } pub fn user_shell_command_record_item( command: &str, exec_output: &ExecToolCallOutput, + turn_context: &TurnContext, ) -> ResponseItem { ResponseItem::Message { id: None, role: "user".to_string(), content: vec![ContentItem::InputText { - text: format_user_shell_command_record(command, exec_output), + text: format_user_shell_command_record(command, exec_output, turn_context), }], } } @@ -55,6 +68,7 @@ pub fn user_shell_command_record_item( #[cfg(test)] mod tests { use super::*; + use crate::codex::make_session_and_context; use crate::exec::StreamOutput; use pretty_assertions::assert_eq; @@ -76,7 +90,8 @@ mod tests { duration: Duration::from_secs(1), timed_out: false, }; - let item = user_shell_command_record_item("echo hi", &exec_output); + let (_, turn_context) = make_session_and_context(); + let item = user_shell_command_record_item("echo hi", &exec_output, &turn_context); let ResponseItem::Message { content, .. } = item else { panic!("expected message"); }; @@ -99,7 +114,8 @@ mod tests { duration: Duration::from_millis(120), timed_out: false, }; - let record = format_user_shell_command_record("false", &exec_output); + let (_, turn_context) = make_session_and_context(); + let record = format_user_shell_command_record("false", &exec_output, &turn_context); assert_eq!( record, "\n\nfalse\n\n\nExit code: 42\nDuration: 0.1200 seconds\nOutput:\ncombined output wins\n\n" diff --git a/codex-rs/core/templates/parallel/instructions.md b/codex-rs/core/templates/parallel/instructions.md index d690501af..292d585e4 100644 --- a/codex-rs/core/templates/parallel/instructions.md +++ b/codex-rs/core/templates/parallel/instructions.md @@ -1,3 +1,4 @@ + ## Exploration and reading files - **Think first.** Before any tool call, decide ALL files/resources you will need. @@ -10,5 +11,3 @@ * Always maximize parallelism. Never read files one-by-one unless logically unavoidable. * This concern every read/list/search operations including, but not only, `cat`, `rg`, `sed`, `ls`, `git show`, `nl`, `wc`, ... * Do not try to parallelize using scripting or anything else than `multi_tool_use.parallel`. - -## Editing constraints \ No newline at end of file diff --git a/codex-rs/core/tests/common/Cargo.toml b/codex-rs/core/tests/common/Cargo.toml index 65abe23c6..094f33a26 100644 --- a/codex-rs/core/tests/common/Cargo.toml +++ b/codex-rs/core/tests/common/Cargo.toml @@ -18,3 +18,4 @@ tempfile = { workspace = true } tokio = { workspace = true, features = ["time"] } walkdir = { workspace = true } wiremock = { workspace = true } +shlex = { workspace = true } diff --git a/codex-rs/core/tests/common/lib.rs b/codex-rs/core/tests/common/lib.rs index 3f75ed181..e7b1e71ef 100644 --- a/codex-rs/core/tests/common/lib.rs +++ b/codex-rs/core/tests/common/lib.rs @@ -172,6 +172,15 @@ pub fn sandbox_network_env_var() -> &'static str { codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR } +pub fn format_with_current_shell(command: &str) -> Vec { + codex_core::shell::default_user_shell().derive_exec_args(command, true) +} + +pub fn format_with_current_shell_display(command: &str) -> String { + let args = format_with_current_shell(command); + shlex::try_join(args.iter().map(String::as_str)).expect("serialize current shell command") +} + pub mod fs_wait { use anyhow::Result; use anyhow::anyhow; diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 3ebb28355..b84e96639 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -460,6 +460,16 @@ pub fn ev_apply_patch_function_call(call_id: &str, patch: &str) -> Value { }) } +pub fn ev_shell_command_call(call_id: &str, command: &str) -> Value { + let args = serde_json::json!({ "command": command }); + ev_shell_command_call_with_args(call_id, &args) +} + +pub fn ev_shell_command_call_with_args(call_id: &str, args: &serde_json::Value) -> Value { + let arguments = serde_json::to_string(args).expect("serialize shell command arguments"); + ev_function_call(call_id, "shell_command", &arguments) +} + pub fn ev_apply_patch_shell_call(call_id: &str, patch: &str) -> Value { let args = serde_json::json!({ "command": ["apply_patch", patch] }); let arguments = serde_json::to_string(&args).expect("serialize apply_patch arguments"); diff --git a/codex-rs/core/tests/suite/abort_tasks.rs b/codex-rs/core/tests/suite/abort_tasks.rs index 177461984..0d4a807a3 100644 --- a/codex-rs/core/tests/suite/abort_tasks.rs +++ b/codex-rs/core/tests/suite/abort_tasks.rs @@ -17,15 +17,11 @@ use core_test_support::wait_for_event; use regex_lite::Regex; use serde_json::json; -/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE +/// Integration test: spawn a long‑running shell_command tool via a mocked Responses SSE /// function call, then interrupt the session and expect TurnAborted. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn interrupt_long_running_tool_emits_turn_aborted() { - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "sleep 60".to_string(), - ]; + let command = "sleep 60"; let args = json!({ "command": command, @@ -33,14 +29,19 @@ async fn interrupt_long_running_tool_emits_turn_aborted() { }) .to_string(); let body = sse(vec![ - ev_function_call("call_sleep", "shell", &args), + ev_function_call("call_sleep", "shell_command", &args), ev_completed("done"), ]); let server = start_mock_server().await; mount_sse_once(&server, body).await; - let codex = test_codex().build(&server).await.unwrap().codex; + let codex = test_codex() + .with_model("gpt-5.1") + .build(&server) + .await + .unwrap() + .codex; // Kick off a turn that triggers the function call. codex @@ -67,11 +68,7 @@ async fn interrupt_long_running_tool_emits_turn_aborted() { /// responses server, and ensures the model receives the synthesized abort. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn interrupt_tool_records_history_entries() { - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "sleep 60".to_string(), - ]; + let command = "sleep 60"; let call_id = "call-history"; let args = json!({ @@ -81,7 +78,7 @@ async fn interrupt_tool_records_history_entries() { .to_string(); let first_body = sse(vec![ ev_response_created("resp-history"), - ev_function_call(call_id, "shell", &args), + ev_function_call(call_id, "shell_command", &args), ev_completed("resp-history"), ]); let follow_up_body = sse(vec![ @@ -92,7 +89,11 @@ async fn interrupt_tool_records_history_entries() { let server = start_mock_server().await; let response_mock = mount_sse_sequence(&server, vec![first_body, follow_up_body]).await; - let fixture = test_codex().build(&server).await.unwrap(); + let fixture = test_codex() + .with_model("gpt-5.1") + .build(&server) + .await + .unwrap(); let codex = Arc::clone(&fixture.codex); codex diff --git a/codex-rs/core/tests/suite/apply_patch_cli.rs b/codex-rs/core/tests/suite/apply_patch_cli.rs index 59ecf421e..596db602c 100644 --- a/codex-rs/core/tests/suite/apply_patch_cli.rs +++ b/codex-rs/core/tests/suite/apply_patch_cli.rs @@ -667,7 +667,7 @@ async fn apply_patch_cli_verification_failure_has_no_side_effects( } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn apply_patch_shell_heredoc_with_cd_updates_relative_workdir() -> Result<()> { +async fn apply_patch_shell_command_heredoc_with_cd_updates_relative_workdir() -> Result<()> { skip_if_no_network!(Ok(())); let harness = apply_patch_harness_with(|config| { @@ -684,14 +684,11 @@ async fn apply_patch_shell_heredoc_with_cd_updates_relative_workdir() -> Result< let script = "cd sub && apply_patch <<'EOF'\n*** Begin Patch\n*** Update File: in_sub.txt\n@@\n-before\n+after\n*** End Patch\nEOF\n"; let call_id = "shell-heredoc-cd"; - let args = json!({ - "command": ["bash", "-lc", script], - "timeout_ms": 5_000, - }); + let args = json!({ "command": script, "timeout_ms": 5_000 }); let bodies = vec![ sse(vec![ ev_response_created("resp-1"), - ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), ev_completed("resp-1"), ]), sse(vec![ @@ -706,14 +703,14 @@ async fn apply_patch_shell_heredoc_with_cd_updates_relative_workdir() -> Result< let out = harness.function_call_stdout(call_id).await; assert!( out.contains("Success."), - "expected successful apply_patch invocation via shell: {out}" + "expected successful apply_patch invocation via shell_command: {out}" ); assert_eq!(fs::read_to_string(&target)?, "after\n"); Ok(()) } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn apply_patch_shell_failure_propagates_error_and_skips_diff() -> Result<()> { +async fn apply_patch_shell_command_failure_propagates_error_and_skips_diff() -> Result<()> { skip_if_no_network!(Ok(())); let harness = apply_patch_harness_with(|config| { @@ -730,14 +727,11 @@ async fn apply_patch_shell_failure_propagates_error_and_skips_diff() -> Result<( let script = "apply_patch <<'EOF'\n*** Begin Patch\n*** Update File: invalid.txt\n@@\n-nope\n+changed\n*** End Patch\nEOF\n"; let call_id = "shell-apply-failure"; - let args = json!({ - "command": ["bash", "-lc", script], - "timeout_ms": 5_000, - }); + let args = json!({ "command": script, "timeout_ms": 5_000 }); let bodies = vec![ sse(vec![ ev_response_created("resp-1"), - ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), ev_completed("resp-1"), ]), sse(vec![ @@ -780,10 +774,6 @@ async fn apply_patch_shell_failure_propagates_error_and_skips_diff() -> Result<( ); let out = harness.function_call_stdout(call_id).await; - assert!( - out.contains("apply_patch verification failed"), - "expected verification failure message" - ); assert!( out.contains("Failed to find expected lines in"), "expected failure diagnostics: {out}" diff --git a/codex-rs/core/tests/suite/approvals.rs b/codex-rs/core/tests/suite/approvals.rs index a106d2eae..3eacdc335 100644 --- a/codex-rs/core/tests/suite/approvals.rs +++ b/codex-rs/core/tests/suite/approvals.rs @@ -71,7 +71,7 @@ enum ActionKind { response_body: &'static str, }, RunCommand { - command: &'static [&'static str], + command: &'static str, }, RunUnifiedExecCommand { command: &'static str, @@ -97,20 +97,12 @@ impl ActionKind { server: &MockServer, call_id: &str, with_escalated_permissions: bool, - ) -> Result<(Value, Option>)> { + ) -> Result<(Value, Option)> { match self { ActionKind::WriteFile { target, content } => { let (path, _) = target.resolve_for_patch(test); let _ = fs::remove_file(&path); - let command = vec![ - "/bin/sh".to_string(), - "-c".to_string(), - format!( - "printf {content:?} > {path:?} && cat {path:?}", - content = content, - path = path - ), - ]; + let command = format!("printf {content:?} > {path:?} && cat {path:?}"); let event = shell_event(call_id, &command, 1_000, with_escalated_permissions)?; Ok((event, Some(command))) } @@ -127,21 +119,18 @@ impl ActionKind { .await; let url = format!("{}{}", server.uri(), endpoint); + let escaped_url = url.replace('\'', "\\'"); let script = format!( - "import sys\nimport urllib.request\nurl = {url:?}\ntry:\n data = urllib.request.urlopen(url, timeout=2).read().decode()\n print('OK:' + data.strip())\nexcept Exception as exc:\n print('ERR:' + exc.__class__.__name__)\n sys.exit(1)", + "import sys\nimport urllib.request\nurl = '{escaped_url}'\ntry:\n data = urllib.request.urlopen(url, timeout=2).read().decode()\n print('OK:' + data.strip())\nexcept Exception as exc:\n print('ERR:' + exc.__class__.__name__)\n sys.exit(1)", ); - let command = vec!["python3".to_string(), "-c".to_string(), script]; + let command = format!("python3 -c \"{script}\""); let event = shell_event(call_id, &command, 1_000, with_escalated_permissions)?; Ok((event, Some(command))) } ActionKind::RunCommand { command } => { - let command: Vec = command - .iter() - .map(std::string::ToString::to_string) - .collect(); - let event = shell_event(call_id, &command, 1_000, with_escalated_permissions)?; - Ok((event, Some(command))) + let event = shell_event(call_id, command, 1_000, with_escalated_permissions)?; + Ok((event, Some(command.to_string()))) } ActionKind::RunUnifiedExecCommand { command, @@ -154,14 +143,7 @@ impl ActionKind { with_escalated_permissions, *justification, )?; - Ok(( - event, - Some(vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - command.to_string(), - ]), - )) + Ok((event, Some(command.to_string()))) } ActionKind::ApplyPatchFunction { target, content } => { let (path, patch_path) = target.resolve_for_patch(test); @@ -185,19 +167,19 @@ fn build_add_file_patch(patch_path: &str, content: &str) -> String { format!("*** Begin Patch\n*** Add File: {patch_path}\n+{content}\n*** End Patch\n") } -fn shell_apply_patch_command(patch: &str) -> Vec { +fn shell_apply_patch_command(patch: &str) -> String { let mut script = String::from("apply_patch <<'PATCH'\n"); script.push_str(patch); if !patch.ends_with('\n') { script.push('\n'); } script.push_str("PATCH\n"); - vec!["bash".to_string(), "-lc".to_string(), script] + script } fn shell_event( call_id: &str, - command: &[String], + command: &str, timeout_ms: u64, with_escalated_permissions: bool, ) -> Result { @@ -209,7 +191,7 @@ fn shell_event( args["with_escalated_permissions"] = json!(true); } let args_str = serde_json::to_string(&args)?; - Ok(ev_function_call(call_id, "shell", &args_str)) + Ok(ev_function_call(call_id, "shell_command", &args_str)) } fn exec_command_event( @@ -296,7 +278,10 @@ impl Expectation { } Expectation::FileCreatedNoExitCode { target, content } => { let (path, _) = target.resolve_for_patch(test); - assert_eq!(result.exit_code, None, "expected no exit code for {path:?}"); + assert!( + result.exit_code.is_none() || result.exit_code == Some(0), + "expected no exit code for {path:?}", + ); assert!( result.stdout.contains(content), "stdout missing {content:?}: {}", @@ -385,8 +370,8 @@ impl Expectation { ); } Expectation::NetworkSuccessNoExitCode { body_contains } => { - assert_eq!( - result.exit_code, None, + assert!( + result.exit_code.is_none() || result.exit_code == Some(0), "expected no exit code for successful network call: {}", result.stdout ); @@ -433,8 +418,8 @@ impl Expectation { ); } Expectation::CommandSuccessNoExitCode { stdout_contains } => { - assert_eq!( - result.exit_code, None, + assert!( + result.exit_code.is_none() || result.exit_code == Some(0), "expected no exit code for trusted command: {}", result.stdout ); @@ -531,10 +516,18 @@ fn parse_result(item: &Value) -> CommandResult { CommandResult { exit_code, stdout } } Err(_) => { + let structured = Regex::new(r"(?s)^Exit code:\s*(-?\d+).*?Output:\n(.*)$").unwrap(); let regex = Regex::new(r"(?s)^.*?Process exited with code (\d+)\n.*?Output:\n(.*)$").unwrap(); // parse freeform output - if let Some(captures) = regex.captures(output_str) { + if let Some(captures) = structured.captures(output_str) { + let exit_code = captures.get(1).unwrap().as_str().parse::().unwrap(); + let output = captures.get(2).unwrap().as_str(); + CommandResult { + exit_code: Some(exit_code), + stdout: output.to_string(), + } + } else if let Some(captures) = regex.captures(output_str) { let exit_code = captures.get(1).unwrap().as_str().parse::().unwrap(); let output = captures.get(2).unwrap().as_str(); CommandResult { @@ -553,7 +546,7 @@ fn parse_result(item: &Value) -> CommandResult { async fn expect_exec_approval( test: &TestCodex, - expected_command: &[String], + expected_command: &str, ) -> ExecApprovalRequestEvent { let event = wait_for_event(&test.codex, |event| { matches!( @@ -565,7 +558,12 @@ async fn expect_exec_approval( match event { EventMsg::ExecApprovalRequest(approval) => { - assert_eq!(approval.command, expected_command); + let last_arg = approval + .command + .last() + .map(std::string::String::as_str) + .unwrap_or_default(); + assert_eq!(last_arg, expected_command); approval } EventMsg::TaskComplete(_) => panic!("expected approval request before completion"), @@ -660,7 +658,7 @@ fn scenarios() -> Vec { features: vec![], model_override: Some("gpt-5.1"), outcome: Outcome::Auto, - expectation: Expectation::FileCreatedNoExitCode { + expectation: Expectation::FileCreated { target: TargetPath::OutsideWorkspace("dfa_on_request_5_1.txt"), content: "danger-on-request", }, @@ -702,7 +700,7 @@ fn scenarios() -> Vec { approval_policy: UnlessTrusted, sandbox_policy: SandboxPolicy::DangerFullAccess, action: ActionKind::RunCommand { - command: &["echo", "trusted-unless"], + command: "echo trusted-unless", }, with_escalated_permissions: false, features: vec![], @@ -717,7 +715,7 @@ fn scenarios() -> Vec { approval_policy: UnlessTrusted, sandbox_policy: SandboxPolicy::DangerFullAccess, action: ActionKind::RunCommand { - command: &["echo", "trusted-unless"], + command: "echo trusted-unless", }, with_escalated_permissions: false, features: vec![], @@ -880,7 +878,7 @@ fn scenarios() -> Vec { approval_policy: OnRequest, sandbox_policy: SandboxPolicy::ReadOnly, action: ActionKind::RunCommand { - command: &["echo", "trusted-read-only"], + command: "echo trusted-read-only", }, with_escalated_permissions: false, features: vec![], @@ -895,7 +893,7 @@ fn scenarios() -> Vec { approval_policy: OnRequest, sandbox_policy: SandboxPolicy::ReadOnly, action: ActionKind::RunCommand { - command: &["echo", "trusted-read-only"], + command: "echo trusted-read-only", }, with_escalated_permissions: false, features: vec![], @@ -1020,7 +1018,7 @@ fn scenarios() -> Vec { }, }, ScenarioSpec { - name: "apply_patch_shell_requires_patch_approval", + name: "apply_patch_shell_command_requires_patch_approval", approval_policy: UnlessTrusted, sandbox_policy: workspace_write(false), action: ActionKind::ApplyPatchShell { @@ -1114,7 +1112,7 @@ fn scenarios() -> Vec { }, }, ScenarioSpec { - name: "apply_patch_shell_outside_requires_patch_approval", + name: "apply_patch_shell_command_outside_requires_patch_approval", approval_policy: OnRequest, sandbox_policy: workspace_write(false), action: ActionKind::ApplyPatchShell { @@ -1229,7 +1227,10 @@ fn scenarios() -> Vec { message_contains: if cfg!(target_os = "linux") { &["Permission denied"] } else { - &["Permission denied|Operation not permitted|Read-only file system"] + &[ + "Permission denied|Operation not permitted|operation not permitted|\ + Read-only file system", + ] }, }, }, @@ -1238,7 +1239,7 @@ fn scenarios() -> Vec { approval_policy: Never, sandbox_policy: SandboxPolicy::ReadOnly, action: ActionKind::RunCommand { - command: &["echo", "trusted-never"], + command: "echo trusted-never", }, with_escalated_permissions: false, features: vec![], @@ -1373,7 +1374,10 @@ fn scenarios() -> Vec { message_contains: if cfg!(target_os = "linux") { &["Permission denied"] } else { - &["Permission denied|Operation not permitted|Read-only file system"] + &[ + "Permission denied|Operation not permitted|operation not permitted|\ + Read-only file system", + ] }, }, }, @@ -1509,7 +1513,7 @@ async fn run_scenario(scenario: &ScenarioSpec) -> Result<()> { expected_reason, } => { let command = expected_command - .as_ref() + .as_deref() .expect("exec approval requires shell command"); let approval = expect_exec_approval(&test, command).await; if let Some(expected_reason) = expected_reason { diff --git a/codex-rs/core/tests/suite/cli_stream.rs b/codex-rs/core/tests/suite/cli_stream.rs index b484eca22..d7f0fb983 100644 --- a/codex-rs/core/tests/suite/cli_stream.rs +++ b/codex-rs/core/tests/suite/cli_stream.rs @@ -499,9 +499,20 @@ async fn integration_git_info_unit_test() { "Git info should contain repository_url" ); let repo_url = git_info.repository_url.as_ref().unwrap(); + // Some hosts rewrite remotes (e.g., github.com → git@github.com), so assert against + // the actual remote reported by git instead of a static URL. + let expected_remote_url = std::process::Command::new("git") + .args(["remote", "get-url", "origin"]) + .current_dir(&git_repo) + .output() + .unwrap(); + let expected_remote_url = String::from_utf8(expected_remote_url.stdout) + .unwrap() + .trim() + .to_string(); assert_eq!( - repo_url, "https://github.com/example/integration-test.git", - "Repository URL should match what we configured" + repo_url, &expected_remote_url, + "Repository URL should match git remote get-url output" ); println!("✅ Git info collection test passed!"); diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index fefb18f32..71bcd5192 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -992,7 +992,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { id: Some("web-search-id".into()), status: Some("completed".into()), action: WebSearchAction::Search { - query: "weather".into(), + query: Some("weather".into()), }, }); prompt.input.push(ResponseItem::FunctionCall { @@ -1121,7 +1121,8 @@ async fn token_count_includes_rate_limits_snapshot() { "used_percent": 40.0, "window_minutes": 60, "resets_at": 1704074400 - } + }, + "credits": null } }) ); @@ -1155,7 +1156,7 @@ async fn token_count_includes_rate_limits_snapshot() { "reasoning_output_tokens": 0, "total_tokens": 123 }, - // Default model is gpt-5.1-codex in tests → 95% usable context window + // Default model is gpt-5.1-codex-max in tests → 95% usable context window "model_context_window": 258400 }, "rate_limits": { @@ -1168,7 +1169,8 @@ async fn token_count_includes_rate_limits_snapshot() { "used_percent": 40.0, "window_minutes": 60, "resets_at": 1704074400 - } + }, + "credits": null } }) ); @@ -1238,7 +1240,8 @@ async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> { "used_percent": 87.5, "window_minutes": 60, "resets_at": null - } + }, + "credits": null }); let submission_id = codex diff --git a/codex-rs/core/tests/suite/codex_delegate.rs b/codex-rs/core/tests/suite/codex_delegate.rs index c6ece7fe5..36cf0b865 100644 --- a/codex-rs/core/tests/suite/codex_delegate.rs +++ b/codex-rs/core/tests/suite/codex_delegate.rs @@ -1,3 +1,4 @@ +use codex_core::model_family::find_family_for_model; use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; @@ -25,17 +26,17 @@ use pretty_assertions::assert_eq; async fn codex_delegate_forwards_exec_approval_and_proceeds_on_approval() { skip_if_no_network!(); - // Sub-agent turn 1: emit a shell function_call requiring approval, then complete. + // Sub-agent turn 1: emit a shell_command function_call requiring approval, then complete. let call_id = "call-exec-1"; let args = serde_json::json!({ - "command": ["bash", "-lc", "rm -rf delegated"], + "command": "rm -rf delegated", "timeout_ms": 1000, "with_escalated_permissions": true, }) .to_string(); let sse1 = sse(vec![ ev_response_created("resp-1"), - ev_function_call(call_id, "shell", &args), + ev_function_call(call_id, "shell_command", &args), ev_completed("resp-1"), ]); @@ -61,6 +62,8 @@ async fn codex_delegate_forwards_exec_approval_and_proceeds_on_approval() { let mut builder = test_codex().with_config(|config| { config.approval_policy = AskForApproval::OnRequest; config.sandbox_policy = SandboxPolicy::ReadOnly; + config.model = "gpt-5.1".to_string(); + config.model_family = find_family_for_model("gpt-5.1").expect("gpt-5.1 is a valid model"); }); let test = builder.build(&server).await.expect("build test codex"); @@ -70,6 +73,7 @@ async fn codex_delegate_forwards_exec_approval_and_proceeds_on_approval() { review_request: ReviewRequest { prompt: "Please review".to_string(), user_facing_hint: "review".to_string(), + append_to_original_thread: true, }, }) .await @@ -137,6 +141,8 @@ async fn codex_delegate_forwards_patch_approval_and_proceeds_on_decision() { // Use a restricted sandbox so patch approval is required config.sandbox_policy = SandboxPolicy::ReadOnly; config.include_apply_patch_tool = true; + config.model = "gpt-5.1".to_string(); + config.model_family = find_family_for_model("gpt-5.1").expect("gpt-5.1 is a valid model"); }); let test = builder.build(&server).await.expect("build test codex"); @@ -145,6 +151,7 @@ async fn codex_delegate_forwards_patch_approval_and_proceeds_on_decision() { review_request: ReviewRequest { prompt: "Please review".to_string(), user_facing_hint: "review".to_string(), + append_to_original_thread: true, }, }) .await @@ -199,6 +206,7 @@ async fn codex_delegate_ignores_legacy_deltas() { review_request: ReviewRequest { prompt: "Please review".to_string(), user_facing_hint: "review".to_string(), + append_to_original_thread: true, }, }) .await diff --git a/codex-rs/core/tests/suite/compact.rs b/codex-rs/core/tests/suite/compact.rs index e8b3813c0..1324d3edb 100644 --- a/codex-rs/core/tests/suite/compact.rs +++ b/codex-rs/core/tests/suite/compact.rs @@ -191,22 +191,19 @@ async fn summarize_context_three_requests_and_instructions() { let body2_str = body2.to_string(); let input2 = body2.get("input").and_then(|v| v.as_array()).unwrap(); let has_compact_prompt = body_contains_text(&body2_str, SUMMARIZATION_PROMPT); - if has_compact_prompt { - // The last item is the user message created from the injected input. - let last2 = input2.last().unwrap(); - assert_eq!(last2.get("type").unwrap().as_str().unwrap(), "message"); - assert_eq!(last2.get("role").unwrap().as_str().unwrap(), "user"); - let text2 = last2["content"][0]["text"].as_str().unwrap(); - assert_eq!( - text2, SUMMARIZATION_PROMPT, - "expected summarize trigger, got `{text2}`" - ); - } else { - assert!( - !has_compact_prompt, - "compaction request should not unexpectedly include the summarize trigger" - ); - } + assert!( + has_compact_prompt, + "compaction request should include the summarize trigger" + ); + // The last item is the user message created from the injected input. + let last2 = input2.last().unwrap(); + assert_eq!(last2.get("type").unwrap().as_str().unwrap(), "message"); + assert_eq!(last2.get("role").unwrap().as_str().unwrap(), "user"); + let text2 = last2["content"][0]["text"].as_str().unwrap(); + assert_eq!( + text2, SUMMARIZATION_PROMPT, + "expected summarize trigger, got `{text2}`" + ); // Third request must contain the refreshed instructions, compacted user history, and new user message. let input3 = body3.get("input").and_then(|v| v.as_array()).unwrap(); @@ -387,7 +384,7 @@ async fn manual_compact_uses_custom_prompt() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn manual_compact_emits_estimated_token_usage_event() { +async fn manual_compact_emits_api_and_local_token_usage_events() { skip_if_no_network!(); let server = start_mock_server().await; diff --git a/codex-rs/core/tests/suite/compact_remote.rs b/codex-rs/core/tests/suite/compact_remote.rs index 4bc1af9e1..dc88bc574 100644 --- a/codex-rs/core/tests/suite/compact_remote.rs +++ b/codex-rs/core/tests/suite/compact_remote.rs @@ -13,10 +13,13 @@ use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use codex_protocol::user_input::UserInput; use core_test_support::responses; +use core_test_support::responses::mount_sse_once; +use core_test_support::responses::sse; use core_test_support::skip_if_no_network; use core_test_support::test_codex::TestCodexHarness; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; +use core_test_support::wait_for_event_match; use pretty_assertions::assert_eq; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -125,6 +128,72 @@ async fn remote_compact_replaces_history_for_followups() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn remote_compact_runs_automatically() -> Result<()> { + skip_if_no_network!(Ok(())); + + let harness = TestCodexHarness::with_builder( + test_codex() + .with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) + .with_config(|config| { + config.features.enable(Feature::RemoteCompaction); + }), + ) + .await?; + let codex = harness.test().codex.clone(); + + mount_sse_once( + harness.server(), + sse(vec![ + responses::ev_shell_command_call("m1", "echo 'hi'"), + responses::ev_completed_with_tokens("resp-1", 100000000), // over token limit + ]), + ) + .await; + let responses_mock = mount_sse_once( + harness.server(), + responses::sse(vec![ + responses::ev_assistant_message("m2", "AFTER_COMPACT_REPLY"), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + let compacted_history = vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "REMOTE_COMPACTED_SUMMARY".to_string(), + }], + }]; + let compact_mock = responses::mount_compact_json_once( + harness.server(), + serde_json::json!({ "output": compacted_history.clone() }), + ) + .await; + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello remote compact".into(), + }], + }) + .await?; + let message = wait_for_event_match(&codex, |ev| match ev { + EventMsg::AgentMessage(ev) => Some(ev.message.clone()), + _ => None, + }) + .await; + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + assert_eq!(message, "Compact task completed"); + assert_eq!(compact_mock.requests().len(), 1); + let follow_up_body = responses_mock.single_request().body_json().to_string(); + assert!(follow_up_body.contains("REMOTE_COMPACTED_SUMMARY")); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn remote_compact_persists_replacement_history_in_rollout() -> Result<()> { skip_if_no_network!(Ok(())); diff --git a/codex-rs/core/tests/suite/exec.rs b/codex-rs/core/tests/suite/exec.rs index ea5ab8487..bb0f1bce0 100644 --- a/codex-rs/core/tests/suite/exec.rs +++ b/codex-rs/core/tests/suite/exec.rs @@ -32,7 +32,7 @@ async fn run_test_cmd(tmp: TempDir, cmd: Vec<&str>) -> Result Result<()> { + // TODO execpolicy doesn't parse powershell commands yet + if cfg!(windows) { + return Ok(()); + } + + let mut builder = test_codex().with_config(|config| { + let policy_path = config.codex_home.join("policy").join("policy.codexpolicy"); + fs::create_dir_all( + policy_path + .parent() + .expect("policy directory must have a parent"), + ) + .expect("create policy directory"); + fs::write( + &policy_path, + r#"prefix_rule(pattern=["echo"], decision="forbidden")"#, + ) + .expect("write policy file"); + config.model = "gpt-5.1".to_string(); + config.model_family = + find_family_for_model("gpt-5.1").expect("gpt-5.1 should have a model family"); + }); + let server = start_mock_server().await; + let test = builder.build(&server).await?; + + let call_id = "shell-forbidden"; + let args = json!({ + "command": "echo blocked", + "timeout_ms": 1_000, + }); + + mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ) + .await; + + let session_model = test.session_configured.model.clone(); + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "run shell command".into(), + }], + final_output_json_schema: None, + cwd: test.cwd_path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let EventMsg::ExecCommandEnd(end) = wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::ExecCommandEnd(_)) + }) + .await + else { + unreachable!() + }; + wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::TaskComplete(_)) + }) + .await; + + assert!( + end.aggregated_output + .contains("execpolicy forbids this command"), + "unexpected output: {}", + end.aggregated_output + ); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index ef248901e..b87766361 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -28,6 +28,7 @@ mod compact_remote; mod compact_resume_fork; mod deprecation_notice; mod exec; +mod exec_policy; mod fork_conversation; mod grep_files; mod items; @@ -48,6 +49,7 @@ mod seatbelt; mod shell_serialization; mod stream_error_allows_next_turn; mod stream_no_completed; +mod text_encoding_fix; mod tool_harness; mod tool_parallelism; mod tools; diff --git a/codex-rs/core/tests/suite/model_tools.rs b/codex-rs/core/tests/suite/model_tools.rs index 28593a55b..e807ce7db 100644 --- a/codex-rs/core/tests/suite/model_tools.rs +++ b/codex-rs/core/tests/suite/model_tools.rs @@ -1,21 +1,10 @@ #![allow(clippy::unwrap_used)] -use codex_core::CodexAuth; -use codex_core::ConversationManager; -use codex_core::ModelProviderInfo; -use codex_core::built_in_model_providers; -use codex_core::features::Feature; -use codex_core::model_family::find_family_for_model; -use codex_core::protocol::EventMsg; -use codex_core::protocol::Op; -use codex_protocol::user_input::UserInput; -use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; use core_test_support::responses; +use core_test_support::responses::start_mock_server; use core_test_support::skip_if_no_network; -use core_test_support::wait_for_event; -use tempfile::TempDir; -use wiremock::MockServer; +use core_test_support::test_codex::test_codex; fn sse_completed(id: &str) -> String { load_sse_fixture_with_id("tests/fixtures/completed_template.json", id) @@ -39,46 +28,17 @@ fn tool_identifiers(body: &serde_json::Value) -> Vec { #[allow(clippy::expect_used)] async fn collect_tool_identifiers_for_model(model: &str) -> Vec { - let server = MockServer::start().await; - + let server = start_mock_server().await; let sse = sse_completed(model); let resp_mock = responses::mount_sse_once(&server, sse).await; - let model_provider = ModelProviderInfo { - base_url: Some(format!("{}/v1", server.uri())), - ..built_in_model_providers()["openai"].clone() - }; - - let cwd = TempDir::new().unwrap(); - let codex_home = TempDir::new().unwrap(); - let mut config = load_default_config_for_test(&codex_home); - config.cwd = cwd.path().to_path_buf(); - config.model_provider = model_provider; - config.model = model.to_string(); - config.model_family = - find_family_for_model(model).unwrap_or_else(|| panic!("unknown model family for {model}")); - config.features.disable(Feature::ApplyPatchFreeform); - config.features.disable(Feature::ViewImageTool); - config.features.disable(Feature::WebSearchRequest); - config.features.disable(Feature::UnifiedExec); - - let conversation_manager = - ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); - let codex = conversation_manager - .new_conversation(config) + let mut builder = test_codex().with_model(model); + let test = builder + .build(&server) .await - .expect("create new conversation") - .conversation; + .expect("create test Codex conversation"); - codex - .submit(Op::UserInput { - items: vec![UserInput::Text { - text: "hello tools".into(), - }], - }) - .await - .unwrap(); - wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + test.submit_turn("hello tools").await.expect("submit turn"); let body = resp_mock.single_request().body_json(); tool_identifiers(&body) @@ -97,72 +57,67 @@ async fn model_selects_expected_tools() { "list_mcp_resources".to_string(), "list_mcp_resource_templates".to_string(), "read_mcp_resource".to_string(), - "update_plan".to_string() + "update_plan".to_string(), + "view_image".to_string() ], "codex-mini-latest should expose the local shell tool", ); - let o3_tools = collect_tool_identifiers_for_model("o3").await; + let gpt5_codex_tools = collect_tool_identifiers_for_model("gpt-5-codex").await; assert_eq!( - o3_tools, + gpt5_codex_tools, vec![ - "shell".to_string(), + "shell_command".to_string(), "list_mcp_resources".to_string(), "list_mcp_resource_templates".to_string(), "read_mcp_resource".to_string(), - "update_plan".to_string() + "update_plan".to_string(), + "apply_patch".to_string(), + "view_image".to_string() ], - "o3 should expose the generic shell tool", + "gpt-5-codex should expose the apply_patch tool", ); - let gpt5_codex_tools = collect_tool_identifiers_for_model("gpt-5-codex").await; + let gpt51_codex_tools = collect_tool_identifiers_for_model("gpt-5.1-codex").await; assert_eq!( - gpt5_codex_tools, + gpt51_codex_tools, vec![ - if cfg!(windows) { - "shell_command" - } else { - "shell" - } - .to_string(), + "shell_command".to_string(), "list_mcp_resources".to_string(), "list_mcp_resource_templates".to_string(), "read_mcp_resource".to_string(), "update_plan".to_string(), - "apply_patch".to_string() + "apply_patch".to_string(), + "view_image".to_string() ], - "gpt-5-codex should expose the apply_patch tool", + "gpt-5.1-codex should expose the apply_patch tool", ); - let gpt51_codex_tools = collect_tool_identifiers_for_model("gpt-5.1-codex").await; + let gpt5_tools = collect_tool_identifiers_for_model("gpt-5").await; assert_eq!( - gpt51_codex_tools, + gpt5_tools, vec![ - if cfg!(windows) { - "shell_command" - } else { - "shell" - } - .to_string(), + "shell".to_string(), "list_mcp_resources".to_string(), "list_mcp_resource_templates".to_string(), "read_mcp_resource".to_string(), "update_plan".to_string(), - "apply_patch".to_string() + "view_image".to_string() ], - "gpt-5.1-codex should expose the apply_patch tool", + "gpt-5 should expose the apply_patch tool", ); let gpt51_tools = collect_tool_identifiers_for_model("gpt-5.1").await; assert_eq!( gpt51_tools, vec![ - "shell".to_string(), + "shell_command".to_string(), "list_mcp_resources".to_string(), "list_mcp_resource_templates".to_string(), "read_mcp_resource".to_string(), "update_plan".to_string(), - "apply_patch".to_string() + "apply_patch".to_string(), + "view_image".to_string() ], "gpt-5.1 should expose the apply_patch tool", ); diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index c858ea20f..f4455fd02 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -1,9 +1,9 @@ #![allow(clippy::unwrap_used)] -use codex_core::config::OPENAI_DEFAULT_MODEL; use codex_core::features::Feature; use codex_core::model_family::find_family_for_model; use codex_core::protocol::AskForApproval; +use codex_core::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; use codex_core::protocol::SandboxPolicy; @@ -19,7 +19,6 @@ use core_test_support::skip_if_no_network; use core_test_support::test_codex::TestCodex; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; -use std::collections::HashMap; use tempfile::TempDir; fn text_user_input(text: String) -> serde_json::Value { @@ -31,18 +30,15 @@ fn text_user_input(text: String) -> serde_json::Value { } fn default_env_context_str(cwd: &str, shell: &Shell) -> String { + let shell_name = shell.name(); format!( r#" - {} + {cwd} on-request read-only restricted -{}"#, - cwd, - match shell.name() { - Some(name) => format!(" {name}\n"), - None => String::new(), - } + {shell_name} +"# ) } @@ -156,61 +152,15 @@ async fn prompt_tools_are_consistent_across_requests() -> anyhow::Result<()> { .await?; wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - // our internal implementation is responsible for keeping tools in sync - // with the OpenAI schema, so we just verify the tool presence here - let tools_by_model: HashMap<&'static str, Vec<&'static str>> = HashMap::from([ - ( - "gpt-5.1", - vec![ - "shell", - "list_mcp_resources", - "list_mcp_resource_templates", - "read_mcp_resource", - "update_plan", - "view_image", - ], - ), - ( - "gpt-5.1", - vec![ - "shell", - "list_mcp_resources", - "list_mcp_resource_templates", - "read_mcp_resource", - "update_plan", - "apply_patch", - "view_image", - ], - ), - ( - "gpt-5.1-codex", - vec![ - "shell", - "list_mcp_resources", - "list_mcp_resource_templates", - "read_mcp_resource", - "update_plan", - "apply_patch", - "view_image", - ], - ), - ( - "gpt-5.1-codex", - vec![ - "shell", - "list_mcp_resources", - "list_mcp_resource_templates", - "read_mcp_resource", - "update_plan", - "apply_patch", - "view_image", - ], - ), - ]); - let expected_tools_names = tools_by_model - .get(OPENAI_DEFAULT_MODEL) - .unwrap_or_else(|| panic!("expected tools to be defined for model {OPENAI_DEFAULT_MODEL}")) - .as_slice(); + let expected_tools_names = vec![ + "shell_command", + "list_mcp_resources", + "list_mcp_resource_templates", + "read_mcp_resource", + "update_plan", + "apply_patch", + "view_image", + ]; let body0 = req1.single_request().body_json(); let expected_instructions = if expected_tools_names.contains(&"apply_patch") { @@ -227,14 +177,14 @@ async fn prompt_tools_are_consistent_across_requests() -> anyhow::Result<()> { body0["instructions"], serde_json::json!(expected_instructions), ); - assert_tool_names(&body0, expected_tools_names); + assert_tool_names(&body0, &expected_tools_names); let body1 = req2.single_request().body_json(); assert_eq!( body1["instructions"], serde_json::json!(expected_instructions), ); - assert_tool_names(&body1, expected_tools_names); + assert_tool_names(&body1, &expected_tools_names); Ok(()) } @@ -274,7 +224,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests .await?; wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - let shell = default_user_shell().await; + let shell = default_user_shell(); let cwd_str = config.cwd.to_string_lossy(); let expected_env_text = default_env_context_str(&cwd_str, &shell); let expected_ui_text = format!( @@ -392,6 +342,7 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() -> an // After overriding the turn context, the environment context should be emitted again // reflecting the new approval policy and sandbox settings. Omit cwd because it did // not change. + let shell = default_user_shell(); let expected_env_text_2 = format!( r#" never @@ -400,8 +351,10 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() -> an {} + {} "#, - writable.path().to_string_lossy(), + writable.path().display(), + shell.name() ); let expected_env_msg_2 = serde_json::json!({ "type": "message", @@ -420,6 +373,89 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() -> an Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn override_before_first_turn_emits_environment_context() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let req = mount_sse_once(&server, sse_completed("resp-1")).await; + + let TestCodex { codex, .. } = test_codex().build(&server).await?; + + codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: Some(AskForApproval::Never), + sandbox_policy: None, + model: None, + effort: None, + summary: None, + }) + .await?; + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "first message".into(), + }], + }) + .await?; + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let body = req.single_request().body_json(); + let input = body["input"] + .as_array() + .expect("input array must be present"); + assert!( + !input.is_empty(), + "expected at least environment context and user message" + ); + + let env_msg = &input[1]; + let env_text = env_msg["content"][0]["text"] + .as_str() + .expect("environment context text"); + assert!( + env_text.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG), + "second entry should be environment context, got: {env_text}" + ); + assert!( + env_text.contains("never"), + "environment context should reflect overridden approval policy: {env_text}" + ); + + let env_count = input + .iter() + .filter(|msg| { + msg["content"] + .as_array() + .and_then(|content| { + content.iter().find(|item| { + item["type"].as_str() == Some("input_text") + && item["text"] + .as_str() + .map(|text| text.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG)) + .unwrap_or(false) + }) + }) + .is_some() + }) + .count(); + assert_eq!( + env_count, 2, + "environment context should appear exactly twice, found {env_count}" + ); + + let user_msg = &input[2]; + let user_text = user_msg["content"][0]["text"] + .as_str() + .expect("user message text"); + assert_eq!(user_text, "first message"); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn per_turn_overrides_keep_cached_prefix_and_key_constant() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -486,6 +522,8 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() -> anyhow::Res "role": "user", "content": [ { "type": "input_text", "text": "hello 2" } ] }); + let shell = default_user_shell(); + let expected_env_text_2 = format!( r#" {} @@ -495,9 +533,11 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() -> anyhow::Res {} + {} "#, - new_cwd.path().to_string_lossy(), - writable.path().to_string_lossy(), + new_cwd.path().display(), + writable.path().display(), + shell.name(), ); let expected_env_msg_2 = serde_json::json!({ "type": "message", @@ -574,7 +614,7 @@ async fn send_user_turn_with_no_changes_does_not_send_environment_context() -> a let body1 = req1.single_request().body_json(); let body2 = req2.single_request().body_json(); - let shell = default_user_shell().await; + let shell = default_user_shell(); let default_cwd_lossy = default_cwd.to_string_lossy(); let expected_ui_text = format!( "# AGENTS.md instructions for {default_cwd_lossy}\n\n\nbe consistent and helpful\n" @@ -661,7 +701,7 @@ async fn send_user_turn_with_changes_sends_environment_context() -> anyhow::Resu let body1 = req1.single_request().body_json(); let body2 = req2.single_request().body_json(); - let shell = default_user_shell().await; + let shell = default_user_shell(); let expected_ui_text = format!( "# AGENTS.md instructions for {}\n\n\nbe consistent and helpful\n", default_cwd.to_string_lossy() @@ -681,14 +721,15 @@ async fn send_user_turn_with_changes_sends_environment_context() -> anyhow::Resu ]); assert_eq!(body1["input"], expected_input_1); - let expected_env_msg_2 = text_user_input( + let shell_name = shell.name(); + let expected_env_msg_2 = text_user_input(format!( r#" never danger-full-access enabled + {shell_name} "# - .to_string(), - ); + )); let expected_user_message_2 = text_user_input("hello 2".to_string()); let expected_input_2 = serde_json::Value::Array(vec![ expected_ui_msg, diff --git a/codex-rs/core/tests/suite/review.rs b/codex-rs/core/tests/suite/review.rs index 9c3f812c2..3904f18f8 100644 --- a/codex-rs/core/tests/suite/review.rs +++ b/codex-rs/core/tests/suite/review.rs @@ -82,6 +82,7 @@ async fn review_op_emits_lifecycle_and_review_output() { review_request: ReviewRequest { prompt: "Please review my changes".to_string(), user_facing_hint: "my changes".to_string(), + append_to_original_thread: true, }, }) .await @@ -178,6 +179,7 @@ async fn review_op_with_plain_text_emits_review_fallback() { review_request: ReviewRequest { prompt: "Plain text review".to_string(), user_facing_hint: "plain text review".to_string(), + append_to_original_thread: true, }, }) .await @@ -236,6 +238,7 @@ async fn review_filters_agent_message_related_events() { review_request: ReviewRequest { prompt: "Filter streaming events".to_string(), user_facing_hint: "Filter streaming events".to_string(), + append_to_original_thread: true, }, }) .await @@ -320,6 +323,7 @@ async fn review_does_not_emit_agent_message_on_structured_output() { review_request: ReviewRequest { prompt: "check structured".to_string(), user_facing_hint: "check structured".to_string(), + append_to_original_thread: true, }, }) .await @@ -373,6 +377,7 @@ async fn review_uses_custom_review_model_from_config() { review_request: ReviewRequest { prompt: "use custom model".to_string(), user_facing_hint: "use custom model".to_string(), + append_to_original_thread: true, }, }) .await @@ -490,6 +495,7 @@ async fn review_input_isolated_from_parent_history() { review_request: ReviewRequest { prompt: review_prompt.clone(), user_facing_hint: review_prompt.clone(), + append_to_original_thread: true, }, }) .await @@ -602,6 +608,7 @@ async fn review_history_does_not_leak_into_parent_session() { review_request: ReviewRequest { prompt: "Start a review".to_string(), user_facing_hint: "Start a review".to_string(), + append_to_original_thread: true, }, }) .await diff --git a/codex-rs/core/tests/suite/shell_serialization.rs b/codex-rs/core/tests/suite/shell_serialization.rs index 3b468ce2b..9c49e95f5 100644 --- a/codex-rs/core/tests/suite/shell_serialization.rs +++ b/codex-rs/core/tests/suite/shell_serialization.rs @@ -2,6 +2,7 @@ #![allow(clippy::expect_used)] use anyhow::Result; +use codex_core::config::Config; use codex_core::features::Feature; use codex_core::model_family::find_family_for_model; use codex_core::protocol::SandboxPolicy; @@ -40,6 +41,20 @@ const FIXTURE_JSON: &str = r#"{ } "#; +fn configure_shell_command_model(output_type: ShellModelOutput, config: &mut Config) { + if !matches!(output_type, ShellModelOutput::ShellCommand) { + return; + } + + if let Some(shell_command_family) = find_family_for_model("test-gpt-5-codex") { + if config.model_family.shell_type == shell_command_family.shell_type { + return; + } + config.model = shell_command_family.slug.clone(); + config.model_family = shell_command_family; + } +} + fn shell_responses( call_id: &str, command: Vec<&str>, @@ -101,7 +116,6 @@ fn shell_responses( #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[test_case(ShellModelOutput::Shell)] -#[test_case(ShellModelOutput::ShellCommand)] #[test_case(ShellModelOutput::LocalShell)] async fn shell_output_stays_json_without_freeform_apply_patch( output_type: ShellModelOutput, @@ -113,10 +127,7 @@ async fn shell_output_stays_json_without_freeform_apply_patch( config.features.disable(Feature::ApplyPatchFreeform); config.model = "gpt-5".to_string(); config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a model family"); - if matches!(output_type, ShellModelOutput::ShellCommand) { - config.features.enable(Feature::ShellCommandTool); - } - let _ = output_type; + configure_shell_command_model(output_type, config); }); let test = builder.build(&server).await?; @@ -171,10 +182,7 @@ async fn shell_output_is_structured_with_freeform_apply_patch( let server = start_mock_server().await; let mut builder = test_codex().with_config(move |config| { config.features.enable(Feature::ApplyPatchFreeform); - if matches!(output_type, ShellModelOutput::ShellCommand) { - config.features.enable(Feature::ShellCommandTool); - } - let _ = output_type; + configure_shell_command_model(output_type, config); }); let test = builder.build(&server).await?; @@ -213,7 +221,6 @@ freeform shell #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[test_case(ShellModelOutput::Shell)] -#[test_case(ShellModelOutput::ShellCommand)] #[test_case(ShellModelOutput::LocalShell)] async fn shell_output_preserves_fixture_json_without_serialization( output_type: ShellModelOutput, @@ -225,10 +232,7 @@ async fn shell_output_preserves_fixture_json_without_serialization( config.features.disable(Feature::ApplyPatchFreeform); config.model = "gpt-5".to_string(); config.model_family = find_family_for_model("gpt-5").expect("gpt-5 is a model family"); - if matches!(output_type, ShellModelOutput::ShellCommand) { - config.features.enable(Feature::ShellCommandTool); - } - let _ = output_type; + configure_shell_command_model(output_type, config); }); let test = builder.build(&server).await?; @@ -295,10 +299,7 @@ async fn shell_output_structures_fixture_with_serialization( let server = start_mock_server().await; let mut builder = test_codex().with_config(move |config| { config.features.enable(Feature::ApplyPatchFreeform); - if matches!(output_type, ShellModelOutput::ShellCommand) { - config.features.enable(Feature::ShellCommandTool); - } - let _ = output_type; + configure_shell_command_model(output_type, config); }); let test = builder.build(&server).await?; @@ -360,15 +361,12 @@ async fn shell_output_for_freeform_tool_records_duration( let server = start_mock_server().await; let mut builder = test_codex().with_config(move |config| { config.include_apply_patch_tool = true; - if matches!(output_type, ShellModelOutput::ShellCommand) { - config.features.enable(Feature::ShellCommandTool); - } - let _ = output_type; + configure_shell_command_model(output_type, config); }); let test = builder.build(&server).await?; let call_id = "shell-structured"; - let responses = shell_responses(call_id, vec!["/bin/bash", "-c", "sleep 1"], output_type)?; + let responses = shell_responses(call_id, vec!["/bin/sh", "-c", "sleep 1"], output_type)?; let mock = mount_sse_sequence(&server, responses).await; test.submit_turn_with_policy( @@ -409,7 +407,6 @@ $"#; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[test_case(ShellModelOutput::Shell)] -#[test_case(ShellModelOutput::ShellCommand)] #[test_case(ShellModelOutput::LocalShell)] async fn shell_output_reserializes_truncated_content(output_type: ShellModelOutput) -> Result<()> { skip_if_no_network!(Ok(())); @@ -419,10 +416,8 @@ async fn shell_output_reserializes_truncated_content(output_type: ShellModelOutp config.model = "gpt-5.1-codex".to_string(); config.model_family = find_family_for_model("gpt-5.1-codex").expect("gpt-5.1-codex is a model family"); - if matches!(output_type, ShellModelOutput::ShellCommand) { - config.features.enable(Feature::ShellCommandTool); - } - let _ = output_type; + config.tool_output_token_limit = Some(200); + configure_shell_command_model(output_type, config); }); let test = builder.build(&server).await?; @@ -459,10 +454,7 @@ Output: 4 5 6 -.* -\[\.{3} omitted \d+ of 400 lines \.{3}\] - -.* +.*…46 tokens truncated….* 396 397 398 @@ -727,9 +719,7 @@ async fn shell_output_is_structured_for_nonzero_exit(output_type: ShellModelOutp config.model_family = find_family_for_model("gpt-5.1-codex").expect("gpt-5.1-codex is a model family"); config.include_apply_patch_tool = true; - if matches!(output_type, ShellModelOutput::ShellCommand) { - config.features.enable(Feature::ShellCommandTool); - } + configure_shell_command_model(output_type, config); }); let test = builder.build(&server).await?; @@ -760,12 +750,12 @@ Output: } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn shell_command_output_is_structured() -> Result<()> { +async fn shell_command_output_is_freeform() -> Result<()> { skip_if_no_network!(Ok(())); let server = start_mock_server().await; let mut builder = test_codex().with_config(move |config| { - config.features.enable(Feature::ShellCommandTool); + configure_shell_command_model(ShellModelOutput::ShellCommand, config); }); let test = builder.build(&server).await?; @@ -812,6 +802,106 @@ shell command Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_command_output_is_not_truncated_under_10k_bytes() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex().with_model("gpt-5.1"); + let test = builder.build(&server).await?; + + let call_id = "shell-command"; + let args = json!({ + "command": "perl -e 'print \"1\" x 10000'", + "timeout_ms": 1000, + }); + let responses = vec![ + sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "shell_command done"), + ev_completed("resp-2"), + ]), + ]; + let mock = mount_sse_sequence(&server, responses).await; + + test.submit_turn_with_policy( + "run the shell_command script in the user's shell", + SandboxPolicy::DangerFullAccess, + ) + .await?; + + let req = mock + .last_request() + .expect("shell_command output request recorded"); + let output_item = req.function_call_output(call_id); + let output = output_item + .get("output") + .and_then(Value::as_str) + .expect("shell_command output string"); + + let expected_pattern = r"(?s)^Exit code: 0 +Wall time: [0-9]+(?:\.[0-9]+)? seconds +Output: +1{10000}$"; + assert_regex_match(expected_pattern, output); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_command_output_is_not_truncated_over_10k_bytes() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex().with_model("gpt-5.1"); + let test = builder.build(&server).await?; + + let call_id = "shell-command"; + let args = json!({ + "command": "perl -e 'print \"1\" x 10001'", + "timeout_ms": 1000, + }); + let responses = vec![ + sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "shell_command done"), + ev_completed("resp-2"), + ]), + ]; + let mock = mount_sse_sequence(&server, responses).await; + + test.submit_turn_with_policy( + "run the shell_command script in the user's shell", + SandboxPolicy::DangerFullAccess, + ) + .await?; + + let req = mock + .last_request() + .expect("shell_command output request recorded"); + let output_item = req.function_call_output(call_id); + let output = output_item + .get("output") + .and_then(Value::as_str) + .expect("shell_command output string"); + + let expected_pattern = r"(?s)^Exit code: 0 +Wall time: [0-9]+(?:\.[0-9]+)? seconds +Output: +1*…1 chars truncated…1*$"; + assert_regex_match(expected_pattern, output); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn local_shell_call_output_is_structured() -> Result<()> { skip_if_no_network!(Ok(())); diff --git a/codex-rs/core/tests/suite/text_encoding_fix.rs b/codex-rs/core/tests/suite/text_encoding_fix.rs new file mode 100644 index 000000000..ecebb1e42 --- /dev/null +++ b/codex-rs/core/tests/suite/text_encoding_fix.rs @@ -0,0 +1,77 @@ +//! Integration test for the text encoding fix for issue #6178. +//! +//! These tests simulate VSCode's shell preview on Windows/WSL where the output +//! may be encoded with a legacy code page before it reaches Codex. + +use codex_core::exec::StreamOutput; +use pretty_assertions::assert_eq; + +#[test] +fn test_utf8_shell_output() { + // Baseline: UTF-8 output should bypass the detector and remain unchanged. + assert_eq!(decode_shell_output("пример".as_bytes()), "пример"); +} + +#[test] +fn test_cp1251_shell_output() { + // VS Code shells on Windows frequently surface CP1251 bytes for Cyrillic text. + assert_eq!(decode_shell_output(b"\xEF\xF0\xE8\xEC\xE5\xF0"), "пример"); +} + +#[test] +fn test_cp866_shell_output() { + // Native cmd.exe still defaults to CP866; make sure we recognize that too. + assert_eq!(decode_shell_output(b"\xAF\xE0\xA8\xAC\xA5\xE0"), "пример"); +} + +#[test] +fn test_windows_1252_smart_decoding() { + // Smart detection should turn fancy quotes/dashes into the proper Unicode glyphs. + assert_eq!( + decode_shell_output(b"\x93\x94 test \x96 dash"), + "\u{201C}\u{201D} test \u{2013} dash" + ); +} + +#[test] +fn test_smart_decoding_improves_over_lossy_utf8() { + // Regression guard: String::from_utf8_lossy() alone used to emit replacement chars here. + let bytes = b"\x93\x94 test \x96 dash"; + assert!( + String::from_utf8_lossy(bytes).contains('\u{FFFD}'), + "lossy UTF-8 should inject replacement chars" + ); + assert_eq!( + decode_shell_output(bytes), + "\u{201C}\u{201D} test \u{2013} dash", + "smart decoding should keep curly quotes intact" + ); +} + +#[test] +fn test_mixed_ascii_and_legacy_encoding() { + // Commands tend to mix ASCII status text with Latin-1 bytes (e.g. café). + assert_eq!(decode_shell_output(b"Output: caf\xE9"), "Output: café"); // codespell:ignore caf +} + +#[test] +fn test_pure_latin1_shell_output() { + // Latin-1 by itself should still decode correctly (regression coverage for the older tests). + assert_eq!(decode_shell_output(b"caf\xE9"), "café"); // codespell:ignore caf +} + +#[test] +fn test_invalid_bytes_still_fall_back_to_lossy() { + // If detection fails, we still want the user to see replacement characters. + let bytes = b"\xFF\xFE\xFD"; + assert_eq!(decode_shell_output(bytes), String::from_utf8_lossy(bytes)); +} + +fn decode_shell_output(bytes: &[u8]) -> String { + StreamOutput { + text: bytes.to_vec(), + truncated_after_lines: None, + } + .from_utf8_lossy() + .text +} diff --git a/codex-rs/core/tests/suite/tool_parallelism.rs b/codex-rs/core/tests/suite/tool_parallelism.rs index 807bb9d82..96998f5d1 100644 --- a/codex-rs/core/tests/suite/tool_parallelism.rs +++ b/codex-rs/core/tests/suite/tool_parallelism.rs @@ -146,10 +146,11 @@ async fn non_parallel_tools_run_serially() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); let server = start_mock_server().await; - let test = test_codex().build(&server).await?; + let mut builder = test_codex().with_model("gpt-5.1"); + let test = builder.build(&server).await?; let shell_args = json!({ - "command": ["/bin/sh", "-c", "sleep 0.3"], + "command": "sleep 0.3", "timeout_ms": 1_000, }); let args_one = serde_json::to_string(&shell_args)?; @@ -157,8 +158,8 @@ async fn non_parallel_tools_run_serially() -> anyhow::Result<()> { let first_response = sse(vec![ json!({"type": "response.created", "response": {"id": "resp-1"}}), - ev_function_call("call-1", "shell", &args_one), - ev_function_call("call-2", "shell", &args_two), + ev_function_call("call-1", "shell_command", &args_one), + ev_function_call("call-2", "shell_command", &args_two), ev_completed("resp-1"), ]); let second_response = sse(vec![ @@ -167,7 +168,7 @@ async fn non_parallel_tools_run_serially() -> anyhow::Result<()> { ]); mount_sse_sequence(&server, vec![first_response, second_response]).await; - let duration = run_turn_and_measure(&test, "run shell twice").await?; + let duration = run_turn_and_measure(&test, "run shell_command twice").await?; assert_serial_duration(duration); Ok(()) @@ -185,14 +186,14 @@ async fn mixed_tools_fall_back_to_serial() -> anyhow::Result<()> { }) .to_string(); let shell_args = serde_json::to_string(&json!({ - "command": ["/bin/sh", "-c", "sleep 0.3"], + "command": "sleep 0.3", "timeout_ms": 1_000, }))?; let first_response = sse(vec![ json!({"type": "response.created", "response": {"id": "resp-1"}}), ev_function_call("call-1", "test_sync_tool", &sync_args), - ev_function_call("call-2", "shell", &shell_args), + ev_function_call("call-2", "shell_command", &shell_args), ev_completed("resp-1"), ]); let second_response = sse(vec![ @@ -215,7 +216,7 @@ async fn tool_results_grouped() -> anyhow::Result<()> { let test = build_codex_with_test_tool(&server).await?; let shell_args = serde_json::to_string(&json!({ - "command": ["/bin/sh", "-c", "echo 'shell output'"], + "command": "echo 'shell output'", "timeout_ms": 1_000, }))?; @@ -223,9 +224,9 @@ async fn tool_results_grouped() -> anyhow::Result<()> { &server, sse(vec![ json!({"type": "response.created", "response": {"id": "resp-1"}}), - ev_function_call("call-1", "shell", &shell_args), - ev_function_call("call-2", "shell", &shell_args), - ev_function_call("call-3", "shell", &shell_args), + ev_function_call("call-1", "shell_command", &shell_args), + ev_function_call("call-2", "shell_command", &shell_args), + ev_function_call("call-3", "shell_command", &shell_args), ev_completed("resp-1"), ]), ) diff --git a/codex-rs/core/tests/suite/truncation.rs b/codex-rs/core/tests/suite/truncation.rs index 200d27e88..f06694f62 100644 --- a/codex-rs/core/tests/suite/truncation.rs +++ b/codex-rs/core/tests/suite/truncation.rs @@ -27,7 +27,6 @@ use core_test_support::skip_if_no_network; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use escargot::CargoBuild; -use regex_lite::Regex; use serde_json::Value; use serde_json::json; use std::collections::HashMap; @@ -48,7 +47,7 @@ async fn truncate_function_error_trims_respond_to_model() -> Result<()> { let test = builder.build(&server).await?; // Construct a very long, non-existent path to force a RespondToModel error with a large message - let long_path = "a".repeat(20_000); + let long_path = "long path text should trigger truncation".repeat(8_000); let call_id = "grep-huge-error"; let args = json!({ "pattern": "alpha", @@ -80,12 +79,16 @@ async fn truncate_function_error_trims_respond_to_model() -> Result<()> { tracing::debug!(output = %output, "truncated function error output"); - // Expect plaintext with byte-truncation marker and no omitted-lines marker + // Expect plaintext with token-based truncation marker and no omitted-lines marker assert!( serde_json::from_str::(&output).is_err(), "expected error output to be plain text", ); - let truncated_pattern = r#"(?s)^Total output lines: 1\s+.*\[\.\.\. output truncated to fit 11264 bytes \.\.\.\]\s*$"#; + assert!( + !output.contains("Total output lines:"), + "error output should not include line-based truncation header: {output}", + ); + let truncated_pattern = r"(?s)^unable to access `.*tokens truncated.*$"; assert_regex_match(truncated_pattern, &output); assert!( !output.contains("omitted"), @@ -95,7 +98,157 @@ async fn truncate_function_error_trims_respond_to_model() -> Result<()> { Ok(()) } -// Verifies that a standard tool call (shell) exceeding the model formatting +// Verifies that a standard tool call (shell_command) exceeding the model formatting +// limits is truncated before being sent back to the model. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tool_call_output_configured_limit_chars_type() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + // Use a model that exposes the shell_command tool. + let mut builder = test_codex().with_model("gpt-5.1").with_config(|config| { + config.tool_output_token_limit = Some(100_000); + }); + + let fixture = builder.build(&server).await?; + + let call_id = "shell-too-large"; + let command = if cfg!(windows) { + "for ($i=1; $i -le 100000; $i++) { Write-Output $i }" + } else { + "seq 1 100000" + }; + let args = serde_json::json!({ + "command": command, + "timeout_ms": 5_000, + }); + + // First response: model tells us to run the tool; second: complete the turn. + mount_sse_once( + &server, + sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + responses::ev_completed("resp-1"), + ]), + ) + .await; + let mock2 = mount_sse_once( + &server, + sse(vec![ + responses::ev_assistant_message("msg-1", "done"), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + fixture + .submit_turn_with_policy("trigger big shell output", SandboxPolicy::DangerFullAccess) + .await?; + + // Inspect what we sent back to the model; it should contain a truncated + // function_call_output for the shell call. + let output = mock2 + .single_request() + .function_call_output_text(call_id) + .context("function_call_output present for shell call")?; + let output = output.replace("\r\n", "\n"); + + // Expect plain text (not JSON) containing the entire shell output. + assert!( + serde_json::from_str::(&output).is_err(), + "expected truncated shell output to be plain text" + ); + + assert!( + (400000..=401000).contains(&output.len()), + "we should be almost 100k tokens" + ); + + assert!( + !output.contains("tokens truncated"), + "shell output should not contain tokens truncated marker: {output}" + ); + + Ok(()) +} + +// Verifies that a standard tool call (shell_command) exceeding the model formatting +// limits is truncated before being sent back to the model. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tool_call_output_exceeds_limit_truncated_chars_limit() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + // Use a model that exposes the shell_command tool. + let mut builder = test_codex().with_model("gpt-5.1"); + + let fixture = builder.build(&server).await?; + + let call_id = "shell-too-large"; + let command = if cfg!(windows) { + "for ($i=1; $i -le 100000; $i++) { Write-Output $i }" + } else { + "seq 1 100000" + }; + let args = serde_json::json!({ + "command": command, + "timeout_ms": 5_000, + }); + + // First response: model tells us to run the tool; second: complete the turn. + mount_sse_once( + &server, + sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + responses::ev_completed("resp-1"), + ]), + ) + .await; + let mock2 = mount_sse_once( + &server, + sse(vec![ + responses::ev_assistant_message("msg-1", "done"), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + fixture + .submit_turn_with_policy("trigger big shell output", SandboxPolicy::DangerFullAccess) + .await?; + + // Inspect what we sent back to the model; it should contain a truncated + // function_call_output for the shell call. + let output = mock2 + .single_request() + .function_call_output_text(call_id) + .context("function_call_output present for shell call")?; + let output = output.replace("\r\n", "\n"); + + // Expect plain text (not JSON) containing the entire shell output. + assert!( + serde_json::from_str::(&output).is_err(), + "expected truncated shell output to be plain text" + ); + + let truncated_pattern = r#"(?s)^Exit code: 0\nWall time: [0-9]+(?:\.[0-9]+)? seconds\nTotal output lines: 100000\nOutput:\n.*?…\d+ chars truncated….*$"#; + + assert_regex_match(truncated_pattern, &output); + + let len = output.len(); + assert!( + (9_900..=10_100).contains(&len), + "expected ~10k chars after truncation, got {len}" + ); + + Ok(()) +} + +// Verifies that a standard tool call (shell_command) exceeding the model formatting // limits is truncated before being sent back to the model. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> { @@ -103,7 +256,7 @@ async fn tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> { let server = start_mock_server().await; - // Use a model that exposes the generic shell tool. + // Use a model that exposes the shell_command tool. let mut builder = test_codex().with_config(|config| { config.model = "gpt-5.1-codex".to_string(); config.model_family = @@ -112,28 +265,22 @@ async fn tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> { let fixture = builder.build(&server).await?; let call_id = "shell-too-large"; - let args = if cfg!(windows) { - serde_json::json!({ - "command": [ - "powershell", - "-Command", - "for ($i=1; $i -le 400; $i++) { Write-Output $i }" - ], - "timeout_ms": 5_000, - }) + let command = if cfg!(windows) { + "for ($i=1; $i -le 100000; $i++) { Write-Output $i }" } else { - serde_json::json!({ - "command": ["/bin/sh", "-c", "seq 1 400"], - "timeout_ms": 5_000, - }) + "seq 1 100000" }; + let args = serde_json::json!({ + "command": command, + "timeout_ms": 5_000, + }); // First response: model tells us to run the tool; second: complete the turn. mount_sse_once( &server, sse(vec![ responses::ev_response_created("resp-1"), - responses::ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), + responses::ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), responses::ev_completed("resp-1"), ]), ) @@ -159,14 +306,14 @@ async fn tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> { .context("function_call_output present for shell call")?; let output = output.replace("\r\n", "\n"); - // Expect plain text (not JSON) with truncation markers and line elision. + // Expect plain text (not JSON) containing the entire shell output. assert!( serde_json::from_str::(&output).is_err(), "expected truncated shell output to be plain text" ); let truncated_pattern = r#"(?s)^Exit code: 0 -Wall time: .* seconds -Total output lines: 400 +Wall time: [0-9]+(?:\.[0-9]+)? seconds +Total output lines: 100000 Output: 1 2 @@ -174,22 +321,16 @@ Output: 4 5 6 -.* -\[\.{3} omitted 144 of 400 lines \.{3}\] - -.* -396 -397 -398 -399 -400 +.*…137224 tokens truncated.* +99999 +100000 $"#; assert_regex_match(truncated_pattern, &output); Ok(()) } -// Ensures shell tool outputs that exceed the line limit are truncated only once. +// Ensures shell_command outputs that exceed the line limit are truncated only once. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tool_call_output_truncated_only_once() -> Result<()> { skip_if_no_network!(Ok(())); @@ -203,27 +344,21 @@ async fn tool_call_output_truncated_only_once() -> Result<()> { }); let fixture = builder.build(&server).await?; let call_id = "shell-single-truncation"; - let args = if cfg!(windows) { - serde_json::json!({ - "command": [ - "powershell", - "-Command", - "for ($i=1; $i -le 2000; $i++) { Write-Output $i }" - ], - "timeout_ms": 5_000, - }) + let command = if cfg!(windows) { + "for ($i=1; $i -le 10000; $i++) { Write-Output $i }" } else { - serde_json::json!({ - "command": ["/bin/sh", "-c", "seq 1 2000"], - "timeout_ms": 5_000, - }) + "seq 1 10000" }; + let args = serde_json::json!({ + "command": command, + "timeout_ms": 5_000, + }); mount_sse_once( &server, sse(vec![ responses::ev_response_created("resp-1"), - responses::ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), + responses::ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), responses::ev_completed("resp-1"), ]), ) @@ -246,11 +381,11 @@ async fn tool_call_output_truncated_only_once() -> Result<()> { .function_call_output_text(call_id) .context("function_call_output present for shell call")?; - let total_line_headers = output.matches("Total output lines:").count(); + let truncation_markers = output.matches("tokens truncated").count(); assert_eq!( - total_line_headers, 1, - "shell output should carry only one truncation header: {output}" + truncation_markers, 1, + "shell output should carry only one truncation marker: {output}" ); Ok(()) @@ -269,7 +404,7 @@ async fn mcp_tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> let tool_name = format!("mcp__{server_name}__echo"); // Build a very large message to exceed 10KiB once serialized. - let large_msg = "long-message-with-newlines-".repeat(600); + let large_msg = "long-message-with-newlines-".repeat(6000); let args_json = serde_json::json!({ "message": large_msg }); mount_sse_once( @@ -318,6 +453,7 @@ async fn mcp_tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> disabled_tools: None, }, ); + config.tool_output_token_limit = Some(500); }); let fixture = builder.build(&server).await?; @@ -334,22 +470,14 @@ async fn mcp_tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> .function_call_output_text(call_id) .context("function_call_output present for rmcp call")?; - // Expect plain text with byte-based truncation marker. assert!( - serde_json::from_str::(&output).is_err(), - "expected truncated MCP output to be plain text" - ); - assert!( - output.starts_with("Total output lines: 1\n\n{"), - "expected total line header and JSON head, got: {output}" + !output.contains("Total output lines:"), + "MCP output should not include line-based truncation header: {output}" ); - let byte_marker = Regex::new(r"\[\.\.\. output truncated to fit 11264 bytes \.\.\.\]") - .expect("compile regex"); - assert!( - byte_marker.is_match(&output), - "expected byte truncation marker, got: {output}" - ); + let truncated_pattern = r#"(?s)^\{"echo":\s*"ECHOING: long-message-with-newlines-.*tokens truncated.*long-message-with-newlines-.*$"#; + assert_regex_match(truncated_pattern, &output); + assert!(output.len() < 2500, "{}", output.len()); Ok(()) } @@ -453,3 +581,271 @@ async fn mcp_image_output_preserves_image_and_no_text_summary() -> Result<()> { Ok(()) } + +// Token-based policy should report token counts even when truncation is byte-estimated. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn token_policy_marker_reports_tokens() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.model = "gpt-5.1-codex".to_string(); // token policy + config.model_family = + find_family_for_model("gpt-5.1-codex").expect("model family for gpt-5.1-codex"); + config.tool_output_token_limit = Some(50); // small budget to force truncation + }); + let fixture = builder.build(&server).await?; + + let call_id = "shell-token-marker"; + let args = json!({ + "command": "seq 1 150", + "timeout_ms": 5_000, + }); + + mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + ) + .await; + let done_mock = mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ) + .await; + + fixture + .submit_turn_with_policy("run the shell tool", SandboxPolicy::DangerFullAccess) + .await?; + + let output = done_mock + .single_request() + .function_call_output_text(call_id) + .context("shell output present")?; + + let pattern = r"(?s)^Exit code: 0\nWall time: [0-9]+(?:\.[0-9]+)? seconds\nTotal output lines: 150\nOutput:\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\n16\n17\n18\n19.*tokens truncated.*129\n130\n131\n132\n133\n134\n135\n136\n137\n138\n139\n140\n141\n142\n143\n144\n145\n146\n147\n148\n149\n150\n$"; + + assert_regex_match(pattern, &output); + + Ok(()) +} + +// Byte-based policy should report bytes removed. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn byte_policy_marker_reports_bytes() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.model = "gpt-5.1".to_string(); // byte policy + config.model_family = find_family_for_model("gpt-5.1").expect("model family for gpt-5.1"); + config.tool_output_token_limit = Some(50); // ~200 byte cap + }); + let fixture = builder.build(&server).await?; + + let call_id = "shell-byte-marker"; + let args = json!({ + "command": "seq 1 150", + "timeout_ms": 5_000, + }); + + mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + ) + .await; + let done_mock = mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ) + .await; + + fixture + .submit_turn_with_policy("run the shell tool", SandboxPolicy::DangerFullAccess) + .await?; + + let output = done_mock + .single_request() + .function_call_output_text(call_id) + .context("shell output present")?; + + let pattern = r"(?s)^Exit code: 0\nWall time: [0-9]+(?:\.[0-9]+)? seconds\nTotal output lines: 150\nOutput:\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\n16\n17\n18\n19.*chars truncated.*129\n130\n131\n132\n133\n134\n135\n136\n137\n138\n139\n140\n141\n142\n143\n144\n145\n146\n147\n148\n149\n150\n$"; + + assert_regex_match(pattern, &output); + + Ok(()) +} + +// shell_command output should remain intact when the config opts into a large token budget. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_command_output_not_truncated_with_custom_limit() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.model = "gpt-5.1-codex".to_string(); + config.model_family = + find_family_for_model("gpt-5.1-codex").expect("model family for gpt-5.1-codex"); + config.tool_output_token_limit = Some(50_000); // ample budget + }); + let fixture = builder.build(&server).await?; + + let call_id = "shell-no-trunc"; + let args = json!({ + "command": "seq 1 1000", + "timeout_ms": 5_000, + }); + let expected_body: String = (1..=1000).map(|i| format!("{i}\n")).collect(); + + mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + ) + .await; + let done_mock = mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ) + .await; + + fixture + .submit_turn_with_policy( + "run big output without truncation", + SandboxPolicy::DangerFullAccess, + ) + .await?; + + let output = done_mock + .single_request() + .function_call_output_text(call_id) + .context("shell output present")?; + + assert!( + output.ends_with(&expected_body), + "expected entire shell output when budget increased: {output}" + ); + assert!( + !output.contains("truncated"), + "output should remain untruncated with ample budget" + ); + + Ok(()) +} + +// MCP server output should also remain intact when the config increases the token limit. +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn mcp_tool_call_output_not_truncated_with_custom_limit() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let call_id = "rmcp-untruncated"; + let server_name = "rmcp"; + let tool_name = format!("mcp__{server_name}__echo"); + let large_msg = "a".repeat(80_000); + let args_json = serde_json::json!({ "message": large_msg }); + + mount_sse_once( + &server, + sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call(call_id, &tool_name, &args_json.to_string()), + responses::ev_completed("resp-1"), + ]), + ) + .await; + let mock2 = mount_sse_once( + &server, + sse(vec![ + responses::ev_assistant_message("msg-1", "rmcp echo tool completed."), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + let rmcp_test_server_bin = CargoBuild::new() + .package("codex-rmcp-client") + .bin("test_stdio_server") + .run()? + .path() + .to_string_lossy() + .into_owned(); + + let mut builder = test_codex().with_config(move |config| { + config.features.enable(Feature::RmcpClient); + config.tool_output_token_limit = Some(50_000); + config.mcp_servers.insert( + server_name.to_string(), + codex_core::config::types::McpServerConfig { + transport: codex_core::config::types::McpServerTransportConfig::Stdio { + command: rmcp_test_server_bin, + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + startup_timeout_sec: Some(std::time::Duration::from_secs(10)), + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + }, + ); + }); + let fixture = builder.build(&server).await?; + + fixture + .submit_turn_with_policy( + "call the rmcp echo tool with a very large message", + SandboxPolicy::ReadOnly, + ) + .await?; + + let output = mock2 + .single_request() + .function_call_output_text(call_id) + .context("function_call_output present for rmcp call")?; + + let parsed: Value = serde_json::from_str(&output)?; + assert_eq!( + output.len(), + 80031, + "parsed MCP output should retain its serialized length" + ); + let expected_echo = format!("ECHOING: {large_msg}"); + let echo_str = parsed["echo"] + .as_str() + .context("echo field should be a string in rmcp echo output")?; + assert_eq!( + echo_str.len(), + expected_echo.len(), + "echo length should match" + ); + assert_eq!(echo_str, expected_echo); + assert!( + !output.contains("truncated"), + "output should not include truncation markers when limit is raised: {output}" + ); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index 970be5277..aed4cecef 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -26,9 +26,11 @@ use core_test_support::test_codex::TestCodex; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use core_test_support::wait_for_event_match; +use core_test_support::wait_for_event_with_timeout; use regex_lite::Regex; use serde_json::Value; use serde_json::json; +use tokio::time::Duration; fn extract_output_text(item: &Value) -> Option<&str> { item.get("output").and_then(|value| match value { @@ -814,7 +816,7 @@ async fn exec_command_reports_chunk_and_exit_metadata() -> Result<()> { let call_id = "uexec-metadata"; let args = serde_json::json!({ - "cmd": "printf 'abcdefghijklmnopqrstuvwxyz'", + "cmd": "printf 'token one token two token three token four token five token six token seven'", "yield_time_ms": 500, "max_output_tokens": 6, }); @@ -902,6 +904,98 @@ async fn exec_command_reports_chunk_and_exit_metadata() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_respects_early_exit_notifications() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::UnifiedExec); + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "uexec-early-exit"; + let args = serde_json::json!({ + "cmd": "sleep 0.05", + "yield_time_ms": 31415, + }); + + let responses = vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "exec_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "watch early exit timing".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let outputs = collect_tool_outputs(&bodies)?; + let output = outputs + .get(call_id) + .expect("missing early exit unified_exec output"); + + assert!( + output.session_id.is_none(), + "short-lived process should not keep a session alive" + ); + assert_eq!( + output.exit_code, + Some(0), + "short-lived process should exit successfully" + ); + + let wall_time = output.wall_time_seconds; + assert!( + wall_time < 0.75, + "wall_time should reflect early exit rather than the full yield time; got {wall_time}" + ); + assert!( + output.output.is_empty(), + "sleep command should not emit output, got {:?}", + output.output + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn write_stdin_returns_exit_metadata_and_clears_session() -> Result<()> { skip_if_no_network!(Ok(())); @@ -1295,7 +1389,7 @@ async fn unified_exec_streams_after_lagged_output() -> Result<()> { import sys import time -chunk = b'x' * (1 << 20) +chunk = b'long content here to trigger truncation' * (1 << 10) for _ in range(4): sys.stdout.buffer.write(chunk) sys.stdout.flush() @@ -1365,8 +1459,13 @@ PY summary: ReasoningSummary::Auto, }) .await?; - - wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; + // This is a worst case scenario for the truncate logic. + wait_for_event_with_timeout( + &codex, + |event| matches!(event, EventMsg::TaskComplete(_)), + Duration::from_secs(10), + ) + .await; let requests = server.received_requests().await.expect("recorded requests"); assert!(!requests.is_empty(), "expected at least one POST request"); @@ -1523,14 +1622,15 @@ async fn unified_exec_formats_large_output_summary() -> Result<()> { } = builder.build(&server).await?; let script = r#"python3 - <<'PY' -for i in range(300): - print(f"line-{i}") +import sys +sys.stdout.write("token token \n" * 5000) PY "#; let call_id = "uexec-large-output"; let args = serde_json::json!({ "cmd": script, + "max_output_tokens": 100, "yield_time_ms": 500, }); @@ -1577,15 +1677,14 @@ PY let outputs = collect_tool_outputs(&bodies)?; let large_output = outputs.get(call_id).expect("missing large output summary"); - assert_regex_match( - concat!( - r"(?s)", - r"line-0.*?", - r"\[\.{3} omitted \d+ of \d+ lines \.{3}\].*?", - r"line-299", - ), - &large_output.output, - ); + let output_text = large_output.output.replace("\r\n", "\n"); + let truncated_pattern = r"(?s)^Total output lines: \d+\n\n(token token \n){5,}.*…\d+ tokens truncated….*(token token \n){5,}$"; + assert_regex_match(truncated_pattern, &output_text); + + let original_tokens = large_output + .original_token_count + .expect("missing original_token_count for large output summary"); + assert!(original_tokens > 0); Ok(()) } diff --git a/codex-rs/core/tests/suite/user_shell_cmd.rs b/codex-rs/core/tests/suite/user_shell_cmd.rs index 0d42c45c1..f6e53b95e 100644 --- a/codex-rs/core/tests/suite/user_shell_cmd.rs +++ b/codex-rs/core/tests/suite/user_shell_cmd.rs @@ -207,10 +207,16 @@ async fn user_shell_command_history_is_persisted_and_shared_with_model() -> anyh } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[cfg(not(target_os = "windows"))] // TODO: unignore on windows async fn user_shell_command_output_is_truncated_in_history() -> anyhow::Result<()> { let server = responses::start_mock_server().await; - let mut builder = core_test_support::test_codex::test_codex(); - let test = builder.build(&server).await?; + let builder = core_test_support::test_codex::test_codex(); + let test = builder + .with_config(|config| { + config.tool_output_token_limit = Some(100); + }) + .build(&server) + .await?; #[cfg(windows)] let command = r#"for ($i=1; $i -le 400; $i++) { Write-Output $i }"#.to_string(); @@ -249,10 +255,10 @@ async fn user_shell_command_output_is_truncated_in_history() -> anyhow::Result<( .expect("command message recorded in request"); let command_message = command_message.replace("\r\n", "\n"); - let head = (1..=128).map(|i| format!("{i}\n")).collect::(); - let tail = (273..=400).map(|i| format!("{i}\n")).collect::(); + let head = (1..=69).map(|i| format!("{i}\n")).collect::(); + let tail = (352..=400).map(|i| format!("{i}\n")).collect::(); let truncated_body = - format!("Total output lines: 400\n\n{head}\n[... omitted 144 of 400 lines ...]\n\n{tail}"); + format!("Total output lines: 400\n\n{head}70…273 tokens truncated…351\n{tail}"); let escaped_command = escape(&command); let escaped_truncated_body = escape(&truncated_body); let expected_pattern = format!( @@ -270,6 +276,7 @@ async fn user_shell_command_is_truncated_only_once() -> anyhow::Result<()> { let server = start_mock_server().await; let mut builder = test_codex().with_config(|config| { + config.tool_output_token_limit = Some(100); config.model = "gpt-5.1-codex".to_string(); config.model_family = find_family_for_model("gpt-5.1-codex").expect("gpt-5.1-codex is a model family"); @@ -279,16 +286,12 @@ async fn user_shell_command_is_truncated_only_once() -> anyhow::Result<()> { let call_id = "user-shell-double-truncation"; let args = if cfg!(windows) { serde_json::json!({ - "command": [ - "powershell", - "-Command", - "for ($i=1; $i -le 2000; $i++) { Write-Output $i }" - ], + "command": "for ($i=1; $i -le 2000; $i++) { Write-Output $i }", "timeout_ms": 5_000, }) } else { serde_json::json!({ - "command": ["/bin/sh", "-c", "seq 1 2000"], + "command": "seq 1 2000", "timeout_ms": 5_000, }) }; @@ -297,7 +300,7 @@ async fn user_shell_command_is_truncated_only_once() -> anyhow::Result<()> { &server, sse(vec![ ev_response_created("resp-1"), - ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), + ev_function_call(call_id, "shell_command", &serde_json::to_string(&args)?), ev_completed("resp-1"), ]), ) @@ -312,19 +315,22 @@ async fn user_shell_command_is_truncated_only_once() -> anyhow::Result<()> { .await; fixture - .submit_turn_with_policy("trigger big shell output", SandboxPolicy::DangerFullAccess) + .submit_turn_with_policy( + "trigger big shell_command output", + SandboxPolicy::DangerFullAccess, + ) .await?; let output = mock2 .single_request() .function_call_output_text(call_id) - .context("function_call_output present for shell call")?; + .context("function_call_output present for shell_command call")?; let truncation_headers = output.matches("Total output lines:").count(); assert_eq!( truncation_headers, 1, - "shell output should carry only one truncation header: {output}" + "shell_command output should carry only one truncation header: {output}" ); Ok(()) diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml new file mode 100644 index 000000000..24c13e0e2 --- /dev/null +++ b/codex-rs/exec-server/Cargo.toml @@ -0,0 +1,58 @@ +[package] +edition = "2024" +name = "codex-exec-server" +version = { workspace = true } + +[[bin]] +name = "codex-execve-wrapper" +path = "src/bin/main_execve_wrapper.rs" + +[[bin]] +name = "codex-exec-mcp-server" +path = "src/bin/main_mcp_server.rs" + +[lib] +name = "codex_exec_server" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +clap = { workspace = true, features = ["derive"] } +codex-core = { workspace = true } +libc = { workspace = true } +path-absolutize = { workspace = true } +rmcp = { workspace = true, default-features = false, features = [ + "auth", + "elicitation", + "base64", + "client", + "macros", + "schemars", + "server", + "transport-child-process", + "transport-streamable-http-client-reqwest", + "transport-streamable-http-server", + "transport-io", +] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +shlex = { workspace = true } +socket2 = { workspace = true } +tokio = { workspace = true, features = [ + "io-std", + "macros", + "process", + "rt-multi-thread", + "signal", +] } +tokio-util = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } + +[dev-dependencies] +pretty_assertions = { workspace = true } +tempfile = { workspace = true } diff --git a/codex-rs/exec-server/src/bin/main_execve_wrapper.rs b/codex-rs/exec-server/src/bin/main_execve_wrapper.rs new file mode 100644 index 000000000..3ab346e8e --- /dev/null +++ b/codex-rs/exec-server/src/bin/main_execve_wrapper.rs @@ -0,0 +1,8 @@ +#[cfg(not(unix))] +fn main() { + eprintln!("codex-execve-wrapper is only implemented for UNIX"); + std::process::exit(1); +} + +#[cfg(unix)] +pub use codex_exec_server::main_execve_wrapper as main; diff --git a/codex-rs/exec-server/src/bin/main_mcp_server.rs b/codex-rs/exec-server/src/bin/main_mcp_server.rs new file mode 100644 index 000000000..6c75ae423 --- /dev/null +++ b/codex-rs/exec-server/src/bin/main_mcp_server.rs @@ -0,0 +1,8 @@ +#[cfg(not(unix))] +fn main() { + eprintln!("codex-exec-mcp-server is only implemented for UNIX"); + std::process::exit(1); +} + +#[cfg(unix)] +pub use codex_exec_server::main_mcp_server as main; diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs new file mode 100644 index 000000000..adec09d4d --- /dev/null +++ b/codex-rs/exec-server/src/lib.rs @@ -0,0 +1,8 @@ +#[cfg(unix)] +mod posix; + +#[cfg(unix)] +pub use posix::main_execve_wrapper; + +#[cfg(unix)] +pub use posix::main_mcp_server; diff --git a/codex-rs/exec-server/src/posix.rs b/codex-rs/exec-server/src/posix.rs new file mode 100644 index 000000000..b4dd0fbf4 --- /dev/null +++ b/codex-rs/exec-server/src/posix.rs @@ -0,0 +1,170 @@ +//! This is an MCP that implements an alternative `shell` tool with fine-grained privilege +//! escalation based on a per-exec() policy. +//! +//! We spawn Bash process inside a sandbox. The Bash we spawn is patched to allow us to intercept +//! every exec() call it makes by invoking a wrapper program and passing in the arguments it would +//! have passed to exec(). The Bash process (and its descendants) inherit a communication socket +//! from us, and we give its fd number in the CODEX_ESCALATE_SOCKET environment variable. +//! +//! When we intercept an exec() call, we send a message over the socket back to the main +//! MCP process. The MCP process can then decide whether to allow the exec() call to proceed +//! or to escalate privileges and run the requested command with elevated permissions. In the +//! latter case, we send a message back to the child requesting that it forward its open FDs to us. +//! We then execute the requested command on its behalf, patching in the forwarded FDs. +//! +//! +//! ### The privilege escalation flow +//! +//! Child MCP Bash Escalate Helper +//! | +//! o----->o +//! | | +//! | o--(exec)-->o +//! | | | +//! |o<-(EscalateReq)--o +//! || | | +//! |o--(Escalate)---->o +//! || | | +//! |o<---------(fds)--o +//! || | | +//! o<-----o | | +//! | || | | +//! x----->o | | +//! || | | +//! |x--(exit code)--->o +//! | | | +//! | o<--(exit)--x +//! | | +//! o<-----x +//! +//! ### The non-escalation flow +//! +//! MCP Bash Escalate Helper Child +//! | +//! o----->o +//! | | +//! | o--(exec)-->o +//! | | | +//! |o<-(EscalateReq)--o +//! || | | +//! |o-(Run)---------->o +//! | | | +//! | | x--(exec)-->o +//! | | | +//! | o<--------------(exit)--x +//! | | +//! o<-----x +//! +use std::path::Path; +use std::path::PathBuf; + +use clap::Parser; +use tracing_subscriber::EnvFilter; +use tracing_subscriber::{self}; + +use crate::posix::mcp_escalation_policy::ExecPolicyOutcome; + +mod escalate_client; +mod escalate_protocol; +mod escalate_server; +mod escalation_policy; +mod mcp; +mod mcp_escalation_policy; +mod socket; +mod stopwatch; + +/// Default value of --execve option relative to the current executable. +/// Note this must match the name of the binary as specified in Cargo.toml. +const CODEX_EXECVE_WRAPPER_EXE_NAME: &str = "codex-execve-wrapper"; + +#[derive(Parser)] +struct McpServerCli { + /// Executable to delegate execve(2) calls to in Bash. + #[arg(long = "execve")] + execve_wrapper: Option, + + /// Path to Bash that has been patched to support execve() wrapping. + #[arg(long = "bash")] + bash_path: Option, +} + +#[tokio::main] +pub async fn main_mcp_server() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_writer(std::io::stderr) + .with_ansi(false) + .init(); + + let cli = McpServerCli::parse(); + let execve_wrapper = match cli.execve_wrapper { + Some(path) => path, + None => { + let cwd = std::env::current_exe()?; + cwd.parent() + .map(|p| p.join(CODEX_EXECVE_WRAPPER_EXE_NAME)) + .ok_or_else(|| { + anyhow::anyhow!("failed to determine execve wrapper path from current exe") + })? + } + }; + let bash_path = match cli.bash_path { + Some(path) => path, + None => mcp::get_bash_path()?, + }; + + tracing::info!("Starting MCP server"); + let service = mcp::serve(bash_path, execve_wrapper, dummy_exec_policy) + .await + .inspect_err(|e| { + tracing::error!("serving error: {:?}", e); + })?; + + service.waiting().await?; + Ok(()) +} + +#[derive(Parser)] +pub struct ExecveWrapperCli { + file: String, + + #[arg(trailing_var_arg = true)] + argv: Vec, +} + +#[tokio::main] +pub async fn main_execve_wrapper() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_writer(std::io::stderr) + .with_ansi(false) + .init(); + + let ExecveWrapperCli { file, argv } = ExecveWrapperCli::parse(); + let exit_code = escalate_client::run(file, argv).await?; + std::process::exit(exit_code); +} + +// TODO: replace with execpolicy + +fn dummy_exec_policy(file: &Path, argv: &[String], _workdir: &Path) -> ExecPolicyOutcome { + if file.ends_with("rm") { + ExecPolicyOutcome::Forbidden + } else if file.ends_with("git") { + ExecPolicyOutcome::Prompt { + run_with_escalated_permissions: false, + } + } else if file == Path::new("/opt/homebrew/bin/gh") + && let [_, arg1, arg2, ..] = argv + && arg1 == "issue" + && arg2 == "list" + { + ExecPolicyOutcome::Allow { + run_with_escalated_permissions: true, + } + } else { + ExecPolicyOutcome::Allow { + run_with_escalated_permissions: false, + } + } +} diff --git a/codex-rs/exec-server/src/posix/escalate_client.rs b/codex-rs/exec-server/src/posix/escalate_client.rs new file mode 100644 index 000000000..bea4b6fa5 --- /dev/null +++ b/codex-rs/exec-server/src/posix/escalate_client.rs @@ -0,0 +1,109 @@ +use std::io; +use std::os::fd::AsRawFd; +use std::os::fd::FromRawFd as _; +use std::os::fd::OwnedFd; + +use anyhow::Context as _; + +use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR; +use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR; +use crate::posix::escalate_protocol::EscalateAction; +use crate::posix::escalate_protocol::EscalateRequest; +use crate::posix::escalate_protocol::EscalateResponse; +use crate::posix::escalate_protocol::SuperExecMessage; +use crate::posix::escalate_protocol::SuperExecResult; +use crate::posix::socket::AsyncDatagramSocket; +use crate::posix::socket::AsyncSocket; + +fn get_escalate_client() -> anyhow::Result { + // TODO: we should defensively require only calling this once, since AsyncSocket will take ownership of the fd. + let client_fd = std::env::var(ESCALATE_SOCKET_ENV_VAR)?.parse::()?; + if client_fd < 0 { + return Err(anyhow::anyhow!( + "{ESCALATE_SOCKET_ENV_VAR} is not a valid file descriptor: {client_fd}" + )); + } + Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?) +} + +pub(crate) async fn run(file: String, argv: Vec) -> anyhow::Result { + let handshake_client = get_escalate_client()?; + let (server, client) = AsyncSocket::pair()?; + const HANDSHAKE_MESSAGE: [u8; 1] = [0]; + handshake_client + .send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()]) + .await + .context("failed to send handshake datagram")?; + let env = std::env::vars() + .filter(|(k, _)| { + !matches!( + k.as_str(), + ESCALATE_SOCKET_ENV_VAR | BASH_EXEC_WRAPPER_ENV_VAR + ) + }) + .collect(); + client + .send(EscalateRequest { + file: file.clone().into(), + argv: argv.clone(), + workdir: std::env::current_dir()?, + env, + }) + .await + .context("failed to send EscalateRequest")?; + let message = client.receive::().await?; + match message.action { + EscalateAction::Escalate => { + // TODO: maybe we should send ALL open FDs (except the escalate client)? + let fds_to_send = [ + unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) }, + unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) }, + unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) }, + ]; + + // TODO: also forward signals over the super-exec socket + + client + .send_with_fds( + SuperExecMessage { + fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(), + }, + &fds_to_send, + ) + .await + .context("failed to send SuperExecMessage")?; + let SuperExecResult { exit_code } = client.receive::().await?; + Ok(exit_code) + } + EscalateAction::Run => { + // We avoid std::process::Command here because we want to be as transparent as + // possible. std::os::unix::process::CommandExt has .exec() but it does some funky + // stuff with signal masks and dup2() on its standard FDs, which we don't want. + use std::ffi::CString; + let file = CString::new(file).context("NUL in file")?; + + let argv_cstrs: Vec = argv + .iter() + .map(|s| CString::new(s.as_str()).context("NUL in argv")) + .collect::, _>>()?; + + let mut argv: Vec<*const libc::c_char> = + argv_cstrs.iter().map(|s| s.as_ptr()).collect(); + argv.push(std::ptr::null()); + + let err = unsafe { + libc::execv(file.as_ptr(), argv.as_ptr()); + std::io::Error::last_os_error() + }; + + Err(err.into()) + } + EscalateAction::Deny { reason } => { + match reason { + Some(reason) => eprintln!("Execution denied: {reason}"), + None => eprintln!("Execution denied"), + } + Ok(1) + } + } +} diff --git a/codex-rs/exec-server/src/posix/escalate_protocol.rs b/codex-rs/exec-server/src/posix/escalate_protocol.rs new file mode 100644 index 000000000..e3fc27d07 --- /dev/null +++ b/codex-rs/exec-server/src/posix/escalate_protocol.rs @@ -0,0 +1,51 @@ +use std::collections::HashMap; +use std::os::fd::RawFd; +use std::path::PathBuf; + +use serde::Deserialize; +use serde::Serialize; + +/// 'exec-server escalate' reads this to find the inherited FD for the escalate socket. +pub(super) const ESCALATE_SOCKET_ENV_VAR: &str = "CODEX_ESCALATE_SOCKET"; + +/// The patched bash uses this to wrap exec() calls. +pub(super) const BASH_EXEC_WRAPPER_ENV_VAR: &str = "BASH_EXEC_WRAPPER"; + +/// The client sends this to the server to request an exec() call. +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)] +pub(super) struct EscalateRequest { + /// The absolute path to the executable to run, i.e. the first arg to exec. + pub(super) file: PathBuf, + /// The argv, including the program name (argv[0]). + pub(super) argv: Vec, + pub(super) workdir: PathBuf, + pub(super) env: HashMap, +} + +/// The server sends this to the client to respond to an exec() request. +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)] +pub(super) struct EscalateResponse { + pub(super) action: EscalateAction, +} + +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)] +pub(super) enum EscalateAction { + /// The command should be run directly by the client. + Run, + /// The command should be escalated to the server for execution. + Escalate, + /// The command should not be executed. + Deny { reason: Option }, +} + +/// The client sends this to the server to forward its open FDs. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub(super) struct SuperExecMessage { + pub(super) fds: Vec, +} + +/// The server responds when the exec()'d command has exited. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub(super) struct SuperExecResult { + pub(super) exit_code: i32, +} diff --git a/codex-rs/exec-server/src/posix/escalate_server.rs b/codex-rs/exec-server/src/posix/escalate_server.rs new file mode 100644 index 000000000..784562f2f --- /dev/null +++ b/codex-rs/exec-server/src/posix/escalate_server.rs @@ -0,0 +1,319 @@ +use std::collections::HashMap; +use std::os::fd::AsRawFd; +use std::path::PathBuf; +use std::process::Stdio; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Context as _; +use path_absolutize::Absolutize as _; + +use codex_core::exec::SandboxType; +use codex_core::exec::process_exec_tool_call; +use codex_core::get_platform_sandbox; +use codex_core::protocol::SandboxPolicy; +use tokio::process::Command; +use tokio_util::sync::CancellationToken; + +use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR; +use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR; +use crate::posix::escalate_protocol::EscalateAction; +use crate::posix::escalate_protocol::EscalateRequest; +use crate::posix::escalate_protocol::EscalateResponse; +use crate::posix::escalate_protocol::SuperExecMessage; +use crate::posix::escalate_protocol::SuperExecResult; +use crate::posix::escalation_policy::EscalationPolicy; +use crate::posix::socket::AsyncDatagramSocket; +use crate::posix::socket::AsyncSocket; +use codex_core::exec::ExecExpiration; + +pub(crate) struct EscalateServer { + bash_path: PathBuf, + execve_wrapper: PathBuf, + policy: Arc, +} + +impl EscalateServer { + pub fn new

(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self + where + P: EscalationPolicy + Send + Sync + 'static, + { + Self { + bash_path, + execve_wrapper, + policy: Arc::new(policy), + } + } + + pub async fn exec( + &self, + command: String, + env: HashMap, + workdir: PathBuf, + cancel_rx: CancellationToken, + ) -> anyhow::Result { + let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?; + let client_socket = escalate_client.into_inner(); + client_socket.set_cloexec(false)?; + + let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone())); + let mut env = env.clone(); + env.insert( + ESCALATE_SOCKET_ENV_VAR.to_string(), + client_socket.as_raw_fd().to_string(), + ); + env.insert( + BASH_EXEC_WRAPPER_ENV_VAR.to_string(), + self.execve_wrapper.to_string_lossy().to_string(), + ); + + // TODO: use the sandbox policy and cwd from the calling client. + // Note that sandbox_cwd is ignored for ReadOnly, but needs to be legit + // for `SandboxPolicy::WorkspaceWrite`. + let sandbox_policy = SandboxPolicy::ReadOnly; + let sandbox_cwd = PathBuf::from("/__NONEXISTENT__"); + + let result = process_exec_tool_call( + codex_core::exec::ExecParams { + command: vec![ + self.bash_path.to_string_lossy().to_string(), + "-c".to_string(), + command, + ], + cwd: PathBuf::from(&workdir), + expiration: ExecExpiration::Cancellation(cancel_rx), + env, + with_escalated_permissions: None, + justification: None, + arg0: None, + }, + get_platform_sandbox().unwrap_or(SandboxType::None), + &sandbox_policy, + &sandbox_cwd, + &None, + None, + ) + .await?; + escalate_task.abort(); + let result = ExecResult { + exit_code: result.exit_code, + output: result.aggregated_output.text, + duration: result.duration, + timed_out: result.timed_out, + }; + Ok(result) + } +} + +async fn escalate_task( + socket: AsyncDatagramSocket, + policy: Arc, +) -> anyhow::Result<()> { + loop { + let (_, mut fds) = socket.receive_with_fds().await?; + if fds.len() != 1 { + tracing::error!("expected 1 fd in datagram handshake, got {}", fds.len()); + continue; + } + let stream_socket = AsyncSocket::from_fd(fds.remove(0))?; + let policy = policy.clone(); + tokio::spawn(async move { + if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await { + tracing::error!("escalate session failed: {err:?}"); + } + }); + } +} + +#[derive(Debug)] +pub(crate) struct ExecResult { + pub(crate) exit_code: i32, + pub(crate) output: String, + pub(crate) duration: Duration, + pub(crate) timed_out: bool, +} + +async fn handle_escalate_session_with_policy( + socket: AsyncSocket, + policy: Arc, +) -> anyhow::Result<()> { + let EscalateRequest { + file, + argv, + workdir, + env, + } = socket.receive::().await?; + let file = PathBuf::from(&file).absolutize()?.into_owned(); + let workdir = PathBuf::from(&workdir).absolutize()?.into_owned(); + let action = policy + .determine_action(file.as_path(), &argv, &workdir) + .await?; + + tracing::debug!("decided {action:?} for {file:?} {argv:?} {workdir:?}"); + + match action { + EscalateAction::Run => { + socket + .send(EscalateResponse { + action: EscalateAction::Run, + }) + .await?; + } + EscalateAction::Escalate => { + socket + .send(EscalateResponse { + action: EscalateAction::Escalate, + }) + .await?; + let (msg, fds) = socket + .receive_with_fds::() + .await + .context("failed to receive SuperExecMessage")?; + if fds.len() != msg.fds.len() { + return Err(anyhow::anyhow!( + "mismatched number of fds in SuperExecMessage: {} in the message, {} from the control message", + msg.fds.len(), + fds.len() + )); + } + + if msg + .fds + .iter() + .any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd)) + { + return Err(anyhow::anyhow!( + "overlapping fds not yet supported in SuperExecMessage" + )); + } + + let mut command = Command::new(file); + command + .args(&argv[1..]) + .arg0(argv[0].clone()) + .envs(&env) + .current_dir(&workdir) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + unsafe { + command.pre_exec(move || { + for (dst_fd, src_fd) in msg.fds.iter().zip(&fds) { + libc::dup2(src_fd.as_raw_fd(), *dst_fd); + } + Ok(()) + }); + } + let mut child = command.spawn()?; + let exit_status = child.wait().await?; + socket + .send(SuperExecResult { + exit_code: exit_status.code().unwrap_or(127), + }) + .await?; + } + EscalateAction::Deny { reason } => { + socket + .send(EscalateResponse { + action: EscalateAction::Deny { reason }, + }) + .await?; + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use std::collections::HashMap; + use std::path::Path; + use std::path::PathBuf; + + struct DeterministicEscalationPolicy { + action: EscalateAction, + } + + #[async_trait::async_trait] + impl EscalationPolicy for DeterministicEscalationPolicy { + async fn determine_action( + &self, + _file: &Path, + _argv: &[String], + _workdir: &Path, + ) -> Result { + Ok(self.action.clone()) + } + } + + #[tokio::test] + async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> { + let (server, client) = AsyncSocket::pair()?; + let server_task = tokio::spawn(handle_escalate_session_with_policy( + server, + Arc::new(DeterministicEscalationPolicy { + action: EscalateAction::Run, + }), + )); + + client + .send(EscalateRequest { + file: PathBuf::from("/bin/echo"), + argv: vec!["echo".to_string()], + workdir: PathBuf::from("/tmp"), + env: HashMap::new(), + }) + .await?; + + let response = client.receive::().await?; + assert_eq!( + EscalateResponse { + action: EscalateAction::Run, + }, + response + ); + server_task.await? + } + + #[tokio::test] + async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> { + let (server, client) = AsyncSocket::pair()?; + let server_task = tokio::spawn(handle_escalate_session_with_policy( + server, + Arc::new(DeterministicEscalationPolicy { + action: EscalateAction::Escalate, + }), + )); + + client + .send(EscalateRequest { + file: PathBuf::from("/bin/sh"), + argv: vec![ + "sh".to_string(), + "-c".to_string(), + r#"if [ "$KEY" = VALUE ]; then exit 42; else exit 1; fi"#.to_string(), + ], + workdir: std::env::current_dir()?, + env: HashMap::from([("KEY".to_string(), "VALUE".to_string())]), + }) + .await?; + + let response = client.receive::().await?; + assert_eq!( + EscalateResponse { + action: EscalateAction::Escalate, + }, + response + ); + + client + .send_with_fds(SuperExecMessage { fds: Vec::new() }, &[]) + .await?; + + let result = client.receive::().await?; + assert_eq!(42, result.exit_code); + + server_task.await? + } +} diff --git a/codex-rs/exec-server/src/posix/escalation_policy.rs b/codex-rs/exec-server/src/posix/escalation_policy.rs new file mode 100644 index 000000000..a7fcc4f62 --- /dev/null +++ b/codex-rs/exec-server/src/posix/escalation_policy.rs @@ -0,0 +1,14 @@ +use std::path::Path; + +use crate::posix::escalate_protocol::EscalateAction; + +/// Decides what action to take in response to an execve request from a client. +#[async_trait::async_trait] +pub(crate) trait EscalationPolicy: Send + Sync { + async fn determine_action( + &self, + file: &Path, + argv: &[String], + workdir: &Path, + ) -> Result; +} diff --git a/codex-rs/exec-server/src/posix/mcp.rs b/codex-rs/exec-server/src/posix/mcp.rs new file mode 100644 index 000000000..b2f9b6de4 --- /dev/null +++ b/codex-rs/exec-server/src/posix/mcp.rs @@ -0,0 +1,149 @@ +use std::path::PathBuf; +use std::time::Duration; + +use anyhow::Context as _; +use anyhow::Result; +use rmcp::ErrorData as McpError; +use rmcp::RoleServer; +use rmcp::ServerHandler; +use rmcp::ServiceExt; +use rmcp::handler::server::router::tool::ToolRouter; +use rmcp::handler::server::wrapper::Parameters; +use rmcp::model::*; +use rmcp::schemars; +use rmcp::service::RequestContext; +use rmcp::service::RunningService; +use rmcp::tool; +use rmcp::tool_handler; +use rmcp::tool_router; +use rmcp::transport::stdio; + +use crate::posix::escalate_server::EscalateServer; +use crate::posix::escalate_server::{self}; +use crate::posix::mcp_escalation_policy::ExecPolicy; +use crate::posix::mcp_escalation_policy::McpEscalationPolicy; +use crate::posix::stopwatch::Stopwatch; + +/// Path to our patched bash. +const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH"; + +pub(crate) fn get_bash_path() -> Result { + std::env::var(CODEX_BASH_PATH_ENV_VAR) + .map(PathBuf::from) + .context(format!("{CODEX_BASH_PATH_ENV_VAR} must be set")) +} + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct ExecParams { + /// The bash string to execute. + pub command: String, + /// The working directory to execute the command in. Must be an absolute path. + pub workdir: String, + /// The timeout for the command in milliseconds. + pub timeout_ms: Option, +} + +#[derive(Debug, serde::Serialize, schemars::JsonSchema)] +pub struct ExecResult { + pub exit_code: i32, + pub output: String, + pub duration: Duration, + pub timed_out: bool, +} + +impl From for ExecResult { + fn from(result: escalate_server::ExecResult) -> Self { + Self { + exit_code: result.exit_code, + output: result.output, + duration: result.duration, + timed_out: result.timed_out, + } + } +} + +#[derive(Clone)] +pub struct ExecTool { + tool_router: ToolRouter, + bash_path: PathBuf, + execve_wrapper: PathBuf, + policy: ExecPolicy, +} + +#[tool_router] +impl ExecTool { + pub fn new(bash_path: PathBuf, execve_wrapper: PathBuf, policy: ExecPolicy) -> Self { + Self { + tool_router: Self::tool_router(), + bash_path, + execve_wrapper, + policy, + } + } + + /// Runs a shell command and returns its output. You MUST provide the workdir as an absolute path. + #[tool] + async fn shell( + &self, + context: RequestContext, + Parameters(params): Parameters, + ) -> Result { + let effective_timeout = Duration::from_millis( + params + .timeout_ms + .unwrap_or(codex_core::exec::DEFAULT_EXEC_COMMAND_TIMEOUT_MS), + ); + let stopwatch = Stopwatch::new(effective_timeout); + let cancel_token = stopwatch.cancellation_token(); + let escalate_server = EscalateServer::new( + self.bash_path.clone(), + self.execve_wrapper.clone(), + McpEscalationPolicy::new(self.policy, context, stopwatch.clone()), + ); + let result = escalate_server + .exec( + params.command, + // TODO: use ShellEnvironmentPolicy + std::env::vars().collect(), + PathBuf::from(¶ms.workdir), + cancel_token, + ) + .await + .map_err(|e| McpError::internal_error(e.to_string(), None))?; + Ok(CallToolResult::success(vec![Content::json( + ExecResult::from(result), + )?])) + } +} + +#[tool_handler] +impl ServerHandler for ExecTool { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::V_2025_06_18, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation::from_build_env(), + instructions: Some( + "This server provides a tool to execute shell commands and return their output." + .to_string(), + ), + } + } + + async fn initialize( + &self, + _request: InitializeRequestParam, + _context: RequestContext, + ) -> Result { + Ok(self.get_info()) + } +} + +pub(crate) async fn serve( + bash_path: PathBuf, + execve_wrapper: PathBuf, + policy: ExecPolicy, +) -> Result, rmcp::service::ServerInitializeError> { + let tool = ExecTool::new(bash_path, execve_wrapper, policy); + tool.serve(stdio()).await +} diff --git a/codex-rs/exec-server/src/posix/mcp_escalation_policy.rs b/codex-rs/exec-server/src/posix/mcp_escalation_policy.rs new file mode 100644 index 000000000..9e059fdba --- /dev/null +++ b/codex-rs/exec-server/src/posix/mcp_escalation_policy.rs @@ -0,0 +1,146 @@ +use std::path::Path; + +use rmcp::ErrorData as McpError; +use rmcp::RoleServer; +use rmcp::model::CreateElicitationRequestParam; +use rmcp::model::CreateElicitationResult; +use rmcp::model::ElicitationAction; +use rmcp::model::ElicitationSchema; +use rmcp::service::RequestContext; + +use crate::posix::escalate_protocol::EscalateAction; +use crate::posix::escalation_policy::EscalationPolicy; +use crate::posix::stopwatch::Stopwatch; + +/// This is the policy which decides how to handle an exec() call. +/// +/// `file` is the absolute, canonical path to the executable to run, i.e. the first arg to exec. +/// `argv` is the argv, including the program name (`argv[0]`). +/// `workdir` is the absolute, canonical path to the working directory in which to execute the +/// command. +pub(crate) type ExecPolicy = fn(file: &Path, argv: &[String], workdir: &Path) -> ExecPolicyOutcome; + +pub(crate) enum ExecPolicyOutcome { + Allow { + run_with_escalated_permissions: bool, + }, + Prompt { + run_with_escalated_permissions: bool, + }, + Forbidden, +} + +/// ExecPolicy with access to the MCP RequestContext so that it can leverage +/// elicitations. +pub(crate) struct McpEscalationPolicy { + policy: ExecPolicy, + context: RequestContext, + stopwatch: Stopwatch, +} + +impl McpEscalationPolicy { + pub(crate) fn new( + policy: ExecPolicy, + context: RequestContext, + stopwatch: Stopwatch, + ) -> Self { + Self { + policy, + context, + stopwatch, + } + } + + async fn prompt( + &self, + file: &Path, + argv: &[String], + workdir: &Path, + context: RequestContext, + ) -> Result { + let args = shlex::try_join(argv.iter().skip(1).map(String::as_str)).unwrap_or_default(); + let command = if args.is_empty() { + file.display().to_string() + } else { + format!("{} {}", file.display(), args) + }; + self.stopwatch + .pause_for(async { + context + .peer + .create_elicitation(CreateElicitationRequestParam { + message: format!( + "Allow agent to run `{command}` in `{}`?", + workdir.display() + ), + requested_schema: ElicitationSchema::builder() + .title("Execution Permission Request") + .optional_string_with("reason", |schema| { + schema.description( + "Optional reason for allowing or denying execution", + ) + }) + .build() + .map_err(|e| { + McpError::internal_error( + format!("failed to build elicitation schema: {e}"), + None, + ) + })?, + }) + .await + .map_err(|e| McpError::internal_error(e.to_string(), None)) + }) + .await + } +} + +#[async_trait::async_trait] +impl EscalationPolicy for McpEscalationPolicy { + async fn determine_action( + &self, + file: &Path, + argv: &[String], + workdir: &Path, + ) -> Result { + let outcome = (self.policy)(file, argv, workdir); + let action = match outcome { + ExecPolicyOutcome::Allow { + run_with_escalated_permissions, + } => { + if run_with_escalated_permissions { + EscalateAction::Escalate + } else { + EscalateAction::Run + } + } + ExecPolicyOutcome::Prompt { + run_with_escalated_permissions, + } => { + let result = self + .prompt(file, argv, workdir, self.context.clone()) + .await?; + // TODO: Extract reason from `result.content`. + match result.action { + ElicitationAction::Accept => { + if run_with_escalated_permissions { + EscalateAction::Escalate + } else { + EscalateAction::Run + } + } + ElicitationAction::Decline => EscalateAction::Deny { + reason: Some("User declined execution".to_string()), + }, + ElicitationAction::Cancel => EscalateAction::Deny { + reason: Some("User cancelled execution".to_string()), + }, + } + } + ExecPolicyOutcome::Forbidden => EscalateAction::Deny { + reason: Some("Execution forbidden by policy".to_string()), + }, + }; + Ok(action) + } +} diff --git a/codex-rs/exec-server/src/posix/socket.rs b/codex-rs/exec-server/src/posix/socket.rs new file mode 100644 index 000000000..92c93dcc7 --- /dev/null +++ b/codex-rs/exec-server/src/posix/socket.rs @@ -0,0 +1,486 @@ +use libc::c_uint; +use serde::Deserialize; +use serde::Serialize; +use socket2::Domain; +use socket2::MaybeUninitSlice; +use socket2::MsgHdr; +use socket2::MsgHdrMut; +use socket2::Socket; +use socket2::Type; +use std::io::IoSlice; +use std::mem::MaybeUninit; +use std::os::fd::AsRawFd; +use std::os::fd::FromRawFd; +use std::os::fd::OwnedFd; +use std::os::fd::RawFd; +use tokio::io::Interest; +use tokio::io::unix::AsyncFd; + +const MAX_FDS_PER_MESSAGE: usize = 16; +const LENGTH_PREFIX_SIZE: usize = size_of::(); +const MAX_DATAGRAM_SIZE: usize = 8192; + +/// Converts a slice of MaybeUninit to a slice of T. +/// +/// The caller guarantees that every element of `buf` is initialized. +fn assume_init(buf: &[MaybeUninit]) -> &[T] { + unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) } +} + +fn assume_init_slice(buf: &[MaybeUninit; N]) -> &[T; N] { + unsafe { &*(buf as *const [MaybeUninit; N] as *const [T; N]) } +} + +fn assume_init_vec(mut buf: Vec>) -> Vec { + unsafe { + let ptr = buf.as_mut_ptr() as *mut T; + let len = buf.len(); + let cap = buf.capacity(); + std::mem::forget(buf); + Vec::from_raw_parts(ptr, len, cap) + } +} + +fn control_space_for_fds(count: usize) -> usize { + unsafe { libc::CMSG_SPACE((count * size_of::()) as _) as usize } +} + +/// Extracts the FDs from a SCM_RIGHTS control message. +fn extract_fds(control: &[u8]) -> Vec { + let mut fds = Vec::new(); + let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() }; + hdr.msg_control = control.as_ptr() as *mut libc::c_void; + hdr.msg_controllen = control.len() as _; + let hdr = hdr; // drop mut + + let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr }; + while !cmsg.is_null() { + let level = unsafe { (*cmsg).cmsg_level }; + let ty = unsafe { (*cmsg).cmsg_type }; + if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS { + let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::() }; + let fd_count: usize = { + let cmsg_data_len = + unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize }; + cmsg_data_len / size_of::() + }; + for i in 0..fd_count { + let fd = unsafe { data_ptr.add(i).read() }; + fds.push(unsafe { OwnedFd::from_raw_fd(fd) }); + } + } + cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) }; + } + fds +} + +/// Read a frame from a SOCK_STREAM socket. +/// +/// A frame is a message length prefix followed by a payload. FDs may be included in the control +/// message when receiving the frame header. +async fn read_frame(async_socket: &AsyncFd) -> std::io::Result<(Vec, Vec)> { + let (message_len, fds) = read_frame_header(async_socket).await?; + let payload = read_frame_payload(async_socket, message_len).await?; + Ok((payload, fds)) +} + +/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket. +async fn read_frame_header( + async_socket: &AsyncFd, +) -> std::io::Result<(usize, Vec)> { + let mut header = [MaybeUninit::::uninit(); LENGTH_PREFIX_SIZE]; + let mut filled = 0; + let mut control = vec![MaybeUninit::::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)]; + let mut captured_control = false; + + while filled < LENGTH_PREFIX_SIZE { + let mut guard = async_socket.readable().await?; + // The first read should come with a control message containing any FDs. + let result = if !captured_control { + guard.try_io(|inner| { + let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])]; + let (read, control_len) = { + let mut msg = MsgHdrMut::new() + .with_buffers(&mut bufs) + .with_control(&mut control); + let read = inner.get_ref().recvmsg(&mut msg, 0)?; + (read, msg.control_len()) + }; + control.truncate(control_len); + captured_control = true; + Ok(read) + }) + } else { + guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) + }; + let Ok(result) = result else { + // Would block, try again. + continue; + }; + + let read = result?; + if read == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "socket closed while receiving frame header", + )); + } + + filled += read; + assert!(filled <= LENGTH_PREFIX_SIZE); + if filled == LENGTH_PREFIX_SIZE { + let len_bytes = assume_init_slice(&header); + let payload_len = u32::from_le_bytes(*len_bytes) as usize; + let fds = extract_fds(assume_init(&control)); + return Ok((payload_len, fds)); + } + } + unreachable!("header loop always returns") +} + +/// Read `message_len` bytes from a SOCK_STREAM socket. +async fn read_frame_payload( + async_socket: &AsyncFd, + message_len: usize, +) -> std::io::Result> { + if message_len == 0 { + return Ok(Vec::new()); + } + let mut payload = vec![MaybeUninit::::uninit(); message_len]; + let mut filled = 0; + while filled < message_len { + let mut guard = async_socket.readable().await?; + let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])); + let Ok(result) = result else { + // Would block, try again. + continue; + }; + let read = result?; + if read == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "socket closed while receiving frame payload", + )); + } + filled += read; + assert!(filled <= message_len); + if filled == message_len { + return Ok(assume_init_vec(payload)); + } + } + unreachable!("loop exits only after returning payload") +} + +fn send_message_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { + if fds.len() > MAX_FDS_PER_MESSAGE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("too many fds: {}", fds.len()), + )); + } + let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + data.len()); + frame.extend_from_slice(&encode_length(data.len())?); + frame.extend_from_slice(data); + + let mut control = vec![0u8; control_space_for_fds(fds.len())]; + unsafe { + let cmsg = control.as_mut_ptr().cast::(); + (*cmsg).cmsg_len = libc::CMSG_LEN(size_of::() as c_uint * fds.len() as c_uint) as _; + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = libc::SCM_RIGHTS; + let data_ptr = libc::CMSG_DATA(cmsg).cast::(); + for (i, fd) in fds.iter().enumerate() { + data_ptr.add(i).write(fd.as_raw_fd()); + } + } + + let payload = [IoSlice::new(&frame)]; + let msg = MsgHdr::new().with_buffers(&payload).with_control(&control); + let mut sent = socket.sendmsg(&msg, 0)?; + while sent < frame.len() { + let bytes = socket.send(&frame[sent..])?; + if bytes == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "socket closed while sending frame payload", + )); + } + sent += bytes; + } + Ok(()) +} + +fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> { + let len_u32 = u32::try_from(len).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("message too large: {len}"), + ) + })?; + Ok(len_u32.to_le_bytes()) +} + +pub(crate) fn send_json_message( + socket: &Socket, + msg: T, + fds: &[OwnedFd], +) -> std::io::Result<()> { + let data = serde_json::to_vec(&msg)?; + send_message_bytes(socket, &data, fds) +} + +fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { + if fds.len() > MAX_FDS_PER_MESSAGE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("too many fds: {}", fds.len()), + )); + } + let mut control = vec![0u8; control_space_for_fds(fds.len())]; + if !fds.is_empty() { + unsafe { + let cmsg = control.as_mut_ptr().cast::(); + (*cmsg).cmsg_len = + libc::CMSG_LEN(size_of::() as c_uint * fds.len() as c_uint) as _; + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = libc::SCM_RIGHTS; + let data_ptr = libc::CMSG_DATA(cmsg).cast::(); + for (i, fd) in fds.iter().enumerate() { + data_ptr.add(i).write(fd.as_raw_fd()); + } + } + } + let payload = [IoSlice::new(data)]; + let msg = MsgHdr::new().with_buffers(&payload).with_control(&control); + let written = socket.sendmsg(&msg, 0)?; + if written != data.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + format!( + "short datagram write: wrote {written} bytes out of {}", + data.len() + ), + )); + } + Ok(()) +} + +fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec, Vec)> { + let mut buffer = vec![MaybeUninit::::uninit(); MAX_DATAGRAM_SIZE]; + let mut control = vec![MaybeUninit::::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)]; + let (read, control_len) = { + let mut bufs = [MaybeUninitSlice::new(&mut buffer)]; + let mut msg = MsgHdrMut::new() + .with_buffers(&mut bufs) + .with_control(&mut control); + let read = socket.recvmsg(&mut msg, 0)?; + (read, msg.control_len()) + }; + let data = assume_init(&buffer[..read]).to_vec(); + let fds = extract_fds(assume_init(&control[..control_len])); + Ok((data, fds)) +} + +pub(crate) struct AsyncSocket { + inner: AsyncFd, +} + +impl AsyncSocket { + fn new(socket: Socket) -> std::io::Result { + socket.set_nonblocking(true)?; + let async_socket = AsyncFd::new(socket)?; + Ok(AsyncSocket { + inner: async_socket, + }) + } + + pub fn from_fd(fd: OwnedFd) -> std::io::Result { + AsyncSocket::new(Socket::from(fd)) + } + + pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> { + let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?)) + } + + pub async fn send_with_fds( + &self, + msg: T, + fds: &[OwnedFd], + ) -> std::io::Result<()> { + self.inner + .async_io(Interest::WRITABLE, |socket| { + send_json_message(socket, &msg, fds) + }) + .await + } + + pub async fn receive_with_fds Deserialize<'de>>( + &self, + ) -> std::io::Result<(T, Vec)> { + let (payload, fds) = read_frame(&self.inner).await?; + let message: T = serde_json::from_slice(&payload)?; + Ok((message, fds)) + } + + pub async fn send(&self, msg: T) -> std::io::Result<()> + where + T: Serialize, + { + self.send_with_fds(&msg, &[]).await + } + + pub async fn receive Deserialize<'de>>(&self) -> std::io::Result { + let (msg, fds) = self.receive_with_fds().await?; + if !fds.is_empty() { + tracing::warn!("unexpected fds in receive: {}", fds.len()); + } + Ok(msg) + } + + pub fn into_inner(self) -> Socket { + self.inner.into_inner() + } +} + +pub(crate) struct AsyncDatagramSocket { + inner: AsyncFd, +} + +impl AsyncDatagramSocket { + fn new(socket: Socket) -> std::io::Result { + socket.set_nonblocking(true)?; + Ok(Self { + inner: AsyncFd::new(socket)?, + }) + } + + pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result { + Self::new(unsafe { Socket::from_raw_fd(fd) }) + } + + pub fn pair() -> std::io::Result<(Self, Self)> { + let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + Ok((Self::new(server)?, Self::new(client)?)) + } + + pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { + self.inner + .async_io(Interest::WRITABLE, |socket| { + send_datagram_bytes(socket, data, fds) + }) + .await + } + + pub async fn receive_with_fds(&self) -> std::io::Result<(Vec, Vec)> { + self.inner + .async_io(Interest::READABLE, receive_datagram_bytes) + .await + } + + pub fn into_inner(self) -> Socket { + self.inner.into_inner() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use serde::Deserialize; + use serde::Serialize; + use std::os::fd::AsFd; + use std::os::fd::AsRawFd; + use tempfile::NamedTempFile; + + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] + struct TestPayload { + id: i32, + label: String, + } + + fn fd_list(count: usize) -> std::io::Result> { + let file = NamedTempFile::new()?; + let mut fds = Vec::new(); + for _ in 0..count { + fds.push(file.as_fd().try_clone_to_owned()?); + } + Ok(fds) + } + + #[tokio::test] + async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> { + let (server, client) = AsyncSocket::pair()?; + let payload = TestPayload { + id: 7, + label: "round-trip".to_string(), + }; + let send_fds = fd_list(1)?; + + let receive_task = + tokio::spawn(async move { server.receive_with_fds::().await }); + client.send_with_fds(payload.clone(), &send_fds).await?; + drop(send_fds); + + let (received_payload, received_fds) = receive_task.await.unwrap()?; + assert_eq!(payload, received_payload); + assert_eq!(1, received_fds.len()); + let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) }; + assert!( + fd_status >= 0, + "expected received file descriptor to be valid, but got {fd_status}", + ); + Ok(()) + } + + #[tokio::test] + async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> { + let (server, client) = AsyncDatagramSocket::pair()?; + let data = b"datagram payload".to_vec(); + let send_fds = fd_list(1)?; + let receive_task = tokio::spawn(async move { server.receive_with_fds().await }); + + client.send_with_fds(&data, &send_fds).await?; + drop(send_fds); + + let (received_bytes, received_fds) = receive_task.await.unwrap()?; + assert_eq!(data, received_bytes); + assert_eq!(1, received_fds.len()); + Ok(()) + } + + #[test] + fn send_message_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { + let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; + let err = send_message_bytes(&socket, b"hello", &fds).unwrap_err(); + assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); + Ok(()) + } + + #[test] + fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { + let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; + let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err(); + assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); + Ok(()) + } + + #[test] + fn encode_length_errors_for_oversized_messages() { + let err = encode_length(usize::MAX).unwrap_err(); + assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); + } + + #[tokio::test] + async fn receive_fails_when_peer_closes_before_header() { + let (server, client) = AsyncSocket::pair().expect("failed to create socket pair"); + drop(client); + let err = server + .receive::() + .await + .expect_err("expected read failure"); + assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind()); + } +} diff --git a/codex-rs/exec-server/src/posix/stopwatch.rs b/codex-rs/exec-server/src/posix/stopwatch.rs new file mode 100644 index 000000000..de29a4568 --- /dev/null +++ b/codex-rs/exec-server/src/posix/stopwatch.rs @@ -0,0 +1,211 @@ +use std::future::Future; +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use tokio::sync::Mutex; +use tokio::sync::Notify; +use tokio_util::sync::CancellationToken; + +#[derive(Clone, Debug)] +pub(crate) struct Stopwatch { + limit: Duration, + inner: Arc>, + notify: Arc, +} + +#[derive(Debug)] +struct StopwatchState { + elapsed: Duration, + running_since: Option, + active_pauses: u32, +} + +impl Stopwatch { + pub(crate) fn new(limit: Duration) -> Self { + Self { + inner: Arc::new(Mutex::new(StopwatchState { + elapsed: Duration::ZERO, + running_since: Some(Instant::now()), + active_pauses: 0, + })), + notify: Arc::new(Notify::new()), + limit, + } + } + + pub(crate) fn cancellation_token(&self) -> CancellationToken { + let limit = self.limit; + let token = CancellationToken::new(); + let cancel = token.clone(); + let inner = Arc::clone(&self.inner); + let notify = Arc::clone(&self.notify); + tokio::spawn(async move { + loop { + let (remaining, running) = { + let guard = inner.lock().await; + let elapsed = guard.elapsed + + guard + .running_since + .map(|since| since.elapsed()) + .unwrap_or_default(); + if elapsed >= limit { + break; + } + (limit - elapsed, guard.running_since.is_some()) + }; + + if !running { + notify.notified().await; + continue; + } + + let sleep = tokio::time::sleep(remaining); + tokio::pin!(sleep); + tokio::select! { + _ = &mut sleep => { + break; + } + _ = notify.notified() => { + continue; + } + } + } + cancel.cancel(); + }); + token + } + + /// Runs `fut`, pausing the stopwatch while the future is pending. The clock + /// resumes automatically when the future completes. Nested/overlapping + /// calls are reference-counted so the stopwatch only resumes when every + /// pause is lifted. + pub(crate) async fn pause_for(&self, fut: F) -> T + where + F: Future, + { + self.pause().await; + let result = fut.await; + self.resume().await; + result + } + + async fn pause(&self) { + let mut guard = self.inner.lock().await; + guard.active_pauses += 1; + if guard.active_pauses == 1 + && let Some(since) = guard.running_since.take() + { + guard.elapsed += since.elapsed(); + self.notify.notify_waiters(); + } + } + + async fn resume(&self) { + let mut guard = self.inner.lock().await; + if guard.active_pauses == 0 { + return; + } + guard.active_pauses -= 1; + if guard.active_pauses == 0 && guard.running_since.is_none() { + guard.running_since = Some(Instant::now()); + self.notify.notify_waiters(); + } + } +} + +#[cfg(test)] +mod tests { + use super::Stopwatch; + use tokio::time::Duration; + use tokio::time::Instant; + use tokio::time::sleep; + use tokio::time::timeout; + + #[tokio::test] + async fn cancellation_receiver_fires_after_limit() { + let stopwatch = Stopwatch::new(Duration::from_millis(50)); + let token = stopwatch.cancellation_token(); + let start = Instant::now(); + token.cancelled().await; + assert!(start.elapsed() >= Duration::from_millis(50)); + } + + #[tokio::test] + async fn pause_prevents_timeout_until_resumed() { + let stopwatch = Stopwatch::new(Duration::from_millis(50)); + let token = stopwatch.cancellation_token(); + + let pause_handle = tokio::spawn({ + let stopwatch = stopwatch.clone(); + async move { + stopwatch + .pause_for(async { + sleep(Duration::from_millis(100)).await; + }) + .await; + } + }); + + assert!( + timeout(Duration::from_millis(30), token.cancelled()) + .await + .is_err() + ); + + pause_handle.await.expect("pause task should finish"); + + token.cancelled().await; + } + + #[tokio::test] + async fn overlapping_pauses_only_resume_once() { + let stopwatch = Stopwatch::new(Duration::from_millis(50)); + let token = stopwatch.cancellation_token(); + + // First pause. + let pause1 = { + let stopwatch = stopwatch.clone(); + tokio::spawn(async move { + stopwatch + .pause_for(async { + sleep(Duration::from_millis(80)).await; + }) + .await; + }) + }; + + // Overlapping pause that ends sooner. + let pause2 = { + let stopwatch = stopwatch.clone(); + tokio::spawn(async move { + stopwatch + .pause_for(async { + sleep(Duration::from_millis(30)).await; + }) + .await; + }) + }; + + // While both pauses are active, the cancellation should not fire. + assert!( + timeout(Duration::from_millis(40), token.cancelled()) + .await + .is_err() + ); + + pause2.await.expect("short pause should complete"); + + // Still paused because the long pause is active. + assert!( + timeout(Duration::from_millis(30), token.cancelled()) + .await + .is_err() + ); + + pause1.await.expect("long pause should complete"); + + // Now the stopwatch should resume and hit the limit shortly after. + token.cancelled().await; + } +} diff --git a/codex-rs/exec/src/cli.rs b/codex-rs/exec/src/cli.rs index ef20bf6fc..6866bc0ff 100644 --- a/codex-rs/exec/src/cli.rs +++ b/codex-rs/exec/src/cli.rs @@ -101,7 +101,7 @@ pub struct ResumeArgs { pub session_id: Option, /// Resume the most recent recorded session (newest) without specifying an id. - #[arg(long = "last", default_value_t = false, conflicts_with = "session_id")] + #[arg(long = "last", default_value_t = false)] pub last: bool, /// Prompt to send after resuming the session. If `-` is used, read from stdin. diff --git a/codex-rs/exec/src/event_processor_with_human_output.rs b/codex-rs/exec/src/event_processor_with_human_output.rs index 8c7bb6881..f0bb70720 100644 --- a/codex-rs/exec/src/event_processor_with_human_output.rs +++ b/codex-rs/exec/src/event_processor_with_human_output.rs @@ -161,7 +161,7 @@ impl EventProcessor for EventProcessorWithHumanOutput { fn process_event(&mut self, event: Event) -> CodexStatus { let Event { id: _, msg } = event; match msg { - EventMsg::Error(ErrorEvent { message }) => { + EventMsg::Error(ErrorEvent { message, .. }) => { let prefix = "ERROR:".style(self.red); ts_msg!(self, "{prefix} {message}"); } @@ -221,7 +221,7 @@ impl EventProcessor for EventProcessorWithHumanOutput { EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => { ts_msg!(self, "{}", message.style(self.dimmed)); } - EventMsg::StreamError(StreamErrorEvent { message }) => { + EventMsg::StreamError(StreamErrorEvent { message, .. }) => { ts_msg!(self, "{}", message.style(self.dimmed)); } EventMsg::TaskStarted(_) => { @@ -346,6 +346,7 @@ impl EventProcessor for EventProcessorWithHumanOutput { call_id, auto_approved, changes, + .. }) => { // Store metadata so we can calculate duration later when we // receive the corresponding PatchApplyEnd event. @@ -480,11 +481,7 @@ impl EventProcessor for EventProcessorWithHumanOutput { let SessionConfiguredEvent { session_id: conversation_id, model, - reasoning_effort: _, - history_log_id: _, - history_entry_count: _, - initial_messages: _, - rollout_path: _, + .. } = session_configured_event; ts_msg!( diff --git a/codex-rs/exec/src/exec_events.rs b/codex-rs/exec/src/exec_events.rs index 64288e138..f3726dad7 100644 --- a/codex-rs/exec/src/exec_events.rs +++ b/codex-rs/exec/src/exec_events.rs @@ -144,6 +144,7 @@ pub enum CommandExecutionStatus { InProgress, Completed, Failed, + Declined, } /// A command executed by the agent. @@ -166,6 +167,7 @@ pub struct FileUpdateChange { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)] #[serde(rename_all = "snake_case")] pub enum PatchApplyStatus { + InProgress, Completed, Failed, } diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index a003b4ff2..830bb841c 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -82,7 +82,21 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any let prompt_arg = match &command { // Allow prompt before the subcommand by falling back to the parent-level prompt // when the Resume subcommand did not provide its own prompt. - Some(ExecCommand::Resume(args)) => args.prompt.clone().or(prompt), + Some(ExecCommand::Resume(args)) => { + let resume_prompt = args + .prompt + .clone() + // When using `resume --last `, clap still parses the first positional + // as `session_id`. Reinterpret it as the prompt so the flag works with JSON mode. + .or_else(|| { + if args.last { + args.session_id.clone() + } else { + None + } + }); + resume_prompt.or(prompt) + } None => prompt, }; diff --git a/codex-rs/exec/tests/event_processor_with_json_output.rs b/codex-rs/exec/tests/event_processor_with_json_output.rs index 7a6245ae7..06f99b96b 100644 --- a/codex-rs/exec/tests/event_processor_with_json_output.rs +++ b/codex-rs/exec/tests/event_processor_with_json_output.rs @@ -1,5 +1,6 @@ use codex_core::protocol::AgentMessageEvent; use codex_core::protocol::AgentReasoningEvent; +use codex_core::protocol::AskForApproval; use codex_core::protocol::ErrorEvent; use codex_core::protocol::Event; use codex_core::protocol::EventMsg; @@ -12,6 +13,7 @@ use codex_core::protocol::McpToolCallBeginEvent; use codex_core::protocol::McpToolCallEndEvent; use codex_core::protocol::PatchApplyBeginEvent; use codex_core::protocol::PatchApplyEndEvent; +use codex_core::protocol::SandboxPolicy; use codex_core::protocol::SessionConfiguredEvent; use codex_core::protocol::WarningEvent; use codex_core::protocol::WebSearchEndEvent; @@ -45,6 +47,7 @@ use codex_exec::exec_events::WebSearchItem; use codex_protocol::plan_tool::PlanItemArg; use codex_protocol::plan_tool::StepStatus; use codex_protocol::plan_tool::UpdatePlanArgs; +use codex_protocol::protocol::CodexErrorInfo; use mcp_types::CallToolResult; use mcp_types::ContentBlock; use mcp_types::TextContent; @@ -72,6 +75,10 @@ fn session_configured_produces_thread_started_event() { EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, model: "codex-mini-latest".to_string(), + model_provider_id: "test-provider".to_string(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + cwd: PathBuf::from("/home/user/project"), reasoning_effort: None, history_log_id: 0, history_entry_count: 0, @@ -533,6 +540,7 @@ fn error_event_produces_error() { "e1", EventMsg::Error(codex_core::protocol::ErrorEvent { message: "boom".to_string(), + codex_error_info: Some(CodexErrorInfo::Other), }), )); assert_eq!( @@ -572,6 +580,7 @@ fn stream_error_event_produces_error() { "e1", EventMsg::StreamError(codex_core::protocol::StreamErrorEvent { message: "retrying".to_string(), + codex_error_info: Some(CodexErrorInfo::Other), }), )); assert_eq!( @@ -590,6 +599,7 @@ fn error_followed_by_task_complete_produces_turn_failed() { "e1", EventMsg::Error(ErrorEvent { message: "boom".to_string(), + codex_error_info: Some(CodexErrorInfo::Other), }), ); assert_eq!( @@ -816,6 +826,7 @@ fn patch_apply_success_produces_item_completed_patchapply() { "p1", EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id: "call-1".to_string(), + turn_id: "turn-1".to_string(), auto_approved: true, changes: changes.clone(), }), @@ -828,9 +839,11 @@ fn patch_apply_success_produces_item_completed_patchapply() { "p2", EventMsg::PatchApplyEnd(PatchApplyEndEvent { call_id: "call-1".to_string(), + turn_id: "turn-1".to_string(), stdout: "applied 3 changes".to_string(), stderr: String::new(), success: true, + changes: changes.clone(), }), ); let out_end = ep.collect_thread_events(&end); @@ -885,6 +898,7 @@ fn patch_apply_failure_produces_item_completed_patchapply_failed() { "p1", EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id: "call-2".to_string(), + turn_id: "turn-2".to_string(), auto_approved: false, changes: changes.clone(), }), @@ -896,9 +910,11 @@ fn patch_apply_failure_produces_item_completed_patchapply_failed() { "p2", EventMsg::PatchApplyEnd(PatchApplyEndEvent { call_id: "call-2".to_string(), + turn_id: "turn-2".to_string(), stdout: String::new(), stderr: "failed to apply".to_string(), success: false, + changes: changes.clone(), }), ); let out_end = ep.collect_thread_events(&end); diff --git a/codex-rs/exec/tests/suite/resume.rs b/codex-rs/exec/tests/suite/resume.rs index 1a56edded..e37b38606 100644 --- a/codex-rs/exec/tests/suite/resume.rs +++ b/codex-rs/exec/tests/suite/resume.rs @@ -123,6 +123,60 @@ fn exec_resume_last_appends_to_existing_file() -> anyhow::Result<()> { Ok(()) } +#[test] +fn exec_resume_last_accepts_prompt_after_flag_in_json_mode() -> anyhow::Result<()> { + let test = test_codex_exec(); + let fixture = + Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/cli_responses_fixture.sse"); + + // 1) First run: create a session with a unique marker in the content. + let marker = format!("resume-last-json-{}", Uuid::new_v4()); + let prompt = format!("echo {marker}"); + + test.cmd() + .env("CODEX_RS_SSE_FIXTURE", &fixture) + .env("OPENAI_BASE_URL", "http://unused.local") + .arg("--skip-git-repo-check") + .arg("-C") + .arg(env!("CARGO_MANIFEST_DIR")) + .arg(&prompt) + .assert() + .success(); + + // Find the created session file containing the marker. + let sessions_dir = test.home_path().join("sessions"); + let path = find_session_file_containing_marker(&sessions_dir, &marker) + .expect("no session file found after first run"); + + // 2) Second run: resume the most recent file and pass the prompt after --last. + let marker2 = format!("resume-last-json-2-{}", Uuid::new_v4()); + let prompt2 = format!("echo {marker2}"); + + test.cmd() + .env("CODEX_RS_SSE_FIXTURE", &fixture) + .env("OPENAI_BASE_URL", "http://unused.local") + .arg("--skip-git-repo-check") + .arg("-C") + .arg(env!("CARGO_MANIFEST_DIR")) + .arg("--json") + .arg("resume") + .arg("--last") + .arg(&prompt2) + .assert() + .success(); + + let resumed_path = find_session_file_containing_marker(&sessions_dir, &marker2) + .expect("no resumed session file containing marker2"); + assert_eq!( + resumed_path, path, + "resume --last should append to existing file" + ); + let content = std::fs::read_to_string(&resumed_path)?; + assert!(content.contains(&marker)); + assert!(content.contains(&marker2)); + Ok(()) +} + #[test] fn exec_resume_by_id_appends_to_existing_file() -> anyhow::Result<()> { let test = test_codex_exec(); diff --git a/codex-rs/execpolicy-legacy/Cargo.toml b/codex-rs/execpolicy-legacy/Cargo.toml new file mode 100644 index 000000000..89e6e43b4 --- /dev/null +++ b/codex-rs/execpolicy-legacy/Cargo.toml @@ -0,0 +1,34 @@ +[package] +edition = "2024" +name = "codex-execpolicy-legacy" +description = "Legacy exec policy engine for validating proposed exec calls." +version = { workspace = true } + +[[bin]] +name = "codex-execpolicy-legacy" +path = "src/main.rs" + +[lib] +name = "codex_execpolicy_legacy" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +allocative = { workspace = true } +anyhow = { workspace = true } +clap = { workspace = true, features = ["derive"] } +derive_more = { workspace = true, features = ["display"] } +env_logger = { workspace = true } +log = { workspace = true } +multimap = { workspace = true } +path-absolutize = { workspace = true } +regex-lite = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +serde_with = { workspace = true, features = ["macros"] } +starlark = { workspace = true } + +[dev-dependencies] +tempfile = { workspace = true } diff --git a/codex-rs/execpolicy-legacy/README.md b/codex-rs/execpolicy-legacy/README.md new file mode 100644 index 000000000..1351377e5 --- /dev/null +++ b/codex-rs/execpolicy-legacy/README.md @@ -0,0 +1,183 @@ +# codex-execpolicy-legacy + +This crate hosts the original execpolicy implementation. The newer prefix-rule +engine lives in `codex-execpolicy`. + +The goal of this library is to classify a proposed [`execv(3)`](https://linux.die.net/man/3/execv) command into one of the following states: + +- `safe` The command is safe to run (\*). +- `match` The command matched a rule in the policy, but the caller should decide whether it is safe to run based on the files it will write. +- `forbidden` The command is not allowed to be run. +- `unverified` The safety cannot be determined: make the user decide. + +(\*) Whether an `execv(3)` call should be considered "safe" often requires additional context beyond the arguments to `execv()` itself. For example, if you trust an autonomous software agent to write files in your source tree, then deciding whether `/bin/cp foo bar` is "safe" depends on `getcwd(3)` for the calling process as well as the `realpath` of `foo` and `bar` when resolved against `getcwd()`. +To that end, rather than returning a boolean, the validator returns a structured result that the client is expected to use to determine the "safety" of the proposed `execv()` call. + +For example, to check the command `ls -l foo`, the checker would be invoked as follows: + +```shell +cargo run -p codex-execpolicy-legacy -- check ls -l foo | jq +``` + +It will exit with `0` and print the following to stdout: + +```json +{ + "result": "safe", + "match": { + "program": "ls", + "flags": [ + { + "name": "-l" + } + ], + "opts": [], + "args": [ + { + "index": 1, + "type": "ReadableFile", + "value": "foo" + } + ], + "system_path": ["/bin/ls", "/usr/bin/ls"] + } +} +``` + +Of note: + +- `foo` is tagged as a `ReadableFile`, so the caller should resolve `foo` relative to `getcwd()` and `realpath` it (as it may be a symlink) to determine whether `foo` is safe to read. +- While the specified executable is `ls`, `"system_path"` offers `/bin/ls` and `/usr/bin/ls` as viable alternatives to avoid using whatever `ls` happens to appear first on the user's `$PATH`. If either exists on the host, it is recommended to use it as the first argument to `execv(3)` instead of `ls`. + +Further, "safety" in this system is not a guarantee that the command will execute successfully. As an example, `cat /Users/mbolin/code/codex/README.md` may be considered "safe" if the system has decided the agent is allowed to read anything under `/Users/mbolin/code/codex`, but it will fail at runtime if `README.md` does not exist. (Though this is "safe" in that the agent did not read any files that it was not authorized to read.) + +## Policy + +Currently, the default policy is defined in [`default.policy`](./src/default.policy) within the crate. + +The system uses [Starlark](https://bazel.build/rules/language) as the file format because, unlike something like JSON or YAML, it supports "macros" without compromising on safety or reproducibility. (Under the hood, we use [`starlark-rust`](https://github.com/facebook/starlark-rust) as the specific Starlark implementation.) + +This policy contains "rules" such as: + +```python +define_program( + program="cp", + options=[ + flag("-r"), + flag("-R"), + flag("--recursive"), + ], + args=[ARG_RFILES, ARG_WFILE], + system_path=["/bin/cp", "/usr/bin/cp"], + should_match=[ + ["foo", "bar"], + ], + should_not_match=[ + ["foo"], + ], +) +``` + +This rule means that: + +- `cp` can be used with any of the following flags (where "flag" means "an option that does not take an argument"): `-r`, `-R`, `--recursive`. +- The initial `ARG_RFILES` passed to `args` means that it expects one or more arguments that correspond to "readable files" +- The final `ARG_WFILE` passed to `args` means that it expects exactly one argument that corresponds to a "writeable file." +- As a means of a lightweight way of including a unit test alongside the definition, the `should_match` list is a list of examples of `execv(3)` args that should match the rule and `should_not_match` is a list of examples that should not match. These examples are verified when the `.policy` file is loaded. + +Note that the language of the `.policy` file is still evolving, as we have to continue to expand it so it is sufficiently expressive to accept all commands we want to consider "safe" without allowing unsafe commands to pass through. + +The integrity of `default.policy` is verified [via unit tests](./tests). + +Further, the CLI supports a `--policy` option to specify a custom `.policy` file for ad-hoc testing. + +## Output Type: `match` + +Going back to the `cp` example, because the rule matches an `ARG_WFILE`, it will return `match` instead of `safe`: + +```shell +cargo run -p codex-execpolicy-legacy -- check cp src1 src2 dest | jq +``` + +If the caller wants to consider allowing this command, it should parse the JSON to pick out the `WriteableFile` arguments and decide whether they are safe to write: + +```json +{ + "result": "match", + "match": { + "program": "cp", + "flags": [], + "opts": [], + "args": [ + { + "index": 0, + "type": "ReadableFile", + "value": "src1" + }, + { + "index": 1, + "type": "ReadableFile", + "value": "src2" + }, + { + "index": 2, + "type": "WriteableFile", + "value": "dest" + } + ], + "system_path": ["/bin/cp", "/usr/bin/cp"] + } +} +``` + +Note the exit code is still `0` for a `match` unless the `--require-safe` flag is specified, in which case the exit code is `12`. + +## Output Type: `forbidden` + +It is also possible to define a rule that, if it matches a command, should flag it as _forbidden_. For example, we do not want agents to be able to run `applied deploy` _ever_, so we define the following rule: + +```python +define_program( + program="applied", + args=["deploy"], + forbidden="Infrastructure Risk: command contains 'applied deploy'", + should_match=[ + ["deploy"], + ], + should_not_match=[ + ["lint"], + ], +) +``` + +Note that for a rule to be forbidden, the `forbidden` keyword arg must be specified as the reason the command is forbidden. This will be included in the output: + +```shell +cargo run -p codex-execpolicy-legacy -- check applied deploy | jq +``` + +```json +{ + "result": "forbidden", + "reason": "Infrastructure Risk: command contains 'applied deploy'", + "cause": { + "Exec": { + "exec": { + "program": "applied", + "flags": [], + "opts": [], + "args": [ + { + "index": 0, + "type": { + "Literal": "deploy" + }, + "value": "deploy" + } + ], + "system_path": [] + } + } + } +} +``` diff --git a/codex-rs/execpolicy/build.rs b/codex-rs/execpolicy-legacy/build.rs similarity index 100% rename from codex-rs/execpolicy/build.rs rename to codex-rs/execpolicy-legacy/build.rs diff --git a/codex-rs/execpolicy/src/arg_matcher.rs b/codex-rs/execpolicy-legacy/src/arg_matcher.rs similarity index 100% rename from codex-rs/execpolicy/src/arg_matcher.rs rename to codex-rs/execpolicy-legacy/src/arg_matcher.rs diff --git a/codex-rs/execpolicy/src/arg_resolver.rs b/codex-rs/execpolicy-legacy/src/arg_resolver.rs similarity index 100% rename from codex-rs/execpolicy/src/arg_resolver.rs rename to codex-rs/execpolicy-legacy/src/arg_resolver.rs diff --git a/codex-rs/execpolicy/src/arg_type.rs b/codex-rs/execpolicy-legacy/src/arg_type.rs similarity index 100% rename from codex-rs/execpolicy/src/arg_type.rs rename to codex-rs/execpolicy-legacy/src/arg_type.rs diff --git a/codex-rs/execpolicy/src/default.policy b/codex-rs/execpolicy-legacy/src/default.policy similarity index 100% rename from codex-rs/execpolicy/src/default.policy rename to codex-rs/execpolicy-legacy/src/default.policy diff --git a/codex-rs/execpolicy-legacy/src/error.rs b/codex-rs/execpolicy-legacy/src/error.rs new file mode 100644 index 000000000..e6443d69d --- /dev/null +++ b/codex-rs/execpolicy-legacy/src/error.rs @@ -0,0 +1,96 @@ +use std::path::PathBuf; + +use serde::Serialize; + +use crate::arg_matcher::ArgMatcher; +use crate::arg_resolver::PositionalArg; +use serde_with::DisplayFromStr; +use serde_with::serde_as; + +pub type Result = std::result::Result; + +#[serde_as] +#[derive(Debug, Eq, PartialEq, Serialize)] +#[serde(tag = "type")] +pub enum Error { + NoSpecForProgram { + program: String, + }, + OptionMissingValue { + program: String, + option: String, + }, + OptionFollowedByOptionInsteadOfValue { + program: String, + option: String, + value: String, + }, + UnknownOption { + program: String, + option: String, + }, + UnexpectedArguments { + program: String, + args: Vec, + }, + DoubleDashNotSupportedYet { + program: String, + }, + MultipleVarargPatterns { + program: String, + first: ArgMatcher, + second: ArgMatcher, + }, + RangeStartExceedsEnd { + start: usize, + end: usize, + }, + RangeEndOutOfBounds { + end: usize, + len: usize, + }, + PrefixOverlapsSuffix {}, + NotEnoughArgs { + program: String, + args: Vec, + arg_patterns: Vec, + }, + InternalInvariantViolation { + message: String, + }, + VarargMatcherDidNotMatchAnything { + program: String, + matcher: ArgMatcher, + }, + EmptyFileName {}, + LiteralValueDidNotMatch { + expected: String, + actual: String, + }, + InvalidPositiveInteger { + value: String, + }, + MissingRequiredOptions { + program: String, + options: Vec, + }, + SedCommandNotProvablySafe { + command: String, + }, + ReadablePathNotInReadableFolders { + file: PathBuf, + folders: Vec, + }, + WriteablePathNotInWriteableFolders { + file: PathBuf, + folders: Vec, + }, + CannotCheckRelativePath { + file: PathBuf, + }, + CannotCanonicalizePath { + file: String, + #[serde_as(as = "DisplayFromStr")] + error: std::io::ErrorKind, + }, +} diff --git a/codex-rs/execpolicy/src/exec_call.rs b/codex-rs/execpolicy-legacy/src/exec_call.rs similarity index 100% rename from codex-rs/execpolicy/src/exec_call.rs rename to codex-rs/execpolicy-legacy/src/exec_call.rs diff --git a/codex-rs/execpolicy/src/execv_checker.rs b/codex-rs/execpolicy-legacy/src/execv_checker.rs similarity index 100% rename from codex-rs/execpolicy/src/execv_checker.rs rename to codex-rs/execpolicy-legacy/src/execv_checker.rs diff --git a/codex-rs/execpolicy-legacy/src/lib.rs b/codex-rs/execpolicy-legacy/src/lib.rs new file mode 100644 index 000000000..6f1222598 --- /dev/null +++ b/codex-rs/execpolicy-legacy/src/lib.rs @@ -0,0 +1,45 @@ +#![allow(clippy::type_complexity)] +#![allow(clippy::too_many_arguments)] +#[macro_use] +extern crate starlark; + +mod arg_matcher; +mod arg_resolver; +mod arg_type; +mod error; +mod exec_call; +mod execv_checker; +mod opt; +mod policy; +mod policy_parser; +mod program; +mod sed_command; +mod valid_exec; + +pub use arg_matcher::ArgMatcher; +pub use arg_resolver::PositionalArg; +pub use arg_type::ArgType; +pub use error::Error; +pub use error::Result; +pub use exec_call::ExecCall; +pub use execv_checker::ExecvChecker; +pub use opt::Opt; +pub use policy::Policy; +pub use policy_parser::PolicyParser; +pub use program::Forbidden; +pub use program::MatchedExec; +pub use program::NegativeExamplePassedCheck; +pub use program::PositiveExampleFailedCheck; +pub use program::ProgramSpec; +pub use sed_command::parse_sed_command; +pub use valid_exec::MatchedArg; +pub use valid_exec::MatchedFlag; +pub use valid_exec::MatchedOpt; +pub use valid_exec::ValidExec; + +const DEFAULT_POLICY: &str = include_str!("default.policy"); + +pub fn get_default_policy() -> starlark::Result { + let parser = PolicyParser::new("#default", DEFAULT_POLICY); + parser.parse() +} diff --git a/codex-rs/execpolicy-legacy/src/main.rs b/codex-rs/execpolicy-legacy/src/main.rs new file mode 100644 index 000000000..f5b66dfe5 --- /dev/null +++ b/codex-rs/execpolicy-legacy/src/main.rs @@ -0,0 +1,169 @@ +use anyhow::Result; +use clap::Parser; +use clap::Subcommand; +use codex_execpolicy_legacy::ExecCall; +use codex_execpolicy_legacy::MatchedExec; +use codex_execpolicy_legacy::Policy; +use codex_execpolicy_legacy::PolicyParser; +use codex_execpolicy_legacy::ValidExec; +use codex_execpolicy_legacy::get_default_policy; +use serde::Deserialize; +use serde::Serialize; +use serde::de; +use starlark::Error as StarlarkError; +use std::path::PathBuf; +use std::str::FromStr; + +const MATCHED_BUT_WRITES_FILES_EXIT_CODE: i32 = 12; +const MIGHT_BE_SAFE_EXIT_CODE: i32 = 13; +const FORBIDDEN_EXIT_CODE: i32 = 14; + +#[derive(Parser, Deserialize, Debug)] +#[command(version, about, long_about = None)] +pub struct Args { + /// If the command fails the policy, exit with 13, but print parseable JSON + /// to stdout. + #[clap(long)] + pub require_safe: bool, + + /// Path to the policy file. + #[clap(long, short = 'p')] + pub policy: Option, + + #[command(subcommand)] + pub command: Command, +} + +#[derive(Clone, Debug, Deserialize, Subcommand)] +pub enum Command { + /// Checks the command as if the arguments were the inputs to execv(3). + Check { + #[arg(trailing_var_arg = true)] + command: Vec, + }, + + /// Checks the command encoded as a JSON object. + #[clap(name = "check-json")] + CheckJson { + /// JSON object with "program" (str) and "args" (list[str]) fields. + #[serde(deserialize_with = "deserialize_from_json")] + exec: ExecArg, + }, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ExecArg { + pub program: String, + + #[serde(default)] + pub args: Vec, +} + +fn main() -> Result<()> { + env_logger::init(); + + let args = Args::parse(); + let policy = match args.policy { + Some(policy) => { + let policy_source = policy.to_string_lossy().to_string(); + let unparsed_policy = std::fs::read_to_string(policy)?; + let parser = PolicyParser::new(&policy_source, &unparsed_policy); + parser.parse() + } + None => get_default_policy(), + }; + let policy = policy.map_err(StarlarkError::into_anyhow)?; + + let exec = match args.command { + Command::Check { command } => match command.split_first() { + Some((first, rest)) => ExecArg { + program: first.to_string(), + args: rest.to_vec(), + }, + None => { + eprintln!("no command provided"); + std::process::exit(1); + } + }, + Command::CheckJson { exec } => exec, + }; + + let (output, exit_code) = check_command(&policy, exec, args.require_safe); + let json = serde_json::to_string(&output)?; + println!("{json}"); + std::process::exit(exit_code); +} + +fn check_command( + policy: &Policy, + ExecArg { program, args }: ExecArg, + check: bool, +) -> (Output, i32) { + let exec_call = ExecCall { program, args }; + match policy.check(&exec_call) { + Ok(MatchedExec::Match { exec }) => { + if exec.might_write_files() { + let exit_code = if check { + MATCHED_BUT_WRITES_FILES_EXIT_CODE + } else { + 0 + }; + (Output::Match { r#match: exec }, exit_code) + } else { + (Output::Safe { r#match: exec }, 0) + } + } + Ok(MatchedExec::Forbidden { reason, cause }) => { + let exit_code = if check { FORBIDDEN_EXIT_CODE } else { 0 }; + (Output::Forbidden { reason, cause }, exit_code) + } + Err(err) => { + let exit_code = if check { MIGHT_BE_SAFE_EXIT_CODE } else { 0 }; + (Output::Unverified { error: err }, exit_code) + } + } +} + +#[derive(Debug, Serialize)] +#[serde(tag = "result")] +pub enum Output { + /// The command is verified as safe. + #[serde(rename = "safe")] + Safe { r#match: ValidExec }, + + /// The command has matched a rule in the policy, but the caller should + /// decide whether it is "safe" given the files it wants to write. + #[serde(rename = "match")] + Match { r#match: ValidExec }, + + /// The user is forbidden from running the command. + #[serde(rename = "forbidden")] + Forbidden { + reason: String, + cause: codex_execpolicy_legacy::Forbidden, + }, + + /// The safety of the command could not be verified. + #[serde(rename = "unverified")] + Unverified { + error: codex_execpolicy_legacy::Error, + }, +} + +fn deserialize_from_json<'de, D>(deserializer: D) -> Result +where + D: de::Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + let decoded = serde_json::from_str(&s) + .map_err(|e| serde::de::Error::custom(format!("JSON parse error: {e}")))?; + Ok(decoded) +} + +impl FromStr for ExecArg { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + serde_json::from_str(s).map_err(Into::into) + } +} diff --git a/codex-rs/execpolicy/src/opt.rs b/codex-rs/execpolicy-legacy/src/opt.rs similarity index 100% rename from codex-rs/execpolicy/src/opt.rs rename to codex-rs/execpolicy-legacy/src/opt.rs diff --git a/codex-rs/execpolicy-legacy/src/policy.rs b/codex-rs/execpolicy-legacy/src/policy.rs new file mode 100644 index 000000000..825d6164a --- /dev/null +++ b/codex-rs/execpolicy-legacy/src/policy.rs @@ -0,0 +1,103 @@ +use multimap::MultiMap; +use regex_lite::Error as RegexError; +use regex_lite::Regex; + +use crate::ExecCall; +use crate::Forbidden; +use crate::MatchedExec; +use crate::NegativeExamplePassedCheck; +use crate::ProgramSpec; +use crate::error::Error; +use crate::error::Result; +use crate::policy_parser::ForbiddenProgramRegex; +use crate::program::PositiveExampleFailedCheck; + +pub struct Policy { + programs: MultiMap, + forbidden_program_regexes: Vec, + forbidden_substrings_pattern: Option, +} + +impl Policy { + pub fn new( + programs: MultiMap, + forbidden_program_regexes: Vec, + forbidden_substrings: Vec, + ) -> std::result::Result { + let forbidden_substrings_pattern = if forbidden_substrings.is_empty() { + None + } else { + let escaped_substrings = forbidden_substrings + .iter() + .map(|s| regex_lite::escape(s)) + .collect::>() + .join("|"); + Some(Regex::new(&format!("({escaped_substrings})"))?) + }; + Ok(Self { + programs, + forbidden_program_regexes, + forbidden_substrings_pattern, + }) + } + + pub fn check(&self, exec_call: &ExecCall) -> Result { + let ExecCall { program, args } = &exec_call; + for ForbiddenProgramRegex { regex, reason } in &self.forbidden_program_regexes { + if regex.is_match(program) { + return Ok(MatchedExec::Forbidden { + cause: Forbidden::Program { + program: program.clone(), + exec_call: exec_call.clone(), + }, + reason: reason.clone(), + }); + } + } + + for arg in args { + if let Some(regex) = &self.forbidden_substrings_pattern + && regex.is_match(arg) + { + return Ok(MatchedExec::Forbidden { + cause: Forbidden::Arg { + arg: arg.clone(), + exec_call: exec_call.clone(), + }, + reason: format!("arg `{arg}` contains forbidden substring"), + }); + } + } + + let mut last_err = Err(Error::NoSpecForProgram { + program: program.clone(), + }); + if let Some(spec_list) = self.programs.get_vec(program) { + for spec in spec_list { + match spec.check(exec_call) { + Ok(matched_exec) => return Ok(matched_exec), + Err(err) => { + last_err = Err(err); + } + } + } + } + last_err + } + + pub fn check_each_good_list_individually(&self) -> Vec { + let mut violations = Vec::new(); + for (_program, spec) in self.programs.flat_iter() { + violations.extend(spec.verify_should_match_list()); + } + violations + } + + pub fn check_each_bad_list_individually(&self) -> Vec { + let mut violations = Vec::new(); + for (_program, spec) in self.programs.flat_iter() { + violations.extend(spec.verify_should_not_match_list()); + } + violations + } +} diff --git a/codex-rs/execpolicy/src/policy_parser.rs b/codex-rs/execpolicy-legacy/src/policy_parser.rs similarity index 100% rename from codex-rs/execpolicy/src/policy_parser.rs rename to codex-rs/execpolicy-legacy/src/policy_parser.rs diff --git a/codex-rs/execpolicy/src/program.rs b/codex-rs/execpolicy-legacy/src/program.rs similarity index 100% rename from codex-rs/execpolicy/src/program.rs rename to codex-rs/execpolicy-legacy/src/program.rs diff --git a/codex-rs/execpolicy/src/sed_command.rs b/codex-rs/execpolicy-legacy/src/sed_command.rs similarity index 100% rename from codex-rs/execpolicy/src/sed_command.rs rename to codex-rs/execpolicy-legacy/src/sed_command.rs diff --git a/codex-rs/execpolicy/src/valid_exec.rs b/codex-rs/execpolicy-legacy/src/valid_exec.rs similarity index 100% rename from codex-rs/execpolicy/src/valid_exec.rs rename to codex-rs/execpolicy-legacy/src/valid_exec.rs diff --git a/codex-rs/execpolicy/tests/all.rs b/codex-rs/execpolicy-legacy/tests/all.rs similarity index 100% rename from codex-rs/execpolicy/tests/all.rs rename to codex-rs/execpolicy-legacy/tests/all.rs diff --git a/codex-rs/execpolicy/tests/suite/bad.rs b/codex-rs/execpolicy-legacy/tests/suite/bad.rs similarity index 72% rename from codex-rs/execpolicy/tests/suite/bad.rs rename to codex-rs/execpolicy-legacy/tests/suite/bad.rs index 8b6e195fb..e1f867533 100644 --- a/codex-rs/execpolicy/tests/suite/bad.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/bad.rs @@ -1,5 +1,5 @@ -use codex_execpolicy::NegativeExamplePassedCheck; -use codex_execpolicy::get_default_policy; +use codex_execpolicy_legacy::NegativeExamplePassedCheck; +use codex_execpolicy_legacy::get_default_policy; #[test] fn verify_everything_in_bad_list_is_rejected() { diff --git a/codex-rs/execpolicy/tests/suite/cp.rs b/codex-rs/execpolicy-legacy/tests/suite/cp.rs similarity index 81% rename from codex-rs/execpolicy/tests/suite/cp.rs rename to codex-rs/execpolicy-legacy/tests/suite/cp.rs index aa19f0b5d..3cfc9ac5c 100644 --- a/codex-rs/execpolicy/tests/suite/cp.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/cp.rs @@ -1,15 +1,15 @@ -extern crate codex_execpolicy; +extern crate codex_execpolicy_legacy; -use codex_execpolicy::ArgMatcher; -use codex_execpolicy::ArgType; -use codex_execpolicy::Error; -use codex_execpolicy::ExecCall; -use codex_execpolicy::MatchedArg; -use codex_execpolicy::MatchedExec; -use codex_execpolicy::Policy; -use codex_execpolicy::Result; -use codex_execpolicy::ValidExec; -use codex_execpolicy::get_default_policy; +use codex_execpolicy_legacy::ArgMatcher; +use codex_execpolicy_legacy::ArgType; +use codex_execpolicy_legacy::Error; +use codex_execpolicy_legacy::ExecCall; +use codex_execpolicy_legacy::MatchedArg; +use codex_execpolicy_legacy::MatchedExec; +use codex_execpolicy_legacy::Policy; +use codex_execpolicy_legacy::Result; +use codex_execpolicy_legacy::ValidExec; +use codex_execpolicy_legacy::get_default_policy; #[expect(clippy::expect_used)] fn setup() -> Policy { diff --git a/codex-rs/execpolicy/tests/suite/good.rs b/codex-rs/execpolicy-legacy/tests/suite/good.rs similarity index 72% rename from codex-rs/execpolicy/tests/suite/good.rs rename to codex-rs/execpolicy-legacy/tests/suite/good.rs index 3b7313a33..3c86c7acb 100644 --- a/codex-rs/execpolicy/tests/suite/good.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/good.rs @@ -1,5 +1,5 @@ -use codex_execpolicy::PositiveExampleFailedCheck; -use codex_execpolicy::get_default_policy; +use codex_execpolicy_legacy::PositiveExampleFailedCheck; +use codex_execpolicy_legacy::get_default_policy; #[test] fn verify_everything_in_good_list_is_allowed() { diff --git a/codex-rs/execpolicy/tests/suite/head.rs b/codex-rs/execpolicy-legacy/tests/suite/head.rs similarity index 87% rename from codex-rs/execpolicy/tests/suite/head.rs rename to codex-rs/execpolicy-legacy/tests/suite/head.rs index 3c32ccfbf..390b4ddb3 100644 --- a/codex-rs/execpolicy/tests/suite/head.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/head.rs @@ -1,16 +1,16 @@ -use codex_execpolicy::ArgMatcher; -use codex_execpolicy::ArgType; -use codex_execpolicy::Error; -use codex_execpolicy::ExecCall; -use codex_execpolicy::MatchedArg; -use codex_execpolicy::MatchedExec; -use codex_execpolicy::MatchedOpt; -use codex_execpolicy::Policy; -use codex_execpolicy::Result; -use codex_execpolicy::ValidExec; -use codex_execpolicy::get_default_policy; +use codex_execpolicy_legacy::ArgMatcher; +use codex_execpolicy_legacy::ArgType; +use codex_execpolicy_legacy::Error; +use codex_execpolicy_legacy::ExecCall; +use codex_execpolicy_legacy::MatchedArg; +use codex_execpolicy_legacy::MatchedExec; +use codex_execpolicy_legacy::MatchedOpt; +use codex_execpolicy_legacy::Policy; +use codex_execpolicy_legacy::Result; +use codex_execpolicy_legacy::ValidExec; +use codex_execpolicy_legacy::get_default_policy; -extern crate codex_execpolicy; +extern crate codex_execpolicy_legacy; #[expect(clippy::expect_used)] fn setup() -> Policy { diff --git a/codex-rs/execpolicy/tests/suite/literal.rs b/codex-rs/execpolicy-legacy/tests/suite/literal.rs similarity index 78% rename from codex-rs/execpolicy/tests/suite/literal.rs rename to codex-rs/execpolicy-legacy/tests/suite/literal.rs index d849371e3..296206dbc 100644 --- a/codex-rs/execpolicy/tests/suite/literal.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/literal.rs @@ -1,13 +1,13 @@ -use codex_execpolicy::ArgType; -use codex_execpolicy::Error; -use codex_execpolicy::ExecCall; -use codex_execpolicy::MatchedArg; -use codex_execpolicy::MatchedExec; -use codex_execpolicy::PolicyParser; -use codex_execpolicy::Result; -use codex_execpolicy::ValidExec; +use codex_execpolicy_legacy::ArgType; +use codex_execpolicy_legacy::Error; +use codex_execpolicy_legacy::ExecCall; +use codex_execpolicy_legacy::MatchedArg; +use codex_execpolicy_legacy::MatchedExec; +use codex_execpolicy_legacy::PolicyParser; +use codex_execpolicy_legacy::Result; +use codex_execpolicy_legacy::ValidExec; -extern crate codex_execpolicy; +extern crate codex_execpolicy_legacy; #[test] fn test_invalid_subcommand() -> Result<()> { diff --git a/codex-rs/execpolicy/tests/suite/ls.rs b/codex-rs/execpolicy-legacy/tests/suite/ls.rs similarity index 91% rename from codex-rs/execpolicy/tests/suite/ls.rs rename to codex-rs/execpolicy-legacy/tests/suite/ls.rs index e52316c06..b4b27b0e2 100644 --- a/codex-rs/execpolicy/tests/suite/ls.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/ls.rs @@ -1,15 +1,15 @@ -extern crate codex_execpolicy; - -use codex_execpolicy::ArgType; -use codex_execpolicy::Error; -use codex_execpolicy::ExecCall; -use codex_execpolicy::MatchedArg; -use codex_execpolicy::MatchedExec; -use codex_execpolicy::MatchedFlag; -use codex_execpolicy::Policy; -use codex_execpolicy::Result; -use codex_execpolicy::ValidExec; -use codex_execpolicy::get_default_policy; +extern crate codex_execpolicy_legacy; + +use codex_execpolicy_legacy::ArgType; +use codex_execpolicy_legacy::Error; +use codex_execpolicy_legacy::ExecCall; +use codex_execpolicy_legacy::MatchedArg; +use codex_execpolicy_legacy::MatchedExec; +use codex_execpolicy_legacy::MatchedFlag; +use codex_execpolicy_legacy::Policy; +use codex_execpolicy_legacy::Result; +use codex_execpolicy_legacy::ValidExec; +use codex_execpolicy_legacy::get_default_policy; #[expect(clippy::expect_used)] fn setup() -> Policy { diff --git a/codex-rs/execpolicy/tests/suite/mod.rs b/codex-rs/execpolicy-legacy/tests/suite/mod.rs similarity index 100% rename from codex-rs/execpolicy/tests/suite/mod.rs rename to codex-rs/execpolicy-legacy/tests/suite/mod.rs diff --git a/codex-rs/execpolicy/tests/suite/parse_sed_command.rs b/codex-rs/execpolicy-legacy/tests/suite/parse_sed_command.rs similarity index 84% rename from codex-rs/execpolicy/tests/suite/parse_sed_command.rs rename to codex-rs/execpolicy-legacy/tests/suite/parse_sed_command.rs index 20f5bbf30..f1da55d64 100644 --- a/codex-rs/execpolicy/tests/suite/parse_sed_command.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/parse_sed_command.rs @@ -1,5 +1,5 @@ -use codex_execpolicy::Error; -use codex_execpolicy::parse_sed_command; +use codex_execpolicy_legacy::Error; +use codex_execpolicy_legacy::parse_sed_command; #[test] fn parses_simple_print_command() { diff --git a/codex-rs/execpolicy/tests/suite/pwd.rs b/codex-rs/execpolicy-legacy/tests/suite/pwd.rs similarity index 82% rename from codex-rs/execpolicy/tests/suite/pwd.rs rename to codex-rs/execpolicy-legacy/tests/suite/pwd.rs index fdf5a4f1a..73d1caada 100644 --- a/codex-rs/execpolicy/tests/suite/pwd.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/pwd.rs @@ -1,15 +1,15 @@ -extern crate codex_execpolicy; +extern crate codex_execpolicy_legacy; use std::vec; -use codex_execpolicy::Error; -use codex_execpolicy::ExecCall; -use codex_execpolicy::MatchedExec; -use codex_execpolicy::MatchedFlag; -use codex_execpolicy::Policy; -use codex_execpolicy::PositionalArg; -use codex_execpolicy::ValidExec; -use codex_execpolicy::get_default_policy; +use codex_execpolicy_legacy::Error; +use codex_execpolicy_legacy::ExecCall; +use codex_execpolicy_legacy::MatchedExec; +use codex_execpolicy_legacy::MatchedFlag; +use codex_execpolicy_legacy::Policy; +use codex_execpolicy_legacy::PositionalArg; +use codex_execpolicy_legacy::ValidExec; +use codex_execpolicy_legacy::get_default_policy; #[expect(clippy::expect_used)] fn setup() -> Policy { diff --git a/codex-rs/execpolicy/tests/suite/sed.rs b/codex-rs/execpolicy-legacy/tests/suite/sed.rs similarity index 82% rename from codex-rs/execpolicy/tests/suite/sed.rs rename to codex-rs/execpolicy-legacy/tests/suite/sed.rs index bf35bf6d4..d732cc7c9 100644 --- a/codex-rs/execpolicy/tests/suite/sed.rs +++ b/codex-rs/execpolicy-legacy/tests/suite/sed.rs @@ -1,16 +1,16 @@ -extern crate codex_execpolicy; +extern crate codex_execpolicy_legacy; -use codex_execpolicy::ArgType; -use codex_execpolicy::Error; -use codex_execpolicy::ExecCall; -use codex_execpolicy::MatchedArg; -use codex_execpolicy::MatchedExec; -use codex_execpolicy::MatchedFlag; -use codex_execpolicy::MatchedOpt; -use codex_execpolicy::Policy; -use codex_execpolicy::Result; -use codex_execpolicy::ValidExec; -use codex_execpolicy::get_default_policy; +use codex_execpolicy_legacy::ArgType; +use codex_execpolicy_legacy::Error; +use codex_execpolicy_legacy::ExecCall; +use codex_execpolicy_legacy::MatchedArg; +use codex_execpolicy_legacy::MatchedExec; +use codex_execpolicy_legacy::MatchedFlag; +use codex_execpolicy_legacy::MatchedOpt; +use codex_execpolicy_legacy::Policy; +use codex_execpolicy_legacy::Result; +use codex_execpolicy_legacy::ValidExec; +use codex_execpolicy_legacy::get_default_policy; #[expect(clippy::expect_used)] fn setup() -> Policy { diff --git a/codex-rs/execpolicy/Cargo.toml b/codex-rs/execpolicy/Cargo.toml index 0fe7cd486..bececed4b 100644 --- a/codex-rs/execpolicy/Cargo.toml +++ b/codex-rs/execpolicy/Cargo.toml @@ -1,33 +1,29 @@ [package] -edition = "2024" name = "codex-execpolicy" version = { workspace = true } - -[[bin]] -name = "codex-execpolicy" -path = "src/main.rs" +edition = "2024" +description = "Codex exec policy: prefix-based Starlark rules for command decisions." [lib] name = "codex_execpolicy" path = "src/lib.rs" +[[bin]] +name = "codex-execpolicy" +path = "src/main.rs" + [lints] workspace = true [dependencies] -allocative = { workspace = true } anyhow = { workspace = true } clap = { workspace = true, features = ["derive"] } -derive_more = { workspace = true, features = ["display"] } -env_logger = { workspace = true } -log = { workspace = true } multimap = { workspace = true } -path-absolutize = { workspace = true } -regex-lite = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } -serde_with = { workspace = true, features = ["macros"] } +shlex = { workspace = true } starlark = { workspace = true } +thiserror = { workspace = true } [dev-dependencies] -tempfile = { workspace = true } +pretty_assertions = { workspace = true } diff --git a/codex-rs/execpolicy/README.md b/codex-rs/execpolicy/README.md index ca9582944..9fd9c6330 100644 --- a/codex-rs/execpolicy/README.md +++ b/codex-rs/execpolicy/README.md @@ -1,180 +1,62 @@ -# codex_execpolicy - -The goal of this library is to classify a proposed [`execv(3)`](https://linux.die.net/man/3/execv) command into one of the following states: - -- `safe` The command is safe to run (\*). -- `match` The command matched a rule in the policy, but the caller should decide whether it is safe to run based on the files it will write. -- `forbidden` The command is not allowed to be run. -- `unverified` The safety cannot be determined: make the user decide. - -(\*) Whether an `execv(3)` call should be considered "safe" often requires additional context beyond the arguments to `execv()` itself. For example, if you trust an autonomous software agent to write files in your source tree, then deciding whether `/bin/cp foo bar` is "safe" depends on `getcwd(3)` for the calling process as well as the `realpath` of `foo` and `bar` when resolved against `getcwd()`. -To that end, rather than returning a boolean, the validator returns a structured result that the client is expected to use to determine the "safety" of the proposed `execv()` call. - -For example, to check the command `ls -l foo`, the checker would be invoked as follows: - -```shell -cargo run -- check ls -l foo | jq -``` - -It will exit with `0` and print the following to stdout: - -```json -{ - "result": "safe", - "match": { - "program": "ls", - "flags": [ - { - "name": "-l" - } - ], - "opts": [], - "args": [ - { - "index": 1, - "type": "ReadableFile", - "value": "foo" - } - ], - "system_path": ["/bin/ls", "/usr/bin/ls"] - } -} -``` - -Of note: - -- `foo` is tagged as a `ReadableFile`, so the caller should resolve `foo` relative to `getcwd()` and `realpath` it (as it may be a symlink) to determine whether `foo` is safe to read. -- While the specified executable is `ls`, `"system_path"` offers `/bin/ls` and `/usr/bin/ls` as viable alternatives to avoid using whatever `ls` happens to appear first on the user's `$PATH`. If either exists on the host, it is recommended to use it as the first argument to `execv(3)` instead of `ls`. - -Further, "safety" in this system is not a guarantee that the command will execute successfully. As an example, `cat /Users/mbolin/code/codex/README.md` may be considered "safe" if the system has decided the agent is allowed to read anything under `/Users/mbolin/code/codex`, but it will fail at runtime if `README.md` does not exist. (Though this is "safe" in that the agent did not read any files that it was not authorized to read.) - -## Policy - -Currently, the default policy is defined in [`default.policy`](./src/default.policy) within the crate. - -The system uses [Starlark](https://bazel.build/rules/language) as the file format because, unlike something like JSON or YAML, it supports "macros" without compromising on safety or reproducibility. (Under the hood, we use [`starlark-rust`](https://github.com/facebook/starlark-rust) as the specific Starlark implementation.) - -This policy contains "rules" such as: - -```python -define_program( - program="cp", - options=[ - flag("-r"), - flag("-R"), - flag("--recursive"), - ], - args=[ARG_RFILES, ARG_WFILE], - system_path=["/bin/cp", "/usr/bin/cp"], - should_match=[ - ["foo", "bar"], - ], - should_not_match=[ - ["foo"], - ], +# codex-execpolicy + +## Overview +- Policy engine and CLI built around `prefix_rule(pattern=[...], decision?, match?, not_match?)`. +- This release covers the prefix-rule subset of the execpolicy language; a richer language will follow. +- Tokens are matched in order; any `pattern` element may be a list to denote alternatives. `decision` defaults to `allow`; valid values: `allow`, `prompt`, `forbidden`. +- `match` / `not_match` supply example invocations that are validated at load time (think of them as unit tests); examples can be token arrays or strings (strings are tokenized with `shlex`). +- The CLI always prints the JSON serialization of the evaluation result. +- The legacy rule matcher lives in `codex-execpolicy-legacy`. + +## Policy shapes +- Prefix rules use Starlark syntax: +```starlark +prefix_rule( + pattern = ["cmd", ["alt1", "alt2"]], # ordered tokens; list entries denote alternatives + decision = "prompt", # allow | prompt | forbidden; defaults to allow + match = [["cmd", "alt1"], "cmd alt2"], # examples that must match this rule + not_match = [["cmd", "oops"], "cmd alt3"], # examples that must not match this rule ) ``` -This rule means that: - -- `cp` can be used with any of the following flags (where "flag" means "an option that does not take an argument"): `-r`, `-R`, `--recursive`. -- The initial `ARG_RFILES` passed to `args` means that it expects one or more arguments that correspond to "readable files" -- The final `ARG_WFILE` passed to `args` means that it expects exactly one argument that corresponds to a "writeable file." -- As a means of a lightweight way of including a unit test alongside the definition, the `should_match` list is a list of examples of `execv(3)` args that should match the rule and `should_not_match` is a list of examples that should not match. These examples are verified when the `.policy` file is loaded. - -Note that the language of the `.policy` file is still evolving, as we have to continue to expand it so it is sufficiently expressive to accept all commands we want to consider "safe" without allowing unsafe commands to pass through. - -The integrity of `default.policy` is verified [via unit tests](./tests). - -Further, the CLI supports a `--policy` option to specify a custom `.policy` file for ad-hoc testing. - -## Output Type: `match` - -Going back to the `cp` example, because the rule matches an `ARG_WFILE`, it will return `match` instead of `safe`: - -```shell -cargo run -- check cp src1 src2 dest | jq +## CLI +- From the Codex CLI, run `codex execpolicy check` subcommand with one or more policy files (for example `src/default.codexpolicy`) to check a command: +```bash +codex execpolicy check --policy path/to/policy.codexpolicy git status ``` +- Pass multiple `--policy` flags to merge rules, evaluated in the order provided, and use `--pretty` for formatted JSON. +- You can also run the standalone dev binary directly during development: +```bash +cargo run -p codex-execpolicy -- check --policy path/to/policy.codexpolicy git status +``` +- Example outcomes: + - Match: `{"match": { ... "decision": "allow" ... }}` + - No match: `{"noMatch": {}}` -If the caller wants to consider allowing this command, it should parse the JSON to pick out the `WriteableFile` arguments and decide whether they are safe to write: - +## Response shapes +- Match: ```json { - "result": "match", "match": { - "program": "cp", - "flags": [], - "opts": [], - "args": [ + "decision": "allow|prompt|forbidden", + "matchedRules": [ { - "index": 0, - "type": "ReadableFile", - "value": "src1" - }, - { - "index": 1, - "type": "ReadableFile", - "value": "src2" - }, - { - "index": 2, - "type": "WriteableFile", - "value": "dest" + "prefixRuleMatch": { + "matchedPrefix": ["", "..."], + "decision": "allow|prompt|forbidden" + } } - ], - "system_path": ["/bin/cp", "/usr/bin/cp"] + ] } } ``` -Note the exit code is still `0` for a `match` unless the `--require-safe` flag is specified, in which case the exit code is `12`. - -## Output Type: `forbidden` - -It is also possible to define a rule that, if it matches a command, should flag it as _forbidden_. For example, we do not want agents to be able to run `applied deploy` _ever_, so we define the following rule: - -```python -define_program( - program="applied", - args=["deploy"], - forbidden="Infrastructure Risk: command contains 'applied deploy'", - should_match=[ - ["deploy"], - ], - should_not_match=[ - ["lint"], - ], -) +- No match: +```json +{"noMatch": {}} ``` -Note that for a rule to be forbidden, the `forbidden` keyword arg must be specified as the reason the command is forbidden. This will be included in the output: - -```shell -cargo run -- check applied deploy | jq -``` +- `matchedRules` lists every rule whose prefix matched the command; `matchedPrefix` is the exact prefix that matched. +- The effective `decision` is the strictest severity across all matches (`forbidden` > `prompt` > `allow`). -```json -{ - "result": "forbidden", - "reason": "Infrastructure Risk: command contains 'applied deploy'", - "cause": { - "Exec": { - "exec": { - "program": "applied", - "flags": [], - "opts": [], - "args": [ - { - "index": 0, - "type": { - "Literal": "deploy" - }, - "value": "deploy" - } - ], - "system_path": [] - } - } - } -} -``` +Note: `execpolicy` commands are still in preview. The API may have breaking changes in the future. diff --git a/codex-rs/execpolicy2/examples/example.codexpolicy b/codex-rs/execpolicy/examples/example.codexpolicy similarity index 100% rename from codex-rs/execpolicy2/examples/example.codexpolicy rename to codex-rs/execpolicy/examples/example.codexpolicy diff --git a/codex-rs/execpolicy2/src/decision.rs b/codex-rs/execpolicy/src/decision.rs similarity index 100% rename from codex-rs/execpolicy2/src/decision.rs rename to codex-rs/execpolicy/src/decision.rs diff --git a/codex-rs/execpolicy/src/error.rs b/codex-rs/execpolicy/src/error.rs index e6443d69d..2f168a027 100644 --- a/codex-rs/execpolicy/src/error.rs +++ b/codex-rs/execpolicy/src/error.rs @@ -1,96 +1,26 @@ -use std::path::PathBuf; - -use serde::Serialize; - -use crate::arg_matcher::ArgMatcher; -use crate::arg_resolver::PositionalArg; -use serde_with::DisplayFromStr; -use serde_with::serde_as; +use starlark::Error as StarlarkError; +use thiserror::Error; pub type Result = std::result::Result; -#[serde_as] -#[derive(Debug, Eq, PartialEq, Serialize)] -#[serde(tag = "type")] +#[derive(Debug, Error)] pub enum Error { - NoSpecForProgram { - program: String, - }, - OptionMissingValue { - program: String, - option: String, - }, - OptionFollowedByOptionInsteadOfValue { - program: String, - option: String, - value: String, - }, - UnknownOption { - program: String, - option: String, - }, - UnexpectedArguments { - program: String, - args: Vec, - }, - DoubleDashNotSupportedYet { - program: String, - }, - MultipleVarargPatterns { - program: String, - first: ArgMatcher, - second: ArgMatcher, - }, - RangeStartExceedsEnd { - start: usize, - end: usize, - }, - RangeEndOutOfBounds { - end: usize, - len: usize, - }, - PrefixOverlapsSuffix {}, - NotEnoughArgs { - program: String, - args: Vec, - arg_patterns: Vec, - }, - InternalInvariantViolation { - message: String, - }, - VarargMatcherDidNotMatchAnything { - program: String, - matcher: ArgMatcher, - }, - EmptyFileName {}, - LiteralValueDidNotMatch { - expected: String, - actual: String, - }, - InvalidPositiveInteger { - value: String, - }, - MissingRequiredOptions { - program: String, - options: Vec, - }, - SedCommandNotProvablySafe { - command: String, - }, - ReadablePathNotInReadableFolders { - file: PathBuf, - folders: Vec, - }, - WriteablePathNotInWriteableFolders { - file: PathBuf, - folders: Vec, - }, - CannotCheckRelativePath { - file: PathBuf, - }, - CannotCanonicalizePath { - file: String, - #[serde_as(as = "DisplayFromStr")] - error: std::io::ErrorKind, - }, + #[error("invalid decision: {0}")] + InvalidDecision(String), + #[error("invalid pattern element: {0}")] + InvalidPattern(String), + #[error("invalid example: {0}")] + InvalidExample(String), + #[error( + "expected every example to match at least one rule. rules: {rules:?}; unmatched examples: \ + {examples:?}" + )] + ExampleDidNotMatch { + rules: Vec, + examples: Vec, + }, + #[error("expected example to not match rule `{rule}`: {example}")] + ExampleDidMatch { rule: String, example: String }, + #[error("starlark error: {0}")] + Starlark(StarlarkError), } diff --git a/codex-rs/execpolicy/src/execpolicycheck.rs b/codex-rs/execpolicy/src/execpolicycheck.rs new file mode 100644 index 000000000..0b5e0dcaf --- /dev/null +++ b/codex-rs/execpolicy/src/execpolicycheck.rs @@ -0,0 +1,67 @@ +use std::fs; +use std::path::PathBuf; + +use anyhow::Context; +use anyhow::Result; +use clap::Parser; + +use crate::Evaluation; +use crate::Policy; +use crate::PolicyParser; + +/// Arguments for evaluating a command against one or more execpolicy files. +#[derive(Debug, Parser, Clone)] +pub struct ExecPolicyCheckCommand { + /// Paths to execpolicy files to evaluate (repeatable). + #[arg(short = 'p', long = "policy", value_name = "PATH", required = true)] + pub policies: Vec, + + /// Pretty-print the JSON output. + #[arg(long)] + pub pretty: bool, + + /// Command tokens to check against the policy. + #[arg( + value_name = "COMMAND", + required = true, + trailing_var_arg = true, + allow_hyphen_values = true + )] + pub command: Vec, +} + +impl ExecPolicyCheckCommand { + /// Load the policies for this command, evaluate the command, and render JSON output. + pub fn run(&self) -> Result<()> { + let policy = load_policies(&self.policies)?; + let evaluation = policy.check(&self.command); + + let json = format_evaluation_json(&evaluation, self.pretty)?; + println!("{json}"); + + Ok(()) + } +} + +pub fn format_evaluation_json(evaluation: &Evaluation, pretty: bool) -> Result { + if pretty { + serde_json::to_string_pretty(evaluation).map_err(Into::into) + } else { + serde_json::to_string(evaluation).map_err(Into::into) + } +} + +pub fn load_policies(policy_paths: &[PathBuf]) -> Result { + let mut parser = PolicyParser::new(); + + for policy_path in policy_paths { + let policy_file_contents = fs::read_to_string(policy_path) + .with_context(|| format!("failed to read policy at {}", policy_path.display()))?; + let policy_identifier = policy_path.to_string_lossy().to_string(); + parser + .parse(&policy_identifier, &policy_file_contents) + .with_context(|| format!("failed to parse policy at {}", policy_path.display()))?; + } + + Ok(parser.build()) +} diff --git a/codex-rs/execpolicy/src/lib.rs b/codex-rs/execpolicy/src/lib.rs index 6f1222598..b459caea1 100644 --- a/codex-rs/execpolicy/src/lib.rs +++ b/codex-rs/execpolicy/src/lib.rs @@ -1,45 +1,17 @@ -#![allow(clippy::type_complexity)] -#![allow(clippy::too_many_arguments)] -#[macro_use] -extern crate starlark; +pub mod decision; +pub mod error; +pub mod execpolicycheck; +pub mod parser; +pub mod policy; +pub mod rule; -mod arg_matcher; -mod arg_resolver; -mod arg_type; -mod error; -mod exec_call; -mod execv_checker; -mod opt; -mod policy; -mod policy_parser; -mod program; -mod sed_command; -mod valid_exec; - -pub use arg_matcher::ArgMatcher; -pub use arg_resolver::PositionalArg; -pub use arg_type::ArgType; +pub use decision::Decision; pub use error::Error; pub use error::Result; -pub use exec_call::ExecCall; -pub use execv_checker::ExecvChecker; -pub use opt::Opt; +pub use execpolicycheck::ExecPolicyCheckCommand; +pub use parser::PolicyParser; +pub use policy::Evaluation; pub use policy::Policy; -pub use policy_parser::PolicyParser; -pub use program::Forbidden; -pub use program::MatchedExec; -pub use program::NegativeExamplePassedCheck; -pub use program::PositiveExampleFailedCheck; -pub use program::ProgramSpec; -pub use sed_command::parse_sed_command; -pub use valid_exec::MatchedArg; -pub use valid_exec::MatchedFlag; -pub use valid_exec::MatchedOpt; -pub use valid_exec::ValidExec; - -const DEFAULT_POLICY: &str = include_str!("default.policy"); - -pub fn get_default_policy() -> starlark::Result { - let parser = PolicyParser::new("#default", DEFAULT_POLICY); - parser.parse() -} +pub use rule::Rule; +pub use rule::RuleMatch; +pub use rule::RuleRef; diff --git a/codex-rs/execpolicy/src/main.rs b/codex-rs/execpolicy/src/main.rs index 68a72c04f..e1373b6d1 100644 --- a/codex-rs/execpolicy/src/main.rs +++ b/codex-rs/execpolicy/src/main.rs @@ -1,167 +1,22 @@ use anyhow::Result; use clap::Parser; -use clap::Subcommand; -use codex_execpolicy::ExecCall; -use codex_execpolicy::MatchedExec; -use codex_execpolicy::Policy; -use codex_execpolicy::PolicyParser; -use codex_execpolicy::ValidExec; -use codex_execpolicy::get_default_policy; -use serde::Deserialize; -use serde::Serialize; -use serde::de; -use starlark::Error as StarlarkError; -use std::path::PathBuf; -use std::str::FromStr; +use codex_execpolicy::ExecPolicyCheckCommand; -const MATCHED_BUT_WRITES_FILES_EXIT_CODE: i32 = 12; -const MIGHT_BE_SAFE_EXIT_CODE: i32 = 13; -const FORBIDDEN_EXIT_CODE: i32 = 14; - -#[derive(Parser, Deserialize, Debug)] -#[command(version, about, long_about = None)] -pub struct Args { - /// If the command fails the policy, exit with 13, but print parseable JSON - /// to stdout. - #[clap(long)] - pub require_safe: bool, - - /// Path to the policy file. - #[clap(long, short = 'p')] - pub policy: Option, - - #[command(subcommand)] - pub command: Command, -} - -#[derive(Clone, Debug, Deserialize, Subcommand)] -pub enum Command { - /// Checks the command as if the arguments were the inputs to execv(3). - Check { - #[arg(trailing_var_arg = true)] - command: Vec, - }, - - /// Checks the command encoded as a JSON object. - #[clap(name = "check-json")] - CheckJson { - /// JSON object with "program" (str) and "args" (list[str]) fields. - #[serde(deserialize_with = "deserialize_from_json")] - exec: ExecArg, - }, -} - -#[derive(Clone, Debug, Deserialize)] -pub struct ExecArg { - pub program: String, - - #[serde(default)] - pub args: Vec, +/// CLI for evaluating exec policies +#[derive(Parser)] +#[command(name = "codex-execpolicy")] +enum Cli { + /// Evaluate a command against a policy. + Check(ExecPolicyCheckCommand), } fn main() -> Result<()> { - env_logger::init(); - - let args = Args::parse(); - let policy = match args.policy { - Some(policy) => { - let policy_source = policy.to_string_lossy().to_string(); - let unparsed_policy = std::fs::read_to_string(policy)?; - let parser = PolicyParser::new(&policy_source, &unparsed_policy); - parser.parse() - } - None => get_default_policy(), - }; - let policy = policy.map_err(StarlarkError::into_anyhow)?; - - let exec = match args.command { - Command::Check { command } => match command.split_first() { - Some((first, rest)) => ExecArg { - program: first.to_string(), - args: rest.to_vec(), - }, - None => { - eprintln!("no command provided"); - std::process::exit(1); - } - }, - Command::CheckJson { exec } => exec, - }; - - let (output, exit_code) = check_command(&policy, exec, args.require_safe); - let json = serde_json::to_string(&output)?; - println!("{json}"); - std::process::exit(exit_code); -} - -fn check_command( - policy: &Policy, - ExecArg { program, args }: ExecArg, - check: bool, -) -> (Output, i32) { - let exec_call = ExecCall { program, args }; - match policy.check(&exec_call) { - Ok(MatchedExec::Match { exec }) => { - if exec.might_write_files() { - let exit_code = if check { - MATCHED_BUT_WRITES_FILES_EXIT_CODE - } else { - 0 - }; - (Output::Match { r#match: exec }, exit_code) - } else { - (Output::Safe { r#match: exec }, 0) - } - } - Ok(MatchedExec::Forbidden { reason, cause }) => { - let exit_code = if check { FORBIDDEN_EXIT_CODE } else { 0 }; - (Output::Forbidden { reason, cause }, exit_code) - } - Err(err) => { - let exit_code = if check { MIGHT_BE_SAFE_EXIT_CODE } else { 0 }; - (Output::Unverified { error: err }, exit_code) - } + let cli = Cli::parse(); + match cli { + Cli::Check(cmd) => cmd_check(cmd), } } -#[derive(Debug, Serialize)] -#[serde(tag = "result")] -pub enum Output { - /// The command is verified as safe. - #[serde(rename = "safe")] - Safe { r#match: ValidExec }, - - /// The command has matched a rule in the policy, but the caller should - /// decide whether it is "safe" given the files it wants to write. - #[serde(rename = "match")] - Match { r#match: ValidExec }, - - /// The user is forbidden from running the command. - #[serde(rename = "forbidden")] - Forbidden { - reason: String, - cause: codex_execpolicy::Forbidden, - }, - - /// The safety of the command could not be verified. - #[serde(rename = "unverified")] - Unverified { error: codex_execpolicy::Error }, -} - -fn deserialize_from_json<'de, D>(deserializer: D) -> Result -where - D: de::Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - let decoded = serde_json::from_str(&s) - .map_err(|e| serde::de::Error::custom(format!("JSON parse error: {e}")))?; - Ok(decoded) -} - -impl FromStr for ExecArg { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - serde_json::from_str(s).map_err(Into::into) - } +fn cmd_check(cmd: ExecPolicyCheckCommand) -> Result<()> { + cmd.run() } diff --git a/codex-rs/execpolicy2/src/parser.rs b/codex-rs/execpolicy/src/parser.rs similarity index 100% rename from codex-rs/execpolicy2/src/parser.rs rename to codex-rs/execpolicy/src/parser.rs diff --git a/codex-rs/execpolicy/src/policy.rs b/codex-rs/execpolicy/src/policy.rs index 825d6164a..e048fce1f 100644 --- a/codex-rs/execpolicy/src/policy.rs +++ b/codex-rs/execpolicy/src/policy.rs @@ -1,103 +1,84 @@ +use crate::decision::Decision; +use crate::rule::RuleMatch; +use crate::rule::RuleRef; use multimap::MultiMap; -use regex_lite::Error as RegexError; -use regex_lite::Regex; - -use crate::ExecCall; -use crate::Forbidden; -use crate::MatchedExec; -use crate::NegativeExamplePassedCheck; -use crate::ProgramSpec; -use crate::error::Error; -use crate::error::Result; -use crate::policy_parser::ForbiddenProgramRegex; -use crate::program::PositiveExampleFailedCheck; +use serde::Deserialize; +use serde::Serialize; +#[derive(Clone, Debug)] pub struct Policy { - programs: MultiMap, - forbidden_program_regexes: Vec, - forbidden_substrings_pattern: Option, + rules_by_program: MultiMap, } impl Policy { - pub fn new( - programs: MultiMap, - forbidden_program_regexes: Vec, - forbidden_substrings: Vec, - ) -> std::result::Result { - let forbidden_substrings_pattern = if forbidden_substrings.is_empty() { - None - } else { - let escaped_substrings = forbidden_substrings - .iter() - .map(|s| regex_lite::escape(s)) - .collect::>() - .join("|"); - Some(Regex::new(&format!("({escaped_substrings})"))?) - }; - Ok(Self { - programs, - forbidden_program_regexes, - forbidden_substrings_pattern, - }) + pub fn new(rules_by_program: MultiMap) -> Self { + Self { rules_by_program } } - pub fn check(&self, exec_call: &ExecCall) -> Result { - let ExecCall { program, args } = &exec_call; - for ForbiddenProgramRegex { regex, reason } in &self.forbidden_program_regexes { - if regex.is_match(program) { - return Ok(MatchedExec::Forbidden { - cause: Forbidden::Program { - program: program.clone(), - exec_call: exec_call.clone(), - }, - reason: reason.clone(), - }); - } - } + pub fn empty() -> Self { + Self::new(MultiMap::new()) + } - for arg in args { - if let Some(regex) = &self.forbidden_substrings_pattern - && regex.is_match(arg) - { - return Ok(MatchedExec::Forbidden { - cause: Forbidden::Arg { - arg: arg.clone(), - exec_call: exec_call.clone(), - }, - reason: format!("arg `{arg}` contains forbidden substring"), - }); - } - } + pub fn rules(&self) -> &MultiMap { + &self.rules_by_program + } - let mut last_err = Err(Error::NoSpecForProgram { - program: program.clone(), - }); - if let Some(spec_list) = self.programs.get_vec(program) { - for spec in spec_list { - match spec.check(exec_call) { - Ok(matched_exec) => return Ok(matched_exec), - Err(err) => { - last_err = Err(err); - } - } - } + pub fn check(&self, cmd: &[String]) -> Evaluation { + let rules = match cmd.first() { + Some(first) => match self.rules_by_program.get_vec(first) { + Some(rules) => rules, + None => return Evaluation::NoMatch {}, + }, + None => return Evaluation::NoMatch {}, + }; + + let matched_rules: Vec = + rules.iter().filter_map(|rule| rule.matches(cmd)).collect(); + match matched_rules.iter().map(RuleMatch::decision).max() { + Some(decision) => Evaluation::Match { + decision, + matched_rules, + }, + None => Evaluation::NoMatch {}, } - last_err } - pub fn check_each_good_list_individually(&self) -> Vec { - let mut violations = Vec::new(); - for (_program, spec) in self.programs.flat_iter() { - violations.extend(spec.verify_should_match_list()); + pub fn check_multiple(&self, commands: Commands) -> Evaluation + where + Commands: IntoIterator, + Commands::Item: AsRef<[String]>, + { + let matched_rules: Vec = commands + .into_iter() + .flat_map(|command| match self.check(command.as_ref()) { + Evaluation::Match { matched_rules, .. } => matched_rules, + Evaluation::NoMatch { .. } => Vec::new(), + }) + .collect(); + + match matched_rules.iter().map(RuleMatch::decision).max() { + Some(decision) => Evaluation::Match { + decision, + matched_rules, + }, + None => Evaluation::NoMatch {}, } - violations } +} - pub fn check_each_bad_list_individually(&self) -> Vec { - let mut violations = Vec::new(); - for (_program, spec) in self.programs.flat_iter() { - violations.extend(spec.verify_should_not_match_list()); - } - violations +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Evaluation { + NoMatch {}, + Match { + decision: Decision, + #[serde(rename = "matchedRules")] + matched_rules: Vec, + }, +} + +impl Evaluation { + pub fn is_match(&self) -> bool { + matches!(self, Self::Match { .. }) } } diff --git a/codex-rs/execpolicy2/src/rule.rs b/codex-rs/execpolicy/src/rule.rs similarity index 100% rename from codex-rs/execpolicy2/src/rule.rs rename to codex-rs/execpolicy/src/rule.rs diff --git a/codex-rs/execpolicy2/tests/basic.rs b/codex-rs/execpolicy/tests/basic.rs similarity index 96% rename from codex-rs/execpolicy2/tests/basic.rs rename to codex-rs/execpolicy/tests/basic.rs index 75921ecd5..e4189caa2 100644 --- a/codex-rs/execpolicy2/tests/basic.rs +++ b/codex-rs/execpolicy/tests/basic.rs @@ -1,14 +1,14 @@ use std::any::Any; use std::sync::Arc; -use codex_execpolicy2::Decision; -use codex_execpolicy2::Evaluation; -use codex_execpolicy2::PolicyParser; -use codex_execpolicy2::RuleMatch; -use codex_execpolicy2::RuleRef; -use codex_execpolicy2::rule::PatternToken; -use codex_execpolicy2::rule::PrefixPattern; -use codex_execpolicy2::rule::PrefixRule; +use codex_execpolicy::Decision; +use codex_execpolicy::Evaluation; +use codex_execpolicy::PolicyParser; +use codex_execpolicy::RuleMatch; +use codex_execpolicy::RuleRef; +use codex_execpolicy::rule::PatternToken; +use codex_execpolicy::rule::PrefixPattern; +use codex_execpolicy::rule::PrefixRule; use pretty_assertions::assert_eq; fn tokens(cmd: &[&str]) -> Vec { @@ -288,7 +288,7 @@ prefix_rule( "color.status=always", "status", ])); - assert_eq!(Evaluation::NoMatch, no_match_eval); + assert_eq!(Evaluation::NoMatch {}, no_match_eval); } #[test] diff --git a/codex-rs/execpolicy2/Cargo.toml b/codex-rs/execpolicy2/Cargo.toml deleted file mode 100644 index c86fb7ff7..000000000 --- a/codex-rs/execpolicy2/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "codex-execpolicy2" -version = { workspace = true } -edition = "2024" -description = "Codex exec policy v2: prefix-based Starlark rules for command decisions." - -[lib] -name = "codex_execpolicy2" -path = "src/lib.rs" - -[[bin]] -name = "codex-execpolicy2" -path = "src/main.rs" - -[lints] -workspace = true - -[dependencies] -anyhow = { workspace = true } -clap = { workspace = true, features = ["derive"] } -multimap = { workspace = true } -serde = { workspace = true, features = ["derive"] } -serde_json = { workspace = true } -shlex = { workspace = true } -starlark = { workspace = true } -thiserror = { workspace = true } - -[dev-dependencies] -pretty_assertions = { workspace = true } diff --git a/codex-rs/execpolicy2/README.md b/codex-rs/execpolicy2/README.md deleted file mode 100644 index 8cf5302fb..000000000 --- a/codex-rs/execpolicy2/README.md +++ /dev/null @@ -1,59 +0,0 @@ -# codex-execpolicy2 - -## Overview -- Policy engine and CLI built around `prefix_rule(pattern=[...], decision?, match?, not_match?)`. -- This release covers only the prefix-rule subset of the planned execpolicy v2 language; a richer language will follow. -- Tokens are matched in order; any `pattern` element may be a list to denote alternatives. `decision` defaults to `allow`; valid values: `allow`, `prompt`, `forbidden`. -- `match` / `not_match` supply example invocations that are validated at load time (think of them as unit tests); examples can be token arrays or strings (strings are tokenized with `shlex`). -- The CLI always prints the JSON serialization of the evaluation result (whether a match or not). - -## Policy shapes -- Prefix rules use Starlark syntax: -```starlark -prefix_rule( - pattern = ["cmd", ["alt1", "alt2"]], # ordered tokens; list entries denote alternatives - decision = "prompt", # allow | prompt | forbidden; defaults to allow - match = [["cmd", "alt1"], "cmd alt2"], # examples that must match this rule - not_match = [["cmd", "oops"], "cmd alt3"], # examples that must not match this rule -) -``` - -## Response shapes -- Match: -```json -{ - "match": { - "decision": "allow|prompt|forbidden", - "matchedRules": [ - { - "prefixRuleMatch": { - "matchedPrefix": ["", "..."], - "decision": "allow|prompt|forbidden" - } - } - ] - } -} -``` - -- No match: -```json -"noMatch" -``` - -- `matchedRules` lists every rule whose prefix matched the command; `matchedPrefix` is the exact prefix that matched. -- The effective `decision` is the strictest severity across all matches (`forbidden` > `prompt` > `allow`). - -## CLI -- Provide one or more policy files (for example `src/default.codexpolicy`) to check a command: -```bash -cargo run -p codex-execpolicy2 -- check --policy path/to/policy.codexpolicy git status -``` -- Pass multiple `--policy` flags to merge rules, evaluated in the order provided: -```bash -cargo run -p codex-execpolicy2 -- check --policy base.codexpolicy --policy overrides.codexpolicy git status -``` -- Output is newline-delimited JSON by default; pass `--pretty` for pretty-printed JSON if desired. -- Example outcomes: - - Match: `{"match": { ... "decision": "allow" ... }}` - - No match: `"noMatch"` diff --git a/codex-rs/execpolicy2/src/error.rs b/codex-rs/execpolicy2/src/error.rs deleted file mode 100644 index 2f168a027..000000000 --- a/codex-rs/execpolicy2/src/error.rs +++ /dev/null @@ -1,26 +0,0 @@ -use starlark::Error as StarlarkError; -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Debug, Error)] -pub enum Error { - #[error("invalid decision: {0}")] - InvalidDecision(String), - #[error("invalid pattern element: {0}")] - InvalidPattern(String), - #[error("invalid example: {0}")] - InvalidExample(String), - #[error( - "expected every example to match at least one rule. rules: {rules:?}; unmatched examples: \ - {examples:?}" - )] - ExampleDidNotMatch { - rules: Vec, - examples: Vec, - }, - #[error("expected example to not match rule `{rule}`: {example}")] - ExampleDidMatch { rule: String, example: String }, - #[error("starlark error: {0}")] - Starlark(StarlarkError), -} diff --git a/codex-rs/execpolicy2/src/lib.rs b/codex-rs/execpolicy2/src/lib.rs deleted file mode 100644 index 1b789fd86..000000000 --- a/codex-rs/execpolicy2/src/lib.rs +++ /dev/null @@ -1,15 +0,0 @@ -pub mod decision; -pub mod error; -pub mod parser; -pub mod policy; -pub mod rule; - -pub use decision::Decision; -pub use error::Error; -pub use error::Result; -pub use parser::PolicyParser; -pub use policy::Evaluation; -pub use policy::Policy; -pub use rule::Rule; -pub use rule::RuleMatch; -pub use rule::RuleRef; diff --git a/codex-rs/execpolicy2/src/main.rs b/codex-rs/execpolicy2/src/main.rs deleted file mode 100644 index f654c5f84..000000000 --- a/codex-rs/execpolicy2/src/main.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::fs; -use std::path::PathBuf; - -use anyhow::Context; -use anyhow::Result; -use clap::Parser; -use codex_execpolicy2::PolicyParser; - -/// CLI for evaluating exec policies -#[derive(Parser)] -#[command(name = "codex-execpolicy2")] -enum Cli { - /// Evaluate a command against a policy. - Check { - #[arg(short, long = "policy", value_name = "PATH", required = true)] - policies: Vec, - - /// Pretty-print the JSON output. - #[arg(long)] - pretty: bool, - - /// Command tokens to check. - #[arg( - value_name = "COMMAND", - required = true, - trailing_var_arg = true, - allow_hyphen_values = true - )] - command: Vec, - }, -} - -fn main() -> Result<()> { - let cli = Cli::parse(); - match cli { - Cli::Check { - policies, - command, - pretty, - } => cmd_check(policies, command, pretty), - } -} - -fn cmd_check(policy_paths: Vec, args: Vec, pretty: bool) -> Result<()> { - let policy = load_policies(&policy_paths)?; - - let eval = policy.check(&args); - let json = if pretty { - serde_json::to_string_pretty(&eval)? - } else { - serde_json::to_string(&eval)? - }; - println!("{json}"); - Ok(()) -} - -fn load_policies(policy_paths: &[PathBuf]) -> Result { - let mut parser = PolicyParser::new(); - for policy_path in policy_paths { - let policy_file_contents = fs::read_to_string(policy_path) - .with_context(|| format!("failed to read policy at {}", policy_path.display()))?; - let policy_identifier = policy_path.to_string_lossy().to_string(); - parser.parse(&policy_identifier, &policy_file_contents)?; - } - Ok(parser.build()) -} diff --git a/codex-rs/execpolicy2/src/policy.rs b/codex-rs/execpolicy2/src/policy.rs deleted file mode 100644 index 12416b050..000000000 --- a/codex-rs/execpolicy2/src/policy.rs +++ /dev/null @@ -1,80 +0,0 @@ -use crate::decision::Decision; -use crate::rule::RuleMatch; -use crate::rule::RuleRef; -use multimap::MultiMap; -use serde::Deserialize; -use serde::Serialize; - -#[derive(Clone, Debug)] -pub struct Policy { - rules_by_program: MultiMap, -} - -impl Policy { - pub fn new(rules_by_program: MultiMap) -> Self { - Self { rules_by_program } - } - - pub fn rules(&self) -> &MultiMap { - &self.rules_by_program - } - - pub fn check(&self, cmd: &[String]) -> Evaluation { - let rules = match cmd.first() { - Some(first) => match self.rules_by_program.get_vec(first) { - Some(rules) => rules, - None => return Evaluation::NoMatch, - }, - None => return Evaluation::NoMatch, - }; - - let matched_rules: Vec = - rules.iter().filter_map(|rule| rule.matches(cmd)).collect(); - match matched_rules.iter().map(RuleMatch::decision).max() { - Some(decision) => Evaluation::Match { - decision, - matched_rules, - }, - None => Evaluation::NoMatch, - } - } - - pub fn check_multiple(&self, commands: Commands) -> Evaluation - where - Commands: IntoIterator, - Commands::Item: AsRef<[String]>, - { - let matched_rules: Vec = commands - .into_iter() - .flat_map(|command| match self.check(command.as_ref()) { - Evaluation::Match { matched_rules, .. } => matched_rules, - Evaluation::NoMatch => Vec::new(), - }) - .collect(); - - match matched_rules.iter().map(RuleMatch::decision).max() { - Some(decision) => Evaluation::Match { - decision, - matched_rules, - }, - None => Evaluation::NoMatch, - } - } -} - -#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub enum Evaluation { - NoMatch, - Match { - decision: Decision, - #[serde(rename = "matchedRules")] - matched_rules: Vec, - }, -} - -impl Evaluation { - pub fn is_match(&self) -> bool { - matches!(self, Self::Match { .. }) - } -} diff --git a/codex-rs/linux-sandbox/tests/suite/landlock.rs b/codex-rs/linux-sandbox/tests/suite/landlock.rs index 508dba08c..27099fb29 100644 --- a/codex-rs/linux-sandbox/tests/suite/landlock.rs +++ b/codex-rs/linux-sandbox/tests/suite/landlock.rs @@ -40,7 +40,7 @@ async fn run_cmd(cmd: &[&str], writable_roots: &[PathBuf], timeout_ms: u64) { let params = ExecParams { command: cmd.iter().copied().map(str::to_owned).collect(), cwd, - timeout_ms: Some(timeout_ms), + expiration: timeout_ms.into(), env: create_env_from_core_vars(), with_escalated_permissions: None, justification: None, @@ -143,7 +143,7 @@ async fn assert_network_blocked(cmd: &[&str]) { cwd, // Give the tool a generous 2-second timeout so even slow DNS timeouts // do not stall the suite. - timeout_ms: Some(NETWORK_TIMEOUT_MS), + expiration: NETWORK_TIMEOUT_MS.into(), env: create_env_from_core_vars(), with_escalated_permissions: None, justification: None, diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index 93dc7764d..8dccb5125 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -210,6 +210,7 @@ async fn run_codex_tool_session_inner( } EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id, + turn_id: _, reason, grant_root, changes, diff --git a/codex-rs/mcp-server/src/outgoing_message.rs b/codex-rs/mcp-server/src/outgoing_message.rs index 4b6782d8d..9e9d07930 100644 --- a/codex-rs/mcp-server/src/outgoing_message.rs +++ b/codex-rs/mcp-server/src/outgoing_message.rs @@ -231,8 +231,12 @@ pub(crate) struct OutgoingError { #[cfg(test)] mod tests { + use std::path::PathBuf; + use anyhow::Result; + use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; + use codex_core::protocol::SandboxPolicy; use codex_core::protocol::SessionConfiguredEvent; use codex_protocol::ConversationId; use codex_protocol::config_types::ReasoningEffort; @@ -254,6 +258,10 @@ mod tests { msg: EventMsg::SessionConfigured(SessionConfiguredEvent { session_id: conversation_id, model: "gpt-4o".to_string(), + model_provider_id: "test-provider".to_string(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + cwd: PathBuf::from("/home/user/project"), reasoning_effort: Some(ReasoningEffort::default()), history_log_id: 1, history_entry_count: 1000, @@ -289,6 +297,10 @@ mod tests { let session_configured_event = SessionConfiguredEvent { session_id: conversation_id, model: "gpt-4o".to_string(), + model_provider_id: "test-provider".to_string(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + cwd: PathBuf::from("/home/user/project"), reasoning_effort: Some(ReasoningEffort::default()), history_log_id: 1, history_entry_count: 1000, @@ -318,12 +330,18 @@ mod tests { }, "id": "1", "msg": { + "type": "session_configured", "session_id": session_configured_event.session_id, - "model": session_configured_event.model, + "model": "gpt-4o", + "model_provider_id": "test-provider", + "approval_policy": "never", + "sandbox_policy": { + "type": "read-only" + }, + "cwd": "/home/user/project", "reasoning_effort": session_configured_event.reasoning_effort, "history_log_id": session_configured_event.history_log_id, "history_entry_count": session_configured_event.history_entry_count, - "type": "session_configured", "rollout_path": rollout_file.path().to_path_buf(), } }); diff --git a/codex-rs/mcp-server/tests/common/Cargo.toml b/codex-rs/mcp-server/tests/common/Cargo.toml index 7c2bc2266..286cd63c7 100644 --- a/codex-rs/mcp-server/tests/common/Cargo.toml +++ b/codex-rs/mcp-server/tests/common/Cargo.toml @@ -23,3 +23,5 @@ tokio = { workspace = true, features = [ "rt-multi-thread", ] } wiremock = { workspace = true } +core_test_support = { path = "../../../core/tests/common" } +shlex = { workspace = true } diff --git a/codex-rs/mcp-server/tests/common/lib.rs b/codex-rs/mcp-server/tests/common/lib.rs index d088b184e..d8d6b6738 100644 --- a/codex-rs/mcp-server/tests/common/lib.rs +++ b/codex-rs/mcp-server/tests/common/lib.rs @@ -2,12 +2,13 @@ mod mcp_process; mod mock_model_server; mod responses; +pub use core_test_support::format_with_current_shell; pub use mcp_process::McpProcess; use mcp_types::JSONRPCResponse; pub use mock_model_server::create_mock_chat_completions_server; pub use responses::create_apply_patch_sse_response; pub use responses::create_final_assistant_message_sse_response; -pub use responses::create_shell_sse_response; +pub use responses::create_shell_command_sse_response; use serde::de::DeserializeOwned; pub fn to_response(response: JSONRPCResponse) -> anyhow::Result { diff --git a/codex-rs/mcp-server/tests/common/responses.rs b/codex-rs/mcp-server/tests/common/responses.rs index 9a827fb98..0a9183c04 100644 --- a/codex-rs/mcp-server/tests/common/responses.rs +++ b/codex-rs/mcp-server/tests/common/responses.rs @@ -1,17 +1,18 @@ use serde_json::json; use std::path::Path; -pub fn create_shell_sse_response( +pub fn create_shell_command_sse_response( command: Vec, workdir: Option<&Path>, timeout_ms: Option, call_id: &str, ) -> anyhow::Result { - // The `arguments`` for the `shell` tool is a serialized JSON object. + // The `arguments` for the `shell_command` tool is a serialized JSON object. + let command_str = shlex::try_join(command.iter().map(String::as_str))?; let tool_call_arguments = serde_json::to_string(&json!({ - "command": command, + "command": command_str, "workdir": workdir.map(|w| w.to_string_lossy()), - "timeout": timeout_ms + "timeout_ms": timeout_ms }))?; let tool_call = json!({ "choices": [ @@ -21,7 +22,7 @@ pub fn create_shell_sse_response( { "id": call_id, "function": { - "name": "shell", + "name": "shell_command", "arguments": tool_call_arguments } } @@ -62,10 +63,10 @@ pub fn create_apply_patch_sse_response( patch_content: &str, call_id: &str, ) -> anyhow::Result { - // Use shell command to call apply_patch with heredoc format - let shell_command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF"); + // Use shell_command to call apply_patch with heredoc format + let command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF"); let tool_call_arguments = serde_json::to_string(&json!({ - "command": ["bash", "-lc", shell_command] + "command": command }))?; let tool_call = json!({ @@ -76,7 +77,7 @@ pub fn create_apply_patch_sse_response( { "id": call_id, "function": { - "name": "shell", + "name": "shell_command", "arguments": tool_call_arguments } } diff --git a/codex-rs/mcp-server/tests/suite/codex_tool.rs b/codex-rs/mcp-server/tests/suite/codex_tool.rs index ae0c23f1d..f65495c47 100644 --- a/codex-rs/mcp-server/tests/suite/codex_tool.rs +++ b/codex-rs/mcp-server/tests/suite/codex_tool.rs @@ -30,7 +30,8 @@ use mcp_test_support::McpProcess; use mcp_test_support::create_apply_patch_sse_response; use mcp_test_support::create_final_assistant_message_sse_response; use mcp_test_support::create_mock_chat_completions_server; -use mcp_test_support::create_shell_sse_response; +use mcp_test_support::create_shell_command_sse_response; +use mcp_test_support::format_with_current_shell; // Allow ample time on slower CI or under load to avoid flakes. const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); @@ -71,13 +72,16 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> { "-c".to_string(), format!("import pathlib; pathlib.Path('{created_filename}').touch()"), ]; + let expected_shell_command = format_with_current_shell(&format!( + "python3 -c \"import pathlib; pathlib.Path('{created_filename}').touch()\"" + )); let McpHandle { process: mut mcp_process, server: _server, dir: _dir, } = create_mcp_process(vec![ - create_shell_sse_response( + create_shell_command_sse_response( shell_command.clone(), Some(workdir_for_shell_function_call.path()), Some(5_000), @@ -111,7 +115,7 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> { )?; let expected_elicitation_request = create_expected_elicitation_request( elicitation_request_id.clone(), - shell_command.clone(), + expected_shell_command, workdir_for_shell_function_call.path(), codex_request_id.to_string(), params.codex_event_id.clone(), @@ -218,6 +222,12 @@ async fn test_patch_approval_triggers_elicitation() { } async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> { + if cfg!(windows) { + // powershell apply_patch shell calls are not parsed into apply patch approvals + + return Ok(()); + } + let cwd = TempDir::new()?; let test_file = cwd.path().join("destination_file.txt"); std::fs::write(&test_file, "original content\n")?; diff --git a/codex-rs/otel/Cargo.toml b/codex-rs/otel/Cargo.toml index ea518c2e4..2ab170c91 100644 --- a/codex-rs/otel/Cargo.toml +++ b/codex-rs/otel/Cargo.toml @@ -27,18 +27,26 @@ opentelemetry-otlp = { workspace = true, features = [ "grpc-tonic", "http-proto", "http-json", + "logs", "reqwest", "reqwest-rustls", + "tls", + "tls-roots", ], optional = true } opentelemetry-semantic-conventions = { workspace = true } opentelemetry_sdk = { workspace = true, features = [ "logs", "rt-tokio", ], optional = true } +http = { workspace = true } reqwest = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } strum_macros = { workspace = true } tokio = { workspace = true } -tonic = { workspace = true, optional = true } +tonic = { workspace = true, optional = true, features = [ + "transport", + "tls-native-roots", + "tls-ring", +] } tracing = { workspace = true } diff --git a/codex-rs/otel/src/config.rs b/codex-rs/otel/src/config.rs index 77063ed09..b6336b3a5 100644 --- a/codex-rs/otel/src/config.rs +++ b/codex-rs/otel/src/config.rs @@ -18,16 +18,25 @@ pub enum OtelHttpProtocol { Json, } +#[derive(Clone, Debug, Default)] +pub struct OtelTlsConfig { + pub ca_certificate: Option, + pub client_certificate: Option, + pub client_private_key: Option, +} + #[derive(Clone, Debug)] pub enum OtelExporter { None, OtlpGrpc { endpoint: String, headers: HashMap, + tls: Option, }, OtlpHttp { endpoint: String, headers: HashMap, protocol: OtelHttpProtocol, + tls: Option, }, } diff --git a/codex-rs/otel/src/otel_event_manager.rs b/codex-rs/otel/src/otel_event_manager.rs index 5d9cbd499..fde351cd6 100644 --- a/codex-rs/otel/src/otel_event_manager.rs +++ b/codex-rs/otel/src/otel_event_manager.rs @@ -88,7 +88,6 @@ impl OtelEventManager { reasoning_effort: Option, reasoning_summary: ReasoningSummary, context_window: Option, - max_output_tokens: Option, auto_compact_token_limit: Option, approval_policy: AskForApproval, sandbox_policy: SandboxPolicy, @@ -111,7 +110,6 @@ impl OtelEventManager { reasoning_effort = reasoning_effort.map(|e| e.to_string()), reasoning_summary = %reasoning_summary, context_window = context_window, - max_output_tokens = max_output_tokens, auto_compact_token_limit = auto_compact_token_limit, approval_policy = %approval_policy, sandbox_policy = %sandbox_policy, diff --git a/codex-rs/otel/src/otel_provider.rs b/codex-rs/otel/src/otel_provider.rs index 222322a2e..8be2431ea 100644 --- a/codex-rs/otel/src/otel_provider.rs +++ b/codex-rs/otel/src/otel_provider.rs @@ -1,8 +1,13 @@ use crate::config::OtelExporter; use crate::config::OtelHttpProtocol; use crate::config::OtelSettings; +use crate::config::OtelTlsConfig; +use http::Uri; use opentelemetry::KeyValue; use opentelemetry_otlp::LogExporter; +use opentelemetry_otlp::OTEL_EXPORTER_OTLP_LOGS_TIMEOUT; +use opentelemetry_otlp::OTEL_EXPORTER_OTLP_TIMEOUT; +use opentelemetry_otlp::OTEL_EXPORTER_OTLP_TIMEOUT_DEFAULT; use opentelemetry_otlp::Protocol; use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithHttpConfig; @@ -10,11 +15,23 @@ use opentelemetry_otlp::WithTonicConfig; use opentelemetry_sdk::Resource; use opentelemetry_sdk::logs::SdkLoggerProvider; use opentelemetry_semantic_conventions as semconv; +use reqwest::Certificate as ReqwestCertificate; +use reqwest::Identity as ReqwestIdentity; use reqwest::header::HeaderMap; use reqwest::header::HeaderName; use reqwest::header::HeaderValue; +use std::env; use std::error::Error; +use std::fs; +use std::io::ErrorKind; +use std::io::{self}; +use std::path::Path; +use std::path::PathBuf; +use std::time::Duration; use tonic::metadata::MetadataMap; +use tonic::transport::Certificate as TonicCertificate; +use tonic::transport::ClientTlsConfig; +use tonic::transport::Identity as TonicIdentity; use tracing::debug; const ENV_ATTRIBUTE: &str = "env"; @@ -47,8 +64,12 @@ impl OtelProvider { debug!("No exporter enabled in OTLP settings."); return Ok(None); } - OtelExporter::OtlpGrpc { endpoint, headers } => { - debug!("Using OTLP Grpc exporter: {}", endpoint); + OtelExporter::OtlpGrpc { + endpoint, + headers, + tls, + } => { + debug!("Using OTLP Grpc exporter: {endpoint}"); let mut header_map = HeaderMap::new(); for (key, value) in headers { @@ -59,10 +80,25 @@ impl OtelProvider { } } + let base_tls_config = ClientTlsConfig::new() + .with_enabled_roots() + .assume_http2(true); + + let tls_config = match tls.as_ref() { + Some(tls) => build_grpc_tls_config( + endpoint, + base_tls_config, + tls, + settings.codex_home.as_path(), + )?, + None => base_tls_config, + }; + let exporter = LogExporter::builder() .with_tonic() .with_endpoint(endpoint) .with_metadata(MetadataMap::from_headers(header_map)) + .with_tls_config(tls_config) .build()?; builder = builder.with_batch_exporter(exporter); @@ -71,20 +107,27 @@ impl OtelProvider { endpoint, headers, protocol, + tls, } => { - debug!("Using OTLP Http exporter: {}", endpoint); + debug!("Using OTLP Http exporter: {endpoint}"); let protocol = match protocol { OtelHttpProtocol::Binary => Protocol::HttpBinary, OtelHttpProtocol::Json => Protocol::HttpJson, }; - let exporter = LogExporter::builder() + let mut exporter_builder = LogExporter::builder() .with_http() .with_endpoint(endpoint) .with_protocol(protocol) - .with_headers(headers.clone()) - .build()?; + .with_headers(headers.clone()); + + if let Some(tls) = tls.as_ref() { + let client = build_http_client(tls, settings.codex_home.as_path())?; + exporter_builder = exporter_builder.with_http_client(client); + } + + let exporter = exporter_builder.build()?; builder = builder.with_batch_exporter(exporter); } @@ -101,3 +144,127 @@ impl Drop for OtelProvider { let _ = self.logger.shutdown(); } } + +fn build_grpc_tls_config( + endpoint: &str, + tls_config: ClientTlsConfig, + tls: &OtelTlsConfig, + codex_home: &Path, +) -> Result> { + let uri: Uri = endpoint.parse()?; + let host = uri.host().ok_or_else(|| { + config_error(format!( + "OTLP gRPC endpoint {endpoint} does not include a host" + )) + })?; + + let mut config = tls_config.domain_name(host.to_owned()); + + if let Some(path) = tls.ca_certificate.as_ref() { + let (pem, _) = read_bytes(codex_home, path)?; + config = config.ca_certificate(TonicCertificate::from_pem(pem)); + } + + match (&tls.client_certificate, &tls.client_private_key) { + (Some(cert_path), Some(key_path)) => { + let (cert_pem, _) = read_bytes(codex_home, cert_path)?; + let (key_pem, _) = read_bytes(codex_home, key_path)?; + config = config.identity(TonicIdentity::from_pem(cert_pem, key_pem)); + } + (Some(_), None) | (None, Some(_)) => { + return Err(config_error( + "client_certificate and client_private_key must both be provided for mTLS", + )); + } + (None, None) => {} + } + + Ok(config) +} + +fn build_http_client( + tls: &OtelTlsConfig, + codex_home: &Path, +) -> Result> { + let mut builder = + reqwest::Client::builder().timeout(resolve_otlp_timeout(OTEL_EXPORTER_OTLP_LOGS_TIMEOUT)); + + if let Some(path) = tls.ca_certificate.as_ref() { + let (pem, location) = read_bytes(codex_home, path)?; + let certificate = ReqwestCertificate::from_pem(pem.as_slice()).map_err(|error| { + config_error(format!( + "failed to parse certificate {}: {error}", + location.display() + )) + })?; + builder = builder.add_root_certificate(certificate); + } + + match (&tls.client_certificate, &tls.client_private_key) { + (Some(cert_path), Some(key_path)) => { + let (mut cert_pem, cert_location) = read_bytes(codex_home, cert_path)?; + let (key_pem, key_location) = read_bytes(codex_home, key_path)?; + cert_pem.extend_from_slice(key_pem.as_slice()); + let identity = ReqwestIdentity::from_pem(cert_pem.as_slice()).map_err(|error| { + config_error(format!( + "failed to parse client identity using {} and {}: {error}", + cert_location.display(), + key_location.display() + )) + })?; + builder = builder.identity(identity); + } + (Some(_), None) | (None, Some(_)) => { + return Err(config_error( + "client_certificate and client_private_key must both be provided for mTLS", + )); + } + (None, None) => {} + } + + builder + .build() + .map_err(|error| Box::new(error) as Box) +} + +fn resolve_otlp_timeout(signal_var: &str) -> Duration { + if let Some(timeout) = read_timeout_env(signal_var) { + return timeout; + } + if let Some(timeout) = read_timeout_env(OTEL_EXPORTER_OTLP_TIMEOUT) { + return timeout; + } + OTEL_EXPORTER_OTLP_TIMEOUT_DEFAULT +} + +fn read_timeout_env(var: &str) -> Option { + let value = env::var(var).ok()?; + let parsed = value.parse::().ok()?; + if parsed < 0 { + return None; + } + Some(Duration::from_millis(parsed as u64)) +} + +fn read_bytes(base: &Path, provided: &PathBuf) -> Result<(Vec, PathBuf), Box> { + let resolved = resolve_config_path(base, provided); + match fs::read(&resolved) { + Ok(bytes) => Ok((bytes, resolved)), + Err(error) => Err(Box::new(io::Error::new( + error.kind(), + format!("failed to read {}: {error}", resolved.display()), + ))), + } +} + +fn resolve_config_path(base: &Path, provided: &PathBuf) -> PathBuf { + if provided.is_absolute() { + provided.clone() + } else { + base.join(provided) + } +} + +fn config_error(message: impl Into) -> Box { + Box::new(io::Error::new(ErrorKind::InvalidData, message.into())) +} diff --git a/codex-rs/protocol/Cargo.toml b/codex-rs/protocol/Cargo.toml index d3ec3af08..00ed100e0 100644 --- a/codex-rs/protocol/Cargo.toml +++ b/codex-rs/protocol/Cargo.toml @@ -20,10 +20,10 @@ icu_locale_core = { workspace = true } icu_provider = { workspace = true, features = ["sync"] } mcp-types = { workspace = true } mime_guess = { workspace = true } +schemars = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } serde_with = { workspace = true, features = ["macros", "base64"] } -schemars = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } sys-locale = { workspace = true } @@ -37,6 +37,7 @@ uuid = { workspace = true, features = ["serde", "v7", "v4"] } [dev-dependencies] anyhow = { workspace = true } +pretty_assertions = { workspace = true } tempfile = { workspace = true } [package.metadata.cargo-shear] diff --git a/codex-rs/protocol/src/approvals.rs b/codex-rs/protocol/src/approvals.rs index f7c5fc604..25f5e90e9 100644 --- a/codex-rs/protocol/src/approvals.rs +++ b/codex-rs/protocol/src/approvals.rs @@ -57,6 +57,10 @@ pub struct ExecApprovalRequestEvent { pub struct ApplyPatchApprovalRequestEvent { /// Responses API call id for the associated patch apply call, if available. pub call_id: String, + /// Turn ID that this patch belongs to. + /// Uses `#[serde(default)]` for backwards compatibility with older senders. + #[serde(default)] + pub turn_id: String, pub changes: HashMap, /// Optional explanatory reason (e.g. request for extra write access). #[serde(skip_serializing_if = "Option::is_none")] diff --git a/codex-rs/protocol/src/config_types.rs b/codex-rs/protocol/src/config_types.rs index 3881bde67..2ee6d3974 100644 --- a/codex-rs/protocol/src/config_types.rs +++ b/codex-rs/protocol/src/config_types.rs @@ -30,6 +30,7 @@ pub enum ReasoningEffort { #[default] Medium, High, + XHigh, } /// A summary of the reasoning performed by the model. This can be useful for diff --git a/codex-rs/protocol/src/models.rs b/codex-rs/protocol/src/models.rs index 2e623eefe..b6bd92c6e 100644 --- a/codex-rs/protocol/src/models.rs +++ b/codex-rs/protocol/src/models.rs @@ -230,8 +230,24 @@ pub struct LocalShellExecAction { #[serde(tag = "type", rename_all = "snake_case")] pub enum WebSearchAction { Search { - query: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + query: Option, }, + OpenPage { + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + url: Option, + }, + FindInPage { + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + url: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pattern: Option, + }, + #[serde(other)] Other, } @@ -634,6 +650,72 @@ mod tests { Ok(()) } + #[test] + fn roundtrips_web_search_call_actions() -> Result<()> { + let cases = vec![ + ( + r#"{ + "type": "web_search_call", + "status": "completed", + "action": { + "type": "search", + "query": "weather seattle" + } + }"#, + WebSearchAction::Search { + query: Some("weather seattle".into()), + }, + Some("completed".into()), + ), + ( + r#"{ + "type": "web_search_call", + "status": "open", + "action": { + "type": "open_page", + "url": "https://example.com" + } + }"#, + WebSearchAction::OpenPage { + url: Some("https://example.com".into()), + }, + Some("open".into()), + ), + ( + r#"{ + "type": "web_search_call", + "status": "in_progress", + "action": { + "type": "find_in_page", + "url": "https://example.com/docs", + "pattern": "installation" + } + }"#, + WebSearchAction::FindInPage { + url: Some("https://example.com/docs".into()), + pattern: Some("installation".into()), + }, + Some("in_progress".into()), + ), + ]; + + for (json_literal, expected_action, expected_status) in cases { + let parsed: ResponseItem = serde_json::from_str(json_literal)?; + let expected = ResponseItem::WebSearchCall { + id: None, + status: expected_status.clone(), + action: expected_action.clone(), + }; + assert_eq!(parsed, expected); + + let serialized = serde_json::to_value(&parsed)?; + let original_value: serde_json::Value = serde_json::from_str(json_literal)?; + assert_eq!(serialized, original_value); + } + + Ok(()) + } + #[test] fn deserialize_shell_tool_call_params() -> Result<()> { let json = r#"{ diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index 7f5e5228d..e0c2e4df3 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -562,6 +562,35 @@ pub enum EventMsg { ReasoningRawContentDelta(ReasoningRawContentDeltaEvent), } +/// Codex errors that we expose to clients. +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, JsonSchema, TS)] +#[serde(rename_all = "snake_case")] +#[ts(rename_all = "snake_case")] +pub enum CodexErrorInfo { + ContextWindowExceeded, + UsageLimitExceeded, + HttpConnectionFailed { + http_status_code: Option, + }, + /// Failed to connect to the response SSE stream. + ResponseStreamConnectionFailed { + http_status_code: Option, + }, + InternalServerError, + Unauthorized, + BadRequest, + SandboxError, + /// The response SSE stream disconnected in the middle of a turnbefore completion. + ResponseStreamDisconnected { + http_status_code: Option, + }, + /// Reached the retry limit for responses. + ResponseTooManyFailedAttempts { + http_status_code: Option, + }, + Other, +} + #[derive(Debug, Clone, Deserialize, Serialize, TS, JsonSchema)] pub struct RawResponseItemEvent { pub item: ResponseItem, @@ -686,6 +715,8 @@ pub struct ExitedReviewModeEvent { #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] pub struct ErrorEvent { pub message: String, + #[serde(default)] + pub codex_error_info: Option, } #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] @@ -790,6 +821,7 @@ pub struct TokenCountEvent { pub struct RateLimitSnapshot { pub primary: Option, pub secondary: Option, + pub credits: Option, } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema, TS)] @@ -804,6 +836,13 @@ pub struct RateLimitWindow { pub resets_at: Option, } +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema, TS)] +pub struct CreditsSnapshot { + pub has_credits: bool, + pub unlimited: bool, + pub balance: Option, +} + // Includes prompts, tools and space to call compact. const BASELINE_TOKENS: i64 = 12000; @@ -1173,11 +1212,13 @@ pub struct GitInfo { pub repository_url: Option, } -/// Review request sent to the review session. #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema, TS)] +/// Review request sent to the review session. pub struct ReviewRequest { pub prompt: String, pub user_facing_hint: String, + #[serde(default)] + pub append_to_original_thread: bool, } /// Structured review result produced by a child review session. @@ -1353,6 +1394,8 @@ pub struct UndoCompletedEvent { #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] pub struct StreamErrorEvent { pub message: String, + #[serde(default)] + pub codex_error_info: Option, } #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] @@ -1364,6 +1407,10 @@ pub struct StreamInfoEvent { pub struct PatchApplyBeginEvent { /// Identifier so this can be paired with the PatchApplyEnd event. pub call_id: String, + /// Turn ID that this patch belongs to. + /// Uses `#[serde(default)]` for backwards compatibility. + #[serde(default)] + pub turn_id: String, /// If true, there was no ApplyPatchApprovalRequest for this patch. pub auto_approved: bool, /// The changes to be applied. @@ -1374,12 +1421,19 @@ pub struct PatchApplyBeginEvent { pub struct PatchApplyEndEvent { /// Identifier for the PatchApplyBegin that finished. pub call_id: String, + /// Turn ID that this patch belongs to. + /// Uses `#[serde(default)]` for backwards compatibility. + #[serde(default)] + pub turn_id: String, /// Captured stdout (summary printed by apply_patch). pub stdout: String, /// Captured stderr (parser errors, IO failures, etc.). pub stderr: String, /// Whether the patch was applied successfully. pub success: bool, + /// The changes that were applied (mirrors PatchApplyBeginEvent::changes). + #[serde(default)] + pub changes: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] @@ -1467,7 +1521,7 @@ pub struct ListCustomPromptsResponseEvent { pub custom_prompts: Vec, } -#[derive(Debug, Default, Clone, Deserialize, Serialize, JsonSchema, TS)] +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] pub struct SessionConfiguredEvent { /// Name left as session_id instead of conversation_id for backwards compatibility. pub session_id: ConversationId, @@ -1475,6 +1529,18 @@ pub struct SessionConfiguredEvent { /// Tell the client what model is being queried. pub model: String, + pub model_provider_id: String, + + /// When to escalate for approval for execution + pub approval_policy: AskForApproval, + + /// How to sandbox commands executed in the system + pub sandbox_policy: SandboxPolicy, + + /// Working directory that should be treated as the *root* of the + /// session. + pub cwd: PathBuf, + /// The effort the model is putting into reasoning about the user's request. #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_effort: Option, @@ -1560,6 +1626,7 @@ mod tests { use crate::items::UserMessageItem; use crate::items::WebSearchItem; use anyhow::Result; + use pretty_assertions::assert_eq; use serde_json::json; use tempfile::NamedTempFile; @@ -1604,6 +1671,10 @@ mod tests { msg: EventMsg::SessionConfigured(SessionConfiguredEvent { session_id: conversation_id, model: "codex-mini-latest".to_string(), + model_provider_id: "openai".to_string(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + cwd: PathBuf::from("/home/user/project"), reasoning_effort: Some(ReasoningEffortConfig::default()), history_log_id: 0, history_entry_count: 0, @@ -1618,6 +1689,12 @@ mod tests { "type": "session_configured", "session_id": "67e55044-10b1-426f-9247-bb680e5fe0c8", "model": "codex-mini-latest", + "model_provider_id": "openai", + "approval_policy": "never", + "sandbox_policy": { + "type": "read-only" + }, + "cwd": "/home/user/project", "reasoning_effort": "medium", "history_log_id": 0, "history_entry_count": 0, diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index 9f996007f..b96a09f17 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -8,6 +8,7 @@ use crate::exec_command::strip_bash_lc_and_escape; use crate::file_search::FileSearchManager; use crate::history_cell::HistoryCell; use crate::model_migration::ModelMigrationOutcome; +use crate::model_migration::migration_copy_for_config; use crate::model_migration::run_model_migration_prompt; use crate::pager_overlay::Overlay; use crate::render::highlight::highlight_bash_to_lines; @@ -17,14 +18,21 @@ use crate::tui; use crate::tui::TuiEvent; use crate::update_action::UpdateAction; use codex_ansi_escape::ansi_escape_line; +use codex_app_server_protocol::AuthMode; +use codex_common::model_presets::HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG; +use codex_common::model_presets::HIDE_GPT5_1_MIGRATION_PROMPT_CONFIG; use codex_common::model_presets::ModelUpgrade; use codex_common::model_presets::all_model_presets; use codex_core::AuthManager; use codex_core::ConversationManager; use codex_core::config::Config; use codex_core::config::edit::ConfigEditsBuilder; +#[cfg(target_os = "windows")] +use codex_core::features::Feature; use codex_core::model_family::find_family_for_model; use codex_core::protocol::FinalOutput; +#[cfg(target_os = "windows")] +use codex_core::protocol::Op; use codex_core::protocol::SessionSource; use codex_core::protocol::TokenUsage; use codex_core::protocol_config_types::ReasoningEffort as ReasoningEffortConfig; @@ -48,6 +56,9 @@ use tokio::sync::mpsc::unbounded_channel; #[cfg(not(debug_assertions))] use crate::history_cell::UpdateAvailableHistoryCell; +const GPT_5_1_MIGRATION_AUTH_MODES: [AuthMode; 2] = [AuthMode::ChatGPT, AuthMode::ApiKey]; +const GPT_5_1_CODEX_MIGRATION_AUTH_MODES: [AuthMode; 1] = [AuthMode::ChatGPT]; + #[derive(Debug, Clone)] pub struct AppExitInfo { pub token_usage: TokenUsage, @@ -93,10 +104,21 @@ fn should_show_model_migration_prompt( .any(|preset| preset.model == current_model) } +fn migration_prompt_hidden(config: &Config, migration_config_key: &str) -> Option { + match migration_config_key { + HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG => { + config.notices.hide_gpt_5_1_codex_max_migration_prompt + } + HIDE_GPT5_1_MIGRATION_PROMPT_CONFIG => config.notices.hide_gpt5_1_migration_prompt, + _ => None, + } +} + async fn handle_model_migration_prompt_if_needed( tui: &mut tui::Tui, config: &mut Config, app_event_tx: &AppEventSender, + auth_mode: Option, ) -> Option { let upgrade = all_model_presets() .iter() @@ -106,18 +128,24 @@ async fn handle_model_migration_prompt_if_needed( if let Some(ModelUpgrade { id: target_model, reasoning_effort_mapping, + migration_config_key, }) = upgrade { + if !migration_prompt_allows_auth_mode(auth_mode, migration_config_key) { + return None; + } + let target_model = target_model.to_string(); - let hide_prompt_flag = config.notices.hide_gpt5_1_migration_prompt; + let hide_prompt_flag = migration_prompt_hidden(config, migration_config_key); if !should_show_model_migration_prompt(&config.model, &target_model, hide_prompt_flag) { return None; } - match run_model_migration_prompt(tui).await { + let prompt_copy = migration_copy_for_config(migration_config_key); + match run_model_migration_prompt(tui, prompt_copy).await { ModelMigrationOutcome::Accepted => { app_event_tx.send(AppEvent::PersistModelMigrationPromptAcknowledged { - migration_config: "hide_gpt5_1_migration_prompt".to_string(), + migration_config: migration_config_key.to_string(), }); config.model = target_model.to_string(); if let Some(family) = find_family_for_model(&target_model) { @@ -144,6 +172,11 @@ async fn handle_model_migration_prompt_if_needed( effort: mapped_effort, }); } + ModelMigrationOutcome::Rejected => { + app_event_tx.send(AppEvent::PersistModelMigrationPromptAcknowledged { + migration_config: migration_config_key.to_string(), + }); + } ModelMigrationOutcome::Exit => { return Some(AppExitInfo { token_usage: TokenUsage::default(), @@ -207,8 +240,10 @@ impl App { let (app_event_tx, mut app_event_rx) = unbounded_channel(); let app_event_tx = AppEventSender::new(app_event_tx); + let auth_mode = auth_manager.auth().map(|auth| auth.mode); let exit_info = - handle_model_migration_prompt_if_needed(tui, &mut config, &app_event_tx).await; + handle_model_migration_prompt_if_needed(tui, &mut config, &app_event_tx, auth_mode) + .await; if let Some(exit_info) = exit_info { return Ok(exit_info); } @@ -220,7 +255,7 @@ impl App { let enhanced_keys_supported = tui.enhanced_keys_supported(); - let chat_widget = match resume_selection { + let mut chat_widget = match resume_selection { ResumeSelection::StartFresh | ResumeSelection::Exit => { let init = crate::chatwidget::ChatWidgetInit { config: config.clone(), @@ -263,6 +298,8 @@ impl App { } }; + chat_widget.maybe_prompt_windows_sandbox_enable(); + let file_search = FileSearchManager::new(config.cwd.clone(), app_event_tx.clone()); #[cfg(not(debug_assertions))] let upgrade_version = crate::updates::get_upgrade_version(&config); @@ -287,7 +324,7 @@ impl App { skip_world_writable_scan_once: false, }; - // On startup, if Auto mode (workspace-write) or ReadOnly is active, warn about world-writable dirs on Windows. + // On startup, if Agent mode (workspace-write) or ReadOnly is active, warn about world-writable dirs on Windows. #[cfg(target_os = "windows")] { let should_check = codex_core::get_platform_sandbox().is_some() @@ -306,7 +343,8 @@ impl App { let env_map: std::collections::HashMap = std::env::vars().collect(); let tx = app.app_event_tx.clone(); let logs_base_dir = app.config.codex_home.clone(); - Self::spawn_world_writable_scan(cwd, env_map, logs_base_dir, tx); + let sandbox_policy = app.config.sandbox_policy.clone(); + Self::spawn_world_writable_scan(cwd, env_map, logs_base_dir, sandbox_policy, tx); } } @@ -537,8 +575,70 @@ impl App { AppEvent::OpenFeedbackConsent { category } => { self.chat_widget.open_feedback_consent(category); } - AppEvent::ShowWindowsAutoModeInstructions => { - self.chat_widget.open_windows_auto_mode_instructions(); + AppEvent::OpenWindowsSandboxEnablePrompt { preset } => { + self.chat_widget.open_windows_sandbox_enable_prompt(preset); + } + AppEvent::EnableWindowsSandboxForAgentMode { preset } => { + #[cfg(target_os = "windows")] + { + let profile = self.active_profile.as_deref(); + let feature_key = Feature::WindowsSandbox.key(); + match ConfigEditsBuilder::new(&self.config.codex_home) + .with_profile(profile) + .set_feature_enabled(feature_key, true) + .apply() + .await + { + Ok(()) => { + self.config.set_windows_sandbox_globally(true); + self.chat_widget.clear_forced_auto_mode_downgrade(); + if let Some((sample_paths, extra_count, failed_scan)) = + self.chat_widget.world_writable_warning_details() + { + self.app_event_tx.send( + AppEvent::OpenWorldWritableWarningConfirmation { + preset: Some(preset.clone()), + sample_paths, + extra_count, + failed_scan, + }, + ); + } else { + self.app_event_tx.send(AppEvent::CodexOp( + Op::OverrideTurnContext { + cwd: None, + approval_policy: Some(preset.approval), + sandbox_policy: Some(preset.sandbox.clone()), + model: None, + effort: None, + summary: None, + }, + )); + self.app_event_tx + .send(AppEvent::UpdateAskForApprovalPolicy(preset.approval)); + self.app_event_tx + .send(AppEvent::UpdateSandboxPolicy(preset.sandbox.clone())); + self.chat_widget.add_info_message( + "Enabled experimental Windows sandbox.".to_string(), + None, + ); + } + } + Err(err) => { + tracing::error!( + error = %err, + "failed to enable Windows sandbox feature" + ); + self.chat_widget.add_error_message(format!( + "Failed to enable the Windows sandbox feature: {err}" + )); + } + } + } + #[cfg(not(target_os = "windows"))] + { + let _ = preset; + } } AppEvent::PersistModelSelection { model, effort } => { let profile = self.active_profile.as_deref(); @@ -549,19 +649,17 @@ impl App { .await { Ok(()) => { - let effort_label = effort - .map(|eff| format!(" with {eff} reasoning")) - .unwrap_or_else(|| " with default reasoning".to_string()); + let reasoning_label = Self::reasoning_label(effort); if let Some(profile) = profile { self.chat_widget.add_info_message( format!( - "Model changed to {model}{effort_label} for {profile} profile" + "Model changed to {model} {reasoning_label} for {profile} profile" ), None, ); } else { self.chat_widget.add_info_message( - format!("Model changed to {model}{effort_label}"), + format!("Model changed to {model} {reasoning_label}"), None, ); } @@ -593,6 +691,13 @@ impl App { | codex_core::protocol::SandboxPolicy::ReadOnly ); + self.config.sandbox_policy = policy.clone(); + #[cfg(target_os = "windows")] + if !matches!(policy, codex_core::protocol::SandboxPolicy::ReadOnly) + || codex_core::get_platform_sandbox().is_some() + { + self.config.forced_auto_mode_downgraded_on_windows = false; + } self.chat_widget.set_sandbox_policy(policy); // If sandbox policy becomes workspace-write or read-only, run the Windows world-writable scan. @@ -613,7 +718,14 @@ impl App { std::env::vars().collect(); let tx = self.app_event_tx.clone(); let logs_base_dir = self.config.codex_home.clone(); - Self::spawn_world_writable_scan(cwd, env_map, logs_base_dir, tx); + let sandbox_policy = self.config.sandbox_policy.clone(); + Self::spawn_world_writable_scan( + cwd, + env_map, + logs_base_dir, + sandbox_policy, + tx, + ); } } } @@ -656,7 +768,7 @@ impl App { "failed to persist world-writable warning acknowledgement" ); self.chat_widget.add_error_message(format!( - "Failed to save Auto mode warning preference: {err}" + "Failed to save Agent mode warning preference: {err}" )); } } @@ -722,6 +834,17 @@ impl App { Ok(true) } + fn reasoning_label(reasoning_effort: Option) -> &'static str { + match reasoning_effort { + Some(ReasoningEffortConfig::Minimal) => "minimal", + Some(ReasoningEffortConfig::Low) => "low", + Some(ReasoningEffortConfig::Medium) => "medium", + Some(ReasoningEffortConfig::High) => "high", + Some(ReasoningEffortConfig::XHigh) => "xhigh", + None | Some(ReasoningEffortConfig::None) => "default", + } + } + pub(crate) fn token_usage(&self) -> codex_core::protocol::TokenUsage { self.chat_widget.token_usage() } @@ -796,6 +919,7 @@ impl App { cwd: PathBuf, env_map: std::collections::HashMap, logs_base_dir: PathBuf, + sandbox_policy: codex_core::protocol::SandboxPolicy, tx: AppEventSender, ) { #[inline] @@ -805,8 +929,10 @@ impl App { } tokio::task::spawn_blocking(move || { let result = codex_windows_sandbox::preflight_audit_everyone_writable( + &logs_base_dir, &cwd, &env_map, + &sandbox_policy, Some(logs_base_dir.as_path()), ); if let Ok(ref paths) = result @@ -844,6 +970,28 @@ impl App { } } +fn migration_prompt_allowed_auth_modes(migration_config_key: &str) -> Option<&'static [AuthMode]> { + match migration_config_key { + HIDE_GPT5_1_MIGRATION_PROMPT_CONFIG => Some(&GPT_5_1_MIGRATION_AUTH_MODES), + HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG => Some(&GPT_5_1_CODEX_MIGRATION_AUTH_MODES), + _ => None, + } +} + +fn migration_prompt_allows_auth_mode( + auth_mode: Option, + migration_config_key: &str, +) -> bool { + if let Some(allowed_modes) = migration_prompt_allowed_auth_modes(migration_config_key) { + match auth_mode { + None => true, + Some(mode) => allowed_modes.contains(&mode), + } + } else { + auth_mode != Some(AuthMode::ApiKey) + } +} + #[cfg(test)] mod tests { use super::*; @@ -858,6 +1006,8 @@ mod tests { use codex_core::AuthManager; use codex_core::CodexAuth; use codex_core::ConversationManager; + use codex_core::protocol::AskForApproval; + use codex_core::protocol::SandboxPolicy; use codex_core::protocol::SessionConfiguredEvent; use codex_protocol::ConversationId; use ratatui::prelude::Line; @@ -868,7 +1018,6 @@ mod tests { fn make_test_app() -> App { let (chat_widget, app_event_tx, _rx, _op_rx) = make_chatwidget_manual_with_sender(); let config = chat_widget.config_ref().clone(); - let server = Arc::new(ConversationManager::with_auth(CodexAuth::from_api_key( "Test API Key", ))); @@ -910,6 +1059,11 @@ mod tests { "gpt-5.1-codex-mini", None )); + assert!(should_show_model_migration_prompt( + "gpt-5.1-codex", + "gpt-5.1-codex-max", + None + )); assert!(!should_show_model_migration_prompt( "gpt-5.1-codex", "gpt-5.1-codex", @@ -968,6 +1122,10 @@ mod tests { let event = SessionConfiguredEvent { session_id: ConversationId::new(), model: "gpt-test".to_string(), + model_provider_id: "test-provider".to_string(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + cwd: PathBuf::from("/home/user/project"), reasoning_effort: None, history_log_id: 0, history_entry_count: 0, @@ -1036,4 +1194,40 @@ mod tests { Some("codex resume 123e4567-e89b-12d3-a456-426614174000".to_string()) ); } + + #[test] + fn gpt5_migration_allows_api_key_and_chatgpt() { + assert!(migration_prompt_allows_auth_mode( + Some(AuthMode::ApiKey), + HIDE_GPT5_1_MIGRATION_PROMPT_CONFIG, + )); + assert!(migration_prompt_allows_auth_mode( + Some(AuthMode::ChatGPT), + HIDE_GPT5_1_MIGRATION_PROMPT_CONFIG, + )); + } + + #[test] + fn gpt_5_1_codex_max_migration_limits_to_chatgpt() { + assert!(migration_prompt_allows_auth_mode( + Some(AuthMode::ChatGPT), + HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG, + )); + assert!(!migration_prompt_allows_auth_mode( + Some(AuthMode::ApiKey), + HIDE_GPT_5_1_CODEX_MAX_MIGRATION_PROMPT_CONFIG, + )); + } + + #[test] + fn other_migrations_block_api_key() { + assert!(!migration_prompt_allows_auth_mode( + Some(AuthMode::ApiKey), + "unknown" + )); + assert!(migration_prompt_allows_auth_mode( + Some(AuthMode::ChatGPT), + "unknown" + )); + } } diff --git a/codex-rs/tui/src/app_event.rs b/codex-rs/tui/src/app_event.rs index 0c3033c5c..cf494f57d 100644 --- a/codex-rs/tui/src/app_event.rs +++ b/codex-rs/tui/src/app_event.rs @@ -91,9 +91,17 @@ pub(crate) enum AppEvent { failed_scan: bool, }, - /// Show Windows Subsystem for Linux setup instructions for auto mode. + /// Prompt to enable the Windows sandbox feature before using Agent mode. #[cfg_attr(not(target_os = "windows"), allow(dead_code))] - ShowWindowsAutoModeInstructions, + OpenWindowsSandboxEnablePrompt { + preset: ApprovalPreset, + }, + + /// Enable the Windows sandbox feature and switch to Agent mode. + #[cfg_attr(not(target_os = "windows"), allow(dead_code))] + EnableWindowsSandboxForAgentMode { + preset: ApprovalPreset, + }, /// Update the current approval policy in the running app and widget. UpdateAskForApprovalPolicy(AskForApproval), diff --git a/codex-rs/tui/src/bottom_pane/chat_composer.rs b/codex-rs/tui/src/bottom_pane/chat_composer.rs index f64a824f4..30727c298 100644 --- a/codex-rs/tui/src/bottom_pane/chat_composer.rs +++ b/codex-rs/tui/src/bottom_pane/chat_composer.rs @@ -254,7 +254,7 @@ impl ChatComposer { true } Err(err) => { - tracing::info!("ERR: {err}"); + tracing::trace!("ERR: {err}"); false } } diff --git a/codex-rs/tui/src/bottom_pane/feedback_view.rs b/codex-rs/tui/src/bottom_pane/feedback_view.rs index a8476df0a..8a42e563c 100644 --- a/codex-rs/tui/src/bottom_pane/feedback_view.rs +++ b/codex-rs/tui/src/bottom_pane/feedback_view.rs @@ -26,7 +26,8 @@ use super::popup_consts::standard_popup_hint_line; use super::textarea::TextArea; use super::textarea::TextAreaState; -const BASE_ISSUE_URL: &str = "https://github.com/openai/codex/issues/new?template=2-bug-report.yml"; +const BASE_BUG_ISSUE_URL: &str = + "https://github.com/openai/codex/issues/new?template=2-bug-report.yml"; /// Minimal input overlay to collect an optional feedback note, then upload /// both logs and rollout with classification + metadata. @@ -88,26 +89,38 @@ impl FeedbackNoteView { match result { Ok(()) => { - let issue_url = format!("{BASE_ISSUE_URL}&steps=Uploaded%20thread:%20{thread_id}"); let prefix = if self.include_logs { "• Feedback uploaded." } else { "• Feedback recorded (no logs)." }; - self.app_event_tx.send(AppEvent::InsertHistoryCell(Box::new( - history_cell::PlainHistoryCell::new(vec![ - Line::from(format!( - "{prefix} Please open an issue using the following URL:" - )), + let issue_url = issue_url_for_category(self.category, &thread_id); + let mut lines = vec![Line::from(match issue_url.as_ref() { + Some(_) => format!("{prefix} Please open an issue using the following URL:"), + None => format!("{prefix} Thanks for the feedback!"), + })]; + if let Some(url) = issue_url { + lines.extend([ "".into(), - Line::from(vec![" ".into(), issue_url.cyan().underlined()]), + Line::from(vec![" ".into(), url.cyan().underlined()]), "".into(), Line::from(vec![ " Or mention your thread ID ".into(), std::mem::take(&mut thread_id).bold(), " in an existing issue.".into(), ]), - ]), + ]); + } else { + lines.extend([ + "".into(), + Line::from(vec![ + " Thread ID: ".into(), + std::mem::take(&mut thread_id).bold(), + ]), + ]); + } + self.app_event_tx.send(AppEvent::InsertHistoryCell(Box::new( + history_cell::PlainHistoryCell::new(lines), ))); } Err(e) => { @@ -320,6 +333,15 @@ fn feedback_classification(category: FeedbackCategory) -> &'static str { } } +fn issue_url_for_category(category: FeedbackCategory, thread_id: &str) -> Option { + match category { + FeedbackCategory::Bug | FeedbackCategory::BadResult | FeedbackCategory::Other => Some( + format!("{BASE_BUG_ISSUE_URL}&steps=Uploaded%20thread:%20{thread_id}"), + ), + FeedbackCategory::GoodResult => None, + } +} + // Build the selection popup params for feedback categories. pub(crate) fn feedback_selection_params( app_event_tx: AppEventSender, @@ -514,4 +536,22 @@ mod tests { let rendered = render(&view, 60); insta::assert_snapshot!("feedback_view_other", rendered); } + + #[test] + fn issue_url_available_for_bug_bad_result_and_other() { + let bug_url = issue_url_for_category(FeedbackCategory::Bug, "thread-1"); + assert!( + bug_url + .as_deref() + .is_some_and(|url| url.contains("template=2-bug-report")) + ); + + let bad_result_url = issue_url_for_category(FeedbackCategory::BadResult, "thread-2"); + assert!(bad_result_url.is_some()); + + let other_url = issue_url_for_category(FeedbackCategory::Other, "thread-3"); + assert!(other_url.is_some()); + + assert!(issue_url_for_category(FeedbackCategory::GoodResult, "t").is_none()); + } } diff --git a/codex-rs/tui/src/bottom_pane/list_selection_view.rs b/codex-rs/tui/src/bottom_pane/list_selection_view.rs index 88fbaf5b5..d294a4726 100644 --- a/codex-rs/tui/src/bottom_pane/list_selection_view.rs +++ b/codex-rs/tui/src/bottom_pane/list_selection_view.rs @@ -52,6 +52,7 @@ pub(crate) struct SelectionViewParams { pub is_searchable: bool, pub search_placeholder: Option, pub header: Box, + pub initial_selected_idx: Option, } impl Default for SelectionViewParams { @@ -64,6 +65,7 @@ impl Default for SelectionViewParams { is_searchable: false, search_placeholder: None, header: Box::new(()), + initial_selected_idx: None, } } } @@ -80,6 +82,7 @@ pub(crate) struct ListSelectionView { filtered_indices: Vec, last_selected_actual_idx: Option, header: Box, + initial_selected_idx: Option, } impl ListSelectionView { @@ -110,6 +113,7 @@ impl ListSelectionView { filtered_indices: Vec::new(), last_selected_actual_idx: None, header, + initial_selected_idx: params.initial_selected_idx, }; s.apply_filter(); s @@ -132,7 +136,8 @@ impl ListSelectionView { (!self.is_searchable) .then(|| self.items.iter().position(|item| item.is_current)) .flatten() - }); + }) + .or_else(|| self.initial_selected_idx.take()); if self.is_searchable && !self.search_query.is_empty() { let query_lower = self.search_query.to_lowercase(); diff --git a/codex-rs/tui/src/bottom_pane/mod.rs b/codex-rs/tui/src/bottom_pane/mod.rs index da2efb63c..5dbfb210b 100644 --- a/codex-rs/tui/src/bottom_pane/mod.rs +++ b/codex-rs/tui/src/bottom_pane/mod.rs @@ -69,6 +69,7 @@ pub(crate) struct BottomPane { is_task_running: bool, ctrl_c_quit_hint: bool, esc_backtrack_hint: bool, + animations_enabled: bool, /// Inline status indicator shown above the composer while a task is running. status: Option, @@ -84,28 +85,38 @@ pub(crate) struct BottomPaneParams { pub(crate) enhanced_keys_supported: bool, pub(crate) placeholder_text: String, pub(crate) disable_paste_burst: bool, + pub(crate) animations_enabled: bool, } impl BottomPane { pub fn new(params: BottomPaneParams) -> Self { - let enhanced_keys_supported = params.enhanced_keys_supported; + let BottomPaneParams { + app_event_tx, + frame_requester, + has_input_focus, + enhanced_keys_supported, + placeholder_text, + disable_paste_burst, + animations_enabled, + } = params; Self { composer: ChatComposer::new( - params.has_input_focus, - params.app_event_tx.clone(), + has_input_focus, + app_event_tx.clone(), enhanced_keys_supported, - params.placeholder_text, - params.disable_paste_burst, + placeholder_text, + disable_paste_burst, ), view_stack: Vec::new(), - app_event_tx: params.app_event_tx, - frame_requester: params.frame_requester, - has_input_focus: params.has_input_focus, + app_event_tx, + frame_requester, + has_input_focus, is_task_running: false, ctrl_c_quit_hint: false, status: None, queued_user_messages: QueuedUserMessages::new(), esc_backtrack_hint: false, + animations_enabled, context_window_percent: None, } } @@ -114,6 +125,11 @@ impl BottomPane { self.status.as_ref() } + #[cfg(test)] + pub(crate) fn context_window_percent(&self) -> Option { + self.context_window_percent + } + fn active_view(&self) -> Option<&dyn BottomPaneView> { self.view_stack.last().map(std::convert::AsRef::as_ref) } @@ -289,6 +305,7 @@ impl BottomPane { self.status = Some(StatusIndicatorWidget::new( self.app_event_tx.clone(), self.frame_requester.clone(), + self.animations_enabled, )); } if let Some(status) = self.status.as_mut() { @@ -314,6 +331,7 @@ impl BottomPane { self.status = Some(StatusIndicatorWidget::new( self.app_event_tx.clone(), self.frame_requester.clone(), + self.animations_enabled, )); self.request_redraw(); } @@ -549,6 +567,7 @@ mod tests { enhanced_keys_supported: false, placeholder_text: "Ask Codex to do anything".to_string(), disable_paste_burst: false, + animations_enabled: true, }); pane.push_approval_request(exec_request()); assert_eq!(CancellationEvent::Handled, pane.on_ctrl_c()); @@ -569,6 +588,7 @@ mod tests { enhanced_keys_supported: false, placeholder_text: "Ask Codex to do anything".to_string(), disable_paste_burst: false, + animations_enabled: true, }); // Create an approval modal (active view). @@ -600,6 +620,7 @@ mod tests { enhanced_keys_supported: false, placeholder_text: "Ask Codex to do anything".to_string(), disable_paste_burst: false, + animations_enabled: true, }); // Start a running task so the status indicator is active above the composer. @@ -665,6 +686,7 @@ mod tests { enhanced_keys_supported: false, placeholder_text: "Ask Codex to do anything".to_string(), disable_paste_burst: false, + animations_enabled: true, }); // Begin a task: show initial status. @@ -690,6 +712,7 @@ mod tests { enhanced_keys_supported: false, placeholder_text: "Ask Codex to do anything".to_string(), disable_paste_burst: false, + animations_enabled: true, }); // Activate spinner (status view replaces composer) with no live ring. @@ -719,6 +742,7 @@ mod tests { enhanced_keys_supported: false, placeholder_text: "Ask Codex to do anything".to_string(), disable_paste_burst: false, + animations_enabled: true, }); pane.set_task_running(true); @@ -745,6 +769,7 @@ mod tests { enhanced_keys_supported: false, placeholder_text: "Ask Codex to do anything".to_string(), disable_paste_burst: false, + animations_enabled: true, }); pane.set_task_running(true); diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index cdcc78aba..a371fa8cb 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -95,8 +95,6 @@ use crate::history_cell::HistoryCell; use crate::history_cell::McpToolCallCell; use crate::history_cell::PlainHistoryCell; use crate::markdown::append_markdown; -#[cfg(target_os = "windows")] -use crate::onboarding::WSL_INSTRUCTIONS; use crate::render::Insets; use crate::render::renderable::ColumnRenderable; use crate::render::renderable::FlexRenderable; @@ -292,6 +290,8 @@ pub(crate) struct ChatWidget { pending_notification: Option, // Simple review mode flag; used to adjust layout and banners. is_review_mode: bool, + // Snapshot of token usage to restore after review mode exits. + pre_review_token_info: Option>, // Whether to add a final message separator after the last message needs_final_message_separator: bool, @@ -491,16 +491,39 @@ impl ChatWidget { } pub(crate) fn set_token_info(&mut self, info: Option) { - if let Some(info) = info { - let context_window = info - .model_context_window - .or(self.config.model_context_window); - let percent = context_window.map(|window| { + match info { + Some(info) => self.apply_token_info(info), + None => { + self.bottom_pane.set_context_window_percent(None); + self.token_info = None; + } + } + } + + fn apply_token_info(&mut self, info: TokenUsageInfo) { + let percent = self.context_remaining_percent(&info); + self.bottom_pane.set_context_window_percent(percent); + self.token_info = Some(info); + } + + fn context_remaining_percent(&self, info: &TokenUsageInfo) -> Option { + info.model_context_window + .or(self.config.model_context_window) + .map(|window| { info.last_token_usage .percent_of_context_window_remaining(window) - }); - self.bottom_pane.set_context_window_percent(percent); - self.token_info = Some(info); + }) + } + + fn restore_pre_review_token_info(&mut self) { + if let Some(saved) = self.pre_review_token_info.take() { + match saved { + Some(info) => self.apply_token_info(info), + None => { + self.bottom_pane.set_context_window_percent(None); + self.token_info = None; + } + } } } @@ -939,6 +962,7 @@ impl ChatWidget { parsed, source, None, + self.config.animations, ))); } @@ -1048,6 +1072,7 @@ impl ChatWidget { ev.parsed_cmd, ev.source, interaction_input, + self.config.animations, ))); } @@ -1060,6 +1085,7 @@ impl ChatWidget { self.active_cell = Some(Box::new(history_cell::new_active_mcp_tool_call( ev.call_id, ev.invocation, + self.config.animations, ))); self.request_redraw(); } @@ -1081,7 +1107,11 @@ impl ChatWidget { Some(cell) if cell.call_id() == call_id => cell.complete(duration, result), _ => { self.flush_active_cell(); - let mut cell = history_cell::new_active_mcp_tool_call(call_id, invocation); + let mut cell = history_cell::new_active_mcp_tool_call( + call_id, + invocation, + self.config.animations, + ); let extra_cell = cell.complete(duration, result); self.active_cell = Some(Box::new(cell)); extra_cell @@ -1123,6 +1153,7 @@ impl ChatWidget { enhanced_keys_supported, placeholder_text: placeholder, disable_paste_burst: config.disable_paste_burst, + animations_enabled: config.animations, }), active_cell: None, config: config.clone(), @@ -1152,6 +1183,7 @@ impl ChatWidget { suppress_session_configured_redraw: false, pending_notification: None, is_review_mode: false, + pre_review_token_info: None, needs_final_message_separator: false, last_rendered_width: std::cell::Cell::new(None), feedback, @@ -1196,6 +1228,7 @@ impl ChatWidget { enhanced_keys_supported, placeholder_text: placeholder, disable_paste_burst: config.disable_paste_burst, + animations_enabled: config.animations, }), active_cell: None, config: config.clone(), @@ -1225,6 +1258,7 @@ impl ChatWidget { suppress_session_configured_redraw: true, pending_notification: None, is_review_mode: false, + pre_review_token_info: None, needs_final_message_separator: false, last_rendered_width: std::cell::Cell::new(None), feedback, @@ -1439,6 +1473,7 @@ impl ChatWidget { // }), msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id: "1".to_string(), + turn_id: "turn-1".to_string(), changes: HashMap::from([ ( PathBuf::from("/tmp/test.txt"), @@ -1629,7 +1664,7 @@ impl ChatWidget { self.on_rate_limit_snapshot(ev.rate_limits); } EventMsg::Warning(WarningEvent { message }) => self.on_warning(message), - EventMsg::Error(ErrorEvent { message }) => self.on_error(message), + EventMsg::Error(ErrorEvent { message, .. }) => self.on_error(message), EventMsg::McpStartupUpdate(ev) => self.on_mcp_startup_update(ev), EventMsg::McpStartupComplete(ev) => self.on_mcp_startup_complete(ev), EventMsg::TurnAborted(ev) => match ev.reason { @@ -1672,7 +1707,9 @@ impl ChatWidget { } EventMsg::UndoStarted(ev) => self.on_undo_started(ev), EventMsg::UndoCompleted(ev) => self.on_undo_completed(ev), - EventMsg::StreamError(StreamErrorEvent { message }) => self.on_stream_error(message), + EventMsg::StreamError(StreamErrorEvent { message, .. }) => { + self.on_stream_error(message) + } EventMsg::UserMessage(ev) => { if from_replay { self.on_user_message_event(ev); @@ -1693,6 +1730,9 @@ impl ChatWidget { fn on_entered_review_mode(&mut self, review: ReviewRequest) { // Enter review mode and emit a concise banner + if self.pre_review_token_info.is_none() { + self.pre_review_token_info = Some(self.token_info.clone()); + } self.is_review_mode = true; let banner = format!(">> Code review started: {} <<", review.user_facing_hint); self.add_to_history(history_cell::new_review_status_line(banner)); @@ -1733,6 +1773,7 @@ impl ChatWidget { } self.is_review_mode = false; + self.restore_pre_review_token_info(); // Append a finishing banner at the end of this turn. self.add_to_history(history_cell::new_review_status_line( "<< Code review finished >>".to_string(), @@ -2015,6 +2056,26 @@ impl ChatWidget { let default_effort: ReasoningEffortConfig = preset.default_reasoning_effort; let supported = preset.supported_reasoning_efforts; + let warn_effort = if supported + .iter() + .any(|option| option.effort == ReasoningEffortConfig::XHigh) + { + Some(ReasoningEffortConfig::XHigh) + } else if supported + .iter() + .any(|option| option.effort == ReasoningEffortConfig::High) + { + Some(ReasoningEffortConfig::High) + } else { + None + }; + let warning_text = warn_effort.map(|effort| { + let effort_label = Self::reasoning_effort_label(effort); + format!("⚠ {effort_label} reasoning effort can quickly consume Plus plan rate limits.") + }); + let warn_for_model = preset.model.starts_with("gpt-5.1-codex") + || preset.model.starts_with("gpt-5.1-codex-max"); + struct EffortChoice { stored: Option, display: ReasoningEffortConfig, @@ -2059,13 +2120,18 @@ impl ChatWidget { } else { default_choice }; + let selection_choice = highlight_choice.or(default_choice); + let initial_selected_idx = choices + .iter() + .position(|choice| choice.stored == selection_choice) + .or_else(|| { + selection_choice + .and_then(|effort| choices.iter().position(|choice| choice.display == effort)) + }); let mut items: Vec = Vec::new(); for choice in choices.iter() { let effort = choice.display; - let mut effort_label = effort.to_string(); - if let Some(first) = effort_label.get_mut(0..1) { - first.make_ascii_uppercase(); - } + let mut effort_label = Self::reasoning_effort_label(effort).to_string(); if choice.stored == default_choice { effort_label.push_str(" (default)"); } @@ -2080,14 +2146,17 @@ impl ChatWidget { }) .filter(|text| !text.is_empty()); - let warning = "⚠ High reasoning effort can quickly consume Plus plan rate limits."; - let show_warning = - preset.model.starts_with("gpt-5.1-codex") && effort == ReasoningEffortConfig::High; - let selected_description = show_warning.then(|| { - description - .as_ref() - .map_or(warning.to_string(), |d| format!("{d}\n{warning}")) - }); + let show_warning = warn_for_model && warn_effort == Some(effort); + let selected_description = if show_warning { + warning_text.as_ref().map(|warning_message| { + description.as_ref().map_or_else( + || warning_message.clone(), + |d| format!("{d}\n{warning_message}"), + ) + }) + } else { + None + }; let model_for_action = model_slug.clone(); let effort_for_action = choice.stored; @@ -2135,10 +2204,22 @@ impl ChatWidget { header: Box::new(header), footer_hint: Some(standard_popup_hint_line()), items, + initial_selected_idx, ..Default::default() }); } + fn reasoning_effort_label(effort: ReasoningEffortConfig) -> &'static str { + match effort { + ReasoningEffortConfig::None => "None", + ReasoningEffortConfig::Minimal => "Minimal", + ReasoningEffortConfig::Low => "Low", + ReasoningEffortConfig::Medium => "Medium", + ReasoningEffortConfig::High => "High", + ReasoningEffortConfig::XHigh => "Extra high", + } + } + fn apply_model_and_effort(&self, model: String, effort: Option) { self.app_event_tx .send(AppEvent::CodexOp(Op::OverrideTurnContext { @@ -2171,41 +2252,12 @@ impl ChatWidget { let current_sandbox = self.config.sandbox_policy.clone(); let mut items: Vec = Vec::new(); let presets: Vec = builtin_approval_presets(); - #[cfg(target_os = "windows")] - let header_renderable: Box = if self - .config - .forced_auto_mode_downgraded_on_windows - { - use ratatui_macros::line; - - let mut header = ColumnRenderable::new(); - header.push(line![ - "Codex forced your settings back to Read Only on this Windows machine.".bold() - ]); - header.push(line![ - "To re-enable Auto mode, run Codex inside Windows Subsystem for Linux (WSL) or enable Full Access manually.".dim() - ]); - Box::new(header) - } else { - Box::new(()) - }; - #[cfg(not(target_os = "windows"))] - let header_renderable: Box = Box::new(()); for preset in presets.into_iter() { let is_current = current_approval == preset.approval && current_sandbox == preset.sandbox; let name = preset.label.to_string(); let description_text = preset.description; - let description = if cfg!(target_os = "windows") - && preset.id == "auto" - && codex_core::get_platform_sandbox().is_none() - { - Some(format!( - "{description_text}\nRequires Windows Subsystem for Linux (WSL). Show installation instructions..." - )) - } else { - Some(description_text.to_string()) - }; + let description = Some(description_text.to_string()); let requires_confirmation = preset.id == "full-access" && !self .config @@ -2223,53 +2275,16 @@ impl ChatWidget { #[cfg(target_os = "windows")] { if codex_core::get_platform_sandbox().is_none() { - vec![Box::new(|tx| { - tx.send(AppEvent::ShowWindowsAutoModeInstructions); + let preset_clone = preset.clone(); + vec![Box::new(move |tx| { + tx.send(AppEvent::OpenWindowsSandboxEnablePrompt { + preset: preset_clone.clone(), + }); })] - } else if !self - .config - .notices - .hide_world_writable_warning - .unwrap_or(false) - && self.windows_world_writable_flagged() + } else if let Some((sample_paths, extra_count, failed_scan)) = + self.world_writable_warning_details() { let preset_clone = preset.clone(); - // Compute sample paths for the warning popup. - let mut env_map: std::collections::HashMap = - std::collections::HashMap::new(); - for (k, v) in std::env::vars() { - env_map.insert(k, v); - } - let (sample_paths, extra_count, failed_scan) = - match codex_windows_sandbox::preflight_audit_everyone_writable( - &self.config.cwd, - &env_map, - Some(self.config.codex_home.as_path()), - ) { - Ok(paths) if !paths.is_empty() => { - fn normalize_windows_path_for_display( - p: &std::path::Path, - ) -> String { - let canon = dunce::canonicalize(p) - .unwrap_or_else(|_| p.to_path_buf()); - canon.display().to_string().replace('/', "\\") - } - let as_strings: Vec = paths - .iter() - .map(|p| normalize_windows_path_for_display(p)) - .collect(); - let samples: Vec = - as_strings.iter().take(3).cloned().collect(); - let extra = if as_strings.len() > samples.len() { - as_strings.len() - samples.len() - } else { - 0 - }; - (samples, extra, false) - } - Err(_) => (Vec::new(), 0, true), - _ => (Vec::new(), 0, false), - }; vec![Box::new(move |tx| { tx.send(AppEvent::OpenWorldWritableWarningConfirmation { preset: Some(preset_clone.clone()), @@ -2303,7 +2318,7 @@ impl ChatWidget { title: Some("Select Approval Mode".to_string()), footer_hint: Some(standard_popup_hint_line()), items, - header: header_renderable, + header: Box::new(()), ..Default::default() }); } @@ -2328,20 +2343,26 @@ impl ChatWidget { } #[cfg(target_os = "windows")] - fn windows_world_writable_flagged(&self) -> bool { - use std::collections::HashMap; - let mut env_map: HashMap = HashMap::new(); - for (k, v) in std::env::vars() { - env_map.insert(k, v); - } - match codex_windows_sandbox::preflight_audit_everyone_writable( - &self.config.cwd, - &env_map, - Some(self.config.codex_home.as_path()), - ) { - Ok(paths) => !paths.is_empty(), - Err(_) => true, + pub(crate) fn world_writable_warning_details(&self) -> Option<(Vec, usize, bool)> { + if self + .config + .notices + .hide_world_writable_warning + .unwrap_or(false) + { + return None; } + let cwd = match std::env::current_dir() { + Ok(cwd) => cwd, + Err(_) => return Some((Vec::new(), 0, true)), + }; + codex_windows_sandbox::world_writable_warning_details(self.config.codex_home.as_path(), cwd) + } + + #[cfg(not(target_os = "windows"))] + #[allow(dead_code)] + pub(crate) fn world_writable_warning_details(&self) -> Option<(Vec, usize, bool)> { + None } pub(crate) fn open_full_access_confirmation(&mut self, preset: ApprovalPreset) { @@ -2421,12 +2442,15 @@ impl ChatWidget { None => (None, None), }; let mut header_children: Vec> = Vec::new(); - let mode_label = match self.config.sandbox_policy { - SandboxPolicy::WorkspaceWrite { .. } => "Auto mode", + let describe_policy = |policy: &SandboxPolicy| match policy { + SandboxPolicy::WorkspaceWrite { .. } => "Agent mode", SandboxPolicy::ReadOnly => "Read-Only mode", - _ => "Auto mode", + _ => "Agent mode", }; - let title_line = Line::from("Unprotected directories found").bold(); + let mode_label = preset + .as_ref() + .map(|p| describe_policy(&p.sandbox)) + .unwrap_or_else(|| describe_policy(&self.config.sandbox_policy)); let info_line = if failed_scan { Line::from(vec![ "We couldn't complete the world-writable scan, so protections cannot be verified. " @@ -2436,14 +2460,10 @@ impl ChatWidget { ]) } else { Line::from(vec![ - "Some important directories on this system are world-writable. ".into(), - format!( - "The Windows sandbox cannot protect writes to these locations in {mode_label}." - ) - .fg(Color::Red), + "The Windows sandbox cannot protect writes to folders that are writable by Everyone.".into(), + " Consider removing write access for Everyone from the following folders:".into(), ]) }; - header_children.push(Box::new(title_line)); header_children.push(Box::new( Paragraph::new(vec![info_line]).wrap(Wrap { trim: false }), )); @@ -2451,9 +2471,9 @@ impl ChatWidget { if !sample_paths.is_empty() { // Show up to three examples and optionally an "and X more" line. let mut lines: Vec = Vec::new(); - lines.push(Line::from("Examples:").bold()); + lines.push(Line::from("")); for p in &sample_paths { - lines.push(Line::from(format!(" - {p}"))); + lines.push(Line::from(format!(" - {p}"))); } if extra_count > 0 { lines.push(Line::from(format!("and {extra_count} more"))); @@ -2521,29 +2541,43 @@ impl ChatWidget { } #[cfg(target_os = "windows")] - pub(crate) fn open_windows_auto_mode_instructions(&mut self) { + pub(crate) fn open_windows_sandbox_enable_prompt(&mut self, preset: ApprovalPreset) { use ratatui_macros::line; let mut header = ColumnRenderable::new(); - header.push(line![ - "Auto mode requires Windows Subsystem for Linux (WSL2).".bold() - ]); - header.push(line!["Run Codex inside WSL to enable sandboxed commands."]); - header.push(line![""]); - header.push(Paragraph::new(WSL_INSTRUCTIONS).wrap(Wrap { trim: false })); - - let items = vec![SelectionItem { - name: "Back".to_string(), - description: Some( - "Return to the approval mode list. Auto mode stays disabled outside WSL." - .to_string(), - ), - actions: vec![Box::new(|tx| { - tx.send(AppEvent::OpenApprovalsPopup); - })], - dismiss_on_select: true, - ..Default::default() - }]; + header.push(*Box::new( + Paragraph::new(vec![ + line!["Agent mode on Windows uses an experimental sandbox to limit network and filesystem access.".bold()], + line![ + "Learn more: https://developers.openai.com/codex/windows" + ], + ]) + .wrap(Wrap { trim: false }), + )); + + let preset_clone = preset; + let items = vec![ + SelectionItem { + name: "Enable experimental sandbox".to_string(), + description: None, + actions: vec![Box::new(move |tx| { + tx.send(AppEvent::EnableWindowsSandboxForAgentMode { + preset: preset_clone.clone(), + }); + })], + dismiss_on_select: true, + ..Default::default() + }, + SelectionItem { + name: "Go back".to_string(), + description: None, + actions: vec![Box::new(|tx| { + tx.send(AppEvent::OpenApprovalsPopup); + })], + dismiss_on_select: true, + ..Default::default() + }, + ]; self.bottom_pane.show_selection_view(SelectionViewParams { title: None, @@ -2555,7 +2589,31 @@ impl ChatWidget { } #[cfg(not(target_os = "windows"))] - pub(crate) fn open_windows_auto_mode_instructions(&mut self) {} + pub(crate) fn open_windows_sandbox_enable_prompt(&mut self, _preset: ApprovalPreset) {} + + #[cfg(target_os = "windows")] + pub(crate) fn maybe_prompt_windows_sandbox_enable(&mut self) { + if self.config.forced_auto_mode_downgraded_on_windows + && codex_core::get_platform_sandbox().is_none() + && let Some(preset) = builtin_approval_presets() + .into_iter() + .find(|preset| preset.id == "auto") + { + self.open_windows_sandbox_enable_prompt(preset); + } + } + + #[cfg(not(target_os = "windows"))] + pub(crate) fn maybe_prompt_windows_sandbox_enable(&mut self) {} + + #[cfg(target_os = "windows")] + pub(crate) fn clear_forced_auto_mode_downgrade(&mut self) { + self.config.forced_auto_mode_downgraded_on_windows = false; + } + + #[cfg(not(target_os = "windows"))] + #[allow(dead_code)] + pub(crate) fn clear_forced_auto_mode_downgrade(&mut self) {} /// Set the approval policy in the widget's config copy. pub(crate) fn set_approval_policy(&mut self, policy: AskForApproval) { @@ -2564,7 +2622,16 @@ impl ChatWidget { /// Set the sandbox policy in the widget's config copy. pub(crate) fn set_sandbox_policy(&mut self, policy: SandboxPolicy) { + #[cfg(target_os = "windows")] + let should_clear_downgrade = !matches!(policy, SandboxPolicy::ReadOnly) + || codex_core::get_platform_sandbox().is_some(); + self.config.sandbox_policy = policy; + + #[cfg(target_os = "windows")] + if should_clear_downgrade { + self.config.forced_auto_mode_downgraded_on_windows = false; + } } pub(crate) fn set_full_access_warning_acknowledged(&mut self, acknowledged: bool) { @@ -2721,6 +2788,7 @@ impl ChatWidget { review_request: ReviewRequest { prompt: "Review the current code changes (staged, unstaged, and untracked files) and provide prioritized findings.".to_string(), user_facing_hint: "current changes".to_string(), + append_to_original_thread: true, }, })); }, @@ -2777,6 +2845,7 @@ impl ChatWidget { "Review the code changes against the base branch '{branch}'. Start by finding the merge diff between the current branch and {branch}'s upstream e.g. (`git merge-base HEAD \"$(git rev-parse --abbrev-ref \"{branch}@{{upstream}}\")\"`), then run `git diff` against that SHA to see what changes we would merge into the {branch} branch. Provide prioritized, actionable findings." ), user_facing_hint: format!("changes against '{branch}'"), + append_to_original_thread: true, }, })); })], @@ -2817,6 +2886,7 @@ impl ChatWidget { review_request: ReviewRequest { prompt, user_facing_hint: hint, + append_to_original_thread: true, }, })); })], @@ -2851,6 +2921,7 @@ impl ChatWidget { review_request: ReviewRequest { prompt: trimmed.clone(), user_facing_hint: trimmed, + append_to_original_thread: true, }, })); }), @@ -3061,6 +3132,7 @@ pub(crate) fn show_review_commit_picker_with_entries( review_request: ReviewRequest { prompt, user_facing_hint: hint, + append_to_original_thread: true, }, })); })], diff --git a/codex-rs/tui/src/chatwidget/agent.rs b/codex-rs/tui/src/chatwidget/agent.rs index bf15b6c4a..240972347 100644 --- a/codex-rs/tui/src/chatwidget/agent.rs +++ b/codex-rs/tui/src/chatwidget/agent.rs @@ -4,6 +4,8 @@ use codex_core::CodexConversation; use codex_core::ConversationManager; use codex_core::NewConversation; use codex_core::config::Config; +use codex_core::protocol::Event; +use codex_core::protocol::EventMsg; use codex_core::protocol::Op; use tokio::sync::mpsc::UnboundedSender; use tokio::sync::mpsc::unbounded_channel; @@ -28,9 +30,16 @@ pub(crate) fn spawn_agent( session_configured, } = match server.new_conversation(config).await { Ok(v) => v, - Err(e) => { - // TODO: surface this error to the user. - tracing::error!("failed to initialize codex: {e}"); + #[allow(clippy::print_stderr)] + Err(err) => { + let message = err.to_string(); + eprintln!("{message}"); + app_event_tx_clone.send(AppEvent::CodexEvent(Event { + id: "".to_string(), + msg: EventMsg::Error(err.to_error_event(None)), + })); + app_event_tx_clone.send(AppEvent::ExitRequest); + tracing::error!("failed to initialize codex: {err}"); return; } }; diff --git a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__approvals_selection_popup.snap b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__approvals_selection_popup.snap index 190594b1b..6758ec62c 100644 --- a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__approvals_selection_popup.snap +++ b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__approvals_selection_popup.snap @@ -4,14 +4,10 @@ expression: popup --- Select Approval Mode -› 1. Read Only (current) Codex can read files and answer questions. Codex - requires approval to make edits, run commands, or - access network. - 2. Auto Codex can read files, make edits, and run commands - in the workspace. Codex requires approval to work - outside the workspace or access network. - 3. Full Access Codex can read files, make edits, and run commands - with network access, without approval. Exercise - caution. +› 1. Read Only (current) Requires approval to edit files and run commands. + 2. Agent Read and edit files, and run commands. + 3. Agent (full access) Codex can edit files outside this workspace and run + commands with network access. Exercise caution when + using. Press enter to confirm or esc to go back diff --git a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__approvals_selection_popup@windows.snap b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__approvals_selection_popup@windows.snap index 7d16ad57b..6758ec62c 100644 --- a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__approvals_selection_popup@windows.snap +++ b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__approvals_selection_popup@windows.snap @@ -4,16 +4,10 @@ expression: popup --- Select Approval Mode -› 1. Read Only (current) Codex can read files and answer questions. Codex - requires approval to make edits, run commands, or - access network. - 2. Auto Codex can read files, make edits, and run commands - in the workspace. Codex requires approval to work - outside the workspace or access network. - Requires Windows Subsystem for Linux (WSL). Show - installation instructions... - 3. Full Access Codex can read files, make edits, and run commands - with network access, without approval. Exercise - caution. +› 1. Read Only (current) Requires approval to edit files and run commands. + 2. Agent Read and edit files, and run commands. + 3. Agent (full access) Codex can edit files outside this workspace and run + commands with network access. Exercise caution when + using. Press enter to confirm or esc to go back diff --git a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__model_reasoning_selection_popup.snap b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__model_reasoning_selection_popup.snap index 060d1f82a..b4b89736a 100644 --- a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__model_reasoning_selection_popup.snap +++ b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__model_reasoning_selection_popup.snap @@ -2,13 +2,11 @@ source: tui/src/chatwidget/tests.rs expression: popup --- - Select Reasoning Level for gpt-5.1-codex + Select Reasoning Level for gpt-5.1-codex-max - 1. Low Fastest responses with limited reasoning - 2. Medium (default) Dynamically adjusts reasoning based on the task -› 3. High (current) Maximizes reasoning depth for complex or ambiguous - problems - ⚠ High reasoning effort can quickly consume Plus plan - rate limits. + 1. Low Fast responses with lighter reasoning + 2. Medium (default) Balances speed and reasoning depth for everyday tasks +› 3. High (current) Maximizes reasoning depth for complex problems + 4. Extra high Extra high reasoning depth for complex problems Press enter to confirm or esc to go back diff --git a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__model_reasoning_selection_popup_extra_high_warning.snap b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__model_reasoning_selection_popup_extra_high_warning.snap new file mode 100644 index 000000000..c5332ff59 --- /dev/null +++ b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__model_reasoning_selection_popup_extra_high_warning.snap @@ -0,0 +1,16 @@ +--- +source: tui/src/chatwidget/tests.rs +assertion_line: 1548 +expression: popup +--- + Select Reasoning Level for gpt-5.1-codex-max + + 1. Low Fast responses with lighter reasoning + 2. Medium (default) Balances speed and reasoning depth for everyday + tasks + 3. High Maximizes reasoning depth for complex problems +› 4. Extra high (current) Extra high reasoning depth for complex problems + ⚠ Extra high reasoning effort can quickly consume + Plus plan rate limits. + + Press enter to confirm or esc to go back diff --git a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__user_shell_ls_output.snap b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__user_shell_ls_output.snap new file mode 100644 index 000000000..c67cd637d --- /dev/null +++ b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__user_shell_ls_output.snap @@ -0,0 +1,7 @@ +--- +source: tui/src/chatwidget/tests.rs +expression: blob +--- +• You ran ls + └ file1 + file2 diff --git a/codex-rs/tui/src/chatwidget/tests.rs b/codex-rs/tui/src/chatwidget/tests.rs index abd9a6123..61e23fb3f 100644 --- a/codex-rs/tui/src/chatwidget/tests.rs +++ b/codex-rs/tui/src/chatwidget/tests.rs @@ -38,6 +38,9 @@ use codex_core::protocol::ReviewRequest; use codex_core::protocol::StreamErrorEvent; use codex_core::protocol::TaskCompleteEvent; use codex_core::protocol::TaskStartedEvent; +use codex_core::protocol::TokenCountEvent; +use codex_core::protocol::TokenUsage; +use codex_core::protocol::TokenUsageInfo; use codex_core::protocol::UndoCompletedEvent; use codex_core::protocol::UndoStartedEvent; use codex_core::protocol::ViewImageToolCallEvent; @@ -47,6 +50,7 @@ use codex_protocol::parse_command::ParsedCommand; use codex_protocol::plan_tool::PlanItemArg; use codex_protocol::plan_tool::StepStatus; use codex_protocol::plan_tool::UpdatePlanArgs; +use codex_protocol::protocol::CodexErrorInfo; use crossterm::event::KeyCode; use crossterm::event::KeyEvent; use crossterm::event::KeyModifiers; @@ -58,6 +62,11 @@ use tempfile::tempdir; use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::unbounded_channel; +#[cfg(target_os = "windows")] +fn set_windows_sandbox_enabled(enabled: bool) { + codex_core::set_windows_sandbox_enabled(enabled); +} + fn test_config() -> Config { // Use base defaults to avoid depending on host state. Config::load_from_base_config_with_overrides( @@ -76,6 +85,7 @@ fn snapshot(percent: f64) -> RateLimitSnapshot { resets_at: None, }), secondary: None, + credits: None, } } @@ -88,6 +98,10 @@ fn resumed_initial_messages_render_history() { let configured = codex_core::protocol::SessionConfiguredEvent { session_id: conversation_id, model: "test-model".to_string(), + model_provider_id: "test-provider".to_string(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + cwd: PathBuf::from("/home/user/project"), reasoning_effort: Some(ReasoningEffortConfig::default()), history_log_id: 0, history_entry_count: 0, @@ -140,6 +154,7 @@ fn entered_review_mode_uses_request_hint() { msg: EventMsg::EnteredReviewMode(ReviewRequest { prompt: "Review the latest changes".to_string(), user_facing_hint: "feature branch".to_string(), + append_to_original_thread: true, }), }); @@ -159,6 +174,7 @@ fn entered_review_mode_defaults_to_current_changes_banner() { msg: EventMsg::EnteredReviewMode(ReviewRequest { prompt: "Review the current changes".to_string(), user_facing_hint: "current changes".to_string(), + append_to_original_thread: true, }), }); @@ -203,6 +219,81 @@ fn exited_review_mode_emits_results_and_finishes() { assert!(!chat.is_review_mode); } +/// Exiting review restores the pre-review context window indicator. +#[test] +fn review_restores_context_window_indicator() { + let (mut chat, mut rx, _ops) = make_chatwidget_manual(); + + let context_window = 13_000; + let pre_review_tokens = 12_700; // ~30% remaining after subtracting baseline. + let review_tokens = 12_030; // ~97% remaining after subtracting baseline. + + chat.handle_codex_event(Event { + id: "token-before".into(), + msg: EventMsg::TokenCount(TokenCountEvent { + info: Some(make_token_info(pre_review_tokens, context_window)), + rate_limits: None, + }), + }); + assert_eq!(chat.bottom_pane.context_window_percent(), Some(30)); + + chat.handle_codex_event(Event { + id: "review-start".into(), + msg: EventMsg::EnteredReviewMode(ReviewRequest { + prompt: "Review the latest changes".to_string(), + user_facing_hint: "feature branch".to_string(), + append_to_original_thread: true, + }), + }); + + chat.handle_codex_event(Event { + id: "token-review".into(), + msg: EventMsg::TokenCount(TokenCountEvent { + info: Some(make_token_info(review_tokens, context_window)), + rate_limits: None, + }), + }); + assert_eq!(chat.bottom_pane.context_window_percent(), Some(97)); + + chat.handle_codex_event(Event { + id: "review-end".into(), + msg: EventMsg::ExitedReviewMode(ExitedReviewModeEvent { + review_output: None, + }), + }); + let _ = drain_insert_history(&mut rx); + + assert_eq!(chat.bottom_pane.context_window_percent(), Some(30)); + assert!(!chat.is_review_mode); +} + +/// Receiving a TokenCount event without usage clears the context indicator. +#[test] +fn token_count_none_resets_context_indicator() { + let (mut chat, _rx, _ops) = make_chatwidget_manual(); + + let context_window = 13_000; + let pre_compact_tokens = 12_700; + + chat.handle_codex_event(Event { + id: "token-before".into(), + msg: EventMsg::TokenCount(TokenCountEvent { + info: Some(make_token_info(pre_compact_tokens, context_window)), + rate_limits: None, + }), + }); + assert_eq!(chat.bottom_pane.context_window_percent(), Some(30)); + + chat.handle_codex_event(Event { + id: "token-cleared".into(), + msg: EventMsg::TokenCount(TokenCountEvent { + info: None, + rate_limits: None, + }), + }); + assert_eq!(chat.bottom_pane.context_window_percent(), None); +} + #[cfg_attr( target_os = "macos", ignore = "system configuration APIs are blocked under macOS seatbelt" @@ -248,6 +339,7 @@ fn make_chatwidget_manual() -> ( enhanced_keys_supported: false, placeholder_text: "Ask Codex to do anything".to_string(), disable_paste_burst: false, + animations_enabled: cfg.animations, }); let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("test")); let widget = ChatWidget { @@ -280,6 +372,7 @@ fn make_chatwidget_manual() -> ( suppress_session_configured_redraw: false, pending_notification: None, is_review_mode: false, + pre_review_token_info: None, needs_final_message_separator: false, last_rendered_width: std::cell::Cell::new(None), feedback: codex_feedback::CodexFeedback::new(), @@ -326,6 +419,21 @@ fn lines_to_single_string(lines: &[ratatui::text::Line<'static>]) -> String { s } +fn make_token_info(total_tokens: i64, context_window: i64) -> TokenUsageInfo { + fn usage(total_tokens: i64) -> TokenUsage { + TokenUsage { + total_tokens, + ..TokenUsage::default() + } + } + + TokenUsageInfo { + total_token_usage: usage(total_tokens), + last_token_usage: usage(total_tokens), + model_context_window: Some(context_window), + } +} + #[test] fn rate_limit_warnings_emit_thresholds() { let mut state = RateLimitWarningState::default(); @@ -1457,72 +1565,117 @@ fn approvals_selection_popup_snapshot() { } #[test] -fn approvals_popup_includes_wsl_note_for_auto_mode() { +fn full_access_confirmation_popup_snapshot() { let (mut chat, _rx, _op_rx) = make_chatwidget_manual(); - if cfg!(target_os = "windows") { - chat.config.forced_auto_mode_downgraded_on_windows = true; - } - chat.open_approvals_popup(); + let preset = builtin_approval_presets() + .into_iter() + .find(|preset| preset.id == "full-access") + .expect("full access preset"); + chat.open_full_access_confirmation(preset); let popup = render_bottom_popup(&chat, 80); - assert_eq!( - popup.contains("Requires Windows Subsystem for Linux (WSL)"), - cfg!(target_os = "windows"), - "expected auto preset description to mention WSL requirement only on Windows, popup: {popup}" - ); - assert_eq!( - popup.contains("Codex forced your settings back to Read Only on this Windows machine."), - cfg!(target_os = "windows") && chat.config.forced_auto_mode_downgraded_on_windows, - "expected downgrade notice only when auto mode is forced off on Windows, popup: {popup}" - ); + assert_snapshot!("full_access_confirmation_popup", popup); } +#[cfg(target_os = "windows")] #[test] -fn full_access_confirmation_popup_snapshot() { +fn windows_auto_mode_prompt_requests_enabling_sandbox_feature() { let (mut chat, _rx, _op_rx) = make_chatwidget_manual(); let preset = builtin_approval_presets() .into_iter() - .find(|preset| preset.id == "full-access") - .expect("full access preset"); - chat.open_full_access_confirmation(preset); + .find(|preset| preset.id == "auto") + .expect("auto preset"); + chat.open_windows_sandbox_enable_prompt(preset); - let popup = render_bottom_popup(&chat, 80); - assert_snapshot!("full_access_confirmation_popup", popup); + let popup = render_bottom_popup(&chat, 120); + assert!( + popup.contains("Agent mode on Windows uses an experimental sandbox"), + "expected auto mode prompt to mention enabling the sandbox feature, popup: {popup}" + ); } #[cfg(target_os = "windows")] #[test] -fn windows_auto_mode_instructions_popup_lists_install_steps() { +fn startup_prompts_for_windows_sandbox_when_agent_requested() { let (mut chat, _rx, _op_rx) = make_chatwidget_manual(); - chat.open_windows_auto_mode_instructions(); + set_windows_sandbox_enabled(false); + chat.config.forced_auto_mode_downgraded_on_windows = true; + + chat.maybe_prompt_windows_sandbox_enable(); let popup = render_bottom_popup(&chat, 120); assert!( - popup.contains("wsl --install"), - "expected WSL instructions popup to include install command, popup: {popup}" + popup.contains("Agent mode on Windows uses an experimental sandbox"), + "expected startup prompt to explain sandbox: {popup}" + ); + assert!( + popup.contains("Enable experimental sandbox"), + "expected startup prompt to offer enabling the sandbox: {popup}" ); + + set_windows_sandbox_enabled(true); } #[test] fn model_reasoning_selection_popup_snapshot() { let (mut chat, _rx, _op_rx) = make_chatwidget_manual(); - chat.config.model = "gpt-5.1-codex".to_string(); + chat.config.model = "gpt-5.1-codex-max".to_string(); chat.config.model_reasoning_effort = Some(ReasoningEffortConfig::High); let preset = builtin_model_presets(None) .into_iter() - .find(|preset| preset.model == "gpt-5.1-codex") - .expect("gpt-5.1-codex preset"); + .find(|preset| preset.model == "gpt-5.1-codex-max") + .expect("gpt-5.1-codex-max preset"); chat.open_reasoning_popup(preset); let popup = render_bottom_popup(&chat, 80); assert_snapshot!("model_reasoning_selection_popup", popup); } +#[test] +fn model_reasoning_selection_popup_extra_high_warning_snapshot() { + let (mut chat, _rx, _op_rx) = make_chatwidget_manual(); + + chat.config.model = "gpt-5.1-codex-max".to_string(); + chat.config.model_reasoning_effort = Some(ReasoningEffortConfig::XHigh); + + let preset = builtin_model_presets(None) + .into_iter() + .find(|preset| preset.model == "gpt-5.1-codex-max") + .expect("gpt-5.1-codex-max preset"); + chat.open_reasoning_popup(preset); + + let popup = render_bottom_popup(&chat, 80); + assert_snapshot!("model_reasoning_selection_popup_extra_high_warning", popup); +} + +#[test] +fn reasoning_popup_shows_extra_high_with_space() { + let (mut chat, _rx, _op_rx) = make_chatwidget_manual(); + + chat.config.model = "gpt-5.1-codex-max".to_string(); + + let preset = builtin_model_presets(None) + .into_iter() + .find(|preset| preset.model == "gpt-5.1-codex-max") + .expect("gpt-5.1-codex-max preset"); + chat.open_reasoning_popup(preset); + + let popup = render_bottom_popup(&chat, 120); + assert!( + popup.contains("Extra high"), + "expected popup to include 'Extra high'; popup: {popup}" + ); + assert!( + !popup.contains("Extrahigh"), + "expected popup not to include 'Extrahigh'; popup: {popup}" + ); +} + #[test] fn single_reasoning_option_skips_selection() { let (mut chat, mut rx, _op_rx) = make_chatwidget_manual(); @@ -1540,6 +1693,7 @@ fn single_reasoning_option_skips_selection() { supported_reasoning_efforts: &SINGLE_EFFORT, is_default: false, upgrade: None, + show_in_picker: true, }; chat.open_reasoning_popup(preset); @@ -1638,6 +1792,28 @@ fn exec_history_extends_previous_when_consecutive() { assert_snapshot!("exploring_step6_finish_cat_bar", active_blob(&chat)); } +#[test] +fn user_shell_command_renders_output_not_exploring() { + let (mut chat, mut rx, _op_rx) = make_chatwidget_manual(); + + let begin_ls = begin_exec_with_source( + &mut chat, + "user-shell-ls", + "ls", + ExecCommandSource::UserShell, + ); + end_exec(&mut chat, begin_ls, "file1\nfile2\n", "", 0); + + let cells = drain_insert_history(&mut rx); + assert_eq!( + cells.len(), + 1, + "expected a single history cell for the user command" + ); + let blob = lines_to_single_string(cells.first().unwrap()); + assert_snapshot!("user_shell_ls_output", blob); +} + #[test] fn disabled_slash_command_while_task_running_snapshot() { // Build a chat widget and simulate an active task @@ -1760,6 +1936,7 @@ fn approval_modal_patch_snapshot() { ); let ev = ApplyPatchApprovalRequestEvent { call_id: "call-approve-patch".into(), + turn_id: "turn-approve-patch".into(), changes, reason: Some("The model wants to apply changes".into()), grant_root: Some(PathBuf::from("/tmp")), @@ -2012,6 +2189,7 @@ fn apply_patch_events_emit_history_cells() { ); let ev = ApplyPatchApprovalRequestEvent { call_id: "c1".into(), + turn_id: "turn-c1".into(), changes, reason: None, grant_root: None, @@ -2052,6 +2230,7 @@ fn apply_patch_events_emit_history_cells() { ); let begin = PatchApplyBeginEvent { call_id: "c1".into(), + turn_id: "turn-c1".into(), auto_approved: true, changes: changes2, }; @@ -2068,11 +2247,20 @@ fn apply_patch_events_emit_history_cells() { ); // 3) End apply success -> success cell + let mut end_changes = HashMap::new(); + end_changes.insert( + PathBuf::from("foo.txt"), + FileChange::Add { + content: "hello\n".to_string(), + }, + ); let end = PatchApplyEndEvent { call_id: "c1".into(), + turn_id: "turn-c1".into(), stdout: "ok\n".into(), stderr: String::new(), success: true, + changes: end_changes, }; chat.handle_codex_event(Event { id: "s1".into(), @@ -2100,6 +2288,7 @@ fn apply_patch_manual_approval_adjusts_header() { id: "s1".into(), msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id: "c1".into(), + turn_id: "turn-c1".into(), changes: proposed_changes, reason: None, grant_root: None, @@ -2118,6 +2307,7 @@ fn apply_patch_manual_approval_adjusts_header() { id: "s1".into(), msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id: "c1".into(), + turn_id: "turn-c1".into(), auto_approved: false, changes: apply_changes, }), @@ -2147,6 +2337,7 @@ fn apply_patch_manual_flow_snapshot() { id: "s1".into(), msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id: "c1".into(), + turn_id: "turn-c1".into(), changes: proposed_changes, reason: Some("Manual review required".into()), grant_root: None, @@ -2169,6 +2360,7 @@ fn apply_patch_manual_flow_snapshot() { id: "s1".into(), msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id: "c1".into(), + turn_id: "turn-c1".into(), auto_approved: false, changes: apply_changes, }), @@ -2196,6 +2388,7 @@ fn apply_patch_approval_sends_op_with_submission_id() { ); let ev = ApplyPatchApprovalRequestEvent { call_id: "call-999".into(), + turn_id: "turn-999".into(), changes, reason: None, grant_root: None, @@ -2235,6 +2428,7 @@ fn apply_patch_full_flow_integration_like() { id: "sub-xyz".into(), msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id: "call-1".into(), + turn_id: "turn-call-1".into(), changes, reason: None, grant_root: None, @@ -2275,17 +2469,25 @@ fn apply_patch_full_flow_integration_like() { id: "sub-xyz".into(), msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id: "call-1".into(), + turn_id: "turn-call-1".into(), auto_approved: false, changes: changes2, }), }); + let mut end_changes = HashMap::new(); + end_changes.insert( + PathBuf::from("pkg.rs"), + FileChange::Add { content: "".into() }, + ); chat.handle_codex_event(Event { id: "sub-xyz".into(), msg: EventMsg::PatchApplyEnd(PatchApplyEndEvent { call_id: "call-1".into(), + turn_id: "turn-call-1".into(), stdout: String::from("ok"), stderr: String::new(), success: true, + changes: end_changes, }), }); } @@ -2306,6 +2508,7 @@ fn apply_patch_untrusted_shows_approval_modal() { id: "sub-1".into(), msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id: "call-1".into(), + turn_id: "turn-call-1".into(), changes, reason: None, grant_root: None, @@ -2354,6 +2557,7 @@ fn apply_patch_request_shows_diff_summary() { id: "sub-apply".into(), msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id: "call-apply".into(), + turn_id: "turn-apply".into(), changes, reason: None, grant_root: None, @@ -2444,6 +2648,7 @@ fn stream_error_updates_status_indicator() { id: "sub-1".into(), msg: EventMsg::StreamError(StreamErrorEvent { message: msg.to_string(), + codex_error_info: Some(CodexErrorInfo::Other), }), }); diff --git a/codex-rs/tui/src/cli.rs b/codex-rs/tui/src/cli.rs index e7a0c945b..2b19b4c06 100644 --- a/codex-rs/tui/src/cli.rs +++ b/codex-rs/tui/src/cli.rs @@ -28,6 +28,10 @@ pub struct Cli { #[clap(skip)] pub resume_session_id: Option, + /// Internal: show all sessions (disables cwd filtering and shows CWD column). + #[clap(skip)] + pub resume_show_all: bool, + /// Model the agent should use. #[arg(long, short = 'm')] pub model: Option, diff --git a/codex-rs/tui/src/exec_cell/model.rs b/codex-rs/tui/src/exec_cell/model.rs index 943fb8365..76316968c 100644 --- a/codex-rs/tui/src/exec_cell/model.rs +++ b/codex-rs/tui/src/exec_cell/model.rs @@ -28,11 +28,15 @@ pub(crate) struct ExecCall { #[derive(Debug)] pub(crate) struct ExecCell { pub(crate) calls: Vec, + animations_enabled: bool, } impl ExecCell { - pub(crate) fn new(call: ExecCall) -> Self { - Self { calls: vec![call] } + pub(crate) fn new(call: ExecCall, animations_enabled: bool) -> Self { + Self { + calls: vec![call], + animations_enabled, + } } pub(crate) fn with_added_call( @@ -56,6 +60,7 @@ impl ExecCell { if self.is_exploring_cell() && Self::is_exploring_call(&call) { Some(Self { calls: [self.calls.clone(), vec![call]].concat(), + animations_enabled: self.animations_enabled, }) } else { None @@ -112,12 +117,17 @@ impl ExecCell { .and_then(|c| c.start_time) } + pub(crate) fn animations_enabled(&self) -> bool { + self.animations_enabled + } + pub(crate) fn iter_calls(&self) -> impl Iterator { self.calls.iter() } pub(super) fn is_exploring_call(call: &ExecCall) -> bool { - !call.parsed.is_empty() + !matches!(call.source, ExecCommandSource::UserShell) + && !call.parsed.is_empty() && call.parsed.iter().all(|p| { matches!( p, diff --git a/codex-rs/tui/src/exec_cell/render.rs b/codex-rs/tui/src/exec_cell/render.rs index 352a61476..3e434138d 100644 --- a/codex-rs/tui/src/exec_cell/render.rs +++ b/codex-rs/tui/src/exec_cell/render.rs @@ -40,17 +40,21 @@ pub(crate) fn new_active_exec_command( parsed: Vec, source: ExecCommandSource, interaction_input: Option, + animations_enabled: bool, ) -> ExecCell { - ExecCell::new(ExecCall { - call_id, - command, - parsed, - output: None, - source, - start_time: Some(Instant::now()), - duration: None, - interaction_input, - }) + ExecCell::new( + ExecCall { + call_id, + command, + parsed, + output: None, + source, + start_time: Some(Instant::now()), + duration: None, + interaction_input, + }, + animations_enabled, + ) } fn format_unified_exec_interaction(command: &[String], input: Option<&str>) -> String { @@ -168,7 +172,10 @@ pub(crate) fn output_lines( } } -pub(crate) fn spinner(start_time: Option) -> Span<'static> { +pub(crate) fn spinner(start_time: Option, animations_enabled: bool) -> Span<'static> { + if !animations_enabled { + return "•".dim(); + } let elapsed = start_time.map(|st| st.elapsed()).unwrap_or_default(); if supports_color::on_cached(supports_color::Stream::Stdout) .map(|level| level.has_16m) @@ -239,7 +246,7 @@ impl ExecCell { let mut out: Vec> = Vec::new(); out.push(Line::from(vec![ if self.is_active() { - spinner(self.active_start_time()) + spinner(self.active_start_time(), self.animations_enabled()) } else { "•".dim() }, @@ -347,7 +354,7 @@ impl ExecCell { let bullet = match success { Some(true) => "•".green().bold(), Some(false) => "•".red().bold(), - None => spinner(call.start_time), + None => spinner(call.start_time, self.animations_enabled()), }; let is_interaction = call.is_unified_exec_interaction(); let title = if is_interaction { diff --git a/codex-rs/tui/src/file_search.rs b/codex-rs/tui/src/file_search.rs index 61327cfd2..af4651264 100644 --- a/codex-rs/tui/src/file_search.rs +++ b/codex-rs/tui/src/file_search.rs @@ -31,7 +31,7 @@ use std::time::Duration; use crate::app_event::AppEvent; use crate::app_event_sender::AppEventSender; -const MAX_FILE_SEARCH_RESULTS: NonZeroUsize = NonZeroUsize::new(8).unwrap(); +const MAX_FILE_SEARCH_RESULTS: NonZeroUsize = NonZeroUsize::new(20).unwrap(); const NUM_FILE_SEARCH_THREADS: NonZeroUsize = NonZeroUsize::new(2).unwrap(); /// How long to wait after a keystroke before firing the first search when none diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index b026170f7..02ab0d243 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -584,11 +584,7 @@ pub(crate) fn new_session_info( let SessionConfiguredEvent { model, reasoning_effort, - session_id: _, - history_log_id: _, - history_entry_count: _, - initial_messages: _, - rollout_path: _, + .. } = event; SessionInfoCell(if is_first_event { // Header box rendered as history (so it appears at the very top) @@ -712,6 +708,7 @@ impl SessionHeaderHistoryCell { ReasoningEffortConfig::Low => "low", ReasoningEffortConfig::Medium => "medium", ReasoningEffortConfig::High => "high", + ReasoningEffortConfig::XHigh => "xhigh", ReasoningEffortConfig::None => "none", }) } @@ -809,16 +806,22 @@ pub(crate) struct McpToolCallCell { start_time: Instant, duration: Option, result: Option>, + animations_enabled: bool, } impl McpToolCallCell { - pub(crate) fn new(call_id: String, invocation: McpInvocation) -> Self { + pub(crate) fn new( + call_id: String, + invocation: McpInvocation, + animations_enabled: bool, + ) -> Self { Self { call_id, invocation, start_time: Instant::now(), duration: None, result: None, + animations_enabled, } } @@ -880,7 +883,7 @@ impl HistoryCell for McpToolCallCell { let bullet = match status { Some(true) => "•".green().bold(), Some(false) => "•".red().bold(), - None => spinner(Some(self.start_time)), + None => spinner(Some(self.start_time), self.animations_enabled), }; let header_text = if status.is_some() { "Called" @@ -968,8 +971,9 @@ impl HistoryCell for McpToolCallCell { pub(crate) fn new_active_mcp_tool_call( call_id: String, invocation: McpInvocation, + animations_enabled: bool, ) -> McpToolCallCell { - McpToolCallCell::new(call_id, invocation) + McpToolCallCell::new(call_id, invocation, animations_enabled) } pub(crate) fn new_web_search_call(query: String) -> PlainHistoryCell { @@ -1634,7 +1638,7 @@ mod tests { })), }; - let cell = new_active_mcp_tool_call("call-1".into(), invocation); + let cell = new_active_mcp_tool_call("call-1".into(), invocation, true); let rendered = render_lines(&cell.display_lines(80)).join("\n"); insta::assert_snapshot!(rendered); @@ -1661,7 +1665,7 @@ mod tests { structured_content: None, }; - let mut cell = new_active_mcp_tool_call("call-2".into(), invocation); + let mut cell = new_active_mcp_tool_call("call-2".into(), invocation, true); assert!( cell.complete(Duration::from_millis(1420), Ok(result)) .is_none() @@ -1683,7 +1687,7 @@ mod tests { })), }; - let mut cell = new_active_mcp_tool_call("call-3".into(), invocation); + let mut cell = new_active_mcp_tool_call("call-3".into(), invocation, true); assert!( cell.complete(Duration::from_secs(2), Err("network timeout".into())) .is_none() @@ -1727,7 +1731,7 @@ mod tests { structured_content: None, }; - let mut cell = new_active_mcp_tool_call("call-4".into(), invocation); + let mut cell = new_active_mcp_tool_call("call-4".into(), invocation, true); assert!( cell.complete(Duration::from_millis(640), Ok(result)) .is_none() @@ -1759,7 +1763,7 @@ mod tests { structured_content: None, }; - let mut cell = new_active_mcp_tool_call("call-5".into(), invocation); + let mut cell = new_active_mcp_tool_call("call-5".into(), invocation, true); assert!( cell.complete(Duration::from_millis(1280), Ok(result)) .is_none() @@ -1798,7 +1802,7 @@ mod tests { structured_content: None, }; - let mut cell = new_active_mcp_tool_call("call-6".into(), invocation); + let mut cell = new_active_mcp_tool_call("call-6".into(), invocation, true); assert!( cell.complete(Duration::from_millis(320), Ok(result)) .is_none() @@ -1856,32 +1860,35 @@ mod tests { fn coalesces_sequential_reads_within_one_call() { // Build one exec cell with a Search followed by two Reads let call_id = "c1".to_string(); - let mut cell = ExecCell::new(ExecCall { - call_id: call_id.clone(), - command: vec!["bash".into(), "-lc".into(), "echo".into()], - parsed: vec![ - ParsedCommand::Search { - query: Some("shimmer_spans".into()), - path: None, - cmd: "rg shimmer_spans".into(), - }, - ParsedCommand::Read { - name: "shimmer.rs".into(), - cmd: "cat shimmer.rs".into(), - path: "shimmer.rs".into(), - }, - ParsedCommand::Read { - name: "status_indicator_widget.rs".into(), - cmd: "cat status_indicator_widget.rs".into(), - path: "status_indicator_widget.rs".into(), - }, - ], - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: call_id.clone(), + command: vec!["bash".into(), "-lc".into(), "echo".into()], + parsed: vec![ + ParsedCommand::Search { + query: Some("shimmer_spans".into()), + path: None, + cmd: "rg shimmer_spans".into(), + }, + ParsedCommand::Read { + name: "shimmer.rs".into(), + cmd: "cat shimmer.rs".into(), + path: "shimmer.rs".into(), + }, + ParsedCommand::Read { + name: "status_indicator_widget.rs".into(), + cmd: "cat status_indicator_widget.rs".into(), + path: "status_indicator_widget.rs".into(), + }, + ], + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); // Mark call complete so markers are ✓ cell.complete_call(&call_id, CommandOutput::default(), Duration::from_millis(1)); @@ -1892,20 +1899,23 @@ mod tests { #[test] fn coalesces_reads_across_multiple_calls() { - let mut cell = ExecCell::new(ExecCall { - call_id: "c1".to_string(), - command: vec!["bash".into(), "-lc".into(), "echo".into()], - parsed: vec![ParsedCommand::Search { - query: Some("shimmer_spans".into()), - path: None, - cmd: "rg shimmer_spans".into(), - }], - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: "c1".to_string(), + command: vec!["bash".into(), "-lc".into(), "echo".into()], + parsed: vec![ParsedCommand::Search { + query: Some("shimmer_spans".into()), + path: None, + cmd: "rg shimmer_spans".into(), + }], + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); // Call 1: Search only cell.complete_call("c1", CommandOutput::default(), Duration::from_millis(1)); // Call 2: Read A @@ -1946,32 +1956,35 @@ mod tests { #[test] fn coalesced_reads_dedupe_names() { - let mut cell = ExecCell::new(ExecCall { - call_id: "c1".to_string(), - command: vec!["bash".into(), "-lc".into(), "echo".into()], - parsed: vec![ - ParsedCommand::Read { - name: "auth.rs".into(), - cmd: "cat auth.rs".into(), - path: "auth.rs".into(), - }, - ParsedCommand::Read { - name: "auth.rs".into(), - cmd: "cat auth.rs".into(), - path: "auth.rs".into(), - }, - ParsedCommand::Read { - name: "shimmer.rs".into(), - cmd: "cat shimmer.rs".into(), - path: "shimmer.rs".into(), - }, - ], - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: "c1".to_string(), + command: vec!["bash".into(), "-lc".into(), "echo".into()], + parsed: vec![ + ParsedCommand::Read { + name: "auth.rs".into(), + cmd: "cat auth.rs".into(), + path: "auth.rs".into(), + }, + ParsedCommand::Read { + name: "auth.rs".into(), + cmd: "cat auth.rs".into(), + path: "auth.rs".into(), + }, + ParsedCommand::Read { + name: "shimmer.rs".into(), + cmd: "cat shimmer.rs".into(), + path: "shimmer.rs".into(), + }, + ], + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); cell.complete_call("c1", CommandOutput::default(), Duration::from_millis(1)); let lines = cell.display_lines(80); let rendered = render_lines(&lines).join("\n"); @@ -1983,16 +1996,19 @@ mod tests { // Create a completed exec cell with a multiline command let cmd = "set -o pipefail\ncargo test --all-features --quiet".to_string(); let call_id = "c1".to_string(); - let mut cell = ExecCell::new(ExecCall { - call_id: call_id.clone(), - command: vec!["bash".into(), "-lc".into(), cmd], - parsed: Vec::new(), - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: call_id.clone(), + command: vec!["bash".into(), "-lc".into(), cmd], + parsed: Vec::new(), + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); // Mark call complete so it renders as "Ran" cell.complete_call(&call_id, CommandOutput::default(), Duration::from_millis(1)); @@ -2006,16 +2022,19 @@ mod tests { #[test] fn single_line_command_compact_when_fits() { let call_id = "c1".to_string(); - let mut cell = ExecCell::new(ExecCall { - call_id: call_id.clone(), - command: vec!["echo".into(), "ok".into()], - parsed: Vec::new(), - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: call_id.clone(), + command: vec!["echo".into(), "ok".into()], + parsed: Vec::new(), + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); cell.complete_call(&call_id, CommandOutput::default(), Duration::from_millis(1)); // Wide enough that it fits inline let lines = cell.display_lines(80); @@ -2027,16 +2046,19 @@ mod tests { fn single_line_command_wraps_with_four_space_continuation() { let call_id = "c1".to_string(); let long = "a_very_long_token_without_spaces_to_force_wrapping".to_string(); - let mut cell = ExecCell::new(ExecCall { - call_id: call_id.clone(), - command: vec!["bash".into(), "-lc".into(), long], - parsed: Vec::new(), - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: call_id.clone(), + command: vec!["bash".into(), "-lc".into(), long], + parsed: Vec::new(), + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); cell.complete_call(&call_id, CommandOutput::default(), Duration::from_millis(1)); let lines = cell.display_lines(24); let rendered = render_lines(&lines).join("\n"); @@ -2047,16 +2069,19 @@ mod tests { fn multiline_command_without_wrap_uses_branch_then_eight_spaces() { let call_id = "c1".to_string(); let cmd = "echo one\necho two".to_string(); - let mut cell = ExecCell::new(ExecCall { - call_id: call_id.clone(), - command: vec!["bash".into(), "-lc".into(), cmd], - parsed: Vec::new(), - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: call_id.clone(), + command: vec!["bash".into(), "-lc".into(), cmd], + parsed: Vec::new(), + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); cell.complete_call(&call_id, CommandOutput::default(), Duration::from_millis(1)); let lines = cell.display_lines(80); let rendered = render_lines(&lines).join("\n"); @@ -2068,16 +2093,19 @@ mod tests { let call_id = "c1".to_string(); let cmd = "first_token_is_long_enough_to_wrap\nsecond_token_is_also_long_enough_to_wrap" .to_string(); - let mut cell = ExecCell::new(ExecCall { - call_id: call_id.clone(), - command: vec!["bash".into(), "-lc".into(), cmd], - parsed: Vec::new(), - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: call_id.clone(), + command: vec!["bash".into(), "-lc".into(), cmd], + parsed: Vec::new(), + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); cell.complete_call(&call_id, CommandOutput::default(), Duration::from_millis(1)); let lines = cell.display_lines(28); let rendered = render_lines(&lines).join("\n"); @@ -2089,16 +2117,19 @@ mod tests { // Build an exec cell with a non-zero exit and 10 lines on stderr to exercise // the head/tail rendering and gutter prefixes. let call_id = "c_err".to_string(); - let mut cell = ExecCell::new(ExecCall { - call_id: call_id.clone(), - command: vec!["bash".into(), "-lc".into(), "seq 1 10 1>&2 && false".into()], - parsed: Vec::new(), - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: call_id.clone(), + command: vec!["bash".into(), "-lc".into(), "seq 1 10 1>&2 && false".into()], + parsed: Vec::new(), + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); let stderr: String = (1..=10) .map(|n| n.to_string()) .collect::>() @@ -2136,16 +2167,19 @@ mod tests { let call_id = "c_wrap_err".to_string(); let long_cmd = "echo this_is_a_very_long_single_token_that_will_wrap_across_the_available_width"; - let mut cell = ExecCell::new(ExecCall { - call_id: call_id.clone(), - command: vec!["bash".into(), "-lc".into(), long_cmd.to_string()], - parsed: Vec::new(), - output: None, - source: ExecCommandSource::Agent, - start_time: Some(Instant::now()), - duration: None, - interaction_input: None, - }); + let mut cell = ExecCell::new( + ExecCall { + call_id: call_id.clone(), + command: vec!["bash".into(), "-lc".into(), long_cmd.to_string()], + parsed: Vec::new(), + output: None, + source: ExecCommandSource::Agent, + start_time: Some(Instant::now()), + duration: None, + interaction_input: None, + }, + true, + ); let stderr = "error: first line on stderr\nerror: second line on stderr".to_string(); cell.complete_call( diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index ca7cba9c2..33bd18c43 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -86,7 +86,7 @@ mod wrapping; #[cfg(test)] pub mod test_backend; -use crate::onboarding::WSL_INSTRUCTIONS; +use crate::onboarding::TrustDirectorySelection; use crate::onboarding::onboarding_screen::OnboardingScreenArgs; use crate::onboarding::onboarding_screen::run_onboarding_app; use crate::tui::Tui; @@ -389,20 +389,13 @@ async fn run_ratatui_app( ); let login_status = get_login_status(&initial_config); let should_show_trust_screen = should_show_trust_screen(&initial_config); - let should_show_windows_wsl_screen = - cfg!(target_os = "windows") && !initial_config.windows_wsl_setup_acknowledged; - let should_show_onboarding = should_show_onboarding( - login_status, - &initial_config, - should_show_trust_screen, - should_show_windows_wsl_screen, - ); + let should_show_onboarding = + should_show_onboarding(login_status, &initial_config, should_show_trust_screen); let config = if should_show_onboarding { let onboarding_result = run_onboarding_app( OnboardingScreenArgs { show_login_screen: should_show_login_screen(login_status, &initial_config), - show_windows_wsl_screen: should_show_windows_wsl_screen, show_trust_screen: should_show_trust_screen, login_status, auth_manager: auth_manager.clone(), @@ -421,21 +414,12 @@ async fn run_ratatui_app( update_action: None, }); } - if onboarding_result.windows_install_selected { - restore(); - session_log::log_session_end(); - let _ = tui.terminal.clear(); - if let Err(err) = writeln!(std::io::stdout(), "{WSL_INSTRUCTIONS}") { - tracing::error!("Failed to write WSL instructions: {err}"); - } - return Ok(AppExitInfo { - token_usage: codex_core::protocol::TokenUsage::default(), - conversation_id: None, - update_action: None, - }); - } - // if the user acknowledged windows or made any trust decision, reload the config accordingly - if should_show_windows_wsl_screen || onboarding_result.directory_trust_decision.is_some() { + // if the user acknowledged windows or made an explicit decision ato trust the directory, reload the config accordingly + if onboarding_result + .directory_trust_decision + .map(|d| d == TrustDirectorySelection::Trust) + .unwrap_or(false) + { load_config_or_exit(cli_kv_overrides, overrides).await } else { initial_config @@ -490,6 +474,7 @@ async fn run_ratatui_app( &mut tui, &config.codex_home, &config.model_provider_id, + cli.resume_show_all, ) .await? { @@ -584,7 +569,7 @@ async fn load_config_or_exit( /// show the trust screen. fn should_show_trust_screen(config: &Config) -> bool { if cfg!(target_os = "windows") && get_platform_sandbox().is_none() { - // If the experimental sandbox is not enabled, Native Windows cannot enforce sandboxed write access without WSL; skip the trust prompt entirely. + // If the experimental sandbox is not enabled, Native Windows cannot enforce sandboxed write access; skip the trust prompt entirely. return false; } if config.did_user_set_custom_approval_policy_or_sandbox_mode { @@ -599,12 +584,7 @@ fn should_show_onboarding( login_status: LoginStatus, config: &Config, show_trust_screen: bool, - show_windows_wsl_screen: bool, ) -> bool { - if show_windows_wsl_screen { - return true; - } - if show_trust_screen { return true; } @@ -628,7 +608,6 @@ mod tests { use codex_core::config::ConfigOverrides; use codex_core::config::ConfigToml; use codex_core::config::ProjectConfig; - use codex_core::set_windows_sandbox_enabled; use serial_test::serial; use tempfile::TempDir; @@ -643,7 +622,7 @@ mod tests { )?; config.did_user_set_custom_approval_policy_or_sandbox_mode = false; config.active_project = ProjectConfig { trust_level: None }; - set_windows_sandbox_enabled(false); + config.set_windows_sandbox_globally(false); let should_show = should_show_trust_screen(&config); if cfg!(target_os = "windows") { @@ -670,7 +649,7 @@ mod tests { )?; config.did_user_set_custom_approval_policy_or_sandbox_mode = false; config.active_project = ProjectConfig { trust_level: None }; - set_windows_sandbox_enabled(true); + config.set_windows_sandbox_globally(true); let should_show = should_show_trust_screen(&config); if cfg!(target_os = "windows") { diff --git a/codex-rs/tui/src/markdown_render.rs b/codex-rs/tui/src/markdown_render.rs index 099d0860c..19cf94492 100644 --- a/codex-rs/tui/src/markdown_render.rs +++ b/codex-rs/tui/src/markdown_render.rs @@ -10,11 +10,50 @@ use pulldown_cmark::Parser; use pulldown_cmark::Tag; use pulldown_cmark::TagEnd; use ratatui::style::Style; -use ratatui::style::Stylize; use ratatui::text::Line; use ratatui::text::Span; use ratatui::text::Text; +struct MarkdownStyles { + h1: Style, + h2: Style, + h3: Style, + h4: Style, + h5: Style, + h6: Style, + code: Style, + emphasis: Style, + strong: Style, + strikethrough: Style, + ordered_list_marker: Style, + unordered_list_marker: Style, + link: Style, + blockquote: Style, +} + +impl Default for MarkdownStyles { + fn default() -> Self { + use ratatui::style::Stylize; + + Self { + h1: Style::new().bold().underlined(), + h2: Style::new().bold(), + h3: Style::new().bold().italic(), + h4: Style::new().italic(), + h5: Style::new().italic(), + h6: Style::new().italic(), + code: Style::new().cyan(), + emphasis: Style::new().italic(), + strong: Style::new().bold(), + strikethrough: Style::new().crossed_out(), + ordered_list_marker: Style::new().light_blue(), + unordered_list_marker: Style::new(), + link: Style::new().cyan().underlined(), + blockquote: Style::new().green(), + } + } +} + #[derive(Clone, Debug)] struct IndentContext { prefix: Vec>, @@ -51,6 +90,7 @@ where { iter: I, text: Text<'static>, + styles: MarkdownStyles, inline_styles: Vec