Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions agent/tools/sandbox_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@

_SANDBOX_SERVER = '''\
"""Minimal FastAPI server for sandbox operations."""
import os, subprocess, pathlib, signal, threading, re, tempfile
from fastapi import FastAPI
import os, subprocess, pathlib, signal, threading, re, tempfile, hmac
from fastapi import FastAPI, HTTPException, Header, Depends
from pydantic import BaseModel
from typing import Optional
import uvicorn
Expand Down Expand Up @@ -154,7 +154,19 @@ def _atomic_write(path: pathlib.Path, content: str):
except OSError:
pass

app = FastAPI()
_AUTH_TOKEN = os.environ.get("HF_TOKEN", "")

def require_auth(authorization: Optional[str] = Header(None)):
# Fail closed: if the sandbox secret isn't set, every request 401s
# rather than silently becoming open again.
if not _AUTH_TOKEN:
raise HTTPException(status_code=401, detail="unauthorized")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="unauthorized")
if not hmac.compare_digest(authorization[len("Bearer "):], _AUTH_TOKEN):
raise HTTPException(status_code=401, detail="unauthorized")

app = FastAPI(dependencies=[Depends(require_auth)])

# Track active bash processes so they can be killed on cancel
_active_procs = {} # pid -> subprocess.Popen
Expand Down Expand Up @@ -516,7 +528,7 @@ def create(
name: str | None = None,
template: str = TEMPLATE_SPACE,
hardware: str = "cpu-basic",
private: bool = False,
private: bool = True,
sleep_time: int | None = None,
token: str | None = None,
secrets: dict[str, str] | None = None,
Expand Down Expand Up @@ -670,6 +682,17 @@ def _wait_for_api(self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], ob
if resp.status_code == 200:
log(f"API is responsive at {self._base_url}")
return
# A reachable server that rejects auth will keep returning
# 401/403 for the full timeout — fail fast with a clear message
# instead of looping until TimeoutError hides the real cause.
if resp.status_code in (401, 403):
raise RuntimeError(
f"Sandbox API at {self._base_url} rejected auth "
f"(HTTP {resp.status_code}). Check that HF_TOKEN is set "
f"as a Space secret and matches the client token."
)
except RuntimeError:
raise
except Exception as e:
last_err = e
time.sleep(3)
Expand Down
2 changes: 1 addition & 1 deletion agent/tools/sandbox_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def _watch_cancel():
},
"private": {
"type": "boolean",
"description": "If true, create a private Space",
"description": "Whether the Space is private (default: true). Set false to create a public Space.",
},
},
},
Expand Down
102 changes: 102 additions & 0 deletions tests/unit/test_sandbox_server_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Tests for the embedded sandbox FastAPI server's bearer-token auth (issue #78)."""

import importlib.util
import subprocess

from fastapi.testclient import TestClient

from agent.tools.sandbox_client import _SANDBOX_SERVER


def _load_server(tmp_path, monkeypatch, token):
"""Write the embedded server source to disk and importlib-load it.

Module-level `_AUTH_TOKEN` is bound at import time from `os.environ`, so
`monkeypatch.setenv` before import is what makes each test isolated.
"""
monkeypatch.setenv("HF_TOKEN", token)
path = tmp_path / "sandbox_server.py"
path.write_text(_SANDBOX_SERVER)
spec = importlib.util.spec_from_file_location("sandbox_server_under_test", str(path))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def test_missing_authorization_header_rejects(tmp_path, monkeypatch):
mod = _load_server(tmp_path, monkeypatch, "secret-xyz")
client = TestClient(mod.app)
assert client.get("/api/health").status_code == 401


def test_bearer_wrong_token_rejects(tmp_path, monkeypatch):
mod = _load_server(tmp_path, monkeypatch, "secret-xyz")
client = TestClient(mod.app)
r = client.get("/api/health", headers={"Authorization": "Bearer wrong"})
assert r.status_code == 401


def test_bearer_correct_token_passes(tmp_path, monkeypatch):
mod = _load_server(tmp_path, monkeypatch, "secret-xyz")
client = TestClient(mod.app)
r = client.get("/api/health", headers={"Authorization": "Bearer secret-xyz"})
assert r.status_code == 200
assert r.json() == {"status": "ok"}


def test_bash_unauthenticated_never_executes(tmp_path, monkeypatch):
"""/api/bash must 401 before subprocess.Popen is invoked."""
mod = _load_server(tmp_path, monkeypatch, "secret-xyz")

def _fail(*_a, **_kw):
raise AssertionError("subprocess.Popen invoked without auth")

monkeypatch.setattr(subprocess, "Popen", _fail)
client = TestClient(mod.app)
r = client.post(
"/api/bash",
headers={"Authorization": "Bearer wrong"},
json={"command": "id", "work_dir": "/app", "timeout": 10},
)
assert r.status_code == 401


def test_fail_closed_when_hf_token_unset(tmp_path, monkeypatch):
"""With no HF_TOKEN in the env, every request must 401 — including ones
that present an empty Bearer value."""
mod = _load_server(tmp_path, monkeypatch, "")
client = TestClient(mod.app)
assert client.get("/api/health").status_code == 401
r = client.get("/api/health", headers={"Authorization": "Bearer "})
assert r.status_code == 401


def test_write_endpoint_also_protected(tmp_path, monkeypatch):
"""Spot-check that POST routes beyond /api/bash are covered by the
app-wide dependency (write/edit/read/kill/exists all share it)."""
mod = _load_server(tmp_path, monkeypatch, "secret-xyz")
client = TestClient(mod.app)
target = tmp_path / "should_not_exist.txt"
r = client.post(
"/api/write",
headers={"Authorization": "Bearer wrong"},
json={"path": str(target), "content": "pwned"},
)
assert r.status_code == 401
assert not target.exists()


def test_bash_with_valid_auth_executes(tmp_path, monkeypatch):
"""Positive-path check: with the correct Bearer, /api/bash actually runs
the command and returns its output. Balances the auth-only negative tests."""
mod = _load_server(tmp_path, monkeypatch, "secret-xyz")
client = TestClient(mod.app)
r = client.post(
"/api/bash",
headers={"Authorization": "Bearer secret-xyz"},
json={"command": "echo hello-sandbox", "work_dir": str(tmp_path), "timeout": 10},
)
assert r.status_code == 200
payload = r.json()
assert payload["success"] is True
assert "hello-sandbox" in payload["output"]