Skip to content

Commit f9896ff

Browse files
committed
Add replace_tag to ManageSnapshots
1 parent d99e463 commit f9896ff

2 files changed

Lines changed: 80 additions & 1 deletion

File tree

pyiceberg/table/update/snapshot.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,36 @@ def remove_tag(self, tag_name: str) -> ManageSnapshots:
894894
"""
895895
return self._remove_ref_snapshot(ref_name=tag_name)
896896

897+
def replace_tag(self, name: str, snapshot_id: int) -> ManageSnapshots:
898+
"""
899+
Replace the tag with the given name to point to the specified snapshot.
900+
901+
Args:
902+
name (str): Tag to replace
903+
snapshot_id (int): new snapshot id for the given tag
904+
Returns:
905+
This for method chaining
906+
"""
907+
self._commit_if_ref_updates_exist()
908+
909+
refs = self._transaction.table_metadata.refs
910+
if name not in refs:
911+
raise ValueError(f"Tag does not exist: {name}")
912+
913+
ref = refs[name]
914+
if ref.snapshot_ref_type != SnapshotRefType.TAG:
915+
raise ValueError(f"Ref {name} is not a tag")
916+
917+
update, requirement = self._transaction._set_ref_snapshot(
918+
snapshot_id=snapshot_id,
919+
ref_name=name,
920+
type=SnapshotRefType.TAG,
921+
max_ref_age_ms=ref.max_ref_age_ms,
922+
)
923+
self._updates += update
924+
self._requirements += requirement
925+
return self
926+
897927
def create_branch(
898928
self,
899929
snapshot_id: int,

tests/integration/test_snapshot_operations.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from pyiceberg.catalog import Catalog
2525
from pyiceberg.table import Table
26-
from pyiceberg.table.refs import SnapshotRef
26+
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
2727

2828

2929
@pytest.fixture
@@ -107,6 +107,55 @@ def test_remove_branch(catalog: Catalog) -> None:
107107
assert tbl.metadata.refs.get(branch_name, None) is None
108108

109109

110+
@pytest.mark.integration
111+
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
112+
def test_replace_tag(catalog: Catalog) -> None:
113+
identifier = "default.test_table_snapshot_operations"
114+
tbl = catalog.load_table(identifier)
115+
assert len(tbl.history()) > 2
116+
117+
current_snapshot_id = tbl.history()[-1].snapshot_id
118+
older_snapshot_id = tbl.history()[-2].snapshot_id
119+
120+
name = "tag"
121+
tbl.manage_snapshots().create_tag(older_snapshot_id, name, 1).commit()
122+
tag = tbl.metadata.refs.get(name)
123+
assert tag is not None
124+
assert tag.snapshot_id == older_snapshot_id
125+
assert tag.snapshot_ref_type == SnapshotRefType.TAG
126+
assert tag.max_ref_age_ms == 1
127+
128+
tbl.manage_snapshots().replace_tag(name=name, snapshot_id=current_snapshot_id).commit()
129+
130+
tag = tbl.metadata.refs.get(name)
131+
assert tag is not None
132+
assert tag.snapshot_id == current_snapshot_id
133+
assert tag.snapshot_ref_type == SnapshotRefType.TAG
134+
assert tag.max_ref_age_ms == 1
135+
136+
137+
@pytest.mark.integration
138+
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
139+
def test_replace_missing_tag(catalog: Catalog) -> None:
140+
identifier = "default.test_table_snapshot_operations"
141+
tbl = catalog.load_table(identifier)
142+
snapshot_id = tbl.history()[-1].snapshot_id
143+
144+
with pytest.raises(ValueError, match="Tag does not exist: test"):
145+
tbl.manage_snapshots().replace_tag(name="test", snapshot_id=snapshot_id).commit()
146+
147+
148+
@pytest.mark.integration
149+
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
150+
def test_replace_branch(catalog: Catalog) -> None:
151+
identifier = "default.test_table_snapshot_operations"
152+
tbl = catalog.load_table(identifier)
153+
snapshot_id = tbl.history()[-1].snapshot_id
154+
155+
with pytest.raises(ValueError, match="Ref main is not a tag"):
156+
tbl.manage_snapshots().replace_tag(name="main", snapshot_id=snapshot_id).commit()
157+
158+
110159
@pytest.mark.integration
111160
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
112161
def test_set_current_snapshot(catalog: Catalog) -> None:

0 commit comments

Comments
 (0)