Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions app/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,17 @@ async def get_database() -> AsyncGenerator[AsyncSession, Any]:
"""Return the database connection as a Generator."""
async with async_session() as session, session.begin():
yield session


async def get_database_manual() -> AsyncGenerator[AsyncSession, Any]:
"""Return a MANUAL database session as a generator.

This will need an explicit `db.commit` call for any route using it, but
should allow multiple database calls in the same route.
"""
async with async_session() as session:
try:
yield session
except Exception:
await session.rollback()
raise
188 changes: 188 additions & 0 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Test functions in the app.database module."""

import contextlib
import os

import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker

from app.config.settings import Settings
Expand All @@ -24,6 +26,192 @@ async def test_get_database(self) -> None:
assert session is not None
assert isinstance(session, db.AsyncSession)

@pytest.mark.asyncio
async def test_get_database_manual_basic_functionality(self) -> None:
"""Test we get an async database session back from manual dependency."""
database = db.get_database_manual()
assert database is not None
assert isinstance(database, db.AsyncGenerator)

session = await database.__anext__()
assert session is not None
assert isinstance(session, db.AsyncSession)

@pytest.mark.asyncio
async def test_get_database_manual_session_type(self) -> None:
"""Test the yielded session has expected properties."""
database = db.get_database_manual()
session = await database.__anext__()

assert isinstance(session, db.AsyncSession)
assert hasattr(session, "commit")
assert hasattr(session, "rollback")
assert hasattr(session, "execute")

@pytest.mark.asyncio
async def test_get_database_manual_exception_rollback(self, mocker) -> None:
"""Test that get_database_manual handles exceptions and rolls back."""
# Mock the async_session to return a mock session
mock_session = mocker.AsyncMock()
mock_session.rollback = mocker.AsyncMock()

# Mock the context manager behavior
mock_session.__aenter__ = mocker.AsyncMock(return_value=mock_session)
mock_session.__aexit__ = mocker.AsyncMock(return_value=None)

mock_async_session = mocker.patch("app.database.db.async_session")
mock_async_session.return_value = mock_session

# Test that exception triggers rollback
database = db.get_database_manual()
session = await database.__anext__()

# Simulate an exception
with contextlib.suppress(Exception):
await database.athrow(Exception("Test exception"))

# Verify we got a valid session
assert session is not None

# Verify rollback was called
mock_session.rollback.assert_called_once()

@pytest.mark.asyncio
async def test_get_database_manual_no_auto_begin(self, mocker) -> None:
"""Test that manual session doesn't auto start a transaction."""
mock_session = mocker.AsyncMock()
mock_session.begin = mocker.AsyncMock()
mock_session.__aenter__ = mocker.AsyncMock(return_value=mock_session)
mock_session.__aexit__ = mocker.AsyncMock(return_value=None)

mock_async_session = mocker.patch("app.database.db.async_session")
mock_async_session.return_value = mock_session

database = db.get_database_manual()
session = await database.__anext__()

# Verify we got a valid session
assert session is not None

# Verify begin() was NOT called automatically
mock_session.begin.assert_not_called()

@pytest.mark.asyncio
async def test_get_database_manual_session_cleanup(self, mocker) -> None:
"""Test that the session is properly cleaned up after context exits."""
mock_session = mocker.AsyncMock()
mock_session.__aenter__ = mocker.AsyncMock(return_value=mock_session)
mock_session.__aexit__ = mocker.AsyncMock(return_value=None)

mock_async_session = mocker.patch("app.database.db.async_session")
mock_async_session.return_value = mock_session

database = db.get_database_manual()
session = await database.__anext__()

# Verify we got a valid session
assert session is not None

# Close the generator to trigger cleanup
await database.aclose()

# Verify the session context manager was properly exited
mock_session.__aexit__.assert_called_once()

@pytest.mark.asyncio
async def test_get_database_manual_multiple_operations(
self, mocker
) -> None:
"""Test that multiple database ops can be performed in same session."""
mock_session = mocker.AsyncMock()
mock_session.execute = mocker.AsyncMock()
mock_session.__aenter__ = mocker.AsyncMock(return_value=mock_session)
mock_session.__aexit__ = mocker.AsyncMock(return_value=None)

mock_async_session = mocker.patch("app.database.db.async_session")
mock_async_session.return_value = mock_session

database = db.get_database_manual()
session = await database.__anext__()

# Simulate multiple operations
await session.execute(text("SELECT 1"))
await session.execute(text("SELECT 2"))

# Verify both operations used the same session
assert mock_session.execute.call_count == 2 # noqa: PLR2004

@pytest.mark.asyncio
async def test_get_database_manual_explicit_commit(self, mocker) -> None:
"""Test that explicit commit works correctly."""
mock_session = mocker.AsyncMock()
mock_session.commit = mocker.AsyncMock()
mock_session.__aenter__ = mocker.AsyncMock(return_value=mock_session)
mock_session.__aexit__ = mocker.AsyncMock(return_value=None)

mock_async_session = mocker.patch("app.database.db.async_session")
mock_async_session.return_value = mock_session

database = db.get_database_manual()
session = await database.__anext__()

# Explicitly commit
await session.commit()

# Verify commit was called
mock_session.commit.assert_called_once()

@pytest.mark.asyncio
async def test_get_database_manual_explicit_rollback(self, mocker) -> None:
"""Test that explicit rollback works correctly."""
mock_session = mocker.AsyncMock()
mock_session.rollback = mocker.AsyncMock()
mock_session.__aenter__ = mocker.AsyncMock(return_value=mock_session)
mock_session.__aexit__ = mocker.AsyncMock(return_value=None)

mock_async_session = mocker.patch("app.database.db.async_session")
mock_async_session.return_value = mock_session

database = db.get_database_manual()
session = await database.__anext__()

# Explicitly rollback
await session.rollback()

# Verify rollback was called
mock_session.rollback.assert_called_once()

@pytest.mark.asyncio
async def test_get_database_manual_no_commit_data_lost(
self, mocker
) -> None:
"""Test that without manual commit, data changes are lost."""
# Mock a session that tracks if commit was called
mock_session = mocker.AsyncMock()
mock_session.commit = mocker.AsyncMock()
mock_session.execute = mocker.AsyncMock()
mock_session.__aenter__ = mocker.AsyncMock(return_value=mock_session)
mock_session.__aexit__ = mocker.AsyncMock(return_value=None)

mock_async_session = mocker.patch("app.database.db.async_session")
mock_async_session.return_value = mock_session

# Use the manual session without committing
database = db.get_database_manual()
session = await database.__anext__()

# Simulate data modification
await session.execute(text("INSERT INTO test_table VALUES (1)"))

# Close the session without committing
await database.aclose()

# Verify commit was never called (data would be lost)
mock_session.commit.assert_not_called()

# Verify the session was properly closed (triggering implicit rollback)
mock_session.__aexit__.assert_called_once()

def test_get_database_url_normal(self, mocker, monkeypatch) -> None:
"""Test the get_database_url function returns the correct URL."""
# enure GITHUB_ACTIONS is false, otherwise will fail in GitHub Actions
Expand Down
Loading