Skip to content

Commit 12b9458

Browse files
committed
support CreateTableTransaction for SQL
1 parent 3910e5e commit 12b9458

File tree

2 files changed

+129
-35
lines changed

2 files changed

+129
-35
lines changed

pyiceberg/catalog/sql.py

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -376,55 +376,89 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
376376
identifier_tuple = self.identifier_to_tuple_without_catalog(
377377
tuple(table_request.identifier.namespace.root + [table_request.identifier.name])
378378
)
379-
current_table = self.load_table(identifier_tuple)
380379
database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError)
381-
base_metadata = current_table.metadata
380+
381+
current_table: Optional[Table]
382+
try:
383+
current_table = self.load_table(identifier_tuple)
384+
except NoSuchTableError:
385+
current_table = None
386+
382387
for requirement in table_request.requirements:
383-
requirement.validate(base_metadata)
388+
requirement.validate(current_table.metadata if current_table else None)
384389

385-
updated_metadata = update_table_metadata(base_metadata, table_request.updates)
386-
if updated_metadata == base_metadata:
390+
updated_metadata = update_table_metadata(
391+
base_metadata=current_table.metadata if current_table else self._empty_table_metadata(),
392+
updates=table_request.updates,
393+
enforce_validation=current_table is None,
394+
)
395+
if current_table and updated_metadata == current_table.metadata:
387396
# no changes, do nothing
388-
return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location)
397+
return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location)
389398

390399
# write new metadata
391-
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1
392-
new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version)
393-
self._write_metadata(updated_metadata, current_table.io, new_metadata_location)
400+
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 if current_table else 0
401+
new_metadata_location = self._get_metadata_location(updated_metadata.location, new_metadata_version)
402+
self._write_metadata(
403+
metadata=updated_metadata,
404+
io=self._load_file_io(updated_metadata.properties, new_metadata_location),
405+
metadata_path=new_metadata_location,
406+
)
394407

395408
with Session(self.engine) as session:
396-
if self.engine.dialect.supports_sane_rowcount:
397-
stmt = (
398-
update(IcebergTables)
399-
.where(
400-
IcebergTables.catalog_name == self.name,
401-
IcebergTables.table_namespace == database_name,
402-
IcebergTables.table_name == table_name,
403-
IcebergTables.metadata_location == current_table.metadata_location,
404-
)
405-
.values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location)
406-
)
407-
result = session.execute(stmt)
408-
if result.rowcount < 1:
409-
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}")
410-
else:
411-
try:
412-
tbl = (
413-
session.query(IcebergTables)
414-
.with_for_update(of=IcebergTables)
415-
.filter(
409+
if current_table:
410+
# table exists, update it
411+
if self.engine.dialect.supports_sane_rowcount:
412+
stmt = (
413+
update(IcebergTables)
414+
.where(
416415
IcebergTables.catalog_name == self.name,
417416
IcebergTables.table_namespace == database_name,
418417
IcebergTables.table_name == table_name,
419418
IcebergTables.metadata_location == current_table.metadata_location,
420419
)
421-
.one()
420+
.values(
421+
metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location
422+
)
422423
)
423-
tbl.metadata_location = new_metadata_location
424-
tbl.previous_metadata_location = current_table.metadata_location
425-
except NoResultFound as e:
426-
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") from e
427-
session.commit()
424+
result = session.execute(stmt)
425+
if result.rowcount < 1:
426+
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}")
427+
else:
428+
try:
429+
tbl = (
430+
session.query(IcebergTables)
431+
.with_for_update(of=IcebergTables)
432+
.filter(
433+
IcebergTables.catalog_name == self.name,
434+
IcebergTables.table_namespace == database_name,
435+
IcebergTables.table_name == table_name,
436+
IcebergTables.metadata_location == current_table.metadata_location,
437+
)
438+
.one()
439+
)
440+
tbl.metadata_location = new_metadata_location
441+
tbl.previous_metadata_location = current_table.metadata_location
442+
except NoResultFound as e:
443+
raise CommitFailedException(
444+
f"Table has been updated by another process: {database_name}.{table_name}"
445+
) from e
446+
session.commit()
447+
else:
448+
# table does not exist, create it
449+
try:
450+
session.add(
451+
IcebergTables(
452+
catalog_name=self.name,
453+
table_namespace=database_name,
454+
table_name=table_name,
455+
metadata_location=new_metadata_location,
456+
previous_metadata_location=None,
457+
)
458+
)
459+
session.commit()
460+
except IntegrityError as e:
461+
raise TableAlreadyExistsError(f"Table {database_name}.{table_name} already exists") from e
428462

429463
return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location)
430464

tests/catalog/test_sql.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,66 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
948948
snapshot_update.append_data_file(data_file)
949949

950950

951+
@pytest.mark.parametrize(
952+
'catalog',
953+
[
954+
lazy_fixture('catalog_memory'),
955+
lazy_fixture('catalog_sqlite'),
956+
lazy_fixture('catalog_sqlite_without_rowcount'),
957+
],
958+
)
959+
@pytest.mark.parametrize("format_version", [1, 2])
960+
def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> None:
961+
identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}"
962+
try:
963+
catalog.create_namespace("default")
964+
except NamespaceAlreadyExistsError:
965+
pass
966+
967+
try:
968+
catalog.drop_table(identifier=identifier)
969+
except NoSuchTableError:
970+
pass
971+
972+
pa_table = pa.Table.from_pydict(
973+
{
974+
'foo': ['a', None, 'z'],
975+
},
976+
schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
977+
)
978+
979+
pa_table_with_column = pa.Table.from_pydict(
980+
{
981+
'foo': ['a', None, 'z'],
982+
'bar': [19, None, 25],
983+
},
984+
schema=pa.schema([
985+
pa.field("foo", pa.string(), nullable=True),
986+
pa.field("bar", pa.int32(), nullable=True),
987+
]),
988+
)
989+
990+
with catalog.create_table_transaction(
991+
identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)}
992+
) as txn:
993+
with txn.update_snapshot().fast_append() as snapshot_update:
994+
for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io):
995+
snapshot_update.append_data_file(data_file)
996+
997+
with txn.update_schema() as schema_txn:
998+
schema_txn.union_by_name(pa_table_with_column.schema)
999+
1000+
with txn.update_snapshot().fast_append() as snapshot_update:
1001+
for data_file in _dataframe_to_data_files(
1002+
table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io
1003+
):
1004+
snapshot_update.append_data_file(data_file)
1005+
1006+
tbl = catalog.load_table(identifier=identifier)
1007+
assert tbl.format_version == format_version
1008+
assert len(tbl.scan().to_arrow()) == 6
1009+
1010+
9511011
@pytest.mark.parametrize(
9521012
'catalog',
9531013
[

0 commit comments

Comments
 (0)