Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 7417461

Browse files
authored
Merge pull request #124 from gvbgduh/aiopg
Aiopg
2 parents 647aea3 + 4198ae2 commit 7417461

File tree

8 files changed

+328
-24
lines changed

8 files changed

+328
-24
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ python:
99
- "3.7"
1010

1111
env:
12-
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db"
12+
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db, postgresql+aiopg://localhost/test_database"
1313

1414
services:
1515
- postgresql

databases/backends/aiopg.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import getpass
2+
import json
3+
import logging
4+
import typing
5+
import uuid
6+
7+
import aiopg
8+
from aiopg.sa.engine import APGCompiler_psycopg2
9+
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
10+
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
11+
from sqlalchemy.engine.result import ResultMetaData, RowProxy
12+
from sqlalchemy.sql import ClauseElement
13+
from sqlalchemy.types import TypeEngine
14+
15+
from databases.core import DatabaseURL
16+
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
17+
18+
logger = logging.getLogger("databases")
19+
20+
21+
class AiopgBackend(DatabaseBackend):
22+
def __init__(
23+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
24+
) -> None:
25+
self._database_url = DatabaseURL(database_url)
26+
self._options = options
27+
self._dialect = self._get_dialect()
28+
self._pool = None
29+
30+
def _get_dialect(self) -> Dialect:
31+
dialect = PGDialect_psycopg2(
32+
json_serializer=json.dumps, json_deserializer=lambda x: x
33+
)
34+
dialect.statement_compiler = APGCompiler_psycopg2
35+
dialect.implicit_returning = True
36+
dialect.supports_native_enum = True
37+
dialect.supports_smallserial = True # 9.2+
38+
dialect._backslash_escapes = False
39+
dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+
40+
dialect._has_native_hstore = True
41+
dialect.supports_native_decimal = True
42+
43+
return dialect
44+
45+
def _get_connection_kwargs(self) -> dict:
46+
url_options = self._database_url.options
47+
48+
kwargs = {}
49+
min_size = url_options.get("min_size")
50+
max_size = url_options.get("max_size")
51+
ssl = url_options.get("ssl")
52+
53+
if min_size is not None:
54+
kwargs["minsize"] = int(min_size)
55+
if max_size is not None:
56+
kwargs["maxsize"] = int(max_size)
57+
if ssl is not None:
58+
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]
59+
60+
for key, value in self._options.items():
61+
# Coerce 'min_size' and 'max_size' for consistency.
62+
if key == "min_size":
63+
key = "minsize"
64+
elif key == "max_size":
65+
key = "maxsize"
66+
kwargs[key] = value
67+
68+
return kwargs
69+
70+
async def connect(self) -> None:
71+
assert self._pool is None, "DatabaseBackend is already running"
72+
kwargs = self._get_connection_kwargs()
73+
self._pool = await aiopg.create_pool(
74+
host=self._database_url.hostname,
75+
port=self._database_url.port,
76+
user=self._database_url.username or getpass.getuser(),
77+
password=self._database_url.password,
78+
database=self._database_url.database,
79+
**kwargs,
80+
)
81+
82+
async def disconnect(self) -> None:
83+
assert self._pool is not None, "DatabaseBackend is not running"
84+
self._pool.close()
85+
await self._pool.wait_closed()
86+
self._pool = None
87+
88+
def connection(self) -> "AiopgConnection":
89+
return AiopgConnection(self, self._dialect)
90+
91+
92+
class CompilationContext:
93+
def __init__(self, context: ExecutionContext):
94+
self.context = context
95+
96+
97+
class AiopgConnection(ConnectionBackend):
98+
def __init__(self, database: AiopgBackend, dialect: Dialect):
99+
self._database = database
100+
self._dialect = dialect
101+
self._connection = None # type: typing.Optional[aiopg.Connection]
102+
103+
async def acquire(self) -> None:
104+
assert self._connection is None, "Connection is already acquired"
105+
assert self._database._pool is not None, "DatabaseBackend is not running"
106+
self._connection = await self._database._pool.acquire()
107+
108+
async def release(self) -> None:
109+
assert self._connection is not None, "Connection is not acquired"
110+
assert self._database._pool is not None, "DatabaseBackend is not running"
111+
await self._database._pool.release(self._connection)
112+
self._connection = None
113+
114+
async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
115+
assert self._connection is not None, "Connection is not acquired"
116+
query, args, context = self._compile(query)
117+
cursor = await self._connection.cursor()
118+
try:
119+
await cursor.execute(query, args)
120+
rows = await cursor.fetchall()
121+
metadata = ResultMetaData(context, cursor.description)
122+
return [
123+
RowProxy(metadata, row, metadata._processors, metadata._keymap)
124+
for row in rows
125+
]
126+
finally:
127+
cursor.close()
128+
129+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
130+
assert self._connection is not None, "Connection is not acquired"
131+
query, args, context = self._compile(query)
132+
cursor = await self._connection.cursor()
133+
try:
134+
await cursor.execute(query, args)
135+
row = await cursor.fetchone()
136+
if row is None:
137+
return None
138+
metadata = ResultMetaData(context, cursor.description)
139+
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
140+
finally:
141+
cursor.close()
142+
143+
async def execute(self, query: ClauseElement) -> typing.Any:
144+
assert self._connection is not None, "Connection is not acquired"
145+
query, args, context = self._compile(query)
146+
cursor = await self._connection.cursor()
147+
try:
148+
await cursor.execute(query, args)
149+
return cursor.lastrowid
150+
finally:
151+
cursor.close()
152+
153+
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
154+
assert self._connection is not None, "Connection is not acquired"
155+
cursor = await self._connection.cursor()
156+
try:
157+
for single_query in queries:
158+
single_query, args, context = self._compile(single_query)
159+
await cursor.execute(single_query, args)
160+
finally:
161+
cursor.close()
162+
163+
async def iterate(
164+
self, query: ClauseElement
165+
) -> typing.AsyncGenerator[typing.Any, None]:
166+
assert self._connection is not None, "Connection is not acquired"
167+
query, args, context = self._compile(query)
168+
cursor = await self._connection.cursor()
169+
try:
170+
await cursor.execute(query, args)
171+
metadata = ResultMetaData(context, cursor.description)
172+
async for row in cursor:
173+
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
174+
finally:
175+
cursor.close()
176+
177+
def transaction(self) -> TransactionBackend:
178+
return AiopgTransaction(self)
179+
180+
def _compile(
181+
self, query: ClauseElement
182+
) -> typing.Tuple[str, dict, CompilationContext]:
183+
compiled = query.compile(dialect=self._dialect)
184+
args = compiled.construct_params()
185+
for key, val in args.items():
186+
if key in compiled._bind_processors:
187+
args[key] = compiled._bind_processors[key](val)
188+
189+
execution_context = self._dialect.execution_ctx_cls()
190+
execution_context.dialect = self._dialect
191+
execution_context.result_column_struct = (
192+
compiled._result_columns,
193+
compiled._ordered_columns,
194+
compiled._textual_ordered_columns,
195+
)
196+
197+
logger.debug("Query: %s\nArgs: %s", compiled.string, args)
198+
return compiled.string, args, CompilationContext(execution_context)
199+
200+
@property
201+
def raw_connection(self) -> aiopg.connection.Connection:
202+
assert self._connection is not None, "Connection is not acquired"
203+
return self._connection
204+
205+
206+
class AiopgTransaction(TransactionBackend):
207+
def __init__(self, connection: AiopgConnection):
208+
self._connection = connection
209+
self._is_root = False
210+
self._savepoint_name = ""
211+
212+
async def start(self, is_root: bool) -> None:
213+
assert self._connection._connection is not None, "Connection is not acquired"
214+
self._is_root = is_root
215+
cursor = await self._connection._connection.cursor()
216+
if self._is_root:
217+
await cursor.execute("BEGIN")
218+
else:
219+
id = str(uuid.uuid4()).replace("-", "_")
220+
self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
221+
try:
222+
await cursor.execute(f"SAVEPOINT {self._savepoint_name}")
223+
finally:
224+
cursor.close()
225+
226+
async def commit(self) -> None:
227+
assert self._connection._connection is not None, "Connection is not acquired"
228+
cursor = await self._connection._connection.cursor()
229+
if self._is_root:
230+
await cursor.execute("COMMIT")
231+
else:
232+
try:
233+
await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint_name}")
234+
finally:
235+
cursor.close()
236+
237+
async def rollback(self) -> None:
238+
assert self._connection._connection is not None, "Connection is not acquired"
239+
cursor = await self._connection._connection.cursor()
240+
if self._is_root:
241+
await cursor.execute("ROLLBACK")
242+
else:
243+
try:
244+
await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint_name}")
245+
finally:
246+
cursor.close()

databases/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
class Database:
4343
SUPPORTED_BACKENDS = {
4444
"postgresql": "databases.backends.postgres:PostgresBackend",
45+
"postgresql+aiopg": "databases.backends.aiopg:AiopgBackend",
4546
"postgres": "databases.backends.postgres:PostgresBackend",
4647
"mysql": "databases.backends.mysql:MySQLBackend",
4748
"sqlite": "databases.backends.sqlite:SQLiteBackend",
@@ -60,7 +61,7 @@ def __init__(
6061

6162
self._force_rollback = force_rollback
6263

63-
backend_str = self.SUPPORTED_BACKENDS[self.url.dialect]
64+
backend_str = self.SUPPORTED_BACKENDS[self.url.scheme]
6465
backend_cls = import_from_string(backend_str)
6566
assert issubclass(backend_cls, DatabaseBackend)
6667
self._backend = backend_cls(self.url, **self.options)
@@ -367,6 +368,10 @@ def components(self) -> SplitResult:
367368
self._components = urlsplit(self._url)
368369
return self._components
369370

371+
@property
372+
def scheme(self) -> str:
373+
return self.components.scheme
374+
370375
@property
371376
def dialect(self) -> str:
372377
return self.components.scheme.split("+")[0]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ aiocontextvars;python_version<"3.7"
66

77
# Async database drivers
88
aiomysql
9+
aiopg
910
aiosqlite
1011
asyncpg
1112

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def get_packages(package):
5252
extras_require={
5353
"postgresql": ["asyncpg", "psycopg2-binary"],
5454
"mysql": ["aiomysql", "pymysql"],
55-
"sqlite": ["aiosqlite"]
55+
"sqlite": ["aiosqlite"],
56+
"postgresql+aiopg": ["aiopg"]
5657
},
5758
classifiers=[
5859
"Development Status :: 3 - Alpha",

tests/test_connection_options.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Unit tests for the backend connection arguments.
33
"""
44

5+
from databases.backends.aiopg import AiopgBackend
56
from databases.backends.mysql import MySQLBackend
67
from databases.backends.postgres import PostgresBackend
78

@@ -52,3 +53,31 @@ def test_mysql_explicit_ssl():
5253
backend = MySQLBackend("mysql://localhost/database", ssl=True)
5354
kwargs = backend._get_connection_kwargs()
5455
assert kwargs == {"ssl": True}
56+
57+
58+
def test_aiopg_pool_size():
59+
backend = AiopgBackend(
60+
"postgresql+aiopg://localhost/database?min_size=1&max_size=20"
61+
)
62+
kwargs = backend._get_connection_kwargs()
63+
assert kwargs == {"minsize": 1, "maxsize": 20}
64+
65+
66+
def test_aiopg_explicit_pool_size():
67+
backend = AiopgBackend(
68+
"postgresql+aiopg://localhost/database", min_size=1, max_size=20
69+
)
70+
kwargs = backend._get_connection_kwargs()
71+
assert kwargs == {"minsize": 1, "maxsize": 20}
72+
73+
74+
def test_aiopg_ssl():
75+
backend = AiopgBackend("postgresql+aiopg://localhost/database?ssl=true")
76+
kwargs = backend._get_connection_kwargs()
77+
assert kwargs == {"ssl": True}
78+
79+
80+
def test_aiopg_explicit_ssl():
81+
backend = AiopgBackend("postgresql+aiopg://localhost/database", ssl=True)
82+
kwargs = backend._get_connection_kwargs()
83+
assert kwargs == {"ssl": True}

0 commit comments

Comments
 (0)