diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index cb2c4bde92..302ef80bbc 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -5,9 +5,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar +from sqlalchemy import event from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine -from sqlalchemy.sql import insert, select, text +from sqlalchemy.sql import insert, select from typing_extensions import override from crawlee._utils.docs import docs_group @@ -22,7 +23,9 @@ if TYPE_CHECKING: from types import TracebackType + from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.pool import ConnectionPoolEntry logger = getLogger(__name__) @@ -72,8 +75,7 @@ def __init__( self._initialized = False self.session_maker: None | async_sessionmaker[AsyncSession] = None - # Flag needed to apply optimizations only for default database - self._default_flag = self._engine is None and self._connection_string is None + self._listeners_registered = False self._dialect_name: str | None = None # Call the notification only once @@ -115,9 +117,10 @@ async def initialize(self, configuration: Configuration) -> None: """ if not self._initialized: engine = self._get_or_create_engine(configuration) - async with engine.begin() as conn: - self._dialect_name = engine.dialect.name + self._dialect_name = engine.dialect.name + + async with engine.begin() as conn: if self._dialect_name not in self._SUPPORTED_DIALECTS: raise ValueError( f'Unsupported database dialect: {self._dialect_name}. Supported: ' @@ -128,16 +131,8 @@ async def initialize(self, configuration: Configuration) -> None: # Rollback the transaction when an exception occurs. # This is likely an attempt to create a database from several parallel processes. try: - # Set SQLite pragmas for performance and consistency - if self._default_flag: - await conn.execute(text('PRAGMA journal_mode=WAL')) # Better concurrency - await conn.execute(text('PRAGMA synchronous=NORMAL')) # Balanced safety/speed - await conn.execute(text('PRAGMA cache_size=100000')) # 100MB cache - await conn.execute(text('PRAGMA temp_store=MEMORY')) # Memory temp storage - await conn.execute(text('PRAGMA mmap_size=268435456')) # 256MB memory mapping - await conn.execute(text('PRAGMA foreign_keys=ON')) # Enforce constraints - await conn.execute(text('PRAGMA busy_timeout=30000')) # 30s busy timeout await conn.run_sync(Base.metadata.create_all, checkfirst=True) + from crawlee import __version__ # Noqa: PLC0415 db_version = (await conn.execute(select(VersionDb))).scalar_one_or_none() @@ -153,6 +148,7 @@ async def initialize(self, configuration: Configuration) -> None: ) elif not db_version: await conn.execute(insert(VersionDb).values(version=__version__)) + except (IntegrityError, OperationalError): await conn.rollback() @@ -161,6 +157,10 @@ async def initialize(self, configuration: Configuration) -> None: async def close(self) -> None: """Close the database connection pool.""" if self._engine is not None: + if self._listeners_registered: + event.remove(self._engine.sync_engine, 'connect', self._on_connect) + self._listeners_registered = False + await self._engine.dispose() self._engine = None @@ -285,4 +285,21 @@ def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine: connect_args=connect_args, **kwargs, ) + + event.listen(self._engine.sync_engine, 'connect', self._on_connect) + self._listeners_registered = True + return self._engine + + def _on_connect(self, dbapi_conn: DBAPIConnection, _connection_record: ConnectionPoolEntry) -> None: + """Event listener for new database connections to set pragmas.""" + if self._dialect_name == 'sqlite': + cursor = dbapi_conn.cursor() + cursor.execute('PRAGMA journal_mode=WAL') # Better concurrency + cursor.execute('PRAGMA synchronous=NORMAL') # Balanced safety/speed + cursor.execute('PRAGMA cache_size=100000') # 100MB cache + cursor.execute('PRAGMA temp_store=MEMORY') # Memory temp storage + cursor.execute('PRAGMA mmap_size=268435456') # 256MB memory mapping + cursor.execute('PRAGMA foreign_keys=ON') # Enforce constraints + cursor.execute('PRAGMA busy_timeout=30000') # 30s busy timeout + cursor.close() diff --git a/tests/unit/storage_clients/_sql/test_sql_storage_client.py b/tests/unit/storage_clients/_sql/test_sql_storage_client.py new file mode 100644 index 0000000000..93202b54b9 --- /dev/null +++ b/tests/unit/storage_clients/_sql/test_sql_storage_client.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine + +from crawlee.configuration import Configuration +from crawlee.storage_clients import SqlStorageClient + +if TYPE_CHECKING: + from pathlib import Path + + +async def test_sqlite_wal_mode_with_default_connection(tmp_path: Path) -> None: + """Test that WAL mode is applied for the default SQLite connection.""" + configuration = Configuration(storage_dir=str(tmp_path)) + + async with SqlStorageClient() as storage_client: + await storage_client.initialize(configuration) + + async with storage_client.engine.begin() as conn: + result = await conn.execute(text('PRAGMA journal_mode')) + assert result.scalar() == 'wal' + + +async def test_sqlite_wal_mode_with_connection_string(tmp_path: Path) -> None: + """Test that WAL mode is applied when using a custom SQLite connection string.""" + db_path = tmp_path / 'test.db' + configuration = Configuration(storage_dir=str(tmp_path)) + + async with SqlStorageClient(connection_string=f'sqlite+aiosqlite:///{db_path}') as storage_client: + await storage_client.initialize(configuration) + + async with storage_client.engine.begin() as conn: + result = await conn.execute(text('PRAGMA journal_mode')) + assert result.scalar() == 'wal' + + +async def test_sqlite_wal_mode_not_applied_with_custom_engine(tmp_path: Path) -> None: + """Test that WAL mode is not applied when using a user-provided engine.""" + db_path = tmp_path / 'test.db' + configuration = Configuration(storage_dir=str(tmp_path)) + engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True) + + async with SqlStorageClient(engine=engine) as storage_client: + await storage_client.initialize(configuration) + + async with engine.begin() as conn: + result = await conn.execute(text('PRAGMA journal_mode')) + assert result.scalar() != 'wal' diff --git a/uv.lock b/uv.lock index dcffbb003a..fdf6c75858 100644 --- a/uv.lock +++ b/uv.lock @@ -8,6 +8,10 @@ resolution-markers = [ "python_full_version < '3.11'", ] +[options] +exclude-newer = "2026-04-04T22:02:40.053808872Z" +exclude-newer-span = "PT24H" + [[package]] name = "aiomysql" version = "0.3.2"