Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ PyAthena provides comprehensive support for Amazon Athena's data types and featu
**Data Type Support:**
- **STRUCT/ROW Types**: :ref:`Complete support <sqlalchemy>` for complex nested data structures
- **ARRAY Types**: Native handling of array data with automatic Python list conversion
- **MAP Types**: Dictionary-like data structure support
- **MAP Types**: :ref:`Complete support <sqlalchemy>` for key-value dictionary-like data structures
- **JSON Integration**: Seamless JSON data parsing and conversion
- **Performance Optimized**: Smart format detection for efficient data processing

Expand Down
139 changes: 139 additions & 0 deletions docs/sqlalchemy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,142 @@ Migration from Raw Strings
result = cursor.execute("SELECT struct_column FROM table").fetchone()
struct_data = result[0] # {"name": "John", "age": 30} - automatically converted
name = struct_data['name'] # Direct access

MAP Type Support
~~~~~~~~~~~~~~~~

PyAthena provides comprehensive support for Amazon Athena's MAP data types, enabling you to work with key-value data structures in your Python applications.

Basic Usage
^^^^^^^^^^^

.. code:: python

from sqlalchemy import Column, String, Integer, Table, MetaData
from pyathena.sqlalchemy.types import AthenaMap

# Define a table with MAP columns
products = Table('products', metadata,
Column('id', Integer),
Column('attributes', AthenaMap(String, String)),
Column('metrics', AthenaMap(String, Integer)),
Column('categories', AthenaMap(Integer, String))
)

This generates the following SQL structure:

.. code:: sql

CREATE TABLE products (
id INTEGER,
attributes MAP<STRING, STRING>,
metrics MAP<STRING, INTEGER>,
categories MAP<INTEGER, STRING>
)

Querying MAP Data
^^^^^^^^^^^^^^^^^

PyAthena automatically converts MAP data between different formats:

.. code:: python

from sqlalchemy import create_engine, select

# Query MAP data using MAP constructor
result = connection.execute(
select().from_statement(
text("SELECT MAP(ARRAY['name', 'category'], ARRAY['Laptop', 'Electronics']) as product_info")
)
).fetchone()

# Access MAP data as dictionary
product_info = result.product_info # {"name": "Laptop", "category": "Electronics"}

Advanced MAP Operations
^^^^^^^^^^^^^^^^^^^^^^^

For complex MAP operations, use JSON casting:

.. code:: python

# Using CAST AS JSON for complex MAP operations
result = connection.execute(
select().from_statement(
text("SELECT CAST(MAP(ARRAY['price', 'rating'], ARRAY['999', '4.5']) AS JSON) as data")
)
).fetchone()

# Parse JSON result
import json
data = json.loads(result.data) # {"price": "999", "rating": "4.5"}

Data Format Support
^^^^^^^^^^^^^^^^^^^

PyAthena supports multiple MAP data formats:

**Athena Native Format:**

.. code:: python

# Input: "{name=Laptop, category=Electronics}"
# Output: {"name": "Laptop", "category": "Electronics"}

**JSON Format (Recommended):**

.. code:: python

# Input: '{"name": "Laptop", "category": "Electronics"}'
# Output: {"name": "Laptop", "category": "Electronics"}

Performance Considerations
^^^^^^^^^^^^^^^^^^^^^^^^^^

- **JSON Format**: Recommended for complex nested structures
- **Native Format**: Optimized for simple key-value pairs
- **Smart Detection**: PyAthena automatically detects the format to avoid unnecessary parsing overhead

Best Practices
^^^^^^^^^^^^^^

1. **Use JSON casting** for complex nested structures:

.. code:: sql

SELECT CAST(complex_map AS JSON) FROM table_name

2. **Define clear key-value types** in AthenaMap definitions:

.. code:: python

AthenaMap(String, Integer) # String keys, Integer values
AthenaMap(Integer, AthenaStruct(...)) # Integer keys, STRUCT values

3. **Handle NULL values** appropriately in your application logic:

.. code:: python

if result.map_column is not None:
# Process map data
value = result.map_column.get('key_name')

Migration from Raw Strings
^^^^^^^^^^^^^^^^^^^^^^^^^^^

**Before (raw string handling):**

.. code:: python

result = cursor.execute("SELECT map_column FROM table").fetchone()
raw_data = result[0] # "{\"key1\": \"value1\", \"key2\": \"value2\"}"
import json
parsed_data = json.loads(raw_data)

**After (automatic conversion):**

.. code:: python

result = cursor.execute("SELECT map_column FROM table").fetchone()
map_data = result[0] # {"key1": "value1", "key2": "value2"} - automatically converted
value = map_data['key1'] # Direct access
88 changes: 87 additions & 1 deletion pyathena/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,57 @@ def _to_json(varchar_value: Optional[str]) -> Optional[Any]:
return json.loads(varchar_value)


def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
"""Convert map data to Python dictionary.

Supports two formats:
1. JSON format: '{"key1": "value1", "key2": "value2"}' (recommended)
2. Athena native format: '{key1=value1, key2=value2}' (basic cases only)

For complex maps, use CAST(map_column AS JSON) in your SQL query.

Args:
varchar_value: String representation of map data

Returns:
Dictionary representation of map, or None if parsing fails
"""
if varchar_value is None:
return None

# Quick check: if it doesn't look like a map, return None
if not (varchar_value.startswith("{") and varchar_value.endswith("}")):
return None

# Optimize: Check if it looks like JSON vs Athena native format
# JSON objects typically have quoted keys: {"key": value}
# Athena native format has unquoted keys: {key=value}
inner_preview = varchar_value[1:10] if len(varchar_value) > 10 else varchar_value[1:-1]

if '"' in inner_preview or varchar_value.startswith('{"'):
# Likely JSON format - try JSON parsing
try:
result = json.loads(varchar_value)
return result if isinstance(result, dict) else None
except json.JSONDecodeError:
# If JSON parsing fails, fall back to native format parsing
pass

inner = varchar_value[1:-1].strip()
if not inner:
return {}

try:
# MAP format is always key=value pairs
# But for complex structures, return None to keep as string
if any(char in inner for char in "()[]"):
# Contains complex structures (arrays, structs), skip parsing
return None
return _parse_map_native(inner)
except Exception:
return None


def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
"""Convert struct data to Python dictionary.

Expand Down Expand Up @@ -128,6 +179,41 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
return None


def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]:
"""Parse map native format: key1=value1, key2=value2.

Args:
inner: Interior content of map without braces.

Returns:
Dictionary with parsed key-value pairs, or None if no valid pairs found.
"""
result = {}

# Simple split by comma for basic cases
pairs = [pair.strip() for pair in inner.split(",")]

for pair in pairs:
if "=" not in pair:
continue

key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()

# Skip pairs with special characters (safety check)
if any(char in key for char in '{}="') or any(char in value for char in '{}="'):
continue

# Convert both key and value to appropriate types
converted_key = _convert_value(key)
converted_value = _convert_value(value)
# Always use string keys for consistency with expected test behavior
result[str(converted_key)] = converted_value

return result if result else None


def _parse_named_struct(inner: str) -> Optional[Dict[str, Any]]:
"""Parse named struct format: a=1, b=2.

Expand Down Expand Up @@ -217,7 +303,7 @@ def _to_default(varchar_value: Optional[str]) -> Optional[str]:
"time": _to_time,
"varbinary": _to_binary,
"array": _to_default,
"map": _to_default,
"map": _to_map,
"row": _to_struct,
"decimal": _to_decimal,
"json": _to_json,
Expand Down
25 changes: 19 additions & 6 deletions pyathena/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AthenaPartitionTransform,
AthenaRowFormatSerde,
)
from pyathena.sqlalchemy.types import AthenaMap, AthenaStruct

if TYPE_CHECKING:
from sqlalchemy import (
Expand Down Expand Up @@ -140,17 +141,29 @@ def visit_enum(self, type_, **kw):
return self.visit_string(type_, **kw)

def visit_struct(self, type_, **kw): # noqa: N802
if hasattr(type_, "fields") and type_.fields:
field_specs = []
for field_name, field_type in type_.fields.items():
field_type_str = self.process(field_type, **kw)
field_specs.append(f"{field_name} {field_type_str}")
return f"ROW({', '.join(field_specs)})"
if isinstance(type_, AthenaStruct):
if type_.fields:
field_specs = []
for field_name, field_type in type_.fields.items():
field_type_str = self.process(field_type, **kw)
field_specs.append(f"{field_name} {field_type_str}")
return f"ROW({', '.join(field_specs)})"
return "ROW()"
return "ROW()"

def visit_STRUCT(self, type_, **kw): # noqa: N802
return self.visit_struct(type_, **kw)

def visit_map(self, type_, **kw): # noqa: N802
if isinstance(type_, AthenaMap):
key_type_str = self.process(type_.key_type, **kw)
value_type_str = self.process(type_.value_type, **kw)
return f"MAP<{key_type_str}, {value_type_str}>"
return "MAP<STRING, STRING>"

def visit_MAP(self, type_, **kw): # noqa: N802
return self.visit_map(type_, **kw)


class AthenaStatementCompiler(SQLCompiler):
def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):
Expand Down
29 changes: 29 additions & 0 deletions pyathena/sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,32 @@ def python_type(self) -> type:

class STRUCT(AthenaStruct):
__visit_name__ = "STRUCT"


class AthenaMap(TypeEngine[Dict[str, Any]]):
__visit_name__ = "map"

def __init__(self, key_type: Any = None, value_type: Any = None) -> None:
if key_type is None:
self.key_type: TypeEngine[Any] = sqltypes.String()
elif isinstance(key_type, TypeEngine):
self.key_type = key_type
else:
# Assume it's a SQLAlchemy type class and instantiate it
self.key_type = key_type()

if value_type is None:
self.value_type: TypeEngine[Any] = sqltypes.String()
elif isinstance(value_type, TypeEngine):
self.value_type = value_type
else:
# Assume it's a SQLAlchemy type class and instantiate it
self.value_type = value_type()

@property
def python_type(self) -> type:
return dict


class MAP(AthenaMap):
__visit_name__ = "MAP"
2 changes: 1 addition & 1 deletion tests/pyathena/pandas/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_as_pandas(cursor):
b"123",
"[1, 2]",
[1, 2],
"{1=2, 3=4}",
{"1": 2, "3": 4},
{"1": 2, "3": 4},
{"a": 1, "b": 2},
Decimal("0.1"),
Expand Down
2 changes: 1 addition & 1 deletion tests/pyathena/sqlalchemy/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_reflect_select(self, engine):
date(2017, 1, 2),
b"123",
"[1, 2]",
"{1=2, 3=4}", # map type remains as string
{"1": 2, "3": 4}, # map type now converted to dict
{"a": 1, "b": 2}, # row type now converted to dict
Decimal("0.1"),
]
Expand Down
31 changes: 30 additions & 1 deletion tests/pyathena/sqlalchemy/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import Integer, String

from pyathena.sqlalchemy.compiler import AthenaTypeCompiler
from pyathena.sqlalchemy.types import STRUCT, AthenaStruct
from pyathena.sqlalchemy.types import MAP, STRUCT, AthenaMap, AthenaStruct


class TestAthenaTypeCompiler:
Expand Down Expand Up @@ -51,3 +51,32 @@ def test_visit_struct_single_field(self):
struct_type = AthenaStruct(("name", String))
result = compiler.visit_struct(struct_type)
assert result == "ROW(name STRING)" or result == "ROW(name VARCHAR)"

def test_visit_map_default(self):
dialect = Mock()
compiler = AthenaTypeCompiler(dialect)
map_type = AthenaMap()
result = compiler.visit_map(map_type)
assert result == "MAP<STRING, STRING>"

def test_visit_map_with_types(self):
dialect = Mock()
compiler = AthenaTypeCompiler(dialect)
map_type = AthenaMap(String, Integer)
result = compiler.visit_map(map_type)
assert result == "MAP<STRING, INTEGER>" or result == "MAP<VARCHAR, INTEGER>"

def test_visit_map_uppercase(self):
dialect = Mock()
compiler = AthenaTypeCompiler(dialect)
map_type = MAP(Integer, String)
result = compiler.visit_MAP(map_type)
assert result == "MAP<INTEGER, STRING>" or result == "MAP<INTEGER, VARCHAR>"

def test_visit_map_no_attributes(self):
# Test map type without key_type/value_type attributes
dialect = Mock()
compiler = AthenaTypeCompiler(dialect)
map_type = type("MockMap", (), {})()
result = compiler.visit_map(map_type)
assert result == "MAP<STRING, STRING>"
Loading