-
Notifications
You must be signed in to change notification settings - Fork 471
Add UnionByName functionality #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -69,11 +69,14 @@ | |||||
| ) | ||||||
| from pyiceberg.partitioning import PartitionSpec | ||||||
| from pyiceberg.schema import ( | ||||||
| PartnerAccessor, | ||||||
| Schema, | ||||||
| SchemaVisitor, | ||||||
| SchemaWithPartnerVisitor, | ||||||
| assign_fresh_schema_ids, | ||||||
| promote, | ||||||
| visit, | ||||||
| visit_with_partner, | ||||||
| ) | ||||||
| from pyiceberg.table.metadata import ( | ||||||
| INITIAL_SEQUENCE_NUMBER, | ||||||
|
|
@@ -1379,7 +1382,7 @@ class Move: | |||||
|
|
||||||
|
|
||||||
| class UpdateSchema: | ||||||
| _table: Table | ||||||
| _table: Optional[Table] | ||||||
| _schema: Schema | ||||||
| _last_column_id: itertools.count[int] | ||||||
| _identifier_field_names: Set[str] | ||||||
|
|
@@ -1398,14 +1401,23 @@ class UpdateSchema: | |||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| table: Table, | ||||||
| table: Optional[Table], | ||||||
| transaction: Optional[Transaction] = None, | ||||||
| allow_incompatible_changes: bool = False, | ||||||
| case_sensitive: bool = True, | ||||||
| schema: Optional[Schema] = None, | ||||||
| ) -> None: | ||||||
| self._table = table | ||||||
| self._schema = table.schema() | ||||||
| self._last_column_id = itertools.count(table.metadata.last_column_id + 1) | ||||||
|
|
||||||
| if isinstance(schema, Schema): | ||||||
| self._schema = schema | ||||||
| self._last_column_id = itertools.count(1 + schema.highest_field_id) | ||||||
| elif table is not None: | ||||||
| self._schema = table.schema() | ||||||
| self._last_column_id = itertools.count(1 + table.metadata.last_column_id) | ||||||
| else: | ||||||
| raise ValueError("Either provide a table or a schema") | ||||||
|
|
||||||
| self._identifier_field_names = self._schema.identifier_field_names() | ||||||
|
|
||||||
| self._adds = {} | ||||||
|
|
@@ -1449,6 +1461,15 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema: | |||||
| self._case_sensitive = case_sensitive | ||||||
| return self | ||||||
|
|
||||||
| def union_by_name(self, new_schema: Schema) -> UpdateSchema: | ||||||
| visit_with_partner( | ||||||
| new_schema, | ||||||
| -1, | ||||||
| UnionByNameVisitor(update_schema=self, new_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore | ||||||
|
HonahX marked this conversation as resolved.
Outdated
|
||||||
| PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive), | ||||||
| ) | ||||||
| return self | ||||||
|
|
||||||
| def add_column( | ||||||
| self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False | ||||||
| ) -> UpdateSchema: | ||||||
|
|
@@ -1816,6 +1837,9 @@ def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, T | |||||
|
|
||||||
| def commit(self) -> None: | ||||||
| """Apply the pending changes and commit.""" | ||||||
| if self._table is None: | ||||||
| raise ValueError("Requires a table to commit to") | ||||||
|
|
||||||
| new_schema = self._apply() | ||||||
|
|
||||||
| existing_schema_id = next((schema.schema_id for schema in self._table.metadata.schemas if schema == new_schema), None) | ||||||
|
|
@@ -1862,7 +1886,8 @@ def _apply(self) -> Schema: | |||||
|
|
||||||
| field_ids.add(field.field_id) | ||||||
|
|
||||||
| return Schema(*struct.fields, schema_id=1 + max(self._table.schemas().keys()), identifier_field_ids=field_ids) | ||||||
| next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table is not None else self._schema.schema_id) | ||||||
| return Schema(*struct.fields, schema_id=next_schema_id, identifier_field_ids=field_ids) | ||||||
|
|
||||||
| def assign_new_column_id(self) -> int: | ||||||
| return next(self._last_column_id) | ||||||
|
|
@@ -1995,6 +2020,156 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: | |||||
| return primitive | ||||||
|
|
||||||
|
|
||||||
| class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]): | ||||||
| update_schema: UpdateSchema | ||||||
| new_schema: Schema | ||||||
| case_sensitive: bool | ||||||
|
|
||||||
| def __init__(self, update_schema: UpdateSchema, new_schema: Schema, case_sensitive: bool) -> None: | ||||||
| self.update_schema = update_schema | ||||||
| self.new_schema = new_schema | ||||||
| self.case_sensitive = case_sensitive | ||||||
|
|
||||||
| def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: | ||||||
| return struct_result | ||||||
|
|
||||||
| def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool: | ||||||
| if partner_id is None: | ||||||
| return True | ||||||
|
|
||||||
| fields = struct.fields | ||||||
| partner_struct = self._find_field_type(partner_id) | ||||||
|
HonahX marked this conversation as resolved.
|
||||||
|
|
||||||
| for pos, missing in enumerate(missing_positions): | ||||||
| if missing: | ||||||
| self._add_column(partner_id, fields[pos]) | ||||||
| else: | ||||||
| field = fields[pos] | ||||||
| if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive): | ||||||
| self._update_column(field, nested_field) | ||||||
|
|
||||||
| return False | ||||||
|
|
||||||
| def _add_column(self, parent_id: int, field: NestedField) -> None: | ||||||
| if parent_name := self.new_schema.find_column_name(parent_id): | ||||||
| path: Tuple[str, ...] = (parent_name, field.name) | ||||||
| else: | ||||||
| path = (field.name,) | ||||||
|
|
||||||
| self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc) | ||||||
|
|
||||||
| def _update_column(self, field: NestedField, existing_field: NestedField) -> None: | ||||||
| full_name = self.new_schema.find_column_name(existing_field.field_id) | ||||||
|
|
||||||
| if full_name is None: | ||||||
| raise ValueError(f"Could not find field: {existing_field}") | ||||||
|
|
||||||
| if field.optional and existing_field.required: | ||||||
| self.update_schema.make_column_optional(full_name) | ||||||
|
|
||||||
| if field.field_type.is_primitive and field.field_type != existing_field.field_type: | ||||||
| self.update_schema.update_column(full_name, field_type=field.field_type) | ||||||
|
|
||||||
| if field.doc is not None and not field.doc != existing_field.doc: | ||||||
| self.update_schema.update_column(full_name, doc=field.doc) | ||||||
|
|
||||||
| def _find_field_type(self, field_id: int) -> IcebergType: | ||||||
| if field_id == -1: | ||||||
| return self.new_schema.as_struct() | ||||||
| else: | ||||||
| return self.new_schema.find_field(field_id).field_type | ||||||
|
|
||||||
| def field(self, field: NestedField, field_partner: Optional[int], field_result: bool) -> bool: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Shall we name the second argument as
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like it, thanks for the suggestion |
||||||
| return field_partner is None | ||||||
|
|
||||||
| def list(self, list_type: ListType, list_partner: Optional[int], element_missing: bool) -> bool: | ||||||
| if list_partner is None: | ||||||
| return False | ||||||
|
Fokko marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| if element_missing: | ||||||
| raise ValueError("Error traversing schemas: element is missing, but list is present") | ||||||
|
|
||||||
| partner_list_type = self._find_field_type(list_partner) | ||||||
| if not isinstance(partner_list_type, ListType): | ||||||
| raise ValueError(f"Expected list-type, got: {partner_list_type}") | ||||||
|
|
||||||
| self._update_column(list_type.element_field, partner_list_type.element_field) | ||||||
|
|
||||||
| return False | ||||||
|
|
||||||
| def map(self, map_type: MapType, map_partner: Optional[int], key_missing: bool, value_missing: bool) -> bool: | ||||||
| if map_partner is None: | ||||||
| return False | ||||||
|
Fokko marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| if key_missing: | ||||||
| raise ValueError("Error traversing schemas: key is missing, but map is present") | ||||||
|
|
||||||
| if value_missing: | ||||||
| raise ValueError("Error traversing schemas: value is missing, but map is present") | ||||||
|
|
||||||
| partner_map_type = self._find_field_type(map_partner) | ||||||
| if not isinstance(partner_map_type, MapType): | ||||||
| raise ValueError(f"Expected map-type, got: {partner_map_type}") | ||||||
|
|
||||||
| self._update_column(map_type.key_field, partner_map_type.key_field) | ||||||
| self._update_column(map_type.value_field, partner_map_type.value_field) | ||||||
|
|
||||||
| return False | ||||||
|
|
||||||
| def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[int]) -> bool: | ||||||
|
Fokko marked this conversation as resolved.
Outdated
|
||||||
| return primitive_partner is None | ||||||
|
|
||||||
|
|
||||||
| class PartnerIdByNameAccessor(PartnerAccessor[int]): | ||||||
| partner_schema: Schema | ||||||
| case_sensitive: bool | ||||||
|
|
||||||
| def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None: | ||||||
| self.partner_schema = partner_schema | ||||||
| self.case_sensitive = case_sensitive | ||||||
|
|
||||||
| def schema_partner(self, partner: Optional[int]) -> Optional[int]: | ||||||
| return -1 | ||||||
|
|
||||||
| def field_partner(self, partner_field_id: Optional[int], field_id: int, field_name: str) -> Optional[int]: | ||||||
| if partner_field_id is not None: | ||||||
| if partner_field_id == -1: | ||||||
| struct = self.partner_schema.as_struct() | ||||||
| else: | ||||||
| struct = self.partner_schema.find_field(partner_field_id).field_type | ||||||
| if not struct.is_struct: | ||||||
| raise ValueError(f"Expected StructType: {struct}") | ||||||
|
|
||||||
| if field := struct.field_by_name(name=field_name, case_sensitive=self.case_sensitive): | ||||||
| return field.field_id | ||||||
|
|
||||||
| return None | ||||||
|
|
||||||
| def list_element_partner(self, partner_list: Optional[int]) -> Optional[int]: | ||||||
|
HonahX marked this conversation as resolved.
Outdated
|
||||||
| if partner_list is not None and (field := self.partner_schema.find_field(partner_list)): | ||||||
| if not isinstance(field.field_type, ListType): | ||||||
| raise ValueError(f"Expected ListType: {field}") | ||||||
| return field.field_type.element_field.field_id | ||||||
| else: | ||||||
| return None | ||||||
|
|
||||||
| def map_key_partner(self, partner_map: Optional[int]) -> Optional[int]: | ||||||
| if partner_map is not None and (field := self.partner_schema.find_field(partner_map)): | ||||||
| if not isinstance(field.field_type, MapType): | ||||||
| raise ValueError(f"Expected MapType: {field}") | ||||||
| return field.field_type.key_field.field_id | ||||||
| else: | ||||||
| return None | ||||||
|
|
||||||
| def map_value_partner(self, partner_map: Optional[int]) -> Optional[int]: | ||||||
| if partner_map is not None and (field := self.partner_schema.find_field(partner_map)): | ||||||
| if not isinstance(field.field_type, MapType): | ||||||
| raise ValueError(f"Expected MapType: {field}") | ||||||
| return field.field_type.value_field.field_id | ||||||
| else: | ||||||
| return None | ||||||
|
|
||||||
|
|
||||||
| def _add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> Tuple[NestedField, ...]: | ||||||
| adds = adds or [] | ||||||
| return fields + tuple(adds) | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.