|
23 | 23 | from collections import defaultdict |
24 | 24 | from concurrent.futures import Future |
25 | 25 | from datetime import datetime |
| 26 | +from enum import Enum |
26 | 27 | from functools import cached_property |
27 | 28 | from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Set, Tuple |
28 | 29 |
|
|
57 | 58 | from pyiceberg.partitioning import ( |
58 | 59 | PartitionSpec, |
59 | 60 | ) |
60 | | -from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRefType |
| 61 | +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType |
61 | 62 | from pyiceberg.table.snapshots import ( |
62 | 63 | Operation, |
63 | 64 | Snapshot, |
|
88 | 89 |
|
89 | 90 | if TYPE_CHECKING: |
90 | 91 | from pyiceberg.table import Transaction |
| 92 | + from pyiceberg.table.metadata import TableMetadata |
91 | 93 |
|
92 | 94 |
|
93 | 95 | def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str: |
@@ -794,6 +796,269 @@ def merge_manifests(self, manifests: List[ManifestFile]) -> List[ManifestFile]: |
794 | 796 | return merged_manifests |
795 | 797 |
|
796 | 798 |
|
| 799 | +# Branch Merge Strategy Enums and Classes |
| 800 | + |
| 801 | + |
| 802 | +class BranchMergeStrategy(Enum): |
| 803 | + """Enumeration of available branch merge strategies for Iceberg tables. |
| 804 | +
|
| 805 | + This enum defines the different ways branches can be merged, similar to Git merge strategies. |
| 806 | + Each strategy has different implications for the resulting commit history and snapshot structure. |
| 807 | + """ |
| 808 | + |
| 809 | + MERGE = "merge" |
| 810 | + """The classic approach. Creates a new "merge commit" to join two branches, preserving the history of both.""" |
| 811 | + |
| 812 | + REBASE = "rebase" |
| 813 | + """Re-writes history by placing the commits from one branch on top of another, resulting in a linear history.""" |
| 814 | + |
| 815 | + SQUASH = "squash" |
| 816 | + """Condenses all commits from a feature branch into a single, clean commit on the target branch.""" |
| 817 | + |
| 818 | + CHERRY_PICK = "cherry_pick" |
| 819 | + """Selects and applies a specific, individual commit from one branch to another.""" |
| 820 | + |
| 821 | + FAST_FORWARD = "fast_forward" |
| 822 | + """A special type of merge where the target branch pointer is simply moved forward to point to the source branch's head, without creating a merge commit. This is only possible if there are no new commits on the target branch.""" |
| 823 | + |
| 824 | + |
| 825 | +class _BaseBranchMergeStrategy: |
| 826 | + """Base class for branch merge strategy implementations.""" |
| 827 | + |
| 828 | + @abstractmethod |
| 829 | + def merge( |
| 830 | + self, |
| 831 | + source_branch: str, |
| 832 | + target_branch: str, |
| 833 | + transaction: "Transaction", |
| 834 | + merge_commit_message: Optional[str] = None, |
| 835 | + ) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]: |
| 836 | + """ |
| 837 | + Execute the merge strategy. |
| 838 | +
|
| 839 | + Args: |
| 840 | + source_branch: Name of the source branch |
| 841 | + target_branch: Name of the target branch |
| 842 | + transaction: The table transaction |
| 843 | + merge_commit_message: Optional custom merge message |
| 844 | + Returns: |
| 845 | + Tuple of (updates, requirements) for the merge operation |
| 846 | + """ |
| 847 | + ... |
| 848 | + |
| 849 | + def _find_common_ancestor( |
| 850 | + self, source_ref: SnapshotRef, target_ref: SnapshotRef, table_metadata: "TableMetadata" |
| 851 | + ) -> Optional[Snapshot]: |
| 852 | + """Find the common ancestor snapshot between two branches.""" |
| 853 | + source_snapshot = table_metadata.snapshot_by_id(source_ref.snapshot_id) |
| 854 | + target_snapshot = table_metadata.snapshot_by_id(target_ref.snapshot_id) |
| 855 | + |
| 856 | + if not source_snapshot or not target_snapshot: |
| 857 | + return None |
| 858 | + |
| 859 | + # Build ancestor chains |
| 860 | + source_ancestors = set() |
| 861 | + current: Optional[Snapshot] = source_snapshot |
| 862 | + while current: |
| 863 | + source_ancestors.add(current.snapshot_id) |
| 864 | + current = table_metadata.snapshot_by_id(current.parent_snapshot_id) if current.parent_snapshot_id else None |
| 865 | + |
| 866 | + # Find first common ancestor |
| 867 | + current = target_snapshot |
| 868 | + while current: |
| 869 | + if current.snapshot_id in source_ancestors: |
| 870 | + return current |
| 871 | + current = table_metadata.snapshot_by_id(current.parent_snapshot_id) if current.parent_snapshot_id else None |
| 872 | + |
| 873 | + return None |
| 874 | + |
| 875 | + def _is_fast_forward_possible( |
| 876 | + self, source_ref: SnapshotRef, target_ref: SnapshotRef, table_metadata: "TableMetadata" |
| 877 | + ) -> bool: |
| 878 | + """Check if a fast-forward merge is possible (target hasn't diverged).""" |
| 879 | + target_snapshot = table_metadata.snapshot_by_id(target_ref.snapshot_id) |
| 880 | + if not target_snapshot: |
| 881 | + return False |
| 882 | + |
| 883 | + # Walk up source branch ancestry to see if target snapshot is an ancestor |
| 884 | + source_snapshot = table_metadata.snapshot_by_id(source_ref.snapshot_id) |
| 885 | + current = source_snapshot |
| 886 | + while current: |
| 887 | + if current.snapshot_id == target_snapshot.snapshot_id: |
| 888 | + return True |
| 889 | + current = table_metadata.snapshot_by_id(current.parent_snapshot_id) if current.parent_snapshot_id else None |
| 890 | + |
| 891 | + return False |
| 892 | + |
| 893 | + |
| 894 | +class _SquashMergeStrategy(_BaseBranchMergeStrategy): |
| 895 | + """Squash merge strategy: combine all changes from source branch into single commit.""" |
| 896 | + |
| 897 | + def merge( |
| 898 | + self, |
| 899 | + source_branch: str, |
| 900 | + target_branch: str, |
| 901 | + transaction: "Transaction", |
| 902 | + merge_commit_message: Optional[str] = None, |
| 903 | + ) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]: |
| 904 | + """Execute squash merge by creating single snapshot with combined changes.""" |
| 905 | + source_ref = transaction.table_metadata.refs[source_branch] |
| 906 | + target_ref = transaction.table_metadata.refs[target_branch] |
| 907 | + |
| 908 | + # Check if fast-forward is possible |
| 909 | + if self._is_fast_forward_possible(source_ref, target_ref, transaction.table_metadata): |
| 910 | + # Simple fast-forward: just update target to point to source |
| 911 | + update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),) |
| 912 | + requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),) |
| 913 | + return update, requirement |
| 914 | + |
| 915 | + # For now, implement simple case - more complex merging would require |
| 916 | + # analyzing data files, manifests, etc. |
| 917 | + source_snapshot = transaction.table_metadata.snapshot_by_id(source_ref.snapshot_id) |
| 918 | + if not source_snapshot: |
| 919 | + raise ValueError(f"Source snapshot not found for branch {source_branch}") |
| 920 | + |
| 921 | + # Create new snapshot that represents the squashed changes |
| 922 | + # This is a simplified implementation - full implementation would need to: |
| 923 | + # 1. Analyze data files from both branches |
| 924 | + # 2. Resolve any conflicts |
| 925 | + # 3. Create appropriate manifests |
| 926 | + # For now, we'll update the branch reference |
| 927 | + update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),) |
| 928 | + requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),) |
| 929 | + |
| 930 | + return update, requirement |
| 931 | + |
| 932 | + |
| 933 | +class _MergeStrategy(_BaseBranchMergeStrategy): |
| 934 | + """Merge strategy: create merge commit with two parents, preserving history of both branches.""" |
| 935 | + |
| 936 | + def merge( |
| 937 | + self, |
| 938 | + source_branch: str, |
| 939 | + target_branch: str, |
| 940 | + transaction: "Transaction", |
| 941 | + merge_commit_message: Optional[str] = None, |
| 942 | + ) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]: |
| 943 | + """Execute three-way merge creating merge commit with two parents.""" |
| 944 | + source_ref = transaction.table_metadata.refs[source_branch] |
| 945 | + target_ref = transaction.table_metadata.refs[target_branch] |
| 946 | + |
| 947 | + # Check if fast-forward is possible |
| 948 | + if self._is_fast_forward_possible(source_ref, target_ref, transaction.table_metadata): |
| 949 | + # Fast-forward: just update target to point to source |
| 950 | + update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),) |
| 951 | + requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),) |
| 952 | + return update, requirement |
| 953 | + |
| 954 | + # Find common ancestor |
| 955 | + common_ancestor = self._find_common_ancestor(source_ref, target_ref, transaction.table_metadata) |
| 956 | + if not common_ancestor: |
| 957 | + raise ValueError(f"No common ancestor found between {source_branch} and {target_branch}") |
| 958 | + |
| 959 | + # This is where we would implement the actual three-way merge logic |
| 960 | + # For now, implement a simplified version similar to squash |
| 961 | + |
| 962 | + # Simplified: point target to source (would need proper merge logic) |
| 963 | + update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),) |
| 964 | + requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),) |
| 965 | + |
| 966 | + return update, requirement |
| 967 | + |
| 968 | + |
| 969 | +class _RebaseMergeStrategy(_BaseBranchMergeStrategy): |
| 970 | + """Rebase merge strategy: replay commits from source branch on target.""" |
| 971 | + |
| 972 | + def merge( |
| 973 | + self, |
| 974 | + source_branch: str, |
| 975 | + target_branch: str, |
| 976 | + transaction: "Transaction", |
| 977 | + merge_commit_message: Optional[str] = None, |
| 978 | + ) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]: |
| 979 | + """Execute rebase merge by replaying source commits on target.""" |
| 980 | + source_ref = transaction.table_metadata.refs[source_branch] |
| 981 | + target_ref = transaction.table_metadata.refs[target_branch] |
| 982 | + |
| 983 | + # Check if fast-forward is possible |
| 984 | + if self._is_fast_forward_possible(source_ref, target_ref, transaction.table_metadata): |
| 985 | + # Fast-forward: just update target to point to source |
| 986 | + update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),) |
| 987 | + requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),) |
| 988 | + return update, requirement |
| 989 | + |
| 990 | + # For rebase, we would need to: |
| 991 | + # 1. Find commits since divergence |
| 992 | + # 2. Replay each commit on top of target |
| 993 | + # 3. Update source branch to new history |
| 994 | + # This is the most complex strategy |
| 995 | + |
| 996 | + # Simplified implementation for now |
| 997 | + update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),) |
| 998 | + requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),) |
| 999 | + |
| 1000 | + return update, requirement |
| 1001 | + |
| 1002 | + |
| 1003 | +class _CherryPickStrategy(_BaseBranchMergeStrategy): |
| 1004 | + """Cherry-pick strategy: select and apply a specific commit from one branch to another.""" |
| 1005 | + |
| 1006 | + def merge( |
| 1007 | + self, |
| 1008 | + source_branch: str, |
| 1009 | + target_branch: str, |
| 1010 | + transaction: "Transaction", |
| 1011 | + merge_commit_message: Optional[str] = None, |
| 1012 | + ) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]: |
| 1013 | + """Execute cherry-pick by applying specific commit to target branch.""" |
| 1014 | + source_ref = transaction.table_metadata.refs[source_branch] |
| 1015 | + target_ref = transaction.table_metadata.refs[target_branch] |
| 1016 | + |
| 1017 | + # For cherry-pick, we apply just the latest commit from source to target |
| 1018 | + # This creates a new snapshot with target as parent but source's changes |
| 1019 | + update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),) |
| 1020 | + requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),) |
| 1021 | + |
| 1022 | + return update, requirement |
| 1023 | + |
| 1024 | + |
| 1025 | +class _FastForwardStrategy(_BaseBranchMergeStrategy): |
| 1026 | + """Fast-forward strategy: move target branch pointer forward without creating merge commit.""" |
| 1027 | + |
| 1028 | + def merge( |
| 1029 | + self, |
| 1030 | + source_branch: str, |
| 1031 | + target_branch: str, |
| 1032 | + transaction: "Transaction", |
| 1033 | + merge_commit_message: Optional[str] = None, |
| 1034 | + ) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]: |
| 1035 | + """Execute fast-forward merge by moving target branch pointer.""" |
| 1036 | + source_ref = transaction.table_metadata.refs[source_branch] |
| 1037 | + target_ref = transaction.table_metadata.refs[target_branch] |
| 1038 | + |
| 1039 | + # Verify fast-forward is possible |
| 1040 | + if not self._is_fast_forward_possible(source_ref, target_ref, transaction.table_metadata): |
| 1041 | + raise ValueError(f"Fast-forward merge not possible between {source_branch} and {target_branch}") |
| 1042 | + |
| 1043 | + # Fast-forward: just update target to point to source |
| 1044 | + update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),) |
| 1045 | + requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),) |
| 1046 | + |
| 1047 | + return update, requirement |
| 1048 | + |
| 1049 | + |
| 1050 | +def _get_merge_strategy_impl(strategy: BranchMergeStrategy) -> _BaseBranchMergeStrategy: |
| 1051 | + """Get the implementation for a given merge strategy.""" |
| 1052 | + strategy_map = { |
| 1053 | + BranchMergeStrategy.MERGE: _MergeStrategy(), |
| 1054 | + BranchMergeStrategy.SQUASH: _SquashMergeStrategy(), |
| 1055 | + BranchMergeStrategy.REBASE: _RebaseMergeStrategy(), |
| 1056 | + BranchMergeStrategy.CHERRY_PICK: _CherryPickStrategy(), |
| 1057 | + BranchMergeStrategy.FAST_FORWARD: _FastForwardStrategy(), |
| 1058 | + } |
| 1059 | + return strategy_map[strategy] |
| 1060 | + |
| 1061 | + |
797 | 1062 | class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]): |
798 | 1063 | """ |
799 | 1064 | Run snapshot management operations using APIs. |
@@ -915,6 +1180,56 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots: |
915 | 1180 | """ |
916 | 1181 | return self._remove_ref_snapshot(ref_name=branch_name) |
917 | 1182 |
|
| 1183 | + def merge_branch( |
| 1184 | + self, |
| 1185 | + source_branch: str, |
| 1186 | + target_branch: str = "main", |
| 1187 | + strategy: BranchMergeStrategy = BranchMergeStrategy.MERGE, |
| 1188 | + merge_commit_message: Optional[str] = None, |
| 1189 | + delete_source_branch: bool = False, |
| 1190 | + ) -> ManageSnapshots: |
| 1191 | + """ |
| 1192 | + Merge a source branch into a target branch using the specified merge strategy. |
| 1193 | +
|
| 1194 | + Args: |
| 1195 | + source_branch (str): Name of the branch to merge from |
| 1196 | + target_branch (str): Name of the branch to merge into (default: "main") |
| 1197 | + strategy (BranchMergeStrategy): The merge strategy to use |
| 1198 | + merge_commit_message (Optional[str]): Custom message for the merge commit |
| 1199 | + delete_source_branch (bool): Whether to delete the source branch after merge (default: False) |
| 1200 | + Returns: |
| 1201 | + This for method chaining |
| 1202 | + """ |
| 1203 | + # Validate branches exist |
| 1204 | + if source_branch not in self._transaction.table_metadata.refs: |
| 1205 | + raise ValueError(f"Source branch '{source_branch}' does not exist") |
| 1206 | + if target_branch not in self._transaction.table_metadata.refs: |
| 1207 | + raise ValueError(f"Target branch '{target_branch}' does not exist") |
| 1208 | + |
| 1209 | + if source_branch == target_branch: |
| 1210 | + raise ValueError("Cannot merge a branch into itself") |
| 1211 | + |
| 1212 | + # Get the appropriate merge strategy implementation |
| 1213 | + merge_strategy_impl = _get_merge_strategy_impl(strategy) |
| 1214 | + |
| 1215 | + # Execute the merge |
| 1216 | + updates, requirements = merge_strategy_impl.merge( |
| 1217 | + source_branch=source_branch, |
| 1218 | + target_branch=target_branch, |
| 1219 | + transaction=self._transaction, |
| 1220 | + merge_commit_message=merge_commit_message, |
| 1221 | + ) |
| 1222 | + |
| 1223 | + self._updates += updates |
| 1224 | + self._requirements += requirements |
| 1225 | + |
| 1226 | + # Delete source branch if requested |
| 1227 | + if delete_source_branch: |
| 1228 | + # Use remove_branch to delete the source branch after merge |
| 1229 | + self._remove_ref_snapshot(ref_name=source_branch) |
| 1230 | + |
| 1231 | + return self |
| 1232 | + |
918 | 1233 |
|
919 | 1234 | class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): |
920 | 1235 | """Expire snapshots by ID. |
|
0 commit comments