Skip to content

Commit 8fdba32

Browse files
author
Eugene Shershen
committed
add CompatibleNullPool to handle pool sizing arguments and associated tests
1 parent 17fad7a commit 8fdba32

2 files changed

Lines changed: 29 additions & 13 deletions

File tree

psqlpy_sqlalchemy/dialect.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,19 @@
1717
class CompatibleNullPool(NullPool):
1818
"""
1919
A NullPool wrapper that accepts but ignores pool sizing arguments.
20-
20+
2121
This class is used to maintain compatibility with middleware that passes
2222
pool_size and max_overflow arguments, which are not valid for NullPool
2323
but are commonly passed by frameworks like FastAPI with fastapi_async_sqlalchemy.
2424
"""
25-
25+
2626
def __init__(self, creator, pool_size=None, max_overflow=None, **kw):
2727
# Filter out pool sizing arguments that NullPool doesn't accept
28-
filtered_kw = {k: v for k, v in kw.items()
29-
if k not in ('pool_size', 'max_overflow')}
28+
filtered_kw = {
29+
k: v
30+
for k, v in kw.items()
31+
if k not in ("pool_size", "max_overflow")
32+
}
3033
super().__init__(creator, **filtered_kw)
3134

3235

@@ -278,7 +281,6 @@ def _isolation_lookup(self) -> Dict[str, Any]:
278281
"SERIALIZABLE": psqlpy.IsolationLevel.Serializable,
279282
}
280283

281-
282284
def create_connect_args(
283285
self,
284286
url: URL,
@@ -324,4 +326,4 @@ def get_deferrable(self, connection):
324326
PsqlpyDialect = PSQLPyAsyncDialect
325327

326328
# Export the compatible pool class for users who need it
327-
__all__ = ['PSQLPyAsyncDialect', 'PsqlpyDialect', 'CompatibleNullPool']
329+
__all__ = ["PSQLPyAsyncDialect", "PsqlpyDialect", "CompatibleNullPool"]

tests/test_poolclass_compatibility.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from sqlalchemy.ext.asyncio import create_async_engine
1313
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool, QueuePool
1414

15-
from psqlpy_sqlalchemy.dialect import PSQLPyAsyncDialect, PsqlpyDialect, CompatibleNullPool
15+
from psqlpy_sqlalchemy.dialect import (
16+
CompatibleNullPool,
17+
PSQLPyAsyncDialect,
18+
PsqlpyDialect,
19+
)
1620

1721

1822
class TestPoolclassCompatibility(unittest.TestCase):
@@ -223,7 +227,9 @@ def test_compatible_nullpool_with_pool_args_sync(self):
223227
except Exception as e:
224228
# Connection errors are acceptable, we're testing pool creation
225229
if "Invalid argument(s)" in str(e):
226-
self.fail(f"CompatibleNullPool should accept pool arguments: {e}")
230+
self.fail(
231+
f"CompatibleNullPool should accept pool arguments: {e}"
232+
)
227233

228234
def test_compatible_nullpool_with_pool_args_async(self):
229235
"""Test that CompatibleNullPool works with pool arguments in async engines"""
@@ -238,13 +244,16 @@ async def _test_compatible_nullpool():
238244
)
239245
self.assertIsNotNone(self.async_engine)
240246
self.assertEqual(
241-
self.async_engine.sync_engine.pool.__class__, CompatibleNullPool
247+
self.async_engine.sync_engine.pool.__class__,
248+
CompatibleNullPool,
242249
)
243250
return True
244251
except Exception as e:
245252
# Connection errors are acceptable, we're testing pool creation
246253
if "Invalid argument(s)" in str(e):
247-
self.fail(f"CompatibleNullPool should accept pool arguments: {e}")
254+
self.fail(
255+
f"CompatibleNullPool should accept pool arguments: {e}"
256+
)
248257
return True
249258

250259
loop = asyncio.new_event_loop()
@@ -270,7 +279,9 @@ def test_compatible_nullpool_ignores_pool_args(self):
270279
except Exception as e:
271280
# Connection errors are acceptable
272281
if "Invalid argument(s)" in str(e):
273-
self.fail(f"CompatibleNullPool should ignore pool arguments: {e}")
282+
self.fail(
283+
f"CompatibleNullPool should ignore pool arguments: {e}"
284+
)
274285

275286
def test_regular_nullpool_still_fails_with_pool_args(self):
276287
"""Test that regular NullPool still fails with pool arguments (regression test)"""
@@ -281,8 +292,11 @@ def test_regular_nullpool_still_fails_with_pool_args(self):
281292
pool_size=5,
282293
max_overflow=10,
283294
)
284-
285-
self.assertIn("Invalid argument(s) 'pool_size','max_overflow'", str(context.exception))
295+
296+
self.assertIn(
297+
"Invalid argument(s) 'pool_size','max_overflow'",
298+
str(context.exception),
299+
)
286300

287301

288302
class TestFastAPIMiddlewareCompatibility(unittest.TestCase):

0 commit comments

Comments
 (0)