4343 Union ,
4444)
4545
46- from pydantic import Field , SerializeAsAny , field_validator
46+ from pydantic import Field , field_validator
4747from sortedcontainers import SortedList
4848from typing_extensions import Annotated
4949
@@ -383,77 +383,56 @@ def commit_transaction(self) -> Table:
383383 return self ._table
384384
385385
386- class TableUpdateAction (Enum ):
387- upgrade_format_version = "upgrade-format-version"
388- add_schema = "add-schema"
389- set_current_schema = "set-current-schema"
390- add_spec = "add-spec"
391- set_default_spec = "set-default-spec"
392- add_sort_order = "add-sort-order"
393- set_default_sort_order = "set-default-sort-order"
394- add_snapshot = "add-snapshot"
395- set_snapshot_ref = "set-snapshot-ref"
396- remove_snapshots = "remove-snapshots"
397- remove_snapshot_ref = "remove-snapshot-ref"
398- set_location = "set-location"
399- set_properties = "set-properties"
400- remove_properties = "remove-properties"
401-
402-
403- class TableUpdate (IcebergBaseModel ):
404- action : TableUpdateAction
405-
406-
407- class UpgradeFormatVersionUpdate (TableUpdate ):
408- action : TableUpdateAction = TableUpdateAction .upgrade_format_version
386+ class UpgradeFormatVersionUpdate (IcebergBaseModel ):
387+ action : Literal ['upgrade-format-version' ] = Field (default = "upgrade-format-version" )
409388 format_version : int = Field (alias = "format-version" )
410389
411390
412- class AddSchemaUpdate (TableUpdate ):
413- action : TableUpdateAction = TableUpdateAction . add_schema
391+ class AddSchemaUpdate (IcebergBaseModel ):
392+ action : Literal [ 'add-schema' ] = Field ( default = "add-schema" )
414393 schema_ : Schema = Field (alias = "schema" )
415394 # This field is required: https://github.com/apache/iceberg/pull/7445
416395 last_column_id : int = Field (alias = "last-column-id" )
417396
418397
419- class SetCurrentSchemaUpdate (TableUpdate ):
420- action : TableUpdateAction = TableUpdateAction . set_current_schema
398+ class SetCurrentSchemaUpdate (IcebergBaseModel ):
399+ action : Literal [ 'set-current-schema' ] = Field ( default = "set-current-schema" )
421400 schema_id : int = Field (
422401 alias = "schema-id" , description = "Schema ID to set as current, or -1 to set last added schema" , default = - 1
423402 )
424403
425404
426- class AddPartitionSpecUpdate (TableUpdate ):
427- action : TableUpdateAction = TableUpdateAction . add_spec
405+ class AddPartitionSpecUpdate (IcebergBaseModel ):
406+ action : Literal [ 'add-spec' ] = Field ( default = "add-spec" )
428407 spec : PartitionSpec
429408
430409
431- class SetDefaultSpecUpdate (TableUpdate ):
432- action : TableUpdateAction = TableUpdateAction . set_default_spec
410+ class SetDefaultSpecUpdate (IcebergBaseModel ):
411+ action : Literal [ 'set-default-spec' ] = Field ( default = "set-default-spec" )
433412 spec_id : int = Field (
434413 alias = "spec-id" , description = "Partition spec ID to set as the default, or -1 to set last added spec" , default = - 1
435414 )
436415
437416
438- class AddSortOrderUpdate (TableUpdate ):
439- action : TableUpdateAction = TableUpdateAction . add_sort_order
417+ class AddSortOrderUpdate (IcebergBaseModel ):
418+ action : Literal [ 'add-sort-order' ] = Field ( default = "add-sort-order" )
440419 sort_order : SortOrder = Field (alias = "sort-order" )
441420
442421
443- class SetDefaultSortOrderUpdate (TableUpdate ):
444- action : TableUpdateAction = TableUpdateAction . set_default_sort_order
422+ class SetDefaultSortOrderUpdate (IcebergBaseModel ):
423+ action : Literal [ 'set-default-sort-order' ] = Field ( default = "set-default-sort-order" )
445424 sort_order_id : int = Field (
446425 alias = "sort-order-id" , description = "Sort order ID to set as the default, or -1 to set last added sort order" , default = - 1
447426 )
448427
449428
450- class AddSnapshotUpdate (TableUpdate ):
451- action : TableUpdateAction = TableUpdateAction . add_snapshot
429+ class AddSnapshotUpdate (IcebergBaseModel ):
430+ action : Literal [ 'add-snapshot' ] = Field ( default = "add-snapshot" )
452431 snapshot : Snapshot
453432
454433
455- class SetSnapshotRefUpdate (TableUpdate ):
456- action : TableUpdateAction = TableUpdateAction . set_snapshot_ref
434+ class SetSnapshotRefUpdate (IcebergBaseModel ):
435+ action : Literal [ 'set-snapshot-ref' ] = Field ( default = "set-snapshot-ref" )
457436 ref_name : str = Field (alias = "ref-name" )
458437 type : Literal ["tag" , "branch" ]
459438 snapshot_id : int = Field (alias = "snapshot-id" )
@@ -462,35 +441,56 @@ class SetSnapshotRefUpdate(TableUpdate):
462441 min_snapshots_to_keep : Annotated [Optional [int ], Field (alias = "min-snapshots-to-keep" , default = None )]
463442
464443
465- class RemoveSnapshotsUpdate (TableUpdate ):
466- action : TableUpdateAction = TableUpdateAction . remove_snapshots
444+ class RemoveSnapshotsUpdate (IcebergBaseModel ):
445+ action : Literal [ 'remove-snapshots' ] = Field ( default = "remove-snapshots" )
467446 snapshot_ids : List [int ] = Field (alias = "snapshot-ids" )
468447
469448
470- class RemoveSnapshotRefUpdate (TableUpdate ):
471- action : TableUpdateAction = TableUpdateAction . remove_snapshot_ref
449+ class RemoveSnapshotRefUpdate (IcebergBaseModel ):
450+ action : Literal [ 'remove-snapshot-ref' ] = Field ( default = "remove-snapshot-ref" )
472451 ref_name : str = Field (alias = "ref-name" )
473452
474453
475- class SetLocationUpdate (TableUpdate ):
476- action : TableUpdateAction = TableUpdateAction . set_location
454+ class SetLocationUpdate (IcebergBaseModel ):
455+ action : Literal [ 'set-location' ] = Field ( default = "set-location" )
477456 location : str
478457
479458
480- class SetPropertiesUpdate (TableUpdate ):
481- action : TableUpdateAction = TableUpdateAction . set_properties
459+ class SetPropertiesUpdate (IcebergBaseModel ):
460+ action : Literal [ 'set-properties' ] = Field ( default = "set-properties" )
482461 updates : Dict [str , str ]
483462
484463 @field_validator ('updates' , mode = 'before' )
485464 def transform_properties_dict_value_to_str (cls , properties : Properties ) -> Dict [str , str ]:
486465 return transform_dict_value_to_str (properties )
487466
488467
489- class RemovePropertiesUpdate (TableUpdate ):
490- action : TableUpdateAction = TableUpdateAction . remove_properties
468+ class RemovePropertiesUpdate (IcebergBaseModel ):
469+ action : Literal [ 'remove-properties' ] = Field ( default = "remove-properties" )
491470 removals : List [str ]
492471
493472
473+ TableUpdate = Annotated [
474+ Union [
475+ UpgradeFormatVersionUpdate ,
476+ AddSchemaUpdate ,
477+ SetCurrentSchemaUpdate ,
478+ AddPartitionSpecUpdate ,
479+ SetDefaultSpecUpdate ,
480+ AddSortOrderUpdate ,
481+ SetDefaultSortOrderUpdate ,
482+ AddSnapshotUpdate ,
483+ SetSnapshotRefUpdate ,
484+ RemoveSnapshotsUpdate ,
485+ RemoveSnapshotRefUpdate ,
486+ SetLocationUpdate ,
487+ SetPropertiesUpdate ,
488+ RemovePropertiesUpdate ,
489+ ],
490+ Field (discriminator = 'action' ),
491+ ]
492+
493+
494494class _TableMetadataUpdateContext :
495495 _updates : List [TableUpdate ]
496496
@@ -502,21 +502,15 @@ def add_update(self, update: TableUpdate) -> None:
502502
503503 def is_added_snapshot (self , snapshot_id : int ) -> bool :
504504 return any (
505- update .snapshot .snapshot_id == snapshot_id
506- for update in self ._updates
507- if update .action == TableUpdateAction .add_snapshot
505+ update .snapshot .snapshot_id == snapshot_id for update in self ._updates if isinstance (update , AddSnapshotUpdate )
508506 )
509507
510508 def is_added_schema (self , schema_id : int ) -> bool :
511- return any (
512- update .schema_ .schema_id == schema_id for update in self ._updates if update .action == TableUpdateAction .add_schema
513- )
509+ return any (update .schema_ .schema_id == schema_id for update in self ._updates if isinstance (update , AddSchemaUpdate ))
514510
515511 def is_added_sort_order (self , sort_order_id : int ) -> bool :
516512 return any (
517- update .sort_order .order_id == sort_order_id
518- for update in self ._updates
519- if update .action == TableUpdateAction .add_sort_order
513+ update .sort_order .order_id == sort_order_id for update in self ._updates if isinstance (update , AddSortOrderUpdate )
520514 )
521515
522516
@@ -767,7 +761,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda
767761 return new_metadata .model_copy (deep = True )
768762
769763
770- class TableRequirement (IcebergBaseModel ):
764+ class ValidatableTableRequirement (IcebergBaseModel ):
771765 type : str
772766
773767 @abstractmethod
@@ -783,7 +777,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
783777 ...
784778
785779
786- class AssertCreate (TableRequirement ):
780+ class AssertCreate (ValidatableTableRequirement ):
787781 """The table must not already exist; used for create transactions."""
788782
789783 type : Literal ["assert-create" ] = Field (default = "assert-create" )
@@ -793,7 +787,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
793787 raise CommitFailedException ("Table already exists" )
794788
795789
796- class AssertTableUUID (TableRequirement ):
790+ class AssertTableUUID (ValidatableTableRequirement ):
797791 """The table UUID must match the requirement's `uuid`."""
798792
799793 type : Literal ["assert-table-uuid" ] = Field (default = "assert-table-uuid" )
@@ -806,7 +800,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
806800 raise CommitFailedException (f"Table UUID does not match: { self .uuid } != { base_metadata .table_uuid } " )
807801
808802
809- class AssertRefSnapshotId (TableRequirement ):
803+ class AssertRefSnapshotId (ValidatableTableRequirement ):
810804 """The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`.
811805
812806 if `snapshot-id` is `null` or missing, the ref must not already exist.
@@ -831,7 +825,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
831825 raise CommitFailedException (f"Requirement failed: branch or tag { self .ref } is missing, expected { self .snapshot_id } " )
832826
833827
834- class AssertLastAssignedFieldId (TableRequirement ):
828+ class AssertLastAssignedFieldId (ValidatableTableRequirement ):
835829 """The table's last assigned column id must match the requirement's `last-assigned-field-id`."""
836830
837831 type : Literal ["assert-last-assigned-field-id" ] = Field (default = "assert-last-assigned-field-id" )
@@ -846,7 +840,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
846840 )
847841
848842
849- class AssertCurrentSchemaId (TableRequirement ):
843+ class AssertCurrentSchemaId (ValidatableTableRequirement ):
850844 """The table's current schema id must match the requirement's `current-schema-id`."""
851845
852846 type : Literal ["assert-current-schema-id" ] = Field (default = "assert-current-schema-id" )
@@ -861,7 +855,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
861855 )
862856
863857
864- class AssertLastAssignedPartitionId (TableRequirement ):
858+ class AssertLastAssignedPartitionId (ValidatableTableRequirement ):
865859 """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`."""
866860
867861 type : Literal ["assert-last-assigned-partition-id" ] = Field (default = "assert-last-assigned-partition-id" )
@@ -876,7 +870,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
876870 )
877871
878872
879- class AssertDefaultSpecId (TableRequirement ):
873+ class AssertDefaultSpecId (ValidatableTableRequirement ):
880874 """The table's default spec id must match the requirement's `default-spec-id`."""
881875
882876 type : Literal ["assert-default-spec-id" ] = Field (default = "assert-default-spec-id" )
@@ -891,7 +885,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
891885 )
892886
893887
894- class AssertDefaultSortOrderId (TableRequirement ):
888+ class AssertDefaultSortOrderId (ValidatableTableRequirement ):
895889 """The table's default sort order id must match the requirement's `default-sort-order-id`."""
896890
897891 type : Literal ["assert-default-sort-order-id" ] = Field (default = "assert-default-sort-order-id" )
@@ -906,6 +900,20 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
906900 )
907901
908902
903+ TableRequirement = Annotated [
904+ Union [
905+ AssertCreate ,
906+ AssertTableUUID ,
907+ AssertRefSnapshotId ,
908+ AssertLastAssignedFieldId ,
909+ AssertCurrentSchemaId ,
910+ AssertLastAssignedPartitionId ,
911+ AssertDefaultSpecId ,
912+ AssertDefaultSortOrderId ,
913+ ],
914+ Field (discriminator = 'type' ),
915+ ]
916+
909917UpdatesAndRequirements = Tuple [Tuple [TableUpdate , ...], Tuple [TableRequirement , ...]]
910918
911919
@@ -927,8 +935,8 @@ class TableIdentifier(IcebergBaseModel):
927935
928936class CommitTableRequest (IcebergBaseModel ):
929937 identifier : TableIdentifier = Field ()
930- requirements : Tuple [SerializeAsAny [ TableRequirement ] , ...] = Field (default_factory = tuple )
931- updates : Tuple [SerializeAsAny [ TableUpdate ] , ...] = Field (default_factory = tuple )
938+ requirements : Tuple [TableRequirement , ...] = Field (default_factory = tuple )
939+ updates : Tuple [TableUpdate , ...] = Field (default_factory = tuple )
932940
933941
934942class CommitTableResponse (IcebergBaseModel ):
0 commit comments