Skip to content

Commit 421fcc1

Browse files
Merge pull request #590 from laughingman7743/feature/array-type-support
Add ARRAY type support
2 parents 2d2a66f + dd63c7b commit 421fcc1

11 files changed

Lines changed: 755 additions & 12 deletions

File tree

docs/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ PyAthena provides comprehensive support for Amazon Athena's data types and featu
5050

5151
**Data Type Support:**
5252
- **STRUCT/ROW Types**: :ref:`Complete support <sqlalchemy>` for complex nested data structures
53-
- **ARRAY Types**: Native handling of array data with automatic Python list conversion
53+
- **ARRAY Types**: :ref:`Complete support <sqlalchemy>` for ordered collections with automatic Python list conversion
5454
- **MAP Types**: :ref:`Complete support <sqlalchemy>` for key-value dictionary-like data structures
5555
- **JSON Integration**: Seamless JSON data parsing and conversion
5656
- **Performance Optimized**: Smart format detection for efficient data processing

docs/sqlalchemy.rst

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,181 @@ Migration from Raw Strings
606606
result = cursor.execute("SELECT map_column FROM table").fetchone()
607607
map_data = result[0] # {"key1": "value1", "key2": "value2"} - automatically converted
608608
value = map_data['key1'] # Direct access
609+
610+
ARRAY Type Support
611+
~~~~~~~~~~~~~~~~~~
612+
613+
PyAthena provides comprehensive support for Amazon Athena's ARRAY data types, enabling you to work with ordered collections of data in your Python applications.
614+
615+
Basic Usage
616+
^^^^^^^^^^^
617+
618+
.. code:: python
619+
620+
from sqlalchemy import Column, String, Integer, Table, MetaData
621+
from pyathena.sqlalchemy.types import AthenaArray
622+
623+
# Define a table with ARRAY columns
624+
orders = Table('orders', metadata,
625+
Column('id', Integer),
626+
Column('item_ids', AthenaArray(Integer)),
627+
Column('tags', AthenaArray(String)),
628+
Column('categories', AthenaArray(String))
629+
)
630+
631+
This creates a table definition equivalent to:
632+
633+
.. code:: sql
634+
635+
CREATE TABLE orders (
636+
id INTEGER,
637+
item_ids ARRAY<INTEGER>,
638+
tags ARRAY<STRING>,
639+
categories ARRAY<STRING>
640+
)
641+
642+
Querying ARRAY Data
643+
^^^^^^^^^^^^^^^^^^^
644+
645+
PyAthena automatically converts ARRAY data between different formats:
646+
647+
.. code:: python
648+
649+
from sqlalchemy import create_engine, select
650+
651+
# Query ARRAY data using ARRAY constructor
652+
result = connection.execute(
653+
select().from_statement(
654+
text("SELECT ARRAY[1, 2, 3, 4, 5] as item_ids")
655+
)
656+
).fetchone()
657+
658+
# Access ARRAY data as Python list
659+
item_ids = result.item_ids # [1, 2, 3, 4, 5]
660+
661+
Complex ARRAY Operations
662+
^^^^^^^^^^^^^^^^^^^^^^^^
663+
664+
For arrays containing complex data types:
665+
666+
.. code:: python
667+
668+
# Arrays with STRUCT elements
669+
result = connection.execute(
670+
select().from_statement(
671+
text("SELECT ARRAY[ROW('Alice', 25), ROW('Bob', 30)] as users")
672+
)
673+
).fetchone()
674+
675+
users = result.users # [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}]
676+
677+
# Using CAST AS JSON for complex ARRAY operations
678+
result = connection.execute(
679+
select().from_statement(
680+
text("SELECT CAST(ARRAY[1, 2, 3] AS JSON) as data")
681+
)
682+
).fetchone()
683+
684+
# Parse JSON result
685+
import json
686+
if isinstance(result.data, str):
687+
array_data = json.loads(result.data) # [1, 2, 3]
688+
else:
689+
array_data = result.data # Already converted to list
690+
691+
Data Format Support
692+
^^^^^^^^^^^^^^^^^^^
693+
694+
PyAthena supports multiple ARRAY data formats:
695+
696+
**Athena Native Format:**
697+
698+
.. code:: python
699+
700+
# Input: '[1, 2, 3]'
701+
# Output: [1, 2, 3]
702+
703+
# Input: '[apple, banana, cherry]'
704+
# Output: ["apple", "banana", "cherry"]
705+
706+
**JSON Format:**
707+
708+
.. code:: python
709+
710+
# Input: '[1, 2, 3]'
711+
# Output: [1, 2, 3]
712+
713+
# Input: '["apple", "banana", "cherry"]'
714+
# Output: ["apple", "banana", "cherry"]
715+
716+
**Complex Nested Arrays:**
717+
718+
.. code:: python
719+
720+
# Input: '[{name=John, age=30}, {name=Jane, age=25}]'
721+
# Output: [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]
722+
723+
Type Definitions
724+
^^^^^^^^^^^^^^^^
725+
726+
AthenaArray supports various item types:
727+
728+
.. code:: python
729+
730+
from pyathena.sqlalchemy.types import AthenaArray, AthenaStruct, AthenaMap
731+
732+
# Simple arrays
733+
AthenaArray(String) # ARRAY<STRING>
734+
AthenaArray(Integer) # ARRAY<INTEGER>
735+
736+
# Arrays of complex types
737+
AthenaArray(AthenaStruct(...)) # ARRAY<STRUCT<...>>
738+
AthenaArray(AthenaMap(...)) # ARRAY<MAP<...>>
739+
740+
# Nested arrays
741+
AthenaArray(AthenaArray(Integer)) # ARRAY<ARRAY<INTEGER>>
742+
743+
Best Practices
744+
^^^^^^^^^^^^^^
745+
746+
1. **Use appropriate item types** in AthenaArray definitions:
747+
748+
.. code:: python
749+
750+
AthenaArray(Integer) # For numeric arrays
751+
AthenaArray(String) # For string arrays
752+
AthenaArray(AthenaStruct(...)) # For arrays of structs
753+
754+
2. **Use CAST AS JSON** for complex array operations:
755+
756+
.. code:: sql
757+
758+
SELECT CAST(complex_array AS JSON) FROM table_name
759+
760+
3. **Handle NULL values** appropriately in your application logic:
761+
762+
.. code:: python
763+
764+
if result.array_column is not None:
765+
# Process array data
766+
first_item = result.array_column[0] if result.array_column else None
767+
768+
Migration from Raw Strings
769+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
770+
771+
**Before (raw string handling):**
772+
773+
.. code:: python
774+
775+
result = cursor.execute("SELECT array_column FROM table").fetchone()
776+
raw_data = result[0] # "[1, 2, 3]"
777+
import json
778+
parsed_data = json.loads(raw_data)
779+
780+
**After (automatic conversion):**
781+
782+
.. code:: python
783+
784+
result = cursor.execute("SELECT array_column FROM table").fetchone()
785+
array_data = result[0] # [1, 2, 3] - automatically converted
786+
first_item = array_data[0] # Direct access

pyathena/converter.py

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import deepcopy
99
from datetime import date, datetime, time
1010
from decimal import Decimal
11-
from typing import Any, Callable, Dict, Optional, Type
11+
from typing import Any, Callable, Dict, List, Optional, Type
1212

1313
from dateutil.tz import gettz
1414

@@ -78,6 +78,52 @@ def _to_json(varchar_value: Optional[str]) -> Optional[Any]:
7878
return json.loads(varchar_value)
7979

8080

81+
def _to_array(varchar_value: Optional[str]) -> Optional[List[Any]]:
82+
"""Convert array data to Python list.
83+
84+
Supports two formats:
85+
1. JSON format: '[1, 2, 3]' or '["a", "b", "c"]' (recommended)
86+
2. Athena native format: '[1, 2, 3]' (basic cases only)
87+
88+
For complex arrays, use CAST(array_column AS JSON) in your SQL query.
89+
90+
Args:
91+
varchar_value: String representation of array data
92+
93+
Returns:
94+
List representation of array, or None if parsing fails
95+
"""
96+
if varchar_value is None:
97+
return None
98+
99+
# Quick check: if it doesn't look like an array, return None
100+
if not (varchar_value.startswith("[") and varchar_value.endswith("]")):
101+
return None
102+
103+
# Optimize: Try JSON parsing first (most reliable)
104+
try:
105+
result = json.loads(varchar_value)
106+
if isinstance(result, list):
107+
return result
108+
except json.JSONDecodeError:
109+
# If JSON parsing fails, fall back to basic parsing for simple cases
110+
pass
111+
112+
inner = varchar_value[1:-1].strip()
113+
if not inner:
114+
return []
115+
116+
try:
117+
# For nested arrays, too complex for basic parsing
118+
if "[" in inner:
119+
# Contains nested arrays - too complex for basic parsing
120+
return None
121+
# Try native parsing (including struct arrays)
122+
return _parse_array_native(inner)
123+
except Exception:
124+
return None
125+
126+
81127
def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
82128
"""Convert map data to Python dictionary.
83129
@@ -179,6 +225,81 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
179225
return None
180226

181227

228+
def _parse_array_native(inner: str) -> Optional[List[Any]]:
229+
"""Parse array native format: 1, 2, 3 or {a, b}, {c, d}.
230+
231+
Args:
232+
inner: Interior content of array without brackets.
233+
234+
Returns:
235+
List with parsed values, or None if no valid values found.
236+
"""
237+
result = []
238+
239+
# Smart split by comma - respect brace groupings
240+
items = _split_array_items(inner)
241+
242+
for item in items:
243+
if not item:
244+
continue
245+
246+
# Handle struct (ROW) values in format {a, b, c} or {key=value, ...}
247+
if item.strip().startswith("{") and item.strip().endswith("}"):
248+
# This is a struct value - parse it as a struct
249+
struct_value = _to_struct(item.strip())
250+
if struct_value is not None:
251+
result.append(struct_value)
252+
continue
253+
254+
# Skip items with nested arrays or complex quoting (safety check)
255+
if any(char in item for char in '[]="'):
256+
continue
257+
258+
# Convert item to appropriate type
259+
converted_item = _convert_value(item)
260+
result.append(converted_item)
261+
262+
return result if result else None
263+
264+
265+
def _split_array_items(inner: str) -> List[str]:
266+
"""Split array items by comma, respecting brace and bracket groupings.
267+
268+
Args:
269+
inner: Interior content of array without brackets.
270+
271+
Returns:
272+
List of item strings.
273+
"""
274+
items = []
275+
current_item = ""
276+
brace_depth = 0
277+
bracket_depth = 0
278+
279+
for char in inner:
280+
if char == "{":
281+
brace_depth += 1
282+
elif char == "}":
283+
brace_depth -= 1
284+
elif char == "[":
285+
bracket_depth += 1
286+
elif char == "]":
287+
bracket_depth -= 1
288+
elif char == "," and brace_depth == 0 and bracket_depth == 0:
289+
# Top-level comma - end current item
290+
items.append(current_item.strip())
291+
current_item = ""
292+
continue
293+
294+
current_item += char
295+
296+
# Add the last item
297+
if current_item.strip():
298+
items.append(current_item.strip())
299+
300+
return items
301+
302+
182303
def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]:
183304
"""Parse map native format: key1=value1, key2=value2.
184305
@@ -302,7 +423,7 @@ def _to_default(varchar_value: Optional[str]) -> Optional[str]:
302423
"date": _to_date,
303424
"time": _to_time,
304425
"varbinary": _to_binary,
305-
"array": _to_default,
426+
"array": _to_array,
306427
"map": _to_map,
307428
"row": _to_struct,
308429
"decimal": _to_decimal,

pyathena/sqlalchemy/compiler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AthenaPartitionTransform,
1818
AthenaRowFormatSerde,
1919
)
20-
from pyathena.sqlalchemy.types import AthenaMap, AthenaStruct
20+
from pyathena.sqlalchemy.types import AthenaArray, AthenaMap, AthenaStruct
2121

2222
if TYPE_CHECKING:
2323
from sqlalchemy import (
@@ -164,6 +164,15 @@ def visit_map(self, type_, **kw): # noqa: N802
164164
def visit_MAP(self, type_, **kw): # noqa: N802
165165
return self.visit_map(type_, **kw)
166166

167+
def visit_array(self, type_, **kw): # noqa: N802
168+
if isinstance(type_, AthenaArray):
169+
item_type_str = self.process(type_.item_type, **kw)
170+
return f"ARRAY<{item_type_str}>"
171+
return "ARRAY<STRING>"
172+
173+
def visit_ARRAY(self, type_, **kw): # noqa: N802
174+
return self.visit_array(type_, **kw)
175+
167176

168177
class AthenaStatementCompiler(SQLCompiler):
169178
def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):

0 commit comments

Comments
 (0)