Skip to content

Commit d876865

Browse files
Merge pull request #679 from pyathena-dev/feature/673-async-sqlalchemy-dialect
2 parents 5651427 + 0173d48 commit d876865

File tree

18 files changed

+778
-52
lines changed

18 files changed

+778
-52
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ test: chk
2121
test-sqla:
2222
uv run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/sqlalchemy/
2323

24+
.PHONY: test-sqla-async
25+
test-sqla-async:
26+
uv run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/sqlalchemy/ --dburi async
27+
2428
.PHONY: tox
2529
tox:
2630
uvx tox@$(TOX_VERSION) -c pyproject.toml run

docs/sqlalchemy.md

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22

33
# SQLAlchemy
44

5-
Install SQLAlchemy with `pip install "SQLAlchemy>=1.0.0"` or `pip install PyAthena[SQLAlchemy]`.
5+
Install SQLAlchemy with `pip install "SQLAlchemy>=1.0.0"` or `pip install PyAthena[sqlalchemy]`.
66
Supported SQLAlchemy is 1.0.0 or higher.
77

8+
For async support (`create_async_engine`), install with `pip install PyAthena[aiosqlalchemy]`
9+
(requires SQLAlchemy 2.0+).
10+
11+
### Sync
12+
813
```python
914
from sqlalchemy import func, select
1015
from sqlalchemy.engine import create_engine
@@ -24,6 +29,48 @@ with engine.connect() as connection:
2429
print(result.scalar())
2530
```
2631

32+
### Async
33+
34+
```python
35+
from sqlalchemy import text
36+
from sqlalchemy.ext.asyncio import create_async_engine
37+
38+
conn_str = "awsathena+aiorest://{aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com:443/"\
39+
"{schema_name}?s3_staging_dir={s3_staging_dir}"
40+
engine = create_async_engine(conn_str.format(
41+
aws_access_key_id="YOUR_ACCESS_KEY_ID",
42+
aws_secret_access_key="YOUR_SECRET_ACCESS_KEY",
43+
region_name="us-west-2",
44+
schema_name="default",
45+
s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/"))
46+
47+
async def main():
48+
async with engine.connect() as connection:
49+
result = await connection.execute(text("SELECT * FROM many_rows"))
50+
print(result.fetchall())
51+
await engine.dispose()
52+
```
53+
54+
SQLAlchemy's reflection API (`Table(..., autoload_with=)`, `inspect()`) is synchronous
55+
internally, so it cannot be called directly on an async connection. Use `run_sync()` to
56+
bridge the gap:
57+
58+
```python
59+
from sqlalchemy.sql.schema import Table, MetaData
60+
61+
async with engine.connect() as connection:
62+
# Table reflection
63+
table = await connection.run_sync(
64+
lambda sync_conn: Table("my_table", MetaData(), autoload_with=sync_conn)
65+
)
66+
67+
# Schema inspection
68+
import sqlalchemy
69+
schemas = await connection.run_sync(
70+
lambda sync_conn: sqlalchemy.inspect(sync_conn).get_schema_names()
71+
)
72+
```
73+
2774
## Connection string
2875

2976
The connection string has the following format:
@@ -38,8 +85,16 @@ If you do not specify `aws_access_key_id` and `aws_secret_access_key` using inst
3885
awsathena+rest://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}&...
3986
```
4087

88+
For async, replace the driver portion (e.g. `+rest` with `+aiorest`):
89+
90+
```text
91+
awsathena+aiorest://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}&...
92+
```
93+
4194
## Dialect & driver
4295

96+
### Sync
97+
4398
| Dialect | Driver | Schema | Cursor |
4499
|-----------|--------|------------------|------------------------|
45100
| awsathena | | awsathena | DefaultCursor |
@@ -49,6 +104,18 @@ awsathena+rest://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_stagi
49104
| awsathena | polars | awsathena+polars | {ref}`polars-cursor` |
50105
| awsathena | s3fs | awsathena+s3fs | {ref}`s3fs-cursor` |
51106

107+
### Async
108+
109+
Requires `pip install PyAthena[aiosqlalchemy]` (SQLAlchemy 2.0+).
110+
111+
| Dialect | Driver | Schema | Cursor |
112+
|-----------|-----------|---------------------|------------------------------|
113+
| awsathena | aiorest | awsathena+aiorest | DefaultCursor (async) |
114+
| awsathena | aiopandas | awsathena+aiopandas | {ref}`pandas-cursor` (async) |
115+
| awsathena | aioarrow | awsathena+aioarrow | {ref}`arrow-cursor` (async) |
116+
| awsathena | aiopolars | awsathena+aiopolars | {ref}`polars-cursor` (async) |
117+
| awsathena | aios3fs | awsathena+aios3fs | {ref}`s3fs-cursor` (async) |
118+
52119
## Dialect options
53120

54121
### Table options
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# -*- coding: utf-8 -*-

pyathena/aio/sqlalchemy/arrow.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# -*- coding: utf-8 -*-
2+
from typing import TYPE_CHECKING
3+
4+
from pyathena.aio.sqlalchemy.base import AthenaAioDialect
5+
from pyathena.util import strtobool
6+
7+
if TYPE_CHECKING:
8+
from types import ModuleType
9+
10+
11+
class AthenaAioArrowDialect(AthenaAioDialect):
12+
"""Async SQLAlchemy dialect for Amazon Athena with Apache Arrow result format.
13+
14+
This dialect uses ``AioArrowCursor`` for native asyncio query execution
15+
with Apache Arrow Table results.
16+
17+
Connection URL Format:
18+
``awsathena+aioarrow://{access_key}:{secret_key}@athena.{region}.amazonaws.com/{schema}``
19+
20+
Query Parameters:
21+
In addition to the base dialect parameters:
22+
- unload: If "true", use UNLOAD for Parquet output
23+
24+
Example:
25+
>>> from sqlalchemy.ext.asyncio import create_async_engine
26+
>>> engine = create_async_engine(
27+
... "awsathena+aioarrow://:@athena.us-west-2.amazonaws.com/default"
28+
... "?s3_staging_dir=s3://my-bucket/athena-results/"
29+
... "&unload=true"
30+
... )
31+
32+
See Also:
33+
:class:`~pyathena.aio.arrow.cursor.AioArrowCursor`: The underlying async cursor.
34+
:class:`~pyathena.aio.sqlalchemy.base.AthenaAioDialect`: Base async dialect.
35+
"""
36+
37+
driver = "aioarrow"
38+
supports_statement_cache = True
39+
40+
def create_connect_args(self, url):
41+
from pyathena.aio.arrow.cursor import AioArrowCursor
42+
43+
opts = super()._create_connect_args(url)
44+
opts.update({"cursor_class": AioArrowCursor})
45+
cursor_kwargs = {}
46+
if "unload" in opts:
47+
cursor_kwargs.update({"unload": bool(strtobool(opts.pop("unload")))})
48+
if cursor_kwargs:
49+
opts.update({"cursor_kwargs": cursor_kwargs})
50+
self._connect_options = opts
51+
return [[], opts]
52+
53+
@classmethod
54+
def import_dbapi(cls) -> "ModuleType":
55+
return super().import_dbapi()

pyathena/aio/sqlalchemy/base.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import annotations
3+
4+
from collections import deque
5+
from typing import TYPE_CHECKING, Any, Dict, List, MutableMapping, Optional, Tuple, Union, cast
6+
7+
from sqlalchemy import pool
8+
from sqlalchemy.engine import AdaptedConnection
9+
from sqlalchemy.util.concurrency import await_only
10+
11+
import pyathena
12+
from pyathena.aio.connection import AioConnection
13+
from pyathena.error import (
14+
DatabaseError,
15+
DataError,
16+
Error,
17+
IntegrityError,
18+
InterfaceError,
19+
InternalError,
20+
NotSupportedError,
21+
OperationalError,
22+
ProgrammingError,
23+
)
24+
from pyathena.sqlalchemy.base import AthenaDialect
25+
26+
if TYPE_CHECKING:
27+
from types import ModuleType
28+
29+
from sqlalchemy import URL
30+
31+
32+
class AsyncAdapt_pyathena_cursor: # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi)
33+
"""Wraps any async PyAthena cursor with a sync DBAPI interface.
34+
35+
SQLAlchemy's async engine uses greenlet-based ``await_only()`` to call
36+
async methods from synchronous code running inside the greenlet context.
37+
This adapter wraps an ``AioCursor`` (or variant) so that the dialect can
38+
use a normal synchronous DBAPI interface while the underlying I/O is async.
39+
"""
40+
41+
server_side = False
42+
__slots__ = ("_cursor", "_rows")
43+
44+
def __init__(self, cursor: Any) -> None:
45+
self._cursor = cursor
46+
self._rows: deque[Any] = deque()
47+
48+
@property
49+
def description(self) -> Any:
50+
return self._cursor.description
51+
52+
@property
53+
def rowcount(self) -> int:
54+
return self._cursor.rowcount # type: ignore[no-any-return]
55+
56+
def close(self) -> None:
57+
self._cursor.close()
58+
self._rows.clear()
59+
60+
def execute(self, operation: str, parameters: Any = None, **kwargs: Any) -> Any:
61+
result = await_only(self._cursor.execute(operation, parameters, **kwargs))
62+
if self._cursor.description:
63+
self._rows = deque(await_only(self._cursor.fetchall()))
64+
else:
65+
self._rows.clear()
66+
return result
67+
68+
def executemany(
69+
self,
70+
operation: str,
71+
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
72+
**kwargs: Any,
73+
) -> None:
74+
for parameters in seq_of_parameters:
75+
await_only(self._cursor.execute(operation, parameters, **kwargs))
76+
self._rows.clear()
77+
78+
def fetchone(self) -> Any:
79+
if self._rows:
80+
return self._rows.popleft()
81+
return None
82+
83+
def fetchmany(self, size: Optional[int] = None) -> Any:
84+
if size is None:
85+
size = self._cursor.arraysize if hasattr(self._cursor, "arraysize") else 1
86+
return [self._rows.popleft() for _ in range(min(size, len(self._rows)))]
87+
88+
def fetchall(self) -> Any:
89+
items = list(self._rows)
90+
self._rows.clear()
91+
return items
92+
93+
def setinputsizes(self, sizes: Any) -> None:
94+
self._cursor.setinputsizes(sizes)
95+
96+
# PyAthena-specific methods used by AthenaDialect reflection
97+
def list_databases(self, *args: Any, **kwargs: Any) -> Any:
98+
return await_only(self._cursor.list_databases(*args, **kwargs))
99+
100+
def get_table_metadata(self, *args: Any, **kwargs: Any) -> Any:
101+
return await_only(self._cursor.get_table_metadata(*args, **kwargs))
102+
103+
def list_table_metadata(self, *args: Any, **kwargs: Any) -> Any:
104+
return await_only(self._cursor.list_table_metadata(*args, **kwargs))
105+
106+
def __enter__(self) -> "AsyncAdapt_pyathena_cursor":
107+
return self
108+
109+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
110+
self.close()
111+
112+
113+
class AsyncAdapt_pyathena_connection(AdaptedConnection): # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi)
114+
"""Wraps ``AioConnection`` with a sync DBAPI interface.
115+
116+
This adapted connection delegates ``cursor()`` to the underlying
117+
``AioConnection`` and wraps each returned async cursor with
118+
``AsyncAdapt_pyathena_cursor``.
119+
"""
120+
121+
__slots__ = ("dbapi", "_connection")
122+
123+
def __init__(self, dbapi: "AsyncAdapt_pyathena_dbapi", connection: AioConnection) -> None:
124+
self.dbapi = dbapi
125+
self._connection = connection
126+
127+
@property
128+
def driver_connection(self) -> AioConnection:
129+
return self._connection # type: ignore[no-any-return]
130+
131+
@property
132+
def catalog_name(self) -> Optional[str]:
133+
return self._connection.catalog_name # type: ignore[no-any-return]
134+
135+
@property
136+
def schema_name(self) -> Optional[str]:
137+
return self._connection.schema_name # type: ignore[no-any-return]
138+
139+
def cursor(self) -> AsyncAdapt_pyathena_cursor:
140+
raw_cursor = self._connection.cursor()
141+
return AsyncAdapt_pyathena_cursor(raw_cursor)
142+
143+
def close(self) -> None:
144+
self._connection.close()
145+
146+
def commit(self) -> None:
147+
self._connection.commit()
148+
149+
def rollback(self) -> None:
150+
pass
151+
152+
153+
class AsyncAdapt_pyathena_dbapi: # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi)
154+
"""Fake DBAPI module for the async SQLAlchemy engine.
155+
156+
SQLAlchemy expects ``import_dbapi()`` to return a module-like object
157+
with ``connect()``, ``paramstyle``, and the standard DBAPI exception
158+
hierarchy. This class fulfils that contract while routing connections
159+
through ``AioConnection``.
160+
"""
161+
162+
paramstyle = "pyformat"
163+
164+
# DBAPI exception hierarchy
165+
Error = Error
166+
Warning = pyathena.Warning
167+
InterfaceError = InterfaceError
168+
DatabaseError = DatabaseError
169+
InternalError = InternalError
170+
OperationalError = OperationalError
171+
ProgrammingError = ProgrammingError
172+
IntegrityError = IntegrityError
173+
DataError = DataError
174+
NotSupportedError = NotSupportedError
175+
176+
def connect(self, **kwargs: Any) -> AsyncAdapt_pyathena_connection:
177+
connection = await_only(AioConnection.create(**kwargs))
178+
return AsyncAdapt_pyathena_connection(self, connection)
179+
180+
181+
class AthenaAioDialect(AthenaDialect):
182+
"""Base async SQLAlchemy dialect for Amazon Athena.
183+
184+
Extends the synchronous ``AthenaDialect`` with async capability
185+
by setting ``is_async = True`` and providing an adapted DBAPI module
186+
that wraps ``AioConnection`` and async cursors via greenlet-based
187+
``await_only()``.
188+
189+
Subclasses (e.g. ``AthenaAioRestDialect``, ``AthenaAioPandasDialect``)
190+
register concrete ``awsathena+aio*`` drivers.
191+
192+
See Also:
193+
:class:`~pyathena.sqlalchemy.base.AthenaDialect`: Synchronous base dialect.
194+
:class:`~pyathena.aio.connection.AioConnection`: Native async connection.
195+
"""
196+
197+
is_async = True
198+
supports_statement_cache = True
199+
200+
@classmethod
201+
def get_pool_class(cls, url: "URL") -> type:
202+
return pool.AsyncAdaptedQueuePool
203+
204+
@classmethod
205+
def import_dbapi(cls) -> "ModuleType":
206+
return AsyncAdapt_pyathena_dbapi() # type: ignore[return-value]
207+
208+
@classmethod
209+
def dbapi(cls) -> "ModuleType": # type: ignore[override]
210+
return AsyncAdapt_pyathena_dbapi() # type: ignore[return-value]
211+
212+
def create_connect_args(self, url: "URL") -> Tuple[Tuple[str], MutableMapping[str, Any]]:
213+
opts = self._create_connect_args(url)
214+
self._connect_options = opts
215+
return cast(Tuple[str], ()), opts
216+
217+
def get_driver_connection(self, connection: Any) -> Any:
218+
return connection

0 commit comments

Comments
 (0)