Skip to content

Commit 5651427

Browse files
Merge pull request #678 from pyathena-dev/feature/670-prepare-unload
Extract _prepare_unload() helper into BaseCursor
2 parents e75cd23 + 17f1b16 commit 5651427

11 files changed

Lines changed: 48 additions & 119 deletions

File tree

pyathena/aio/arrow/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pyathena.arrow.result_set import AthenaArrowResultSet
1414
from pyathena.common import CursorIterator
1515
from pyathena.error import OperationalError, ProgrammingError
16-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
16+
from pyathena.model import AthenaQueryExecution
1717

1818
if TYPE_CHECKING:
1919
import polars as pl
@@ -109,18 +109,7 @@ async def execute( # type: ignore[override]
109109
Self reference for method chaining.
110110
"""
111111
self._reset_state()
112-
if self._unload:
113-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
114-
if not s3_staging_dir:
115-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
116-
operation, unload_location = self._formatter.wrap_unload(
117-
operation,
118-
s3_staging_dir=s3_staging_dir,
119-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
120-
compression=AthenaCompression.COMPRESSION_SNAPPY,
121-
)
122-
else:
123-
unload_location = None
112+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
124113
self.query_id = await self._execute(
125114
operation,
126115
parameters=parameters,

pyathena/aio/pandas/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pyathena.aio.common import WithAsyncFetch
2020
from pyathena.common import CursorIterator
2121
from pyathena.error import OperationalError, ProgrammingError
22-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
22+
from pyathena.model import AthenaQueryExecution
2323
from pyathena.pandas.converter import (
2424
DefaultPandasTypeConverter,
2525
DefaultPandasUnloadTypeConverter,
@@ -134,18 +134,7 @@ async def execute( # type: ignore[override]
134134
Self reference for method chaining.
135135
"""
136136
self._reset_state()
137-
if self._unload:
138-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
139-
if not s3_staging_dir:
140-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
141-
operation, unload_location = self._formatter.wrap_unload(
142-
operation,
143-
s3_staging_dir=s3_staging_dir,
144-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
145-
compression=AthenaCompression.COMPRESSION_SNAPPY,
146-
)
147-
else:
148-
unload_location = None
137+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
149138
self.query_id = await self._execute(
150139
operation,
151140
parameters=parameters,

pyathena/aio/polars/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pyathena.aio.common import WithAsyncFetch
1010
from pyathena.common import CursorIterator
1111
from pyathena.error import OperationalError, ProgrammingError
12-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
12+
from pyathena.model import AthenaQueryExecution
1313
from pyathena.polars.converter import (
1414
DefaultPolarsTypeConverter,
1515
DefaultPolarsUnloadTypeConverter,
@@ -115,18 +115,7 @@ async def execute( # type: ignore[override]
115115
Self reference for method chaining.
116116
"""
117117
self._reset_state()
118-
if self._unload:
119-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
120-
if not s3_staging_dir:
121-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
122-
operation, unload_location = self._formatter.wrap_unload(
123-
operation,
124-
s3_staging_dir=s3_staging_dir,
125-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
126-
compression=AthenaCompression.COMPRESSION_SNAPPY,
127-
)
128-
else:
129-
unload_location = None
118+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
130119
self.query_id = await self._execute(
131120
operation,
132121
parameters=parameters,

pyathena/arrow/async_cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pyathena.arrow.result_set import AthenaArrowResultSet
1515
from pyathena.async_cursor import AsyncCursor
1616
from pyathena.common import CursorIterator
17-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
17+
from pyathena.model import AthenaQueryExecution
1818

1919
_logger = logging.getLogger(__name__) # type: ignore
2020

@@ -182,18 +182,7 @@ def execute(
182182
paramstyle: Optional[str] = None,
183183
**kwargs,
184184
) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]:
185-
if self._unload:
186-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
187-
if not s3_staging_dir:
188-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
189-
operation, unload_location = self._formatter.wrap_unload(
190-
operation,
191-
s3_staging_dir=s3_staging_dir,
192-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
193-
compression=AthenaCompression.COMPRESSION_SNAPPY,
194-
)
195-
else:
196-
unload_location = None
185+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
197186
query_id = self._execute(
198187
operation,
199188
parameters=parameters,

pyathena/arrow/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pyathena.arrow.result_set import AthenaArrowResultSet
1212
from pyathena.common import CursorIterator
1313
from pyathena.error import OperationalError, ProgrammingError
14-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
14+
from pyathena.model import AthenaQueryExecution
1515
from pyathena.result_set import WithFetch
1616

1717
if TYPE_CHECKING:
@@ -166,18 +166,7 @@ def execute(
166166
>>> table = cursor.as_arrow() # Returns Apache Arrow Table
167167
"""
168168
self._reset_state()
169-
if self._unload:
170-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
171-
if not s3_staging_dir:
172-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
173-
operation, unload_location = self._formatter.wrap_unload(
174-
operation,
175-
s3_staging_dir=s3_staging_dir,
176-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
177-
compression=AthenaCompression.COMPRESSION_SNAPPY,
178-
)
179-
else:
180-
unload_location = None
169+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
181170
self.query_id = self._execute(
182171
operation,
183172
parameters=parameters,

pyathena/common.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from pyathena.model import (
1616
AthenaCalculationExecution,
1717
AthenaCalculationExecutionStatus,
18+
AthenaCompression,
1819
AthenaDatabase,
20+
AthenaFileFormat,
1921
AthenaQueryExecution,
2022
AthenaTableMetadata,
2123
)
@@ -652,6 +654,32 @@ def _prepare_query(
652654
_logger.debug(query)
653655
return query, execution_parameters
654656

657+
def _prepare_unload(
658+
self,
659+
operation: str,
660+
s3_staging_dir: Optional[str],
661+
) -> Tuple[str, Optional[str]]:
662+
"""Wrap operation with UNLOAD if enabled.
663+
664+
Args:
665+
operation: SQL query string.
666+
s3_staging_dir: S3 location for query results.
667+
668+
Returns:
669+
Tuple of (possibly-wrapped operation, unload_location or None).
670+
"""
671+
if not getattr(self, "_unload", False):
672+
return operation, None
673+
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
674+
if not s3_staging_dir:
675+
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
676+
return self._formatter.wrap_unload(
677+
operation,
678+
s3_staging_dir=s3_staging_dir,
679+
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
680+
compression=AthenaCompression.COMPRESSION_SNAPPY,
681+
)
682+
655683
def _execute(
656684
self,
657685
operation: str,

pyathena/formatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import deepcopy
99
from datetime import date, datetime, timezone
1010
from decimal import Decimal
11-
from typing import Any, Callable, Dict, Optional, Type
11+
from typing import Any, Callable, Dict, Optional, Tuple, Type
1212

1313
from pyathena.error import ProgrammingError
1414
from pyathena.model import AthenaCompression, AthenaFileFormat
@@ -86,7 +86,7 @@ def wrap_unload(
8686
s3_staging_dir: str,
8787
format_: str = AthenaFileFormat.FILE_FORMAT_PARQUET,
8888
compression: str = AthenaCompression.COMPRESSION_SNAPPY,
89-
):
89+
) -> Tuple[str, Optional[str]]:
9090
"""Wrap a SELECT query with UNLOAD statement for high-performance result retrieval.
9191
9292
Transforms SELECT or WITH queries into UNLOAD statements that export results

pyathena/pandas/async_cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pyathena import ProgrammingError
1010
from pyathena.async_cursor import AsyncCursor
1111
from pyathena.common import CursorIterator
12-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
12+
from pyathena.model import AthenaQueryExecution
1313
from pyathena.pandas.converter import (
1414
DefaultPandasTypeConverter,
1515
DefaultPandasUnloadTypeConverter,
@@ -159,18 +159,7 @@ def execute(
159159
quoting: int = 1,
160160
**kwargs,
161161
) -> Tuple[str, "Future[Union[AthenaPandasResultSet, Any]]"]:
162-
if self._unload:
163-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
164-
if not s3_staging_dir:
165-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
166-
operation, unload_location = self._formatter.wrap_unload(
167-
operation,
168-
s3_staging_dir=s3_staging_dir,
169-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
170-
compression=AthenaCompression.COMPRESSION_SNAPPY,
171-
)
172-
else:
173-
unload_location = None
162+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
174163
query_id = self._execute(
175164
operation,
176165
parameters=parameters,

pyathena/pandas/cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pyathena.common import CursorIterator
2020
from pyathena.error import OperationalError, ProgrammingError
21-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
21+
from pyathena.model import AthenaQueryExecution
2222
from pyathena.pandas.converter import (
2323
DefaultPandasTypeConverter,
2424
DefaultPandasUnloadTypeConverter,
@@ -193,18 +193,7 @@ def execute(
193193
>>> df = cursor.fetchall() # Returns pandas DataFrame
194194
"""
195195
self._reset_state()
196-
if self._unload:
197-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
198-
if not s3_staging_dir:
199-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
200-
operation, unload_location = self._formatter.wrap_unload(
201-
operation,
202-
s3_staging_dir=s3_staging_dir,
203-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
204-
compression=AthenaCompression.COMPRESSION_SNAPPY,
205-
)
206-
else:
207-
unload_location = None
196+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
208197
self.query_id = self._execute(
209198
operation,
210199
parameters=parameters,

pyathena/polars/async_cursor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pyathena import ProgrammingError
1010
from pyathena.async_cursor import AsyncCursor
1111
from pyathena.common import CursorIterator
12-
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
12+
from pyathena.model import AthenaQueryExecution
1313
from pyathena.polars.converter import (
1414
DefaultPolarsTypeConverter,
1515
DefaultPolarsUnloadTypeConverter,
@@ -221,18 +221,7 @@ def execute(
221221
>>> result_set = future.result()
222222
>>> df = result_set.as_polars() # Returns Polars DataFrame
223223
"""
224-
if self._unload:
225-
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
226-
if not s3_staging_dir:
227-
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
228-
operation, unload_location = self._formatter.wrap_unload(
229-
operation,
230-
s3_staging_dir=s3_staging_dir,
231-
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
232-
compression=AthenaCompression.COMPRESSION_SNAPPY,
233-
)
234-
else:
235-
unload_location = None
224+
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
236225
query_id = self._execute(
237226
operation,
238227
parameters=parameters,

0 commit comments

Comments
 (0)