Skip to content

Commit 90c9368

Browse files
author
Eugene Shershen
committed
tests
1 parent bc0f126 commit 90c9368

3 files changed

Lines changed: 210 additions & 95 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
rev: v0.12.4
1515
hooks:
1616
- id: ruff
17-
args: [--fix]
17+
args: [--fix, --unsafe-fixes]
1818
- id: ruff-format
1919

2020
- repo: https://github.com/pre-commit/mirrors-mypy

tests/test_custom_fastapi_middleware.py

Lines changed: 115 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,47 @@
88
import unittest
99
from contextvars import ContextVar
1010
from typing import Dict, Optional, Union
11-
from unittest.mock import Mock
1211

1312
try:
1413
from fastapi import FastAPI
15-
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
14+
from starlette.middleware.base import (
15+
BaseHTTPMiddleware,
16+
RequestResponseEndpoint,
17+
)
1618
from starlette.requests import Request
1719
from starlette.types import ASGIApp
20+
1821
FASTAPI_AVAILABLE = True
1922
except ImportError:
2023
FASTAPI_AVAILABLE = False
24+
2125
# Create mock classes for when FastAPI is not available
2226
class FastAPI:
2327
def __init__(self, *args, **kwargs):
2428
pass
29+
2530
def add_middleware(self, *args, **kwargs):
2631
pass
32+
2733
def get(self, *args, **kwargs):
2834
def decorator(func):
2935
return func
36+
3037
return decorator
31-
38+
3239
class BaseHTTPMiddleware:
3340
def __init__(self, *args, **kwargs):
3441
pass
3542

43+
3644
from sqlalchemy.engine import Engine
3745
from sqlalchemy.engine.url import URL
38-
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
46+
from sqlalchemy.ext.asyncio import (
47+
AsyncSession,
48+
async_sessionmaker,
49+
create_async_engine,
50+
)
51+
3952

4053
# Mock settings for testing
4154
class MockSettings:
@@ -45,28 +58,40 @@ class MockSettings:
4558
DB_POOL_SIZE_MAX = 10
4659
SERVICE_NAME = "test_service"
4760

61+
4862
settings = MockSettings()
4963

64+
5065
# Mock SSLMode for testing
5166
class SSLMode:
5267
disable = "disable"
5368

69+
5470
# Custom exceptions for the middleware
5571
class MissingSessionError(Exception):
5672
"""Raised when no session is available in the current context."""
73+
5774
pass
5875

76+
5977
class SessionNotInitialisedError(Exception):
6078
"""Raised when the session factory is not initialized."""
79+
6180
pass
6281

6382

6483
def create_middleware_and_session_proxy():
6584
"""Create the custom middleware and session proxy as provided in the issue."""
6685
_Session: Optional[async_sessionmaker] = None
67-
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
68-
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
69-
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
86+
_session: ContextVar[Optional[AsyncSession]] = ContextVar(
87+
"_session", default=None
88+
)
89+
_multi_sessions_ctx: ContextVar[bool] = ContextVar(
90+
"_multi_sessions_context", default=False
91+
)
92+
_commit_on_exit_ctx: ContextVar[bool] = ContextVar(
93+
"_commit_on_exit_ctx", default=False
94+
)
7095

7196
class SQLAlchemyMiddleware(BaseHTTPMiddleware):
7297
def __init__(
@@ -84,18 +109,25 @@ def __init__(
84109
session_args = session_args or {}
85110

86111
if not custom_engine and not db_url:
87-
raise ValueError("You need to pass a db_url or a custom_engine parameter.")
112+
raise ValueError(
113+
"You need to pass a db_url or a custom_engine parameter."
114+
)
88115
if not custom_engine:
89116
engine = create_async_engine(db_url, **engine_args)
90117
else:
91118
engine = custom_engine
92119

93120
nonlocal _Session
94121
_Session = async_sessionmaker(
95-
engine, class_=AsyncSession, expire_on_commit=False, **session_args
122+
engine,
123+
class_=AsyncSession,
124+
expire_on_commit=False,
125+
**session_args,
96126
)
97127

98-
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
128+
async def dispatch(
129+
self, request: Request, call_next: RequestResponseEndpoint
130+
):
99131
async with DBSession(commit_on_exit=self.commit_on_exit):
100132
return await call_next(request)
101133

@@ -124,7 +156,9 @@ async def cleanup():
124156

125157
task = asyncio.current_task()
126158
if task is not None:
127-
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
159+
task.add_done_callback(
160+
lambda t: asyncio.create_task(cleanup())
161+
)
128162
return session
129163
else:
130164
session = _session.get()
@@ -151,7 +185,9 @@ async def __aenter__(self):
151185

152186
if self.multi_sessions:
153187
self.multi_sessions_token = _multi_sessions_ctx.set(True)
154-
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
188+
self.commit_on_exit_token = _commit_on_exit_ctx.set(
189+
self.commit_on_exit
190+
)
155191
else:
156192
self.token = _session.set(_Session(**self.session_args))
157193
return type(self)
@@ -185,10 +221,10 @@ def setUp(self):
185221
def test_middleware_creation_with_db_url(self):
186222
"""Test creating middleware with database URL"""
187223
SQLAlchemyMiddleware, DBSession = create_middleware_and_session_proxy()
188-
224+
189225
# Create a mock FastAPI app
190226
app = FastAPI(title="Test App")
191-
227+
192228
# Test middleware creation with basic settings
193229
try:
194230
app.add_middleware(
@@ -201,19 +237,24 @@ def test_middleware_creation_with_db_url(self):
201237
"max_overflow": settings.DB_POOL_SIZE_MAX,
202238
"pool_recycle": 900,
203239
"connect_args": {
204-
"server_settings": {"jit": "off", "application_name": settings.SERVICE_NAME},
240+
"server_settings": {
241+
"jit": "off",
242+
"application_name": settings.SERVICE_NAME,
243+
},
205244
"ssl": SSLMode.disable,
206245
},
207246
},
208247
)
209-
self.assertTrue(True, "Custom FastAPI middleware created successfully")
248+
self.assertTrue(
249+
True, "Custom FastAPI middleware created successfully"
250+
)
210251
except Exception as e:
211252
self.fail(f"Failed to create custom FastAPI middleware: {e}")
212253

213254
def test_middleware_creation_with_custom_engine(self):
214255
"""Test creating middleware with custom engine"""
215256
SQLAlchemyMiddleware, DBSession = create_middleware_and_session_proxy()
216-
257+
217258
# Create a custom engine
218259
engine = create_async_engine(
219260
settings.DB_URL,
@@ -223,46 +264,57 @@ def test_middleware_creation_with_custom_engine(self):
223264
max_overflow=settings.DB_POOL_SIZE_MAX,
224265
pool_recycle=900,
225266
connect_args={
226-
"server_settings": {"jit": "off", "application_name": settings.SERVICE_NAME},
267+
"server_settings": {
268+
"jit": "off",
269+
"application_name": settings.SERVICE_NAME,
270+
},
227271
"ssl": SSLMode.disable,
228272
},
229273
)
230-
274+
231275
# Create a mock FastAPI app
232276
app = FastAPI(title="Test App with Custom Engine")
233-
277+
234278
try:
235279
app.add_middleware(
236280
SQLAlchemyMiddleware,
237281
custom_engine=engine,
238282
commit_on_exit=True,
239283
)
240-
self.assertTrue(True, "Custom FastAPI middleware with custom engine created successfully")
284+
self.assertTrue(
285+
True,
286+
"Custom FastAPI middleware with custom engine created successfully",
287+
)
241288
except Exception as e:
242-
self.fail(f"Failed to create custom FastAPI middleware with custom engine: {e}")
289+
self.fail(
290+
f"Failed to create custom FastAPI middleware with custom engine: {e}"
291+
)
243292

244293
def test_middleware_validation_errors(self):
245294
"""Test middleware validation for required parameters"""
246295
SQLAlchemyMiddleware, DBSession = create_middleware_and_session_proxy()
247-
296+
248297
app = FastAPI(title="Test App")
249-
298+
250299
# Test missing both db_url and custom_engine by calling the constructor directly
251300
with self.assertRaises(ValueError) as context:
252301
# This should raise ValueError when called directly
253-
middleware = SQLAlchemyMiddleware(app)
254-
255-
self.assertIn("You need to pass a db_url or a custom_engine parameter", str(context.exception))
302+
SQLAlchemyMiddleware(app)
303+
304+
self.assertIn(
305+
"You need to pass a db_url or a custom_engine parameter",
306+
str(context.exception),
307+
)
256308

257309
def test_db_session_context_manager(self):
258310
"""Test DBSession context manager functionality"""
259311
SQLAlchemyMiddleware, DBSession = create_middleware_and_session_proxy()
260-
312+
261313
# Test DBSession creation
262314
db_session = DBSession(commit_on_exit=True)
263315
self.assertIsInstance(db_session, DBSession)
264316
self.assertTrue(db_session.commit_on_exit)
265-
317+
266318
# Test multi-sessions mode
267319
multi_db_session = DBSession(multi_sessions=True, commit_on_exit=False)
268320
self.assertIsInstance(multi_db_session, DBSession)
@@ -272,26 +324,30 @@ def test_db_session_context_manager(self):
272324
def test_session_proxy_creation(self):
273325
"""Test that the session proxy is created correctly"""
274326
SQLAlchemyMiddleware, DBSession = create_middleware_and_session_proxy()
275-
327+
276328
# Verify that we get the expected classes
277329
if FASTAPI_AVAILABLE:
278-
self.assertTrue(issubclass(SQLAlchemyMiddleware, BaseHTTPMiddleware))
279-
330+
self.assertTrue(
331+
issubclass(SQLAlchemyMiddleware, BaseHTTPMiddleware)
332+
)
333+
280334
# Test that DBSession has the expected metaclass behavior
281-
self.assertTrue(hasattr(DBSession, '__aenter__'))
282-
self.assertTrue(hasattr(DBSession, '__aexit__'))
283-
335+
self.assertTrue(hasattr(DBSession, "__aenter__"))
336+
self.assertTrue(hasattr(DBSession, "__aexit__"))
337+
284338
# Test that the session property exists on the class (without accessing it)
285339
# We check the metaclass has the session property
286-
self.assertTrue(hasattr(type(DBSession), 'session'))
287-
self.assertTrue(isinstance(getattr(type(DBSession), 'session'), property))
340+
self.assertTrue(hasattr(type(DBSession), "session"))
341+
self.assertTrue(
342+
isinstance(getattr(type(DBSession), "session"), property)
343+
)
288344

289345
def test_middleware_with_session_args(self):
290346
"""Test middleware creation with session arguments"""
291347
SQLAlchemyMiddleware, DBSession = create_middleware_and_session_proxy()
292-
348+
293349
app = FastAPI(title="Test App with Session Args")
294-
350+
295351
try:
296352
app.add_middleware(
297353
SQLAlchemyMiddleware,
@@ -306,17 +362,22 @@ def test_middleware_with_session_args(self):
306362
},
307363
commit_on_exit=True,
308364
)
309-
self.assertTrue(True, "Custom FastAPI middleware with session args created successfully")
365+
self.assertTrue(
366+
True,
367+
"Custom FastAPI middleware with session args created successfully",
368+
)
310369
except Exception as e:
311-
self.fail(f"Failed to create custom FastAPI middleware with session args: {e}")
370+
self.fail(
371+
f"Failed to create custom FastAPI middleware with session args: {e}"
372+
)
312373

313374
@unittest.skipUnless(FASTAPI_AVAILABLE, "FastAPI not available")
314375
def test_full_integration_mock(self):
315376
"""Test full integration with mocked database operations"""
316377
SQLAlchemyMiddleware, DBSession = create_middleware_and_session_proxy()
317-
378+
318379
app = FastAPI(title="Full Integration Test")
319-
380+
320381
# Add the middleware
321382
app.add_middleware(
322383
SQLAlchemyMiddleware,
@@ -328,27 +389,33 @@ def test_full_integration_mock(self):
328389
"max_overflow": settings.DB_POOL_SIZE_MAX,
329390
"pool_recycle": 900,
330391
"connect_args": {
331-
"server_settings": {"jit": "off", "application_name": settings.SERVICE_NAME},
392+
"server_settings": {
393+
"jit": "off",
394+
"application_name": settings.SERVICE_NAME,
395+
},
332396
"ssl": SSLMode.disable,
333397
},
334398
},
335399
commit_on_exit=True,
336400
)
337-
401+
338402
# Add a test route
339403
@app.get("/test")
340404
async def test_route():
341405
return {"message": "Test successful"}
342-
406+
343407
# Test that the app and middleware were set up successfully
344408
# We don't need to actually make HTTP requests, just verify setup
345409
self.assertIsNotNone(app)
346-
self.assertTrue(True, "Full integration test completed - middleware setup successful")
410+
self.assertTrue(
411+
True,
412+
"Full integration test completed - middleware setup successful",
413+
)
347414

348415

349416
if __name__ == "__main__":
350417
if FASTAPI_AVAILABLE:
351418
unittest.main()
352419
else:
353420
print("FastAPI not available. Skipping tests.")
354-
print("To run these tests, install FastAPI: pip install fastapi")
421+
print("To run these tests, install FastAPI: pip install fastapi")

0 commit comments

Comments
 (0)