Skip to content

Commit 015226d

Browse files
authored
Make the snapshot creation part of the Transaction (#446)
* Make the snapshot creation part of the `Transaction` This is also how it is done in Java, and I really like it since it allows you to easily queue up updates in a transaction. For example, an update to the schema. * Extend the API
1 parent c23c24d commit 015226d

File tree

3 files changed

+213
-91
lines changed

3 files changed

+213
-91
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,7 +1714,7 @@ def fill_parquet_file_metadata(
17141714
data_file.split_offsets = split_offsets
17151715

17161716

1717-
def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
1717+
def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: Optional[Schema] = None) -> Iterator[DataFile]:
17181718
task = next(tasks)
17191719

17201720
try:
@@ -1727,7 +1727,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
17271727
parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties)
17281728

17291729
file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
1730-
file_schema = schema_to_pyarrow(table.schema())
1730+
file_schema = file_schema or table.schema()
1731+
arrow_file_schema = schema_to_pyarrow(file_schema)
17311732

17321733
fo = table.io.new_output(file_path)
17331734
row_group_size = PropertyUtil.property_as_int(
@@ -1736,7 +1737,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
17361737
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
17371738
)
17381739
with fo.create(overwrite=True) as fos:
1739-
with pq.ParquetWriter(fos, schema=file_schema, **parquet_writer_kwargs) as writer:
1740+
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
17401741
writer.write_table(task.df, row_group_size=row_group_size)
17411742

17421743
data_file = DataFile(
@@ -1758,8 +1759,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
17581759
fill_parquet_file_metadata(
17591760
data_file=data_file,
17601761
parquet_metadata=writer.writer.metadata,
1761-
stats_columns=compute_statistics_plan(table.schema(), table.properties),
1762-
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
1762+
stats_columns=compute_statistics_plan(file_schema, table.properties),
1763+
parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
17631764
)
17641765
return iter([data_file])
17651766

pyiceberg/table/__init__.py

Lines changed: 165 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,14 @@ def update_schema(self) -> UpdateSchema:
329329
"""
330330
return UpdateSchema(self._table, self)
331331

332+
def update_snapshot(self) -> UpdateSnapshot:
333+
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
334+
335+
Returns:
336+
A new UpdateSnapshot
337+
"""
338+
return UpdateSnapshot(self._table, self)
339+
332340
def remove_properties(self, *removals: str) -> Transaction:
333341
"""Remove properties.
334342
@@ -351,6 +359,12 @@ def update_location(self, location: str) -> Transaction:
351359
"""
352360
raise NotImplementedError("Not yet implemented")
353361

362+
def schema(self) -> Schema:
363+
try:
364+
return next(update for update in self._updates if isinstance(update, AddSchemaUpdate)).schema_
365+
except StopIteration:
366+
return self._table.schema()
367+
354368
def commit_transaction(self) -> Table:
355369
"""Commit the changes to the catalog.
356370
@@ -965,8 +979,21 @@ def history(self) -> List[SnapshotLogEntry]:
965979
return self.metadata.snapshot_log
966980

967981
def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
982+
"""Create a new UpdateSchema to alter the columns of this table.
983+
984+
Returns:
985+
A new UpdateSchema.
986+
"""
968987
return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive)
969988

989+
def update_snapshot(self) -> UpdateSnapshot:
990+
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
991+
992+
Returns:
993+
A new UpdateSnapshot
994+
"""
995+
return UpdateSnapshot(self)
996+
970997
def name_mapping(self) -> NameMapping:
971998
"""Return the table's field-id NameMapping."""
972999
if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING):
@@ -976,7 +1003,7 @@ def name_mapping(self) -> NameMapping:
9761003

9771004
def append(self, df: pa.Table) -> None:
9781005
"""
979-
Append data to the table.
1006+
Shorthand API for appending a PyArrow table to the table.
9801007
9811008
Args:
9821009
df: The Arrow dataframe that will be appended to overwrite the table
@@ -992,19 +1019,16 @@ def append(self, df: pa.Table) -> None:
9921019
if len(self.spec().fields) > 0:
9931020
raise ValueError("Cannot write to partitioned tables")
9941021

995-
merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self)
996-
997-
# skip writing data files if the dataframe is empty
998-
if df.shape[0] > 0:
999-
data_files = _dataframe_to_data_files(self, df=df)
1000-
for data_file in data_files:
1001-
merge.append_data_file(data_file)
1002-
1003-
merge.commit()
1022+
with self.update_snapshot().fast_append() as update_snapshot:
1023+
# skip writing data files if the dataframe is empty
1024+
if df.shape[0] > 0:
1025+
data_files = _dataframe_to_data_files(self, df=df)
1026+
for data_file in data_files:
1027+
update_snapshot.append_data_file(data_file)
10041028

10051029
def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> None:
10061030
"""
1007-
Overwrite all the data in the table.
1031+
Shorthand for overwriting the table with a PyArrow table.
10081032
10091033
Args:
10101034
df: The Arrow dataframe that will be used to overwrite the table
@@ -1025,18 +1049,12 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
10251049
if len(self.spec().fields) > 0:
10261050
raise ValueError("Cannot write to partitioned tables")
10271051

1028-
merge = _MergingSnapshotProducer(
1029-
operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND,
1030-
table=self,
1031-
)
1032-
1033-
# skip writing data files if the dataframe is empty
1034-
if df.shape[0] > 0:
1035-
data_files = _dataframe_to_data_files(self, df=df)
1036-
for data_file in data_files:
1037-
merge.append_data_file(data_file)
1038-
1039-
merge.commit()
1052+
with self.update_snapshot().overwrite() as update_snapshot:
1053+
# skip writing data files if the dataframe is empty
1054+
if df.shape[0] > 0:
1055+
data_files = _dataframe_to_data_files(self, df=df)
1056+
for data_file in data_files:
1057+
update_snapshot.append_data_file(data_file)
10401058

10411059
def refs(self) -> Dict[str, SnapshotRef]:
10421060
"""Return the snapshot references in the table."""
@@ -2331,7 +2349,12 @@ def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int,
23312349
return f'{location}/metadata/snap-{snapshot_id}-{attempt}-{commit_uuid}.avro'
23322350

23332351

2334-
def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:
2352+
def _dataframe_to_data_files(table: Table, df: pa.Table, file_schema: Optional[Schema] = None) -> Iterable[DataFile]:
2353+
"""Convert a PyArrow table into a DataFile.
2354+
2355+
Returns:
2356+
An iterable that supplies datafiles that represent the table.
2357+
"""
23352358
from pyiceberg.io.pyarrow import write_file
23362359

23372360
if len(table.spec().fields) > 0:
@@ -2342,7 +2365,7 @@ def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:
23422365

23432366
# This is an iter, so we don't have to materialize everything every time
23442367
# This will be more relevant when we start doing partitioned writes
2345-
yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]))
2368+
yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]), file_schema=file_schema)
23462369

23472370

23482371
class _MergingSnapshotProducer:
@@ -2352,55 +2375,35 @@ class _MergingSnapshotProducer:
23522375
_parent_snapshot_id: Optional[int]
23532376
_added_data_files: List[DataFile]
23542377
_commit_uuid: uuid.UUID
2378+
_transaction: Optional[Transaction]
23552379

2356-
def __init__(self, operation: Operation, table: Table) -> None:
2380+
def __init__(self, operation: Operation, table: Table, transaction: Optional[Transaction] = None) -> None:
23572381
self._operation = operation
23582382
self._table = table
23592383
self._snapshot_id = table.new_snapshot_id()
23602384
# Since we only support the main branch for now
23612385
self._parent_snapshot_id = snapshot.snapshot_id if (snapshot := self._table.current_snapshot()) else None
23622386
self._added_data_files = []
23632387
self._commit_uuid = uuid.uuid4()
2388+
self._transaction = transaction
2389+
2390+
def __enter__(self) -> _MergingSnapshotProducer:
2391+
"""Start a transaction to update the table."""
2392+
return self
2393+
2394+
def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
2395+
"""Close and commit the transaction."""
2396+
self.commit()
23642397

23652398
def append_data_file(self, data_file: DataFile) -> _MergingSnapshotProducer:
23662399
self._added_data_files.append(data_file)
23672400
return self
23682401

2369-
def _deleted_entries(self) -> List[ManifestEntry]:
2370-
"""To determine if we need to record any deleted entries.
2371-
2372-
With partial overwrites we have to use the predicate to evaluate
2373-
which entries are affected.
2374-
"""
2375-
if self._operation == Operation.OVERWRITE:
2376-
if self._parent_snapshot_id is not None:
2377-
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)
2378-
if previous_snapshot is None:
2379-
# This should never happen since you cannot overwrite an empty table
2380-
raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}")
2381-
2382-
executor = ExecutorFactory.get_or_create()
2383-
2384-
def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
2385-
return [
2386-
ManifestEntry(
2387-
status=ManifestEntryStatus.DELETED,
2388-
snapshot_id=entry.snapshot_id,
2389-
data_sequence_number=entry.data_sequence_number,
2390-
file_sequence_number=entry.file_sequence_number,
2391-
data_file=entry.data_file,
2392-
)
2393-
for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True)
2394-
if entry.data_file.content == DataFileContent.DATA
2395-
]
2402+
@abstractmethod
2403+
def _deleted_entries(self) -> List[ManifestEntry]: ...
23962404

2397-
list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io))
2398-
return list(chain(*list_of_entries))
2399-
return []
2400-
elif self._operation == Operation.APPEND:
2401-
return []
2402-
else:
2403-
raise ValueError(f"Not implemented for: {self._operation}")
2405+
@abstractmethod
2406+
def _existing_manifests(self) -> List[ManifestFile]: ...
24042407

24052408
def _manifests(self) -> List[ManifestFile]:
24062409
def _write_added_manifest() -> List[ManifestFile]:
@@ -2430,7 +2433,7 @@ def _write_added_manifest() -> List[ManifestFile]:
24302433
def _write_delete_manifest() -> List[ManifestFile]:
24312434
# Check if we need to mark the files as deleted
24322435
deleted_entries = self._deleted_entries()
2433-
if deleted_entries:
2436+
if len(deleted_entries) > 0:
24342437
output_file_location = _new_manifest_path(location=self._table.location(), num=1, commit_uuid=self._commit_uuid)
24352438
with write_manifest(
24362439
format_version=self._table.format_version,
@@ -2445,32 +2448,11 @@ def _write_delete_manifest() -> List[ManifestFile]:
24452448
else:
24462449
return []
24472450

2448-
def _fetch_existing_manifests() -> List[ManifestFile]:
2449-
existing_manifests = []
2450-
2451-
# Add existing manifests
2452-
if self._operation == Operation.APPEND and self._parent_snapshot_id is not None:
2453-
# In case we want to append, just add the existing manifests
2454-
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)
2455-
2456-
if previous_snapshot is None:
2457-
raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}")
2458-
2459-
for manifest in previous_snapshot.manifests(io=self._table.io):
2460-
if (
2461-
manifest.has_added_files()
2462-
or manifest.has_existing_files()
2463-
or manifest.added_snapshot_id == self._snapshot_id
2464-
):
2465-
existing_manifests.append(manifest)
2466-
2467-
return existing_manifests
2468-
24692451
executor = ExecutorFactory.get_or_create()
24702452

24712453
added_manifests = executor.submit(_write_added_manifest)
24722454
delete_manifests = executor.submit(_write_delete_manifest)
2473-
existing_manifests = executor.submit(_fetch_existing_manifests)
2455+
existing_manifests = executor.submit(self._existing_manifests)
24742456

24752457
return added_manifests.result() + delete_manifests.result() + existing_manifests.result()
24762458

@@ -2515,10 +2497,107 @@ def commit(self) -> Snapshot:
25152497
schema_id=self._table.schema().schema_id,
25162498
)
25172499

2518-
with self._table.transaction() as tx:
2519-
tx.add_snapshot(snapshot=snapshot)
2520-
tx.set_ref_snapshot(
2500+
if self._transaction is not None:
2501+
self._transaction.add_snapshot(snapshot=snapshot)
2502+
self._transaction.set_ref_snapshot(
25212503
snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
25222504
)
2505+
else:
2506+
with self._table.transaction() as tx:
2507+
tx.add_snapshot(snapshot=snapshot)
2508+
tx.set_ref_snapshot(
2509+
snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
2510+
)
25232511

25242512
return snapshot
2513+
2514+
2515+
class FastAppendFiles(_MergingSnapshotProducer):
2516+
def _existing_manifests(self) -> List[ManifestFile]:
2517+
"""To determine if there are any existing manifest files.
2518+
2519+
A fast append will add another ManifestFile to the ManifestList.
2520+
All the existing manifest files are considered existing.
2521+
"""
2522+
existing_manifests = []
2523+
2524+
if self._parent_snapshot_id is not None:
2525+
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)
2526+
2527+
if previous_snapshot is None:
2528+
raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}")
2529+
2530+
for manifest in previous_snapshot.manifests(io=self._table.io):
2531+
if manifest.has_added_files() or manifest.has_existing_files() or manifest.added_snapshot_id == self._snapshot_id:
2532+
existing_manifests.append(manifest)
2533+
2534+
return existing_manifests
2535+
2536+
def _deleted_entries(self) -> List[ManifestEntry]:
2537+
"""To determine if we need to record any deleted manifest entries.
2538+
2539+
In case of an append, nothing is deleted.
2540+
"""
2541+
return []
2542+
2543+
2544+
class OverwriteFiles(_MergingSnapshotProducer):
2545+
def _existing_manifests(self) -> List[ManifestFile]:
2546+
"""To determine if there are any existing manifest files.
2547+
2548+
In the of a full overwrite, all the previous manifests are
2549+
considered deleted.
2550+
"""
2551+
return []
2552+
2553+
def _deleted_entries(self) -> List[ManifestEntry]:
2554+
"""To determine if we need to record any deleted entries.
2555+
2556+
With a full overwrite all the entries are considered deleted.
2557+
With partial overwrites we have to use the predicate to evaluate
2558+
which entries are affected.
2559+
"""
2560+
if self._parent_snapshot_id is not None:
2561+
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)
2562+
if previous_snapshot is None:
2563+
# This should never happen since you cannot overwrite an empty table
2564+
raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}")
2565+
2566+
executor = ExecutorFactory.get_or_create()
2567+
2568+
def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
2569+
return [
2570+
ManifestEntry(
2571+
status=ManifestEntryStatus.DELETED,
2572+
snapshot_id=entry.snapshot_id,
2573+
data_sequence_number=entry.data_sequence_number,
2574+
file_sequence_number=entry.file_sequence_number,
2575+
data_file=entry.data_file,
2576+
)
2577+
for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True)
2578+
if entry.data_file.content == DataFileContent.DATA
2579+
]
2580+
2581+
list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io))
2582+
return list(chain(*list_of_entries))
2583+
else:
2584+
return []
2585+
2586+
2587+
class UpdateSnapshot:
2588+
_table: Table
2589+
_transaction: Optional[Transaction]
2590+
2591+
def __init__(self, table: Table, transaction: Optional[Transaction] = None) -> None:
2592+
self._table = table
2593+
self._transaction = transaction
2594+
2595+
def fast_append(self) -> FastAppendFiles:
2596+
return FastAppendFiles(table=self._table, operation=Operation.APPEND, transaction=self._transaction)
2597+
2598+
def overwrite(self) -> OverwriteFiles:
2599+
return OverwriteFiles(
2600+
table=self._table,
2601+
operation=Operation.OVERWRITE if self._table.current_snapshot() is not None else Operation.APPEND,
2602+
transaction=self._transaction,
2603+
)

0 commit comments

Comments
 (0)