From ad0a92da35f8debc8136e39ade7d0667f275fdfa Mon Sep 17 00:00:00 2001 From: Alessandro Nori Date: Thu, 16 Apr 2026 10:32:13 +0200 Subject: [PATCH] Add remove_expired_refs to ExpireSnapshots --- pyiceberg/table/__init__.py | 2 + pyiceberg/table/update/snapshot.py | 78 ++++++++++---- tests/table/test_expire_snapshots.py | 147 ++++++++++++++++++++++----- 3 files changed, 186 insertions(+), 41 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index bb8765b651..0af59cbe46 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -205,6 +205,8 @@ class TableProperties: MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep" MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1 + MAX_REF_AGE_MS = "history.expire.max-ref-age-ms" + class Transaction: _table: Table diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 37d120969a..4c00ebc38a 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -1025,53 +1025,70 @@ def _current_ancestors(self) -> set[int]: class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): - """Expire snapshots by ID. + """Expire snapshots and refs. Use table.expire_snapshots().().commit() to run a specific operation. Use table.expire_snapshots().().().commit() to run multiple operations. - Pending changes are applied on commit. + Pending changes are applied on commit. Call order does not affect the result. """ _updates: tuple[TableUpdate, ...] _requirements: tuple[TableRequirement, ...] _snapshot_ids_to_expire: set[int] + _ref_names_to_expire: set[str] + _expire_older_than_ms: int | None def __init__(self, transaction: Transaction) -> None: super().__init__(transaction) self._updates = () self._requirements = () self._snapshot_ids_to_expire = set() + self._ref_names_to_expire = set() + self._expire_older_than_ms = None def _commit(self) -> UpdatesAndRequirements: """ Commit the staged updates and requirements. - This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads). + Applies all pending expirations: explicit snapshot IDs, age-based snapshot expiry, + and ref removals. Protected snapshots (branch/tag heads not being expired) are always + excluded. Returns: Tuple of updates and requirements to be committed, as required by the calling parent apply functions. """ - # Remove any protected snapshot IDs from the set to expire, just in case protected_ids = self._get_protected_snapshot_ids() - self._snapshot_ids_to_expire -= protected_ids - update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire) - self._updates += (update,) + + if self._expire_older_than_ms is not None: + for snapshot in self._transaction.table_metadata.snapshots: + if snapshot.timestamp_ms < self._expire_older_than_ms and snapshot.snapshot_id not in protected_ids: + self._snapshot_ids_to_expire.add(snapshot.snapshot_id) + + snapshot_ids_to_expire = self._snapshot_ids_to_expire - protected_ids + + updates: list[TableUpdate] = list(self._updates) + for ref_name in self._ref_names_to_expire: + updates.append(RemoveSnapshotRefUpdate(ref_name=ref_name)) + if snapshot_ids_to_expire: + updates.append(RemoveSnapshotsUpdate(snapshot_ids=snapshot_ids_to_expire)) + self._updates = tuple(updates) return self._updates, self._requirements def _get_protected_snapshot_ids(self) -> set[int]: """ - Get the IDs of protected snapshots. + Get the IDs of snapshots that must not be expired. - These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration. + These are the HEAD snapshots of all branches and tags that are not + already marked for removal via remove_expired_refs(). Returns: Set of protected snapshot IDs to exclude from expiration. """ return { ref.snapshot_id - for ref in self._transaction.table_metadata.refs.values() - if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH] + for name, ref in self._transaction.table_metadata.refs.items() + if name not in self._ref_names_to_expire } def by_id(self, snapshot_id: int) -> ExpireSnapshots: @@ -1112,7 +1129,10 @@ def by_ids(self, snapshot_ids: list[int]) -> ExpireSnapshots: def older_than(self, dt: datetime) -> ExpireSnapshots: """ - Expire all unprotected snapshots with a timestamp older than a given value. + Expire all unprotected snapshots with a timestamp older than the given value. + + The filter is evaluated at commit time so that snapshots left without a ref + by remove_expired_refs() are also considered, regardless of call order. Args: dt (datetime): Only snapshots with datetime < this value will be expired. @@ -1120,9 +1140,33 @@ def older_than(self, dt: datetime) -> ExpireSnapshots: Returns: This for method chaining. """ - protected_ids = self._get_protected_snapshot_ids() - expire_from = datetime_to_millis(dt) - for snapshot in self._transaction.table_metadata.snapshots: - if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids: - self._snapshot_ids_to_expire.add(snapshot.snapshot_id) + self._expire_older_than_ms = datetime_to_millis(dt) + return self + + def remove_expired_refs(self, default_max_ref_age_ms: int) -> ExpireSnapshots: + """ + Mark stale branches and tags for removal. + + A ref is expired when the age of its snapshot exceeds its own max_ref_age_ms. + If a ref has no per-ref max_ref_age_ms set, default_max_ref_age_ms is used as fallback. + The main branch is never removed. + + Snapshots left without any live ref after this call are no longer protected, + so a subsequent older_than() will include them in age-based expiry. + + Args: + default_max_ref_age_ms: Fallback max age in milliseconds for refs that have no + per-ref max_ref_age_ms configured. + + Returns: + This for method chaining. + """ + now_ms = int(datetime.now().timestamp() * 1000) + for name, ref in self._transaction.table_metadata.refs.items(): + if name == MAIN_BRANCH: + continue + effective_max_ref_age_ms = ref.max_ref_age_ms if ref.max_ref_age_ms is not None else default_max_ref_age_ms + snapshot = self._transaction.table_metadata.snapshot_by_id(ref.snapshot_id) + if snapshot is None or (now_ms - snapshot.timestamp_ms) > effective_max_ref_age_ms: + self._ref_names_to_expire.add(name) return self diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py index 106e5b786c..610139a19f 100644 --- a/tests/table/test_expire_snapshots.py +++ b/tests/table/test_expire_snapshots.py @@ -22,7 +22,8 @@ import pytest from pyiceberg.table import CommitTableResponse, Table -from pyiceberg.table.update import RemoveSnapshotsUpdate, update_table_metadata +from pyiceberg.table.refs import SnapshotRef, SnapshotRefType +from pyiceberg.table.update import RemoveSnapshotRefUpdate, RemoveSnapshotsUpdate, update_table_metadata from pyiceberg.table.update.snapshot import ExpireSnapshots @@ -92,8 +93,8 @@ def test_expire_unprotected_snapshot(table_v2: Table) -> None: table_v2.metadata = table_v2.metadata.model_copy( update={ "refs": { - "main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"), - "tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"), + "main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}), + "tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}), } } ) @@ -134,8 +135,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None: table_v2.metadata = table_v2.metadata.model_copy( update={ "refs": { - "main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"), - "mytag": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"), + "main": SnapshotRef(**{"snapshot-id": HEAD_SNAPSHOT, "type": SnapshotRefType.BRANCH}), + "mytag": SnapshotRef(**{"snapshot-id": TAGGED_SNAPSHOT, "type": SnapshotRefType.TAG}), }, "snapshots": [ SimpleNamespace(snapshot_id=HEAD_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None), @@ -165,13 +166,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None: assert HEAD_SNAPSHOT in remaining_ids assert TAGGED_SNAPSHOT in remaining_ids - # No snapshots should have been expired (commit_table called, but with empty snapshot_ids) - args, kwargs = table_v2.catalog.commit_table.call_args - updates = args[2] if len(args) > 2 else () - # Find RemoveSnapshotsUpdate in updates - remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None) - assert remove_update is not None - assert remove_update.snapshot_ids == [] + # No snapshots expired and no refs expired — commit_table should not be called at all + table_v2.catalog.commit_table.assert_not_called() def test_expire_snapshots_by_ids(table_v2: Table) -> None: @@ -188,24 +184,14 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None: table_v2.catalog = MagicMock() table_v2.catalog.commit_table.return_value = mock_response - # Remove any refs that protect the snapshots to be expired - table_v2.metadata = table_v2.metadata.model_copy( - update={ - "refs": { - "main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"), - "tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"), - } - } - ) - # Add snapshots to metadata for multi-id test from types import SimpleNamespace table_v2.metadata = table_v2.metadata.model_copy( update={ "refs": { - "main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"), - "tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"), + "main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}), + "tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}), }, "snapshots": [ SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_1, timestamp_ms=1, parent_snapshot_id=None), @@ -316,3 +302,116 @@ def test_update_remove_snapshots_with_statistics(table_v2_with_statistics: Table assert not any(stat.snapshot_id == REMOVE_SNAPSHOT for stat in new_metadata.statistics), ( "Statistics for removed snapshot should be gone" ) + + +def _make_commit_response(table: Table) -> CommitTableResponse: + return CommitTableResponse( + metadata=table.metadata, + metadata_location="mock://metadata/location", + uuid=uuid4(), + ) + + +def test_ref_expiration_removes_old_tag_and_snapshot(table_v2: Table) -> None: + """A tag whose snapshot age exceeds max_ref_age_ms is removed; its orphaned snapshot + is also expired when older_than() is combined.""" + OLD_SNAPSHOT = 3051729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2) + + # "test" tag (fixture) points to OLD_SNAPSHOT with max-ref-age-ms=10000000 (~2.7 h). + # OLD_SNAPSHOT timestamp is from 2018 — definitely older than 2.7 h. + assert "test" in table_v2.metadata.refs + assert table_v2.metadata.refs["test"].snapshot_id == OLD_SNAPSHOT + + future = datetime.now() + timedelta(days=1) + table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).older_than(future).commit() + + args, _ = table_v2.catalog.commit_table.call_args + updates = args[2] + + ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)] + snap_updates = [u for u in updates if isinstance(u, RemoveSnapshotsUpdate)] + + assert any(u.ref_name == "test" for u in ref_updates), "Expected 'test' tag to be removed" + assert any(OLD_SNAPSHOT in u.snapshot_ids for u in snap_updates), ( + "Expected OLD_SNAPSHOT to be removed since it is no longer referenced" + ) + + +def test_ref_expiration_removes_old_branch(table_v2: Table) -> None: + """A non-main branch whose snapshot age exceeds max_ref_age_ms is removed.""" + OLD_SNAPSHOT = 3051729675574597004 + CURRENT_SNAPSHOT = 3055729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2) + + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}), + "stale-branch": SnapshotRef(**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}), + } + } + ) + + table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit() + + args, _ = table_v2.catalog.commit_table.call_args + updates = args[2] + ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)] + assert any(u.ref_name == "stale-branch" for u in ref_updates) + assert not any(u.ref_name == "main" for u in ref_updates) + + +def test_main_branch_never_expires(table_v2: Table) -> None: + """main branch is never removed regardless of age or max_ref_age_ms.""" + CURRENT_SNAPSHOT = 3055729675574597004 + + table_v2.catalog = MagicMock() + + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}), + } + } + ) + + table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_young_ref_is_retained(table_v2: Table) -> None: + """A ref whose snapshot is within max_ref_age_ms is not removed.""" + OLD_SNAPSHOT = 3051729675574597004 + CURRENT_SNAPSHOT = 3055729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2) + + # fresh-tag has a huge max_ref_age_ms — it should never expire + # stale-tag has max_ref_age_ms=1 — it will be expired (triggers a commit) + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}), + "fresh-tag": SnapshotRef( + **{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 9999999999999} + ), + "stale-tag": SnapshotRef(**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 1}), + } + } + ) + + table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit() + + table_v2.catalog.commit_table.assert_called_once() + args, _ = table_v2.catalog.commit_table.call_args + updates = args[2] + ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)] + assert any(u.ref_name == "stale-tag" for u in ref_updates), "stale-tag should be expired" + assert not any(u.ref_name == "fresh-tag" for u in ref_updates), "fresh-tag must not be expired"