Skip to content

Commit e358b4f

Browse files
Refactor: add remove_expired_refs method
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 15f2635 commit e358b4f

2 files changed

Lines changed: 106 additions & 95 deletions

File tree

pyiceberg/table/update/snapshot.py

Lines changed: 57 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,110 +1025,71 @@ def _current_ancestors(self) -> set[int]:
10251025

10261026

10271027
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
1028-
"""Expire snapshots by ID.
1028+
"""Expire snapshots and refs.
10291029
10301030
Use table.expire_snapshots().<operation>().commit() to run a specific operation.
10311031
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
1032-
Pending changes are applied on commit.
1032+
Pending changes are applied on commit. Call order does not affect the result.
10331033
"""
10341034

10351035
_updates: tuple[TableUpdate, ...]
10361036
_requirements: tuple[TableRequirement, ...]
10371037
_snapshot_ids_to_expire: set[int]
1038+
_ref_names_to_expire: set[str]
1039+
_expire_older_than_ms: int | None
10381040

10391041
def __init__(self, transaction: Transaction) -> None:
10401042
super().__init__(transaction)
10411043
self._updates = ()
10421044
self._requirements = ()
10431045
self._snapshot_ids_to_expire = set()
1046+
self._ref_names_to_expire = set()
1047+
self._expire_older_than_ms = None
10441048

10451049
def _commit(self) -> UpdatesAndRequirements:
10461050
"""
10471051
Commit the staged updates and requirements.
10481052
1049-
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
1050-
Refs (branches/tags) whose snapshot age exceeds their configured max-ref-age-ms are also removed.
1053+
Applies all pending expirations: explicit snapshot IDs, age-based snapshot expiry,
1054+
and ref removals. Protected snapshots (branch/tag heads not being expired) are always
1055+
excluded. The age threshold from older_than() is evaluated here so that call order
1056+
with remove_expired_refs() does not affect the result.
10511057
10521058
Returns:
10531059
Tuple of updates and requirements to be committed,
10541060
as required by the calling parent apply functions.
10551061
"""
1056-
now_ms = int(datetime.now().timestamp() * 1000)
1057-
expired_ref_names = self._compute_expired_refs(now_ms)
1058-
protected_ids = self._get_protected_snapshot_ids(expired_ref_names)
1062+
protected_ids = self._get_protected_snapshot_ids()
10591063

1060-
# Snapshots exclusively referenced by expired refs are also eligible for expiration
1061-
for ref_name in expired_ref_names:
1062-
ref = self._transaction.table_metadata.refs[ref_name]
1063-
if ref.snapshot_id not in protected_ids:
1064-
self._snapshot_ids_to_expire.add(ref.snapshot_id)
1064+
if self._expire_older_than_ms is not None:
1065+
for snapshot in self._transaction.table_metadata.snapshots:
1066+
if snapshot.timestamp_ms < self._expire_older_than_ms and snapshot.snapshot_id not in protected_ids:
1067+
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
10651068

10661069
snapshot_ids_to_expire = self._snapshot_ids_to_expire - protected_ids
10671070

10681071
updates: list[TableUpdate] = list(self._updates)
1069-
for ref_name in expired_ref_names:
1072+
for ref_name in self._ref_names_to_expire:
10701073
updates.append(RemoveSnapshotRefUpdate(ref_name=ref_name))
10711074
if snapshot_ids_to_expire:
10721075
updates.append(RemoveSnapshotsUpdate(snapshot_ids=snapshot_ids_to_expire))
10731076
self._updates = tuple(updates)
10741077
return self._updates, self._requirements
10751078

1076-
def _compute_expired_refs(self, now_ms: int) -> set[str]:
1077-
"""
1078-
Compute the set of ref names (branches/tags) that should be expired.
1079-
1080-
A ref is expired when the age of its snapshot exceeds:
1081-
- the ref's own max_ref_age_ms, or
1082-
- the table property history.expire.max-ref-age-ms, if the ref has no per-ref setting.
1083-
The main branch is never expired. Refs with no effective max-ref-age configuration are skipped.
1084-
1085-
Args:
1086-
now_ms: Current time in milliseconds.
1087-
1088-
Returns:
1089-
Set of ref names to remove.
1090-
"""
1091-
from pyiceberg.table import TableProperties
1092-
1093-
props = self._transaction.table_metadata.properties
1094-
table_max_ref_age_ms: int | None = (
1095-
int(props[TableProperties.MAX_REF_AGE_MS]) if TableProperties.MAX_REF_AGE_MS in props else None
1096-
)
1097-
1098-
expired: set[str] = set()
1099-
for name, ref in self._transaction.table_metadata.refs.items():
1100-
if name == MAIN_BRANCH:
1101-
continue
1102-
effective_max_ref_age_ms = ref.max_ref_age_ms if ref.max_ref_age_ms is not None else table_max_ref_age_ms
1103-
if effective_max_ref_age_ms is None:
1104-
continue
1105-
snapshot = self._transaction.table_metadata.snapshot_by_id(ref.snapshot_id)
1106-
if snapshot is None:
1107-
expired.add(name)
1108-
continue
1109-
if (now_ms - snapshot.timestamp_ms) > effective_max_ref_age_ms:
1110-
expired.add(name)
1111-
return expired
1112-
1113-
def _get_protected_snapshot_ids(self, expired_ref_names: set[str] | None = None) -> set[int]:
1079+
def _get_protected_snapshot_ids(self) -> set[int]:
11141080
"""
1115-
Get the IDs of protected snapshots.
1081+
Get the IDs of snapshots that must not be expired.
11161082
1117-
These are the HEAD snapshots of all branches and all tagged snapshots that are NOT being expired.
1118-
Snapshots exclusively pointed to by expired refs are not protected.
1119-
1120-
Args:
1121-
expired_ref_names: Set of ref names that are being expired. Defaults to empty set.
1083+
These are the HEAD snapshots of all branches and tags that are not
1084+
already marked for removal via remove_expired_refs().
11221085
11231086
Returns:
11241087
Set of protected snapshot IDs to exclude from expiration.
11251088
"""
1126-
if expired_ref_names is None:
1127-
expired_ref_names = set()
11281089
return {
11291090
ref.snapshot_id
11301091
for name, ref in self._transaction.table_metadata.refs.items()
1131-
if name not in expired_ref_names
1092+
if name not in self._ref_names_to_expire
11321093
}
11331094

11341095
def by_id(self, snapshot_id: int) -> ExpireSnapshots:
@@ -1169,17 +1130,46 @@ def by_ids(self, snapshot_ids: list[int]) -> ExpireSnapshots:
11691130

11701131
def older_than(self, dt: datetime) -> ExpireSnapshots:
11711132
"""
1172-
Expire all unprotected snapshots with a timestamp older than a given value.
1133+
Expire all unprotected snapshots with a timestamp older than the given value.
1134+
1135+
The filter is evaluated at commit time so that snapshots left without a ref
1136+
by remove_expired_refs() are also considered, regardless of call order.
11731137
11741138
Args:
1175-
dt (datetime): Only snapshots with datetime < this value will be expired.
1139+
dt (datetime): Only snapshots with timestamp < this value will be expired.
11761140
11771141
Returns:
11781142
This for method chaining.
11791143
"""
1180-
protected_ids = self._get_protected_snapshot_ids()
1181-
expire_from = datetime_to_millis(dt)
1182-
for snapshot in self._transaction.table_metadata.snapshots:
1183-
if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids:
1184-
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1144+
self._expire_older_than_ms = datetime_to_millis(dt)
1145+
return self
1146+
1147+
def remove_expired_refs(self, default_max_ref_age_ms: int | None = None) -> ExpireSnapshots:
1148+
"""
1149+
Mark stale branches and tags for removal.
1150+
1151+
A ref is expired when the age of its snapshot exceeds its own max_ref_age_ms.
1152+
If a ref has no per-ref max_ref_age_ms set, default_max_ref_age_ms is used as fallback.
1153+
The main branch is never removed.
1154+
1155+
Snapshots left without any live ref after this call are no longer protected,
1156+
so a subsequent older_than() will include them in age-based expiry.
1157+
1158+
Args:
1159+
default_max_ref_age_ms: Fallback max age in milliseconds for refs that have no
1160+
per-ref max_ref_age_ms configured. If None, such refs are skipped.
1161+
1162+
Returns:
1163+
This for method chaining.
1164+
"""
1165+
now_ms = int(datetime.now().timestamp() * 1000)
1166+
for name, ref in self._transaction.table_metadata.refs.items():
1167+
if name == MAIN_BRANCH:
1168+
continue
1169+
effective_max_ref_age_ms = ref.max_ref_age_ms if ref.max_ref_age_ms is not None else default_max_ref_age_ms
1170+
if effective_max_ref_age_ms is None:
1171+
continue
1172+
snapshot = self._transaction.table_metadata.snapshot_by_id(ref.snapshot_id)
1173+
if snapshot is None or (now_ms - snapshot.timestamp_ms) > effective_max_ref_age_ms:
1174+
self._ref_names_to_expire.add(name)
11851175
return self

tests/table/test_expire_snapshots.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ def _make_commit_response(table: Table) -> CommitTableResponse:
323323

324324

325325
def test_ref_expiration_removes_old_tag_and_snapshot(table_v2: Table) -> None:
326-
"""A tag whose snapshot age exceeds max_ref_age_ms is removed along with its snapshot."""
326+
"""A tag whose snapshot age exceeds max_ref_age_ms is removed; its orphaned snapshot
327+
is also expired when older_than() is combined."""
327328
table_v2.catalog = MagicMock()
328329
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
329330

@@ -332,7 +333,8 @@ def test_ref_expiration_removes_old_tag_and_snapshot(table_v2: Table) -> None:
332333
assert "test" in table_v2.metadata.refs
333334
assert table_v2.metadata.refs["test"].snapshot_id == OLD_SNAPSHOT
334335

335-
table_v2.maintenance.expire_snapshots().commit()
336+
future = datetime.now() + timedelta(days=1)
337+
table_v2.maintenance.expire_snapshots().remove_expired_refs().older_than(future).commit()
336338

337339
args, _ = table_v2.catalog.commit_table.call_args
338340
updates = args[2]
@@ -354,17 +356,15 @@ def test_ref_expiration_removes_old_branch(table_v2: Table) -> None:
354356
table_v2.metadata = table_v2.metadata.model_copy(
355357
update={
356358
"refs": {
357-
"main": SnapshotRef(
358-
**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}
359-
),
359+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
360360
"stale-branch": SnapshotRef(
361361
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}
362362
),
363363
}
364364
}
365365
)
366366

367-
table_v2.maintenance.expire_snapshots().commit()
367+
table_v2.maintenance.expire_snapshots().remove_expired_refs().commit()
368368

369369
args, _ = table_v2.catalog.commit_table.call_args
370370
updates = args[2]
@@ -376,9 +376,7 @@ def test_ref_expiration_removes_old_branch(table_v2: Table) -> None:
376376
def test_main_branch_never_expires(table_v2: Table) -> None:
377377
"""main branch is never removed regardless of age or max_ref_age_ms."""
378378
table_v2.catalog = MagicMock()
379-
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
380379

381-
# Only main exists, with max_ref_age_ms=1 (would expire anything else immediately)
382380
table_v2.metadata = table_v2.metadata.model_copy(
383381
update={
384382
"refs": {
@@ -389,41 +387,33 @@ def test_main_branch_never_expires(table_v2: Table) -> None:
389387
}
390388
)
391389

392-
table_v2.maintenance.expire_snapshots().commit()
390+
table_v2.maintenance.expire_snapshots().remove_expired_refs().commit()
393391

394-
# commit_table should NOT be called — main is never expired, so nothing changes
395392
table_v2.catalog.commit_table.assert_not_called()
396393

397394

398-
def test_table_property_max_ref_age_ms_used_as_default(table_v2: Table) -> None:
399-
"""A ref without per-ref max_ref_age_ms uses the table property as fallback."""
395+
def test_remove_expired_refs_default_max_ref_age_ms(table_v2: Table) -> None:
396+
"""A ref without per-ref max_ref_age_ms uses the method parameter as fallback."""
400397
table_v2.catalog = MagicMock()
401398
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
402399

403-
# Ref has no per-ref max_ref_age_ms
404400
table_v2.metadata = table_v2.metadata.model_copy(
405401
update={
406402
"refs": {
407-
"main": SnapshotRef(
408-
**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}
409-
),
410-
"old-tag": SnapshotRef(
411-
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG}
412-
),
413-
},
414-
# Table-level default: 1 ms — everything is expired
415-
"properties": {"history.expire.max-ref-age-ms": "1"},
403+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
404+
# no per-ref max_ref_age_ms
405+
"old-tag": SnapshotRef(**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG}),
406+
}
416407
}
417408
)
418409

419-
table_v2.maintenance.expire_snapshots().commit()
410+
# 1 ms default — old-tag's snapshot (from 2018) will exceed it
411+
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
420412

421413
args, _ = table_v2.catalog.commit_table.call_args
422414
updates = args[2]
423415
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
424-
assert any(u.ref_name == "old-tag" for u in ref_updates), (
425-
"Expected old-tag to be expired via table property default"
426-
)
416+
assert any(u.ref_name == "old-tag" for u in ref_updates)
427417
assert not any(u.ref_name == "main" for u in ref_updates)
428418

429419

@@ -448,12 +438,43 @@ def test_young_ref_is_retained(table_v2: Table) -> None:
448438
}
449439
)
450440

451-
table_v2.maintenance.expire_snapshots().commit()
441+
table_v2.maintenance.expire_snapshots().remove_expired_refs().commit()
452442

453-
# stale-tag causes commit_table to be called; fresh-tag must not be in the expired list
454443
table_v2.catalog.commit_table.assert_called_once()
455444
args, _ = table_v2.catalog.commit_table.call_args
456445
updates = args[2]
457446
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
458447
assert any(u.ref_name == "stale-tag" for u in ref_updates), "stale-tag should be expired"
459448
assert not any(u.ref_name == "fresh-tag" for u in ref_updates), "fresh-tag must not be expired"
449+
450+
451+
def test_remove_expired_refs_and_older_than_order_independent(table_v2: Table) -> None:
452+
"""remove_expired_refs().older_than() and older_than().remove_expired_refs() produce the same result."""
453+
table_v2.catalog = MagicMock()
454+
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
455+
456+
table_v2.metadata = table_v2.metadata.model_copy(
457+
update={
458+
"refs": {
459+
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
460+
"stale-tag": SnapshotRef(
461+
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 1}
462+
),
463+
}
464+
}
465+
)
466+
467+
future = datetime.now() + timedelta(days=1)
468+
469+
def _get_expired_ids(order: str) -> set[int]:
470+
table_v2.catalog.reset_mock()
471+
expire = table_v2.maintenance.expire_snapshots()
472+
if order == "refs_first":
473+
expire.remove_expired_refs().older_than(future).commit()
474+
else:
475+
expire.older_than(future).remove_expired_refs().commit()
476+
args, _ = table_v2.catalog.commit_table.call_args
477+
snap_updates = [u for u in args[2] if isinstance(u, RemoveSnapshotsUpdate)]
478+
return {sid for u in snap_updates for sid in u.snapshot_ids}
479+
480+
assert _get_expired_ids("refs_first") == _get_expired_ids("snapshots_first")

0 commit comments

Comments
 (0)