Skip to content

Commit 8d082e8

Browse files
Add ARRAY type support with compiler and converter functionality
- Add _to_array() converter function supporting JSON and native formats - Create AthenaArray and ARRAY type classes in SQLAlchemy types - Implement visit_array() and visit_ARRAY() methods in compiler - Add comprehensive tests for ARRAY compiler functionality - Follow same architectural patterns as existing STRUCT and MAP types 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 2d2a66f commit 8d082e8

4 files changed

Lines changed: 138 additions & 5 deletions

File tree

pyathena/converter.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import deepcopy
99
from datetime import date, datetime, time
1010
from decimal import Decimal
11-
from typing import Any, Callable, Dict, Optional, Type
11+
from typing import Any, Callable, Dict, List, Optional, Type
1212

1313
from dateutil.tz import gettz
1414

@@ -78,6 +78,51 @@ def _to_json(varchar_value: Optional[str]) -> Optional[Any]:
7878
return json.loads(varchar_value)
7979

8080

81+
def _to_array(varchar_value: Optional[str]) -> Optional[List[Any]]:
82+
"""Convert array data to Python list.
83+
84+
Supports two formats:
85+
1. JSON format: '[1, 2, 3]' or '["a", "b", "c"]' (recommended)
86+
2. Athena native format: '[1, 2, 3]' (basic cases only)
87+
88+
For complex arrays, use CAST(array_column AS JSON) in your SQL query.
89+
90+
Args:
91+
varchar_value: String representation of array data
92+
93+
Returns:
94+
List representation of array, or None if parsing fails
95+
"""
96+
if varchar_value is None:
97+
return None
98+
99+
# Quick check: if it doesn't look like an array, return None
100+
if not (varchar_value.startswith("[") and varchar_value.endswith("]")):
101+
return None
102+
103+
# Optimize: Try JSON parsing first (most reliable)
104+
try:
105+
result = json.loads(varchar_value)
106+
return result if isinstance(result, list) else None
107+
except json.JSONDecodeError:
108+
# If JSON parsing fails, fall back to basic parsing for simple cases
109+
pass
110+
111+
inner = varchar_value[1:-1].strip()
112+
if not inner:
113+
return []
114+
115+
try:
116+
# Simple array format: [1, 2, 3] or [a, b, c]
117+
# For complex structures, return None to keep as string
118+
if any(char in inner for char in "{}()[]"):
119+
# Contains complex structures, skip parsing
120+
return None
121+
return _parse_array_native(inner)
122+
except Exception:
123+
return None
124+
125+
81126
def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
82127
"""Convert map data to Python dictionary.
83128
@@ -179,6 +224,35 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
179224
return None
180225

181226

227+
def _parse_array_native(inner: str) -> Optional[List[Any]]:
228+
"""Parse array native format: 1, 2, 3 or a, b, c.
229+
230+
Args:
231+
inner: Interior content of array without brackets.
232+
233+
Returns:
234+
List with parsed values, or None if no valid values found.
235+
"""
236+
result = []
237+
238+
# Simple split by comma for basic cases
239+
items = [item.strip() for item in inner.split(",")]
240+
241+
for item in items:
242+
if not item:
243+
continue
244+
245+
# Skip items with special characters (safety check)
246+
if any(char in item for char in '{}[]()="'):
247+
continue
248+
249+
# Convert item to appropriate type
250+
converted_item = _convert_value(item)
251+
result.append(converted_item)
252+
253+
return result if result else None
254+
255+
182256
def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]:
183257
"""Parse map native format: key1=value1, key2=value2.
184258
@@ -302,7 +376,7 @@ def _to_default(varchar_value: Optional[str]) -> Optional[str]:
302376
"date": _to_date,
303377
"time": _to_time,
304378
"varbinary": _to_binary,
305-
"array": _to_default,
379+
"array": _to_array,
306380
"map": _to_map,
307381
"row": _to_struct,
308382
"decimal": _to_decimal,

pyathena/sqlalchemy/compiler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AthenaPartitionTransform,
1818
AthenaRowFormatSerde,
1919
)
20-
from pyathena.sqlalchemy.types import AthenaMap, AthenaStruct
20+
from pyathena.sqlalchemy.types import AthenaArray, AthenaMap, AthenaStruct
2121

2222
if TYPE_CHECKING:
2323
from sqlalchemy import (
@@ -164,6 +164,15 @@ def visit_map(self, type_, **kw): # noqa: N802
164164
def visit_MAP(self, type_, **kw): # noqa: N802
165165
return self.visit_map(type_, **kw)
166166

167+
def visit_array(self, type_, **kw): # noqa: N802
168+
if isinstance(type_, AthenaArray):
169+
item_type_str = self.process(type_.item_type, **kw)
170+
return f"ARRAY<{item_type_str}>"
171+
return "ARRAY<STRING>"
172+
173+
def visit_ARRAY(self, type_, **kw): # noqa: N802
174+
return self.visit_array(type_, **kw)
175+
167176

168177
class AthenaStatementCompiler(SQLCompiler):
169178
def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):

pyathena/sqlalchemy/types.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from datetime import date, datetime
5-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
66

77
from sqlalchemy.sql import sqltypes
88
from sqlalchemy.sql.type_api import TypeEngine
@@ -106,3 +106,24 @@ def python_type(self) -> type:
106106

107107
class MAP(AthenaMap):
108108
__visit_name__ = "MAP"
109+
110+
111+
class AthenaArray(TypeEngine[List[Any]]):
112+
__visit_name__ = "array"
113+
114+
def __init__(self, item_type: Any = None) -> None:
115+
if item_type is None:
116+
self.item_type: TypeEngine[Any] = sqltypes.String()
117+
elif isinstance(item_type, TypeEngine):
118+
self.item_type = item_type
119+
else:
120+
# Assume it's a SQLAlchemy type class and instantiate it
121+
self.item_type = item_type()
122+
123+
@property
124+
def python_type(self) -> type:
125+
return list
126+
127+
128+
class ARRAY(AthenaArray):
129+
__visit_name__ = "ARRAY"

tests/pyathena/sqlalchemy/test_compiler.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy import Integer, String
66

77
from pyathena.sqlalchemy.compiler import AthenaTypeCompiler
8-
from pyathena.sqlalchemy.types import MAP, STRUCT, AthenaMap, AthenaStruct
8+
from pyathena.sqlalchemy.types import ARRAY, MAP, STRUCT, AthenaArray, AthenaMap, AthenaStruct
99

1010

1111
class TestAthenaTypeCompiler:
@@ -80,3 +80,32 @@ def test_visit_map_no_attributes(self):
8080
map_type = type("MockMap", (), {})()
8181
result = compiler.visit_map(map_type)
8282
assert result == "MAP<STRING, STRING>"
83+
84+
def test_visit_array_default(self):
85+
dialect = Mock()
86+
compiler = AthenaTypeCompiler(dialect)
87+
array_type = AthenaArray()
88+
result = compiler.visit_array(array_type)
89+
assert result == "ARRAY<STRING>"
90+
91+
def test_visit_array_with_type(self):
92+
dialect = Mock()
93+
compiler = AthenaTypeCompiler(dialect)
94+
array_type = AthenaArray(Integer)
95+
result = compiler.visit_array(array_type)
96+
assert result == "ARRAY<INTEGER>"
97+
98+
def test_visit_array_uppercase(self):
99+
dialect = Mock()
100+
compiler = AthenaTypeCompiler(dialect)
101+
array_type = ARRAY(String)
102+
result = compiler.visit_ARRAY(array_type)
103+
assert result == "ARRAY<STRING>" or result == "ARRAY<VARCHAR>"
104+
105+
def test_visit_array_no_attributes(self):
106+
# Test array type without item_type attribute
107+
dialect = Mock()
108+
compiler = AthenaTypeCompiler(dialect)
109+
array_type = type("MockArray", (), {})()
110+
result = compiler.visit_array(array_type)
111+
assert result == "ARRAY<STRING>"

0 commit comments

Comments
 (0)