Skip to content

Commit a077c73

Browse files
kdbhigginsKieran Higgins
andauthored
Fix CommitTableRequest serialisation (#525)
* add failing test * make requirements a discriminated union * use discriminated type union * add return type * use type annotation * have requirements inherit from ValidatableTableRequirement * AddSortOrder filter by type * lint --------- Co-authored-by: Kieran Higgins <khiggins58@bloomberg.net>
1 parent 7f712fd commit a077c73

File tree

2 files changed

+90
-70
lines changed

2 files changed

+90
-70
lines changed

pyiceberg/table/__init__.py

Lines changed: 78 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
Union,
4444
)
4545

46-
from pydantic import Field, SerializeAsAny, field_validator
46+
from pydantic import Field, field_validator
4747
from sortedcontainers import SortedList
4848
from typing_extensions import Annotated
4949

@@ -383,77 +383,56 @@ def commit_transaction(self) -> Table:
383383
return self._table
384384

385385

386-
class TableUpdateAction(Enum):
387-
upgrade_format_version = "upgrade-format-version"
388-
add_schema = "add-schema"
389-
set_current_schema = "set-current-schema"
390-
add_spec = "add-spec"
391-
set_default_spec = "set-default-spec"
392-
add_sort_order = "add-sort-order"
393-
set_default_sort_order = "set-default-sort-order"
394-
add_snapshot = "add-snapshot"
395-
set_snapshot_ref = "set-snapshot-ref"
396-
remove_snapshots = "remove-snapshots"
397-
remove_snapshot_ref = "remove-snapshot-ref"
398-
set_location = "set-location"
399-
set_properties = "set-properties"
400-
remove_properties = "remove-properties"
401-
402-
403-
class TableUpdate(IcebergBaseModel):
404-
action: TableUpdateAction
405-
406-
407-
class UpgradeFormatVersionUpdate(TableUpdate):
408-
action: TableUpdateAction = TableUpdateAction.upgrade_format_version
386+
class UpgradeFormatVersionUpdate(IcebergBaseModel):
387+
action: Literal['upgrade-format-version'] = Field(default="upgrade-format-version")
409388
format_version: int = Field(alias="format-version")
410389

411390

412-
class AddSchemaUpdate(TableUpdate):
413-
action: TableUpdateAction = TableUpdateAction.add_schema
391+
class AddSchemaUpdate(IcebergBaseModel):
392+
action: Literal['add-schema'] = Field(default="add-schema")
414393
schema_: Schema = Field(alias="schema")
415394
# This field is required: https://github.com/apache/iceberg/pull/7445
416395
last_column_id: int = Field(alias="last-column-id")
417396

418397

419-
class SetCurrentSchemaUpdate(TableUpdate):
420-
action: TableUpdateAction = TableUpdateAction.set_current_schema
398+
class SetCurrentSchemaUpdate(IcebergBaseModel):
399+
action: Literal['set-current-schema'] = Field(default="set-current-schema")
421400
schema_id: int = Field(
422401
alias="schema-id", description="Schema ID to set as current, or -1 to set last added schema", default=-1
423402
)
424403

425404

426-
class AddPartitionSpecUpdate(TableUpdate):
427-
action: TableUpdateAction = TableUpdateAction.add_spec
405+
class AddPartitionSpecUpdate(IcebergBaseModel):
406+
action: Literal['add-spec'] = Field(default="add-spec")
428407
spec: PartitionSpec
429408

430409

431-
class SetDefaultSpecUpdate(TableUpdate):
432-
action: TableUpdateAction = TableUpdateAction.set_default_spec
410+
class SetDefaultSpecUpdate(IcebergBaseModel):
411+
action: Literal['set-default-spec'] = Field(default="set-default-spec")
433412
spec_id: int = Field(
434413
alias="spec-id", description="Partition spec ID to set as the default, or -1 to set last added spec", default=-1
435414
)
436415

437416

438-
class AddSortOrderUpdate(TableUpdate):
439-
action: TableUpdateAction = TableUpdateAction.add_sort_order
417+
class AddSortOrderUpdate(IcebergBaseModel):
418+
action: Literal['add-sort-order'] = Field(default="add-sort-order")
440419
sort_order: SortOrder = Field(alias="sort-order")
441420

442421

443-
class SetDefaultSortOrderUpdate(TableUpdate):
444-
action: TableUpdateAction = TableUpdateAction.set_default_sort_order
422+
class SetDefaultSortOrderUpdate(IcebergBaseModel):
423+
action: Literal['set-default-sort-order'] = Field(default="set-default-sort-order")
445424
sort_order_id: int = Field(
446425
alias="sort-order-id", description="Sort order ID to set as the default, or -1 to set last added sort order", default=-1
447426
)
448427

449428

450-
class AddSnapshotUpdate(TableUpdate):
451-
action: TableUpdateAction = TableUpdateAction.add_snapshot
429+
class AddSnapshotUpdate(IcebergBaseModel):
430+
action: Literal['add-snapshot'] = Field(default="add-snapshot")
452431
snapshot: Snapshot
453432

454433

455-
class SetSnapshotRefUpdate(TableUpdate):
456-
action: TableUpdateAction = TableUpdateAction.set_snapshot_ref
434+
class SetSnapshotRefUpdate(IcebergBaseModel):
435+
action: Literal['set-snapshot-ref'] = Field(default="set-snapshot-ref")
457436
ref_name: str = Field(alias="ref-name")
458437
type: Literal["tag", "branch"]
459438
snapshot_id: int = Field(alias="snapshot-id")
@@ -462,35 +441,56 @@ class SetSnapshotRefUpdate(TableUpdate):
462441
min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)]
463442

464443

465-
class RemoveSnapshotsUpdate(TableUpdate):
466-
action: TableUpdateAction = TableUpdateAction.remove_snapshots
444+
class RemoveSnapshotsUpdate(IcebergBaseModel):
445+
action: Literal['remove-snapshots'] = Field(default="remove-snapshots")
467446
snapshot_ids: List[int] = Field(alias="snapshot-ids")
468447

469448

470-
class RemoveSnapshotRefUpdate(TableUpdate):
471-
action: TableUpdateAction = TableUpdateAction.remove_snapshot_ref
449+
class RemoveSnapshotRefUpdate(IcebergBaseModel):
450+
action: Literal['remove-snapshot-ref'] = Field(default="remove-snapshot-ref")
472451
ref_name: str = Field(alias="ref-name")
473452

474453

475-
class SetLocationUpdate(TableUpdate):
476-
action: TableUpdateAction = TableUpdateAction.set_location
454+
class SetLocationUpdate(IcebergBaseModel):
455+
action: Literal['set-location'] = Field(default="set-location")
477456
location: str
478457

479458

480-
class SetPropertiesUpdate(TableUpdate):
481-
action: TableUpdateAction = TableUpdateAction.set_properties
459+
class SetPropertiesUpdate(IcebergBaseModel):
460+
action: Literal['set-properties'] = Field(default="set-properties")
482461
updates: Dict[str, str]
483462

484463
@field_validator('updates', mode='before')
485464
def transform_properties_dict_value_to_str(cls, properties: Properties) -> Dict[str, str]:
486465
return transform_dict_value_to_str(properties)
487466

488467

489-
class RemovePropertiesUpdate(TableUpdate):
490-
action: TableUpdateAction = TableUpdateAction.remove_properties
468+
class RemovePropertiesUpdate(IcebergBaseModel):
469+
action: Literal['remove-properties'] = Field(default="remove-properties")
491470
removals: List[str]
492471

493472

473+
TableUpdate = Annotated[
474+
Union[
475+
UpgradeFormatVersionUpdate,
476+
AddSchemaUpdate,
477+
SetCurrentSchemaUpdate,
478+
AddPartitionSpecUpdate,
479+
SetDefaultSpecUpdate,
480+
AddSortOrderUpdate,
481+
SetDefaultSortOrderUpdate,
482+
AddSnapshotUpdate,
483+
SetSnapshotRefUpdate,
484+
RemoveSnapshotsUpdate,
485+
RemoveSnapshotRefUpdate,
486+
SetLocationUpdate,
487+
SetPropertiesUpdate,
488+
RemovePropertiesUpdate,
489+
],
490+
Field(discriminator='action'),
491+
]
492+
493+
494494
class _TableMetadataUpdateContext:
495495
_updates: List[TableUpdate]
496496

@@ -502,21 +502,15 @@ def add_update(self, update: TableUpdate) -> None:
502502

503503
def is_added_snapshot(self, snapshot_id: int) -> bool:
504504
return any(
505-
update.snapshot.snapshot_id == snapshot_id
506-
for update in self._updates
507-
if update.action == TableUpdateAction.add_snapshot
505+
update.snapshot.snapshot_id == snapshot_id for update in self._updates if isinstance(update, AddSnapshotUpdate)
508506
)
509507

510508
def is_added_schema(self, schema_id: int) -> bool:
511-
return any(
512-
update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema
513-
)
509+
return any(update.schema_.schema_id == schema_id for update in self._updates if isinstance(update, AddSchemaUpdate))
514510

515511
def is_added_sort_order(self, sort_order_id: int) -> bool:
516512
return any(
517-
update.sort_order.order_id == sort_order_id
518-
for update in self._updates
519-
if update.action == TableUpdateAction.add_sort_order
513+
update.sort_order.order_id == sort_order_id for update in self._updates if isinstance(update, AddSortOrderUpdate)
520514
)
521515

522516

@@ -767,7 +761,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda
767761
return new_metadata.model_copy(deep=True)
768762

769763

770-
class TableRequirement(IcebergBaseModel):
764+
class ValidatableTableRequirement(IcebergBaseModel):
771765
type: str
772766

773767
@abstractmethod
@@ -783,7 +777,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
783777
...
784778

785779

786-
class AssertCreate(TableRequirement):
780+
class AssertCreate(ValidatableTableRequirement):
787781
"""The table must not already exist; used for create transactions."""
788782

789783
type: Literal["assert-create"] = Field(default="assert-create")
@@ -793,7 +787,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
793787
raise CommitFailedException("Table already exists")
794788

795789

796-
class AssertTableUUID(TableRequirement):
790+
class AssertTableUUID(ValidatableTableRequirement):
797791
"""The table UUID must match the requirement's `uuid`."""
798792

799793
type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid")
@@ -806,7 +800,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
806800
raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}")
807801

808802

809-
class AssertRefSnapshotId(TableRequirement):
803+
class AssertRefSnapshotId(ValidatableTableRequirement):
810804
"""The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`.
811805
812806
if `snapshot-id` is `null` or missing, the ref must not already exist.
@@ -831,7 +825,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
831825
raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}")
832826

833827

834-
class AssertLastAssignedFieldId(TableRequirement):
828+
class AssertLastAssignedFieldId(ValidatableTableRequirement):
835829
"""The table's last assigned column id must match the requirement's `last-assigned-field-id`."""
836830

837831
type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id")
@@ -846,7 +840,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
846840
)
847841

848842

849-
class AssertCurrentSchemaId(TableRequirement):
843+
class AssertCurrentSchemaId(ValidatableTableRequirement):
850844
"""The table's current schema id must match the requirement's `current-schema-id`."""
851845

852846
type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id")
@@ -861,7 +855,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
861855
)
862856

863857

864-
class AssertLastAssignedPartitionId(TableRequirement):
858+
class AssertLastAssignedPartitionId(ValidatableTableRequirement):
865859
"""The table's last assigned partition id must match the requirement's `last-assigned-partition-id`."""
866860

867861
type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id")
@@ -876,7 +870,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
876870
)
877871

878872

879-
class AssertDefaultSpecId(TableRequirement):
873+
class AssertDefaultSpecId(ValidatableTableRequirement):
880874
"""The table's default spec id must match the requirement's `default-spec-id`."""
881875

882876
type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id")
@@ -891,7 +885,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
891885
)
892886

893887

894-
class AssertDefaultSortOrderId(TableRequirement):
888+
class AssertDefaultSortOrderId(ValidatableTableRequirement):
895889
"""The table's default sort order id must match the requirement's `default-sort-order-id`."""
896890

897891
type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id")
@@ -906,6 +900,20 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
906900
)
907901

908902

903+
TableRequirement = Annotated[
904+
Union[
905+
AssertCreate,
906+
AssertTableUUID,
907+
AssertRefSnapshotId,
908+
AssertLastAssignedFieldId,
909+
AssertCurrentSchemaId,
910+
AssertLastAssignedPartitionId,
911+
AssertDefaultSpecId,
912+
AssertDefaultSortOrderId,
913+
],
914+
Field(discriminator='type'),
915+
]
916+
909917
UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]
910918

911919

@@ -927,8 +935,8 @@ class TableIdentifier(IcebergBaseModel):
927935

928936
class CommitTableRequest(IcebergBaseModel):
929937
identifier: TableIdentifier = Field()
930-
requirements: Tuple[SerializeAsAny[TableRequirement], ...] = Field(default_factory=tuple)
931-
updates: Tuple[SerializeAsAny[TableUpdate], ...] = Field(default_factory=tuple)
938+
requirements: Tuple[TableRequirement, ...] = Field(default_factory=tuple)
939+
updates: Tuple[TableUpdate, ...] = Field(default_factory=tuple)
932940

933941

934942
class CommitTableResponse(IcebergBaseModel):

tests/table/test_init.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@
5353
AssertLastAssignedPartitionId,
5454
AssertRefSnapshotId,
5555
AssertTableUUID,
56+
CommitTableRequest,
5657
RemovePropertiesUpdate,
5758
SetDefaultSortOrderUpdate,
5859
SetPropertiesUpdate,
5960
SetSnapshotRefUpdate,
6061
StaticTable,
6162
Table,
63+
TableIdentifier,
6264
UpdateSchema,
6365
_apply_table_update,
6466
_check_schema,
@@ -1113,3 +1115,13 @@ def test_table_properties_raise_for_none_value(example_table_metadata_v2: Dict[s
11131115
with pytest.raises(ValidationError) as exc_info:
11141116
TableMetadataV2(**example_table_metadata_v2)
11151117
assert "None type is not a supported value in properties: property_name" in str(exc_info.value)
1118+
1119+
1120+
def test_serialize_commit_table_request() -> None:
1121+
request = CommitTableRequest(
1122+
requirements=(AssertTableUUID(uuid='4bfd18a3-74c6-478e-98b1-71c4c32f4163'),),
1123+
identifier=TableIdentifier(namespace=['a'], name='b'),
1124+
)
1125+
1126+
deserialized_request = CommitTableRequest.model_validate_json(request.model_dump_json())
1127+
assert request == deserialized_request

0 commit comments

Comments
 (0)