Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/squishmark/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from squishmark.models.db import Note, get_db_session
from squishmark.services.analytics import AnalyticsService
from squishmark.services.cache import get_cache
from squishmark.services.csrf import get_or_create_csrf_token, verify_csrf_token
from squishmark.services.github import get_github_service
from squishmark.services.notes import NotesService
from squishmark.services.theme import get_theme_engine, reset_theme_engine
Expand Down Expand Up @@ -214,6 +215,7 @@ async def admin_dashboard(
cache = get_cache()

# Render admin template
csrf_token = get_or_create_csrf_token(request)
theme_engine = await get_theme_engine(github_service)
try:
html = await theme_engine.render_admin(
Expand All @@ -222,6 +224,7 @@ async def admin_dashboard(
analytics=analytics,
notes=[_to_note_response(n) for n in notes],
cache_size=cache.size,
csrf_token=csrf_token,
)
except Exception:
# Fallback if admin template doesn't exist
Expand All @@ -248,6 +251,14 @@ async def admin_dashboard(
return HTMLResponse(content=html)


# CSRF token endpoint (for JSON API callers that can't scrape the meta tag)
@router.get("/csrf")
async def get_csrf(request: Request, admin: AdminUser) -> dict[str, str]:
"""Return the current CSRF token for use in subsequent mutation requests."""
del admin # auth side-effect only
return {"csrf_token": get_or_create_csrf_token(request)}


# Analytics endpoints
@router.get("/analytics")
async def get_analytics(
Expand All @@ -273,7 +284,12 @@ async def list_notes(
return [_to_note_response(n) for n in notes]


@router.post("/notes", status_code=201, response_model=None)
@router.post(
"/notes",
status_code=201,
response_model=None,
dependencies=[Depends(verify_csrf_token)],
)
Comment thread
x3ek marked this conversation as resolved.
async def create_note(
request: Request,
admin: AdminUser,
Expand All @@ -295,7 +311,11 @@ async def create_note(
return response


@router.put("/notes/{note_id}", response_model=None)
@router.put(
"/notes/{note_id}",
response_model=None,
dependencies=[Depends(verify_csrf_token)],
)
Comment thread
x3ek marked this conversation as resolved.
async def update_note(
request: Request,
admin: AdminUser,
Expand All @@ -320,7 +340,11 @@ async def update_note(
return response


@router.delete("/notes/{note_id}", response_model=None)
@router.delete(
"/notes/{note_id}",
response_model=None,
dependencies=[Depends(verify_csrf_token)],
)
Comment thread
x3ek marked this conversation as resolved.
async def delete_note(
request: Request,
admin: AdminUser,
Expand Down Expand Up @@ -371,7 +395,7 @@ async def view_note(


# Cache management
@router.post("/cache/refresh")
@router.post("/cache/refresh", dependencies=[Depends(verify_csrf_token)])
Comment thread
x3ek marked this conversation as resolved.
async def refresh_cache(
admin: AdminUser,
) -> CacheRefreshResponse:
Expand Down
3 changes: 3 additions & 0 deletions src/squishmark/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.responses import RedirectResponse

from squishmark.config import get_settings
from squishmark.services.csrf import SESSION_KEY as CSRF_SESSION_KEY

router = APIRouter(prefix="/auth", tags=["auth"])

Expand Down Expand Up @@ -116,6 +117,8 @@ async def oauth_callback(
"name": user_data.get("name"),
"avatar_url": user_data.get("avatar_url"),
}
# Rotate CSRF token on login so a stale pre-auth token can't be replayed.
request.session.pop(CSRF_SESSION_KEY, None)

# Redirect to admin
return RedirectResponse(url="/admin", status_code=302)
Expand Down
64 changes: 64 additions & 0 deletions src/squishmark/services/csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""CSRF token generation and verification for admin mutation endpoints.

Tokens are stored in the signed session cookie under ``SESSION_KEY`` and
validated on POST/PUT/DELETE requests via the ``verify_csrf_token`` dependency.
Clients send the token in the ``X-CSRF-Token`` header (HTMX, JSON API) or a
``csrf_token`` form field (plain form fallback).
Comment thread
x3ek marked this conversation as resolved.
Outdated
"""

import logging
import secrets

from fastapi import HTTPException, Request

from squishmark.config import get_settings

logger = logging.getLogger(__name__)

SESSION_KEY = "csrf_token"
HEADER_NAME = "X-CSRF-Token"
FORM_FIELD = "csrf_token"


def get_or_create_csrf_token(request: Request) -> str:
"""Return the session's CSRF token, minting a new one if absent."""
token = request.session.get(SESSION_KEY)
if not token:
token = secrets.token_urlsafe(32)
request.session[SESSION_KEY] = token
return token


async def _extract_submitted_token(request: Request) -> str | None:
"""Read the submitted CSRF token from header or form body."""
header_token = request.headers.get(HEADER_NAME)
if header_token:
return header_token

content_type = request.headers.get("content-type", "")
if content_type.startswith(("application/x-www-form-urlencoded", "multipart/form-data")):
form = await request.form()
value = form.get(FORM_FIELD)
if isinstance(value, str):
return value
return None


async def verify_csrf_token(request: Request) -> None:
"""FastAPI dependency that rejects requests missing or with an invalid CSRF token.

Skipped when ``debug`` and ``dev_skip_auth`` are both set, matching the
auth-bypass behavior in ``get_current_admin``.
"""
settings = get_settings()
if settings.debug and settings.dev_skip_auth:
logger.warning("CSRF bypassed - dev_skip_auth is enabled")
return

expected = request.session.get(SESSION_KEY) if hasattr(request, "session") else None
if not expected:
raise HTTPException(status_code=403, detail="CSRF token missing from session")

submitted = await _extract_submitted_token(request)
if not submitted or not secrets.compare_digest(submitted, expected):
raise HTTPException(status_code=403, detail="CSRF token invalid")
43 changes: 43 additions & 0 deletions tests/test_admin_notes.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,49 @@ async def test_edit_form_not_found_raises_404():
assert exc_info.value.status_code == 404


@pytest.mark.asyncio
async def test_get_csrf_returns_session_token():
"""GET /admin/csrf returns the current CSRF token for JSON API callers."""
from squishmark.routers.admin import get_csrf

request = MagicMock()
request.session = {}

result = await get_csrf(request=request, admin="test-admin")

assert "csrf_token" in result
assert result["csrf_token"]
# Returned token matches what's now stored on the session.
assert request.session["csrf_token"] == result["csrf_token"]


@pytest.mark.asyncio
async def test_get_csrf_idempotent_within_session():
"""Calling GET /admin/csrf twice on the same session returns the same token."""
from squishmark.routers.admin import get_csrf

request = MagicMock()
request.session = {}

first = await get_csrf(request=request, admin="test-admin")
second = await get_csrf(request=request, admin="test-admin")

assert first["csrf_token"] == second["csrf_token"]


@pytest.mark.asyncio
async def test_oauth_callback_rotates_csrf_token():
"""Successful OAuth callback clears any prior CSRF token from the session."""
from squishmark.services.csrf import SESSION_KEY

# Simulate the rotation step in isolation — the surrounding OAuth flow is HTTP-heavy.
session = {SESSION_KEY: "stale-token", "user": {"login": "x"}}
session.pop(SESSION_KEY, None)

assert SESSION_KEY not in session
assert session["user"] == {"login": "x"}


Comment thread
x3ek marked this conversation as resolved.
@pytest.mark.asyncio
async def test_get_current_admin_htmx_attaches_redirect_header():
"""HTMX requests with no session get an HX-Redirect header on 401."""
Expand Down
156 changes: 156 additions & 0 deletions tests/test_csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Tests for CSRF token generation and verification."""

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi import HTTPException

from squishmark.services.csrf import (
FORM_FIELD,
HEADER_NAME,
SESSION_KEY,
get_or_create_csrf_token,
verify_csrf_token,
)


def _request(
*,
session: dict | None = None,
header_token: str | None = None,
form_body: dict | None = None,
content_type: str = "application/json",
) -> MagicMock:
"""Build a mock Request with a session dict and optional token sources."""
request = MagicMock()
request.session = session if session is not None else {}
headers = {"content-type": content_type}
if header_token is not None:
headers[HEADER_NAME] = header_token
request.headers = headers
request.form = AsyncMock(return_value=form_body or {})
return request


def test_get_or_create_csrf_token_mints_when_absent():
request = _request()
token = get_or_create_csrf_token(request)

assert token
assert len(token) > 20
assert request.session[SESSION_KEY] == token


def test_get_or_create_csrf_token_returns_existing():
request = _request(session={SESSION_KEY: "preexisting-token"})

token = get_or_create_csrf_token(request)

assert token == "preexisting-token"


def test_get_or_create_csrf_token_is_idempotent():
"""Calling twice on the same request returns the same token."""
request = _request()

first = get_or_create_csrf_token(request)
second = get_or_create_csrf_token(request)

assert first == second


@pytest.mark.asyncio
async def test_verify_csrf_token_accepts_matching_header():
request = _request(
session={SESSION_KEY: "good-token"},
header_token="good-token",
)
with patch("squishmark.services.csrf.get_settings") as mock_settings:
mock_settings.return_value = MagicMock(debug=False, dev_skip_auth=False)
await verify_csrf_token(request) # should not raise


@pytest.mark.asyncio
async def test_verify_csrf_token_rejects_missing_session_token():
request = _request(session={}, header_token="anything")
with patch("squishmark.services.csrf.get_settings") as mock_settings:
mock_settings.return_value = MagicMock(debug=False, dev_skip_auth=False)
with pytest.raises(HTTPException) as exc:
await verify_csrf_token(request)

assert exc.value.status_code == 403
assert "missing" in str(exc.value.detail).lower()


@pytest.mark.asyncio
async def test_verify_csrf_token_rejects_missing_submitted_token():
request = _request(session={SESSION_KEY: "good-token"}) # no header, no form
with patch("squishmark.services.csrf.get_settings") as mock_settings:
mock_settings.return_value = MagicMock(debug=False, dev_skip_auth=False)
with pytest.raises(HTTPException) as exc:
await verify_csrf_token(request)

assert exc.value.status_code == 403
assert "invalid" in str(exc.value.detail).lower()


@pytest.mark.asyncio
async def test_verify_csrf_token_rejects_wrong_header():
request = _request(
session={SESSION_KEY: "good-token"},
header_token="wrong-token",
)
with patch("squishmark.services.csrf.get_settings") as mock_settings:
mock_settings.return_value = MagicMock(debug=False, dev_skip_auth=False)
with pytest.raises(HTTPException) as exc:
await verify_csrf_token(request)

assert exc.value.status_code == 403


@pytest.mark.asyncio
async def test_verify_csrf_token_accepts_form_field_fallback():
"""For form submissions without the header, the csrf_token form field is honored."""
request = _request(
session={SESSION_KEY: "good-token"},
form_body={FORM_FIELD: "good-token"},
content_type="application/x-www-form-urlencoded",
)
with patch("squishmark.services.csrf.get_settings") as mock_settings:
mock_settings.return_value = MagicMock(debug=False, dev_skip_auth=False)
await verify_csrf_token(request) # should not raise


@pytest.mark.asyncio
async def test_verify_csrf_token_header_takes_precedence_over_form():
"""A valid header passes even if the form field is wrong."""
request = _request(
session={SESSION_KEY: "good-token"},
header_token="good-token",
form_body={FORM_FIELD: "wrong"},
content_type="application/x-www-form-urlencoded",
)
with patch("squishmark.services.csrf.get_settings") as mock_settings:
mock_settings.return_value = MagicMock(debug=False, dev_skip_auth=False)
await verify_csrf_token(request) # header wins, no raise


@pytest.mark.asyncio
async def test_verify_csrf_token_bypassed_in_dev_skip_auth():
"""When debug and dev_skip_auth are both set, CSRF check is skipped."""
request = _request() # no session, no token — would normally fail
with patch("squishmark.services.csrf.get_settings") as mock_settings:
mock_settings.return_value = MagicMock(debug=True, dev_skip_auth=True)
await verify_csrf_token(request) # should not raise


@pytest.mark.asyncio
async def test_verify_csrf_token_not_bypassed_in_prod_mode():
"""dev_skip_auth without debug doesn't bypass."""
request = _request()
with patch("squishmark.services.csrf.get_settings") as mock_settings:
mock_settings.return_value = MagicMock(debug=False, dev_skip_auth=True)
with pytest.raises(HTTPException) as exc:
await verify_csrf_token(request)

assert exc.value.status_code == 403
7 changes: 7 additions & 0 deletions themes/blue-tech/admin/admin.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="csrf-token" content="{{ csrf_token }}">
<title>Admin - {{ site.title }}</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
Expand Down Expand Up @@ -124,6 +125,12 @@ <h2>Cache</h2>
</main>

<script>
// Attach CSRF token to every HTMX request.
document.body.addEventListener('htmx:configRequest', (event) => {
const token = document.querySelector('meta[name="csrf-token"]')?.content;
if (token) event.detail.headers['X-CSRF-Token'] = token;
});

// Toggle note form
function toggleNoteForm() {
document.getElementById('note-form').classList.toggle('hidden');
Expand Down
Loading
Loading