1515# specific language governing permissions and limitations
1616# under the License.
1717
18- from typing import Generator , List
18+ import time
19+ from typing import Any , Dict , Generator , List
20+ from uuid import uuid4
1921
2022import boto3
23+ import pyarrow as pa
2124import pytest
2225from botocore .exceptions import ClientError
2326
3033 NoSuchTableError ,
3134 TableAlreadyExistsError ,
3235)
36+ from pyiceberg .io .pyarrow import schema_to_pyarrow
3337from pyiceberg .schema import Schema
3438from pyiceberg .types import IntegerType
3539from tests .conftest import clean_up , get_bucket_name , get_s3_path
@@ -52,8 +56,62 @@ def fixture_test_catalog() -> Generator[Catalog, None, None]:
5256 clean_up (test_catalog )
5357
5458
59+ class AthenaQueryHelper :
60+ _athena_client : boto3 .client
61+ _s3_resource : boto3 .resource
62+ _output_bucket : str
63+ _output_path : str
64+
65+ def __init__ (self ) -> None :
66+ self ._s3_resource = boto3 .resource ("s3" )
67+ self ._athena_client = boto3 .client ("athena" )
68+ self ._output_bucket = get_bucket_name ()
69+ self ._output_path = f"athena_results_{ uuid4 ()} "
70+
71+ def get_query_results (self , query : str ) -> List [Dict [str , Any ]]:
72+ query_execution_id = self ._athena_client .start_query_execution (
73+ QueryString = query , ResultConfiguration = {"OutputLocation" : f"s3://{ self ._output_bucket } /{ self ._output_path } " }
74+ )["QueryExecutionId" ]
75+
76+ while True :
77+ result = self ._athena_client .get_query_execution (QueryExecutionId = query_execution_id )["QueryExecution" ]["Status" ]
78+ query_status = result ["State" ]
79+ assert query_status not in [
80+ "FAILED" ,
81+ "CANCELLED" ,
82+ ], f"""
83+ Athena query with the string failed or was cancelled:
84+ Query: { query }
85+ Status: { query_status }
86+ Reason: { result ["StateChangeReason" ]} """
87+
88+ if query_status not in ["QUEUED" , "RUNNING" ]:
89+ break
90+ time .sleep (0.5 )
91+
92+ # No pagination for now, assume that we are not doing large queries
93+ return self ._athena_client .get_query_results (QueryExecutionId = query_execution_id )["ResultSet" ]["Rows" ]
94+
95+ def clean_up (self ) -> None :
96+ bucket = self ._s3_resource .Bucket (self ._output_bucket )
97+ for obj in bucket .objects .filter (Prefix = f"{ self ._output_path } /" ):
98+ self ._s3_resource .Object (bucket .name , obj .key ).delete ()
99+
100+
101+ @pytest .fixture (name = "athena" , scope = "module" )
102+ def fixture_athena_helper () -> Generator [AthenaQueryHelper , None , None ]:
103+ query_helper = AthenaQueryHelper ()
104+ yield query_helper
105+ query_helper .clean_up ()
106+
107+
55108def test_create_table (
56- test_catalog : Catalog , s3 : boto3 .client , table_schema_nested : Schema , table_name : str , database_name : str
109+ test_catalog : Catalog ,
110+ s3 : boto3 .client ,
111+ table_schema_nested : Schema ,
112+ table_name : str ,
113+ database_name : str ,
114+ athena : AthenaQueryHelper ,
57115) -> None :
58116 identifier = (database_name , table_name )
59117 test_catalog .create_namespace (database_name )
@@ -64,6 +122,48 @@ def test_create_table(
64122 s3 .head_object (Bucket = get_bucket_name (), Key = metadata_location )
65123 assert test_catalog ._parse_metadata_version (table .metadata_location ) == 0
66124
125+ table .append (
126+ pa .Table .from_pylist (
127+ [
128+ {
129+ "foo" : "foo_val" ,
130+ "bar" : 1 ,
131+ "baz" : False ,
132+ "qux" : ["x" , "y" ],
133+ "quux" : {"key" : {"subkey" : 2 }},
134+ "location" : [{"latitude" : 1.1 }],
135+ "person" : {"name" : "some_name" , "age" : 23 },
136+ }
137+ ],
138+ schema = schema_to_pyarrow (table .schema ()),
139+ ),
140+ )
141+
142+ assert athena .get_query_results (f'SELECT * FROM "{ database_name } "."{ table_name } "' ) == [
143+ {
144+ "Data" : [
145+ {"VarCharValue" : "foo" },
146+ {"VarCharValue" : "bar" },
147+ {"VarCharValue" : "baz" },
148+ {"VarCharValue" : "qux" },
149+ {"VarCharValue" : "quux" },
150+ {"VarCharValue" : "location" },
151+ {"VarCharValue" : "person" },
152+ ]
153+ },
154+ {
155+ "Data" : [
156+ {"VarCharValue" : "foo_val" },
157+ {"VarCharValue" : "1" },
158+ {"VarCharValue" : "false" },
159+ {"VarCharValue" : "[x, y]" },
160+ {"VarCharValue" : "{key={subkey=2}}" },
161+ {"VarCharValue" : "[{latitude=1.1, longitude=null}]" },
162+ {"VarCharValue" : "{name=some_name, age=23}" },
163+ ]
164+ },
165+ ]
166+
67167
68168def test_create_table_with_invalid_location (table_schema_nested : Schema , table_name : str , database_name : str ) -> None :
69169 identifier = (database_name , table_name )
@@ -269,7 +369,7 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str)
269369
270370
271371def test_commit_table_update_schema (
272- test_catalog : Catalog , table_schema_nested : Schema , database_name : str , table_name : str
372+ test_catalog : Catalog , table_schema_nested : Schema , database_name : str , table_name : str , athena : AthenaQueryHelper
273373) -> None :
274374 identifier = (database_name , table_name )
275375 test_catalog .create_namespace (namespace = database_name )
@@ -279,6 +379,20 @@ def test_commit_table_update_schema(
279379 assert test_catalog ._parse_metadata_version (table .metadata_location ) == 0
280380 assert original_table_metadata .current_schema_id == 0
281381
382+ assert athena .get_query_results (f'SELECT * FROM "{ database_name } "."{ table_name } "' ) == [
383+ {
384+ "Data" : [
385+ {"VarCharValue" : "foo" },
386+ {"VarCharValue" : "bar" },
387+ {"VarCharValue" : "baz" },
388+ {"VarCharValue" : "qux" },
389+ {"VarCharValue" : "quux" },
390+ {"VarCharValue" : "location" },
391+ {"VarCharValue" : "person" },
392+ ]
393+ }
394+ ]
395+
282396 transaction = table .transaction ()
283397 update = transaction .update_schema ()
284398 update .add_column (path = "b" , field_type = IntegerType ())
@@ -295,6 +409,48 @@ def test_commit_table_update_schema(
295409 assert new_schema == update ._apply ()
296410 assert new_schema .find_field ("b" ).field_type == IntegerType ()
297411
412+ table .append (
413+ pa .Table .from_pylist (
414+ [
415+ {
416+ "foo" : "foo_val" ,
417+ "bar" : 1 ,
418+ "location" : [{"latitude" : 1.1 }],
419+ "person" : {"name" : "some_name" , "age" : 23 },
420+ "b" : 2 ,
421+ }
422+ ],
423+ schema = schema_to_pyarrow (new_schema ),
424+ ),
425+ )
426+
427+ assert athena .get_query_results (f'SELECT * FROM "{ database_name } "."{ table_name } "' ) == [
428+ {
429+ "Data" : [
430+ {"VarCharValue" : "foo" },
431+ {"VarCharValue" : "bar" },
432+ {"VarCharValue" : "baz" },
433+ {"VarCharValue" : "qux" },
434+ {"VarCharValue" : "quux" },
435+ {"VarCharValue" : "location" },
436+ {"VarCharValue" : "person" },
437+ {"VarCharValue" : "b" },
438+ ]
439+ },
440+ {
441+ "Data" : [
442+ {"VarCharValue" : "foo_val" },
443+ {"VarCharValue" : "1" },
444+ {},
445+ {"VarCharValue" : "[]" },
446+ {"VarCharValue" : "{}" },
447+ {"VarCharValue" : "[{latitude=1.1, longitude=null}]" },
448+ {"VarCharValue" : "{name=some_name, age=23}" },
449+ {"VarCharValue" : "2" },
450+ ]
451+ },
452+ ]
453+
298454
299455def test_commit_table_properties (test_catalog : Catalog , table_schema_nested : Schema , database_name : str , table_name : str ) -> None :
300456 identifier = (database_name , table_name )
0 commit comments