Skip to content

Commit 90e67bc

Browse files
Merge pull request #690 from pyathena-dev/feat/result-set-type-hints
Add result_set_type_hints for precise complex type conversion
2 parents 5409ecd + add20b4 commit 90e67bc

33 files changed

+1644
-291
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
## WHAT
2+
<!-- (Write the change being made with this pull request) -->
3+
4+
## WHY
5+
<!-- (Write the motivation why you submit this pull request) -->

CLAUDE.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ export $(cat .env | xargs) && uv run pytest tests/pyathena/test_file.py -v
4242
- Use pytest fixtures from `conftest.py`
4343
- New features require tests; changes to SQLAlchemy dialects must pass `make test-sqla`
4444

45+
#### Test Conventions
46+
- **Class-based tests** for integration tests that use fixtures (cursors, engines): `class TestCursor:` with methods like `def test_fetchone(self, cursor):`
47+
- **Standalone functions** for unit tests of pure logic (converters, parsers, utils): `def test_to_struct_json_formats(input_value, expected):`
48+
- Test file naming mirrors source: `pyathena/parser.py``tests/pyathena/test_parser.py`
49+
- **Fixtures**: Cursor/engine fixtures are defined in `conftest.py` and injected by name (e.g., `cursor`, `engine`, `async_cursor`). Use `indirect=True` parametrization to pass connection options:
50+
```python
51+
@pytest.mark.parametrize("engine", [{"driver": "rest"}], indirect=True)
52+
def test_query(self, engine):
53+
engine, conn = engine
54+
```
55+
- **Parametrize** with `@pytest.mark.parametrize(("input", "expected"), [...])` for data-driven tests
56+
- **Integration tests** (need AWS) use cursor/engine fixtures with real Athena queries; **unit tests** (no AWS) call functions directly with test data
57+
4558
## Architecture — Key Design Decisions
4659

4760
These are non-obvious conventions that can't be discovered by reading code alone.

docs/usage.md

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,112 @@ The `on_start_query_execution` callback is supported by the following cursor typ
389389
Note: `AsyncCursor` and its variants do not support this callback as they already
390390
return the query ID immediately through their different execution model.
391391

392+
## Type hints for complex types
393+
394+
*New in version 3.30.0.*
395+
396+
The Athena API does not return element-level type information for complex types
397+
(array, map, row/struct). PyAthena parses the string representation returned by
398+
Athena, but without type metadata the converter can only apply heuristics — which
399+
may produce incorrect Python types for nested values (e.g. integers left as strings
400+
inside a struct).
401+
402+
The `result_set_type_hints` parameter solves this by letting you provide Athena DDL
403+
type signatures for specific columns. The converter then uses precise, recursive
404+
type-aware conversion instead of heuristics.
405+
406+
```python
407+
from pyathena import connect
408+
409+
cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
410+
region_name="us-west-2").cursor()
411+
cursor.execute(
412+
"SELECT col_array, col_map, col_struct FROM one_row_complex",
413+
result_set_type_hints={
414+
"col_array": "array(integer)",
415+
"col_map": "map(integer, integer)",
416+
"col_struct": "row(a integer, b integer)",
417+
},
418+
)
419+
row = cursor.fetchone()
420+
# col_struct values are now integers, not strings:
421+
# {"a": 1, "b": 2} instead of {"a": "1", "b": "2"}
422+
```
423+
424+
Column name matching is case-insensitive. Type hints support arbitrarily nested types:
425+
426+
```python
427+
cursor.execute(
428+
"""
429+
SELECT CAST(
430+
ROW(ROW('2024-01-01', 123), 4.736, 0.583)
431+
AS ROW(header ROW(stamp VARCHAR, seq INTEGER), x DOUBLE, y DOUBLE)
432+
) AS positions
433+
""",
434+
result_set_type_hints={
435+
"positions": "row(header row(stamp varchar, seq integer), x double, y double)",
436+
},
437+
)
438+
row = cursor.fetchone()
439+
positions = row[0]
440+
# positions["header"]["seq"] == 123 (int, not "123")
441+
# positions["x"] == 4.736 (float, not "4.736")
442+
```
443+
444+
### Hive-style syntax
445+
446+
You can paste type signatures from Hive DDL or ``DESCRIBE TABLE`` output directly.
447+
Hive-style angle brackets and colons are automatically converted to Trino-style syntax:
448+
449+
```python
450+
# Both are equivalent:
451+
result_set_type_hints={"col": "array(struct(a integer, b varchar))"} # Trino
452+
result_set_type_hints={"col": "array<struct<a:int,b:varchar>>"} # Hive
453+
```
454+
455+
The ``int`` alias is also supported and resolves to ``integer``.
456+
457+
### Index-based hints for duplicate column names
458+
459+
When a query produces columns with the same alias (e.g. ``SELECT a AS x, b AS x``),
460+
name-based hints cannot distinguish between them. Use integer keys to specify hints
461+
by zero-based column position:
462+
463+
```python
464+
cursor.execute(
465+
"SELECT a AS x, b AS x FROM my_table",
466+
result_set_type_hints={
467+
0: "array(integer)", # first "x" column
468+
1: "map(varchar, integer)", # second "x" column
469+
},
470+
)
471+
```
472+
473+
Integer (index-based) hints take priority over string (name-based) hints for the same
474+
column. You can mix both styles in the same dictionary.
475+
476+
### Constraints
477+
478+
* **Nested arrays in native format** — Athena's native (non-JSON) string representation
479+
does not clearly delimit nested arrays. If your query returns nested arrays
480+
(e.g. `array(array(integer))`), use `CAST(... AS JSON)` in your query to get
481+
JSON-formatted output, which is parsed reliably.
482+
* **Arrow, Pandas, and Polars cursors** — These cursors accept `result_set_type_hints`
483+
but their converters do not currently use the hints because they rely on their own
484+
type systems. The parameter is passed through for forward compatibility and for
485+
result sets that fall back to the default conversion path.
486+
487+
### Breaking change in 3.30.0
488+
489+
Prior to 3.30.0, PyAthena attempted to infer Python types for scalar values inside
490+
complex types using heuristics (e.g. `"123"``123`). Starting with 3.30.0, values
491+
inside complex types are **kept as strings** unless `result_set_type_hints` is provided.
492+
This change avoids silent misconversion but means existing code that relied on the
493+
heuristic behavior may see string values where it previously saw integers or floats.
494+
495+
To restore typed conversion, pass `result_set_type_hints` with the appropriate type
496+
signatures for the affected columns.
497+
392498
## Environment variables
393499

394500
Support [Boto3 environment variables](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables).

pyathena/aio/cursor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ async def execute( # type: ignore[override]
7979
result_reuse_enable: bool | None = None,
8080
result_reuse_minutes: int | None = None,
8181
paramstyle: str | None = None,
82+
result_set_type_hints: dict[str | int, str] | None = None,
8283
**kwargs,
8384
) -> AioCursor:
8485
"""Execute a SQL query asynchronously.
@@ -93,6 +94,9 @@ async def execute( # type: ignore[override]
9394
result_reuse_enable: Enable result reuse (optional).
9495
result_reuse_minutes: Result reuse duration in minutes (optional).
9596
paramstyle: Parameter style to use (optional).
97+
result_set_type_hints: Optional dictionary mapping column names to
98+
Athena DDL type signatures for precise type conversion within
99+
complex types.
96100
**kwargs: Additional execution parameters.
97101
98102
Returns:
@@ -119,6 +123,7 @@ async def execute( # type: ignore[override]
119123
query_execution,
120124
self.arraysize,
121125
self._retry_config,
126+
result_set_type_hints=result_set_type_hints,
122127
)
123128
else:
124129
raise OperationalError(query_execution.state_change_reason)

pyathena/aio/result_set.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
query_execution: AthenaQueryExecution,
3636
arraysize: int,
3737
retry_config: RetryConfig,
38+
result_set_type_hints: dict[str | int, str] | None = None,
3839
) -> None:
3940
super().__init__(
4041
connection=connection,
@@ -43,6 +44,7 @@ def __init__(
4344
arraysize=arraysize,
4445
retry_config=retry_config,
4546
_pre_fetch=False,
47+
result_set_type_hints=result_set_type_hints,
4648
)
4749

4850
@classmethod
@@ -53,6 +55,7 @@ async def create(
5355
query_execution: AthenaQueryExecution,
5456
arraysize: int,
5557
retry_config: RetryConfig,
58+
result_set_type_hints: dict[str | int, str] | None = None,
5659
) -> AthenaAioResultSet:
5760
"""Async factory method.
5861
@@ -64,11 +67,20 @@ async def create(
6467
query_execution: Query execution metadata.
6568
arraysize: Number of rows to fetch per request.
6669
retry_config: Retry configuration for API calls.
70+
result_set_type_hints: Optional dictionary mapping column names to
71+
Athena DDL type signatures for precise type conversion.
6772
6873
Returns:
6974
A fully initialized ``AthenaAioResultSet``.
7075
"""
71-
result_set = cls(connection, converter, query_execution, arraysize, retry_config)
76+
result_set = cls(
77+
connection,
78+
converter,
79+
query_execution,
80+
arraysize,
81+
retry_config,
82+
result_set_type_hints=result_set_type_hints,
83+
)
7284
if result_set.state == AthenaQueryExecution.STATE_SUCCEEDED:
7385
await result_set._async_pre_fetch()
7486
return result_set

pyathena/arrow/async_cursor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def arraysize(self, value: int) -> None:
149149
def _collect_result_set(
150150
self,
151151
query_id: str,
152+
result_set_type_hints: dict[str | int, str] | None = None,
152153
unload_location: str | None = None,
153154
kwargs: dict[str, Any] | None = None,
154155
) -> AthenaArrowResultSet:
@@ -165,6 +166,7 @@ def _collect_result_set(
165166
unload_location=unload_location,
166167
connect_timeout=self._connect_timeout,
167168
request_timeout=self._request_timeout,
169+
result_set_type_hints=result_set_type_hints,
168170
**kwargs,
169171
)
170172

@@ -179,6 +181,7 @@ def execute(
179181
result_reuse_enable: bool | None = None,
180182
result_reuse_minutes: int | None = None,
181183
paramstyle: str | None = None,
184+
result_set_type_hints: dict[str | int, str] | None = None,
182185
**kwargs,
183186
) -> tuple[str, Future[AthenaArrowResultSet | Any]]:
184187
operation, unload_location = self._prepare_unload(operation, s3_staging_dir)
@@ -198,6 +201,7 @@ def execute(
198201
self._executor.submit(
199202
self._collect_result_set,
200203
query_id,
204+
result_set_type_hints,
201205
unload_location,
202206
kwargs,
203207
),

pyathena/arrow/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _dtypes(self) -> dict[str, type[Any]]:
9090
}
9191
return self.__dtypes
9292

93-
def convert(self, type_: str, value: str | None) -> Any | None:
93+
def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None:
9494
converter = self.get(type_)
9595
return converter(value)
9696

@@ -114,5 +114,5 @@ def __init__(self) -> None:
114114
default=_to_default,
115115
)
116116

117-
def convert(self, type_: str, value: str | None) -> Any | None:
117+
def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None:
118118
pass

pyathena/arrow/cursor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def execute(
137137
result_reuse_minutes: int | None = None,
138138
paramstyle: str | None = None,
139139
on_start_query_execution: Callable[[str], None] | None = None,
140+
result_set_type_hints: dict[str | int, str] | None = None,
140141
**kwargs,
141142
) -> ArrowCursor:
142143
"""Execute a SQL query and return results as Apache Arrow Tables.
@@ -156,6 +157,9 @@ def execute(
156157
result_reuse_minutes: Minutes to reuse cached results.
157158
paramstyle: Parameter style ('qmark' or 'pyformat').
158159
on_start_query_execution: Callback called when query starts.
160+
result_set_type_hints: Optional dictionary mapping column names to
161+
Athena DDL type signatures for precise type conversion within
162+
complex types.
159163
**kwargs: Additional execution parameters.
160164
161165
Returns:
@@ -197,6 +201,7 @@ def execute(
197201
unload_location=unload_location,
198202
connect_timeout=self._connect_timeout,
199203
request_timeout=self._request_timeout,
204+
result_set_type_hints=result_set_type_hints,
200205
**kwargs,
201206
)
202207
else:

pyathena/arrow/result_set.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
unload_location: str | None = None,
9292
connect_timeout: float | None = None,
9393
request_timeout: float | None = None,
94+
result_set_type_hints: dict[str | int, str] | None = None,
9495
**kwargs,
9596
) -> None:
9697
super().__init__(
@@ -99,6 +100,7 @@ def __init__(
99100
query_execution=query_execution,
100101
arraysize=1, # Fetch one row to retrieve metadata
101102
retry_config=retry_config,
103+
result_set_type_hints=result_set_type_hints,
102104
)
103105
self._rows.clear() # Clear pre_fetch data
104106
self._arraysize = arraysize

pyathena/async_cursor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,19 @@ def poll(self, query_id: str) -> Future[AthenaQueryExecution]:
144144
"""
145145
return cast("Future[AthenaQueryExecution]", self._executor.submit(self._poll, query_id))
146146

147-
def _collect_result_set(self, query_id: str) -> AthenaResultSet:
147+
def _collect_result_set(
148+
self,
149+
query_id: str,
150+
result_set_type_hints: dict[str | int, str] | None = None,
151+
) -> AthenaResultSet:
148152
query_execution = cast(AthenaQueryExecution, self._poll(query_id))
149153
return self._result_set_class(
150154
connection=self._connection,
151155
converter=self._converter,
152156
query_execution=query_execution,
153157
arraysize=self._arraysize,
154158
retry_config=self._retry_config,
159+
result_set_type_hints=result_set_type_hints,
155160
)
156161

157162
def execute(
@@ -165,6 +170,7 @@ def execute(
165170
result_reuse_enable: bool | None = None,
166171
result_reuse_minutes: int | None = None,
167172
paramstyle: str | None = None,
173+
result_set_type_hints: dict[str | int, str] | None = None,
168174
**kwargs,
169175
) -> tuple[str, Future[AthenaResultSet | Any]]:
170176
"""Execute a SQL query asynchronously.
@@ -183,6 +189,9 @@ def execute(
183189
result_reuse_enable: Enable result reuse for identical queries (optional).
184190
result_reuse_minutes: Result reuse duration in minutes (optional).
185191
paramstyle: Parameter style to use (optional).
192+
result_set_type_hints: Optional dictionary mapping column names to
193+
Athena DDL type signatures for precise type conversion within
194+
complex types.
186195
**kwargs: Additional execution parameters.
187196
188197
Returns:
@@ -207,7 +216,9 @@ def execute(
207216
result_reuse_minutes=result_reuse_minutes,
208217
paramstyle=paramstyle,
209218
)
210-
return query_id, self._executor.submit(self._collect_result_set, query_id)
219+
return query_id, self._executor.submit(
220+
self._collect_result_set, query_id, result_set_type_hints
221+
)
211222

212223
def executemany(
213224
self,

0 commit comments

Comments
 (0)