Skip to content

Commit a13bd3c

Browse files
Merge pull request #665 from pyathena-dev/fix/managed-query-result-storage
Support Athena managed query result storage
2 parents 46310a7 + 8054871 commit a13bd3c

File tree

24 files changed

+398
-51
lines changed

24 files changed

+398
-51
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
AWS_ATHENA_S3_STAGING_DIR: s3://laughingman7743-pyathena/github/
1919
AWS_ATHENA_WORKGROUP: pyathena
2020
AWS_ATHENA_SPARK_WORKGROUP: pyathena-spark
21+
AWS_ATHENA_MANAGED_WORKGROUP: pyathena-managed
2122

2223
strategy:
2324
fail-fast: false

docs/testing.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ If primary is not available as the default workgroup, specify an alternative wor
2020
$ export AWS_ATHENA_DEFAULT_WORKGROUP=DEFAULT_WORKGROUP
2121
```
2222

23+
### Managed query result storage (optional)
24+
25+
To test the managed query result storage feature, create a workgroup with managed storage enabled and set the `AWS_ATHENA_MANAGED_WORKGROUP` environment variable.
26+
If not set, managed storage tests will be skipped.
27+
28+
```bash
29+
$ export AWS_ATHENA_MANAGED_WORKGROUP=pyathena-managed
30+
```
31+
2332
## Run test
2433

2534
```bash

docs/usage.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,36 @@ print(cursor.description)
1414
print(cursor.fetchall())
1515
```
1616

17+
## Managed query result storage
18+
19+
When using a workgroup with [managed query result storage](https://docs.aws.amazon.com/athena/latest/ug/managed-results.html) enabled,
20+
you don't need to specify an S3 staging directory.
21+
22+
```python
23+
from pyathena import connect
24+
25+
cursor = connect(work_group="YOUR_MANAGED_WORK_GROUP",
26+
region_name="us-west-2").cursor()
27+
cursor.execute("SELECT * FROM one_row")
28+
print(cursor.fetchall())
29+
```
30+
31+
If the ``AWS_ATHENA_S3_STAGING_DIR`` environment variable is set, pass ``s3_staging_dir=""``
32+
to explicitly disable the fallback. Otherwise the API will reject the request because
33+
``ResultConfiguration`` and ``ManagedQueryResultsConfiguration`` cannot be set together.
34+
35+
```python
36+
cursor = connect(work_group="YOUR_MANAGED_WORK_GROUP",
37+
s3_staging_dir="",
38+
region_name="us-west-2").cursor()
39+
```
40+
41+
```{note}
42+
With managed query result storage, query results are retrieved via the `GetQueryResults` API
43+
(1000 rows per request) instead of reading S3 files directly. This may be slower for large
44+
result sets. For large datasets, consider using customer-managed storage or the `UNLOAD` statement.
45+
```
46+
1747
## Cursor iteration
1848

1949
```python
@@ -366,7 +396,7 @@ Support [Boto3 environment variables](https://boto3.amazonaws.com/v1/documentati
366396
### Additional environment variables
367397

368398
AWS_ATHENA_S3_STAGING_DIR
369-
: The S3 location where Athena automatically stores the query results and metadata information. Required if you have not set up workgroups. Not required if a workgroup has been set up.
399+
: The S3 location where Athena automatically stores the query results and metadata information. Required if you have not set up workgroups. Not required if a workgroup has been set up. When connecting to a workgroup with [managed query result storage](https://docs.aws.amazon.com/athena/latest/ug/managed-results.html), pass ``s3_staging_dir=""`` to explicitly disable this environment variable fallback (see [Managed query result storage](#managed-query-result-storage)).
370400

371401
AWS_ATHENA_WORK_GROUP
372402
: The setting of the workgroup to execute the query.

pyathena/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def connect(*args, **kwargs) -> "Connection[Any]":
8585
Args:
8686
s3_staging_dir: S3 location to store query results. Required if not
8787
using workgroups or if the workgroup doesn't have a result location.
88+
Pass an empty string to explicitly disable S3 staging and skip
89+
the ``AWS_ATHENA_S3_STAGING_DIR`` environment variable fallback
90+
(required for workgroups with managed query result storage).
8891
region_name: AWS region name. If not specified, uses the default region
8992
from your AWS configuration.
9093
schema_name: Athena database/schema name. Defaults to "default".
@@ -109,7 +112,7 @@ def connect(*args, **kwargs) -> "Connection[Any]":
109112
A Connection object that can be used to create cursors and execute queries.
110113
111114
Raises:
112-
AssertionError: If neither s3_staging_dir nor work_group is provided.
115+
ProgrammingError: If neither s3_staging_dir nor work_group is provided.
113116
114117
Example:
115118
>>> import pyathena

pyathena/arrow/async_cursor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def execute(
184184
) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]:
185185
if self._unload:
186186
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
187-
assert s3_staging_dir, "If the unload option is used, s3_staging_dir is required."
187+
if not s3_staging_dir:
188+
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
188189
operation, unload_location = self._formatter.wrap_unload(
189190
operation,
190191
s3_staging_dir=s3_staging_dir,

pyathena/arrow/cursor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def execute(
209209
self._reset_state()
210210
if self._unload:
211211
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
212-
assert s3_staging_dir, "If the unload option is used, s3_staging_dir is required."
212+
if not s3_staging_dir:
213+
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
213214
operation, unload_location = self._formatter.wrap_unload(
214215
operation,
215216
s3_staging_dir=s3_staging_dir,

pyathena/arrow/result_set.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(
117117
self._fs = self.__s3_file_system()
118118
if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location:
119119
self._table = self._as_arrow()
120+
elif self.state == AthenaQueryExecution.STATE_SUCCEEDED:
121+
self._table = self._as_arrow_from_api()
120122
else:
121123
import pyarrow as pa
122124

@@ -346,6 +348,25 @@ def _as_arrow(self) -> "Table":
346348
table = self._read_csv()
347349
return table
348350

351+
def _as_arrow_from_api(self, converter: Optional[Converter] = None) -> "Table":
352+
"""Build an Arrow Table from GetQueryResults API.
353+
354+
Used as a fallback when ``output_location`` is not available
355+
(e.g. managed query result storage).
356+
357+
Args:
358+
converter: Type converter for result values. Defaults to
359+
``DefaultTypeConverter`` if not specified.
360+
"""
361+
import pyarrow as pa
362+
363+
rows = self._fetch_all_rows(converter)
364+
if not rows:
365+
return pa.Table.from_pydict({})
366+
description = self.description if self.description else []
367+
columns = [d[0] for d in description]
368+
return pa.table(self._rows_to_columnar(rows, columns))
369+
349370
def as_arrow(self) -> "Table":
350371
return self._table
351372

pyathena/common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,15 @@ def _build_start_query_execution_request(
210210
request: Dict[str, Any] = {
211211
"QueryString": query,
212212
"QueryExecutionContext": {},
213-
"ResultConfiguration": {},
214213
}
215214
if self._schema_name:
216215
request["QueryExecutionContext"].update({"Database": self._schema_name})
217216
if self._catalog_name:
218217
request["QueryExecutionContext"].update({"Catalog": self._catalog_name})
218+
result_configuration: Dict[str, Any] = {}
219219
if self._s3_staging_dir or s3_staging_dir:
220-
request["ResultConfiguration"].update(
221-
{"OutputLocation": s3_staging_dir if s3_staging_dir else self._s3_staging_dir}
220+
result_configuration["OutputLocation"] = (
221+
s3_staging_dir if s3_staging_dir else self._s3_staging_dir
222222
)
223223
if self._work_group or work_group:
224224
request.update({"WorkGroup": work_group if work_group else self._work_group})
@@ -228,7 +228,9 @@ def _build_start_query_execution_request(
228228
}
229229
if self._kms_key:
230230
enc_conf.update({"KmsKey": self._kms_key})
231-
request["ResultConfiguration"].update({"EncryptionConfiguration": enc_conf})
231+
result_configuration["EncryptionConfiguration"] = enc_conf
232+
if result_configuration:
233+
request["ResultConfiguration"] = result_configuration
232234
if self._result_reuse_enable or result_reuse_enable:
233235
reuse_conf = {
234236
"Enabled": result_reuse_enable

pyathena/connection.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pyathena.common import BaseCursor, CursorIterator
2727
from pyathena.converter import Converter
2828
from pyathena.cursor import Cursor
29-
from pyathena.error import NotSupportedError
29+
from pyathena.error import NotSupportedError, ProgrammingError
3030
from pyathena.formatter import DefaultParameterFormatter, Formatter
3131
from pyathena.util import RetryConfig
3232

@@ -77,7 +77,9 @@ class Connection(Generic[ConnectionCursor]):
7777
Note:
7878
Either s3_staging_dir or work_group must be specified. If using a
7979
workgroup, it must have a result location configured unless
80-
s3_staging_dir is also provided.
80+
s3_staging_dir is also provided. For workgroups with managed query
81+
result storage, pass ``s3_staging_dir=""`` to skip the environment
82+
variable fallback.
8183
"""
8284

8385
_ENV_S3_STAGING_DIR: str = "AWS_ATHENA_S3_STAGING_DIR"
@@ -198,6 +200,10 @@ def __init__(
198200
Args:
199201
s3_staging_dir: S3 location to store query results. Required if not
200202
using workgroups or if workgroup doesn't have result location.
203+
Pass an empty string to explicitly disable S3 staging and skip
204+
the ``AWS_ATHENA_S3_STAGING_DIR`` environment variable fallback.
205+
This is required when connecting to a workgroup with managed
206+
query result storage enabled.
201207
region_name: AWS region name. Uses default region if not specified.
202208
schema_name: Default database/schema name. Defaults to "default".
203209
catalog_name: Data catalog name. Defaults to "awsdatacatalog".
@@ -226,12 +232,17 @@ def __init__(
226232
**kwargs: Additional arguments passed to boto3 Session and client.
227233
228234
Raises:
229-
AssertionError: If neither s3_staging_dir nor work_group is provided.
235+
ProgrammingError: If neither s3_staging_dir nor work_group is provided.
230236
231237
Note:
232238
Either s3_staging_dir or work_group must be specified. Environment
233239
variables AWS_ATHENA_S3_STAGING_DIR and AWS_ATHENA_WORK_GROUP are
234240
checked if parameters are not provided.
241+
242+
When using a workgroup with managed query result storage, pass
243+
``s3_staging_dir=""`` to prevent the environment variable fallback
244+
from sending a ``ResultConfiguration`` that conflicts with
245+
``ManagedQueryResultsConfiguration``.
235246
"""
236247
self._kwargs = {
237248
**kwargs,
@@ -241,8 +252,8 @@ def __init__(
241252
"serial_number": serial_number,
242253
"duration_seconds": duration_seconds,
243254
}
244-
if s3_staging_dir:
245-
self.s3_staging_dir: Optional[str] = s3_staging_dir
255+
if s3_staging_dir is not None:
256+
self.s3_staging_dir: Optional[str] = s3_staging_dir or None
246257
else:
247258
self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR)
248259
self.region_name = region_name
@@ -258,9 +269,8 @@ def __init__(
258269
self.profile_name = profile_name
259270
self.config: Optional[Config] = config if config else Config()
260271

261-
assert self.s3_staging_dir or self.work_group, (
262-
"Required argument `s3_staging_dir` or `work_group` not found."
263-
)
272+
if not self.s3_staging_dir and not self.work_group:
273+
raise ProgrammingError("Required argument `s3_staging_dir` or `work_group` not found.")
264274

265275
if self.s3_staging_dir and not self.s3_staging_dir.endswith("/"):
266276
self.s3_staging_dir = f"{self.s3_staging_dir}/"

pyathena/pandas/async_cursor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def execute(
161161
) -> Tuple[str, "Future[Union[AthenaPandasResultSet, Any]]"]:
162162
if self._unload:
163163
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
164-
assert s3_staging_dir, "If the unload option is used, s3_staging_dir is required."
164+
if not s3_staging_dir:
165+
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
165166
operation, unload_location = self._formatter.wrap_unload(
166167
operation,
167168
s3_staging_dir=s3_staging_dir,

0 commit comments

Comments
 (0)