diff --git a/tests/pyathena/sqlalchemy/test_base.py b/tests/pyathena/sqlalchemy/test_base.py index 7013591b..7419aeda 100644 --- a/tests/pyathena/sqlalchemy/test_base.py +++ b/tests/pyathena/sqlalchemy/test_base.py @@ -17,7 +17,7 @@ from sqlalchemy.sql.schema import Column, MetaData, Table from sqlalchemy.sql.selectable import TextualSelect -from pyathena.sqlalchemy.types import TINYINT, AthenaStruct, Tinyint +from pyathena.sqlalchemy.types import TINYINT, AthenaArray, AthenaMap, AthenaStruct, Tinyint from tests.pyathena.conftest import ENV @@ -1997,3 +1997,145 @@ def test_compile_temporal_query_by_timestamp_with_hint(self, engine): f"SELECT count({ENV.schema}.{table_name}.col_1) AS count_1 \n" f"FROM {ENV.schema}.{table_name} FOR VERSION AS OF '{timestamp}'" ) + + def test_create_table_with_array_types(self, engine): + """Test DDL compilation for ARRAY types.""" + engine, conn = engine + table_name = "test_create_table_with_array_types" + table = Table( + table_name, + MetaData(schema=ENV.schema), + Column("id", types.Integer), + Column("tags", AthenaArray(types.String)), + Column("scores", AthenaArray(types.Integer)), + Column("nested_arrays", AthenaArray(AthenaArray(types.String))), + Column( + "struct_array", + AthenaArray(AthenaStruct(("name", types.String), ("age", types.Integer))), + ), + awsathena_location=f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/", + awsathena_file_format="PARQUET", + ) + + # Test DDL compilation + create_ddl = CreateTable(table).compile(dialect=engine.dialect) + ddl_string = str(create_ddl) + + # Verify ARRAY types are correctly compiled + assert "tags ARRAY" in ddl_string + assert "scores ARRAY" in ddl_string + assert "nested_arrays ARRAY>" in ddl_string + assert "struct_array ARRAY" in ddl_string + + def test_create_table_with_map_types(self, engine): + """Test DDL compilation for MAP types.""" + engine, conn = engine + table_name = "test_create_table_with_map_types" + table = Table( + table_name, + MetaData(schema=ENV.schema), + Column("id", types.Integer), + Column("attributes", AthenaMap(types.String, types.String)), + Column("metrics", AthenaMap(types.String, types.Integer)), + Column( + "complex_map", + AthenaMap( + types.String, AthenaStruct(("value", types.String), ("count", types.Integer)) + ), + ), + Column("nested_map", AthenaMap(types.String, AthenaArray(types.String))), + awsathena_location=f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/", + awsathena_file_format="PARQUET", + ) + + # Test DDL compilation + create_ddl = CreateTable(table).compile(dialect=engine.dialect) + ddl_string = str(create_ddl) + + # Verify MAP types are correctly compiled + assert "attributes MAP" in ddl_string + assert "metrics MAP" in ddl_string + assert "complex_map MAP" in ddl_string + assert "nested_map MAP>" in ddl_string + + def test_create_table_with_struct_types(self, engine): + """Test DDL compilation for STRUCT types.""" + engine, conn = engine + table_name = "test_create_table_with_struct_types" + table = Table( + table_name, + MetaData(schema=ENV.schema), + Column("id", types.Integer), + Column( + "user_info", + AthenaStruct( + ("name", types.String), ("age", types.Integer), ("email", types.String) + ), + ), + Column( + "nested_struct", + AthenaStruct( + ( + "personal", + AthenaStruct(("first_name", types.String), ("last_name", types.String)), + ), + ("preferences", AthenaMap(types.String, types.String)), + ), + ), + Column( + "struct_with_array", + AthenaStruct( + ("tags", AthenaArray(types.String)), ("scores", AthenaArray(types.Integer)) + ), + ), + awsathena_location=f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/", + awsathena_file_format="PARQUET", + ) + + # Test DDL compilation + create_ddl = CreateTable(table).compile(dialect=engine.dialect) + ddl_string = str(create_ddl) + + # Verify STRUCT types are correctly compiled + assert "user_info ROW(name STRING, age INTEGER, email STRING)" in ddl_string + assert ( + "nested_struct ROW(personal ROW(first_name STRING, last_name STRING), " + "preferences MAP)" in ddl_string + ) + assert "struct_with_array ROW(tags ARRAY, scores ARRAY)" in ddl_string + + def test_create_table_with_complex_nested_types(self, engine): + """Test DDL compilation for complex nested combinations of ARRAY, MAP, and STRUCT.""" + engine, conn = engine + table_name = "test_create_table_with_complex_nested_types" + table = Table( + table_name, + MetaData(schema=ENV.schema), + Column("id", types.Integer), + Column( + "data", + AthenaArray( + AthenaMap( + types.String, + AthenaStruct( + ("value", types.String), + ("metadata", AthenaMap(types.String, types.String)), + ("tags", AthenaArray(types.String)), + ), + ) + ), + ), + awsathena_location=f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/", + awsathena_file_format="PARQUET", + ) + + # Test DDL compilation + create_ddl = CreateTable(table).compile(dialect=engine.dialect) + ddl_string = str(create_ddl) + + # Verify complex nested type is correctly compiled + expected_type = ( + "data ARRAY, " + "tags ARRAY)>>" + ) + assert expected_type in ddl_string