Skip to content

Commit 2d2a66f

Browse files
Merge pull request #589 from laughingman7743/feature/map-type-support
Add comprehensive MAP type support for SQLAlchemy ORM
2 parents 7dd504e + 32b501d commit 2d2a66f

11 files changed

Lines changed: 540 additions & 270 deletions

File tree

docs/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ PyAthena provides comprehensive support for Amazon Athena's data types and featu
5151
**Data Type Support:**
5252
- **STRUCT/ROW Types**: :ref:`Complete support <sqlalchemy>` for complex nested data structures
5353
- **ARRAY Types**: Native handling of array data with automatic Python list conversion
54-
- **MAP Types**: Dictionary-like data structure support
54+
- **MAP Types**: :ref:`Complete support <sqlalchemy>` for key-value dictionary-like data structures
5555
- **JSON Integration**: Seamless JSON data parsing and conversion
5656
- **Performance Optimized**: Smart format detection for efficient data processing
5757

docs/sqlalchemy.rst

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,142 @@ Migration from Raw Strings
467467
result = cursor.execute("SELECT struct_column FROM table").fetchone()
468468
struct_data = result[0] # {"name": "John", "age": 30} - automatically converted
469469
name = struct_data['name'] # Direct access
470+
471+
MAP Type Support
472+
~~~~~~~~~~~~~~~~
473+
474+
PyAthena provides comprehensive support for Amazon Athena's MAP data types, enabling you to work with key-value data structures in your Python applications.
475+
476+
Basic Usage
477+
^^^^^^^^^^^
478+
479+
.. code:: python
480+
481+
from sqlalchemy import Column, String, Integer, Table, MetaData
482+
from pyathena.sqlalchemy.types import AthenaMap
483+
484+
# Define a table with MAP columns
485+
products = Table('products', metadata,
486+
Column('id', Integer),
487+
Column('attributes', AthenaMap(String, String)),
488+
Column('metrics', AthenaMap(String, Integer)),
489+
Column('categories', AthenaMap(Integer, String))
490+
)
491+
492+
This generates the following SQL structure:
493+
494+
.. code:: sql
495+
496+
CREATE TABLE products (
497+
id INTEGER,
498+
attributes MAP<STRING, STRING>,
499+
metrics MAP<STRING, INTEGER>,
500+
categories MAP<INTEGER, STRING>
501+
)
502+
503+
Querying MAP Data
504+
^^^^^^^^^^^^^^^^^
505+
506+
PyAthena automatically converts MAP data between different formats:
507+
508+
.. code:: python
509+
510+
from sqlalchemy import create_engine, select
511+
512+
# Query MAP data using MAP constructor
513+
result = connection.execute(
514+
select().from_statement(
515+
text("SELECT MAP(ARRAY['name', 'category'], ARRAY['Laptop', 'Electronics']) as product_info")
516+
)
517+
).fetchone()
518+
519+
# Access MAP data as dictionary
520+
product_info = result.product_info # {"name": "Laptop", "category": "Electronics"}
521+
522+
Advanced MAP Operations
523+
^^^^^^^^^^^^^^^^^^^^^^^
524+
525+
For complex MAP operations, use JSON casting:
526+
527+
.. code:: python
528+
529+
# Using CAST AS JSON for complex MAP operations
530+
result = connection.execute(
531+
select().from_statement(
532+
text("SELECT CAST(MAP(ARRAY['price', 'rating'], ARRAY['999', '4.5']) AS JSON) as data")
533+
)
534+
).fetchone()
535+
536+
# Parse JSON result
537+
import json
538+
data = json.loads(result.data) # {"price": "999", "rating": "4.5"}
539+
540+
Data Format Support
541+
^^^^^^^^^^^^^^^^^^^
542+
543+
PyAthena supports multiple MAP data formats:
544+
545+
**Athena Native Format:**
546+
547+
.. code:: python
548+
549+
# Input: "{name=Laptop, category=Electronics}"
550+
# Output: {"name": "Laptop", "category": "Electronics"}
551+
552+
**JSON Format (Recommended):**
553+
554+
.. code:: python
555+
556+
# Input: '{"name": "Laptop", "category": "Electronics"}'
557+
# Output: {"name": "Laptop", "category": "Electronics"}
558+
559+
Performance Considerations
560+
^^^^^^^^^^^^^^^^^^^^^^^^^^
561+
562+
- **JSON Format**: Recommended for complex nested structures
563+
- **Native Format**: Optimized for simple key-value pairs
564+
- **Smart Detection**: PyAthena automatically detects the format to avoid unnecessary parsing overhead
565+
566+
Best Practices
567+
^^^^^^^^^^^^^^
568+
569+
1. **Use JSON casting** for complex nested structures:
570+
571+
.. code:: sql
572+
573+
SELECT CAST(complex_map AS JSON) FROM table_name
574+
575+
2. **Define clear key-value types** in AthenaMap definitions:
576+
577+
.. code:: python
578+
579+
AthenaMap(String, Integer) # String keys, Integer values
580+
AthenaMap(Integer, AthenaStruct(...)) # Integer keys, STRUCT values
581+
582+
3. **Handle NULL values** appropriately in your application logic:
583+
584+
.. code:: python
585+
586+
if result.map_column is not None:
587+
# Process map data
588+
value = result.map_column.get('key_name')
589+
590+
Migration from Raw Strings
591+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
592+
593+
**Before (raw string handling):**
594+
595+
.. code:: python
596+
597+
result = cursor.execute("SELECT map_column FROM table").fetchone()
598+
raw_data = result[0] # "{\"key1\": \"value1\", \"key2\": \"value2\"}"
599+
import json
600+
parsed_data = json.loads(raw_data)
601+
602+
**After (automatic conversion):**
603+
604+
.. code:: python
605+
606+
result = cursor.execute("SELECT map_column FROM table").fetchone()
607+
map_data = result[0] # {"key1": "value1", "key2": "value2"} - automatically converted
608+
value = map_data['key1'] # Direct access

pyathena/converter.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,57 @@ def _to_json(varchar_value: Optional[str]) -> Optional[Any]:
7878
return json.loads(varchar_value)
7979

8080

81+
def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
82+
"""Convert map data to Python dictionary.
83+
84+
Supports two formats:
85+
1. JSON format: '{"key1": "value1", "key2": "value2"}' (recommended)
86+
2. Athena native format: '{key1=value1, key2=value2}' (basic cases only)
87+
88+
For complex maps, use CAST(map_column AS JSON) in your SQL query.
89+
90+
Args:
91+
varchar_value: String representation of map data
92+
93+
Returns:
94+
Dictionary representation of map, or None if parsing fails
95+
"""
96+
if varchar_value is None:
97+
return None
98+
99+
# Quick check: if it doesn't look like a map, return None
100+
if not (varchar_value.startswith("{") and varchar_value.endswith("}")):
101+
return None
102+
103+
# Optimize: Check if it looks like JSON vs Athena native format
104+
# JSON objects typically have quoted keys: {"key": value}
105+
# Athena native format has unquoted keys: {key=value}
106+
inner_preview = varchar_value[1:10] if len(varchar_value) > 10 else varchar_value[1:-1]
107+
108+
if '"' in inner_preview or varchar_value.startswith('{"'):
109+
# Likely JSON format - try JSON parsing
110+
try:
111+
result = json.loads(varchar_value)
112+
return result if isinstance(result, dict) else None
113+
except json.JSONDecodeError:
114+
# If JSON parsing fails, fall back to native format parsing
115+
pass
116+
117+
inner = varchar_value[1:-1].strip()
118+
if not inner:
119+
return {}
120+
121+
try:
122+
# MAP format is always key=value pairs
123+
# But for complex structures, return None to keep as string
124+
if any(char in inner for char in "()[]"):
125+
# Contains complex structures (arrays, structs), skip parsing
126+
return None
127+
return _parse_map_native(inner)
128+
except Exception:
129+
return None
130+
131+
81132
def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
82133
"""Convert struct data to Python dictionary.
83134
@@ -128,6 +179,41 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
128179
return None
129180

130181

182+
def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]:
183+
"""Parse map native format: key1=value1, key2=value2.
184+
185+
Args:
186+
inner: Interior content of map without braces.
187+
188+
Returns:
189+
Dictionary with parsed key-value pairs, or None if no valid pairs found.
190+
"""
191+
result = {}
192+
193+
# Simple split by comma for basic cases
194+
pairs = [pair.strip() for pair in inner.split(",")]
195+
196+
for pair in pairs:
197+
if "=" not in pair:
198+
continue
199+
200+
key, value = pair.split("=", 1)
201+
key = key.strip()
202+
value = value.strip()
203+
204+
# Skip pairs with special characters (safety check)
205+
if any(char in key for char in '{}="') or any(char in value for char in '{}="'):
206+
continue
207+
208+
# Convert both key and value to appropriate types
209+
converted_key = _convert_value(key)
210+
converted_value = _convert_value(value)
211+
# Always use string keys for consistency with expected test behavior
212+
result[str(converted_key)] = converted_value
213+
214+
return result if result else None
215+
216+
131217
def _parse_named_struct(inner: str) -> Optional[Dict[str, Any]]:
132218
"""Parse named struct format: a=1, b=2.
133219
@@ -217,7 +303,7 @@ def _to_default(varchar_value: Optional[str]) -> Optional[str]:
217303
"time": _to_time,
218304
"varbinary": _to_binary,
219305
"array": _to_default,
220-
"map": _to_default,
306+
"map": _to_map,
221307
"row": _to_struct,
222308
"decimal": _to_decimal,
223309
"json": _to_json,

pyathena/sqlalchemy/compiler.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
AthenaPartitionTransform,
1818
AthenaRowFormatSerde,
1919
)
20+
from pyathena.sqlalchemy.types import AthenaMap, AthenaStruct
2021

2122
if TYPE_CHECKING:
2223
from sqlalchemy import (
@@ -140,17 +141,29 @@ def visit_enum(self, type_, **kw):
140141
return self.visit_string(type_, **kw)
141142

142143
def visit_struct(self, type_, **kw): # noqa: N802
143-
if hasattr(type_, "fields") and type_.fields:
144-
field_specs = []
145-
for field_name, field_type in type_.fields.items():
146-
field_type_str = self.process(field_type, **kw)
147-
field_specs.append(f"{field_name} {field_type_str}")
148-
return f"ROW({', '.join(field_specs)})"
144+
if isinstance(type_, AthenaStruct):
145+
if type_.fields:
146+
field_specs = []
147+
for field_name, field_type in type_.fields.items():
148+
field_type_str = self.process(field_type, **kw)
149+
field_specs.append(f"{field_name} {field_type_str}")
150+
return f"ROW({', '.join(field_specs)})"
151+
return "ROW()"
149152
return "ROW()"
150153

151154
def visit_STRUCT(self, type_, **kw): # noqa: N802
152155
return self.visit_struct(type_, **kw)
153156

157+
def visit_map(self, type_, **kw): # noqa: N802
158+
if isinstance(type_, AthenaMap):
159+
key_type_str = self.process(type_.key_type, **kw)
160+
value_type_str = self.process(type_.value_type, **kw)
161+
return f"MAP<{key_type_str}, {value_type_str}>"
162+
return "MAP<STRING, STRING>"
163+
164+
def visit_MAP(self, type_, **kw): # noqa: N802
165+
return self.visit_map(type_, **kw)
166+
154167

155168
class AthenaStatementCompiler(SQLCompiler):
156169
def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):

pyathena/sqlalchemy/types.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,32 @@ def python_type(self) -> type:
7777

7878
class STRUCT(AthenaStruct):
7979
__visit_name__ = "STRUCT"
80+
81+
82+
class AthenaMap(TypeEngine[Dict[str, Any]]):
83+
__visit_name__ = "map"
84+
85+
def __init__(self, key_type: Any = None, value_type: Any = None) -> None:
86+
if key_type is None:
87+
self.key_type: TypeEngine[Any] = sqltypes.String()
88+
elif isinstance(key_type, TypeEngine):
89+
self.key_type = key_type
90+
else:
91+
# Assume it's a SQLAlchemy type class and instantiate it
92+
self.key_type = key_type()
93+
94+
if value_type is None:
95+
self.value_type: TypeEngine[Any] = sqltypes.String()
96+
elif isinstance(value_type, TypeEngine):
97+
self.value_type = value_type
98+
else:
99+
# Assume it's a SQLAlchemy type class and instantiate it
100+
self.value_type = value_type()
101+
102+
@property
103+
def python_type(self) -> type:
104+
return dict
105+
106+
107+
class MAP(AthenaMap):
108+
__visit_name__ = "MAP"

tests/pyathena/pandas/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_as_pandas(cursor):
117117
b"123",
118118
"[1, 2]",
119119
[1, 2],
120-
"{1=2, 3=4}",
120+
{"1": 2, "3": 4},
121121
{"1": 2, "3": 4},
122122
{"a": 1, "b": 2},
123123
Decimal("0.1"),

tests/pyathena/sqlalchemy/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def test_reflect_select(self, engine):
255255
date(2017, 1, 2),
256256
b"123",
257257
"[1, 2]",
258-
"{1=2, 3=4}", # map type remains as string
258+
{"1": 2, "3": 4}, # map type now converted to dict
259259
{"a": 1, "b": 2}, # row type now converted to dict
260260
Decimal("0.1"),
261261
]

tests/pyathena/sqlalchemy/test_compiler.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy import Integer, String
66

77
from pyathena.sqlalchemy.compiler import AthenaTypeCompiler
8-
from pyathena.sqlalchemy.types import STRUCT, AthenaStruct
8+
from pyathena.sqlalchemy.types import MAP, STRUCT, AthenaMap, AthenaStruct
99

1010

1111
class TestAthenaTypeCompiler:
@@ -51,3 +51,32 @@ def test_visit_struct_single_field(self):
5151
struct_type = AthenaStruct(("name", String))
5252
result = compiler.visit_struct(struct_type)
5353
assert result == "ROW(name STRING)" or result == "ROW(name VARCHAR)"
54+
55+
def test_visit_map_default(self):
56+
dialect = Mock()
57+
compiler = AthenaTypeCompiler(dialect)
58+
map_type = AthenaMap()
59+
result = compiler.visit_map(map_type)
60+
assert result == "MAP<STRING, STRING>"
61+
62+
def test_visit_map_with_types(self):
63+
dialect = Mock()
64+
compiler = AthenaTypeCompiler(dialect)
65+
map_type = AthenaMap(String, Integer)
66+
result = compiler.visit_map(map_type)
67+
assert result == "MAP<STRING, INTEGER>" or result == "MAP<VARCHAR, INTEGER>"
68+
69+
def test_visit_map_uppercase(self):
70+
dialect = Mock()
71+
compiler = AthenaTypeCompiler(dialect)
72+
map_type = MAP(Integer, String)
73+
result = compiler.visit_MAP(map_type)
74+
assert result == "MAP<INTEGER, STRING>" or result == "MAP<INTEGER, VARCHAR>"
75+
76+
def test_visit_map_no_attributes(self):
77+
# Test map type without key_type/value_type attributes
78+
dialect = Mock()
79+
compiler = AthenaTypeCompiler(dialect)
80+
map_type = type("MockMap", (), {})()
81+
result = compiler.visit_map(map_type)
82+
assert result == "MAP<STRING, STRING>"

0 commit comments

Comments
 (0)