diff --git a/src/flask_session/__init__.py b/src/flask_session/__init__.py index 5d4caee9..25ae63b4 100644 --- a/src/flask_session/__init__.py +++ b/src/flask_session/__init__.py @@ -101,6 +101,11 @@ def _get_interface(self, app): "SESSION_SQLALCHEMY_BIND_KEY", Defaults.SESSION_SQLALCHEMY_BIND_KEY ) + # SQLAlchemy-native settings + SESSION_SQLALCHEMY_ENGINE = config.get( + "SESSION_SQLALCHEMY_ENGINE", Defaults.SESSION_SQLALCHEMY_ENGINE + ) + # DynamoDB settings SESSION_DYNAMODB = config.get("SESSION_DYNAMODB", Defaults.SESSION_DYNAMODB) SESSION_DYNAMODB_TABLE = config.get( @@ -187,6 +192,18 @@ def _get_interface(self, app): bind_key=SESSION_SQLALCHEMY_BIND_KEY, cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS, ) + elif SESSION_TYPE == "sqlalchemy_native": + from .sqlalchemy_native import NativeSqlAlchemySessionInterface + + session_interface = NativeSqlAlchemySessionInterface( + **common_params, + engine=SESSION_SQLALCHEMY_ENGINE, + table=SESSION_SQLALCHEMY_TABLE, + sequence=SESSION_SQLALCHEMY_SEQUENCE, + schema=SESSION_SQLALCHEMY_SCHEMA, + bind_key=SESSION_SQLALCHEMY_BIND_KEY, + cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS, + ) elif SESSION_TYPE == "dynamodb": from .dynamodb import DynamoDBSessionInterface diff --git a/src/flask_session/defaults.py b/src/flask_session/defaults.py index f1bc1501..f332eb9d 100644 --- a/src/flask_session/defaults.py +++ b/src/flask_session/defaults.py @@ -40,6 +40,9 @@ class Defaults: SESSION_SQLALCHEMY_SCHEMA = None SESSION_SQLALCHEMY_BIND_KEY = None + # SQLAlchemy-native settings + SESSION_SQLALCHEMY_ENGINE = None + # DynamoDB settings SESSION_DYNAMODB = None SESSION_DYNAMODB_TABLE = "Sessions" diff --git a/src/flask_session/sqlalchemy_native/__init__.py b/src/flask_session/sqlalchemy_native/__init__.py new file mode 100644 index 00000000..a2562356 --- /dev/null +++ b/src/flask_session/sqlalchemy_native/__init__.py @@ -0,0 +1,4 @@ +from .sqlalchemy_native import ( # noqa: F401 + NativeSqlAlchemySession, + NativeSqlAlchemySessionInterface, +) diff --git a/src/flask_session/sqlalchemy_native/sqlalchemy_native.py b/src/flask_session/sqlalchemy_native/sqlalchemy_native.py new file mode 100644 index 00000000..40f71528 --- /dev/null +++ b/src/flask_session/sqlalchemy_native/sqlalchemy_native.py @@ -0,0 +1,178 @@ +from datetime import datetime +from datetime import timedelta as TimeDelta +from typing import Optional + +from flask import Flask +from itsdangerous import want_bytes +from sqlalchemy import ( + Column, + DateTime, + Engine, + Integer, + LargeBinary, + Sequence, + String, + delete, + select, +) +from sqlalchemy.orm import DeclarativeBase, Session + +from .._utils import retry_query +from ..base import ServerSideSession, ServerSideSessionInterface +from ..defaults import Defaults + + +class NativeSqlAlchemySession(ServerSideSession): + pass + + +class Base(DeclarativeBase): + pass + + +def create_session_model(table_name, schema=None, bind_key=None, sequence=None): + class Session(Base): + __tablename__ = table_name + __table_args__ = {"schema": schema} if schema else {} + __bind_key__ = bind_key + + id = ( + Column(Integer, Sequence(sequence), primary_key=True) + if sequence + else Column(Integer, primary_key=True) + ) + session_id = Column(String(255), unique=True) + data = Column(LargeBinary) + expiry = Column(DateTime) + + def __repr__(self): + return f"" + + return Session + + +class NativeSqlAlchemySessionInterface(ServerSideSessionInterface): + """Uses a SQLAlchemy engine as session storage. + + :param app: A Flask app instance. + :param engine: A SQLAlchemy engine instance. + :param key_prefix: A prefix that is added to all storage keys. + :param use_signer: Whether to sign the session id cookie or not. + :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param serialization_format: The serialization format to use for the session data. + :param table: The table name you want to use. + :param sequence: The sequence to use for the primary key if needed. + :param schema: The db schema to use. + :param bind_key: The db bind key to use. + :param cleanup_n_requests: Delete expired sessions on average every N requests. + """ + + session_class = NativeSqlAlchemySession + ttl = False + + def __init__( + self, + app: Optional[Flask], + engine: Optional[Engine] = Defaults.SESSION_SQLALCHEMY_ENGINE, + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_ID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + table: str = Defaults.SESSION_SQLALCHEMY_TABLE, + sequence: Optional[str] = Defaults.SESSION_SQLALCHEMY_SEQUENCE, + schema: Optional[str] = Defaults.SESSION_SQLALCHEMY_SCHEMA, + bind_key: Optional[str] = Defaults.SESSION_SQLALCHEMY_BIND_KEY, + cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS, + ): + self.app = app + + if engine is None or not isinstance(engine, Engine): + raise TypeError("No valid Engine instance provided.") + self.engine = engine + + # Create the session model + self.sql_session_model = create_session_model( + table, schema, bind_key, sequence + ) + # Create the table if it does not exist + self.sql_session_model.__table__.create(bind=engine, checkfirst=True) + + super().__init__( + app, + key_prefix, + use_signer, + permanent, + sid_length, + serialization_format, + cleanup_n_requests, + ) + + @retry_query() + def _delete_expired_sessions(self) -> None: + with Session(self.engine) as session: + session.execute( + delete(self.sql_session_model) + .where(self.sql_session_model.expiry <= datetime.utcnow()), + execution_options={"synchronize_session": False} + ) + session.commit() + + @retry_query() + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + # Get the saved session (record) from the database + with Session(self.engine) as session: + record = session.scalars( + select(self.sql_session_model) + .where(self.sql_session_model.session_id == store_id) + ).first() + + # "Delete the session record if it is expired as SQL has no TTL ability + if record and (record.expiry is None or record.expiry <= datetime.utcnow()): + with Session(self.engine) as session: + session.delete(record) + session.commit() + record = None + + if record: + serialized_session_data = want_bytes(record.data) + return self.serializer.loads(serialized_session_data) + return None + + @retry_query() + def _delete_session(self, store_id: str) -> None: + with Session(self.engine) as session: + session.execute( + delete(self.sql_session_model) + .where(self.sql_session_model.session_id == store_id) + ) + session.commit() + + @retry_query() + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + storage_expiration_datetime = datetime.utcnow() + session_lifetime + + # Serialize session data + serialized_session_data = self.serializer.dumps(dict(session)) + + # Update existing or create new session in the database + with Session(self.engine) as session: + record = session.scalars( + select(self.sql_session_model) + .where(self.sql_session_model.session_id == store_id) + ).first() + + if record: + record.data = serialized_session_data + record.expiry = storage_expiration_datetime + else: + record = self.sql_session_model( + session_id=store_id, + data=serialized_session_data, + expiry=storage_expiration_datetime, + ) + session.add(record) + session.commit() diff --git a/tests/test_sqlalchemy_native.py b/tests/test_sqlalchemy_native.py new file mode 100644 index 00000000..cdd2a7b4 --- /dev/null +++ b/tests/test_sqlalchemy_native.py @@ -0,0 +1,59 @@ +import json +from contextlib import contextmanager + +import flask +from flask_session.sqlalchemy_native import NativeSqlAlchemySession +from sqlalchemy import create_engine, select, text +from sqlalchemy.orm import Session + + +class TestNativeSQLAlchemy: + """This requires package: sqlalchemy""" + + @contextmanager + def setup_sqlalchemy(self, app): + try: + with Session(app.session_interface.engine) as session: + session.execute(text("DELETE FROM sessions")) + session.commit() + yield + finally: + with Session(app.session_interface.engine) as session: + session.execute(text("DELETE FROM sessions")) + session.close() + + def retrieve_stored_session(self, key, app): + with Session(app.session_interface.engine) as session: + session_model = session.scalars( + select(app.session_interface.sql_session_model) + .where(app.session_interface.sql_session_model.session_id == key) + ).first() + if session_model: + return session_model.data + return None + + def test_use_signer(self, app_utils): + engine = create_engine("sqlite:///") + app = app_utils.create_app( + { + "SESSION_TYPE": "sqlalchemy_native", + "SESSION_SQLALCHEMY_ENGINE": engine, + } + ) + with app.app_context() and self.setup_sqlalchemy( + app + ) and app.test_request_context(): + assert isinstance( + flask.session, + NativeSqlAlchemySession, + ) + app_utils.test_session(app) + + # Check if the session is stored in SQLAlchemy + cookie = app_utils.test_session_with_cookie(app) + session_id = cookie.split(";")[0].split("=")[1] + byte_string = self.retrieve_stored_session(f"session:{session_id}", app) + stored_session = ( + json.loads(byte_string.decode("utf-8")) if byte_string else {} + ) + assert stored_session.get("value") == "44"