Skip to content

Commit ad0a92d

Browse files
Add remove_expired_refs to ExpireSnapshots
1 parent 721c5aa commit ad0a92d

File tree

3 files changed

+186
-41
lines changed

3 files changed

+186
-41
lines changed

pyiceberg/table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class TableProperties:
205205
MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep"
206206
MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1
207207

208+
MAX_REF_AGE_MS = "history.expire.max-ref-age-ms"
209+
208210

209211
class Transaction:
210212
_table: Table

pyiceberg/table/update/snapshot.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,53 +1025,70 @@ def _current_ancestors(self) -> set[int]:
10251025

10261026

10271027
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
1028-
"""Expire snapshots by ID.
1028+
"""Expire snapshots and refs.
10291029
10301030
Use table.expire_snapshots().<operation>().commit() to run a specific operation.
10311031
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
1032-
Pending changes are applied on commit.
1032+
Pending changes are applied on commit. Call order does not affect the result.
10331033
"""
10341034

10351035
_updates: tuple[TableUpdate, ...]
10361036
_requirements: tuple[TableRequirement, ...]
10371037
_snapshot_ids_to_expire: set[int]
1038+
_ref_names_to_expire: set[str]
1039+
_expire_older_than_ms: int | None
10381040

10391041
def __init__(self, transaction: Transaction) -> None:
10401042
super().__init__(transaction)
10411043
self._updates = ()
10421044
self._requirements = ()
10431045
self._snapshot_ids_to_expire = set()
1046+
self._ref_names_to_expire = set()
1047+
self._expire_older_than_ms = None
10441048

10451049
def _commit(self) -> UpdatesAndRequirements:
10461050
"""
10471051
Commit the staged updates and requirements.
10481052
1049-
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
1053+
Applies all pending expirations: explicit snapshot IDs, age-based snapshot expiry,
1054+
and ref removals. Protected snapshots (branch/tag heads not being expired) are always
1055+
excluded.
10501056
10511057
Returns:
10521058
Tuple of updates and requirements to be committed,
10531059
as required by the calling parent apply functions.
10541060
"""
1055-
# Remove any protected snapshot IDs from the set to expire, just in case
10561061
protected_ids = self._get_protected_snapshot_ids()
1057-
self._snapshot_ids_to_expire -= protected_ids
1058-
update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire)
1059-
self._updates += (update,)
1062+
1063+
if self._expire_older_than_ms is not None:
1064+
for snapshot in self._transaction.table_metadata.snapshots:
1065+
if snapshot.timestamp_ms < self._expire_older_than_ms and snapshot.snapshot_id not in protected_ids:
1066+
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1067+
1068+
snapshot_ids_to_expire = self._snapshot_ids_to_expire - protected_ids
1069+
1070+
updates: list[TableUpdate] = list(self._updates)
1071+
for ref_name in self._ref_names_to_expire:
1072+
updates.append(RemoveSnapshotRefUpdate(ref_name=ref_name))
1073+
if snapshot_ids_to_expire:
1074+
updates.append(RemoveSnapshotsUpdate(snapshot_ids=snapshot_ids_to_expire))
1075+
self._updates = tuple(updates)
10601076
return self._updates, self._requirements
10611077

10621078
def _get_protected_snapshot_ids(self) -> set[int]:
10631079
"""
1064-
Get the IDs of protected snapshots.
1080+
Get the IDs of snapshots that must not be expired.
10651081
1066-
These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration.
1082+
These are the HEAD snapshots of all branches and tags that are not
1083+
already marked for removal via remove_expired_refs().
10671084
10681085
Returns:
10691086
Set of protected snapshot IDs to exclude from expiration.
10701087
"""
10711088
return {
10721089
ref.snapshot_id
1073-
for ref in self._transaction.table_metadata.refs.values()
1074-
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]
1090+
for name, ref in self._transaction.table_metadata.refs.items()
1091+
if name not in self._ref_names_to_expire
10751092
}
10761093

10771094
def by_id(self, snapshot_id: int) -> ExpireSnapshots:
@@ -1112,17 +1129,44 @@ def by_ids(self, snapshot_ids: list[int]) -> ExpireSnapshots:
11121129

11131130
def older_than(self, dt: datetime) -> ExpireSnapshots:
11141131
"""
1115-
Expire all unprotected snapshots with a timestamp older than a given value.
1132+
Expire all unprotected snapshots with a timestamp older than the given value.
1133+
1134+
The filter is evaluated at commit time so that snapshots left without a ref
1135+
by remove_expired_refs() are also considered, regardless of call order.
11161136
11171137
Args:
11181138
dt (datetime): Only snapshots with datetime < this value will be expired.
11191139
11201140
Returns:
11211141
This for method chaining.
11221142
"""
1123-
protected_ids = self._get_protected_snapshot_ids()
1124-
expire_from = datetime_to_millis(dt)
1125-
for snapshot in self._transaction.table_metadata.snapshots:
1126-
if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids:
1127-
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1143+
self._expire_older_than_ms = datetime_to_millis(dt)
1144+
return self
1145+
1146+
def remove_expired_refs(self, default_max_ref_age_ms: int) -> ExpireSnapshots:
1147+
"""
1148+
Mark stale branches and tags for removal.
1149+
1150+
A ref is expired when the age of its snapshot exceeds its own max_ref_age_ms.
1151+
If a ref has no per-ref max_ref_age_ms set, default_max_ref_age_ms is used as fallback.
1152+
The main branch is never removed.
1153+
1154+
Snapshots left without any live ref after this call are no longer protected,
1155+
so a subsequent older_than() will include them in age-based expiry.
1156+
1157+
Args:
1158+
default_max_ref_age_ms: Fallback max age in milliseconds for refs that have no
1159+
per-ref max_ref_age_ms configured.
1160+
1161+
Returns:
1162+
This for method chaining.
1163+
"""
1164+
now_ms = int(datetime.now().timestamp() * 1000)
1165+
for name, ref in self._transaction.table_metadata.refs.items():
1166+
if name == MAIN_BRANCH:
1167+
continue
1168+
effective_max_ref_age_ms = ref.max_ref_age_ms if ref.max_ref_age_ms is not None else default_max_ref_age_ms
1169+
snapshot = self._transaction.table_metadata.snapshot_by_id(ref.snapshot_id)
1170+
if snapshot is None or (now_ms - snapshot.timestamp_ms) > effective_max_ref_age_ms:
1171+
self._ref_names_to_expire.add(name)
11281172
return self

tests/table/test_expire_snapshots.py

Lines changed: 123 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import pytest
2323

2424
from pyiceberg.table import CommitTableResponse, Table
25-
from pyiceberg.table.update import RemoveSnapshotsUpdate, update_table_metadata
25+
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
26+
from pyiceberg.table.update import RemoveSnapshotRefUpdate, RemoveSnapshotsUpdate, update_table_metadata
2627
from pyiceberg.table.update.snapshot import ExpireSnapshots
2728

2829

@@ -92,8 +93,8 @@ def test_expire_unprotected_snapshot(table_v2: Table) -> None:
9293
table_v2.metadata = table_v2.metadata.model_copy(
9394
update={
9495
"refs": {
95-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
96-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
96+
"main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
97+
"tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}),
9798
}
9899
}
99100
)
@@ -134,8 +135,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
134135
table_v2.metadata = table_v2.metadata.model_copy(
135136
update={
136137
"refs": {
137-
"main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"),
138-
"mytag": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"),
138+
"main": SnapshotRef(**{"snapshot-id": HEAD_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
139+
"mytag": SnapshotRef(**{"snapshot-id": TAGGED_SNAPSHOT, "type": SnapshotRefType.TAG}),
139140
},
140141
"snapshots": [
141142
SimpleNamespace(snapshot_id=HEAD_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
@@ -165,13 +166,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
165166
assert HEAD_SNAPSHOT in remaining_ids
166167
assert TAGGED_SNAPSHOT in remaining_ids
167168

168-
# No snapshots should have been expired (commit_table called, but with empty snapshot_ids)
169-
args, kwargs = table_v2.catalog.commit_table.call_args
170-
updates = args[2] if len(args) > 2 else ()
171-
# Find RemoveSnapshotsUpdate in updates
172-
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
173-
assert remove_update is not None
174-
assert remove_update.snapshot_ids == []
169+
# No snapshots expired and no refs expired — commit_table should not be called at all
170+
table_v2.catalog.commit_table.assert_not_called()
175171

176172

177173
def test_expire_snapshots_by_ids(table_v2: Table) -> None:
@@ -188,24 +184,14 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None:
188184
table_v2.catalog = MagicMock()
189185
table_v2.catalog.commit_table.return_value = mock_response
190186

191-
# Remove any refs that protect the snapshots to be expired
192-
table_v2.metadata = table_v2.metadata.model_copy(
193-
update={
194-
"refs": {
195-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
196-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
197-
}
198-
}
199-
)
200-
201187
# Add snapshots to metadata for multi-id test
202188
from types import SimpleNamespace
203189

204190
table_v2.metadata = table_v2.metadata.model_copy(
205191
update={
206192
"refs": {
207-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
208-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
193+
"main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
194+
"tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}),
209195
},
210196
"snapshots": [
211197
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_1, timestamp_ms=1, parent_snapshot_id=None),
@@ -316,3 +302,116 @@ def test_update_remove_snapshots_with_statistics(table_v2_with_statistics: Table
316302
assert not any(stat.snapshot_id == REMOVE_SNAPSHOT for stat in new_metadata.statistics), (
317303
"Statistics for removed snapshot should be gone"
318304
)
305+
306+
307+
def _make_commit_response(table: Table) -> CommitTableResponse:
308+
return CommitTableResponse(
309+
metadata=table.metadata,
310+
metadata_location="mock://metadata/location",
311+
uuid=uuid4(),
312+
)
313+
314+
315+
def test_ref_expiration_removes_old_tag_and_snapshot(table_v2: Table) -> None:
316+
"""A tag whose snapshot age exceeds max_ref_age_ms is removed; its orphaned snapshot
317+
is also expired when older_than() is combined."""
318+
OLD_SNAPSHOT = 3051729675574597004
319+
320+
table_v2.catalog = MagicMock()
321+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
322+
323+
# "test" tag (fixture) points to OLD_SNAPSHOT with max-ref-age-ms=10000000 (~2.7 h).
324+
# OLD_SNAPSHOT timestamp is from 2018 — definitely older than 2.7 h.
325+
assert "test" in table_v2.metadata.refs
326+
assert table_v2.metadata.refs["test"].snapshot_id == OLD_SNAPSHOT
327+
328+
future = datetime.now() + timedelta(days=1)
329+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).older_than(future).commit()
330+
331+
args, _ = table_v2.catalog.commit_table.call_args
332+
updates = args[2]
333+
334+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
335+
snap_updates = [u for u in updates if isinstance(u, RemoveSnapshotsUpdate)]
336+
337+
assert any(u.ref_name == "test" for u in ref_updates), "Expected 'test' tag to be removed"
338+
assert any(OLD_SNAPSHOT in u.snapshot_ids for u in snap_updates), (
339+
"Expected OLD_SNAPSHOT to be removed since it is no longer referenced"
340+
)
341+
342+
343+
def test_ref_expiration_removes_old_branch(table_v2: Table) -> None:
344+
"""A non-main branch whose snapshot age exceeds max_ref_age_ms is removed."""
345+
OLD_SNAPSHOT = 3051729675574597004
346+
CURRENT_SNAPSHOT = 3055729675574597004
347+
348+
table_v2.catalog = MagicMock()
349+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
350+
351+
table_v2.metadata = table_v2.metadata.model_copy(
352+
update={
353+
"refs": {
354+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
355+
"stale-branch": SnapshotRef(**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}),
356+
}
357+
}
358+
)
359+
360+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
361+
362+
args, _ = table_v2.catalog.commit_table.call_args
363+
updates = args[2]
364+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
365+
assert any(u.ref_name == "stale-branch" for u in ref_updates)
366+
assert not any(u.ref_name == "main" for u in ref_updates)
367+
368+
369+
def test_main_branch_never_expires(table_v2: Table) -> None:
370+
"""main branch is never removed regardless of age or max_ref_age_ms."""
371+
CURRENT_SNAPSHOT = 3055729675574597004
372+
373+
table_v2.catalog = MagicMock()
374+
375+
table_v2.metadata = table_v2.metadata.model_copy(
376+
update={
377+
"refs": {
378+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}),
379+
}
380+
}
381+
)
382+
383+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
384+
385+
table_v2.catalog.commit_table.assert_not_called()
386+
387+
388+
def test_young_ref_is_retained(table_v2: Table) -> None:
389+
"""A ref whose snapshot is within max_ref_age_ms is not removed."""
390+
OLD_SNAPSHOT = 3051729675574597004
391+
CURRENT_SNAPSHOT = 3055729675574597004
392+
393+
table_v2.catalog = MagicMock()
394+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
395+
396+
# fresh-tag has a huge max_ref_age_ms — it should never expire
397+
# stale-tag has max_ref_age_ms=1 — it will be expired (triggers a commit)
398+
table_v2.metadata = table_v2.metadata.model_copy(
399+
update={
400+
"refs": {
401+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
402+
"fresh-tag": SnapshotRef(
403+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 9999999999999}
404+
),
405+
"stale-tag": SnapshotRef(**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 1}),
406+
}
407+
}
408+
)
409+
410+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
411+
412+
table_v2.catalog.commit_table.assert_called_once()
413+
args, _ = table_v2.catalog.commit_table.call_args
414+
updates = args[2]
415+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
416+
assert any(u.ref_name == "stale-tag" for u in ref_updates), "stale-tag should be expired"
417+
assert not any(u.ref_name == "fresh-tag" for u in ref_updates), "fresh-tag must not be expired"

0 commit comments

Comments
 (0)