Skip to content

Commit 3601644

Browse files
Rename AioCursorBase to WithAsyncFetch and consolidate into common.py
Move the shared SQL cursor mixin from aio/base.py into aio/common.py and rename from AioCursorBase to WithAsyncFetch to follow the existing WithXXX naming convention (WithResultSet, WithCalculationExecution) where XXX describes the functionality provided. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f5a89d6 commit 3601644

File tree

7 files changed

+166
-171
lines changed

7 files changed

+166
-171
lines changed

pyathena/aio/arrow/cursor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
77

8-
from pyathena.aio.base import AioCursorBase
8+
from pyathena.aio.common import WithAsyncFetch
99
from pyathena.arrow.converter import (
1010
DefaultArrowTypeConverter,
1111
DefaultArrowUnloadTypeConverter,
@@ -22,7 +22,7 @@
2222
_logger = logging.getLogger(__name__) # type: ignore
2323

2424

25-
class AioArrowCursor(AioCursorBase):
25+
class AioArrowCursor(WithAsyncFetch):
2626
"""Native asyncio cursor that returns results as Apache Arrow Tables.
2727
2828
Uses ``asyncio.to_thread()`` to create the result set off the event loop.

pyathena/aio/base.py

Lines changed: 0 additions & 158 deletions
This file was deleted.

pyathena/aio/common.py

Lines changed: 156 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import logging
66
import sys
77
from datetime import datetime, timedelta, timezone
8-
from typing import Any, Dict, List, Optional, Tuple, Union
8+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
99

1010
from pyathena.aio.util import async_retry_api_call
11-
from pyathena.common import BaseCursor
12-
from pyathena.error import DatabaseError, OperationalError
11+
from pyathena.common import BaseCursor, CursorIterator
12+
from pyathena.error import DatabaseError, OperationalError, ProgrammingError
1313
from pyathena.model import AthenaDatabase, AthenaQueryExecution, AthenaTableMetadata
14+
from pyathena.result_set import AthenaResultSet, WithResultSet
1415

1516
_logger = logging.getLogger(__name__) # type: ignore
1617

@@ -346,3 +347,155 @@ async def list_table_metadata( # type: ignore[override]
346347
if not next_token:
347348
break
348349
return metadata
350+
351+
352+
class WithAsyncFetch(AioBaseCursor, CursorIterator, WithResultSet):
353+
"""Mixin providing shared fetch, lifecycle, and async protocol for SQL cursors.
354+
355+
Provides properties (``arraysize``, ``result_set``, ``query_id``,
356+
``rownumber``, ``rowcount``), lifecycle methods (``close``, ``executemany``,
357+
``cancel``), default sync fetch (for cursors whose result sets load all
358+
data eagerly in ``__init__``), and the async iteration protocol.
359+
360+
Subclasses override ``execute()`` and optionally ``__init__`` and
361+
format-specific helpers.
362+
"""
363+
364+
def __init__(self, **kwargs) -> None:
365+
super().__init__(**kwargs)
366+
self._query_id: Optional[str] = None
367+
self._result_set: Optional[AthenaResultSet] = None
368+
self._on_start_query_execution = kwargs.get("on_start_query_execution")
369+
370+
@property
371+
def arraysize(self) -> int:
372+
return self._arraysize
373+
374+
@arraysize.setter
375+
def arraysize(self, value: int) -> None:
376+
if value <= 0:
377+
raise ProgrammingError("arraysize must be a positive integer value.")
378+
self._arraysize = value
379+
380+
@property # type: ignore
381+
def result_set(self) -> Optional[AthenaResultSet]:
382+
return self._result_set
383+
384+
@result_set.setter
385+
def result_set(self, val) -> None:
386+
self._result_set = val
387+
388+
@property
389+
def query_id(self) -> Optional[str]:
390+
return self._query_id
391+
392+
@query_id.setter
393+
def query_id(self, val) -> None:
394+
self._query_id = val
395+
396+
@property
397+
def rownumber(self) -> Optional[int]:
398+
return self.result_set.rownumber if self.result_set else None
399+
400+
@property
401+
def rowcount(self) -> int:
402+
return self.result_set.rowcount if self.result_set else -1
403+
404+
def close(self) -> None:
405+
"""Close the cursor and release associated resources."""
406+
if self.result_set and not self.result_set.is_closed:
407+
self.result_set.close()
408+
409+
async def executemany( # type: ignore[override]
410+
self,
411+
operation: str,
412+
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
413+
**kwargs,
414+
) -> None:
415+
"""Execute a SQL query multiple times with different parameters.
416+
417+
Args:
418+
operation: SQL query string to execute.
419+
seq_of_parameters: Sequence of parameter sets, one per execution.
420+
**kwargs: Additional keyword arguments passed to each ``execute()``.
421+
"""
422+
for parameters in seq_of_parameters:
423+
await self.execute(operation, parameters, **kwargs)
424+
# Operations that have result sets are not allowed with executemany.
425+
self._reset_state()
426+
427+
async def cancel(self) -> None:
428+
"""Cancel the currently executing query.
429+
430+
Raises:
431+
ProgrammingError: If no query is currently executing.
432+
"""
433+
if not self.query_id:
434+
raise ProgrammingError("QueryExecutionId is none or empty.")
435+
await self._cancel(self.query_id)
436+
437+
def fetchone(
438+
self,
439+
) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
440+
"""Fetch the next row of the result set.
441+
442+
Returns:
443+
A tuple representing the next row, or None if no more rows.
444+
445+
Raises:
446+
ProgrammingError: If no result set is available.
447+
"""
448+
if not self.has_result_set:
449+
raise ProgrammingError("No result set.")
450+
result_set = cast(AthenaResultSet, self.result_set)
451+
return result_set.fetchone()
452+
453+
def fetchmany(
454+
self, size: Optional[int] = None
455+
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
456+
"""Fetch multiple rows from the result set.
457+
458+
Args:
459+
size: Maximum number of rows to fetch. Defaults to arraysize.
460+
461+
Returns:
462+
List of tuples representing the fetched rows.
463+
464+
Raises:
465+
ProgrammingError: If no result set is available.
466+
"""
467+
if not self.has_result_set:
468+
raise ProgrammingError("No result set.")
469+
result_set = cast(AthenaResultSet, self.result_set)
470+
return result_set.fetchmany(size)
471+
472+
def fetchall(
473+
self,
474+
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
475+
"""Fetch all remaining rows from the result set.
476+
477+
Returns:
478+
List of tuples representing all remaining rows.
479+
480+
Raises:
481+
ProgrammingError: If no result set is available.
482+
"""
483+
if not self.has_result_set:
484+
raise ProgrammingError("No result set.")
485+
result_set = cast(AthenaResultSet, self.result_set)
486+
return result_set.fetchall()
487+
488+
def __aiter__(self):
489+
return self
490+
491+
async def __anext__(self):
492+
row = self.fetchone()
493+
if row is None:
494+
raise StopAsyncIteration
495+
return row
496+
497+
async def __aenter__(self):
498+
return self
499+
500+
async def __aexit__(self, exc_type, exc_val, exc_tb):
501+
self.close()

pyathena/aio/cursor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
from typing import Any, Callable, Dict, List, Optional, Union, cast
66

7-
from pyathena.aio.base import AioCursorBase
7+
from pyathena.aio.common import WithAsyncFetch
88
from pyathena.aio.result_set import AthenaAioDictResultSet, AthenaAioResultSet
99
from pyathena.common import CursorIterator
1010
from pyathena.error import OperationalError, ProgrammingError
@@ -13,7 +13,7 @@
1313
_logger = logging.getLogger(__name__) # type: ignore
1414

1515

16-
class AioCursor(AioCursorBase):
16+
class AioCursor(WithAsyncFetch):
1717
"""Native asyncio cursor for Amazon Athena.
1818
1919
Unlike ``AsyncCursor`` (which uses ``ThreadPoolExecutor``), this cursor

pyathena/aio/pandas/cursor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
cast,
1717
)
1818

19-
from pyathena.aio.base import AioCursorBase
19+
from pyathena.aio.common import WithAsyncFetch
2020
from pyathena.common import CursorIterator
2121
from pyathena.error import OperationalError, ProgrammingError
2222
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
@@ -32,7 +32,7 @@
3232
_logger = logging.getLogger(__name__) # type: ignore
3333

3434

35-
class AioPandasCursor(AioCursorBase):
35+
class AioPandasCursor(WithAsyncFetch):
3636
"""Native asyncio cursor that returns results as pandas DataFrames.
3737
3838
Uses ``asyncio.to_thread()`` to create the result set off the event loop.

0 commit comments

Comments
 (0)