Skip to content

Commit b9ba8a0

Browse files
Add comprehensive MAP type support for SQLAlchemy ORM
This implements complete MAP type support addressing GitHub issue #553: ## Core Implementation - Add _to_map converter function supporting JSON and Athena native formats - Create AthenaMap and MAP type classes for SQLAlchemy integration - Add visit_map methods to type compiler for SQL generation - Smart format detection for optimal performance ## Features - Full SQLAlchemy ORM support with MAP<key_type, value_type> syntax - Automatic conversion between string and dictionary representations - Support for both JSON and Athena native {key=value} formats - Type-safe key and value type specifications - Performance optimized with format detection ## Testing & Documentation - Comprehensive unit tests for types, compiler, and converter - Integration tests with actual Athena MAP queries - Complete documentation with usage examples and best practices - Migration guide for existing raw string handling Closes #553 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 7dd504e commit b9ba8a0

8 files changed

Lines changed: 344 additions & 7 deletions

File tree

docs/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ PyAthena provides comprehensive support for Amazon Athena's data types and featu
5151
**Data Type Support:**
5252
- **STRUCT/ROW Types**: :ref:`Complete support <sqlalchemy>` for complex nested data structures
5353
- **ARRAY Types**: Native handling of array data with automatic Python list conversion
54-
- **MAP Types**: Dictionary-like data structure support
54+
- **MAP Types**: :ref:`Complete support <sqlalchemy>` for key-value dictionary-like data structures
5555
- **JSON Integration**: Seamless JSON data parsing and conversion
5656
- **Performance Optimized**: Smart format detection for efficient data processing
5757

docs/sqlalchemy.rst

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,142 @@ Migration from Raw Strings
467467
result = cursor.execute("SELECT struct_column FROM table").fetchone()
468468
struct_data = result[0] # {"name": "John", "age": 30} - automatically converted
469469
name = struct_data['name'] # Direct access
470+
471+
MAP Type Support
472+
~~~~~~~~~~~~~~~~
473+
474+
PyAthena provides comprehensive support for Amazon Athena's MAP data types, enabling you to work with key-value data structures in your Python applications.
475+
476+
Basic Usage
477+
^^^^^^^^^^^
478+
479+
.. code:: python
480+
481+
from sqlalchemy import Column, String, Integer, Table, MetaData
482+
from pyathena.sqlalchemy.types import AthenaMap
483+
484+
# Define a table with MAP columns
485+
products = Table('products', metadata,
486+
Column('id', Integer),
487+
Column('attributes', AthenaMap(String, String)),
488+
Column('metrics', AthenaMap(String, Integer)),
489+
Column('categories', AthenaMap(Integer, String))
490+
)
491+
492+
This generates the following SQL structure:
493+
494+
.. code:: sql
495+
496+
CREATE TABLE products (
497+
id INTEGER,
498+
attributes MAP<STRING, STRING>,
499+
metrics MAP<STRING, INTEGER>,
500+
categories MAP<INTEGER, STRING>
501+
)
502+
503+
Querying MAP Data
504+
^^^^^^^^^^^^^^^^^
505+
506+
PyAthena automatically converts MAP data between different formats:
507+
508+
.. code:: python
509+
510+
from sqlalchemy import create_engine, select
511+
512+
# Query MAP data using MAP constructor
513+
result = connection.execute(
514+
select().from_statement(
515+
text("SELECT MAP(ARRAY['name', 'category'], ARRAY['Laptop', 'Electronics']) as product_info")
516+
)
517+
).fetchone()
518+
519+
# Access MAP data as dictionary
520+
product_info = result.product_info # {"name": "Laptop", "category": "Electronics"}
521+
522+
Advanced MAP Operations
523+
^^^^^^^^^^^^^^^^^^^^^^^
524+
525+
For complex MAP operations, use JSON casting:
526+
527+
.. code:: python
528+
529+
# Using CAST AS JSON for complex MAP operations
530+
result = connection.execute(
531+
select().from_statement(
532+
text("SELECT CAST(MAP(ARRAY['price', 'rating'], ARRAY['999', '4.5']) AS JSON) as data")
533+
)
534+
).fetchone()
535+
536+
# Parse JSON result
537+
import json
538+
data = json.loads(result.data) # {"price": "999", "rating": "4.5"}
539+
540+
Data Format Support
541+
^^^^^^^^^^^^^^^^^^^
542+
543+
PyAthena supports multiple MAP data formats:
544+
545+
**Athena Native Format:**
546+
547+
.. code:: python
548+
549+
# Input: "{name=Laptop, category=Electronics}"
550+
# Output: {"name": "Laptop", "category": "Electronics"}
551+
552+
**JSON Format (Recommended):**
553+
554+
.. code:: python
555+
556+
# Input: '{"name": "Laptop", "category": "Electronics"}'
557+
# Output: {"name": "Laptop", "category": "Electronics"}
558+
559+
Performance Considerations
560+
^^^^^^^^^^^^^^^^^^^^^^^^^^
561+
562+
- **JSON Format**: Recommended for complex nested structures
563+
- **Native Format**: Optimized for simple key-value pairs
564+
- **Smart Detection**: PyAthena automatically detects the format to avoid unnecessary parsing overhead
565+
566+
Best Practices
567+
^^^^^^^^^^^^^^
568+
569+
1. **Use JSON casting** for complex nested structures:
570+
571+
.. code:: sql
572+
573+
SELECT CAST(complex_map AS JSON) FROM table_name
574+
575+
2. **Define clear key-value types** in AthenaMap definitions:
576+
577+
.. code:: python
578+
579+
AthenaMap(String, Integer) # String keys, Integer values
580+
AthenaMap(Integer, AthenaStruct(...)) # Integer keys, STRUCT values
581+
582+
3. **Handle NULL values** appropriately in your application logic:
583+
584+
.. code:: python
585+
586+
if result.map_column is not None:
587+
# Process map data
588+
value = result.map_column.get('key_name')
589+
590+
Migration from Raw Strings
591+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
592+
593+
**Before (raw string handling):**
594+
595+
.. code:: python
596+
597+
result = cursor.execute("SELECT map_column FROM table").fetchone()
598+
raw_data = result[0] # "{\"key1\": \"value1\", \"key2\": \"value2\"}"
599+
import json
600+
parsed_data = json.loads(raw_data)
601+
602+
**After (automatic conversion):**
603+
604+
.. code:: python
605+
606+
result = cursor.execute("SELECT map_column FROM table").fetchone()
607+
map_data = result[0] # {"key1": "value1", "key2": "value2"} - automatically converted
608+
value = map_data['key1'] # Direct access

pyathena/converter.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,53 @@ def _to_json(varchar_value: Optional[str]) -> Optional[Any]:
7878
return json.loads(varchar_value)
7979

8080

81+
def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
82+
"""Convert map data to Python dictionary.
83+
84+
Supports two formats:
85+
1. JSON format: '{"key1": "value1", "key2": "value2"}' (recommended)
86+
2. Athena native format: '{key1=value1, key2=value2}' (basic cases only)
87+
88+
For complex maps, use CAST(map_column AS JSON) in your SQL query.
89+
90+
Args:
91+
varchar_value: String representation of map data
92+
93+
Returns:
94+
Dictionary representation of map, or None if parsing fails
95+
"""
96+
if varchar_value is None:
97+
return None
98+
99+
# Quick check: if it doesn't look like a map, return None
100+
if not (varchar_value.startswith("{") and varchar_value.endswith("}")):
101+
return None
102+
103+
# Optimize: Check if it looks like JSON vs Athena native format
104+
# JSON objects typically have quoted keys: {"key": value}
105+
# Athena native format has unquoted keys: {key=value}
106+
inner_preview = varchar_value[1:10] if len(varchar_value) > 10 else varchar_value[1:-1]
107+
108+
if '"' in inner_preview or varchar_value.startswith('{"'):
109+
# Likely JSON format - try JSON parsing
110+
try:
111+
result = json.loads(varchar_value)
112+
return result if isinstance(result, dict) else None
113+
except json.JSONDecodeError:
114+
# If JSON parsing fails, fall back to native format parsing
115+
pass
116+
117+
inner = varchar_value[1:-1].strip()
118+
if not inner:
119+
return {}
120+
121+
try:
122+
# MAP format is always key=value pairs
123+
return _parse_map_native(inner)
124+
except Exception:
125+
return None
126+
127+
81128
def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
82129
"""Convert struct data to Python dictionary.
83130
@@ -128,6 +175,38 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
128175
return None
129176

130177

178+
def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]:
179+
"""Parse map native format: key1=value1, key2=value2.
180+
181+
Args:
182+
inner: Interior content of map without braces.
183+
184+
Returns:
185+
Dictionary with parsed key-value pairs, or None if no valid pairs found.
186+
"""
187+
result = {}
188+
189+
# Simple split by comma for basic cases
190+
pairs = [pair.strip() for pair in inner.split(",")]
191+
192+
for pair in pairs:
193+
if "=" not in pair:
194+
continue
195+
196+
key, value = pair.split("=", 1)
197+
key = key.strip()
198+
value = value.strip()
199+
200+
# Skip pairs with special characters (safety check)
201+
if any(char in key for char in '{}="') or any(char in value for char in '{}="'):
202+
continue
203+
204+
# Convert value to appropriate type
205+
result[key] = _convert_value(value)
206+
207+
return result if result else None
208+
209+
131210
def _parse_named_struct(inner: str) -> Optional[Dict[str, Any]]:
132211
"""Parse named struct format: a=1, b=2.
133212
@@ -217,7 +296,7 @@ def _to_default(varchar_value: Optional[str]) -> Optional[str]:
217296
"time": _to_time,
218297
"varbinary": _to_binary,
219298
"array": _to_default,
220-
"map": _to_default,
299+
"map": _to_map,
221300
"row": _to_struct,
222301
"decimal": _to_decimal,
223302
"json": _to_json,

pyathena/sqlalchemy/compiler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ def visit_struct(self, type_, **kw): # noqa: N802
151151
def visit_STRUCT(self, type_, **kw): # noqa: N802
152152
return self.visit_struct(type_, **kw)
153153

154+
def visit_map(self, type_, **kw): # noqa: N802
155+
if hasattr(type_, "key_type") and hasattr(type_, "value_type"):
156+
key_type_str = self.process(type_.key_type, **kw)
157+
value_type_str = self.process(type_.value_type, **kw)
158+
return f"MAP<{key_type_str}, {value_type_str}>"
159+
return "MAP<STRING, STRING>"
160+
161+
def visit_MAP(self, type_, **kw): # noqa: N802
162+
return self.visit_map(type_, **kw)
163+
154164

155165
class AthenaStatementCompiler(SQLCompiler):
156166
def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):

pyathena/sqlalchemy/types.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,32 @@ def python_type(self) -> type:
7777

7878
class STRUCT(AthenaStruct):
7979
__visit_name__ = "STRUCT"
80+
81+
82+
class AthenaMap(TypeEngine[Dict[str, Any]]):
83+
__visit_name__ = "map"
84+
85+
def __init__(self, key_type: Any = None, value_type: Any = None) -> None:
86+
if key_type is None:
87+
self.key_type: TypeEngine[Any] = sqltypes.String()
88+
elif isinstance(key_type, TypeEngine):
89+
self.key_type = key_type
90+
else:
91+
# Assume it's a SQLAlchemy type class and instantiate it
92+
self.key_type = key_type()
93+
94+
if value_type is None:
95+
self.value_type: TypeEngine[Any] = sqltypes.String()
96+
elif isinstance(value_type, TypeEngine):
97+
self.value_type = value_type
98+
else:
99+
# Assume it's a SQLAlchemy type class and instantiate it
100+
self.value_type = value_type()
101+
102+
@property
103+
def python_type(self) -> type:
104+
return dict
105+
106+
107+
class MAP(AthenaMap):
108+
__visit_name__ = "MAP"

tests/pyathena/sqlalchemy/test_compiler.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy import Integer, String
66

77
from pyathena.sqlalchemy.compiler import AthenaTypeCompiler
8-
from pyathena.sqlalchemy.types import STRUCT, AthenaStruct
8+
from pyathena.sqlalchemy.types import MAP, STRUCT, AthenaMap, AthenaStruct
99

1010

1111
class TestAthenaTypeCompiler:
@@ -51,3 +51,32 @@ def test_visit_struct_single_field(self):
5151
struct_type = AthenaStruct(("name", String))
5252
result = compiler.visit_struct(struct_type)
5353
assert result == "ROW(name STRING)" or result == "ROW(name VARCHAR)"
54+
55+
def test_visit_map_default(self):
56+
dialect = Mock()
57+
compiler = AthenaTypeCompiler(dialect)
58+
map_type = AthenaMap()
59+
result = compiler.visit_map(map_type)
60+
assert result == "MAP<STRING, STRING>"
61+
62+
def test_visit_map_with_types(self):
63+
dialect = Mock()
64+
compiler = AthenaTypeCompiler(dialect)
65+
map_type = AthenaMap(String, Integer)
66+
result = compiler.visit_map(map_type)
67+
assert result == "MAP<STRING, INTEGER>" or result == "MAP<VARCHAR, INTEGER>"
68+
69+
def test_visit_map_uppercase(self):
70+
dialect = Mock()
71+
compiler = AthenaTypeCompiler(dialect)
72+
map_type = MAP(Integer, String)
73+
result = compiler.visit_MAP(map_type)
74+
assert result == "MAP<INTEGER, STRING>" or result == "MAP<INTEGER, VARCHAR>"
75+
76+
def test_visit_map_no_attributes(self):
77+
# Test map type without key_type/value_type attributes
78+
dialect = Mock()
79+
compiler = AthenaTypeCompiler(dialect)
80+
map_type = type("MockMap", (), {})()
81+
result = compiler.visit_map(map_type)
82+
assert result == "MAP<STRING, STRING>"

tests/pyathena/sqlalchemy/test_types.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlalchemy import Integer, String
44
from sqlalchemy.sql import sqltypes
55

6-
from pyathena.sqlalchemy.types import STRUCT, AthenaStruct
6+
from pyathena.sqlalchemy.types import MAP, STRUCT, AthenaMap, AthenaStruct
77

88

99
class TestAthenaStruct:
@@ -64,3 +64,37 @@ def test_field_access_nonexistent_key(self):
6464
struct_type = AthenaStruct(("name", String))
6565
with pytest.raises(KeyError):
6666
struct_type["nonexistent"]
67+
68+
69+
class TestAthenaMap:
70+
def test_creation_with_defaults(self):
71+
map_type = AthenaMap()
72+
assert isinstance(map_type.key_type, sqltypes.String)
73+
assert isinstance(map_type.value_type, sqltypes.String)
74+
75+
def test_creation_with_type_classes(self):
76+
map_type = AthenaMap(String, Integer)
77+
assert isinstance(map_type.key_type, sqltypes.String)
78+
assert isinstance(map_type.value_type, sqltypes.Integer)
79+
80+
def test_creation_with_type_instances(self):
81+
map_type = AthenaMap(String(), Integer())
82+
assert isinstance(map_type.key_type, sqltypes.String)
83+
assert isinstance(map_type.value_type, sqltypes.Integer)
84+
85+
def test_python_type(self):
86+
map_type = AthenaMap()
87+
assert map_type.python_type is dict
88+
89+
def test_visit_name(self):
90+
map_type = AthenaMap()
91+
assert map_type.__visit_name__ == "map"
92+
93+
def test_map_uppercase_visit_name(self):
94+
map_type = MAP()
95+
assert map_type.__visit_name__ == "MAP"
96+
97+
def test_mixed_type_definitions(self):
98+
map_type = AthenaMap(String, Integer())
99+
assert isinstance(map_type.key_type, sqltypes.String)
100+
assert isinstance(map_type.value_type, sqltypes.Integer)

0 commit comments

Comments
 (0)