Skip to content

Commit a2d9c20

Browse files
Merge pull request #639 from laughingman7743/feature/pandas-cursor-alignment
Align PandasCursor with PolarsCursor and optimize DataFrame operations
2 parents 1651786 + b6d27f0 commit a2d9c20

File tree

9 files changed

+213
-60
lines changed

9 files changed

+213
-60
lines changed

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ Extra packages:
6262
+---------------+---------------------------------------+------------------+
6363
| Arrow | ``pip install PyAthena[Arrow]`` | >=10.0.0 |
6464
+---------------+---------------------------------------+------------------+
65+
| Polars | ``pip install PyAthena[Polars]`` | >=1.0.0 |
66+
+---------------+---------------------------------------+------------------+
6567

6668
.. _usage:
6769

docs/api/pandas.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Pandas Result Set
2323
:members:
2424
:inherited-members:
2525

26-
.. autoclass:: pyathena.pandas.result_set.DataFrameIterator
26+
.. autoclass:: pyathena.pandas.result_set.PandasDataFrameIterator
2727
:members:
2828

2929
Pandas Data Converters

docs/api/polars.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ Polars Result Set
2323
:members:
2424
:inherited-members:
2525

26+
.. autoclass:: pyathena.polars.result_set.PolarsDataFrameIterator
27+
:members:
28+
2629
Polars Data Converters
2730
----------------------
2831

docs/pandas.rst

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ SQLAlchemy allows this option to be specified in the connection string.
381381
382382
awsathena+pandas://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}&chunksize=1000000...
383383
384-
When this option is used, the object returned by the as_pandas method is a ``DataFrameIterator`` object.
384+
When this option is used, the object returned by the as_pandas method is a ``PandasDataFrameIterator`` object.
385385
This object has exactly the same interface as the ``TextFileReader`` object and can be handled in the same way.
386386

387387
.. code:: python
@@ -418,7 +418,20 @@ PandasCursor provides an ``iter_chunks()`` method for convenient chunked process
418418
# Memory can be freed after each chunk
419419
del chunk
420420
421-
You can also concatenate them into a single `pandas.DataFrame object`_ using `pandas.concat`_.
421+
The ``PandasDataFrameIterator`` also has an ``as_pandas()`` method that collects all chunks into a single DataFrame:
422+
423+
.. code:: python
424+
425+
from pyathena import connect
426+
from pyathena.pandas.cursor import PandasCursor
427+
428+
cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
429+
region_name="us-west-2",
430+
cursor_class=PandasCursor).cursor()
431+
df_iter = cursor.execute("SELECT * FROM many_rows", chunksize=1_000_000).as_pandas()
432+
df = df_iter.as_pandas() # Collect all chunks into a single DataFrame
433+
434+
This is equivalent to using `pandas.concat`_:
422435

423436
.. code:: python
424437

docs/polars.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,51 @@ The chunked iteration also works with the unload option:
334334
# Process Parquet data in chunks
335335
process_chunk(chunk)
336336
337+
When the chunksize option is used, the object returned by the ``as_polars`` method is a ``PolarsDataFrameIterator`` object.
338+
This object provides the same chunked iteration interface and can be used in the same way:
339+
340+
.. code:: python
341+
342+
from pyathena import connect
343+
from pyathena.polars.cursor import PolarsCursor
344+
345+
cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
346+
region_name="us-west-2",
347+
cursor_class=PolarsCursor).cursor(chunksize=50_000)
348+
df_iter = cursor.execute("SELECT * FROM many_rows").as_polars()
349+
for df in df_iter:
350+
print(df.describe())
351+
print(df.head())
352+
353+
The ``PolarsDataFrameIterator`` also has an ``as_polars()`` method that collects all chunks into a single DataFrame:
354+
355+
.. code:: python
356+
357+
from pyathena import connect
358+
from pyathena.polars.cursor import PolarsCursor
359+
360+
cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
361+
region_name="us-west-2",
362+
cursor_class=PolarsCursor).cursor(chunksize=50_000)
363+
df_iter = cursor.execute("SELECT * FROM many_rows").as_polars()
364+
df = df_iter.as_polars() # Collect all chunks into a single DataFrame
365+
366+
This is equivalent to using `polars.concat`_:
367+
368+
.. code:: python
369+
370+
import polars as pl
371+
from pyathena import connect
372+
from pyathena.polars.cursor import PolarsCursor
373+
374+
cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
375+
region_name="us-west-2",
376+
cursor_class=PolarsCursor).cursor(chunksize=50_000)
377+
df_iter = cursor.execute("SELECT * FROM many_rows").as_polars()
378+
df = pl.concat(list(df_iter))
379+
380+
.. _`polars.concat`: https://docs.pola.rs/api/python/stable/reference/api/polars.concat.html
381+
337382
.. _async-polars-cursor:
338383

339384
AsyncPolarsCursor

pyathena/pandas/cursor.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
DefaultPandasTypeConverter,
2626
DefaultPandasUnloadTypeConverter,
2727
)
28-
from pyathena.pandas.result_set import AthenaPandasResultSet, DataFrameIterator
28+
from pyathena.pandas.result_set import AthenaPandasResultSet, PandasDataFrameIterator
2929
from pyathena.result_set import WithResultSet
3030

3131
if TYPE_CHECKING:
@@ -331,11 +331,11 @@ def fetchall(
331331
result_set = cast(AthenaPandasResultSet, self.result_set)
332332
return result_set.fetchall()
333333

334-
def as_pandas(self) -> Union["DataFrame", DataFrameIterator]:
335-
"""Return DataFrame or DataFrameIterator based on chunksize setting.
334+
def as_pandas(self) -> Union["DataFrame", PandasDataFrameIterator]:
335+
"""Return DataFrame or PandasDataFrameIterator based on chunksize setting.
336336
337337
Returns:
338-
DataFrame when chunksize is None, DataFrameIterator when chunksize is set.
338+
DataFrame when chunksize is None, PandasDataFrameIterator when chunksize is set.
339339
"""
340340
if not self.has_result_set:
341341
raise ProgrammingError("No result set.")
@@ -380,18 +380,13 @@ def iter_chunks(self) -> Generator["DataFrame", None, None]:
380380
"""
381381
if not self.has_result_set:
382382
raise ProgrammingError("No result set.")
383+
result_set = cast(AthenaPandasResultSet, self.result_set)
383384

384-
result = self.as_pandas()
385-
if isinstance(result, DataFrameIterator):
386-
# It's an iterator (chunked mode)
387-
import gc
385+
import gc
388386

389-
for chunk_count, chunk in enumerate(result, 1):
390-
yield chunk
387+
for chunk_count, chunk in enumerate(result_set.iter_chunks(), 1):
388+
yield chunk
391389

392-
# Suggest garbage collection every 10 chunks for large datasets
393-
if chunk_count % 10 == 0:
394-
gc.collect()
395-
else:
396-
# Single DataFrame - yield as one chunk
397-
yield result
390+
# Suggest garbage collection every 10 chunks for large datasets
391+
if chunk_count % 10 == 0:
392+
gc.collect()

pyathena/pandas/result_set.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _no_trunc_date(df: "DataFrame") -> "DataFrame":
3838
return df
3939

4040

41-
class DataFrameIterator(abc.Iterator): # type: ignore
41+
class PandasDataFrameIterator(abc.Iterator): # type: ignore
4242
"""Iterator for chunked DataFrame results from Athena queries.
4343
4444
This class wraps either a pandas TextFileReader (for chunked reading) or
@@ -68,6 +68,12 @@ def __init__(
6868
reader: Union["TextFileReader", "DataFrame"],
6969
trunc_date: Callable[["DataFrame"], "DataFrame"],
7070
) -> None:
71+
"""Initialize the iterator.
72+
73+
Args:
74+
reader: Either a TextFileReader (for chunked) or a single DataFrame.
75+
trunc_date: Function to apply date truncation to each chunk.
76+
"""
7177
from pandas import DataFrame
7278

7379
if isinstance(reader, DataFrame):
@@ -76,41 +82,88 @@ def __init__(
7682
self._reader = reader
7783
self._trunc_date = trunc_date
7884

79-
def __next__(self):
85+
def __next__(self) -> "DataFrame":
86+
"""Get the next DataFrame chunk.
87+
88+
Returns:
89+
The next pandas DataFrame chunk with date truncation applied.
90+
91+
Raises:
92+
StopIteration: When no more chunks are available.
93+
"""
8094
try:
8195
df = next(self._reader)
8296
return self._trunc_date(df)
8397
except StopIteration:
8498
self.close()
8599
raise
86100

87-
def __iter__(self):
101+
def __iter__(self) -> "PandasDataFrameIterator":
102+
"""Return self as iterator."""
88103
return self
89104

90-
def __enter__(self):
105+
def __enter__(self) -> "PandasDataFrameIterator":
106+
"""Context manager entry."""
91107
return self
92108

93-
def __exit__(self, exc_type, exc_value, traceback):
109+
def __exit__(self, exc_type, exc_value, traceback) -> None:
110+
"""Context manager exit."""
94111
self.close()
95112

96113
def close(self) -> None:
114+
"""Close the iterator and release resources."""
97115
from pandas.io.parsers import TextFileReader
98116

99117
if isinstance(self._reader, TextFileReader):
100118
self._reader.close()
101119

102-
def iterrows(self) -> Iterator[Any]:
120+
def iterrows(self) -> Iterator[Tuple[int, Dict[str, Any]]]:
121+
"""Iterate over rows as (index, row_dict) tuples.
122+
123+
Row indices are continuous across all chunks, starting from 0.
124+
125+
Yields:
126+
Tuple of (row_index, row_dict) for each row across all chunks.
127+
"""
128+
row_num = 0
103129
for df in self:
104-
for row in enumerate(df.to_dict("records")):
105-
yield row
130+
# Use itertuples for memory efficiency instead of to_dict("records")
131+
# which loads all rows into memory at once
132+
columns = df.columns.tolist()
133+
for row in df.itertuples(index=False):
134+
yield (row_num, dict(zip(columns, row, strict=True)))
135+
row_num += 1
136+
137+
def get_chunk(self, size: Optional[int] = None) -> "DataFrame":
138+
"""Get a chunk of specified size.
139+
140+
Args:
141+
size: Number of rows to retrieve. If None, returns entire chunk.
106142
107-
def get_chunk(self, size=None):
143+
Returns:
144+
DataFrame chunk.
145+
"""
108146
from pandas.io.parsers import TextFileReader
109147

110148
if isinstance(self._reader, TextFileReader):
111149
return self._reader.get_chunk(size)
112150
return next(self._reader)
113151

152+
def as_pandas(self) -> "DataFrame":
153+
"""Collect all chunks into a single DataFrame.
154+
155+
Returns:
156+
Single pandas DataFrame containing all data.
157+
"""
158+
import pandas as pd
159+
160+
dfs: List["DataFrame"] = list(self)
161+
if not dfs:
162+
return pd.DataFrame()
163+
if len(dfs) == 1:
164+
return dfs[0]
165+
return pd.concat(dfs, ignore_index=True)
166+
114167

115168
class AthenaPandasResultSet(AthenaResultSet):
116169
"""Result set that provides pandas DataFrame results with memory optimization.
@@ -232,14 +285,21 @@ def __init__(
232285
self._data_manifest: List[str] = []
233286
self._kwargs = kwargs
234287
self._fs = self.__s3_file_system()
288+
289+
# Cache time column names for efficient _trunc_date processing
290+
description = self.description if self.description else []
291+
self._time_columns: List[str] = [
292+
d[0] for d in description if d[1] in ("time", "time with time zone")
293+
]
294+
235295
if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location:
236296
df = self._as_pandas()
237297
trunc_date = _no_trunc_date if self.is_unload else self._trunc_date
238-
self._df_iter = DataFrameIterator(df, trunc_date)
298+
self._df_iter = PandasDataFrameIterator(df, trunc_date)
239299
else:
240300
import pandas as pd
241301

242-
self._df_iter = DataFrameIterator(pd.DataFrame(), _no_trunc_date)
302+
self._df_iter = PandasDataFrameIterator(pd.DataFrame(), _no_trunc_date)
243303
self._iterrows = self._df_iter.iterrows()
244304

245305
def _get_parquet_engine(self) -> str:
@@ -401,12 +461,10 @@ def parse_dates(self) -> List[Optional[Any]]:
401461
return [d[0] for d in description if d[1] in self._PARSE_DATES]
402462

403463
def _trunc_date(self, df: "DataFrame") -> "DataFrame":
404-
description = self.description if self.description else []
405-
times = [d[0] for d in description if d[1] in ("time", "time with time zone")]
406-
if times:
407-
truncated = df.loc[:, times].apply(lambda r: r.dt.time)
408-
for time in times:
409-
df.isetitem(df.columns.get_loc(time), truncated[time])
464+
if self._time_columns:
465+
truncated = df.loc[:, self._time_columns].apply(lambda r: r.dt.time)
466+
for time_col in self._time_columns:
467+
df.isetitem(df.columns.get_loc(time_col), truncated[time_col])
410468
return df
411469

412470
def fetchone(
@@ -620,15 +678,42 @@ def _as_pandas(self) -> Union["TextFileReader", "DataFrame"]:
620678
df = self._read_csv()
621679
return df
622680

623-
def as_pandas(self) -> Union[DataFrameIterator, "DataFrame"]:
681+
def as_pandas(self) -> Union[PandasDataFrameIterator, "DataFrame"]:
624682
if self._chunksize is None:
625683
return next(self._df_iter)
626684
return self._df_iter
627685

686+
def iter_chunks(self) -> PandasDataFrameIterator:
687+
"""Iterate over result chunks as pandas DataFrames.
688+
689+
This method provides an iterator interface for processing large result sets.
690+
When chunksize is specified, it yields DataFrames in chunks for memory-efficient
691+
processing. When chunksize is not specified, it yields the entire result as a
692+
single DataFrame.
693+
694+
Returns:
695+
PandasDataFrameIterator that yields pandas DataFrames for each chunk
696+
of rows, or the entire DataFrame if chunksize was not specified.
697+
698+
Example:
699+
>>> # With chunking for large datasets
700+
>>> cursor = connection.cursor(PandasCursor, chunksize=50000)
701+
>>> cursor.execute("SELECT * FROM large_table")
702+
>>> for chunk in cursor.iter_chunks():
703+
... process_chunk(chunk) # Each chunk is a pandas DataFrame
704+
>>>
705+
>>> # Without chunking - yields entire result as single chunk
706+
>>> cursor = connection.cursor(PandasCursor)
707+
>>> cursor.execute("SELECT * FROM small_table")
708+
>>> for df in cursor.iter_chunks():
709+
... process(df) # Single DataFrame with all data
710+
"""
711+
return self._df_iter
712+
628713
def close(self) -> None:
629714
import pandas as pd
630715

631716
super().close()
632-
self._df_iter = DataFrameIterator(pd.DataFrame(), _no_trunc_date)
717+
self._df_iter = PandasDataFrameIterator(pd.DataFrame(), _no_trunc_date)
633718
self._iterrows = enumerate([])
634719
self._data_manifest = []

0 commit comments

Comments
 (0)