Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion agentic_security/middleware/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def setup_cors(app: FastAPI):
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
28 changes: 26 additions & 2 deletions agentic_security/routes/static.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from pathlib import Path

import requests
Expand All @@ -13,6 +14,11 @@
STATIC_DIR = Path(__file__).parent.parent / "static"
ICONS_DIR = STATIC_DIR / "icons"

# Strict allowlist for icon filenames: lowercase/uppercase alphanumerics, dots,
# underscores, and hyphens, must end in .png. Rejects path separators,
# percent-encoded characters, and non-PNG extensions before any FS/network I/O.
ICON_NAME_RE = re.compile(r"[A-Za-z0-9._-]+\.png")

# Configure templates with custom delimiters to avoid conflicts
templates = Jinja2Templates(directory=str(STATIC_DIR))
templates.env = Environment(
Expand Down Expand Up @@ -96,8 +102,26 @@ async def favicon() -> FileResponse:

@router.get("/icons/{icon_name}")
async def serve_icon(icon_name: str) -> FileResponse:
"""Serve an icon from the icons directory."""
icon_path = ICONS_DIR / icon_name
r"""Serve an icon from the icons directory.

``icon_name`` is validated against a strict allowlist before any filesystem
or outbound-HTTP access:

* Must match ``^[A-Za-z0-9._-]+\.png$`` — rejects path separators,
percent-encoded characters, non-PNG extensions, and empty names.
* The resolved path must stay inside ``ICONS_DIR`` — defense-in-depth
against any future URL-handling change that could decode ``%2F``.

Mitigates CWE-22 (path traversal) on both the local write and the
upstream npmmirror fetch.
"""
if not ICON_NAME_RE.fullmatch(icon_name):
raise HTTPException(status_code=400, detail="Invalid icon name")

icon_path = (ICONS_DIR / icon_name).resolve()
if not icon_path.is_relative_to(ICONS_DIR.resolve()):
raise HTTPException(status_code=400, detail="Invalid icon name")

if not icon_path.exists():
# Fetch the icon from the external URL and cache it
url = f"https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/dark/{icon_name}"
Expand Down
109 changes: 109 additions & 0 deletions tests/integration/routes/test_static_icon_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Tests for the /icons/{icon_name} allowlist and path-containment guard (CWE-22)."""

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient

from agentic_security.routes.static import ICON_NAME_RE, router

_app = FastAPI()
_app.include_router(router)
_client = TestClient(_app, raise_server_exceptions=False)


# ---------------------------------------------------------------------------
# ICON_NAME_RE unit tests — no HTTP round-trip needed
# ---------------------------------------------------------------------------


class TestIconNameRegex:
"""ICON_NAME_RE rejects all names that could cause path traversal or SSRF."""

@pytest.mark.parametrize(
"name",
[
"openai.png",
"claude-3.png",
"gpt_4.png",
"Anthropic.png",
"my.icon.png",
"test-icon-123.png",
"A.png",
],
)
def test_valid_names_match(self, name: str):
assert ICON_NAME_RE.fullmatch(name), f"Expected {name!r} to be accepted"

@pytest.mark.parametrize(
"name",
[
"../etc/passwd", # classic path traversal
"../../secret.png", # multi-level traversal
"foo%2Fbar.png", # percent-encoded slash
"foo/bar.png", # literal slash
"icon.PNG", # uppercase extension (not in the dataset)
"icon.jpg", # wrong extension
"icon.gif", # wrong extension
".png", # empty stem
"", # empty string
"icon.png.sh", # double extension — shell script suffix
"\x00icon.png", # null byte
"icon.png\n", # trailing newline (defeats $-anchored match)
],
)
def test_invalid_names_rejected(self, name: str):
assert not ICON_NAME_RE.fullmatch(name), f"Expected {name!r} to be rejected"


# ---------------------------------------------------------------------------
# Integration tests — HTTP-level validation via TestClient
# ---------------------------------------------------------------------------


class TestServeIconValidation:
"""serve_icon returns 400 for names that fail the allowlist check."""

@pytest.mark.parametrize(
"bad_name",
[
"no-extension",
"icon.jpg",
"icon.PNG",
".png",
"icon.png.sh",
],
)
def test_invalid_name_returns_400(self, bad_name: str):
"""Names that fail ICON_NAME_RE get a 400 before any FS/network I/O."""
response = _client.get(f"/icons/{bad_name}")
assert (
response.status_code == 400
), f"Expected 400 for {bad_name!r}, got {response.status_code}"
assert response.json().get("detail") == "Invalid icon name"

def test_valid_name_does_not_return_400(self, tmp_path, mocker):
"""A well-formed name passes validation; 404 is acceptable when the
upstream fetch also fails, but 400 must not be returned."""
mocker.patch("agentic_security.routes.static.ICONS_DIR", tmp_path)

# Stub the external HTTP call so the test is hermetic
mock_resp = mocker.MagicMock()
mock_resp.status_code = 404
mocker.patch(
"agentic_security.routes.static.requests.get", return_value=mock_resp
)

response = _client.get("/icons/openai.png")
assert (
response.status_code != 400
), "A valid icon name should not be rejected by the allowlist check"

def test_valid_name_served_from_cache(self, tmp_path, mocker):
"""When the icon file already exists locally it is served directly."""
mocker.patch("agentic_security.routes.static.ICONS_DIR", tmp_path)
icon_file = tmp_path / "openai.png"
icon_file.write_bytes(b"\x89PNG\r\n\x1a\n") # minimal PNG magic bytes

response = _client.get("/icons/openai.png")
assert response.status_code == 200
assert response.headers.get("content-type", "").startswith("image/png")
94 changes: 94 additions & 0 deletions tests/unit/test_cors_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Unit tests for CORS middleware configuration.

Verifies that the wildcard-origins + allow_credentials=True spec violation
(CORS spec §3.2, Fetch §4.7) has been removed. Browsers silently strip
credentials when the response carries Access-Control-Allow-Origin: * paired
with Access-Control-Allow-Credentials: true, so the old config was both
broken and misleading.
"""

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient

from agentic_security.middleware.cors import setup_cors


def _get_cors_options(app: FastAPI) -> dict:
"""Extract CORS middleware options from the app's middleware stack."""
for middleware in app.user_middleware:
if middleware.cls is CORSMiddleware:
return middleware.kwargs
return {}


class TestCorsSetup:
"""CORS middleware is configured correctly."""

def test_cors_middleware_is_registered(self):
"""setup_cors adds CORSMiddleware to the app."""
app = FastAPI()
setup_cors(app)
cls_names = [m.cls.__name__ for m in app.user_middleware]
assert "CORSMiddleware" in cls_names

def test_wildcard_origins_without_credentials(self):
"""allow_origins=['*'] must not be paired with allow_credentials=True.

The combination is forbidden by the CORS spec and causes browsers to
silently drop credentials on every cross-origin request.
"""
app = FastAPI()
setup_cors(app)
opts = _get_cors_options(app)
allow_origins = opts.get("allow_origins", [])
allow_credentials = opts.get("allow_credentials", False)

if "*" in allow_origins or allow_origins == ["*"]:
assert not allow_credentials, (
"allow_origins=['*'] with allow_credentials=True is invalid per "
"the CORS spec — browsers reject it and credentials are silently dropped"
)

def test_cors_allows_cross_origin_requests(self):
"""Cross-origin preflight requests return a 200 with CORS headers."""
app = FastAPI()

@app.get("/probe")
async def probe():
return {"ok": True}

setup_cors(app)
client = TestClient(app, raise_server_exceptions=True)

response = client.options(
"/probe",
headers={
"Origin": "http://localhost:3000",
"Access-Control-Request-Method": "GET",
},
)
assert response.status_code == 200
assert "access-control-allow-origin" in response.headers

def test_cors_no_credentials_header_with_wildcard(self):
"""With wildcard origins, the response must not include
Access-Control-Allow-Credentials: true."""
app = FastAPI()

@app.get("/probe")
async def probe():
return {"ok": True}

setup_cors(app)
client = TestClient(app)
response = client.get("/probe", headers={"Origin": "http://evil.example.com"})

acao = response.headers.get("access-control-allow-origin", "")
acac = response.headers.get("access-control-allow-credentials", "false")

if acao == "*":
assert acac.lower() != "true", (
"Wildcard ACAO + ACAC:true is a spec violation (RFC 6454 §7.2, "
"Fetch §4.7) and silently breaks credentialed cross-origin requests"
)