Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ PyAthena provides comprehensive support for Amazon Athena's data types and featu

**Data Type Support:**
- **STRUCT/ROW Types**: :ref:`Complete support <sqlalchemy>` for complex nested data structures
- **ARRAY Types**: Native handling of array data with automatic Python list conversion
- **ARRAY Types**: :ref:`Complete support <sqlalchemy>` for ordered collections with automatic Python list conversion
- **MAP Types**: :ref:`Complete support <sqlalchemy>` for key-value dictionary-like data structures
- **JSON Integration**: Seamless JSON data parsing and conversion
- **Performance Optimized**: Smart format detection for efficient data processing
Expand Down
178 changes: 178 additions & 0 deletions docs/sqlalchemy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,181 @@ Migration from Raw Strings
result = cursor.execute("SELECT map_column FROM table").fetchone()
map_data = result[0] # {"key1": "value1", "key2": "value2"} - automatically converted
value = map_data['key1'] # Direct access

ARRAY Type Support
~~~~~~~~~~~~~~~~~~

PyAthena provides comprehensive support for Amazon Athena's ARRAY data types, enabling you to work with ordered collections of data in your Python applications.

Basic Usage
^^^^^^^^^^^

.. code:: python

from sqlalchemy import Column, String, Integer, Table, MetaData
from pyathena.sqlalchemy.types import AthenaArray

# Define a table with ARRAY columns
orders = Table('orders', metadata,
Column('id', Integer),
Column('item_ids', AthenaArray(Integer)),
Column('tags', AthenaArray(String)),
Column('categories', AthenaArray(String))
)

This creates a table definition equivalent to:

.. code:: sql

CREATE TABLE orders (
id INTEGER,
item_ids ARRAY<INTEGER>,
tags ARRAY<STRING>,
categories ARRAY<STRING>
)

Querying ARRAY Data
^^^^^^^^^^^^^^^^^^^

PyAthena automatically converts ARRAY data between different formats:

.. code:: python

from sqlalchemy import create_engine, select

# Query ARRAY data using ARRAY constructor
result = connection.execute(
select().from_statement(
text("SELECT ARRAY[1, 2, 3, 4, 5] as item_ids")
)
).fetchone()

# Access ARRAY data as Python list
item_ids = result.item_ids # [1, 2, 3, 4, 5]

Complex ARRAY Operations
^^^^^^^^^^^^^^^^^^^^^^^^

For arrays containing complex data types:

.. code:: python

# Arrays with STRUCT elements
result = connection.execute(
select().from_statement(
text("SELECT ARRAY[ROW('Alice', 25), ROW('Bob', 30)] as users")
)
).fetchone()

users = result.users # [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}]

# Using CAST AS JSON for complex ARRAY operations
result = connection.execute(
select().from_statement(
text("SELECT CAST(ARRAY[1, 2, 3] AS JSON) as data")
)
).fetchone()

# Parse JSON result
import json
if isinstance(result.data, str):
array_data = json.loads(result.data) # [1, 2, 3]
else:
array_data = result.data # Already converted to list

Data Format Support
^^^^^^^^^^^^^^^^^^^

PyAthena supports multiple ARRAY data formats:

**Athena Native Format:**

.. code:: python

# Input: '[1, 2, 3]'
# Output: [1, 2, 3]

# Input: '[apple, banana, cherry]'
# Output: ["apple", "banana", "cherry"]

**JSON Format:**

.. code:: python

# Input: '[1, 2, 3]'
# Output: [1, 2, 3]

# Input: '["apple", "banana", "cherry"]'
# Output: ["apple", "banana", "cherry"]

**Complex Nested Arrays:**

.. code:: python

# Input: '[{name=John, age=30}, {name=Jane, age=25}]'
# Output: [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]

Type Definitions
^^^^^^^^^^^^^^^^

AthenaArray supports various item types:

.. code:: python

from pyathena.sqlalchemy.types import AthenaArray, AthenaStruct, AthenaMap

# Simple arrays
AthenaArray(String) # ARRAY<STRING>
AthenaArray(Integer) # ARRAY<INTEGER>

# Arrays of complex types
AthenaArray(AthenaStruct(...)) # ARRAY<STRUCT<...>>
AthenaArray(AthenaMap(...)) # ARRAY<MAP<...>>

# Nested arrays
AthenaArray(AthenaArray(Integer)) # ARRAY<ARRAY<INTEGER>>

Best Practices
^^^^^^^^^^^^^^

1. **Use appropriate item types** in AthenaArray definitions:

.. code:: python

AthenaArray(Integer) # For numeric arrays
AthenaArray(String) # For string arrays
AthenaArray(AthenaStruct(...)) # For arrays of structs

2. **Use CAST AS JSON** for complex array operations:

.. code:: sql

SELECT CAST(complex_array AS JSON) FROM table_name

3. **Handle NULL values** appropriately in your application logic:

.. code:: python

if result.array_column is not None:
# Process array data
first_item = result.array_column[0] if result.array_column else None

Migration from Raw Strings
^^^^^^^^^^^^^^^^^^^^^^^^^^^

**Before (raw string handling):**

.. code:: python

result = cursor.execute("SELECT array_column FROM table").fetchone()
raw_data = result[0] # "[1, 2, 3]"
import json
parsed_data = json.loads(raw_data)

**After (automatic conversion):**

.. code:: python

result = cursor.execute("SELECT array_column FROM table").fetchone()
array_data = result[0] # [1, 2, 3] - automatically converted
first_item = array_data[0] # Direct access
125 changes: 123 additions & 2 deletions pyathena/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from copy import deepcopy
from datetime import date, datetime, time
from decimal import Decimal
from typing import Any, Callable, Dict, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type

from dateutil.tz import gettz

Expand Down Expand Up @@ -78,6 +78,52 @@ def _to_json(varchar_value: Optional[str]) -> Optional[Any]:
return json.loads(varchar_value)


def _to_array(varchar_value: Optional[str]) -> Optional[List[Any]]:
"""Convert array data to Python list.

Supports two formats:
1. JSON format: '[1, 2, 3]' or '["a", "b", "c"]' (recommended)
2. Athena native format: '[1, 2, 3]' (basic cases only)

For complex arrays, use CAST(array_column AS JSON) in your SQL query.

Args:
varchar_value: String representation of array data

Returns:
List representation of array, or None if parsing fails
"""
if varchar_value is None:
return None

# Quick check: if it doesn't look like an array, return None
if not (varchar_value.startswith("[") and varchar_value.endswith("]")):
return None

# Optimize: Try JSON parsing first (most reliable)
try:
result = json.loads(varchar_value)
if isinstance(result, list):
return result
except json.JSONDecodeError:
# If JSON parsing fails, fall back to basic parsing for simple cases
pass

inner = varchar_value[1:-1].strip()
if not inner:
return []

try:
# For nested arrays, too complex for basic parsing
if "[" in inner:
# Contains nested arrays - too complex for basic parsing
return None
# Try native parsing (including struct arrays)
return _parse_array_native(inner)
except Exception:
return None


def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
"""Convert map data to Python dictionary.

Expand Down Expand Up @@ -179,6 +225,81 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
return None


def _parse_array_native(inner: str) -> Optional[List[Any]]:
"""Parse array native format: 1, 2, 3 or {a, b}, {c, d}.

Args:
inner: Interior content of array without brackets.

Returns:
List with parsed values, or None if no valid values found.
"""
result = []

# Smart split by comma - respect brace groupings
items = _split_array_items(inner)

for item in items:
if not item:
continue

# Handle struct (ROW) values in format {a, b, c} or {key=value, ...}
if item.strip().startswith("{") and item.strip().endswith("}"):
# This is a struct value - parse it as a struct
struct_value = _to_struct(item.strip())
if struct_value is not None:
result.append(struct_value)
continue

# Skip items with nested arrays or complex quoting (safety check)
if any(char in item for char in '[]="'):
continue

# Convert item to appropriate type
converted_item = _convert_value(item)
result.append(converted_item)

return result if result else None


def _split_array_items(inner: str) -> List[str]:
"""Split array items by comma, respecting brace and bracket groupings.

Args:
inner: Interior content of array without brackets.

Returns:
List of item strings.
"""
items = []
current_item = ""
brace_depth = 0
bracket_depth = 0

for char in inner:
if char == "{":
brace_depth += 1
elif char == "}":
brace_depth -= 1
elif char == "[":
bracket_depth += 1
elif char == "]":
bracket_depth -= 1
elif char == "," and brace_depth == 0 and bracket_depth == 0:
# Top-level comma - end current item
items.append(current_item.strip())
current_item = ""
continue

current_item += char

# Add the last item
if current_item.strip():
items.append(current_item.strip())

return items


def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]:
"""Parse map native format: key1=value1, key2=value2.

Expand Down Expand Up @@ -302,7 +423,7 @@ def _to_default(varchar_value: Optional[str]) -> Optional[str]:
"date": _to_date,
"time": _to_time,
"varbinary": _to_binary,
"array": _to_default,
"array": _to_array,
"map": _to_map,
"row": _to_struct,
"decimal": _to_decimal,
Expand Down
11 changes: 10 additions & 1 deletion pyathena/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
AthenaPartitionTransform,
AthenaRowFormatSerde,
)
from pyathena.sqlalchemy.types import AthenaMap, AthenaStruct
from pyathena.sqlalchemy.types import AthenaArray, AthenaMap, AthenaStruct

if TYPE_CHECKING:
from sqlalchemy import (
Expand Down Expand Up @@ -164,6 +164,15 @@ def visit_map(self, type_, **kw): # noqa: N802
def visit_MAP(self, type_, **kw): # noqa: N802
return self.visit_map(type_, **kw)

def visit_array(self, type_, **kw): # noqa: N802
if isinstance(type_, AthenaArray):
item_type_str = self.process(type_.item_type, **kw)
return f"ARRAY<{item_type_str}>"
return "ARRAY<STRING>"

def visit_ARRAY(self, type_, **kw): # noqa: N802
return self.visit_array(type_, **kw)


class AthenaStatementCompiler(SQLCompiler):
def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):
Expand Down
Loading