Skip to content

Commit 4391919

Browse files
authored
Set Glue Table Information when creating/updating tables (#288)
* Set Glue Table Information when creating/updating tables * Add integration tests for glue/Athena
1 parent 4cf1f35 commit 4391919

File tree

3 files changed

+312
-8
lines changed

3 files changed

+312
-8
lines changed

pyiceberg/catalog/glue.py

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from typing import (
2020
Any,
21+
Dict,
2122
List,
2223
Optional,
2324
Set,
@@ -28,6 +29,7 @@
2829
import boto3
2930
from mypy_boto3_glue.client import GlueClient
3031
from mypy_boto3_glue.type_defs import (
32+
ColumnTypeDef,
3133
DatabaseInputTypeDef,
3234
DatabaseTypeDef,
3335
StorageDescriptorTypeDef,
@@ -59,12 +61,32 @@
5961
)
6062
from pyiceberg.io import load_file_io
6163
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
62-
from pyiceberg.schema import Schema
64+
from pyiceberg.schema import Schema, SchemaVisitor, visit
6365
from pyiceberg.serializers import FromInputFile
6466
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata
65-
from pyiceberg.table.metadata import new_table_metadata
67+
from pyiceberg.table.metadata import TableMetadata, new_table_metadata
6668
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
6769
from pyiceberg.typedef import EMPTY_DICT
70+
from pyiceberg.types import (
71+
BinaryType,
72+
BooleanType,
73+
DateType,
74+
DecimalType,
75+
DoubleType,
76+
FixedType,
77+
FloatType,
78+
IntegerType,
79+
ListType,
80+
LongType,
81+
MapType,
82+
NestedField,
83+
PrimitiveType,
84+
StringType,
85+
StructType,
86+
TimestampType,
87+
TimeType,
88+
UUIDType,
89+
)
6890

6991
# If Glue should skip archiving an old table version when creating a new version in a commit. By
7092
# default, Glue archives all old table versions after an UpdateTable call, but Glue has a default
@@ -73,6 +95,10 @@
7395
GLUE_SKIP_ARCHIVE = "glue.skip-archive"
7496
GLUE_SKIP_ARCHIVE_DEFAULT = True
7597

98+
ICEBERG_FIELD_ID = "iceberg.field.id"
99+
ICEBERG_FIELD_OPTIONAL = "iceberg.field.optional"
100+
ICEBERG_FIELD_CURRENT = "iceberg.field.current"
101+
76102

77103
def _construct_parameters(
78104
metadata_location: str, glue_table: Optional[TableTypeDef] = None, prev_metadata_location: Optional[str] = None
@@ -84,17 +110,97 @@ def _construct_parameters(
84110
return new_parameters
85111

86112

113+
GLUE_PRIMITIVE_TYPES = {
114+
BooleanType: "boolean",
115+
IntegerType: "int",
116+
LongType: "bigint",
117+
FloatType: "float",
118+
DoubleType: "double",
119+
DateType: "date",
120+
TimeType: "string",
121+
StringType: "string",
122+
UUIDType: "string",
123+
TimestampType: "timestamp",
124+
FixedType: "binary",
125+
BinaryType: "binary",
126+
}
127+
128+
129+
class _IcebergSchemaToGlueType(SchemaVisitor[str]):
130+
def schema(self, schema: Schema, struct_result: str) -> str:
131+
return struct_result
132+
133+
def struct(self, struct: StructType, field_results: List[str]) -> str:
134+
return f"struct<{','.join(field_results)}>"
135+
136+
def field(self, field: NestedField, field_result: str) -> str:
137+
return f"{field.name}:{field_result}"
138+
139+
def list(self, list_type: ListType, element_result: str) -> str:
140+
return f"array<{element_result}>"
141+
142+
def map(self, map_type: MapType, key_result: str, value_result: str) -> str:
143+
return f"map<{key_result},{value_result}>"
144+
145+
def primitive(self, primitive: PrimitiveType) -> str:
146+
if isinstance(primitive, DecimalType):
147+
return f"decimal({primitive.precision},{primitive.scale})"
148+
if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES:
149+
raise ValueError(f"Unknown primitive type: {primitive}")
150+
return GLUE_PRIMITIVE_TYPES[primitive_type]
151+
152+
153+
def _to_columns(metadata: TableMetadata) -> List[ColumnTypeDef]:
154+
results: Dict[str, ColumnTypeDef] = {}
155+
156+
def _append_to_results(field: NestedField, is_current: bool) -> None:
157+
if field.name in results:
158+
return
159+
160+
results[field.name] = cast(
161+
ColumnTypeDef,
162+
{
163+
"Name": field.name,
164+
"Type": visit(field.field_type, _IcebergSchemaToGlueType()),
165+
"Parameters": {
166+
ICEBERG_FIELD_ID: str(field.field_id),
167+
ICEBERG_FIELD_OPTIONAL: str(field.optional).lower(),
168+
ICEBERG_FIELD_CURRENT: str(is_current).lower(),
169+
},
170+
},
171+
)
172+
if field.doc:
173+
results[field.name]["Comment"] = field.doc
174+
175+
if current_schema := metadata.schema_by_id(metadata.current_schema_id):
176+
for field in current_schema.columns:
177+
_append_to_results(field, True)
178+
179+
for schema in metadata.schemas:
180+
if schema.schema_id == metadata.current_schema_id:
181+
continue
182+
for field in schema.columns:
183+
_append_to_results(field, False)
184+
185+
return list(results.values())
186+
187+
87188
def _construct_table_input(
88189
table_name: str,
89190
metadata_location: str,
90191
properties: Properties,
192+
metadata: TableMetadata,
91193
glue_table: Optional[TableTypeDef] = None,
92194
prev_metadata_location: Optional[str] = None,
93195
) -> TableInputTypeDef:
94196
table_input: TableInputTypeDef = {
95197
"Name": table_name,
96198
"TableType": EXTERNAL_TABLE,
97199
"Parameters": _construct_parameters(metadata_location, glue_table, prev_metadata_location),
200+
"StorageDescriptor": {
201+
"Columns": _to_columns(metadata),
202+
"Location": metadata.location,
203+
},
98204
}
99205

100206
if "Description" in properties:
@@ -258,7 +364,7 @@ def create_table(
258364
io = load_file_io(properties=self.properties, location=metadata_location)
259365
self._write_metadata(metadata, io, metadata_location)
260366

261-
table_input = _construct_table_input(table_name, metadata_location, properties)
367+
table_input = _construct_table_input(table_name, metadata_location, properties, metadata)
262368
database_name, table_name = self.identifier_to_database_and_table(identifier)
263369
self._create_glue_table(database_name=database_name, table_name=table_name, table_input=table_input)
264370

@@ -322,6 +428,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
322428
table_name=table_name,
323429
metadata_location=new_metadata_location,
324430
properties=current_table.properties,
431+
metadata=updated_metadata,
325432
glue_table=current_glue_table,
326433
prev_metadata_location=current_table.metadata_location,
327434
)

tests/catalog/integration_test_glue.py

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Generator, List
18+
import time
19+
from typing import Any, Dict, Generator, List
20+
from uuid import uuid4
1921

2022
import boto3
23+
import pyarrow as pa
2124
import pytest
2225
from botocore.exceptions import ClientError
2326

@@ -30,6 +33,7 @@
3033
NoSuchTableError,
3134
TableAlreadyExistsError,
3235
)
36+
from pyiceberg.io.pyarrow import schema_to_pyarrow
3337
from pyiceberg.schema import Schema
3438
from pyiceberg.types import IntegerType
3539
from tests.conftest import clean_up, get_bucket_name, get_s3_path
@@ -52,8 +56,62 @@ def fixture_test_catalog() -> Generator[Catalog, None, None]:
5256
clean_up(test_catalog)
5357

5458

59+
class AthenaQueryHelper:
60+
_athena_client: boto3.client
61+
_s3_resource: boto3.resource
62+
_output_bucket: str
63+
_output_path: str
64+
65+
def __init__(self) -> None:
66+
self._s3_resource = boto3.resource("s3")
67+
self._athena_client = boto3.client("athena")
68+
self._output_bucket = get_bucket_name()
69+
self._output_path = f"athena_results_{uuid4()}"
70+
71+
def get_query_results(self, query: str) -> List[Dict[str, Any]]:
72+
query_execution_id = self._athena_client.start_query_execution(
73+
QueryString=query, ResultConfiguration={"OutputLocation": f"s3://{self._output_bucket}/{self._output_path}"}
74+
)["QueryExecutionId"]
75+
76+
while True:
77+
result = self._athena_client.get_query_execution(QueryExecutionId=query_execution_id)["QueryExecution"]["Status"]
78+
query_status = result["State"]
79+
assert query_status not in [
80+
"FAILED",
81+
"CANCELLED",
82+
], f"""
83+
Athena query with the string failed or was cancelled:
84+
Query: {query}
85+
Status: {query_status}
86+
Reason: {result["StateChangeReason"]}"""
87+
88+
if query_status not in ["QUEUED", "RUNNING"]:
89+
break
90+
time.sleep(0.5)
91+
92+
# No pagination for now, assume that we are not doing large queries
93+
return self._athena_client.get_query_results(QueryExecutionId=query_execution_id)["ResultSet"]["Rows"]
94+
95+
def clean_up(self) -> None:
96+
bucket = self._s3_resource.Bucket(self._output_bucket)
97+
for obj in bucket.objects.filter(Prefix=f"{self._output_path}/"):
98+
self._s3_resource.Object(bucket.name, obj.key).delete()
99+
100+
101+
@pytest.fixture(name="athena", scope="module")
102+
def fixture_athena_helper() -> Generator[AthenaQueryHelper, None, None]:
103+
query_helper = AthenaQueryHelper()
104+
yield query_helper
105+
query_helper.clean_up()
106+
107+
55108
def test_create_table(
56-
test_catalog: Catalog, s3: boto3.client, table_schema_nested: Schema, table_name: str, database_name: str
109+
test_catalog: Catalog,
110+
s3: boto3.client,
111+
table_schema_nested: Schema,
112+
table_name: str,
113+
database_name: str,
114+
athena: AthenaQueryHelper,
57115
) -> None:
58116
identifier = (database_name, table_name)
59117
test_catalog.create_namespace(database_name)
@@ -64,6 +122,48 @@ def test_create_table(
64122
s3.head_object(Bucket=get_bucket_name(), Key=metadata_location)
65123
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
66124

125+
table.append(
126+
pa.Table.from_pylist(
127+
[
128+
{
129+
"foo": "foo_val",
130+
"bar": 1,
131+
"baz": False,
132+
"qux": ["x", "y"],
133+
"quux": {"key": {"subkey": 2}},
134+
"location": [{"latitude": 1.1}],
135+
"person": {"name": "some_name", "age": 23},
136+
}
137+
],
138+
schema=schema_to_pyarrow(table.schema()),
139+
),
140+
)
141+
142+
assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [
143+
{
144+
"Data": [
145+
{"VarCharValue": "foo"},
146+
{"VarCharValue": "bar"},
147+
{"VarCharValue": "baz"},
148+
{"VarCharValue": "qux"},
149+
{"VarCharValue": "quux"},
150+
{"VarCharValue": "location"},
151+
{"VarCharValue": "person"},
152+
]
153+
},
154+
{
155+
"Data": [
156+
{"VarCharValue": "foo_val"},
157+
{"VarCharValue": "1"},
158+
{"VarCharValue": "false"},
159+
{"VarCharValue": "[x, y]"},
160+
{"VarCharValue": "{key={subkey=2}}"},
161+
{"VarCharValue": "[{latitude=1.1, longitude=null}]"},
162+
{"VarCharValue": "{name=some_name, age=23}"},
163+
]
164+
},
165+
]
166+
67167

68168
def test_create_table_with_invalid_location(table_schema_nested: Schema, table_name: str, database_name: str) -> None:
69169
identifier = (database_name, table_name)
@@ -269,7 +369,7 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str)
269369

270370

271371
def test_commit_table_update_schema(
272-
test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str
372+
test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str, athena: AthenaQueryHelper
273373
) -> None:
274374
identifier = (database_name, table_name)
275375
test_catalog.create_namespace(namespace=database_name)
@@ -279,6 +379,20 @@ def test_commit_table_update_schema(
279379
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
280380
assert original_table_metadata.current_schema_id == 0
281381

382+
assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [
383+
{
384+
"Data": [
385+
{"VarCharValue": "foo"},
386+
{"VarCharValue": "bar"},
387+
{"VarCharValue": "baz"},
388+
{"VarCharValue": "qux"},
389+
{"VarCharValue": "quux"},
390+
{"VarCharValue": "location"},
391+
{"VarCharValue": "person"},
392+
]
393+
}
394+
]
395+
282396
transaction = table.transaction()
283397
update = transaction.update_schema()
284398
update.add_column(path="b", field_type=IntegerType())
@@ -295,6 +409,48 @@ def test_commit_table_update_schema(
295409
assert new_schema == update._apply()
296410
assert new_schema.find_field("b").field_type == IntegerType()
297411

412+
table.append(
413+
pa.Table.from_pylist(
414+
[
415+
{
416+
"foo": "foo_val",
417+
"bar": 1,
418+
"location": [{"latitude": 1.1}],
419+
"person": {"name": "some_name", "age": 23},
420+
"b": 2,
421+
}
422+
],
423+
schema=schema_to_pyarrow(new_schema),
424+
),
425+
)
426+
427+
assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [
428+
{
429+
"Data": [
430+
{"VarCharValue": "foo"},
431+
{"VarCharValue": "bar"},
432+
{"VarCharValue": "baz"},
433+
{"VarCharValue": "qux"},
434+
{"VarCharValue": "quux"},
435+
{"VarCharValue": "location"},
436+
{"VarCharValue": "person"},
437+
{"VarCharValue": "b"},
438+
]
439+
},
440+
{
441+
"Data": [
442+
{"VarCharValue": "foo_val"},
443+
{"VarCharValue": "1"},
444+
{},
445+
{"VarCharValue": "[]"},
446+
{"VarCharValue": "{}"},
447+
{"VarCharValue": "[{latitude=1.1, longitude=null}]"},
448+
{"VarCharValue": "{name=some_name, age=23}"},
449+
{"VarCharValue": "2"},
450+
]
451+
},
452+
]
453+
298454

299455
def test_commit_table_properties(test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str) -> None:
300456
identifier = (database_name, table_name)

0 commit comments

Comments
 (0)