Skip to content

Commit 65fc183

Browse files
committed
feat: Add rollback_to_snapshot to ManageSnapshots API
1 parent b0880c8 commit 65fc183

File tree

3 files changed

+185
-0
lines changed

3 files changed

+185
-0
lines changed

pyiceberg/table/update/snapshot.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
Snapshot,
6565
SnapshotSummaryCollector,
6666
Summary,
67+
ancestors_of,
6768
update_snapshot_summaries,
6869
)
6970
from pyiceberg.table.update import (
@@ -985,6 +986,38 @@ def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | N
985986
self._transaction._stage(update, requirement)
986987
return self
987988

989+
def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
990+
"""Rollback the table to the given snapshot id.
991+
992+
The snapshot needs to be an ancestor of the current table state.
993+
994+
Args:
995+
snapshot_id (int): rollback to this snapshot_id that used to be current.
996+
Returns:
997+
This for method chaining
998+
Raises:
999+
ValueError: If the snapshot does not exist or is not an ancestor of the current table state.
1000+
"""
1001+
if not self._transaction.table_metadata.snapshot_by_id(snapshot_id):
1002+
raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")
1003+
1004+
if not self._is_current_ancestor(snapshot_id):
1005+
raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")
1006+
1007+
return self.set_current_snapshot(snapshot_id=snapshot_id)
1008+
1009+
def _is_current_ancestor(self, snapshot_id: int) -> bool:
1010+
return snapshot_id in self._current_ancestors()
1011+
1012+
def _current_ancestors(self) -> set[int]:
1013+
return {
1014+
a.snapshot_id
1015+
for a in ancestors_of(
1016+
self._transaction._table.current_snapshot(),
1017+
self._transaction.table_metadata,
1018+
)
1019+
}
1020+
9881021

9891022
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
9901023
"""Expire snapshots by ID.

tests/integration/test_snapshot_operations.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,29 @@ def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
160160
tbl = catalog.load_table(identifier)
161161
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
162162
assert tbl.metadata.refs.get(tag_name, None) is None
163+
164+
165+
@pytest.mark.integration
166+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
167+
def test_rollback_to_snapshot(catalog: Catalog) -> None:
168+
identifier = "default.test_table_snapshot_operations"
169+
tbl = catalog.load_table(identifier)
170+
assert len(tbl.history()) > 2
171+
172+
# get the current snapshot and an ancestor
173+
current_snapshot_id = tbl.history()[-1].snapshot_id
174+
ancestor_snapshot_id = tbl.history()[-2].snapshot_id
175+
assert ancestor_snapshot_id != current_snapshot_id
176+
177+
# rollback to the ancestor snapshot
178+
tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()
179+
180+
tbl = catalog.load_table(identifier)
181+
updated_snapshot = tbl.current_snapshot()
182+
assert updated_snapshot and updated_snapshot.snapshot_id == ancestor_snapshot_id
183+
184+
# restore table
185+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
186+
tbl = catalog.load_table(identifier)
187+
restored_snapshot = tbl.current_snapshot()
188+
assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id

tests/table/test_manage_snapshots.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pytest
2121

22+
from pyiceberg.io import load_file_io
2223
from pyiceberg.table import CommitTableResponse, Table
2324
from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate
2425

@@ -177,3 +178,128 @@ def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None:
177178
# The main branch should point to the same snapshot as the tag
178179
main_update = next(u for u in set_ref_updates if u.ref_name == "main")
179180
assert main_update.snapshot_id == snapshot_one
181+
182+
183+
def test_rollback_to_snapshot(table_v2: Table) -> None:
184+
ancestor_snapshot_id = 3051729675574597004
185+
186+
table_v2.catalog = MagicMock()
187+
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)
188+
189+
table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()
190+
191+
table_v2.catalog.commit_table.assert_called_once()
192+
193+
updates = _get_updates(table_v2.catalog)
194+
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]
195+
196+
assert len(set_ref_updates) == 1
197+
update = set_ref_updates[0]
198+
assert update.snapshot_id == ancestor_snapshot_id
199+
assert update.ref_name == "main"
200+
assert update.type == "branch"
201+
202+
203+
def test_rollback_to_snapshot_unknown_id(table_v2: Table) -> None:
204+
invalid_snapshot_id = 1234567890000
205+
table_v2.catalog = MagicMock()
206+
207+
with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"):
208+
table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit()
209+
210+
table_v2.catalog.commit_table.assert_not_called()
211+
212+
213+
def test_rollback_to_snapshot_not_ancestor(table_v2: Table) -> None:
214+
from pyiceberg.table.metadata import TableMetadataV2
215+
216+
# create a table with a branching snapshot history:
217+
snapshot_a = 1
218+
snapshot_b = 2 # current
219+
snapshot_c = 3 # branch from a, not ancestor of b
220+
221+
metadata_dict = {
222+
"format-version": 2,
223+
"table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
224+
"location": "s3://bucket/test/location",
225+
"last-sequence-number": 3,
226+
"last-updated-ms": 1602638573590,
227+
"last-column-id": 1,
228+
"current-schema-id": 0,
229+
"schemas": [{"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}],
230+
"default-spec-id": 0,
231+
"partition-specs": [{"spec-id": 0, "fields": []}],
232+
"last-partition-id": 999,
233+
"default-sort-order-id": 0,
234+
"current-snapshot-id": snapshot_b,
235+
"snapshots": [
236+
{
237+
"snapshot-id": snapshot_a,
238+
"timestamp-ms": 1000,
239+
"sequence-number": 1,
240+
"manifest-list": "s3://a/1.avro",
241+
},
242+
{
243+
"snapshot-id": snapshot_b,
244+
"parent-snapshot-id": snapshot_a,
245+
"timestamp-ms": 2000,
246+
"sequence-number": 2,
247+
"manifest-list": "s3://a/2.avro",
248+
},
249+
{
250+
"snapshot-id": snapshot_c,
251+
"parent-snapshot-id": snapshot_a,
252+
"timestamp-ms": 3000,
253+
"sequence-number": 3,
254+
"manifest-list": "s3://a/3.avro",
255+
},
256+
],
257+
}
258+
259+
from pyiceberg.table import Table
260+
261+
branching_table = Table(
262+
identifier=("db", "table"),
263+
metadata=TableMetadataV2(**metadata_dict),
264+
metadata_location="s3://bucket/test/metadata.json",
265+
io=load_file_io(),
266+
catalog=MagicMock(),
267+
)
268+
269+
# snapshot_c exists but is not an ancestor of snapshot_b (current)
270+
with pytest.raises(ValueError, match="Cannot roll back to snapshot, not an ancestor of the current state"):
271+
branching_table.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c).commit()
272+
273+
274+
def test_rollback_to_snapshot_chained_with_tag(table_v2: Table) -> None:
275+
ancestor_snapshot_id = 3051729675574597004
276+
277+
table_v2.catalog = MagicMock()
278+
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)
279+
280+
(
281+
table_v2.manage_snapshots()
282+
.create_tag(snapshot_id=ancestor_snapshot_id, tag_name="before-rollback")
283+
.rollback_to_snapshot(snapshot_id=ancestor_snapshot_id)
284+
.commit()
285+
)
286+
287+
table_v2.catalog.commit_table.assert_called_once()
288+
289+
updates = _get_updates(table_v2.catalog)
290+
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]
291+
292+
assert len(set_ref_updates) == 2
293+
ref_names = {u.ref_name for u in set_ref_updates}
294+
assert ref_names == {"before-rollback", "main"}
295+
296+
297+
def test_rollback_to_current_snapshot(table_v2: Table) -> None:
298+
current_snapshot = table_v2.current_snapshot()
299+
assert current_snapshot is not None
300+
301+
table_v2.catalog = MagicMock()
302+
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)
303+
304+
table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=current_snapshot.snapshot_id).commit()
305+
table_v2.catalog.commit_table.assert_called_once()

0 commit comments

Comments
 (0)