Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from abc import ABC, abstractmethod
from datetime import datetime
from functools import singledispatch
from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Optional, Set, Tuple, TypeVar, Union, cast
from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union, cast

from pydantic import Field, field_validator, model_validator
from pydantic import Field, field_validator, model_validator, model_serializer

from pyiceberg.exceptions import CommitFailedException
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec
Expand Down Expand Up @@ -727,6 +727,14 @@ class AssertRefSnapshotId(ValidatableTableRequirement):
ref: str = Field(...)
snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id")

@model_serializer
def ser_model(self) -> dict[str, Any]:
Comment thread
ox marked this conversation as resolved.
Outdated
return {
"type": self.type,
"ref": self.ref,
"snapshot-id": self.snapshot_id,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there a way to call super() or the default serializer? we just want to explicitly override for snapshot-id.

otherwise, we'd have to remember to change to this function everytime we add a new variable

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! Added the mode='wrap' param that allows that


def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
Expand All @@ -745,13 +753,6 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
elif self.snapshot_id is not None:
raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}")

# override the override method, allowing None to serialize to `null` instead of being omitted.
def model_dump_json(
self, exclude_none: bool = False, exclude: Optional[Set[str]] = None, by_alias: bool = True, **kwargs: Any
) -> str:
# `snapshot-id` is required in json response, even if null
return super().model_dump_json(exclude_none=False)


class AssertLastAssignedFieldId(ValidatableTableRequirement):
"""The table's last assigned column id must match the requirement's `last-assigned-field-id`."""
Expand Down
14 changes: 13 additions & 1 deletion tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
import json
import os
import uuid
from typing import Any, Dict
from typing import Any, Dict, Tuple

import pytest
from pytest_mock import MockFixture

from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import StaticTable
from pyiceberg.table.metadata import TableMetadataV1
from pyiceberg.table.update import AssertRefSnapshotId, TableRequirement
from pyiceberg.typedef import IcebergBaseModel


def test_legacy_current_snapshot_id(
Expand All @@ -48,3 +50,13 @@ def test_legacy_current_snapshot_id(
backwards_compatible_static_table = StaticTable.from_metadata(metadata_location)
assert backwards_compatible_static_table.metadata.current_snapshot_id is None
assert backwards_compatible_static_table.metadata == static_table.metadata


def test_null_serializer_field() -> None:
class ExampleRequest(IcebergBaseModel):
requirements: Tuple[TableRequirement, ...]

request = ExampleRequest(requirements=(AssertRefSnapshotId(ref="main", snapshot_id=None),))
dumped_json = request.model_dump_json()
expected_json = """{"type":"assert-ref-snapshot-id","ref":"main","snapshot-id":null}"""
assert expected_json in dumped_json
Comment thread
ox marked this conversation as resolved.
Loading