Skip to content

Commit 855e22a

Browse files
committed
chore: update snapshot.py for code cleanup and organization
1 parent 2c6eb0b commit 855e22a

1 file changed

Lines changed: 37 additions & 55 deletions

File tree

pyiceberg/table/update/snapshot.py

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ def _get_protected_snapshot_ids(self) -> Set[int]:
953953
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]
954954
}
955955

956-
def by_id(self, snapshot_id: int) -> "ExpireSnapshots":
956+
def by_id(self, snapshot_id: int) -> ExpireSnapshots:
957957
"""
958958
Expire a snapshot by its ID.
959959
@@ -1008,7 +1008,7 @@ def older_than(self, dt: datetime) -> "ExpireSnapshots":
10081008

10091009
def older_than_with_retention(
10101010
self, timestamp_ms: int, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
1011-
) -> "ExpireSnapshots":
1011+
) -> ExpireSnapshots:
10121012
"""Expire all unprotected snapshots with a timestamp older than a given value, with retention strategies.
10131013
10141014
Args:
@@ -1027,7 +1027,7 @@ def older_than_with_retention(
10271027

10281028
def with_retention_policy(
10291029
self, timestamp_ms: Optional[int] = None, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
1030-
) -> "ExpireSnapshots":
1030+
) -> ExpireSnapshots:
10311031
"""Comprehensive snapshot expiration with multiple retention strategies.
10321032
10331033
This method provides a unified interface for snapshot expiration with various
@@ -1091,7 +1091,7 @@ def with_retention_policy(
10911091
self._snapshot_ids_to_expire.update(snapshots_to_expire)
10921092
return self
10931093

1094-
def retain_last_n(self, n: int) -> "ExpireSnapshots":
1094+
def retain_last_n(self, n: int) -> ExpireSnapshots:
10951095
"""Keep only the last N snapshots, expiring all others.
10961096
10971097
Args:
@@ -1106,28 +1106,31 @@ def retain_last_n(self, n: int) -> "ExpireSnapshots":
11061106
if n < 1:
11071107
raise ValueError("Number of snapshots to retain must be at least 1")
11081108

1109-
protected_ids = self._get_protected_snapshot_ids()
1109+
snapshots_to_keep = self._get_snapshots_to_keep(retain_last_n=n)
1110+
snapshots_to_expire = [
1111+
id for snapshot in self._transaction.table_metadata.snapshots if (id := snapshot.snapshot_id) not in snapshots_to_keep
1112+
]
11101113

1111-
# Sort snapshots by timestamp (most recent first)
1112-
sorted_snapshots = sorted(self._transaction.table_metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)
1114+
self._snapshot_ids_to_expire.update(snapshots_to_expire)
1115+
return self
11131116

1114-
# Keep the last N snapshots and all protected ones
1115-
snapshots_to_keep = set()
1116-
snapshots_to_keep.update(protected_ids)
1117+
def _get_snapshots_to_keep(self, retain_last_n: Optional[int] = None) -> Set[int]:
1118+
"""Get set of snapshot IDs that should be kept based on protection and retention rules.
11171119
1118-
# Add the N most recent snapshots
1119-
for i, snapshot in enumerate(sorted_snapshots):
1120-
if i < n:
1121-
snapshots_to_keep.add(snapshot.snapshot_id)
1120+
Args:
1121+
retain_last_n: Number of most recent snapshots to keep.
11221122
1123-
# Find snapshots to expire
1124-
snapshots_to_expire = []
1125-
for snapshot in self._transaction.table_metadata.snapshots:
1126-
if snapshot.snapshot_id not in snapshots_to_keep:
1127-
snapshots_to_expire.append(snapshot.snapshot_id)
1123+
Returns:
1124+
Set of snapshot IDs to keep.
1125+
"""
1126+
snapshots_to_keep = self._get_protected_snapshot_ids()
11281127

1129-
self._snapshot_ids_to_expire.update(snapshots_to_expire)
1130-
return self
1128+
if retain_last_n is not None:
1129+
# Sort snapshots by timestamp (most recent first), and get most recent N
1130+
sorted_snapshots = sorted(self._transaction.table_metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)
1131+
snapshots_to_keep.update(snapshot.snapshot_id for snapshot in sorted_snapshots[:retain_last_n])
1132+
1133+
return snapshots_to_keep
11311134

11321135
def _get_snapshots_to_expire_with_retention(
11331136
self, timestamp_ms: Optional[int] = None, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
@@ -1142,22 +1145,10 @@ def _get_snapshots_to_expire_with_retention(
11421145
Returns:
11431146
List of snapshot IDs to expire.
11441147
"""
1145-
protected_ids = self._get_protected_snapshot_ids()
1146-
1147-
# Sort snapshots by timestamp (most recent first)
1148-
sorted_snapshots = sorted(self._transaction.table_metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)
1149-
1150-
# Start with all snapshots that could be expired
1151-
candidates_for_expiration = []
1152-
snapshots_to_keep = set(protected_ids)
1153-
1154-
# Apply retain_last_n constraint
1155-
if retain_last_n is not None:
1156-
for i, snapshot in enumerate(sorted_snapshots):
1157-
if i < retain_last_n:
1158-
snapshots_to_keep.add(snapshot.snapshot_id)
1148+
snapshots_to_keep = self._get_snapshots_to_keep(retain_last_n=retain_last_n)
11591149

11601150
# Apply timestamp constraint
1151+
candidates_for_expiration = []
11611152
for snapshot in self._transaction.table_metadata.snapshots:
11621153
if snapshot.snapshot_id not in snapshots_to_keep and (timestamp_ms is None or snapshot.timestamp_ms < timestamp_ms):
11631154
candidates_for_expiration.append(snapshot)
@@ -1166,18 +1157,12 @@ def _get_snapshots_to_expire_with_retention(
11661157
candidates_for_expiration.sort(key=lambda s: s.timestamp_ms)
11671158

11681159
# Apply min_snapshots_to_keep constraint
1169-
total_snapshots = len(self._transaction.table_metadata.snapshots)
1170-
snapshots_to_expire: List[int] = []
1171-
1172-
for candidate in candidates_for_expiration:
1173-
# Check if expiring this snapshot would violate min_snapshots_to_keep
1174-
remaining_after_expiration = total_snapshots - len(snapshots_to_expire) - 1
1175-
1176-
if min_snapshots_to_keep is None or remaining_after_expiration >= min_snapshots_to_keep:
1177-
snapshots_to_expire.append(candidate.snapshot_id)
1178-
else:
1179-
# Stop expiring to maintain minimum count
1180-
break
1160+
if min_snapshots_to_keep is not None:
1161+
total_snapshots = len(self._transaction.table_metadata.snapshots)
1162+
max_to_expire = total_snapshots - min_snapshots_to_keep
1163+
snapshots_to_expire = [candidate.snapshot_id for candidate in candidates_for_expiration[:max_to_expire]]
1164+
else:
1165+
snapshots_to_expire = [candidate.snapshot_id for candidate in candidates_for_expiration]
11811166

11821167
return snapshots_to_expire
11831168

@@ -1187,15 +1172,12 @@ def _get_expiration_properties(self) -> tuple[Optional[int], Optional[int], Opti
11871172
Returns:
11881173
Tuple of (max_snapshot_age_ms, min_snapshots_to_keep, max_ref_age_ms)
11891174
"""
1190-
properties = self._transaction.table_metadata.properties
1191-
1192-
max_snapshot_age_ms = properties.get("history.expire.max-snapshot-age-ms")
1193-
max_snapshot_age = int(max_snapshot_age_ms) if max_snapshot_age_ms is not None else None
1175+
from pyiceberg.table import TableProperties
11941176

1195-
min_snapshots = properties.get("history.expire.min-snapshots-to-keep")
1196-
min_snapshots_to_keep = int(min_snapshots) if min_snapshots is not None else None
1177+
properties = self._transaction.table_metadata.properties
11971178

1198-
max_ref_age = properties.get("history.expire.max-ref-age-ms")
1199-
max_ref_age_ms = int(max_ref_age) if max_ref_age is not None else None
1179+
max_snapshot_age = property_as_int(properties, TableProperties.MAX_SNAPSHOT_AGE_MS)
1180+
min_snapshots_to_keep = property_as_int(properties, TableProperties.MIN_SNAPSHOTS_TO_KEEP)
1181+
max_ref_age_ms = property_as_int(properties, "history.expire.max-ref-age-ms")
12001182

12011183
return max_snapshot_age, min_snapshots_to_keep, max_ref_age_ms

0 commit comments

Comments
 (0)