Skip to content

Commit c1089a8

Browse files
Add remove_expired_refs to ExpireSnapshots
1 parent 721c5aa commit c1089a8

File tree

3 files changed

+190
-42
lines changed

3 files changed

+190
-42
lines changed

pyiceberg/table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class TableProperties:
205205
MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep"
206206
MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1
207207

208+
MAX_REF_AGE_MS = "history.expire.max-ref-age-ms"
209+
208210

209211
class Transaction:
210212
_table: Table

pyiceberg/table/update/snapshot.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,53 +1025,70 @@ def _current_ancestors(self) -> set[int]:
10251025

10261026

10271027
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
1028-
"""Expire snapshots by ID.
1028+
"""Expire snapshots and refs.
10291029
10301030
Use table.expire_snapshots().<operation>().commit() to run a specific operation.
10311031
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
1032-
Pending changes are applied on commit.
1032+
Pending changes are applied on commit. Call order does not affect the result.
10331033
"""
10341034

10351035
_updates: tuple[TableUpdate, ...]
10361036
_requirements: tuple[TableRequirement, ...]
10371037
_snapshot_ids_to_expire: set[int]
1038+
_ref_names_to_expire: set[str]
1039+
_expire_older_than_ms: int | None
10381040

10391041
def __init__(self, transaction: Transaction) -> None:
10401042
super().__init__(transaction)
10411043
self._updates = ()
10421044
self._requirements = ()
10431045
self._snapshot_ids_to_expire = set()
1046+
self._ref_names_to_expire = set()
1047+
self._expire_older_than_ms = None
10441048

10451049
def _commit(self) -> UpdatesAndRequirements:
10461050
"""
10471051
Commit the staged updates and requirements.
10481052
1049-
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
1053+
Applies all pending expirations: explicit snapshot IDs, age-based snapshot expiry,
1054+
and ref removals. Protected snapshots (branch/tag heads not being expired) are always
1055+
excluded.
10501056
10511057
Returns:
10521058
Tuple of updates and requirements to be committed,
10531059
as required by the calling parent apply functions.
10541060
"""
1055-
# Remove any protected snapshot IDs from the set to expire, just in case
10561061
protected_ids = self._get_protected_snapshot_ids()
1057-
self._snapshot_ids_to_expire -= protected_ids
1058-
update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire)
1059-
self._updates += (update,)
1062+
1063+
if self._expire_older_than_ms is not None:
1064+
for snapshot in self._transaction.table_metadata.snapshots:
1065+
if snapshot.timestamp_ms < self._expire_older_than_ms and snapshot.snapshot_id not in protected_ids:
1066+
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1067+
1068+
snapshot_ids_to_expire = self._snapshot_ids_to_expire - protected_ids
1069+
1070+
updates: list[TableUpdate] = list(self._updates)
1071+
for ref_name in self._ref_names_to_expire:
1072+
updates.append(RemoveSnapshotRefUpdate(ref_name=ref_name))
1073+
if snapshot_ids_to_expire:
1074+
updates.append(RemoveSnapshotsUpdate(snapshot_ids=snapshot_ids_to_expire))
1075+
self._updates = tuple(updates)
10601076
return self._updates, self._requirements
10611077

10621078
def _get_protected_snapshot_ids(self) -> set[int]:
10631079
"""
1064-
Get the IDs of protected snapshots.
1080+
Get the IDs of snapshots that must not be expired.
10651081
1066-
These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration.
1082+
These are the HEAD snapshots of all branches and tags that are not
1083+
already marked for removal via remove_expired_refs().
10671084
10681085
Returns:
10691086
Set of protected snapshot IDs to exclude from expiration.
10701087
"""
10711088
return {
10721089
ref.snapshot_id
1073-
for ref in self._transaction.table_metadata.refs.values()
1074-
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]
1090+
for name, ref in self._transaction.table_metadata.refs.items()
1091+
if name not in self._ref_names_to_expire
10751092
}
10761093

10771094
def by_id(self, snapshot_id: int) -> ExpireSnapshots:
@@ -1112,17 +1129,44 @@ def by_ids(self, snapshot_ids: list[int]) -> ExpireSnapshots:
11121129

11131130
def older_than(self, dt: datetime) -> ExpireSnapshots:
11141131
"""
1115-
Expire all unprotected snapshots with a timestamp older than a given value.
1132+
Expire all unprotected snapshots with a timestamp older than the given value.
1133+
1134+
The filter is evaluated at commit time so that snapshots left without a ref
1135+
by remove_expired_refs() are also considered, regardless of call order.
11161136
11171137
Args:
1118-
dt (datetime): Only snapshots with datetime < this value will be expired.
1138+
dt (datetime): Only snapshots with timestamp < this value will be expired.
11191139
11201140
Returns:
11211141
This for method chaining.
11221142
"""
1123-
protected_ids = self._get_protected_snapshot_ids()
1124-
expire_from = datetime_to_millis(dt)
1125-
for snapshot in self._transaction.table_metadata.snapshots:
1126-
if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids:
1127-
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1143+
self._expire_older_than_ms = datetime_to_millis(dt)
1144+
return self
1145+
1146+
def remove_expired_refs(self, default_max_ref_age_ms: int) -> ExpireSnapshots:
1147+
"""
1148+
Mark stale branches and tags for removal.
1149+
1150+
A ref is expired when the age of its snapshot exceeds its own max_ref_age_ms.
1151+
If a ref has no per-ref max_ref_age_ms set, default_max_ref_age_ms is used as fallback.
1152+
The main branch is never removed.
1153+
1154+
Snapshots left without any live ref after this call are no longer protected,
1155+
so a subsequent older_than() will include them in age-based expiry.
1156+
1157+
Args:
1158+
default_max_ref_age_ms: Fallback max age in milliseconds for refs that have no
1159+
per-ref max_ref_age_ms configured.
1160+
1161+
Returns:
1162+
This for method chaining.
1163+
"""
1164+
now_ms = int(datetime.now().timestamp() * 1000)
1165+
for name, ref in self._transaction.table_metadata.refs.items():
1166+
if name == MAIN_BRANCH:
1167+
continue
1168+
effective_max_ref_age_ms = ref.max_ref_age_ms if ref.max_ref_age_ms is not None else default_max_ref_age_ms
1169+
snapshot = self._transaction.table_metadata.snapshot_by_id(ref.snapshot_id)
1170+
if snapshot is None or (now_ms - snapshot.timestamp_ms) > effective_max_ref_age_ms:
1171+
self._ref_names_to_expire.add(name)
11281172
return self

tests/table/test_expire_snapshots.py

Lines changed: 126 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import pytest
2323

2424
from pyiceberg.table import CommitTableResponse, Table
25-
from pyiceberg.table.update import RemoveSnapshotsUpdate, update_table_metadata
25+
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
26+
from pyiceberg.table.update import RemoveSnapshotRefUpdate, RemoveSnapshotsUpdate, update_table_metadata
2627
from pyiceberg.table.update.snapshot import ExpireSnapshots
2728

2829

@@ -92,8 +93,8 @@ def test_expire_unprotected_snapshot(table_v2: Table) -> None:
9293
table_v2.metadata = table_v2.metadata.model_copy(
9394
update={
9495
"refs": {
95-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
96-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
96+
"main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
97+
"tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}),
9798
}
9899
}
99100
)
@@ -134,8 +135,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
134135
table_v2.metadata = table_v2.metadata.model_copy(
135136
update={
136137
"refs": {
137-
"main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"),
138-
"mytag": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"),
138+
"main": SnapshotRef(**{"snapshot-id": HEAD_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
139+
"mytag": SnapshotRef(**{"snapshot-id": TAGGED_SNAPSHOT, "type": SnapshotRefType.TAG}),
139140
},
140141
"snapshots": [
141142
SimpleNamespace(snapshot_id=HEAD_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
@@ -165,13 +166,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
165166
assert HEAD_SNAPSHOT in remaining_ids
166167
assert TAGGED_SNAPSHOT in remaining_ids
167168

168-
# No snapshots should have been expired (commit_table called, but with empty snapshot_ids)
169-
args, kwargs = table_v2.catalog.commit_table.call_args
170-
updates = args[2] if len(args) > 2 else ()
171-
# Find RemoveSnapshotsUpdate in updates
172-
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
173-
assert remove_update is not None
174-
assert remove_update.snapshot_ids == []
169+
# No snapshots expired and no refs expired — commit_table should not be called at all
170+
table_v2.catalog.commit_table.assert_not_called()
175171

176172

177173
def test_expire_snapshots_by_ids(table_v2: Table) -> None:
@@ -188,24 +184,14 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None:
188184
table_v2.catalog = MagicMock()
189185
table_v2.catalog.commit_table.return_value = mock_response
190186

191-
# Remove any refs that protect the snapshots to be expired
192-
table_v2.metadata = table_v2.metadata.model_copy(
193-
update={
194-
"refs": {
195-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
196-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
197-
}
198-
}
199-
)
200-
201187
# Add snapshots to metadata for multi-id test
202188
from types import SimpleNamespace
203189

204190
table_v2.metadata = table_v2.metadata.model_copy(
205191
update={
206192
"refs": {
207-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
208-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
193+
"main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
194+
"tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}),
209195
},
210196
"snapshots": [
211197
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_1, timestamp_ms=1, parent_snapshot_id=None),
@@ -316,3 +302,119 @@ def test_update_remove_snapshots_with_statistics(table_v2_with_statistics: Table
316302
assert not any(stat.snapshot_id == REMOVE_SNAPSHOT for stat in new_metadata.statistics), (
317303
"Statistics for removed snapshot should be gone"
318304
)
305+
306+
307+
OLD_SNAPSHOT = 3051729675574597004
308+
CURRENT_SNAPSHOT = 3055729675574597004
309+
310+
311+
def _make_commit_response(table: Table) -> CommitTableResponse:
312+
return CommitTableResponse(
313+
metadata=table.metadata,
314+
metadata_location="mock://metadata/location",
315+
uuid=uuid4(),
316+
)
317+
318+
319+
def test_ref_expiration_removes_old_tag_and_snapshot(table_v2: Table) -> None:
320+
"""A tag whose snapshot age exceeds max_ref_age_ms is removed; its orphaned snapshot
321+
is also expired when older_than() is combined."""
322+
table_v2.catalog = MagicMock()
323+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
324+
325+
# "test" tag (fixture) points to OLD_SNAPSHOT with max-ref-age-ms=10000000 (~2.7 h).
326+
# OLD_SNAPSHOT timestamp is from 2018 — definitely older than 2.7 h.
327+
assert "test" in table_v2.metadata.refs
328+
assert table_v2.metadata.refs["test"].snapshot_id == OLD_SNAPSHOT
329+
330+
future = datetime.now() + timedelta(days=1)
331+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).older_than(future).commit()
332+
333+
args, _ = table_v2.catalog.commit_table.call_args
334+
updates = args[2]
335+
336+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
337+
snap_updates = [u for u in updates if isinstance(u, RemoveSnapshotsUpdate)]
338+
339+
assert any(u.ref_name == "test" for u in ref_updates), "Expected 'test' tag to be removed"
340+
assert any(OLD_SNAPSHOT in u.snapshot_ids for u in snap_updates), (
341+
"Expected OLD_SNAPSHOT to be removed since it is no longer referenced"
342+
)
343+
344+
345+
def test_ref_expiration_removes_old_branch(table_v2: Table) -> None:
346+
"""A non-main branch whose snapshot age exceeds max_ref_age_ms is removed."""
347+
table_v2.catalog = MagicMock()
348+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
349+
350+
table_v2.metadata = table_v2.metadata.model_copy(
351+
update={
352+
"refs": {
353+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
354+
"stale-branch": SnapshotRef(
355+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}
356+
),
357+
}
358+
}
359+
)
360+
361+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
362+
363+
args, _ = table_v2.catalog.commit_table.call_args
364+
updates = args[2]
365+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
366+
assert any(u.ref_name == "stale-branch" for u in ref_updates)
367+
assert not any(u.ref_name == "main" for u in ref_updates)
368+
369+
370+
def test_main_branch_never_expires(table_v2: Table) -> None:
371+
"""main branch is never removed regardless of age or max_ref_age_ms."""
372+
table_v2.catalog = MagicMock()
373+
374+
table_v2.metadata = table_v2.metadata.model_copy(
375+
update={
376+
"refs": {
377+
"main": SnapshotRef(
378+
**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}
379+
),
380+
}
381+
}
382+
)
383+
384+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
385+
386+
table_v2.catalog.commit_table.assert_not_called()
387+
388+
389+
390+
def test_young_ref_is_retained(table_v2: Table) -> None:
391+
"""A ref whose snapshot is within max_ref_age_ms is not removed."""
392+
table_v2.catalog = MagicMock()
393+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
394+
395+
# fresh-tag has a huge max_ref_age_ms — it should never expire
396+
# stale-tag has max_ref_age_ms=1 — it will be expired (triggers a commit)
397+
table_v2.metadata = table_v2.metadata.model_copy(
398+
update={
399+
"refs": {
400+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
401+
"fresh-tag": SnapshotRef(
402+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 9999999999999}
403+
),
404+
"stale-tag": SnapshotRef(
405+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 1}
406+
),
407+
}
408+
}
409+
)
410+
411+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
412+
413+
table_v2.catalog.commit_table.assert_called_once()
414+
args, _ = table_v2.catalog.commit_table.call_args
415+
updates = args[2]
416+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
417+
assert any(u.ref_name == "stale-tag" for u in ref_updates), "stale-tag should be expired"
418+
assert not any(u.ref_name == "fresh-tag" for u in ref_updates), "fresh-tag must not be expired"
419+
420+

0 commit comments

Comments
 (0)