Skip to content

Commit 8373d34

Browse files
Add remove_expired_refs to ExpireSnapshots
1 parent 721c5aa commit 8373d34

File tree

3 files changed

+250
-42
lines changed

3 files changed

+250
-42
lines changed

pyiceberg/table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class TableProperties:
205205
MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep"
206206
MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1
207207

208+
MAX_REF_AGE_MS = "history.expire.max-ref-age-ms"
209+
208210

209211
class Transaction:
210212
_table: Table

pyiceberg/table/update/snapshot.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,53 +1025,70 @@ def _current_ancestors(self) -> set[int]:
10251025

10261026

10271027
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
1028-
"""Expire snapshots by ID.
1028+
"""Expire snapshots and refs.
10291029
10301030
Use table.expire_snapshots().<operation>().commit() to run a specific operation.
10311031
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
1032-
Pending changes are applied on commit.
1032+
Pending changes are applied on commit. Call order does not affect the result.
10331033
"""
10341034

10351035
_updates: tuple[TableUpdate, ...]
10361036
_requirements: tuple[TableRequirement, ...]
10371037
_snapshot_ids_to_expire: set[int]
1038+
_ref_names_to_expire: set[str]
1039+
_expire_older_than_ms: int | None
10381040

10391041
def __init__(self, transaction: Transaction) -> None:
10401042
super().__init__(transaction)
10411043
self._updates = ()
10421044
self._requirements = ()
10431045
self._snapshot_ids_to_expire = set()
1046+
self._ref_names_to_expire = set()
1047+
self._expire_older_than_ms = None
10441048

10451049
def _commit(self) -> UpdatesAndRequirements:
10461050
"""
10471051
Commit the staged updates and requirements.
10481052
1049-
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
1053+
Applies all pending expirations: explicit snapshot IDs, age-based snapshot expiry,
1054+
and ref removals. Protected snapshots (branch/tag heads not being expired) are always
1055+
excluded.
10501056
10511057
Returns:
10521058
Tuple of updates and requirements to be committed,
10531059
as required by the calling parent apply functions.
10541060
"""
1055-
# Remove any protected snapshot IDs from the set to expire, just in case
10561061
protected_ids = self._get_protected_snapshot_ids()
1057-
self._snapshot_ids_to_expire -= protected_ids
1058-
update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire)
1059-
self._updates += (update,)
1062+
1063+
if self._expire_older_than_ms is not None:
1064+
for snapshot in self._transaction.table_metadata.snapshots:
1065+
if snapshot.timestamp_ms < self._expire_older_than_ms and snapshot.snapshot_id not in protected_ids:
1066+
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1067+
1068+
snapshot_ids_to_expire = self._snapshot_ids_to_expire - protected_ids
1069+
1070+
updates: list[TableUpdate] = list(self._updates)
1071+
for ref_name in self._ref_names_to_expire:
1072+
updates.append(RemoveSnapshotRefUpdate(ref_name=ref_name))
1073+
if snapshot_ids_to_expire:
1074+
updates.append(RemoveSnapshotsUpdate(snapshot_ids=snapshot_ids_to_expire))
1075+
self._updates = tuple(updates)
10601076
return self._updates, self._requirements
10611077

10621078
def _get_protected_snapshot_ids(self) -> set[int]:
10631079
"""
1064-
Get the IDs of protected snapshots.
1080+
Get the IDs of snapshots that must not be expired.
10651081
1066-
These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration.
1082+
These are the HEAD snapshots of all branches and tags that are not
1083+
already marked for removal via remove_expired_refs().
10671084
10681085
Returns:
10691086
Set of protected snapshot IDs to exclude from expiration.
10701087
"""
10711088
return {
10721089
ref.snapshot_id
1073-
for ref in self._transaction.table_metadata.refs.values()
1074-
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]
1090+
for name, ref in self._transaction.table_metadata.refs.items()
1091+
if name not in self._ref_names_to_expire
10751092
}
10761093

10771094
def by_id(self, snapshot_id: int) -> ExpireSnapshots:
@@ -1112,17 +1129,44 @@ def by_ids(self, snapshot_ids: list[int]) -> ExpireSnapshots:
11121129

11131130
def older_than(self, dt: datetime) -> ExpireSnapshots:
11141131
"""
1115-
Expire all unprotected snapshots with a timestamp older than a given value.
1132+
Expire all unprotected snapshots with a timestamp older than the given value.
1133+
1134+
The filter is evaluated at commit time so that snapshots left without a ref
1135+
by remove_expired_refs() are also considered, regardless of call order.
11161136
11171137
Args:
1118-
dt (datetime): Only snapshots with datetime < this value will be expired.
1138+
dt (datetime): Only snapshots with timestamp < this value will be expired.
11191139
11201140
Returns:
11211141
This for method chaining.
11221142
"""
1123-
protected_ids = self._get_protected_snapshot_ids()
1124-
expire_from = datetime_to_millis(dt)
1125-
for snapshot in self._transaction.table_metadata.snapshots:
1126-
if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids:
1127-
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1143+
self._expire_older_than_ms = datetime_to_millis(dt)
1144+
return self
1145+
1146+
def remove_expired_refs(self, default_max_ref_age_ms: int) -> ExpireSnapshots:
1147+
"""
1148+
Mark stale branches and tags for removal.
1149+
1150+
A ref is expired when the age of its snapshot exceeds its own max_ref_age_ms.
1151+
If a ref has no per-ref max_ref_age_ms set, default_max_ref_age_ms is used as fallback.
1152+
The main branch is never removed.
1153+
1154+
Snapshots left without any live ref after this call are no longer protected,
1155+
so a subsequent older_than() will include them in age-based expiry.
1156+
1157+
Args:
1158+
default_max_ref_age_ms: Fallback max age in milliseconds for refs that have no
1159+
per-ref max_ref_age_ms configured.
1160+
1161+
Returns:
1162+
This for method chaining.
1163+
"""
1164+
now_ms = int(datetime.now().timestamp() * 1000)
1165+
for name, ref in self._transaction.table_metadata.refs.items():
1166+
if name == MAIN_BRANCH:
1167+
continue
1168+
effective_max_ref_age_ms = ref.max_ref_age_ms if ref.max_ref_age_ms is not None else default_max_ref_age_ms
1169+
snapshot = self._transaction.table_metadata.snapshot_by_id(ref.snapshot_id)
1170+
if snapshot is None or (now_ms - snapshot.timestamp_ms) > effective_max_ref_age_ms:
1171+
self._ref_names_to_expire.add(name)
11281172
return self

tests/table/test_expire_snapshots.py

Lines changed: 186 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import pytest
2323

2424
from pyiceberg.table import CommitTableResponse, Table
25-
from pyiceberg.table.update import RemoveSnapshotsUpdate, update_table_metadata
25+
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
26+
from pyiceberg.table.update import RemoveSnapshotRefUpdate, RemoveSnapshotsUpdate, update_table_metadata
2627
from pyiceberg.table.update.snapshot import ExpireSnapshots
2728

2829

@@ -92,8 +93,8 @@ def test_expire_unprotected_snapshot(table_v2: Table) -> None:
9293
table_v2.metadata = table_v2.metadata.model_copy(
9394
update={
9495
"refs": {
95-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
96-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
96+
"main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
97+
"tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}),
9798
}
9899
}
99100
)
@@ -134,8 +135,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
134135
table_v2.metadata = table_v2.metadata.model_copy(
135136
update={
136137
"refs": {
137-
"main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"),
138-
"mytag": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"),
138+
"main": SnapshotRef(**{"snapshot-id": HEAD_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
139+
"mytag": SnapshotRef(**{"snapshot-id": TAGGED_SNAPSHOT, "type": SnapshotRefType.TAG}),
139140
},
140141
"snapshots": [
141142
SimpleNamespace(snapshot_id=HEAD_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
@@ -165,13 +166,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
165166
assert HEAD_SNAPSHOT in remaining_ids
166167
assert TAGGED_SNAPSHOT in remaining_ids
167168

168-
# No snapshots should have been expired (commit_table called, but with empty snapshot_ids)
169-
args, kwargs = table_v2.catalog.commit_table.call_args
170-
updates = args[2] if len(args) > 2 else ()
171-
# Find RemoveSnapshotsUpdate in updates
172-
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
173-
assert remove_update is not None
174-
assert remove_update.snapshot_ids == []
169+
# No snapshots expired and no refs expired — commit_table should not be called at all
170+
table_v2.catalog.commit_table.assert_not_called()
175171

176172

177173
def test_expire_snapshots_by_ids(table_v2: Table) -> None:
@@ -188,24 +184,14 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None:
188184
table_v2.catalog = MagicMock()
189185
table_v2.catalog.commit_table.return_value = mock_response
190186

191-
# Remove any refs that protect the snapshots to be expired
192-
table_v2.metadata = table_v2.metadata.model_copy(
193-
update={
194-
"refs": {
195-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
196-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
197-
}
198-
}
199-
)
200-
201187
# Add snapshots to metadata for multi-id test
202188
from types import SimpleNamespace
203189

204190
table_v2.metadata = table_v2.metadata.model_copy(
205191
update={
206192
"refs": {
207-
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
208-
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
193+
"main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
194+
"tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}),
209195
},
210196
"snapshots": [
211197
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_1, timestamp_ms=1, parent_snapshot_id=None),
@@ -316,3 +302,179 @@ def test_update_remove_snapshots_with_statistics(table_v2_with_statistics: Table
316302
assert not any(stat.snapshot_id == REMOVE_SNAPSHOT for stat in new_metadata.statistics), (
317303
"Statistics for removed snapshot should be gone"
318304
)
305+
306+
307+
# --- Ref expiration tests ---
308+
# The table_v2 fixture has two snapshots:
309+
# 3051729675574597004 (timestamp_ms=1515100955770, ~Jan 2018)
310+
# 3055729675574597004 (timestamp_ms=1555100955770, ~Apr 2019, current/main)
311+
# And a "test" tag pointing to 3051729675574597004 with max-ref-age-ms=10000000 (~2.7 h).
312+
313+
OLD_SNAPSHOT = 3051729675574597004
314+
CURRENT_SNAPSHOT = 3055729675574597004
315+
316+
317+
def _make_commit_response(table: Table) -> CommitTableResponse:
318+
return CommitTableResponse(
319+
metadata=table.metadata,
320+
metadata_location="mock://metadata/location",
321+
uuid=uuid4(),
322+
)
323+
324+
325+
def test_ref_expiration_removes_old_tag_and_snapshot(table_v2: Table) -> None:
326+
"""A tag whose snapshot age exceeds max_ref_age_ms is removed; its orphaned snapshot
327+
is also expired when older_than() is combined."""
328+
table_v2.catalog = MagicMock()
329+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
330+
331+
# "test" tag (fixture) points to OLD_SNAPSHOT with max-ref-age-ms=10000000 (~2.7 h).
332+
# OLD_SNAPSHOT timestamp is from 2018 — definitely older than 2.7 h.
333+
assert "test" in table_v2.metadata.refs
334+
assert table_v2.metadata.refs["test"].snapshot_id == OLD_SNAPSHOT
335+
336+
future = datetime.now() + timedelta(days=1)
337+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).older_than(future).commit()
338+
339+
args, _ = table_v2.catalog.commit_table.call_args
340+
updates = args[2]
341+
342+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
343+
snap_updates = [u for u in updates if isinstance(u, RemoveSnapshotsUpdate)]
344+
345+
assert any(u.ref_name == "test" for u in ref_updates), "Expected 'test' tag to be removed"
346+
assert any(OLD_SNAPSHOT in u.snapshot_ids for u in snap_updates), (
347+
"Expected OLD_SNAPSHOT to be removed since it is no longer referenced"
348+
)
349+
350+
351+
def test_ref_expiration_removes_old_branch(table_v2: Table) -> None:
352+
"""A non-main branch whose snapshot age exceeds max_ref_age_ms is removed."""
353+
table_v2.catalog = MagicMock()
354+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
355+
356+
table_v2.metadata = table_v2.metadata.model_copy(
357+
update={
358+
"refs": {
359+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
360+
"stale-branch": SnapshotRef(
361+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}
362+
),
363+
}
364+
}
365+
)
366+
367+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
368+
369+
args, _ = table_v2.catalog.commit_table.call_args
370+
updates = args[2]
371+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
372+
assert any(u.ref_name == "stale-branch" for u in ref_updates)
373+
assert not any(u.ref_name == "main" for u in ref_updates)
374+
375+
376+
def test_main_branch_never_expires(table_v2: Table) -> None:
377+
"""main branch is never removed regardless of age or max_ref_age_ms."""
378+
table_v2.catalog = MagicMock()
379+
380+
table_v2.metadata = table_v2.metadata.model_copy(
381+
update={
382+
"refs": {
383+
"main": SnapshotRef(
384+
**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}
385+
),
386+
}
387+
}
388+
)
389+
390+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
391+
392+
table_v2.catalog.commit_table.assert_not_called()
393+
394+
395+
def test_remove_expired_refs_default_max_ref_age_ms(table_v2: Table) -> None:
396+
"""A ref without per-ref max_ref_age_ms uses the method parameter as fallback."""
397+
table_v2.catalog = MagicMock()
398+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
399+
400+
table_v2.metadata = table_v2.metadata.model_copy(
401+
update={
402+
"refs": {
403+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
404+
# no per-ref max_ref_age_ms
405+
"old-tag": SnapshotRef(**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG}),
406+
}
407+
}
408+
)
409+
410+
# 1 ms default — old-tag's snapshot (from 2018) will exceed it
411+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
412+
413+
args, _ = table_v2.catalog.commit_table.call_args
414+
updates = args[2]
415+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
416+
assert any(u.ref_name == "old-tag" for u in ref_updates)
417+
assert not any(u.ref_name == "main" for u in ref_updates)
418+
419+
420+
def test_young_ref_is_retained(table_v2: Table) -> None:
421+
"""A ref whose snapshot is within max_ref_age_ms is not removed."""
422+
table_v2.catalog = MagicMock()
423+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
424+
425+
# fresh-tag has a huge max_ref_age_ms — it should never expire
426+
# stale-tag has max_ref_age_ms=1 — it will be expired (triggers a commit)
427+
table_v2.metadata = table_v2.metadata.model_copy(
428+
update={
429+
"refs": {
430+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
431+
"fresh-tag": SnapshotRef(
432+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 9999999999999}
433+
),
434+
"stale-tag": SnapshotRef(
435+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 1}
436+
),
437+
}
438+
}
439+
)
440+
441+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
442+
443+
table_v2.catalog.commit_table.assert_called_once()
444+
args, _ = table_v2.catalog.commit_table.call_args
445+
updates = args[2]
446+
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
447+
assert any(u.ref_name == "stale-tag" for u in ref_updates), "stale-tag should be expired"
448+
assert not any(u.ref_name == "fresh-tag" for u in ref_updates), "fresh-tag must not be expired"
449+
450+
451+
def test_remove_expired_refs_and_older_than_order_independent(table_v2: Table) -> None:
452+
"""remove_expired_refs().older_than() and older_than().remove_expired_refs() produce the same result."""
453+
table_v2.catalog = MagicMock()
454+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
455+
456+
table_v2.metadata = table_v2.metadata.model_copy(
457+
update={
458+
"refs": {
459+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
460+
"stale-tag": SnapshotRef(
461+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 1}
462+
),
463+
}
464+
}
465+
)
466+
467+
future = datetime.now() + timedelta(days=1)
468+
469+
def _get_expired_ids(order: str) -> set[int]:
470+
table_v2.catalog.reset_mock()
471+
expire = table_v2.maintenance.expire_snapshots()
472+
if order == "refs_first":
473+
expire.remove_expired_refs(default_max_ref_age_ms=1).older_than(future).commit()
474+
else:
475+
expire.older_than(future).remove_expired_refs(default_max_ref_age_ms=1).commit()
476+
args, _ = table_v2.catalog.commit_table.call_args
477+
snap_updates = [u for u in args[2] if isinstance(u, RemoveSnapshotsUpdate)]
478+
return {sid for u in snap_updates for sid in u.snapshot_ids}
479+
480+
assert _get_expired_ids("refs_first") == _get_expired_ids("snapshots_first")

0 commit comments

Comments
 (0)