Skip to content

Commit 17fad7a

Browse files
author
Eugene Shershen
committed
add CompatibleNullPool to handle pool sizing arguments and associated tests
1 parent 90c9368 commit 17fad7a

4 files changed

Lines changed: 100 additions & 4 deletions

File tree

psqlpy_sqlalchemy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
PsqlpyDialect = PSQLPyAsyncDialect
44

5-
__version__ = "0.1.0a4"
5+
__version__ = "0.1.0a5"
66
__all__ = ["PsqlpyDialect", "PSQLPyAsyncDialect"]

psqlpy_sqlalchemy/dialect.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,30 @@
66
from sqlalchemy import URL, util
77
from sqlalchemy.dialects.postgresql.base import INTERVAL, PGDialect
88
from sqlalchemy.dialects.postgresql.json import JSONPathType
9-
from sqlalchemy.pool import AsyncAdaptedQueuePool
9+
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
1010
from sqlalchemy.sql import operators, sqltypes
1111
from sqlalchemy.sql.functions import GenericFunction
1212

1313
from .connection import AsyncAdapt_psqlpy_connection, PGExecutionContext_psqlpy
1414
from .dbapi import PSQLPyAdaptDBAPI
1515

1616

17+
class CompatibleNullPool(NullPool):
18+
"""
19+
A NullPool wrapper that accepts but ignores pool sizing arguments.
20+
21+
This class is used to maintain compatibility with middleware that passes
22+
pool_size and max_overflow arguments, which are not valid for NullPool
23+
but are commonly passed by frameworks like FastAPI with fastapi_async_sqlalchemy.
24+
"""
25+
26+
def __init__(self, creator, pool_size=None, max_overflow=None, **kw):
27+
# Filter out pool sizing arguments that NullPool doesn't accept
28+
filtered_kw = {k: v for k, v in kw.items()
29+
if k not in ('pool_size', 'max_overflow')}
30+
super().__init__(creator, **filtered_kw)
31+
32+
1733
# JSONB aggregation functions
1834
class jsonb_agg(GenericFunction):
1935
"""JSONB aggregation function"""
@@ -262,6 +278,7 @@ def _isolation_lookup(self) -> Dict[str, Any]:
262278
"SERIALIZABLE": psqlpy.IsolationLevel.Serializable,
263279
}
264280

281+
265282
def create_connect_args(
266283
self,
267284
url: URL,
@@ -305,3 +322,6 @@ def get_deferrable(self, connection):
305322

306323
# Backward compatibility alias for entry point system
307324
PsqlpyDialect = PSQLPyAsyncDialect
325+
326+
# Export the compatible pool class for users who need it
327+
__all__ = ['PSQLPyAsyncDialect', 'PsqlpyDialect', 'CompatibleNullPool']

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "psqlpy-sqlalchemy"
7-
version = "0.1.0a4"
7+
version = "0.1.0a5"
88
description = "SQLAlchemy dialect for psqlpy PostgreSQL driver"
99
readme = "README.md"
1010
license = {text = "MIT"}

tests/test_poolclass_compatibility.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlalchemy.ext.asyncio import create_async_engine
1313
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool, QueuePool
1414

15-
from psqlpy_sqlalchemy.dialect import PSQLPyAsyncDialect, PsqlpyDialect
15+
from psqlpy_sqlalchemy.dialect import PSQLPyAsyncDialect, PsqlpyDialect, CompatibleNullPool
1616

1717

1818
class TestPoolclassCompatibility(unittest.TestCase):
@@ -208,6 +208,82 @@ async def _test_explicit_nullpool():
208208
finally:
209209
loop.close()
210210

211+
def test_compatible_nullpool_with_pool_args_sync(self):
212+
"""Test that CompatibleNullPool works with pool arguments in sync engines"""
213+
try:
214+
self.engine = create_engine(
215+
"postgresql+psqlpy://user:password@localhost/test",
216+
poolclass=CompatibleNullPool,
217+
pool_size=5,
218+
max_overflow=10,
219+
)
220+
self.assertIsNotNone(self.engine.dialect)
221+
self.assertEqual(self.engine.dialect.driver, "psqlpy")
222+
self.assertEqual(self.engine.pool.__class__, CompatibleNullPool)
223+
except Exception as e:
224+
# Connection errors are acceptable, we're testing pool creation
225+
if "Invalid argument(s)" in str(e):
226+
self.fail(f"CompatibleNullPool should accept pool arguments: {e}")
227+
228+
def test_compatible_nullpool_with_pool_args_async(self):
229+
"""Test that CompatibleNullPool works with pool arguments in async engines"""
230+
231+
async def _test_compatible_nullpool():
232+
try:
233+
self.async_engine = create_async_engine(
234+
"postgresql+psqlpy://user:password@localhost/test",
235+
poolclass=CompatibleNullPool,
236+
pool_size=5,
237+
max_overflow=10,
238+
)
239+
self.assertIsNotNone(self.async_engine)
240+
self.assertEqual(
241+
self.async_engine.sync_engine.pool.__class__, CompatibleNullPool
242+
)
243+
return True
244+
except Exception as e:
245+
# Connection errors are acceptable, we're testing pool creation
246+
if "Invalid argument(s)" in str(e):
247+
self.fail(f"CompatibleNullPool should accept pool arguments: {e}")
248+
return True
249+
250+
loop = asyncio.new_event_loop()
251+
asyncio.set_event_loop(loop)
252+
try:
253+
result = loop.run_until_complete(_test_compatible_nullpool())
254+
self.assertTrue(result)
255+
finally:
256+
loop.close()
257+
258+
def test_compatible_nullpool_ignores_pool_args(self):
259+
"""Test that CompatibleNullPool ignores pool sizing arguments"""
260+
try:
261+
self.engine = create_engine(
262+
"postgresql+psqlpy://user:password@localhost/test",
263+
poolclass=CompatibleNullPool,
264+
pool_size=100, # Should be ignored
265+
max_overflow=200, # Should be ignored
266+
)
267+
# If we get here, the arguments were successfully ignored
268+
self.assertIsNotNone(self.engine.dialect)
269+
self.assertEqual(self.engine.pool.__class__, CompatibleNullPool)
270+
except Exception as e:
271+
# Connection errors are acceptable
272+
if "Invalid argument(s)" in str(e):
273+
self.fail(f"CompatibleNullPool should ignore pool arguments: {e}")
274+
275+
def test_regular_nullpool_still_fails_with_pool_args(self):
276+
"""Test that regular NullPool still fails with pool arguments (regression test)"""
277+
with self.assertRaises(TypeError) as context:
278+
self.engine = create_engine(
279+
"postgresql+psqlpy://user:password@localhost/test",
280+
poolclass=NullPool,
281+
pool_size=5,
282+
max_overflow=10,
283+
)
284+
285+
self.assertIn("Invalid argument(s) 'pool_size','max_overflow'", str(context.exception))
286+
211287

212288
class TestFastAPIMiddlewareCompatibility(unittest.TestCase):
213289
"""Test cases for FastAPI middleware compatibility"""

0 commit comments

Comments
 (0)