|
5 | 5 | import logging |
6 | 6 | import sys |
7 | 7 | from datetime import datetime, timedelta, timezone |
8 | | -from typing import Any, Dict, List, Optional, Tuple, Union |
| 8 | +from typing import Any, Dict, List, Optional, Tuple, Union, cast |
9 | 9 |
|
10 | 10 | from pyathena.aio.util import async_retry_api_call |
11 | | -from pyathena.common import BaseCursor |
12 | | -from pyathena.error import DatabaseError, OperationalError |
| 11 | +from pyathena.common import BaseCursor, CursorIterator |
| 12 | +from pyathena.error import DatabaseError, OperationalError, ProgrammingError |
13 | 13 | from pyathena.model import AthenaDatabase, AthenaQueryExecution, AthenaTableMetadata |
| 14 | +from pyathena.result_set import AthenaResultSet, WithResultSet |
14 | 15 |
|
15 | 16 | _logger = logging.getLogger(__name__) # type: ignore |
16 | 17 |
|
@@ -346,3 +347,155 @@ async def list_table_metadata( # type: ignore[override] |
346 | 347 | if not next_token: |
347 | 348 | break |
348 | 349 | return metadata |
| 350 | + |
| 351 | + |
| 352 | +class WithAsyncFetch(AioBaseCursor, CursorIterator, WithResultSet): |
| 353 | + """Mixin providing shared fetch, lifecycle, and async protocol for SQL cursors. |
| 354 | +
|
| 355 | + Provides properties (``arraysize``, ``result_set``, ``query_id``, |
| 356 | + ``rownumber``, ``rowcount``), lifecycle methods (``close``, ``executemany``, |
| 357 | + ``cancel``), default sync fetch (for cursors whose result sets load all |
| 358 | + data eagerly in ``__init__``), and the async iteration protocol. |
| 359 | +
|
| 360 | + Subclasses override ``execute()`` and optionally ``__init__`` and |
| 361 | + format-specific helpers. |
| 362 | + """ |
| 363 | + |
| 364 | + def __init__(self, **kwargs) -> None: |
| 365 | + super().__init__(**kwargs) |
| 366 | + self._query_id: Optional[str] = None |
| 367 | + self._result_set: Optional[AthenaResultSet] = None |
| 368 | + self._on_start_query_execution = kwargs.get("on_start_query_execution") |
| 369 | + |
| 370 | + @property |
| 371 | + def arraysize(self) -> int: |
| 372 | + return self._arraysize |
| 373 | + |
| 374 | + @arraysize.setter |
| 375 | + def arraysize(self, value: int) -> None: |
| 376 | + if value <= 0: |
| 377 | + raise ProgrammingError("arraysize must be a positive integer value.") |
| 378 | + self._arraysize = value |
| 379 | + |
| 380 | + @property # type: ignore |
| 381 | + def result_set(self) -> Optional[AthenaResultSet]: |
| 382 | + return self._result_set |
| 383 | + |
| 384 | + @result_set.setter |
| 385 | + def result_set(self, val) -> None: |
| 386 | + self._result_set = val |
| 387 | + |
| 388 | + @property |
| 389 | + def query_id(self) -> Optional[str]: |
| 390 | + return self._query_id |
| 391 | + |
| 392 | + @query_id.setter |
| 393 | + def query_id(self, val) -> None: |
| 394 | + self._query_id = val |
| 395 | + |
| 396 | + @property |
| 397 | + def rownumber(self) -> Optional[int]: |
| 398 | + return self.result_set.rownumber if self.result_set else None |
| 399 | + |
| 400 | + @property |
| 401 | + def rowcount(self) -> int: |
| 402 | + return self.result_set.rowcount if self.result_set else -1 |
| 403 | + |
| 404 | + def close(self) -> None: |
| 405 | + """Close the cursor and release associated resources.""" |
| 406 | + if self.result_set and not self.result_set.is_closed: |
| 407 | + self.result_set.close() |
| 408 | + |
| 409 | + async def executemany( # type: ignore[override] |
| 410 | + self, |
| 411 | + operation: str, |
| 412 | + seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], |
| 413 | + **kwargs, |
| 414 | + ) -> None: |
| 415 | + """Execute a SQL query multiple times with different parameters. |
| 416 | +
|
| 417 | + Args: |
| 418 | + operation: SQL query string to execute. |
| 419 | + seq_of_parameters: Sequence of parameter sets, one per execution. |
| 420 | + **kwargs: Additional keyword arguments passed to each ``execute()``. |
| 421 | + """ |
| 422 | + for parameters in seq_of_parameters: |
| 423 | + await self.execute(operation, parameters, **kwargs) |
| 424 | + # Operations that have result sets are not allowed with executemany. |
| 425 | + self._reset_state() |
| 426 | + |
| 427 | + async def cancel(self) -> None: |
| 428 | + """Cancel the currently executing query. |
| 429 | +
|
| 430 | + Raises: |
| 431 | + ProgrammingError: If no query is currently executing. |
| 432 | + """ |
| 433 | + if not self.query_id: |
| 434 | + raise ProgrammingError("QueryExecutionId is none or empty.") |
| 435 | + await self._cancel(self.query_id) |
| 436 | + |
| 437 | + def fetchone( |
| 438 | + self, |
| 439 | + ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: |
| 440 | + """Fetch the next row of the result set. |
| 441 | +
|
| 442 | + Returns: |
| 443 | + A tuple representing the next row, or None if no more rows. |
| 444 | +
|
| 445 | + Raises: |
| 446 | + ProgrammingError: If no result set is available. |
| 447 | + """ |
| 448 | + if not self.has_result_set: |
| 449 | + raise ProgrammingError("No result set.") |
| 450 | + result_set = cast(AthenaResultSet, self.result_set) |
| 451 | + return result_set.fetchone() |
| 452 | + |
| 453 | + def fetchmany( |
| 454 | + self, size: Optional[int] = None |
| 455 | + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: |
| 456 | + """Fetch multiple rows from the result set. |
| 457 | +
|
| 458 | + Args: |
| 459 | + size: Maximum number of rows to fetch. Defaults to arraysize. |
| 460 | +
|
| 461 | + Returns: |
| 462 | + List of tuples representing the fetched rows. |
| 463 | +
|
| 464 | + Raises: |
| 465 | + ProgrammingError: If no result set is available. |
| 466 | + """ |
| 467 | + if not self.has_result_set: |
| 468 | + raise ProgrammingError("No result set.") |
| 469 | + result_set = cast(AthenaResultSet, self.result_set) |
| 470 | + return result_set.fetchmany(size) |
| 471 | + |
| 472 | + def fetchall( |
| 473 | + self, |
| 474 | + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: |
| 475 | + """Fetch all remaining rows from the result set. |
| 476 | +
|
| 477 | + Returns: |
| 478 | + List of tuples representing all remaining rows. |
| 479 | +
|
| 480 | + Raises: |
| 481 | + ProgrammingError: If no result set is available. |
| 482 | + """ |
| 483 | + if not self.has_result_set: |
| 484 | + raise ProgrammingError("No result set.") |
| 485 | + result_set = cast(AthenaResultSet, self.result_set) |
| 486 | + return result_set.fetchall() |
| 487 | + |
| 488 | + def __aiter__(self): |
| 489 | + return self |
| 490 | + |
| 491 | + async def __anext__(self): |
| 492 | + row = self.fetchone() |
| 493 | + if row is None: |
| 494 | + raise StopAsyncIteration |
| 495 | + return row |
| 496 | + |
| 497 | + async def __aenter__(self): |
| 498 | + return self |
| 499 | + |
| 500 | + async def __aexit__(self, exc_type, exc_val, exc_tb): |
| 501 | + self.close() |
0 commit comments