Skip to content

Commit a368bd9

Browse files
authored
Use Pydantic's model_copy for model modification (#182)
* Implement table metadata updater first draft * fix updater error and add tests * implement apply_metadata_update which is simpler * remove old implementation * re-organize method place * fix nit * fix test * add another test * clear TODO * add a combined test * Fix merge conflict * remove table requirement validation for PR simplification * make context private and solve elif issue * remove private field access * push snapshot ref validation to its builder using pydantic * fix comment * remove unnecessary code for AddSchemaUpdate update * replace if with elif * switch to model_copy() * enhance the set current schema update implementation and some other changes * make apply_table_update private * fix lint after merge * add validation * add test for isolation of illegal updates * fix nit * remove unnecessary flag * change to model_copy(deep=True)
1 parent 34b18e4 commit a368bd9

File tree

2 files changed

+78
-30
lines changed

2 files changed

+78
-30
lines changed

pyiceberg/table/__init__.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -417,12 +417,13 @@ def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMeta
417417
if update.last_column_id < base_metadata.last_column_id:
418418
raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}")
419419

420-
updated_metadata_data = copy(base_metadata.model_dump())
421-
updated_metadata_data["last-column-id"] = update.last_column_id
422-
updated_metadata_data["schemas"].append(update.schema_.model_dump())
423-
424420
context.add_update(update)
425-
return TableMetadataUtil.parse_obj(updated_metadata_data)
421+
return base_metadata.model_copy(
422+
update={
423+
"last_column_id": update.last_column_id,
424+
"schemas": base_metadata.schemas + [update.schema_],
425+
}
426+
)
426427

427428

428429
@_apply_table_update.register(SetCurrentSchemaUpdate)
@@ -441,11 +442,8 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta
441442
if schema is None:
442443
raise ValueError(f"Schema with id {new_schema_id} does not exist")
443444

444-
updated_metadata_data = copy(base_metadata.model_dump())
445-
updated_metadata_data["current-schema-id"] = new_schema_id
446-
447445
context.add_update(update)
448-
return TableMetadataUtil.parse_obj(updated_metadata_data)
446+
return base_metadata.model_copy(update={"current_schema_id": new_schema_id})
449447

450448

451449
@_apply_table_update.register(AddSnapshotUpdate)
@@ -469,12 +467,14 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe
469467
f"older than last sequence number {base_metadata.last_sequence_number}"
470468
)
471469

472-
updated_metadata_data = copy(base_metadata.model_dump())
473-
updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms
474-
updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number
475-
updated_metadata_data["snapshots"].append(update.snapshot.model_dump())
476470
context.add_update(update)
477-
return TableMetadataUtil.parse_obj(updated_metadata_data)
471+
return base_metadata.model_copy(
472+
update={
473+
"last_updated_ms": update.snapshot.timestamp_ms,
474+
"last_sequence_number": update.snapshot.sequence_number,
475+
"snapshots": base_metadata.snapshots + [update.snapshot],
476+
}
477+
)
478478

479479

480480
@_apply_table_update.register(SetSnapshotRefUpdate)
@@ -493,28 +493,27 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl
493493

494494
snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id)
495495
if snapshot is None:
496-
raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}")
496+
raise ValueError(f"Cannot set {update.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}")
497497

498-
update_metadata_data = copy(base_metadata.model_dump())
499-
update_last_updated_ms = True
498+
metadata_updates: Dict[str, Any] = {}
500499
if context.is_added_snapshot(snapshot_ref.snapshot_id):
501-
update_metadata_data["last-updated-ms"] = snapshot.timestamp_ms
502-
update_last_updated_ms = False
500+
metadata_updates["last_updated_ms"] = snapshot.timestamp_ms
503501

504502
if update.ref_name == MAIN_BRANCH:
505-
update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id
506-
if update_last_updated_ms:
507-
update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone())
508-
update_metadata_data["snapshot-log"].append(
503+
metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
504+
if "last_updated_ms" not in metadata_updates:
505+
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone())
506+
507+
metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
509508
SnapshotLogEntry(
510509
snapshot_id=snapshot_ref.snapshot_id,
511-
timestamp_ms=update_metadata_data["last-updated-ms"],
512-
).model_dump()
513-
)
510+
timestamp_ms=metadata_updates["last_updated_ms"],
511+
)
512+
]
514513

515-
update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump()
514+
metadata_updates["refs"] = {**base_metadata.refs, update.ref_name: snapshot_ref}
516515
context.add_update(update)
517-
return TableMetadataUtil.parse_obj(update_metadata_data)
516+
return base_metadata.model_copy(update=metadata_updates)
518517

519518

520519
def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata:
@@ -533,7 +532,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda
533532
for update in updates:
534533
new_metadata = _apply_table_update(update, new_metadata, context)
535534

536-
return new_metadata
535+
return new_metadata.model_copy(deep=True)
537536

538537

539538
class TableRequirement(IcebergBaseModel):

tests/table/test_init.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint:disable=redefined-outer-name
18+
from copy import copy
1819
from typing import Dict
1920

2021
import pytest
@@ -50,7 +51,7 @@
5051
_TableMetadataUpdateContext,
5152
update_table_metadata,
5253
)
53-
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER
54+
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2
5455
from pyiceberg.table.snapshots import (
5556
Operation,
5657
Snapshot,
@@ -640,9 +641,12 @@ def test_update_metadata_with_multiple_updates(table_v1: Table) -> None:
640641
)
641642

642643
new_metadata = update_table_metadata(base_metadata, test_updates)
644+
# rebuild the metadata to trigger validation
645+
new_metadata = TableMetadataUtil.parse_obj(copy(new_metadata.model_dump()))
643646

644647
# UpgradeFormatVersionUpdate
645648
assert new_metadata.format_version == 2
649+
assert isinstance(new_metadata, TableMetadataV2)
646650

647651
# UpdateSchema
648652
assert len(new_metadata.schemas) == 2
@@ -669,6 +673,51 @@ def test_update_metadata_with_multiple_updates(table_v1: Table) -> None:
669673
)
670674

671675

676+
def test_metadata_isolation_from_illegal_updates(table_v1: Table) -> None:
677+
base_metadata = table_v1.metadata
678+
base_metadata_backup = base_metadata.model_copy(deep=True)
679+
680+
# Apply legal updates on the table metadata
681+
transaction = table_v1.transaction()
682+
schema_update_1 = transaction.update_schema()
683+
schema_update_1.add_column(path="b", field_type=IntegerType())
684+
schema_update_1.commit()
685+
test_updates = transaction._updates # pylint: disable=W0212
686+
new_snapshot = Snapshot(
687+
snapshot_id=25,
688+
parent_snapshot_id=19,
689+
sequence_number=200,
690+
timestamp_ms=1602638573590,
691+
manifest_list="s3:/a/b/c.avro",
692+
summary=Summary(Operation.APPEND),
693+
schema_id=3,
694+
)
695+
test_updates += (
696+
AddSnapshotUpdate(snapshot=new_snapshot),
697+
SetSnapshotRefUpdate(
698+
ref_name="main",
699+
type="branch",
700+
snapshot_id=25,
701+
max_ref_age_ms=123123123,
702+
max_snapshot_age_ms=12312312312,
703+
min_snapshots_to_keep=1,
704+
),
705+
)
706+
new_metadata = update_table_metadata(base_metadata, test_updates)
707+
708+
# Check that the original metadata is not modified
709+
assert base_metadata == base_metadata_backup
710+
711+
# Perform illegal update on the new metadata:
712+
# TableMetadata should be immutable, but the pydantic's frozen config cannot prevent
713+
# operations such as list append.
714+
new_metadata.partition_specs.append(PartitionSpec(spec_id=0))
715+
assert len(new_metadata.partition_specs) == 2
716+
717+
# The original metadata should not be affected by the illegal update on the new metadata
718+
assert len(base_metadata.partition_specs) == 1
719+
720+
672721
def test_generate_snapshot_id(table_v2: Table) -> None:
673722
assert isinstance(_generate_snapshot_id(), int)
674723
assert isinstance(table_v2.new_snapshot_id(), int)

0 commit comments

Comments
 (0)