From e5a0e9e2e917d51e1e21e49f31cc981efa94fa88 Mon Sep 17 00:00:00 2001 From: Rafiuzzaman Khan Date: Fri, 5 Jun 2026 19:27:01 -0400 Subject: [PATCH] unify auth cookie persistence --- main.py | 20 ++-- routers/core/account.py | 125 +++++++++-------------- tests/routers/core/test_account.py | 158 ++++++++++++++++++++++++++++- tests/utils/test_auth.py | 46 +++++++++ tests/utils/test_dependencies.py | 36 ++++++- utils/core/auth.py | 81 +++++++++++++-- utils/core/dependencies.py | 3 +- 7 files changed, 365 insertions(+), 104 deletions(-) diff --git a/main.py b/main.py index 06a2921..cb53b2b 100644 --- a/main.py +++ b/main.py @@ -20,7 +20,7 @@ get_user_from_request, require_unauthenticated_client, ) -from utils.core.auth import COOKIE_SECURE +from utils.core.auth import refresh_token_is_persistent, set_auth_cookies from utils.core.htmx import ( is_htmx_request, toast_response, @@ -170,19 +170,11 @@ async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens): response = RedirectResponse( url=redirect_url, status_code=status.HTTP_307_TEMPORARY_REDIRECT ) - response.set_cookie( - key="access_token", - value=exc.access_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", - ) - response.set_cookie( - key="refresh_token", - value=exc.refresh_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", + set_auth_cookies( + response, + exc.access_token, + exc.refresh_token, + persistent=refresh_token_is_persistent(exc.refresh_token), ) return response diff --git a/routers/core/account.py b/routers/core/account.py index c9f9934..ee4f047 100644 --- a/routers/core/account.py +++ b/routers/core/account.py @@ -20,7 +20,6 @@ from utils.core.auth import ( HTML_PASSWORD_PATTERN, COMPILED_PASSWORD_PATTERN, - COOKIE_SECURE, MAX_EMAILS_PER_ACCOUNT, oauth2_scheme_cookie, get_password_hash, @@ -28,6 +27,8 @@ create_tracked_refresh_token, revoke_all_refresh_tokens, validate_token, + set_auth_cookies, + clear_auth_cookies, send_reset_email_task, send_email_verification, send_email_verified_notification, @@ -127,8 +128,7 @@ def logout( Log out a user by revoking their refresh token and clearing cookies. """ response = RedirectResponse(url="/", status_code=303) - response.delete_cookie("access_token") - response.delete_cookie("refresh_token") + clear_auth_cookies(response) _, refresh_token_value = tokens if refresh_token_value: @@ -382,7 +382,9 @@ async def register( # Create access token using the committed account's email access_token = create_access_token(data={"sub": account.email, "fresh": True}) - refresh_token = create_tracked_refresh_token(account.id, account.email, session) + refresh_token = create_tracked_refresh_token( + account.id, account.email, session, persistent=False + ) session.commit() # Set cookie — use HX-Redirect for HTMX, 303 for regular form submissions @@ -391,19 +393,11 @@ async def register( response.headers["HX-Redirect"] = str(redirect_url) else: response = RedirectResponse(url=str(redirect_url), status_code=303) - response.set_cookie( - key="access_token", - value=access_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", - ) - response.set_cookie( - key="refresh_token", - value=refresh_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", + set_auth_cookies( + response, + access_token, + refresh_token, + persistent=False, ) return response @@ -417,6 +411,7 @@ async def login( account_and_session: Tuple[Account, Session] = Depends( get_account_from_credentials ), + remember: Optional[str] = Form(None), invitation_token: Optional[str] = Form( None, title="Invitation token", @@ -507,8 +502,11 @@ async def login( # Create access token assert account.id is not None + persistent = remember == "on" access_token = create_access_token(data={"sub": account.email, "fresh": True}) - refresh_token = create_tracked_refresh_token(account.id, account.email, session) + refresh_token = create_tracked_refresh_token( + account.id, account.email, session, persistent=persistent + ) session.commit() # Set cookie — use HX-Redirect for HTMX, 303 for regular form submissions @@ -517,19 +515,11 @@ async def login( response.headers["HX-Redirect"] = str(redirect_url) else: response = RedirectResponse(url=str(redirect_url), status_code=303) - response.set_cookie( - key="access_token", - value=access_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", - ) - response.set_cookie( - key="refresh_token", - value=refresh_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", + set_auth_cookies( + response, + access_token, + refresh_token, + persistent=persistent, ) return response @@ -553,8 +543,7 @@ async def refresh_token( response = RedirectResponse( url=router.url_path_for("read_login"), status_code=303 ) - response.delete_cookie("access_token") - response.delete_cookie("refresh_token") + clear_auth_cookies(response) return response # Validate JTI server-side @@ -563,8 +552,7 @@ async def refresh_token( response = RedirectResponse( url=router.url_path_for("read_login"), status_code=303 ) - response.delete_cookie("access_token") - response.delete_cookie("refresh_token") + clear_auth_cookies(response) return response user_email = decoded_token.get("sub") @@ -591,32 +579,26 @@ async def refresh_token( response = RedirectResponse( url=router.url_path_for("read_login"), status_code=303 ) - response.delete_cookie("access_token") - response.delete_cookie("refresh_token") + clear_auth_cookies(response) return response # Revoke current token and issue new ones db_token.revoked = True + persistent = bool(decoded_token.get("persistent", False)) new_access_token = create_access_token(data={"sub": account.email, "fresh": False}) - new_refresh_token = create_tracked_refresh_token(account.id, account.email, session) + new_refresh_token = create_tracked_refresh_token( + account.id, account.email, session, persistent=persistent + ) session.commit() response = RedirectResponse( url=dashboard_router.url_path_for("read_dashboard"), status_code=303 ) - response.set_cookie( - key="access_token", - value=new_access_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", - ) - response.set_cookie( - key="refresh_token", - value=new_refresh_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", + set_auth_cookies( + response, + new_access_token, + new_refresh_token, + persistent=persistent, ) return response @@ -696,7 +678,7 @@ async def reset_password( data={"sub": authorized_account.email, "fresh": True} ) refresh_token = create_tracked_refresh_token( - authorized_account.id, authorized_account.email, session + authorized_account.id, authorized_account.email, session, persistent=False ) session.commit() @@ -709,19 +691,11 @@ async def reset_password( else: response = RedirectResponse(url=redirect_url, status_code=303) - response.set_cookie( - key="access_token", - value=access_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", - ) - response.set_cookie( - key="refresh_token", - value=refresh_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="strict", + set_auth_cookies( + response, + access_token, + refresh_token, + persistent=False, ) set_flash_cookie(response, message) return response @@ -961,7 +935,9 @@ async def promote_email( # Issue new tokens with the new primary email access_token = create_access_token(data={"sub": account.email, "fresh": True}) - refresh_token = create_tracked_refresh_token(account.id, account.email, session) + refresh_token = create_tracked_refresh_token( + account.id, account.email, session, persistent=False + ) session.commit() # Create recovery token and send notification to the old primary @@ -979,18 +955,11 @@ async def promote_email( else: response = RedirectResponse(url=str(profile_path), status_code=303) set_flash_cookie(response, "Primary email address updated.") - response.set_cookie( - key="access_token", - value=access_token, - httponly=True, - secure=COOKIE_SECURE, - samesite="lax", - ) - response.set_cookie( - key="refresh_token", - value=refresh_token, - httponly=True, - secure=COOKIE_SECURE, + set_auth_cookies( + response, + access_token, + refresh_token, + persistent=False, samesite="lax", ) return response diff --git a/tests/routers/core/test_account.py b/tests/routers/core/test_account.py index f6284a4..003067a 100644 --- a/tests/routers/core/test_account.py +++ b/tests/routers/core/test_account.py @@ -131,6 +131,68 @@ def test_login_endpoint(unauth_client: TestClient, test_account: Account): assert "refresh_token" in cookies +def test_login_with_remember_me_sets_max_age( + unauth_client: TestClient, test_account: Account +) -> None: + response = unauth_client.post( + app.url_path_for("login"), + data={ + "email": test_account.email, + "password": "Test123!@#", + "remember": "on", + }, + ) + assert response.status_code == 303 + cookie_headers = response.headers.get_list("set-cookie") + auth_cookies = [ + header + for header in cookie_headers + if header.startswith("access_token=") or header.startswith("refresh_token=") + ] + assert len(auth_cookies) == 2 + assert all("Max-Age=" in header for header in auth_cookies) + + +def test_login_without_remember_me_uses_session_cookies( + unauth_client: TestClient, test_account: Account +) -> None: + response = unauth_client.post( + app.url_path_for("login"), + data={"email": test_account.email, "password": "Test123!@#"}, + ) + assert response.status_code == 303 + cookie_headers = response.headers.get_list("set-cookie") + auth_cookies = [ + header + for header in cookie_headers + if header.startswith("access_token=") or header.startswith("refresh_token=") + ] + assert len(auth_cookies) == 2 + assert all("Max-Age=" not in header for header in auth_cookies) + + +def test_login_with_non_on_remember_uses_session_cookies( + unauth_client: TestClient, test_account: Account +) -> None: + response = unauth_client.post( + app.url_path_for("login"), + data={ + "email": test_account.email, + "password": "Test123!@#", + "remember": "false", + }, + ) + assert response.status_code == 303 + cookie_headers = response.headers.get_list("set-cookie") + auth_cookies = [ + header + for header in cookie_headers + if header.startswith("access_token=") or header.startswith("refresh_token=") + ] + assert len(auth_cookies) == 2 + assert all("Max-Age=" not in header for header in auth_cookies) + + def test_refresh_token_endpoint(auth_client: TestClient, test_account: Account): # Override just the access token to be expired, keeping the valid refresh token expired_access_token = create_access_token( @@ -161,6 +223,60 @@ def test_refresh_token_endpoint(auth_client: TestClient, test_account: Account): assert decoded["sub"] == test_account.email +def test_refresh_token_endpoint_preserves_persistent_max_age( + session: Session, test_account: Account, test_user: User +) -> None: + refresh_jwt = create_tracked_refresh_token( + test_account.id, test_account.email, session, persistent=True + ) + session.commit() + + client = TestClient(app, follow_redirects=False) + expired_access_token = create_access_token( + {"sub": test_account.email}, timedelta(minutes=-10) + ) + client.cookies.set("access_token", expired_access_token) + client.cookies.set("refresh_token", refresh_jwt) + + response = client.post(app.url_path_for("refresh_token")) + assert response.status_code == 303 + + auth_cookies = [ + header + for header in response.headers.get_list("set-cookie") + if header.startswith("access_token=") or header.startswith("refresh_token=") + ] + assert len(auth_cookies) == 2 + assert all("Max-Age=" in header for header in auth_cookies) + + +def test_refresh_token_endpoint_preserves_session_cookies( + session: Session, test_account: Account, test_user: User +) -> None: + refresh_jwt = create_tracked_refresh_token( + test_account.id, test_account.email, session, persistent=False + ) + session.commit() + + client = TestClient(app, follow_redirects=False) + expired_access_token = create_access_token( + {"sub": test_account.email}, timedelta(minutes=-10) + ) + client.cookies.set("access_token", expired_access_token) + client.cookies.set("refresh_token", refresh_jwt) + + response = client.post(app.url_path_for("refresh_token")) + assert response.status_code == 303 + + auth_cookies = [ + header + for header in response.headers.get_list("set-cookie") + if header.startswith("access_token=") or header.startswith("refresh_token=") + ] + assert len(auth_cookies) == 2 + assert all("Max-Age=" not in header for header in auth_cookies) + + def test_password_reset_flow( unauth_client: TestClient, session: Session, test_account: Account, mock_resend_send ): @@ -799,7 +915,7 @@ def test_refresh_reuse_detection_revokes_all_tokens( """Replaying a revoked refresh token revokes ALL tokens for that account.""" # Create a tracked refresh token and immediately revoke it (simulating prior use) refresh_jwt = create_tracked_refresh_token( - test_account.id, test_account.email, session + test_account.id, test_account.email, session, persistent=True ) session.commit() @@ -866,7 +982,7 @@ def test_automatic_token_refresh_via_dependency( """When access token expires, the dependency auto-refreshes using the refresh token.""" # Create a tracked refresh token refresh_jwt = create_tracked_refresh_token( - test_account.id, test_account.email, session + test_account.id, test_account.email, session, persistent=True ) session.commit() @@ -886,8 +1002,13 @@ def test_automatic_token_refresh_via_dependency( assert response.status_code == 307 cookie_headers = response.headers.get_list("set-cookie") - assert any("access_token=" in c for c in cookie_headers) - assert any("refresh_token=" in c for c in cookie_headers) + auth_cookies = [ + header + for header in cookie_headers + if header.startswith("access_token=") or header.startswith("refresh_token=") + ] + assert len(auth_cookies) == 2 + assert all("Max-Age=" in header for header in auth_cookies) # Old refresh token should be revoked session.expire_all() @@ -901,6 +1022,35 @@ def test_automatic_token_refresh_via_dependency( assert len(active_tokens) == 1 +def test_automatic_token_refresh_preserves_session_cookies( + session: Session, test_account: Account, test_user: User +) -> None: + """Silent rotation should keep session cookies when the refresh token is not persistent.""" + refresh_jwt = create_tracked_refresh_token( + test_account.id, test_account.email, session, persistent=False + ) + session.commit() + + expired_access = create_access_token( + {"sub": test_account.email}, timedelta(minutes=-10) + ) + + client = TestClient(app, follow_redirects=False) + client.cookies.set("access_token", expired_access) + client.cookies.set("refresh_token", refresh_jwt) + + response = client.get(app.url_path_for("read_dashboard")) + assert response.status_code == 307 + + auth_cookies = [ + header + for header in response.headers.get_list("set-cookie") + if header.startswith("access_token=") or header.startswith("refresh_token=") + ] + assert len(auth_cookies) == 2 + assert all("Max-Age=" not in header for header in auth_cookies) + + # --- Add Email Tests --- diff --git a/tests/utils/test_auth.py b/tests/utils/test_auth.py index 9cc80eb..d890447 100644 --- a/tests/utils/test_auth.py +++ b/tests/utils/test_auth.py @@ -4,6 +4,7 @@ from datetime import timedelta from urllib.parse import urlparse, parse_qs from starlette.datastructures import URLPath +from starlette.responses import Response import uuid from main import app from utils.core.auth import ( @@ -15,6 +16,9 @@ generate_password_reset_url, COMPILED_PASSWORD_PATTERN, convert_python_regex_to_html, + auth_cookie_max_ages, + set_auth_cookies, + refresh_token_is_persistent, ) @@ -153,3 +157,45 @@ def test_password_pattern() -> None: # No special character password = "aA1" * 3 assert re.match(COMPILED_PASSWORD_PATTERN, password) is None + + +def test_auth_cookie_max_ages(env_vars) -> None: + session_access, session_refresh = auth_cookie_max_ages(persistent=False) + assert session_access is None + assert session_refresh is None + + persistent_access, persistent_refresh = auth_cookie_max_ages(persistent=True) + assert persistent_access == 30 * 60 + assert persistent_refresh == 30 * 24 * 60 * 60 + + +def test_set_auth_cookies_persistent(env_vars) -> None: + response = Response() + set_auth_cookies(response, "access", "refresh", persistent=True) + headers = response.headers.getlist("set-cookie") + assert len(headers) == 2 + assert all("Max-Age=" in header for header in headers) + + +def test_set_auth_cookies_session(env_vars) -> None: + response = Response() + set_auth_cookies(response, "access", "refresh", persistent=False) + headers = response.headers.getlist("set-cookie") + assert len(headers) == 2 + assert all("Max-Age=" not in header for header in headers) + + +def test_refresh_token_is_persistent(env_vars) -> None: + jti = str(uuid.uuid4()) + persistent_token = create_refresh_token( + {"sub": "test@example.com", "persistent": True}, + jti=jti, + expires_delta=timedelta(days=30), + ) + assert refresh_token_is_persistent(persistent_token) is True + + session_token = create_refresh_token( + {"sub": "test@example.com", "persistent": False}, + jti=str(uuid.uuid4()), + ) + assert refresh_token_is_persistent(session_token) is False diff --git a/tests/utils/test_dependencies.py b/tests/utils/test_dependencies.py index aeb40c6..d6f6f17 100644 --- a/tests/utils/test_dependencies.py +++ b/tests/utils/test_dependencies.py @@ -83,7 +83,41 @@ def test_validate_token_and_get_account() -> None: assert refresh_token == "new_refresh_token" assert mock_db_token.revoked is True mock_tracked_refresh.assert_called_once_with( - 1, "test@example.com", session + 1, "test@example.com", session, persistent=False + ) + + # Test refresh rotation preserves persistent=True from the old token + with patch("utils.core.dependencies.validate_token") as mock_validate: + with patch("utils.core.dependencies.create_access_token") as mock_access_token: + with patch( + "utils.core.dependencies.create_tracked_refresh_token" + ) as mock_tracked_refresh: + mock_validate.return_value = { + "sub": "test@example.com", + "type": "refresh", + "jti": "test-jti", + "persistent": True, + } + mock_access_token.return_value = "new_access_token" + mock_tracked_refresh.return_value = "new_refresh_token" + + mock_db_token = MagicMock() + mock_db_token.account_id = 1 + mock_db_token.revoked = False + + session.exec.return_value.first.side_effect = [ + mock_account, + mock_db_token, + ] + + account, access_token, refresh_token = validate_token_and_get_account( + "valid_token", "refresh", session + ) + assert account == mock_account + assert access_token == "new_access_token" + assert refresh_token == "new_refresh_token" + mock_tracked_refresh.assert_called_once_with( + 1, "test@example.com", session, persistent=True ) # Test with refresh token missing JTI (legacy token) diff --git a/utils/core/auth.py b/utils/core/auth.py index 1a6bbda..f631d81 100644 --- a/utils/core/auth.py +++ b/utils/core/auth.py @@ -8,10 +8,11 @@ from sqlmodel import Session, select from bcrypt import gensalt, hashpw, checkpw from datetime import UTC, datetime, timedelta -from typing import Optional +from typing import Literal, Optional from jinja2.environment import Template from fastapi.templating import Jinja2Templates from fastapi import Cookie +from starlette.responses import Response from utils.core.db import create_engine, get_connection_url from utils.core.models import ( AccountRecoveryToken, @@ -34,6 +35,10 @@ ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 REFRESH_TOKEN_EXPIRE_DAYS = 30 +SESSION_REFRESH_TOKEN_EXPIRE_HOURS = 12 + +ACCESS_TOKEN_COOKIE_NAME = "access_token" +REFRESH_TOKEN_COOKIE_NAME = "refresh_token" PASSWORD_PATTERN_COMPONENTS = [ r"(?=.*\d)", # At least one digit r"(?=.*[a-z])", # At least one lowercase letter @@ -89,12 +94,62 @@ def replacer(match: re.Match) -> str: # Define the oauth2 scheme to get the token from the cookie def oauth2_scheme_cookie( - access_token: Optional[str] = Cookie(None, alias="access_token"), - refresh_token: Optional[str] = Cookie(None, alias="refresh_token"), + access_token: Optional[str] = Cookie(None, alias=ACCESS_TOKEN_COOKIE_NAME), + refresh_token: Optional[str] = Cookie(None, alias=REFRESH_TOKEN_COOKIE_NAME), ) -> tuple[Optional[str], Optional[str]]: return access_token, refresh_token +def auth_cookie_max_ages(*, persistent: bool) -> tuple[Optional[int], Optional[int]]: + """Return (access_max_age, refresh_max_age) in seconds; None means session cookie.""" + if not persistent: + return None, None + return ( + ACCESS_TOKEN_EXPIRE_MINUTES * 60, + REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, + ) + + +def set_auth_cookies( + response: Response, + access_token: str, + refresh_token: str, + *, + persistent: bool, + samesite: Literal["lax", "strict", "none"] = "strict", +) -> None: + """Set httponly auth cookies with session or persistent lifetime.""" + access_max_age, refresh_max_age = auth_cookie_max_ages(persistent=persistent) + response.set_cookie( + key=ACCESS_TOKEN_COOKIE_NAME, + value=access_token, + httponly=True, + secure=COOKIE_SECURE, + samesite=samesite, + max_age=access_max_age, + ) + response.set_cookie( + key=REFRESH_TOKEN_COOKIE_NAME, + value=refresh_token, + httponly=True, + secure=COOKIE_SECURE, + samesite=samesite, + max_age=refresh_max_age, + ) + + +def clear_auth_cookies(response: Response) -> None: + response.delete_cookie(ACCESS_TOKEN_COOKIE_NAME) + response.delete_cookie(REFRESH_TOKEN_COOKIE_NAME) + + +def refresh_token_is_persistent(refresh_token: str) -> bool: + decoded = validate_token(refresh_token, token_type="refresh") + if decoded is None: + return False + return bool(decoded.get("persistent", False)) + + def get_password_hash(password: str) -> str: """ Hash a password using bcrypt with a random salt @@ -140,16 +195,30 @@ def create_refresh_token( return encoded_jwt -def create_tracked_refresh_token(account_id: int, email: str, session: Session) -> str: +def create_tracked_refresh_token( + account_id: int, + email: str, + session: Session, + *, + persistent: bool = False, +) -> str: jti = str(uuid.uuid4()) - expires_at = datetime.now(UTC) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + if persistent: + expires_delta = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + else: + expires_delta = timedelta(hours=SESSION_REFRESH_TOKEN_EXPIRE_HOURS) + expires_at = datetime.now(UTC) + expires_delta db_token = RefreshToken( account_id=account_id, jti=jti, expires_at=expires_at, ) session.add(db_token) - token = create_refresh_token(data={"sub": email}, jti=jti) + token = create_refresh_token( + data={"sub": email, "persistent": persistent}, + jti=jti, + expires_delta=expires_delta, + ) return token diff --git a/utils/core/dependencies.py b/utils/core/dependencies.py index 17ec682..971c12e 100644 --- a/utils/core/dependencies.py +++ b/utils/core/dependencies.py @@ -97,9 +97,10 @@ def validate_token_and_get_account( # Revoke the current token and issue new ones db_token.revoked = True + persistent = bool(decoded_token.get("persistent", False)) new_access_token = create_access_token(data={"sub": account.email}) new_refresh_token = create_tracked_refresh_token( - account.id, account.email, session + account.id, account.email, session, persistent=persistent ) session.commit() return account, new_access_token, new_refresh_token