Skip to content

Commit 805a791

Browse files
committed
Simplify AssertTableUUID and clean up manifests on abort
AssertTableUUID is constant across retries, so add it once before the loop instead of adding/removing on each iteration. Add _clean_all_uncommitted() that deletes both _uncommitted_manifests and _written_manifests on permanent failure, fixing orphaned manifests from the last attempt. Signed-off-by: Sotaro Hikita <bering1814@gmail.com>
1 parent ca5f652 commit 805a791

3 files changed

Lines changed: 78 additions & 21 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -997,27 +997,32 @@ def commit_transaction(self) -> Table:
997997
total_timeout_val if total_timeout_val is not None else TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT
998998
)
999999
start_time = time.monotonic()
1000+
self._requirements += (AssertTableUUID(uuid=self.table_metadata.table_uuid),)
10001001

1001-
for attempt in range(num_retries + 1):
1002-
try:
1003-
self._requirements += (AssertTableUUID(uuid=self.table_metadata.table_uuid),)
1004-
self._table._do_commit( # pylint: disable=W0212
1005-
updates=self._updates,
1006-
requirements=self._requirements,
1007-
)
1008-
self._cleanup_uncommitted_manifests()
1009-
break
1010-
except CommitFailedException:
1011-
elapsed_ms = (time.monotonic() - start_time) * 1000
1012-
if attempt == num_retries or not self._snapshot_producers or elapsed_ms >= total_timeout_ms:
1013-
raise
1014-
1015-
wait = min(min_wait_ms * (2**attempt), max_wait_ms)
1016-
jitter = random.uniform(0, 0.25 * wait)
1017-
time.sleep((wait + jitter) / 1000.0)
1018-
1019-
self._table.refresh()
1020-
self._rebuild_snapshot_updates()
1002+
try:
1003+
for attempt in range(num_retries + 1):
1004+
try:
1005+
self._table._do_commit( # pylint: disable=W0212
1006+
updates=self._updates,
1007+
requirements=self._requirements,
1008+
)
1009+
self._cleanup_uncommitted_manifests()
1010+
break
1011+
except CommitFailedException:
1012+
elapsed_ms = (time.monotonic() - start_time) * 1000
1013+
if attempt == num_retries or not self._snapshot_producers or elapsed_ms >= total_timeout_ms:
1014+
raise
1015+
1016+
wait = min(min_wait_ms * (2**attempt), max_wait_ms)
1017+
jitter = random.uniform(0, 0.25 * wait)
1018+
time.sleep((wait + jitter) / 1000.0)
1019+
1020+
self._table.refresh()
1021+
self._rebuild_snapshot_updates()
1022+
except Exception:
1023+
for producer in self._snapshot_producers:
1024+
producer._clean_all_uncommitted()
1025+
raise
10211026

10221027
self._updates = ()
10231028
self._requirements = ()
@@ -1034,7 +1039,7 @@ def _rebuild_snapshot_updates(self) -> None:
10341039
from pyiceberg.table.update import AddSnapshotUpdate, AssertRefSnapshotId, SetSnapshotRefUpdate
10351040

10361041
self._updates = tuple(u for u in self._updates if not isinstance(u, (AddSnapshotUpdate, SetSnapshotRefUpdate)))
1037-
self._requirements = tuple(r for r in self._requirements if not isinstance(r, (AssertRefSnapshotId, AssertTableUUID)))
1042+
self._requirements = tuple(r for r in self._requirements if not isinstance(r, AssertRefSnapshotId))
10381043

10391044
for producer in self._snapshot_producers:
10401045
producer._refresh_for_retry()

pyiceberg/table/update/snapshot.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,16 @@ def _cleanup_uncommitted(self) -> None:
378378
logger.warning("Failed to delete uncommitted manifest: %s", path, exc_info=True)
379379
self._uncommitted_manifests.clear()
380380

381+
def _clean_all_uncommitted(self) -> None:
382+
"""Clean up all manifests written during this producer's lifecycle on abort."""
383+
for path in itertools.chain(self._uncommitted_manifests, self._written_manifests):
384+
try:
385+
self._io.delete(path)
386+
except Exception:
387+
logger.warning("Failed to delete uncommitted manifest: %s", path, exc_info=True)
388+
self._uncommitted_manifests.clear()
389+
self._written_manifests.clear()
390+
381391
def _refresh_for_retry(self) -> None:
382392
"""Reset state for a retry attempt with refreshed metadata."""
383393
self._uncommitted_manifests.extend(self._written_manifests)

tests/table/test_commit_retry.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,3 +556,45 @@ def test_overwrite_with_serializable_update_isolation_raises(catalog: Catalog) -
556556

557557
with pytest.raises(ValidationException):
558558
tbl2.overwrite(pa.table({"x": [10, 20, 30]}), overwrite_filter="x > 0")
559+
560+
561+
def test_clean_all_uncommitted_on_validation_exception(catalog: Catalog) -> None:
562+
"""Verify that all manifests are cleaned up when commit aborts with ValidationException."""
563+
catalog.create_namespace("default")
564+
schema = _test_schema()
565+
catalog.create_table("default.clean_abort_test", schema=schema)
566+
567+
import pyarrow as pa
568+
569+
df = pa.table({"x": [1, 2, 3]})
570+
571+
tbl = catalog.load_table("default.clean_abort_test")
572+
tbl.append(df)
573+
574+
tbl1 = catalog.load_table("default.clean_abort_test")
575+
tbl2 = catalog.load_table("default.clean_abort_test")
576+
577+
tbl1.delete("x == 1")
578+
579+
captured_producers: list = []
580+
581+
original_clean_all = None
582+
583+
def capturing_clean_all(self_producer: Any) -> None:
584+
captured_producers.append(self_producer)
585+
original_clean_all(self_producer)
586+
587+
from pyiceberg.table.update.snapshot import _SnapshotProducer
588+
589+
original_clean_all = _SnapshotProducer._clean_all_uncommitted
590+
591+
with patch.object(_SnapshotProducer, "_clean_all_uncommitted", capturing_clean_all):
592+
with pytest.raises(ValidationException):
593+
tbl2.delete("x == 1")
594+
595+
# _clean_all_uncommitted was called on abort
596+
assert len(captured_producers) > 0
597+
# All manifest lists should be cleared
598+
for producer in captured_producers:
599+
assert producer._written_manifests == []
600+
assert producer._uncommitted_manifests == []

0 commit comments

Comments
 (0)