|
19 | 19 |
|
20 | 20 | import pytest |
21 | 21 |
|
| 22 | +from pyiceberg.io import load_file_io |
22 | 23 | from pyiceberg.table import CommitTableResponse, Table |
23 | 24 | from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate |
24 | 25 |
|
@@ -177,3 +178,128 @@ def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None: |
177 | 178 | # The main branch should point to the same snapshot as the tag |
178 | 179 | main_update = next(u for u in set_ref_updates if u.ref_name == "main") |
179 | 180 | assert main_update.snapshot_id == snapshot_one |
| 181 | + |
| 182 | + |
| 183 | +def test_rollback_to_snapshot(table_v2: Table) -> None: |
| 184 | + ancestor_snapshot_id = 3051729675574597004 |
| 185 | + |
| 186 | + table_v2.catalog = MagicMock() |
| 187 | + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) |
| 188 | + |
| 189 | + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit() |
| 190 | + |
| 191 | + table_v2.catalog.commit_table.assert_called_once() |
| 192 | + |
| 193 | + updates = _get_updates(table_v2.catalog) |
| 194 | + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] |
| 195 | + |
| 196 | + assert len(set_ref_updates) == 1 |
| 197 | + update = set_ref_updates[0] |
| 198 | + assert update.snapshot_id == ancestor_snapshot_id |
| 199 | + assert update.ref_name == "main" |
| 200 | + assert update.type == "branch" |
| 201 | + |
| 202 | + |
| 203 | +def test_rollback_to_snapshot_unknown_id(table_v2: Table) -> None: |
| 204 | + invalid_snapshot_id = 1234567890000 |
| 205 | + table_v2.catalog = MagicMock() |
| 206 | + |
| 207 | + with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"): |
| 208 | + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit() |
| 209 | + |
| 210 | + table_v2.catalog.commit_table.assert_not_called() |
| 211 | + |
| 212 | + |
| 213 | +def test_rollback_to_snapshot_not_ancestor(table_v2: Table) -> None: |
| 214 | + from pyiceberg.table.metadata import TableMetadataV2 |
| 215 | + |
| 216 | + # create a table with a branching snapshot history: |
| 217 | + snapshot_a = 1 |
| 218 | + snapshot_b = 2 # current |
| 219 | + snapshot_c = 3 # branch from a, not ancestor of b |
| 220 | + |
| 221 | + metadata_dict = { |
| 222 | + "format-version": 2, |
| 223 | + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", |
| 224 | + "location": "s3://bucket/test/location", |
| 225 | + "last-sequence-number": 3, |
| 226 | + "last-updated-ms": 1602638573590, |
| 227 | + "last-column-id": 1, |
| 228 | + "current-schema-id": 0, |
| 229 | + "schemas": [{"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}], |
| 230 | + "default-spec-id": 0, |
| 231 | + "partition-specs": [{"spec-id": 0, "fields": []}], |
| 232 | + "last-partition-id": 999, |
| 233 | + "default-sort-order-id": 0, |
| 234 | + "current-snapshot-id": snapshot_b, |
| 235 | + "snapshots": [ |
| 236 | + { |
| 237 | + "snapshot-id": snapshot_a, |
| 238 | + "timestamp-ms": 1000, |
| 239 | + "sequence-number": 1, |
| 240 | + "manifest-list": "s3://a/1.avro", |
| 241 | + }, |
| 242 | + { |
| 243 | + "snapshot-id": snapshot_b, |
| 244 | + "parent-snapshot-id": snapshot_a, |
| 245 | + "timestamp-ms": 2000, |
| 246 | + "sequence-number": 2, |
| 247 | + "manifest-list": "s3://a/2.avro", |
| 248 | + }, |
| 249 | + { |
| 250 | + "snapshot-id": snapshot_c, |
| 251 | + "parent-snapshot-id": snapshot_a, |
| 252 | + "timestamp-ms": 3000, |
| 253 | + "sequence-number": 3, |
| 254 | + "manifest-list": "s3://a/3.avro", |
| 255 | + }, |
| 256 | + ], |
| 257 | + } |
| 258 | + |
| 259 | + from pyiceberg.table import Table |
| 260 | + |
| 261 | + branching_table = Table( |
| 262 | + identifier=("db", "table"), |
| 263 | + metadata=TableMetadataV2(**metadata_dict), |
| 264 | + metadata_location="s3://bucket/test/metadata.json", |
| 265 | + io=load_file_io(), |
| 266 | + catalog=MagicMock(), |
| 267 | + ) |
| 268 | + |
| 269 | + # snapshot_c exists but is not an ancestor of snapshot_b (current) |
| 270 | + with pytest.raises(ValueError, match="Cannot roll back to snapshot, not an ancestor of the current state"): |
| 271 | + branching_table.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c).commit() |
| 272 | + |
| 273 | + |
| 274 | +def test_rollback_to_snapshot_chained_with_tag(table_v2: Table) -> None: |
| 275 | + ancestor_snapshot_id = 3051729675574597004 |
| 276 | + |
| 277 | + table_v2.catalog = MagicMock() |
| 278 | + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) |
| 279 | + |
| 280 | + ( |
| 281 | + table_v2.manage_snapshots() |
| 282 | + .create_tag(snapshot_id=ancestor_snapshot_id, tag_name="before-rollback") |
| 283 | + .rollback_to_snapshot(snapshot_id=ancestor_snapshot_id) |
| 284 | + .commit() |
| 285 | + ) |
| 286 | + |
| 287 | + table_v2.catalog.commit_table.assert_called_once() |
| 288 | + |
| 289 | + updates = _get_updates(table_v2.catalog) |
| 290 | + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] |
| 291 | + |
| 292 | + assert len(set_ref_updates) == 2 |
| 293 | + ref_names = {u.ref_name for u in set_ref_updates} |
| 294 | + assert ref_names == {"before-rollback", "main"} |
| 295 | + |
| 296 | + |
| 297 | +def test_rollback_to_current_snapshot(table_v2: Table) -> None: |
| 298 | + current_snapshot = table_v2.current_snapshot() |
| 299 | + assert current_snapshot is not None |
| 300 | + |
| 301 | + table_v2.catalog = MagicMock() |
| 302 | + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) |
| 303 | + |
| 304 | + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=current_snapshot.snapshot_id).commit() |
| 305 | + table_v2.catalog.commit_table.assert_called_once() |
0 commit comments