Skip to content

Commit cf7d4e2

Browse files
Fix struct converter validation and update test expectations
- Enhanced validation in _to_struct to properly reject complex cases with special characters (=, ", ,) - Updated SQLAlchemy and Pandas test expectations to match new struct conversion behavior - Struct values now properly converted to dictionaries instead of remaining as strings - Added proper imports for AthenaStruct in test files - Replaced print statements with logging for better test debugging 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 3820ea7 commit cf7d4e2

4 files changed

Lines changed: 158 additions & 6 deletions

File tree

pyathena/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]:
139139
value = inner[eq_pos + 1 : comma_pos].strip()
140140
current_pos = comma_pos + 1
141141

142-
# Basic validation: reject if key or value contains problematic chars
143-
if any(char in key for char in '{}=",') or any(char in value for char in '{}"'):
142+
# Basic validation: reject if key or value contains problematic chars
143+
if any(char in key for char in '{}=",') or any(char in value for char in '{}=",'):
144144
# Fall back to returning the original string for complex cases
145145
return None
146146

tests/pyathena/pandas/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_as_pandas(cursor):
119119
[1, 2],
120120
"{1=2, 3=4}",
121121
{"1": 2, "3": 4},
122-
"{a=1, b=2}",
122+
{"a": 1, "b": 2},
123123
Decimal("0.1"),
124124
)
125125
]

tests/pyathena/sqlalchemy/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sqlalchemy.sql.schema import Column, MetaData, Table
1818
from sqlalchemy.sql.selectable import TextualSelect
1919

20-
from pyathena.sqlalchemy.types import TINYINT, Tinyint
20+
from pyathena.sqlalchemy.types import TINYINT, AthenaStruct, Tinyint
2121
from tests.pyathena.conftest import ENV
2222

2323

@@ -255,8 +255,8 @@ def test_reflect_select(self, engine):
255255
date(2017, 1, 2),
256256
b"123",
257257
"[1, 2]",
258-
"{1=2, 3=4}",
259-
"{a=1, b=2}",
258+
"{1=2, 3=4}", # map type remains as string
259+
{"a": 1, "b": 2}, # row type now converted to dict
260260
Decimal("0.1"),
261261
]
262262
assert isinstance(one_row_complex.c.col_boolean.type, types.BOOLEAN)

tests/pyathena/test_cursor.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
import contextlib
3+
import logging
34
import re
45
import time
56
from concurrent import futures
@@ -17,6 +18,8 @@
1718
from tests import ENV
1819
from tests.pyathena.conftest import connect
1920

21+
_logger = logging.getLogger(__name__)
22+
2023

2124
class TestCursor:
2225
def test_fetchone(self, cursor):
@@ -749,3 +752,152 @@ def test_fetchall(self, dict_cursor):
749752
assert dict_cursor.fetchall() == [{"number_of_rows": 1}]
750753
dict_cursor.execute("SELECT a FROM many_rows ORDER BY a")
751754
assert dict_cursor.fetchall() == [{"a": i} for i in range(10000)]
755+
756+
757+
class TestComplexDataTypes:
758+
"""Test complex data types (STRUCT, ARRAY, MAP) with actual Athena queries."""
759+
760+
def test_struct_types(self, cursor):
761+
"""Test various STRUCT type scenarios to understand Athena's behavior."""
762+
test_cases = [
763+
# Basic struct
764+
("SELECT ROW('John', 30) AS simple_struct", "simple_struct"),
765+
# Named struct fields
766+
(
767+
"SELECT CAST(ROW('Alice', 25) AS ROW(name VARCHAR, age INTEGER)) AS named_struct",
768+
"named_struct",
769+
),
770+
# Struct with special characters
771+
("SELECT ROW('Hello, world', 'x=y+1') AS special_chars_struct", "special_chars_struct"),
772+
# Struct with quotes
773+
("SELECT ROW('He said \"hello\"', 'It's working') AS quotes_struct", "quotes_struct"),
774+
# Struct with NULL values
775+
("SELECT ROW('Alice', NULL, 'active') AS null_struct", "null_struct"),
776+
# Nested struct
777+
(
778+
"SELECT ROW(ROW('John', 30), ROW('Engineer', 'Tech')) AS nested_struct",
779+
"nested_struct",
780+
),
781+
# Struct as JSON (recommended for complex cases)
782+
("SELECT CAST(ROW('Alice', 25, 'Hello, world') AS JSON) AS json_struct", "json_struct"),
783+
]
784+
785+
_logger.info("=== STRUCT Type Test Results ===")
786+
for query, description in test_cases:
787+
cursor.execute(query)
788+
result = cursor.fetchone()
789+
struct_value = result[0]
790+
_logger.info(f"{description}: {struct_value!r} (type: {type(struct_value).__name__})")
791+
792+
# Basic validation
793+
assert struct_value is not None, f"STRUCT value should not be None for {description}"
794+
795+
# Test if our converter can handle it
796+
if isinstance(struct_value, str):
797+
from pyathena.converter import _to_struct
798+
799+
converted = _to_struct(struct_value)
800+
_logger.info(f" -> Converted: {converted!r}")
801+
# Add assertion to verify conversion worked when expected
802+
if converted is not None:
803+
assert isinstance(converted, dict), f"Converted value should be dict for {description}"
804+
805+
def test_array_types(self, cursor):
806+
"""Test various ARRAY type scenarios."""
807+
test_cases = [
808+
# Simple array
809+
("SELECT ARRAY[1, 2, 3, 4, 5] AS simple_array", "simple_array"),
810+
# String array
811+
("SELECT ARRAY['apple', 'banana', 'cherry'] AS string_array", "string_array"),
812+
# Array with special characters
813+
(
814+
"SELECT ARRAY['Hello, world', 'x=y+1', 'It's working'] AS special_array",
815+
"special_array",
816+
),
817+
# Array of structs
818+
("SELECT ARRAY[ROW('Alice', 25), ROW('Bob', 30)] AS struct_array", "struct_array"),
819+
# Nested arrays
820+
("SELECT ARRAY[ARRAY[1, 2], ARRAY[3, 4]] AS nested_array", "nested_array"),
821+
# Array as JSON
822+
("SELECT CAST(ARRAY['Alice', 'Bob', 'Charlie'] AS JSON) AS json_array", "json_array"),
823+
]
824+
825+
_logger.info("=== ARRAY Type Test Results ===")
826+
for query, description in test_cases:
827+
cursor.execute(query)
828+
result = cursor.fetchone()
829+
array_value = result[0]
830+
_logger.info(f"{description}: {array_value!r} (type: {type(array_value).__name__})")
831+
832+
# Basic validation
833+
assert array_value is not None, f"ARRAY value should not be None for {description}"
834+
835+
def test_map_types(self, cursor):
836+
"""Test various MAP type scenarios."""
837+
test_cases = [
838+
# Simple map
839+
(
840+
"SELECT MAP(ARRAY[1, 2, 3], ARRAY['one', 'two', 'three']) AS simple_map",
841+
"simple_map",
842+
),
843+
# String key map
844+
(
845+
"SELECT MAP(ARRAY['name', 'age', 'city'], ARRAY['John', '30', 'Tokyo']) AS string_map",
846+
"string_map",
847+
),
848+
# Map with special characters
849+
(
850+
"SELECT MAP(ARRAY['msg', 'formula'], ARRAY['Hello, world', 'x=y+1']) AS special_map",
851+
"special_map",
852+
),
853+
# Map with struct values
854+
(
855+
"SELECT MAP(ARRAY['person1', 'person2'], ARRAY[ROW('Alice', 25), ROW('Bob', 30)]) AS struct_value_map",
856+
"struct_value_map",
857+
),
858+
# Map as JSON
859+
(
860+
"SELECT CAST(MAP(ARRAY['name', 'age'], ARRAY['Alice', '25']) AS JSON) AS json_map",
861+
"json_map",
862+
),
863+
]
864+
865+
_logger.info("=== MAP Type Test Results ===")
866+
for query, description in test_cases:
867+
cursor.execute(query)
868+
result = cursor.fetchone()
869+
map_value = result[0]
870+
_logger.info(f"{description}: {map_value!r} (type: {type(map_value).__name__})")
871+
872+
# Basic validation
873+
assert map_value is not None, f"MAP value should not be None for {description}"
874+
875+
def test_complex_combinations(self, cursor):
876+
"""Test complex combinations of data types."""
877+
test_cases = [
878+
# Struct containing array and map
879+
(
880+
"SELECT ROW(ARRAY[1, 2, 3], MAP(ARRAY['a', 'b'], ARRAY[1, 2])) AS struct_with_collections",
881+
"struct_with_collections",
882+
),
883+
# Array of maps
884+
(
885+
"SELECT ARRAY[MAP(ARRAY['name'], ARRAY['Alice']), MAP(ARRAY['name'], ARRAY['Bob'])] AS array_of_maps",
886+
"array_of_maps",
887+
),
888+
# Map with array values
889+
(
890+
"SELECT MAP(ARRAY['numbers', 'letters'], ARRAY[ARRAY[1, 2, 3], ARRAY['a', 'b', 'c']]) AS map_with_arrays",
891+
"map_with_arrays",
892+
),
893+
]
894+
895+
_logger.info("=== Complex Combinations Test Results ===")
896+
for query, description in test_cases:
897+
cursor.execute(query)
898+
result = cursor.fetchone()
899+
complex_value = result[0]
900+
_logger.info(f"{description}: {complex_value!r} (type: {type(complex_value).__name__})")
901+
902+
# Basic validation
903+
assert complex_value is not None, f"Complex value should not be None for {description}"

0 commit comments

Comments
 (0)