88import unittest
99from contextvars import ContextVar
1010from typing import Dict , Optional , Union
11- from unittest .mock import Mock
1211
1312try :
1413 from fastapi import FastAPI
15- from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
14+ from starlette .middleware .base import (
15+ BaseHTTPMiddleware ,
16+ RequestResponseEndpoint ,
17+ )
1618 from starlette .requests import Request
1719 from starlette .types import ASGIApp
20+
1821 FASTAPI_AVAILABLE = True
1922except ImportError :
2023 FASTAPI_AVAILABLE = False
24+
2125 # Create mock classes for when FastAPI is not available
2226 class FastAPI :
2327 def __init__ (self , * args , ** kwargs ):
2428 pass
29+
2530 def add_middleware (self , * args , ** kwargs ):
2631 pass
32+
2733 def get (self , * args , ** kwargs ):
2834 def decorator (func ):
2935 return func
36+
3037 return decorator
31-
38+
3239 class BaseHTTPMiddleware :
3340 def __init__ (self , * args , ** kwargs ):
3441 pass
3542
43+
3644from sqlalchemy .engine import Engine
3745from sqlalchemy .engine .url import URL
38- from sqlalchemy .ext .asyncio import AsyncSession , async_sessionmaker , create_async_engine
46+ from sqlalchemy .ext .asyncio import (
47+ AsyncSession ,
48+ async_sessionmaker ,
49+ create_async_engine ,
50+ )
51+
3952
4053# Mock settings for testing
4154class MockSettings :
@@ -45,28 +58,40 @@ class MockSettings:
4558 DB_POOL_SIZE_MAX = 10
4659 SERVICE_NAME = "test_service"
4760
61+
4862settings = MockSettings ()
4963
64+
5065# Mock SSLMode for testing
5166class SSLMode :
5267 disable = "disable"
5368
69+
5470# Custom exceptions for the middleware
5571class MissingSessionError (Exception ):
5672 """Raised when no session is available in the current context."""
73+
5774 pass
5875
76+
5977class SessionNotInitialisedError (Exception ):
6078 """Raised when the session factory is not initialized."""
79+
6180 pass
6281
6382
6483def create_middleware_and_session_proxy ():
6584 """Create the custom middleware and session proxy as provided in the issue."""
6685 _Session : Optional [async_sessionmaker ] = None
67- _session : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session" , default = None )
68- _multi_sessions_ctx : ContextVar [bool ] = ContextVar ("_multi_sessions_context" , default = False )
69- _commit_on_exit_ctx : ContextVar [bool ] = ContextVar ("_commit_on_exit_ctx" , default = False )
86+ _session : ContextVar [Optional [AsyncSession ]] = ContextVar (
87+ "_session" , default = None
88+ )
89+ _multi_sessions_ctx : ContextVar [bool ] = ContextVar (
90+ "_multi_sessions_context" , default = False
91+ )
92+ _commit_on_exit_ctx : ContextVar [bool ] = ContextVar (
93+ "_commit_on_exit_ctx" , default = False
94+ )
7095
7196 class SQLAlchemyMiddleware (BaseHTTPMiddleware ):
7297 def __init__ (
@@ -84,18 +109,25 @@ def __init__(
84109 session_args = session_args or {}
85110
86111 if not custom_engine and not db_url :
87- raise ValueError ("You need to pass a db_url or a custom_engine parameter." )
112+ raise ValueError (
113+ "You need to pass a db_url or a custom_engine parameter."
114+ )
88115 if not custom_engine :
89116 engine = create_async_engine (db_url , ** engine_args )
90117 else :
91118 engine = custom_engine
92119
93120 nonlocal _Session
94121 _Session = async_sessionmaker (
95- engine , class_ = AsyncSession , expire_on_commit = False , ** session_args
122+ engine ,
123+ class_ = AsyncSession ,
124+ expire_on_commit = False ,
125+ ** session_args ,
96126 )
97127
98- async def dispatch (self , request : Request , call_next : RequestResponseEndpoint ):
128+ async def dispatch (
129+ self , request : Request , call_next : RequestResponseEndpoint
130+ ):
99131 async with DBSession (commit_on_exit = self .commit_on_exit ):
100132 return await call_next (request )
101133
@@ -124,7 +156,9 @@ async def cleanup():
124156
125157 task = asyncio .current_task ()
126158 if task is not None :
127- task .add_done_callback (lambda t : asyncio .create_task (cleanup ()))
159+ task .add_done_callback (
160+ lambda t : asyncio .create_task (cleanup ())
161+ )
128162 return session
129163 else :
130164 session = _session .get ()
@@ -151,7 +185,9 @@ async def __aenter__(self):
151185
152186 if self .multi_sessions :
153187 self .multi_sessions_token = _multi_sessions_ctx .set (True )
154- self .commit_on_exit_token = _commit_on_exit_ctx .set (self .commit_on_exit )
188+ self .commit_on_exit_token = _commit_on_exit_ctx .set (
189+ self .commit_on_exit
190+ )
155191 else :
156192 self .token = _session .set (_Session (** self .session_args ))
157193 return type (self )
@@ -185,10 +221,10 @@ def setUp(self):
185221 def test_middleware_creation_with_db_url (self ):
186222 """Test creating middleware with database URL"""
187223 SQLAlchemyMiddleware , DBSession = create_middleware_and_session_proxy ()
188-
224+
189225 # Create a mock FastAPI app
190226 app = FastAPI (title = "Test App" )
191-
227+
192228 # Test middleware creation with basic settings
193229 try :
194230 app .add_middleware (
@@ -201,19 +237,24 @@ def test_middleware_creation_with_db_url(self):
201237 "max_overflow" : settings .DB_POOL_SIZE_MAX ,
202238 "pool_recycle" : 900 ,
203239 "connect_args" : {
204- "server_settings" : {"jit" : "off" , "application_name" : settings .SERVICE_NAME },
240+ "server_settings" : {
241+ "jit" : "off" ,
242+ "application_name" : settings .SERVICE_NAME ,
243+ },
205244 "ssl" : SSLMode .disable ,
206245 },
207246 },
208247 )
209- self .assertTrue (True , "Custom FastAPI middleware created successfully" )
248+ self .assertTrue (
249+ True , "Custom FastAPI middleware created successfully"
250+ )
210251 except Exception as e :
211252 self .fail (f"Failed to create custom FastAPI middleware: { e } " )
212253
213254 def test_middleware_creation_with_custom_engine (self ):
214255 """Test creating middleware with custom engine"""
215256 SQLAlchemyMiddleware , DBSession = create_middleware_and_session_proxy ()
216-
257+
217258 # Create a custom engine
218259 engine = create_async_engine (
219260 settings .DB_URL ,
@@ -223,46 +264,57 @@ def test_middleware_creation_with_custom_engine(self):
223264 max_overflow = settings .DB_POOL_SIZE_MAX ,
224265 pool_recycle = 900 ,
225266 connect_args = {
226- "server_settings" : {"jit" : "off" , "application_name" : settings .SERVICE_NAME },
267+ "server_settings" : {
268+ "jit" : "off" ,
269+ "application_name" : settings .SERVICE_NAME ,
270+ },
227271 "ssl" : SSLMode .disable ,
228272 },
229273 )
230-
274+
231275 # Create a mock FastAPI app
232276 app = FastAPI (title = "Test App with Custom Engine" )
233-
277+
234278 try :
235279 app .add_middleware (
236280 SQLAlchemyMiddleware ,
237281 custom_engine = engine ,
238282 commit_on_exit = True ,
239283 )
240- self .assertTrue (True , "Custom FastAPI middleware with custom engine created successfully" )
284+ self .assertTrue (
285+ True ,
286+ "Custom FastAPI middleware with custom engine created successfully" ,
287+ )
241288 except Exception as e :
242- self .fail (f"Failed to create custom FastAPI middleware with custom engine: { e } " )
289+ self .fail (
290+ f"Failed to create custom FastAPI middleware with custom engine: { e } "
291+ )
243292
244293 def test_middleware_validation_errors (self ):
245294 """Test middleware validation for required parameters"""
246295 SQLAlchemyMiddleware , DBSession = create_middleware_and_session_proxy ()
247-
296+
248297 app = FastAPI (title = "Test App" )
249-
298+
250299 # Test missing both db_url and custom_engine by calling the constructor directly
251300 with self .assertRaises (ValueError ) as context :
252301 # This should raise ValueError when called directly
253- middleware = SQLAlchemyMiddleware (app )
254-
255- self .assertIn ("You need to pass a db_url or a custom_engine parameter" , str (context .exception ))
302+ SQLAlchemyMiddleware (app )
303+
304+ self .assertIn (
305+ "You need to pass a db_url or a custom_engine parameter" ,
306+ str (context .exception ),
307+ )
256308
257309 def test_db_session_context_manager (self ):
258310 """Test DBSession context manager functionality"""
259311 SQLAlchemyMiddleware , DBSession = create_middleware_and_session_proxy ()
260-
312+
261313 # Test DBSession creation
262314 db_session = DBSession (commit_on_exit = True )
263315 self .assertIsInstance (db_session , DBSession )
264316 self .assertTrue (db_session .commit_on_exit )
265-
317+
266318 # Test multi-sessions mode
267319 multi_db_session = DBSession (multi_sessions = True , commit_on_exit = False )
268320 self .assertIsInstance (multi_db_session , DBSession )
@@ -272,26 +324,30 @@ def test_db_session_context_manager(self):
272324 def test_session_proxy_creation (self ):
273325 """Test that the session proxy is created correctly"""
274326 SQLAlchemyMiddleware , DBSession = create_middleware_and_session_proxy ()
275-
327+
276328 # Verify that we get the expected classes
277329 if FASTAPI_AVAILABLE :
278- self .assertTrue (issubclass (SQLAlchemyMiddleware , BaseHTTPMiddleware ))
279-
330+ self .assertTrue (
331+ issubclass (SQLAlchemyMiddleware , BaseHTTPMiddleware )
332+ )
333+
280334 # Test that DBSession has the expected metaclass behavior
281- self .assertTrue (hasattr (DBSession , ' __aenter__' ))
282- self .assertTrue (hasattr (DBSession , ' __aexit__' ))
283-
335+ self .assertTrue (hasattr (DBSession , " __aenter__" ))
336+ self .assertTrue (hasattr (DBSession , " __aexit__" ))
337+
284338 # Test that the session property exists on the class (without accessing it)
285339 # We check the metaclass has the session property
286- self .assertTrue (hasattr (type (DBSession ), 'session' ))
287- self .assertTrue (isinstance (getattr (type (DBSession ), 'session' ), property ))
340+ self .assertTrue (hasattr (type (DBSession ), "session" ))
341+ self .assertTrue (
342+ isinstance (getattr (type (DBSession ), "session" ), property )
343+ )
288344
289345 def test_middleware_with_session_args (self ):
290346 """Test middleware creation with session arguments"""
291347 SQLAlchemyMiddleware , DBSession = create_middleware_and_session_proxy ()
292-
348+
293349 app = FastAPI (title = "Test App with Session Args" )
294-
350+
295351 try :
296352 app .add_middleware (
297353 SQLAlchemyMiddleware ,
@@ -306,17 +362,22 @@ def test_middleware_with_session_args(self):
306362 },
307363 commit_on_exit = True ,
308364 )
309- self .assertTrue (True , "Custom FastAPI middleware with session args created successfully" )
365+ self .assertTrue (
366+ True ,
367+ "Custom FastAPI middleware with session args created successfully" ,
368+ )
310369 except Exception as e :
311- self .fail (f"Failed to create custom FastAPI middleware with session args: { e } " )
370+ self .fail (
371+ f"Failed to create custom FastAPI middleware with session args: { e } "
372+ )
312373
313374 @unittest .skipUnless (FASTAPI_AVAILABLE , "FastAPI not available" )
314375 def test_full_integration_mock (self ):
315376 """Test full integration with mocked database operations"""
316377 SQLAlchemyMiddleware , DBSession = create_middleware_and_session_proxy ()
317-
378+
318379 app = FastAPI (title = "Full Integration Test" )
319-
380+
320381 # Add the middleware
321382 app .add_middleware (
322383 SQLAlchemyMiddleware ,
@@ -328,27 +389,33 @@ def test_full_integration_mock(self):
328389 "max_overflow" : settings .DB_POOL_SIZE_MAX ,
329390 "pool_recycle" : 900 ,
330391 "connect_args" : {
331- "server_settings" : {"jit" : "off" , "application_name" : settings .SERVICE_NAME },
392+ "server_settings" : {
393+ "jit" : "off" ,
394+ "application_name" : settings .SERVICE_NAME ,
395+ },
332396 "ssl" : SSLMode .disable ,
333397 },
334398 },
335399 commit_on_exit = True ,
336400 )
337-
401+
338402 # Add a test route
339403 @app .get ("/test" )
340404 async def test_route ():
341405 return {"message" : "Test successful" }
342-
406+
343407 # Test that the app and middleware were set up successfully
344408 # We don't need to actually make HTTP requests, just verify setup
345409 self .assertIsNotNone (app )
346- self .assertTrue (True , "Full integration test completed - middleware setup successful" )
410+ self .assertTrue (
411+ True ,
412+ "Full integration test completed - middleware setup successful" ,
413+ )
347414
348415
349416if __name__ == "__main__" :
350417 if FASTAPI_AVAILABLE :
351418 unittest .main ()
352419 else :
353420 print ("FastAPI not available. Skipping tests." )
354- print ("To run these tests, install FastAPI: pip install fastapi" )
421+ print ("To run these tests, install FastAPI: pip install fastapi" )
0 commit comments