@@ -953,7 +953,7 @@ def _get_protected_snapshot_ids(self) -> Set[int]:
953953 if ref .snapshot_ref_type in [SnapshotRefType .TAG , SnapshotRefType .BRANCH ]
954954 }
955955
956- def by_id (self , snapshot_id : int ) -> " ExpireSnapshots" :
956+ def by_id (self , snapshot_id : int ) -> ExpireSnapshots :
957957 """
958958 Expire a snapshot by its ID.
959959
@@ -1008,7 +1008,7 @@ def older_than(self, dt: datetime) -> "ExpireSnapshots":
10081008
10091009 def older_than_with_retention (
10101010 self , timestamp_ms : int , retain_last_n : Optional [int ] = None , min_snapshots_to_keep : Optional [int ] = None
1011- ) -> " ExpireSnapshots" :
1011+ ) -> ExpireSnapshots :
10121012 """Expire all unprotected snapshots with a timestamp older than a given value, with retention strategies.
10131013
10141014 Args:
@@ -1027,7 +1027,7 @@ def older_than_with_retention(
10271027
10281028 def with_retention_policy (
10291029 self , timestamp_ms : Optional [int ] = None , retain_last_n : Optional [int ] = None , min_snapshots_to_keep : Optional [int ] = None
1030- ) -> " ExpireSnapshots" :
1030+ ) -> ExpireSnapshots :
10311031 """Comprehensive snapshot expiration with multiple retention strategies.
10321032
10331033 This method provides a unified interface for snapshot expiration with various
@@ -1091,7 +1091,7 @@ def with_retention_policy(
10911091 self ._snapshot_ids_to_expire .update (snapshots_to_expire )
10921092 return self
10931093
1094- def retain_last_n (self , n : int ) -> " ExpireSnapshots" :
1094+ def retain_last_n (self , n : int ) -> ExpireSnapshots :
10951095 """Keep only the last N snapshots, expiring all others.
10961096
10971097 Args:
@@ -1106,28 +1106,31 @@ def retain_last_n(self, n: int) -> "ExpireSnapshots":
11061106 if n < 1 :
11071107 raise ValueError ("Number of snapshots to retain must be at least 1" )
11081108
1109- protected_ids = self ._get_protected_snapshot_ids ()
1109+ snapshots_to_keep = self ._get_snapshots_to_keep (retain_last_n = n )
1110+ snapshots_to_expire = [
1111+ id for snapshot in self ._transaction .table_metadata .snapshots if (id := snapshot .snapshot_id ) not in snapshots_to_keep
1112+ ]
11101113
1111- # Sort snapshots by timestamp (most recent first )
1112- sorted_snapshots = sorted ( self . _transaction . table_metadata . snapshots , key = lambda s : s . timestamp_ms , reverse = True )
1114+ self . _snapshot_ids_to_expire . update ( snapshots_to_expire )
1115+ return self
11131116
1114- # Keep the last N snapshots and all protected ones
1115- snapshots_to_keep = set ()
1116- snapshots_to_keep .update (protected_ids )
1117+ def _get_snapshots_to_keep (self , retain_last_n : Optional [int ] = None ) -> Set [int ]:
1118+ """Get set of snapshot IDs that should be kept based on protection and retention rules.
11171119
1118- # Add the N most recent snapshots
1119- for i , snapshot in enumerate (sorted_snapshots ):
1120- if i < n :
1121- snapshots_to_keep .add (snapshot .snapshot_id )
1120+ Args:
1121+ retain_last_n: Number of most recent snapshots to keep.
11221122
1123- # Find snapshots to expire
1124- snapshots_to_expire = []
1125- for snapshot in self ._transaction .table_metadata .snapshots :
1126- if snapshot .snapshot_id not in snapshots_to_keep :
1127- snapshots_to_expire .append (snapshot .snapshot_id )
1123+ Returns:
1124+ Set of snapshot IDs to keep.
1125+ """
1126+ snapshots_to_keep = self ._get_protected_snapshot_ids ()
11281127
1129- self ._snapshot_ids_to_expire .update (snapshots_to_expire )
1130- return self
1128+ if retain_last_n is not None :
1129+ # Sort snapshots by timestamp (most recent first), and get most recent N
1130+ sorted_snapshots = sorted (self ._transaction .table_metadata .snapshots , key = lambda s : s .timestamp_ms , reverse = True )
1131+ snapshots_to_keep .update (snapshot .snapshot_id for snapshot in sorted_snapshots [:retain_last_n ])
1132+
1133+ return snapshots_to_keep
11311134
11321135 def _get_snapshots_to_expire_with_retention (
11331136 self , timestamp_ms : Optional [int ] = None , retain_last_n : Optional [int ] = None , min_snapshots_to_keep : Optional [int ] = None
@@ -1142,22 +1145,10 @@ def _get_snapshots_to_expire_with_retention(
11421145 Returns:
11431146 List of snapshot IDs to expire.
11441147 """
1145- protected_ids = self ._get_protected_snapshot_ids ()
1146-
1147- # Sort snapshots by timestamp (most recent first)
1148- sorted_snapshots = sorted (self ._transaction .table_metadata .snapshots , key = lambda s : s .timestamp_ms , reverse = True )
1149-
1150- # Start with all snapshots that could be expired
1151- candidates_for_expiration = []
1152- snapshots_to_keep = set (protected_ids )
1153-
1154- # Apply retain_last_n constraint
1155- if retain_last_n is not None :
1156- for i , snapshot in enumerate (sorted_snapshots ):
1157- if i < retain_last_n :
1158- snapshots_to_keep .add (snapshot .snapshot_id )
1148+ snapshots_to_keep = self ._get_snapshots_to_keep (retain_last_n = retain_last_n )
11591149
11601150 # Apply timestamp constraint
1151+ candidates_for_expiration = []
11611152 for snapshot in self ._transaction .table_metadata .snapshots :
11621153 if snapshot .snapshot_id not in snapshots_to_keep and (timestamp_ms is None or snapshot .timestamp_ms < timestamp_ms ):
11631154 candidates_for_expiration .append (snapshot )
@@ -1166,18 +1157,12 @@ def _get_snapshots_to_expire_with_retention(
11661157 candidates_for_expiration .sort (key = lambda s : s .timestamp_ms )
11671158
11681159 # Apply min_snapshots_to_keep constraint
1169- total_snapshots = len (self ._transaction .table_metadata .snapshots )
1170- snapshots_to_expire : List [int ] = []
1171-
1172- for candidate in candidates_for_expiration :
1173- # Check if expiring this snapshot would violate min_snapshots_to_keep
1174- remaining_after_expiration = total_snapshots - len (snapshots_to_expire ) - 1
1175-
1176- if min_snapshots_to_keep is None or remaining_after_expiration >= min_snapshots_to_keep :
1177- snapshots_to_expire .append (candidate .snapshot_id )
1178- else :
1179- # Stop expiring to maintain minimum count
1180- break
1160+ if min_snapshots_to_keep is not None :
1161+ total_snapshots = len (self ._transaction .table_metadata .snapshots )
1162+ max_to_expire = total_snapshots - min_snapshots_to_keep
1163+ snapshots_to_expire = [candidate .snapshot_id for candidate in candidates_for_expiration [:max_to_expire ]]
1164+ else :
1165+ snapshots_to_expire = [candidate .snapshot_id for candidate in candidates_for_expiration ]
11811166
11821167 return snapshots_to_expire
11831168
@@ -1187,15 +1172,12 @@ def _get_expiration_properties(self) -> tuple[Optional[int], Optional[int], Opti
11871172 Returns:
11881173 Tuple of (max_snapshot_age_ms, min_snapshots_to_keep, max_ref_age_ms)
11891174 """
1190- properties = self ._transaction .table_metadata .properties
1191-
1192- max_snapshot_age_ms = properties .get ("history.expire.max-snapshot-age-ms" )
1193- max_snapshot_age = int (max_snapshot_age_ms ) if max_snapshot_age_ms is not None else None
1175+ from pyiceberg .table import TableProperties
11941176
1195- min_snapshots = properties .get ("history.expire.min-snapshots-to-keep" )
1196- min_snapshots_to_keep = int (min_snapshots ) if min_snapshots is not None else None
1177+ properties = self ._transaction .table_metadata .properties
11971178
1198- max_ref_age = properties .get ("history.expire.max-ref-age-ms" )
1199- max_ref_age_ms = int (max_ref_age ) if max_ref_age is not None else None
1179+ max_snapshot_age = property_as_int (properties , TableProperties .MAX_SNAPSHOT_AGE_MS )
1180+ min_snapshots_to_keep = property_as_int (properties , TableProperties .MIN_SNAPSHOTS_TO_KEEP )
1181+ max_ref_age_ms = property_as_int (properties , "history.expire.max-ref-age-ms" )
12001182
12011183 return max_snapshot_age , min_snapshots_to_keep , max_ref_age_ms
0 commit comments