diff --git a/docs/introduction.rst b/docs/introduction.rst index 5ae485aa..dbb0ec65 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -50,7 +50,7 @@ PyAthena provides comprehensive support for Amazon Athena's data types and featu **Data Type Support:** - **STRUCT/ROW Types**: :ref:`Complete support ` for complex nested data structures - - **ARRAY Types**: Native handling of array data with automatic Python list conversion + - **ARRAY Types**: :ref:`Complete support ` for ordered collections with automatic Python list conversion - **MAP Types**: :ref:`Complete support ` for key-value dictionary-like data structures - **JSON Integration**: Seamless JSON data parsing and conversion - **Performance Optimized**: Smart format detection for efficient data processing diff --git a/docs/sqlalchemy.rst b/docs/sqlalchemy.rst index 92994fdf..25374639 100644 --- a/docs/sqlalchemy.rst +++ b/docs/sqlalchemy.rst @@ -606,3 +606,181 @@ Migration from Raw Strings result = cursor.execute("SELECT map_column FROM table").fetchone() map_data = result[0] # {"key1": "value1", "key2": "value2"} - automatically converted value = map_data['key1'] # Direct access + +ARRAY Type Support +~~~~~~~~~~~~~~~~~~ + +PyAthena provides comprehensive support for Amazon Athena's ARRAY data types, enabling you to work with ordered collections of data in your Python applications. + +Basic Usage +^^^^^^^^^^^ + +.. code:: python + + from sqlalchemy import Column, String, Integer, Table, MetaData + from pyathena.sqlalchemy.types import AthenaArray + + # Define a table with ARRAY columns + orders = Table('orders', metadata, + Column('id', Integer), + Column('item_ids', AthenaArray(Integer)), + Column('tags', AthenaArray(String)), + Column('categories', AthenaArray(String)) + ) + +This creates a table definition equivalent to: + +.. code:: sql + + CREATE TABLE orders ( + id INTEGER, + item_ids ARRAY, + tags ARRAY, + categories ARRAY + ) + +Querying ARRAY Data +^^^^^^^^^^^^^^^^^^^ + +PyAthena automatically converts ARRAY data between different formats: + +.. code:: python + + from sqlalchemy import create_engine, select + + # Query ARRAY data using ARRAY constructor + result = connection.execute( + select().from_statement( + text("SELECT ARRAY[1, 2, 3, 4, 5] as item_ids") + ) + ).fetchone() + + # Access ARRAY data as Python list + item_ids = result.item_ids # [1, 2, 3, 4, 5] + +Complex ARRAY Operations +^^^^^^^^^^^^^^^^^^^^^^^^ + +For arrays containing complex data types: + +.. code:: python + + # Arrays with STRUCT elements + result = connection.execute( + select().from_statement( + text("SELECT ARRAY[ROW('Alice', 25), ROW('Bob', 30)] as users") + ) + ).fetchone() + + users = result.users # [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}] + + # Using CAST AS JSON for complex ARRAY operations + result = connection.execute( + select().from_statement( + text("SELECT CAST(ARRAY[1, 2, 3] AS JSON) as data") + ) + ).fetchone() + + # Parse JSON result + import json + if isinstance(result.data, str): + array_data = json.loads(result.data) # [1, 2, 3] + else: + array_data = result.data # Already converted to list + +Data Format Support +^^^^^^^^^^^^^^^^^^^ + +PyAthena supports multiple ARRAY data formats: + +**Athena Native Format:** + +.. code:: python + + # Input: '[1, 2, 3]' + # Output: [1, 2, 3] + + # Input: '[apple, banana, cherry]' + # Output: ["apple", "banana", "cherry"] + +**JSON Format:** + +.. code:: python + + # Input: '[1, 2, 3]' + # Output: [1, 2, 3] + + # Input: '["apple", "banana", "cherry"]' + # Output: ["apple", "banana", "cherry"] + +**Complex Nested Arrays:** + +.. code:: python + + # Input: '[{name=John, age=30}, {name=Jane, age=25}]' + # Output: [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] + +Type Definitions +^^^^^^^^^^^^^^^^ + +AthenaArray supports various item types: + +.. code:: python + + from pyathena.sqlalchemy.types import AthenaArray, AthenaStruct, AthenaMap + + # Simple arrays + AthenaArray(String) # ARRAY + AthenaArray(Integer) # ARRAY + + # Arrays of complex types + AthenaArray(AthenaStruct(...)) # ARRAY> + AthenaArray(AthenaMap(...)) # ARRAY> + + # Nested arrays + AthenaArray(AthenaArray(Integer)) # ARRAY> + +Best Practices +^^^^^^^^^^^^^^ + +1. **Use appropriate item types** in AthenaArray definitions: + + .. code:: python + + AthenaArray(Integer) # For numeric arrays + AthenaArray(String) # For string arrays + AthenaArray(AthenaStruct(...)) # For arrays of structs + +2. **Use CAST AS JSON** for complex array operations: + + .. code:: sql + + SELECT CAST(complex_array AS JSON) FROM table_name + +3. **Handle NULL values** appropriately in your application logic: + + .. code:: python + + if result.array_column is not None: + # Process array data + first_item = result.array_column[0] if result.array_column else None + +Migration from Raw Strings +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**Before (raw string handling):** + +.. code:: python + + result = cursor.execute("SELECT array_column FROM table").fetchone() + raw_data = result[0] # "[1, 2, 3]" + import json + parsed_data = json.loads(raw_data) + +**After (automatic conversion):** + +.. code:: python + + result = cursor.execute("SELECT array_column FROM table").fetchone() + array_data = result[0] # [1, 2, 3] - automatically converted + first_item = array_data[0] # Direct access diff --git a/pyathena/converter.py b/pyathena/converter.py index d68aa505..712f9c71 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -8,7 +8,7 @@ from copy import deepcopy from datetime import date, datetime, time from decimal import Decimal -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type from dateutil.tz import gettz @@ -78,6 +78,52 @@ def _to_json(varchar_value: Optional[str]) -> Optional[Any]: return json.loads(varchar_value) +def _to_array(varchar_value: Optional[str]) -> Optional[List[Any]]: + """Convert array data to Python list. + + Supports two formats: + 1. JSON format: '[1, 2, 3]' or '["a", "b", "c"]' (recommended) + 2. Athena native format: '[1, 2, 3]' (basic cases only) + + For complex arrays, use CAST(array_column AS JSON) in your SQL query. + + Args: + varchar_value: String representation of array data + + Returns: + List representation of array, or None if parsing fails + """ + if varchar_value is None: + return None + + # Quick check: if it doesn't look like an array, return None + if not (varchar_value.startswith("[") and varchar_value.endswith("]")): + return None + + # Optimize: Try JSON parsing first (most reliable) + try: + result = json.loads(varchar_value) + if isinstance(result, list): + return result + except json.JSONDecodeError: + # If JSON parsing fails, fall back to basic parsing for simple cases + pass + + inner = varchar_value[1:-1].strip() + if not inner: + return [] + + try: + # For nested arrays, too complex for basic parsing + if "[" in inner: + # Contains nested arrays - too complex for basic parsing + return None + # Try native parsing (including struct arrays) + return _parse_array_native(inner) + except Exception: + return None + + def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]: """Convert map data to Python dictionary. @@ -179,6 +225,81 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]: return None +def _parse_array_native(inner: str) -> Optional[List[Any]]: + """Parse array native format: 1, 2, 3 or {a, b}, {c, d}. + + Args: + inner: Interior content of array without brackets. + + Returns: + List with parsed values, or None if no valid values found. + """ + result = [] + + # Smart split by comma - respect brace groupings + items = _split_array_items(inner) + + for item in items: + if not item: + continue + + # Handle struct (ROW) values in format {a, b, c} or {key=value, ...} + if item.strip().startswith("{") and item.strip().endswith("}"): + # This is a struct value - parse it as a struct + struct_value = _to_struct(item.strip()) + if struct_value is not None: + result.append(struct_value) + continue + + # Skip items with nested arrays or complex quoting (safety check) + if any(char in item for char in '[]="'): + continue + + # Convert item to appropriate type + converted_item = _convert_value(item) + result.append(converted_item) + + return result if result else None + + +def _split_array_items(inner: str) -> List[str]: + """Split array items by comma, respecting brace and bracket groupings. + + Args: + inner: Interior content of array without brackets. + + Returns: + List of item strings. + """ + items = [] + current_item = "" + brace_depth = 0 + bracket_depth = 0 + + for char in inner: + if char == "{": + brace_depth += 1 + elif char == "}": + brace_depth -= 1 + elif char == "[": + bracket_depth += 1 + elif char == "]": + bracket_depth -= 1 + elif char == "," and brace_depth == 0 and bracket_depth == 0: + # Top-level comma - end current item + items.append(current_item.strip()) + current_item = "" + continue + + current_item += char + + # Add the last item + if current_item.strip(): + items.append(current_item.strip()) + + return items + + def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]: """Parse map native format: key1=value1, key2=value2. @@ -302,7 +423,7 @@ def _to_default(varchar_value: Optional[str]) -> Optional[str]: "date": _to_date, "time": _to_time, "varbinary": _to_binary, - "array": _to_default, + "array": _to_array, "map": _to_map, "row": _to_struct, "decimal": _to_decimal, diff --git a/pyathena/sqlalchemy/compiler.py b/pyathena/sqlalchemy/compiler.py index da95dae1..7fb9296c 100644 --- a/pyathena/sqlalchemy/compiler.py +++ b/pyathena/sqlalchemy/compiler.py @@ -17,7 +17,7 @@ AthenaPartitionTransform, AthenaRowFormatSerde, ) -from pyathena.sqlalchemy.types import AthenaMap, AthenaStruct +from pyathena.sqlalchemy.types import AthenaArray, AthenaMap, AthenaStruct if TYPE_CHECKING: from sqlalchemy import ( @@ -164,6 +164,15 @@ def visit_map(self, type_, **kw): # noqa: N802 def visit_MAP(self, type_, **kw): # noqa: N802 return self.visit_map(type_, **kw) + def visit_array(self, type_, **kw): # noqa: N802 + if isinstance(type_, AthenaArray): + item_type_str = self.process(type_.item_type, **kw) + return f"ARRAY<{item_type_str}>" + return "ARRAY" + + def visit_ARRAY(self, type_, **kw): # noqa: N802 + return self.visit_array(type_, **kw) + class AthenaStatementCompiler(SQLCompiler): def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw): diff --git a/pyathena/sqlalchemy/types.py b/pyathena/sqlalchemy/types.py index 77588c15..ee798d77 100644 --- a/pyathena/sqlalchemy/types.py +++ b/pyathena/sqlalchemy/types.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import date, datetime -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from sqlalchemy.sql import sqltypes from sqlalchemy.sql.type_api import TypeEngine @@ -106,3 +106,24 @@ def python_type(self) -> type: class MAP(AthenaMap): __visit_name__ = "MAP" + + +class AthenaArray(TypeEngine[List[Any]]): + __visit_name__ = "array" + + def __init__(self, item_type: Any = None) -> None: + if item_type is None: + self.item_type: TypeEngine[Any] = sqltypes.String() + elif isinstance(item_type, TypeEngine): + self.item_type = item_type + else: + # Assume it's a SQLAlchemy type class and instantiate it + self.item_type = item_type() + + @property + def python_type(self) -> type: + return list + + +class ARRAY(AthenaArray): + __visit_name__ = "ARRAY" diff --git a/tests/pyathena/pandas/test_util.py b/tests/pyathena/pandas/test_util.py index 922e3f5f..972dd192 100644 --- a/tests/pyathena/pandas/test_util.py +++ b/tests/pyathena/pandas/test_util.py @@ -115,7 +115,7 @@ def test_as_pandas(cursor): datetime(2017, 1, 1, 0, 0, 0).time(), date(2017, 1, 2), b"123", - "[1, 2]", + [1, 2], [1, 2], {"1": 2, "3": 4}, {"1": 2, "3": 4}, diff --git a/tests/pyathena/sqlalchemy/test_base.py b/tests/pyathena/sqlalchemy/test_base.py index 448292c8..7013591b 100644 --- a/tests/pyathena/sqlalchemy/test_base.py +++ b/tests/pyathena/sqlalchemy/test_base.py @@ -254,7 +254,7 @@ def test_reflect_select(self, engine): datetime(2017, 1, 1, 0, 0, 0), date(2017, 1, 2), b"123", - "[1, 2]", + [1, 2], {"1": 2, "3": 4}, # map type now converted to dict {"a": 1, "b": 2}, # row type now converted to dict Decimal("0.1"), diff --git a/tests/pyathena/sqlalchemy/test_compiler.py b/tests/pyathena/sqlalchemy/test_compiler.py index 3c7120e2..207d9373 100644 --- a/tests/pyathena/sqlalchemy/test_compiler.py +++ b/tests/pyathena/sqlalchemy/test_compiler.py @@ -5,7 +5,7 @@ from sqlalchemy import Integer, String from pyathena.sqlalchemy.compiler import AthenaTypeCompiler -from pyathena.sqlalchemy.types import MAP, STRUCT, AthenaMap, AthenaStruct +from pyathena.sqlalchemy.types import ARRAY, MAP, STRUCT, AthenaArray, AthenaMap, AthenaStruct class TestAthenaTypeCompiler: @@ -80,3 +80,32 @@ def test_visit_map_no_attributes(self): map_type = type("MockMap", (), {})() result = compiler.visit_map(map_type) assert result == "MAP" + + def test_visit_array_default(self): + dialect = Mock() + compiler = AthenaTypeCompiler(dialect) + array_type = AthenaArray() + result = compiler.visit_array(array_type) + assert result == "ARRAY" + + def test_visit_array_with_type(self): + dialect = Mock() + compiler = AthenaTypeCompiler(dialect) + array_type = AthenaArray(Integer) + result = compiler.visit_array(array_type) + assert result == "ARRAY" + + def test_visit_array_uppercase(self): + dialect = Mock() + compiler = AthenaTypeCompiler(dialect) + array_type = ARRAY(String) + result = compiler.visit_ARRAY(array_type) + assert result == "ARRAY" or result == "ARRAY" + + def test_visit_array_no_attributes(self): + # Test array type without item_type attribute + dialect = Mock() + compiler = AthenaTypeCompiler(dialect) + array_type = type("MockArray", (), {})() + result = compiler.visit_array(array_type) + assert result == "ARRAY" diff --git a/tests/pyathena/sqlalchemy/test_types.py b/tests/pyathena/sqlalchemy/test_types.py index df6cdf20..2de66df9 100644 --- a/tests/pyathena/sqlalchemy/test_types.py +++ b/tests/pyathena/sqlalchemy/test_types.py @@ -3,7 +3,7 @@ from sqlalchemy import Integer, String from sqlalchemy.sql import sqltypes -from pyathena.sqlalchemy.types import MAP, STRUCT, AthenaMap, AthenaStruct +from pyathena.sqlalchemy.types import ARRAY, MAP, STRUCT, AthenaArray, AthenaMap, AthenaStruct class TestAthenaStruct: @@ -98,3 +98,50 @@ def test_mixed_type_definitions(self): map_type = AthenaMap(String, Integer()) assert isinstance(map_type.key_type, sqltypes.String) assert isinstance(map_type.value_type, sqltypes.Integer) + + +class TestAthenaArray: + def test_creation_with_default(self): + array_type = AthenaArray() + assert isinstance(array_type.item_type, sqltypes.String) + + def test_creation_with_type_class(self): + array_type = AthenaArray(Integer) + assert isinstance(array_type.item_type, sqltypes.Integer) + + def test_creation_with_type_instance(self): + array_type = AthenaArray(Integer()) + assert isinstance(array_type.item_type, sqltypes.Integer) + + def test_creation_with_string_type(self): + array_type = AthenaArray(String) + assert isinstance(array_type.item_type, sqltypes.String) + + def test_python_type(self): + array_type = AthenaArray() + assert array_type.python_type is list + + def test_visit_name(self): + array_type = AthenaArray() + assert array_type.__visit_name__ == "array" + + def test_array_uppercase_visit_name(self): + array_type = ARRAY() + assert array_type.__visit_name__ == "ARRAY" + + def test_array_with_complex_type(self): + array_type = AthenaArray(AthenaStruct(("name", String), ("age", Integer))) + assert isinstance(array_type.item_type, AthenaStruct) + assert "name" in array_type.item_type.fields + assert "age" in array_type.item_type.fields + + def test_array_with_nested_array(self): + array_type = AthenaArray(AthenaArray(Integer)) + assert isinstance(array_type.item_type, AthenaArray) + assert isinstance(array_type.item_type.item_type, sqltypes.Integer) + + def test_array_with_map_type(self): + array_type = AthenaArray(AthenaMap(String, Integer)) + assert isinstance(array_type.item_type, AthenaMap) + assert isinstance(array_type.item_type.key_type, sqltypes.String) + assert isinstance(array_type.item_type.value_type, sqltypes.Integer) diff --git a/tests/pyathena/test_converter.py b/tests/pyathena/test_converter.py index 0f2658ef..24be2720 100644 --- a/tests/pyathena/test_converter.py +++ b/tests/pyathena/test_converter.py @@ -2,7 +2,7 @@ import pytest -from pyathena.converter import DefaultTypeConverter, _to_struct +from pyathena.converter import DefaultTypeConverter, _to_array, _to_struct @pytest.mark.parametrize( @@ -74,6 +74,38 @@ def test_to_map_athena_numeric_keys(): assert result == expected +def test_to_array_athena_numeric_elements(): + """Test Athena array with numeric elements""" + array_value = "[1, 2, 3, 4]" + result = _to_array(array_value) + expected = [1, 2, 3, 4] + assert result == expected + + +def test_to_array_athena_mixed_elements(): + """Test Athena array with mixed type elements""" + array_value = "[1, hello, true, null]" + result = _to_array(array_value) + expected = [1, "hello", True, None] + assert result == expected + + +def test_to_array_athena_struct_elements(): + """Test Athena array with struct elements""" + array_value = "[{name=John, age=30}, {name=Jane, age=25}]" + result = _to_array(array_value) + expected = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] + assert result == expected + + +def test_to_array_athena_unnamed_struct_elements(): + """Test Athena array with unnamed struct elements""" + array_value = "[{Alice, 25}, {Bob, 30}]" + result = _to_array(array_value) + expected = [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}] + assert result == expected + + @pytest.mark.parametrize( "input_value", [ @@ -88,6 +120,101 @@ def test_to_struct_non_dict_json(input_value): assert result is None +@pytest.mark.parametrize( + "input_value,expected", + [ + (None, None), + ( + "[1, 2, 3, 4, 5]", + [1, 2, 3, 4, 5], + ), + ( + '["apple", "banana", "cherry"]', + ["apple", "banana", "cherry"], + ), + ( + "[true, false, null]", + [True, False, None], + ), + ( + '[{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]', + [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}], + ), + ("not valid json", None), + ("", None), + ("[]", []), + ], +) +def test_to_array_json_formats(input_value, expected): + """Test ARRAY conversion for various JSON formats and edge cases.""" + result = _to_array(input_value) + assert result == expected + + +@pytest.mark.parametrize( + "input_value,expected", + [ + ("[1, 2, 3]", [1, 2, 3]), + ("[]", []), + ("[apple, banana, cherry]", ["apple", "banana", "cherry"]), + ("[{Alice, 25}, {Bob, 30}]", [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}]), + ( + "[{name=John, age=30}, {name=Jane, age=25}]", + [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}], + ), + ("[true, false, null]", [True, False, None]), + ("[1, 2.5, hello]", [1, 2.5, "hello"]), + ], +) +def test_to_array_athena_native_formats(input_value, expected): + """Test ARRAY conversion for Athena native formats.""" + result = _to_array(input_value) + assert result == expected + + +@pytest.mark.parametrize( + "input_value,expected", + [ + ("[ARRAY[1, 2], ARRAY[3, 4]]", None), # Nested arrays (native format) + ("[[1, 2], [3, 4]]", [[1, 2], [3, 4]]), # Nested arrays (JSON format - parseable) + ("[MAP(ARRAY['key'], ARRAY['value'])]", None), # Complex nested structures + ], +) +def test_to_array_complex_nested_cases(input_value, expected): + """Test complex nested array cases behavior.""" + result = _to_array(input_value) + assert result == expected + + +@pytest.mark.parametrize( + "input_value", + [ + '"just a string"', # String JSON + "42", # Number JSON + '{"key": "value"}', # Object JSON + ], +) +def test_to_array_non_array_json(input_value): + """Test that non-array JSON formats return None.""" + result = _to_array(input_value) + assert result is None + + +@pytest.mark.parametrize( + "input_value", + [ + "not an array", # Not bracketed + "[unclosed array", # Malformed + "closed array]", # Malformed + "[{malformed struct}", # Malformed struct + ], +) +def test_to_array_invalid_formats(input_value): + """Test that invalid array formats return None.""" + result = _to_array(input_value) + assert result is None + + class TestDefaultTypeConverter: @pytest.mark.parametrize( "input_value,expected", @@ -104,3 +231,21 @@ def test_struct_conversion(self, input_value, expected): converter = DefaultTypeConverter() result = converter.convert("row", input_value) assert result == expected + + @pytest.mark.parametrize( + "input_value,expected", + [ + ("[1, 2, 3]", [1, 2, 3]), + ('["a", "b", "c"]', ["a", "b", "c"]), + (None, None), + ("", None), + ("invalid json", None), + ("[apple, banana]", ["apple", "banana"]), + ("[]", []), + ], + ) + def test_array_conversion(self, input_value, expected): + """Test DefaultTypeConverter ARRAY conversion for various input formats.""" + converter = DefaultTypeConverter() + result = converter.convert("array", input_value) + assert result == expected diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index 70af75a9..56d126a3 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -13,7 +13,7 @@ import pytest from pyathena import BINARY, BOOLEAN, DATE, DATETIME, JSON, NUMBER, STRING, TIME -from pyathena.converter import _to_map, _to_struct +from pyathena.converter import _to_array, _to_map, _to_struct from pyathena.cursor import Cursor from pyathena.error import DatabaseError, NotSupportedError, ProgrammingError from pyathena.model import AthenaQueryExecution @@ -515,7 +515,7 @@ def test_complex(self, cursor): datetime(2017, 1, 1, 0, 0, 0).time(), date(2017, 1, 2), b"123", - "[1, 2]", + [1, 2], [1, 2], {"1": 2, "3": 4}, {"1": 2, "3": 4}, @@ -939,3 +939,196 @@ def test_complex_combinations(self, cursor, query, description): # If it's not a string, it should still be a valid value (not None) assert complex_value is not None, f"Complex value should not be None for {description}" _logger.info(f"{description}: Complex value type {type(complex_value).__name__}") + + @pytest.mark.parametrize( + "query,description", + [ + ("SELECT ARRAY[1, 2, 3, 4, 5] AS simple_array", "simple_array"), + ("SELECT ARRAY['apple', 'banana', 'cherry'] AS string_array", "string_array"), + ("SELECT ARRAY[true, false, null] AS boolean_array", "boolean_array"), + ("SELECT ARRAY[1.5, 2.7, 3.14] AS float_array", "float_array"), + ("SELECT ARRAY[] AS empty_array", "empty_array"), + ("SELECT ARRAY[1, null, 3, null, 5] AS null_elements_array", "null_elements_array"), + ( + "SELECT ARRAY[CAST(1 AS VARCHAR), 'mixed', 'true', 'null', '2.5'] AS mixed_array", + "mixed_array", + ), + ( + "SELECT ARRAY[CAST(1.23 AS DECIMAL(10,2)), CAST(4.56 AS DECIMAL(10,2))] " + "AS decimal_array", + "decimal_array", + ), + ( + "SELECT ARRAY[DATE '2023-01-01', DATE '2023-12-31'] AS date_array", + "date_array", + ), + ( + "SELECT ARRAY[TIMESTAMP '2023-01-01 12:00:00'] AS timestamp_array", + "timestamp_array", + ), + ], + ) + def test_array_types_basic(self, cursor, query, description): + """Test basic ARRAY type scenarios.""" + _logger.info(f"=== ARRAY Type Test: {description} ===") + cursor.execute(query) + result = cursor.fetchone() + array_value = result[0] + _logger.info(f"{description}: {array_value!r} (type: {type(array_value).__name__})") + + # Validate array value + assert array_value is not None or description == "empty_array", ( + f"ARRAY value should not be None for {description}" + ) + if description == "empty_array": + assert array_value == [], f"Empty array should be [] for {description}" + else: + assert isinstance(array_value, list), f"ARRAY value should be list for {description}" + _logger.info(f"{description}: Array value type {type(array_value).__name__}") + + @pytest.mark.parametrize( + "query,description", + [ + ( + "SELECT ARRAY[ROW(1, 'Alice'), ROW(2, 'Bob'), ROW(3, 'Charlie')] AS struct_array", + "struct_array", + ), + ( + "SELECT ARRAY[ROW('name', 'John', 25), ROW('name', 'Jane', 30)] " + "AS unnamed_struct_array", + "unnamed_struct_array", + ), + ( + "SELECT ARRAY[CAST(ROW('John', 25) AS ROW(name VARCHAR, age INT)), " + "CAST(ROW('Jane', 30) AS ROW(name VARCHAR, age INT))] AS named_struct_array", + "named_struct_array", + ), + ], + ) + def test_array_types_with_structs(self, cursor, query, description): + """Test ARRAY types containing STRUCT elements.""" + _logger.info(f"=== ARRAY with STRUCT Test: {description} ===") + cursor.execute(query) + result = cursor.fetchone() + array_value = result[0] + _logger.info(f"{description}: {array_value!r} (type: {type(array_value).__name__})") + + # Validate array value + assert array_value is not None, f"ARRAY value should not be None for {description}" + assert isinstance(array_value, list), f"ARRAY value should be list for {description}" + assert len(array_value) > 0, f"ARRAY should not be empty for {description}" + + # Check first element is a dict (converted struct) + first_element = array_value[0] + assert isinstance(first_element, dict), ( + f"First array element should be dict (struct) for {description}" + ) + _logger.info(f"{description}: First element: {first_element!r}") + + @pytest.mark.parametrize( + "query,description", + [ + ( + "SELECT CAST(ARRAY[1, 2, 3] AS JSON) AS arr_json", + "arr_json", + ), + ( + "SELECT CAST(ARRAY['a', 'b', 'c'] AS JSON) AS str_arr_json", + "str_arr_json", + ), + ( + "SELECT CAST(ARRAY[ROW(1, 'Alice'), ROW(2, 'Bob')] AS JSON) AS struct_arr_json", + "struct_arr_json", + ), + ], + ) + def test_array_types_json_cast(self, cursor, query, description): + """Test ARRAY types with JSON casting.""" + _logger.info(f"=== ARRAY JSON Cast Test: {description} ===") + cursor.execute(query) + result = cursor.fetchone() + array_value = result[0] + _logger.info(f"{description}: {array_value!r} (type: {type(array_value).__name__})") + + # Validate array value + assert array_value is not None, f"ARRAY value should not be None for {description}" + assert isinstance(array_value, list), ( + f"JSON cast ARRAY value should be list for {description}" + ) + _logger.info(f"{description}: JSON cast array type {type(array_value).__name__}") + + @pytest.mark.parametrize( + "query,description", + [ + ("SELECT CARDINALITY(ARRAY[1, 2, 3, 4, 5]) AS array_size", "array_size"), + ("SELECT ARRAY[10, 20, 30, 40][2] AS array_element", "array_element"), + ("SELECT ARRAY[1, 2] || ARRAY[3, 4] AS array_concat", "array_concat"), + ("SELECT CONTAINS(ARRAY[1, 2, 3], 2) AS array_contains", "array_contains"), + ], + ) + def test_array_operations(self, cursor, query, description): + """Test ARRAY operations and functions.""" + _logger.info(f"=== ARRAY Operation Test: {description} ===") + cursor.execute(query) + result = cursor.fetchone() + operation_result = result[0] + _logger.info( + f"{description}: {operation_result!r} (type: {type(operation_result).__name__})" + ) + + # Validate operation result + assert operation_result is not None, ( + f"ARRAY operation result should not be None for {description}" + ) + + # Type-specific validations + if description == "array_size": + assert operation_result == 5, f"Array size should be 5 for {description}" + elif description == "array_element": + assert operation_result == 20, f"Array element [2] should be 20 for {description}" + elif description == "array_concat": + assert operation_result == [1, 2, 3, 4], ( + f"Array concat should be [1,2,3,4] for {description}" + ) + elif description == "array_contains": + assert operation_result is True, f"Array contains should be True for {description}" + + def test_array_converter_behavior(self, cursor): + """Test ARRAY converter behavior with different formats.""" + _logger.info("=== ARRAY Converter Behavior Test ===") + + # Test simple array conversion + cursor.execute("SELECT ARRAY[1, 2, 3] AS simple") + result = cursor.fetchone() + simple_array = result[0] + _logger.info(f"Simple array: {simple_array!r}") + assert simple_array == [1, 2, 3] + + # Test array with struct conversion + cursor.execute( + "SELECT ARRAY[CAST(ROW(1, 2) AS ROW(a INT, b INT)), " + "CAST(ROW(3, 4) AS ROW(a INT, b INT))] AS struct_array" + ) + result = cursor.fetchone() + struct_array = result[0] + _logger.info(f"Struct array: {struct_array!r}") + assert isinstance(struct_array, list) + assert len(struct_array) == 2 + assert isinstance(struct_array[0], dict) + + # Test converter function directly + test_cases = [ + ("[1, 2, 3]", [1, 2, 3]), + ('["a", "b", "c"]', ["a", "b", "c"]), + ("[{a=1, b=2}]", [{"a": 1, "b": 2}]), + ("[]", []), + (None, None), + ("invalid", None), + ] + + for test_input, expected in test_cases: + result = _to_array(test_input) + _logger.info(f"Converter test: {test_input!r} -> {result!r}") + assert result == expected, ( + f"Converter failed for {test_input}: expected {expected}, got {result}" + )