diff --git a/tests/conftest.py b/tests/conftest.py index 7f9f058a..6faedbde 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -314,7 +314,7 @@ def wrapper(*args, **kwargs): # Create mock JWKS client mock_jwks = Mock(spec=PyJWKClient) mock_signing_key = Mock() - mock_signing_key.key = kwargs["TEST_CONSTANTS"]["PUBLIC_KEY"] + mock_signing_key.key = kwargs["session_constants"]["PUBLIC_KEY"] mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key # Apply the mock diff --git a/tests/test_session.py b/tests/test_session.py index 254c9cf0..b2fb654e 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from tests.conftest import with_jwks_mock -from workos.session import Session +from workos.session import AsyncSession, Session from workos.types.user_management.authentication_response import ( RefreshTokenAuthenticationResponse, ) @@ -14,358 +14,482 @@ RefreshWithSessionCookieErrorResponse, RefreshWithSessionCookieSuccessResponse, ) -from workos.types.user_management.user import User from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -@pytest.fixture(scope="session") -def TEST_CONSTANTS(): - # Generate RSA key pair for testing - private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) - - public_key = private_key.public_key() - - # Get the private key in PEM format - private_pem = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - current_datetime = datetime.now(timezone.utc) - current_timestamp = str(current_datetime) - - token_claims = { - "sid": "session_123", - "org_id": "organization_123", - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], - "exp": int(current_datetime.timestamp()) + 3600, - "iat": int(current_datetime.timestamp()), - } - - user_id = "user_123" - - return { - "COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=", - "SESSION_DATA": "session_data", - "CLIENT_ID": "client_123", - "USER_ID": user_id, - "SESSION_ID": "session_123", - "ORGANIZATION_ID": "organization_123", - "CURRENT_DATETIME": current_datetime, - "CURRENT_TIMESTAMP": current_timestamp, - "PRIVATE_KEY": private_pem, - "PUBLIC_KEY": public_key, - "TEST_TOKEN": jwt.encode(token_claims, private_pem, algorithm="RS256"), - "TEST_TOKEN_CLAIMS": token_claims, - "TEST_USER": { - "object": "user", - "id": user_id, - "email": "user@example.com", - "first_name": "Test", - "last_name": "User", - "email_verified": True, - "created_at": current_timestamp, - "updated_at": current_timestamp, - }, - } - - -@pytest.fixture -def mock_user_management(): - mock = Mock() - mock.get_jwks_url.return_value = ( - "https://api.workos.com/user_management/sso/jwks/client_123" - ) - - return mock - - -@with_jwks_mock -def test_initialize_session_module(TEST_CONSTANTS, mock_user_management): - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=TEST_CONSTANTS["SESSION_DATA"], - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - assert session.client_id == TEST_CONSTANTS["CLIENT_ID"] - assert session.cookie_password is not None - - -@with_jwks_mock -def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management): - with pytest.raises(ValueError, match="cookie_password is required"): - Session( +class SessionFixtures: + @pytest.fixture + def session_constants(self): + # Generate RSA key pair for testing + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + public_key = private_key.public_key() + + # Get the private key in PEM format + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + current_datetime = datetime.now(timezone.utc) + current_timestamp = str(current_datetime) + + token_claims = { + "sid": "session_123", + "org_id": "organization_123", + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + "exp": int(current_datetime.timestamp()) + 3600, + "iat": int(current_datetime.timestamp()), + } + + user_id = "user_123" + + return { + "COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=", + "SESSION_DATA": "session_data", + "CLIENT_ID": "client_123", + "USER_ID": user_id, + "SESSION_ID": "session_123", + "ORGANIZATION_ID": "organization_123", + "CURRENT_DATETIME": current_datetime, + "CURRENT_TIMESTAMP": current_timestamp, + "PRIVATE_KEY": private_pem, + "PUBLIC_KEY": public_key, + "TEST_TOKEN": jwt.encode(token_claims, private_pem, algorithm="RS256"), + "TEST_TOKEN_CLAIMS": token_claims, + "TEST_USER": { + "object": "user", + "id": user_id, + "email": "user@example.com", + "first_name": "Test", + "last_name": "User", + "email_verified": True, + "created_at": current_timestamp, + "updated_at": current_timestamp, + }, + } + + @pytest.fixture + def mock_user_management(self): + mock = Mock() + mock.get_jwks_url.return_value = ( + "https://api.workos.com/user_management/sso/jwks/client_123" + ) + + return mock + + +class TestSessionBase(SessionFixtures): + @with_jwks_mock + def test_initialize_session_module(self, session_constants, mock_user_management): + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_constants["SESSION_DATA"], + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + + assert session.client_id == session_constants["CLIENT_ID"] + assert session.cookie_password is not None + + @with_jwks_mock + def test_initialize_without_cookie_password( + self, session_constants, mock_user_management + ): + with pytest.raises(ValueError, match="cookie_password is required"): + Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_constants["SESSION_DATA"], + cookie_password="", + ) + + @with_jwks_mock + def test_authenticate_no_session_cookie_provided( + self, session_constants, mock_user_management + ): + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data="", + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + + response = session.authenticate() + + assert response.authenticated is False + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED + ) + + @with_jwks_mock + def test_authenticate_invalid_session_cookie( + self, session_constants, mock_user_management + ): + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data="invalid_session_data", + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + + response = session.authenticate() + + assert response.authenticated is False + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + + @with_jwks_mock + def test_authenticate_invalid_jwt(self, session_constants, mock_user_management): + invalid_session_data = Session.seal_data( + {"access_token": "invalid_session_data"}, + session_constants["COOKIE_PASSWORD"], + ) + session = Session( user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=TEST_CONSTANTS["SESSION_DATA"], - cookie_password="", + client_id=session_constants["CLIENT_ID"], + session_data=invalid_session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], ) + response = session.authenticate() + assert response.authenticated is False + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT -@with_jwks_mock -def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management): - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=None, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - response = session.authenticate() - - assert ( - response.reason - == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED - ) - - -@with_jwks_mock -def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data="invalid_session_data", - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - response = session.authenticate() - - assert ( - response.reason - == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE - ) - - -@with_jwks_mock -def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): - invalid_session_data = Session.seal_data( - {"access_token": "invalid_session_data"}, TEST_CONSTANTS["COOKIE_PASSWORD"] - ) - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=invalid_session_data, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - response = session.authenticate() - - assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT - - -@with_jwks_mock -def test_authenticate_jwt_with_aud_claim(TEST_CONSTANTS, mock_user_management): - access_token = jwt.encode( - {**TEST_CONSTANTS["TEST_TOKEN_CLAIMS"], **{"aud": TEST_CONSTANTS["CLIENT_ID"]}}, - TEST_CONSTANTS["PRIVATE_KEY"], - algorithm="RS256", - ) - - session_data = Session.seal_data( - {"access_token": access_token, "user": TEST_CONSTANTS["TEST_USER"]}, - TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=session_data, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - response = session.authenticate() - - assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse) - - -@with_jwks_mock -def test_authenticate_success(TEST_CONSTANTS, mock_user_management): - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=TEST_CONSTANTS["SESSION_DATA"], - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - # Mock the session data that would be unsealed - mock_session = { - "access_token": jwt.encode( + @with_jwks_mock + def test_authenticate_jwt_with_aud_claim( + self, session_constants, mock_user_management + ): + access_token = jwt.encode( { - "sid": TEST_CONSTANTS["SESSION_ID"], - "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], - "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, - "iat": int(datetime.now(timezone.utc).timestamp()), + **session_constants["TEST_TOKEN_CLAIMS"], + **{"aud": session_constants["CLIENT_ID"]}, }, - TEST_CONSTANTS["PRIVATE_KEY"], + session_constants["PRIVATE_KEY"], algorithm="RS256", - ), - "user": { - "object": "user", - "id": TEST_CONSTANTS["USER_ID"], - "email": "user@example.com", - "email_verified": True, - "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], - "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], - }, - "impersonator": None, - } - - # Mock the JWT payload that would be decoded - mock_jwt_payload = { - "sid": TEST_CONSTANTS["SESSION_ID"], - "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], - } - - with patch.object(Session, "unseal_data", return_value=mock_session), patch.object( - session, "_is_valid_jwt", return_value=True - ), patch("jwt.decode", return_value=mock_jwt_payload), patch.object( - session.jwks, - "get_signing_key_from_jwt", - return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]), - ): + ) + + session_data = Session.seal_data( + {"access_token": access_token, "user": session_constants["TEST_USER"]}, + session_constants["COOKIE_PASSWORD"], + ) + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + response = session.authenticate() assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse) - assert response.authenticated is True - assert response.session_id == TEST_CONSTANTS["SESSION_ID"] - assert response.organization_id == TEST_CONSTANTS["ORGANIZATION_ID"] - assert response.role == "admin" - assert response.permissions == ["read"] - assert response.entitlements == ["feature_1"] - assert response.user.id == TEST_CONSTANTS["USER_ID"] - assert response.impersonator is None - - -@with_jwks_mock -def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data="invalid_session_data", - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - response = session.refresh() - - assert isinstance(response, RefreshWithSessionCookieErrorResponse) - assert ( - response.reason - == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE - ) - - -@with_jwks_mock -def test_refresh_success(TEST_CONSTANTS, mock_user_management): - session_data = Session.seal_data( - {"refresh_token": "refresh_token_12345", "user": TEST_CONSTANTS["TEST_USER"]}, - TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - mock_response = { - "access_token": TEST_CONSTANTS["TEST_TOKEN"], - "refresh_token": "refresh_token_123", - "sealed_session": session_data, - "user": TEST_CONSTANTS["TEST_USER"], - } - - mock_user_management.authenticate_with_refresh_token.return_value = ( - RefreshTokenAuthenticationResponse(**mock_response) - ) - - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=session_data, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) - - with patch.object(session, "_is_valid_jwt", return_value=True) as _: - with patch( - "jwt.decode", - return_value={ - "sid": TEST_CONSTANTS["SESSION_ID"], - "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], + + @with_jwks_mock + def test_authenticate_success(self, session_constants, mock_user_management): + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_constants["SESSION_DATA"], + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + + # Mock the session data that would be unsealed + mock_session = { + "access_token": jwt.encode( + { + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, + "iat": int(datetime.now(timezone.utc).timestamp()), + }, + session_constants["PRIVATE_KEY"], + algorithm="RS256", + ), + "user": { + "object": "user", + "id": session_constants["USER_ID"], + "email": "user@example.com", + "email_verified": True, + "created_at": session_constants["CURRENT_TIMESTAMP"], + "updated_at": session_constants["CURRENT_TIMESTAMP"], }, + "impersonator": None, + } + + # Mock the JWT payload that would be decoded + mock_jwt_payload = { + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + } + + with patch.object( + Session, "unseal_data", return_value=mock_session + ), patch.object(session, "_is_valid_jwt", return_value=True), patch( + "jwt.decode", return_value=mock_jwt_payload + ), patch.object( + session.jwks, + "get_signing_key_from_jwt", + return_value=Mock(key=session_constants["PUBLIC_KEY"]), ): - response = session.refresh() + response = session.authenticate() - assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse) assert response.authenticated is True - assert response.user.id == TEST_CONSTANTS["TEST_USER"]["id"] + assert response.session_id == session_constants["SESSION_ID"] + assert response.organization_id == session_constants["ORGANIZATION_ID"] + assert response.role == "admin" + assert response.permissions == ["read"] + assert response.entitlements == ["feature_1"] + assert response.user.id == session_constants["USER_ID"] + assert response.impersonator is None + + @with_jwks_mock + def test_refresh_invalid_session_cookie( + self, session_constants, mock_user_management + ): + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data="invalid_session_data", + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieErrorResponse) + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + + def test_seal_data(self, session_constants): + test_data = {"test": "data"} + sealed = Session.seal_data(test_data, session_constants["COOKIE_PASSWORD"]) + assert isinstance(sealed, str) - # Verify the refresh token was used correctly - mock_user_management.authenticate_with_refresh_token.assert_called_once_with( - refresh_token="refresh_token_12345", - organization_id=None, - session={ - "seal_session": True, - "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"], - }, - ) + # Test unsealing + unsealed = Session.unseal_data(sealed, session_constants["COOKIE_PASSWORD"]) + assert unsealed == test_data -@with_jwks_mock -def test_refresh_success_with_aud_claim(TEST_CONSTANTS, mock_user_management): - session_data = Session.seal_data( - {"refresh_token": "refresh_token_12345", "user": TEST_CONSTANTS["TEST_USER"]}, - TEST_CONSTANTS["COOKIE_PASSWORD"], - ) + def test_unseal_invalid_data(self, session_constants): + with pytest.raises( + Exception + ): # Adjust exception type based on your implementation + Session.unseal_data( + "invalid_sealed_data", session_constants["COOKIE_PASSWORD"] + ) - access_token = jwt.encode( - {**TEST_CONSTANTS["TEST_TOKEN_CLAIMS"], **{"aud": TEST_CONSTANTS["CLIENT_ID"]}}, - TEST_CONSTANTS["PRIVATE_KEY"], - algorithm="RS256", - ) - mock_response = { - "access_token": access_token, - "refresh_token": "refresh_token_123", - "sealed_session": session_data, - "user": TEST_CONSTANTS["TEST_USER"], - } +class TestSession(SessionFixtures): + @with_jwks_mock + def test_refresh_success(self, session_constants, mock_user_management): + session_data = Session.seal_data( + { + "refresh_token": "refresh_token_12345", + "user": session_constants["TEST_USER"], + }, + session_constants["COOKIE_PASSWORD"], + ) + + mock_response = { + "access_token": session_constants["TEST_TOKEN"], + "refresh_token": "refresh_token_123", + "sealed_session": session_data, + "user": session_constants["TEST_USER"], + } + + mock_user_management.authenticate_with_refresh_token.return_value = ( + RefreshTokenAuthenticationResponse(**mock_response) + ) + + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + + with patch.object(session, "_is_valid_jwt", return_value=True) as _: + with patch( + "jwt.decode", + return_value={ + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + }, + ): + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == session_constants["TEST_USER"]["id"] + + # Verify the refresh token was used correctly + mock_user_management.authenticate_with_refresh_token.assert_called_once_with( + refresh_token="refresh_token_12345", + organization_id=None, + session={ + "seal_session": True, + "cookie_password": session_constants["COOKIE_PASSWORD"], + }, + ) + + @with_jwks_mock + def test_refresh_success_with_aud_claim( + self, session_constants, mock_user_management + ): + session_data = Session.seal_data( + { + "refresh_token": "refresh_token_12345", + "user": session_constants["TEST_USER"], + }, + session_constants["COOKIE_PASSWORD"], + ) + + access_token = jwt.encode( + { + **session_constants["TEST_TOKEN_CLAIMS"], + **{"aud": session_constants["CLIENT_ID"]}, + }, + session_constants["PRIVATE_KEY"], + algorithm="RS256", + ) + + mock_response = { + "access_token": access_token, + "refresh_token": "refresh_token_123", + "sealed_session": session_data, + "user": session_constants["TEST_USER"], + } + + mock_user_management.authenticate_with_refresh_token.return_value = ( + RefreshTokenAuthenticationResponse(**mock_response) + ) + + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) - mock_user_management.authenticate_with_refresh_token.return_value = ( - RefreshTokenAuthenticationResponse(**mock_response) - ) - session = Session( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=session_data, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], - ) +class TestAsyncSession(SessionFixtures): + @with_jwks_mock + async def test_refresh_success(self, session_constants, mock_user_management): + session_data = AsyncSession.seal_data( + { + "refresh_token": "refresh_token_12345", + "user": session_constants["TEST_USER"], + }, + session_constants["COOKIE_PASSWORD"], + ) + + mock_response = { + "access_token": session_constants["TEST_TOKEN"], + "refresh_token": "refresh_token_123", + "sealed_session": session_data, + "user": session_constants["TEST_USER"], + } + + mock_user_management.authenticate_with_refresh_token.return_value = ( + RefreshTokenAuthenticationResponse(**mock_response) + ) + + session = AsyncSession( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + ) - response = session.refresh() + with patch.object(session, "_is_valid_jwt", return_value=True) as _: + with patch( + "jwt.decode", + return_value={ + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + }, + ): + response = await session.refresh() + + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == session_constants["TEST_USER"]["id"] + + # Verify the refresh token was used correctly + mock_user_management.authenticate_with_refresh_token.assert_called_once_with( + refresh_token="refresh_token_12345", + organization_id=None, + session={ + "seal_session": True, + "cookie_password": session_constants["COOKIE_PASSWORD"], + }, + ) - assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + @with_jwks_mock + async def test_refresh_success_with_aud_claim( + self, session_constants, mock_user_management + ): + session_data = AsyncSession.seal_data( + { + "refresh_token": "refresh_token_12345", + "user": session_constants["TEST_USER"], + }, + session_constants["COOKIE_PASSWORD"], + ) + access_token = jwt.encode( + { + **session_constants["TEST_TOKEN_CLAIMS"], + **{"aud": session_constants["CLIENT_ID"]}, + }, + session_constants["PRIVATE_KEY"], + algorithm="RS256", + ) -def test_seal_data(TEST_CONSTANTS): - test_data = {"test": "data"} - sealed = Session.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"]) - assert isinstance(sealed, str) + mock_response = { + "access_token": access_token, + "refresh_token": "refresh_token_123", + "sealed_session": session_data, + "user": session_constants["TEST_USER"], + } - # Test unsealing - unsealed = Session.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"]) + mock_user_management.authenticate_with_refresh_token.return_value = ( + RefreshTokenAuthenticationResponse(**mock_response) + ) - assert unsealed == test_data + session = AsyncSession( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + response = await session.refresh() -def test_unseal_invalid_data(TEST_CONSTANTS): - with pytest.raises(Exception): # Adjust exception type based on your implementation - Session.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"]) + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) diff --git a/workos/session.py b/workos/session.py index abee954c..3a105081 100644 --- a/workos/session.py +++ b/workos/session.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Protocol import json from typing import Any, Dict, Optional, Union, cast @@ -7,9 +7,6 @@ from jwt import PyJWKClient from cryptography.fernet import Fernet -from workos.types.user_management.authentication_response import ( - RefreshTokenAuthenticationResponse, -) from workos.types.user_management.session import ( AuthenticateWithSessionCookieFailureReason, AuthenticateWithSessionCookieSuccessResponse, @@ -17,12 +14,21 @@ RefreshWithSessionCookieErrorResponse, RefreshWithSessionCookieSuccessResponse, ) +from workos.typing.sync_or_async import SyncOrAsync if TYPE_CHECKING: from workos.user_management import UserManagementModule + from workos.user_management import AsyncUserManagement, UserManagement + +class SessionModule(Protocol): + user_management: "UserManagementModule" + client_id: str + session_data: str + cookie_password: str + jwks: PyJWKClient + jwk_algorithms: List[str] -class Session: def __init__( self, *, @@ -96,6 +102,86 @@ def authenticate( impersonator=session.get("impersonator", None), ) + def refresh( + self, + *, + organization_id: Optional[str] = None, + cookie_password: Optional[str] = None, + ) -> SyncOrAsync[ + Union[ + RefreshWithSessionCookieSuccessResponse, + RefreshWithSessionCookieErrorResponse, + ] + ]: ... + + def get_logout_url(self, return_to: Optional[str] = None) -> str: + auth_response = self.authenticate() + + if isinstance(auth_response, AuthenticateWithSessionCookieErrorResponse): + raise ValueError( + f"Failed to extract session ID for logout URL: {auth_response.reason}" + ) + + result = self.user_management.get_logout_url( + session_id=auth_response.session_id, + return_to=return_to, + ) + return str(result) + + def _is_valid_jwt(self, token: str) -> bool: + try: + signing_key = self.jwks.get_signing_key_from_jwt(token) + jwt.decode( + token, + signing_key.key, + algorithms=self.jwk_algorithms, + options={"verify_aud": False}, + ) + return True + except jwt.exceptions.InvalidTokenError: + return False + + @staticmethod + def seal_data(data: Dict[str, Any], key: str) -> str: + fernet = Fernet(key) + # Encrypt and convert bytes to string + encrypted_bytes = fernet.encrypt(json.dumps(data).encode()) + return encrypted_bytes.decode("utf-8") + + @staticmethod + def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]: + fernet = Fernet(key) + # Convert string back to bytes before decryption + encrypted_bytes = sealed_data.encode("utf-8") + decrypted_str = fernet.decrypt(encrypted_bytes).decode() + return cast(Dict[str, Any], json.loads(decrypted_str)) + + +class Session(SessionModule): + user_management: "UserManagement" + + def __init__( + self, + *, + user_management: "UserManagement", + client_id: str, + session_data: str, + cookie_password: str, + ) -> None: + # If the cookie password is not provided, throw an error + if cookie_password is None or cookie_password == "": + raise ValueError("cookie_password is required") + + self.user_management = user_management + self.client_id = client_id + self.session_data = session_data + self.cookie_password = cookie_password + + self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + + # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm + self.jwk_algorithms = ["RS256"] + def refresh( self, *, @@ -124,13 +210,10 @@ def refresh( ) try: - auth_response = cast( - RefreshTokenAuthenticationResponse, - self.user_management.authenticate_with_refresh_token( - refresh_token=session["refresh_token"], - organization_id=organization_id, - session={"seal_session": True, "cookie_password": cookie_password}, - ), + auth_response = self.user_management.authenticate_with_refresh_token( + refresh_token=session["refresh_token"], + organization_id=organization_id, + session={"seal_session": True, "cookie_password": cookie_password}, ) self.session_data = str(auth_response.sealed_session) @@ -163,44 +246,92 @@ def refresh( authenticated=False, reason=str(e) ) - def get_logout_url(self, return_to: Optional[str] = None) -> str: - auth_response = self.authenticate() - if isinstance(auth_response, AuthenticateWithSessionCookieErrorResponse): - raise ValueError( - f"Failed to extract session ID for logout URL: {auth_response.reason}" - ) +class AsyncSession(SessionModule): + user_management: "AsyncUserManagement" - result = self.user_management.get_logout_url( - session_id=auth_response.session_id, - return_to=return_to, + def __init__( + self, + *, + user_management: "AsyncUserManagement", + client_id: str, + session_data: str, + cookie_password: str, + ) -> None: + # If the cookie password is not provided, throw an error + if cookie_password is None or cookie_password == "": + raise ValueError("cookie_password is required") + + self.user_management = user_management + self.client_id = client_id + self.session_data = session_data + self.cookie_password = cookie_password + + self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + + # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm + self.jwk_algorithms = ["RS256"] + + async def refresh( + self, + *, + organization_id: Optional[str] = None, + cookie_password: Optional[str] = None, + ) -> Union[ + RefreshWithSessionCookieSuccessResponse, + RefreshWithSessionCookieErrorResponse, + ]: + cookie_password = ( + self.cookie_password if cookie_password is None else cookie_password ) - return str(result) - def _is_valid_jwt(self, token: str) -> bool: try: - signing_key = self.jwks.get_signing_key_from_jwt(token) - jwt.decode( - token, + session = self.unseal_data(self.session_data, cookie_password) + except Exception: + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, + ) + + if not session.get("refresh_token", None) or not session.get("user", None): + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, + ) + + try: + auth_response = await self.user_management.authenticate_with_refresh_token( + refresh_token=session["refresh_token"], + organization_id=organization_id, + session={"seal_session": True, "cookie_password": cookie_password}, + ) + + self.session_data = str(auth_response.sealed_session) + self.cookie_password = ( + cookie_password if cookie_password is not None else self.cookie_password + ) + + signing_key = self.jwks.get_signing_key_from_jwt(auth_response.access_token) + + decoded = jwt.decode( + auth_response.access_token, signing_key.key, algorithms=self.jwk_algorithms, options={"verify_aud": False}, ) - return True - except jwt.exceptions.InvalidTokenError: - return False - @staticmethod - def seal_data(data: Dict[str, Any], key: str) -> str: - fernet = Fernet(key) - # Encrypt and convert bytes to string - encrypted_bytes = fernet.encrypt(json.dumps(data).encode()) - return encrypted_bytes.decode("utf-8") - - @staticmethod - def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]: - fernet = Fernet(key) - # Convert string back to bytes before decryption - encrypted_bytes = sealed_data.encode("utf-8") - decrypted_str = fernet.decrypt(encrypted_bytes).decode() - return cast(Dict[str, Any], json.loads(decrypted_str)) + return RefreshWithSessionCookieSuccessResponse( + authenticated=True, + sealed_session=str(auth_response.sealed_session), + session_id=decoded["sid"], + organization_id=decoded.get("org_id", None), + role=decoded.get("role", None), + permissions=decoded.get("permissions", None), + entitlements=decoded.get("entitlements", None), + user=auth_response.user, + impersonator=auth_response.impersonator, + ) + except Exception as e: + return RefreshWithSessionCookieErrorResponse( + authenticated=False, reason=str(e) + ) diff --git a/workos/user_management.py b/workos/user_management.py index e67b032a..93ab5c8c 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -1,7 +1,7 @@ -from typing import Optional, Protocol, Sequence, Type, cast +from typing import Awaitable, Optional, Protocol, Sequence, Type, Union, cast from urllib.parse import urlencode from workos._client_configuration import ClientConfiguration -from workos.session import Session +from workos.session import AsyncSession, Session from workos.types.list_resource import ( ListArgs, ListMetadata, @@ -117,7 +117,7 @@ class UserManagementModule(Protocol): def load_sealed_session( self, *, sealed_session: str, cookie_password: str - ) -> SyncOrAsync[Session]: + ) -> Union[Session, Awaitable[AsyncSession]]: """Load a sealed session and return the session data. Args: @@ -1485,8 +1485,8 @@ def __init__( async def load_sealed_session( self, *, sealed_session: str, cookie_password: str - ) -> Session: - return Session( + ) -> AsyncSession: + return AsyncSession( user_management=self, client_id=self._http_client.client_id, session_data=sealed_session,