Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions packages/prime-sandboxes/src/prime_sandboxes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class Sandbox(BaseModel):
gpu_type: Optional[str] = Field(None, alias="gpuType")
vm: bool = False
network_access: bool = Field(True, alias="networkAccess")
allowed_domains: List[str] = Field(default_factory=list, alias="allowedDomains")
blocked_domains: List[str] = Field(default_factory=list, alias="blockedDomains")
status: str
timeout_minutes: int = Field(..., alias="timeoutMinutes")
idle_timeout_minutes: Optional[int] = Field(None, alias="idleTimeoutMinutes")
Expand Down Expand Up @@ -91,6 +93,8 @@ class CreateSandboxRequest(BaseModel):
gpu_type: Optional[str] = None
vm: bool = False
network_access: bool = True
allowed_domains: List[str] = Field(default_factory=list)
blocked_domains: List[str] = Field(default_factory=list)
timeout_minutes: int = 60
idle_timeout_minutes: Optional[int] = None
environment_vars: Optional[Dict[str, str]] = None
Expand Down Expand Up @@ -119,6 +123,30 @@ def validate_guaranteed(self) -> "CreateSandboxRequest":
raise ValueError("guaranteed is not supported for VM sandboxes")
return self

@model_validator(mode="after")
def validate_allowed_domains(self) -> "CreateSandboxRequest":
if self.allowed_domains:
if self.network_access:
raise ValueError(
"allowed_domains requires network_access=false "
"(it is an egress allowlist for restricted sandboxes)"
)
if self.vm:
raise ValueError("allowed_domains is not supported for VM sandboxes")
return self

@model_validator(mode="after")
def validate_blocked_domains(self) -> "CreateSandboxRequest":
if self.blocked_domains:
if not self.network_access:
raise ValueError(
"blocked_domains requires network_access=true "
"(it is an egress blocklist for unrestricted sandboxes)"
)
if self.vm:
raise ValueError("blocked_domains is not supported for VM sandboxes")
return self

@model_validator(mode="after")
def validate_idle_timeout(self) -> "CreateSandboxRequest":
if self.idle_timeout_minutes is None:
Expand Down
79 changes: 74 additions & 5 deletions packages/prime/src/prime_cli/commands/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def _format_sandbox_for_details(sandbox: Sandbox) -> Dict[str, Any]:
"gpu_type": getattr(sandbox, "gpu_type", None),
"vm": sandbox.vm,
"network_access": sandbox.network_access,
"allowed_domains": getattr(sandbox, "allowed_domains", []) or [],
"blocked_domains": getattr(sandbox, "blocked_domains", []) or [],
"timeout_minutes": sandbox.timeout_minutes,
# Read with getattr so an older installed prime-sandboxes wheel
# (without these fields) still renders the details view instead of
Expand Down Expand Up @@ -413,6 +415,10 @@ def get(
style="green" if sandbox_data["network_access"] else "yellow",
)
table.add_row("Network Access", network_display)
if sandbox_data.get("allowed_domains"):
table.add_row("Allowed Domains", ", ".join(sandbox_data["allowed_domains"]))
if sandbox_data.get("blocked_domains"):
table.add_row("Blocked Domains", ", ".join(sandbox_data["blocked_domains"]))
table.add_row("Timeout (minutes)", str(sandbox_data["timeout_minutes"]))
if sandbox_data.get("idle_timeout_minutes") is not None:
table.add_row("Idle Timeout (minutes)", str(sandbox_data["idle_timeout_minutes"]))
Expand Down Expand Up @@ -503,6 +509,24 @@ def create(
"--network-access/--no-network-access",
help="Allow outbound internet access (enabled by default)",
),
allowed_domains: Optional[List[str]] = typer.Option(
None,
"--allowed-domain",
help=(
"Egress domain allowlist for a restricted sandbox. "
"Wildcards like '*.example.com' are allowed. Requires "
"--no-network-access. Can be specified multiple times."
),
),
blocked_domains: Optional[List[str]] = typer.Option(
None,
"--blocked-domain",
help=(
"Egress domain blocklist for an unrestricted sandbox. "
"Wildcards like '*.example.com' are allowed. Requires "
"--network-access (the default). Can be specified multiple times."
),
),
timeout_minutes: int = typer.Option(60, help="Timeout in minutes"),
idle_timeout_minutes: Optional[int] = typer.Option(
None,
Expand Down Expand Up @@ -600,6 +624,28 @@ def create(
)
raise typer.Exit(1)

if allowed_domains:
if network_access:
console.print(
"[red]--allowed-domain requires --no-network-access.[/red] "
"It is an egress allowlist for restricted sandboxes."
)
raise typer.Exit(1)
if vm:
console.print("[red]--allowed-domain is not supported for VM sandboxes.[/red]")
raise typer.Exit(1)

if blocked_domains:
if not network_access:
console.print(
"[red]--blocked-domain requires --network-access.[/red] "
"It is an egress blocklist for unrestricted sandboxes."
)
raise typer.Exit(1)
if vm:
console.print("[red]--blocked-domain is not supported for VM sandboxes.[/red]")
raise typer.Exit(1)

if idle_timeout_minutes is not None:
if idle_timeout_minutes < 1:
console.print("[red]--idle-timeout-minutes must be at least 1.[/red]")
Expand Down Expand Up @@ -638,13 +684,32 @@ def create(
suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
name = f"{base_name}-{suffix}"

# Only forward idle_timeout_minutes if the installed prime-sandboxes
# SDK actually defines the field; older wheels would silently drop it
# via Pydantic's extra="ignore" default, hiding the misconfiguration
# from the user. SDK version floor is bumped in a follow-up release PR.
# Only forward fields if the installed prime-sandboxes SDK actually
# defines them
request_kwargs: Dict[str, Any] = {}
request_model_fields = CreateSandboxRequest.model_fields
if allowed_domains:
if "allowed_domains" not in request_model_fields:
console.print(
"[red]Installed prime-sandboxes SDK does not support "
"--allowed-domain.[/red] Upgrade prime-sandboxes before "
"using sandbox egress allowlists."
)
raise typer.Exit(1)
if blocked_domains:
if "blocked_domains" not in request_model_fields:
console.print(
"[red]Installed prime-sandboxes SDK does not support "
"--blocked-domain.[/red] Upgrade prime-sandboxes before "
"using sandbox egress blocklists."
)
raise typer.Exit(1)
if "allowed_domains" in request_model_fields:
request_kwargs["allowed_domains"] = allowed_domains if allowed_domains else []
if "blocked_domains" in request_model_fields:
request_kwargs["blocked_domains"] = blocked_domains if blocked_domains else []
if idle_timeout_minutes is not None:
if "idle_timeout_minutes" in CreateSandboxRequest.model_fields:
if "idle_timeout_minutes" in request_model_fields:
request_kwargs["idle_timeout_minutes"] = idle_timeout_minutes
else:
console.print(
Expand Down Expand Up @@ -687,6 +752,10 @@ def create(
console.print(f"GPUs: {gpu_type} x{gpu_count}")
network_status = "[green]Enabled[/green]" if network_access else "[yellow]Disabled[/yellow]"
console.print(f"Network Access: {network_status}")
if request_kwargs.get("allowed_domains"):
console.print(f"Allowed Domains: {', '.join(request_kwargs['allowed_domains'])}")
if request_kwargs.get("blocked_domains"):
console.print(f"Blocked Domains: {', '.join(request_kwargs['blocked_domains'])}")
console.print(f"Timeout: {timeout_minutes} minutes")
# Only show the idle timeout in the summary when the SDK actually
# accepted it; otherwise we'd display a value the backend never sees.
Expand Down
113 changes: 113 additions & 0 deletions packages/prime/tests/test_sandbox_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import pytest
from prime_cli.commands import sandbox as sandbox_cmd
from prime_cli.commands.sandbox import _format_sandbox_expiry
from prime_cli.main import app
from prime_cli.utils import strip_ansi
Expand Down Expand Up @@ -170,6 +171,118 @@ def mock_create(self: Any, request: Any) -> Any:
assert captured["request"].region == "eu-west"


def test_sandbox_create_forwards_allowed_domains(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("PRIME_API_KEY", "dummy")
monkeypatch.setenv("PRIME_DISABLE_VERSION_CHECK", "1")

captured: dict[str, Any] = {}

def mock_create(self: Any, request: Any) -> Any:
captured["request"] = request
return SimpleNamespace(id="sbx-allowed-domains")

monkeypatch.setattr("prime_cli.commands.sandbox.SandboxClient.create", mock_create)

result = runner.invoke(
app,
[
"sandbox",
"create",
"python:3.11-slim",
"--no-network-access",
"--allowed-domain",
"api.github.com",
"--allowed-domain",
"*.pypi.org",
"--yes",
],
)

output = strip_ansi(result.output)
assert result.exit_code == 0, f"Failed: {result.output}"
assert "Allowed Domains: api.github.com, *.pypi.org" in output
assert captured["request"].network_access is False
assert captured["request"].allowed_domains == ["api.github.com", "*.pypi.org"]


def test_sandbox_create_forwards_blocked_domains(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("PRIME_API_KEY", "dummy")
monkeypatch.setenv("PRIME_DISABLE_VERSION_CHECK", "1")

captured: dict[str, Any] = {}

def mock_create(self: Any, request: Any) -> Any:
captured["request"] = request
return SimpleNamespace(id="sbx-blocked-domains")

monkeypatch.setattr("prime_cli.commands.sandbox.SandboxClient.create", mock_create)

result = runner.invoke(
app,
[
"sandbox",
"create",
"python:3.11-slim",
"--blocked-domain",
"github.com",
"--blocked-domain",
"*.tracker.test",
"--yes",
],
)

output = strip_ansi(result.output)
assert result.exit_code == 0, f"Failed: {result.output}"
assert "Blocked Domains: github.com, *.tracker.test" in output
assert captured["request"].network_access is True
assert captured["request"].blocked_domains == ["github.com", "*.tracker.test"]


@pytest.mark.parametrize(
("flag_args", "missing_field", "unsupported_flag"),
[
(
["--no-network-access", "--allowed-domain", "api.github.com"],
"allowed_domains",
"--allowed-domain",
),
(["--blocked-domain", "github.com"], "blocked_domains", "--blocked-domain"),
],
)
def test_sandbox_create_domain_flags_require_sdk_field(
monkeypatch: pytest.MonkeyPatch,
flag_args: list[str],
missing_field: str,
unsupported_flag: str,
) -> None:
monkeypatch.setenv("PRIME_API_KEY", "dummy")
monkeypatch.setenv("PRIME_DISABLE_VERSION_CHECK", "1")

request_fields = dict(sandbox_cmd.CreateSandboxRequest.model_fields)
request_fields.pop(missing_field, None)
monkeypatch.setattr(sandbox_cmd.CreateSandboxRequest, "model_fields", request_fields)

called = False

def mock_create(self: Any, request: Any) -> Any:
nonlocal called
called = True
return SimpleNamespace(id="sbx-should-not-create")

monkeypatch.setattr("prime_cli.commands.sandbox.SandboxClient.create", mock_create)

result = runner.invoke(
app,
["sandbox", "create", "python:3.11-slim", *flag_args, "--yes"],
)

output = strip_ansi(result.output)
assert result.exit_code == 1
assert f"does not support {unsupported_flag}" in output
assert "Sandbox Configuration:" not in output
assert called is False


def test_sandbox_create_requires_gpu_type(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("PRIME_API_KEY", "dummy")
monkeypatch.setenv("PRIME_DISABLE_VERSION_CHECK", "1")
Expand Down
Loading