Skip to content

Commit f5a89d6

Browse files
Add AioCursorBase, AioS3FSCursor, and AioSparkCursor for Phase 3
Phase 3 of native asyncio cursor implementation: Part 1 - Boilerplate deduplication: - Extract shared properties, lifecycle methods, sync fetch, and async protocol into AioCursorBase (aio/base.py) - Refactor AioCursor, AioPandasCursor, AioArrowCursor, AioPolarsCursor to extend AioCursorBase, reducing ~520 lines of duplicated code Part 2 - AioS3FSCursor: - Lightweight async CSV cursor using S3FileSystem - Async fetch methods (via asyncio.to_thread) since S3FS uses lazy streaming from S3 Part 3 - AioSparkCursor: - AioSparkBaseCursor overrides post-init I/O with async equivalents (poll, cancel, terminate_session, read_s3_file) - AioSparkCursor for executing PySpark code asynchronously - Session init stays sync (wrapped in asyncio.to_thread at creation) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5238436 commit f5a89d6

File tree

15 files changed

+943
-520
lines changed

15 files changed

+943
-520
lines changed

pyathena/aio/arrow/cursor.py

Lines changed: 4 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import asyncio
55
import logging
6-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, cast
6+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
77

8-
from pyathena.aio.common import AioBaseCursor
8+
from pyathena.aio.base import AioCursorBase
99
from pyathena.arrow.converter import (
1010
DefaultArrowTypeConverter,
1111
DefaultArrowUnloadTypeConverter,
@@ -14,7 +14,6 @@
1414
from pyathena.common import CursorIterator
1515
from pyathena.error import OperationalError, ProgrammingError
1616
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
17-
from pyathena.result_set import WithResultSet
1817

1918
if TYPE_CHECKING:
2019
import polars as pl
@@ -23,7 +22,7 @@
2322
_logger = logging.getLogger(__name__) # type: ignore
2423

2524

26-
class AioArrowCursor(AioBaseCursor, CursorIterator, WithResultSet):
25+
class AioArrowCursor(AioCursorBase):
2726
"""Native asyncio cursor that returns results as Apache Arrow Tables.
2827
2928
Uses ``asyncio.to_thread()`` to create the result set off the event loop.
@@ -66,13 +65,12 @@ def __init__(
6665
kill_on_interrupt=kill_on_interrupt,
6766
result_reuse_enable=result_reuse_enable,
6867
result_reuse_minutes=result_reuse_minutes,
68+
on_start_query_execution=on_start_query_execution,
6969
**kwargs,
7070
)
7171
self._unload = unload
72-
self._on_start_query_execution = on_start_query_execution
7372
self._connect_timeout = connect_timeout
7473
self._request_timeout = request_timeout
75-
self._query_id: Optional[str] = None
7674
self._result_set: Optional[AthenaArrowResultSet] = None
7775

7876
@staticmethod
@@ -83,45 +81,6 @@ def get_default_converter(
8381
return DefaultArrowUnloadTypeConverter()
8482
return DefaultArrowTypeConverter()
8583

86-
@property
87-
def arraysize(self) -> int:
88-
return self._arraysize
89-
90-
@arraysize.setter
91-
def arraysize(self, value: int) -> None:
92-
if value <= 0:
93-
raise ProgrammingError("arraysize must be a positive integer value.")
94-
self._arraysize = value
95-
96-
@property # type: ignore
97-
def result_set(self) -> Optional[AthenaArrowResultSet]:
98-
return self._result_set
99-
100-
@result_set.setter
101-
def result_set(self, val) -> None:
102-
self._result_set = val
103-
104-
@property
105-
def query_id(self) -> Optional[str]:
106-
return self._query_id
107-
108-
@query_id.setter
109-
def query_id(self, val) -> None:
110-
self._query_id = val
111-
112-
@property
113-
def rownumber(self) -> Optional[int]:
114-
return self.result_set.rownumber if self.result_set else None
115-
116-
@property
117-
def rowcount(self) -> int:
118-
return self.result_set.rowcount if self.result_set else -1
119-
120-
def close(self) -> None:
121-
"""Close the cursor and release associated resources."""
122-
if self.result_set and not self.result_set.is_closed:
123-
self.result_set.close()
124-
12584
async def execute( # type: ignore[override]
12685
self,
12786
operation: str,
@@ -202,84 +161,6 @@ async def execute( # type: ignore[override]
202161
raise OperationalError(query_execution.state_change_reason)
203162
return self
204163

205-
async def executemany( # type: ignore[override]
206-
self,
207-
operation: str,
208-
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
209-
**kwargs,
210-
) -> None:
211-
"""Execute a SQL query multiple times with different parameters.
212-
213-
Args:
214-
operation: SQL query string to execute.
215-
seq_of_parameters: Sequence of parameter sets, one per execution.
216-
**kwargs: Additional keyword arguments passed to each ``execute()``.
217-
"""
218-
for parameters in seq_of_parameters:
219-
await self.execute(operation, parameters, **kwargs)
220-
self._reset_state()
221-
222-
async def cancel(self) -> None:
223-
"""Cancel the currently executing query.
224-
225-
Raises:
226-
ProgrammingError: If no query is currently executing.
227-
"""
228-
if not self.query_id:
229-
raise ProgrammingError("QueryExecutionId is none or empty.")
230-
await self._cancel(self.query_id)
231-
232-
def fetchone(
233-
self,
234-
) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
235-
"""Fetch the next row of the result set.
236-
237-
Returns:
238-
A tuple representing the next row, or None if no more rows.
239-
240-
Raises:
241-
ProgrammingError: If no result set is available.
242-
"""
243-
if not self.has_result_set:
244-
raise ProgrammingError("No result set.")
245-
result_set = cast(AthenaArrowResultSet, self.result_set)
246-
return result_set.fetchone()
247-
248-
def fetchmany(
249-
self, size: Optional[int] = None
250-
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
251-
"""Fetch multiple rows from the result set.
252-
253-
Args:
254-
size: Maximum number of rows to fetch. Defaults to arraysize.
255-
256-
Returns:
257-
List of tuples representing the fetched rows.
258-
259-
Raises:
260-
ProgrammingError: If no result set is available.
261-
"""
262-
if not self.has_result_set:
263-
raise ProgrammingError("No result set.")
264-
result_set = cast(AthenaArrowResultSet, self.result_set)
265-
return result_set.fetchmany(size)
266-
267-
def fetchall(
268-
self,
269-
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
270-
"""Fetch all remaining rows from the result set.
271-
272-
Returns:
273-
List of tuples representing all remaining rows.
274-
275-
Raises:
276-
ProgrammingError: If no result set is available.
277-
"""
278-
if not self.has_result_set:
279-
raise ProgrammingError("No result set.")
280-
result_set = cast(AthenaArrowResultSet, self.result_set)
281-
return result_set.fetchall()
282-
283164
def as_arrow(self) -> "Table":
284165
"""Return query results as an Apache Arrow Table.
285166
@@ -301,18 +182,3 @@ def as_polars(self) -> "pl.DataFrame":
301182
raise ProgrammingError("No result set.")
302183
result_set = cast(AthenaArrowResultSet, self.result_set)
303184
return result_set.as_polars()
304-
305-
def __aiter__(self):
306-
return self
307-
308-
async def __anext__(self):
309-
row = self.fetchone()
310-
if row is None:
311-
raise StopAsyncIteration
312-
return row
313-
314-
async def __aenter__(self):
315-
return self
316-
317-
async def __aexit__(self, exc_type, exc_val, exc_tb):
318-
self.close()

pyathena/aio/base.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import annotations
3+
4+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
5+
6+
from pyathena.aio.common import AioBaseCursor
7+
from pyathena.common import CursorIterator
8+
from pyathena.error import ProgrammingError
9+
from pyathena.result_set import AthenaResultSet, WithResultSet
10+
11+
12+
class AioCursorBase(AioBaseCursor, CursorIterator, WithResultSet):
13+
"""Base class for native asyncio SQL cursors.
14+
15+
Provides shared properties, lifecycle methods, and default sync fetch
16+
(for cursors whose result sets load all data eagerly in ``__init__``).
17+
Subclasses override ``execute()`` and optionally ``__init__`` and
18+
format-specific helpers.
19+
"""
20+
21+
def __init__(self, **kwargs) -> None:
22+
super().__init__(**kwargs)
23+
self._query_id: Optional[str] = None
24+
self._result_set: Optional[AthenaResultSet] = None
25+
self._on_start_query_execution = kwargs.get("on_start_query_execution")
26+
27+
@property
28+
def arraysize(self) -> int:
29+
return self._arraysize
30+
31+
@arraysize.setter
32+
def arraysize(self, value: int) -> None:
33+
if value <= 0:
34+
raise ProgrammingError("arraysize must be a positive integer value.")
35+
self._arraysize = value
36+
37+
@property # type: ignore
38+
def result_set(self) -> Optional[AthenaResultSet]:
39+
return self._result_set
40+
41+
@result_set.setter
42+
def result_set(self, val) -> None:
43+
self._result_set = val
44+
45+
@property
46+
def query_id(self) -> Optional[str]:
47+
return self._query_id
48+
49+
@query_id.setter
50+
def query_id(self, val) -> None:
51+
self._query_id = val
52+
53+
@property
54+
def rownumber(self) -> Optional[int]:
55+
return self.result_set.rownumber if self.result_set else None
56+
57+
@property
58+
def rowcount(self) -> int:
59+
return self.result_set.rowcount if self.result_set else -1
60+
61+
def close(self) -> None:
62+
"""Close the cursor and release associated resources."""
63+
if self.result_set and not self.result_set.is_closed:
64+
self.result_set.close()
65+
66+
async def executemany( # type: ignore[override]
67+
self,
68+
operation: str,
69+
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
70+
**kwargs,
71+
) -> None:
72+
"""Execute a SQL query multiple times with different parameters.
73+
74+
Args:
75+
operation: SQL query string to execute.
76+
seq_of_parameters: Sequence of parameter sets, one per execution.
77+
**kwargs: Additional keyword arguments passed to each ``execute()``.
78+
"""
79+
for parameters in seq_of_parameters:
80+
await self.execute(operation, parameters, **kwargs)
81+
# Operations that have result sets are not allowed with executemany.
82+
self._reset_state()
83+
84+
async def cancel(self) -> None:
85+
"""Cancel the currently executing query.
86+
87+
Raises:
88+
ProgrammingError: If no query is currently executing.
89+
"""
90+
if not self.query_id:
91+
raise ProgrammingError("QueryExecutionId is none or empty.")
92+
await self._cancel(self.query_id)
93+
94+
def fetchone(
95+
self,
96+
) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
97+
"""Fetch the next row of the result set.
98+
99+
Returns:
100+
A tuple representing the next row, or None if no more rows.
101+
102+
Raises:
103+
ProgrammingError: If no result set is available.
104+
"""
105+
if not self.has_result_set:
106+
raise ProgrammingError("No result set.")
107+
result_set = cast(AthenaResultSet, self.result_set)
108+
return result_set.fetchone()
109+
110+
def fetchmany(
111+
self, size: Optional[int] = None
112+
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
113+
"""Fetch multiple rows from the result set.
114+
115+
Args:
116+
size: Maximum number of rows to fetch. Defaults to arraysize.
117+
118+
Returns:
119+
List of tuples representing the fetched rows.
120+
121+
Raises:
122+
ProgrammingError: If no result set is available.
123+
"""
124+
if not self.has_result_set:
125+
raise ProgrammingError("No result set.")
126+
result_set = cast(AthenaResultSet, self.result_set)
127+
return result_set.fetchmany(size)
128+
129+
def fetchall(
130+
self,
131+
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
132+
"""Fetch all remaining rows from the result set.
133+
134+
Returns:
135+
List of tuples representing all remaining rows.
136+
137+
Raises:
138+
ProgrammingError: If no result set is available.
139+
"""
140+
if not self.has_result_set:
141+
raise ProgrammingError("No result set.")
142+
result_set = cast(AthenaResultSet, self.result_set)
143+
return result_set.fetchall()
144+
145+
def __aiter__(self):
146+
return self
147+
148+
async def __anext__(self):
149+
row = self.fetchone()
150+
if row is None:
151+
raise StopAsyncIteration
152+
return row
153+
154+
async def __aenter__(self):
155+
return self
156+
157+
async def __aexit__(self, exc_type, exc_val, exc_tb):
158+
self.close()

0 commit comments

Comments
 (0)