Skip to content

Commit ea48360

Browse files
kdbhigginsKieran Higgins
authored andcommitted
Fix CommitTableRequest serialisation (#525)
* add failing test * make requirements a discriminated union * use discriminated type union * add return type * use type annotation * have requirements inherit from ValidatableTableRequirement * AddSortOrder filter by type * lint --------- Co-authored-by: Kieran Higgins <khiggins58@bloomberg.net>
1 parent cc44926 commit ea48360

File tree

1 file changed

+83
-65
lines changed

1 file changed

+83
-65
lines changed

pyiceberg/table/__init__.py

Lines changed: 83 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
Union,
4242
)
4343

44-
from pydantic import Field, SerializeAsAny
44+
from pydantic import Field
4545
from sortedcontainers import SortedList
4646
from typing_extensions import Annotated
4747

@@ -368,77 +368,56 @@ def commit_transaction(self) -> Table:
368368
return self._table
369369

370370

371-
class TableUpdateAction(Enum):
372-
upgrade_format_version = "upgrade-format-version"
373-
add_schema = "add-schema"
374-
set_current_schema = "set-current-schema"
375-
add_spec = "add-spec"
376-
set_default_spec = "set-default-spec"
377-
add_sort_order = "add-sort-order"
378-
set_default_sort_order = "set-default-sort-order"
379-
add_snapshot = "add-snapshot"
380-
set_snapshot_ref = "set-snapshot-ref"
381-
remove_snapshots = "remove-snapshots"
382-
remove_snapshot_ref = "remove-snapshot-ref"
383-
set_location = "set-location"
384-
set_properties = "set-properties"
385-
remove_properties = "remove-properties"
386-
387-
388-
class TableUpdate(IcebergBaseModel):
389-
action: TableUpdateAction
390-
391-
392-
class UpgradeFormatVersionUpdate(TableUpdate):
393-
action: TableUpdateAction = TableUpdateAction.upgrade_format_version
371+
class UpgradeFormatVersionUpdate(IcebergBaseModel):
372+
action: Literal['upgrade-format-version'] = Field(default="upgrade-format-version")
394373
format_version: int = Field(alias="format-version")
395374

396375

397-
class AddSchemaUpdate(TableUpdate):
398-
action: TableUpdateAction = TableUpdateAction.add_schema
376+
class AddSchemaUpdate(IcebergBaseModel):
377+
action: Literal['add-schema'] = Field(default="add-schema")
399378
schema_: Schema = Field(alias="schema")
400379
# This field is required: https://github.com/apache/iceberg/pull/7445
401380
last_column_id: int = Field(alias="last-column-id")
402381

403382

404-
class SetCurrentSchemaUpdate(TableUpdate):
405-
action: TableUpdateAction = TableUpdateAction.set_current_schema
383+
class SetCurrentSchemaUpdate(IcebergBaseModel):
384+
action: Literal['set-current-schema'] = Field(default="set-current-schema")
406385
schema_id: int = Field(
407386
alias="schema-id", description="Schema ID to set as current, or -1 to set last added schema", default=-1
408387
)
409388

410389

411-
class AddPartitionSpecUpdate(TableUpdate):
412-
action: TableUpdateAction = TableUpdateAction.add_spec
390+
class AddPartitionSpecUpdate(IcebergBaseModel):
391+
action: Literal['add-spec'] = Field(default="add-spec")
413392
spec: PartitionSpec
414393

415394

416-
class SetDefaultSpecUpdate(TableUpdate):
417-
action: TableUpdateAction = TableUpdateAction.set_default_spec
395+
class SetDefaultSpecUpdate(IcebergBaseModel):
396+
action: Literal['set-default-spec'] = Field(default="set-default-spec")
418397
spec_id: int = Field(
419398
alias="spec-id", description="Partition spec ID to set as the default, or -1 to set last added spec", default=-1
420399
)
421400

422401

423-
class AddSortOrderUpdate(TableUpdate):
424-
action: TableUpdateAction = TableUpdateAction.add_sort_order
402+
class AddSortOrderUpdate(IcebergBaseModel):
403+
action: Literal['add-sort-order'] = Field(default="add-sort-order")
425404
sort_order: SortOrder = Field(alias="sort-order")
426405

427406

428-
class SetDefaultSortOrderUpdate(TableUpdate):
429-
action: TableUpdateAction = TableUpdateAction.set_default_sort_order
407+
class SetDefaultSortOrderUpdate(IcebergBaseModel):
408+
action: Literal['set-default-sort-order'] = Field(default="set-default-sort-order")
430409
sort_order_id: int = Field(
431410
alias="sort-order-id", description="Sort order ID to set as the default, or -1 to set last added sort order", default=-1
432411
)
433412

434413

435-
class AddSnapshotUpdate(TableUpdate):
436-
action: TableUpdateAction = TableUpdateAction.add_snapshot
414+
class AddSnapshotUpdate(IcebergBaseModel):
415+
action: Literal['add-snapshot'] = Field(default="add-snapshot")
437416
snapshot: Snapshot
438417

439418

440-
class SetSnapshotRefUpdate(TableUpdate):
441-
action: TableUpdateAction = TableUpdateAction.set_snapshot_ref
419+
class SetSnapshotRefUpdate(IcebergBaseModel):
420+
action: Literal['set-snapshot-ref'] = Field(default="set-snapshot-ref")
442421
ref_name: str = Field(alias="ref-name")
443422
type: Literal["tag", "branch"]
444423
snapshot_id: int = Field(alias="snapshot-id")
@@ -447,31 +426,52 @@ class SetSnapshotRefUpdate(TableUpdate):
447426
min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)]
448427

449428

450-
class RemoveSnapshotsUpdate(TableUpdate):
451-
action: TableUpdateAction = TableUpdateAction.remove_snapshots
429+
class RemoveSnapshotsUpdate(IcebergBaseModel):
430+
action: Literal['remove-snapshots'] = Field(default="remove-snapshots")
452431
snapshot_ids: List[int] = Field(alias="snapshot-ids")
453432

454433

455-
class RemoveSnapshotRefUpdate(TableUpdate):
456-
action: TableUpdateAction = TableUpdateAction.remove_snapshot_ref
434+
class RemoveSnapshotRefUpdate(IcebergBaseModel):
435+
action: Literal['remove-snapshot-ref'] = Field(default="remove-snapshot-ref")
457436
ref_name: str = Field(alias="ref-name")
458437

459438

460-
class SetLocationUpdate(TableUpdate):
461-
action: TableUpdateAction = TableUpdateAction.set_location
439+
class SetLocationUpdate(IcebergBaseModel):
440+
action: Literal['set-location'] = Field(default="set-location")
462441
location: str
463442

464443

465-
class SetPropertiesUpdate(TableUpdate):
466-
action: TableUpdateAction = TableUpdateAction.set_properties
444+
class SetPropertiesUpdate(IcebergBaseModel):
445+
action: Literal['set-properties'] = Field(default="set-properties")
467446
updates: Dict[str, str]
468447

469448

470-
class RemovePropertiesUpdate(TableUpdate):
471-
action: TableUpdateAction = TableUpdateAction.remove_properties
449+
class RemovePropertiesUpdate(IcebergBaseModel):
450+
action: Literal['remove-properties'] = Field(default="remove-properties")
472451
removals: List[str]
473452

474453

454+
TableUpdate = Annotated[
455+
Union[
456+
UpgradeFormatVersionUpdate,
457+
AddSchemaUpdate,
458+
SetCurrentSchemaUpdate,
459+
AddPartitionSpecUpdate,
460+
SetDefaultSpecUpdate,
461+
AddSortOrderUpdate,
462+
SetDefaultSortOrderUpdate,
463+
AddSnapshotUpdate,
464+
SetSnapshotRefUpdate,
465+
RemoveSnapshotsUpdate,
466+
RemoveSnapshotRefUpdate,
467+
SetLocationUpdate,
468+
SetPropertiesUpdate,
469+
RemovePropertiesUpdate,
470+
],
471+
Field(discriminator='action'),
472+
]
473+
474+
475475
class _TableMetadataUpdateContext:
476476
_updates: List[TableUpdate]
477477

@@ -483,14 +483,15 @@ def add_update(self, update: TableUpdate) -> None:
483483

484484
def is_added_snapshot(self, snapshot_id: int) -> bool:
485485
return any(
486-
update.snapshot.snapshot_id == snapshot_id
487-
for update in self._updates
488-
if update.action == TableUpdateAction.add_snapshot
486+
update.snapshot.snapshot_id == snapshot_id for update in self._updates if isinstance(update, AddSnapshotUpdate)
489487
)
490488

491489
def is_added_schema(self, schema_id: int) -> bool:
490+
return any(update.schema_.schema_id == schema_id for update in self._updates if isinstance(update, AddSchemaUpdate))
491+
492+
def is_added_sort_order(self, sort_order_id: int) -> bool:
492493
return any(
493-
update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema
494+
update.sort_order.order_id == sort_order_id for update in self._updates if isinstance(update, AddSortOrderUpdate)
494495
)
495496

496497

@@ -674,7 +675,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda
674675
return new_metadata.model_copy(deep=True)
675676

676677

677-
class TableRequirement(IcebergBaseModel):
678+
class ValidatableTableRequirement(IcebergBaseModel):
678679
type: str
679680

680681
@abstractmethod
@@ -690,7 +691,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
690691
...
691692

692693

693-
class AssertCreate(TableRequirement):
694+
class AssertCreate(ValidatableTableRequirement):
694695
"""The table must not already exist; used for create transactions."""
695696

696697
type: Literal["assert-create"] = Field(default="assert-create")
@@ -700,7 +701,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
700701
raise CommitFailedException("Table already exists")
701702

702703

703-
class AssertTableUUID(TableRequirement):
704+
class AssertTableUUID(ValidatableTableRequirement):
704705
"""The table UUID must match the requirement's `uuid`."""
705706

706707
type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid")
@@ -713,7 +714,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
713714
raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}")
714715

715716

716-
class AssertRefSnapshotId(TableRequirement):
717+
class AssertRefSnapshotId(ValidatableTableRequirement):
717718
"""The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`.
718719
719720
if `snapshot-id` is `null` or missing, the ref must not already exist.
@@ -738,7 +739,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
738739
raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}")
739740

740741

741-
class AssertLastAssignedFieldId(TableRequirement):
742+
class AssertLastAssignedFieldId(ValidatableTableRequirement):
742743
"""The table's last assigned column id must match the requirement's `last-assigned-field-id`."""
743744

744745
type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id")
@@ -753,7 +754,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
753754
)
754755

755756

756-
class AssertCurrentSchemaId(TableRequirement):
757+
class AssertCurrentSchemaId(ValidatableTableRequirement):
757758
"""The table's current schema id must match the requirement's `current-schema-id`."""
758759

759760
type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id")
@@ -768,7 +769,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
768769
)
769770

770771

771-
class AssertLastAssignedPartitionId(TableRequirement):
772+
class AssertLastAssignedPartitionId(ValidatableTableRequirement):
772773
"""The table's last assigned partition id must match the requirement's `last-assigned-partition-id`."""
773774

774775
type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id")
@@ -783,7 +784,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
783784
)
784785

785786

786-
class AssertDefaultSpecId(TableRequirement):
787+
class AssertDefaultSpecId(ValidatableTableRequirement):
787788
"""The table's default spec id must match the requirement's `default-spec-id`."""
788789

789790
type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id")
@@ -798,7 +799,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
798799
)
799800

800801

801-
class AssertDefaultSortOrderId(TableRequirement):
802+
class AssertDefaultSortOrderId(ValidatableTableRequirement):
802803
"""The table's default sort order id must match the requirement's `default-sort-order-id`."""
803804

804805
type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id")
@@ -813,6 +814,23 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
813814
)
814815

815816

817+
TableRequirement = Annotated[
818+
Union[
819+
AssertCreate,
820+
AssertTableUUID,
821+
AssertRefSnapshotId,
822+
AssertLastAssignedFieldId,
823+
AssertCurrentSchemaId,
824+
AssertLastAssignedPartitionId,
825+
AssertDefaultSpecId,
826+
AssertDefaultSortOrderId,
827+
],
828+
Field(discriminator='type'),
829+
]
830+
831+
UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]
832+
833+
816834
class Namespace(IcebergRootModel[List[str]]):
817835
"""Reference to one or more levels of a namespace."""
818836

@@ -831,8 +849,8 @@ class TableIdentifier(IcebergBaseModel):
831849

832850
class CommitTableRequest(IcebergBaseModel):
833851
identifier: TableIdentifier = Field()
834-
requirements: Tuple[SerializeAsAny[TableRequirement], ...] = Field(default_factory=tuple)
835-
updates: Tuple[SerializeAsAny[TableUpdate], ...] = Field(default_factory=tuple)
852+
requirements: Tuple[TableRequirement, ...] = Field(default_factory=tuple)
853+
updates: Tuple[TableUpdate, ...] = Field(default_factory=tuple)
836854

837855

838856
class CommitTableResponse(IcebergBaseModel):

0 commit comments

Comments
 (0)