Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT

file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
schema = table_metadata.schema()
arrow_file_schema = schema_to_pyarrow(schema)
arrow_file_schema = schema.as_arrow()

fo = io.new_output(file_path)
row_group_size = PropertyUtil.property_as_int(
Expand Down
18 changes: 15 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,15 @@
_JAVA_LONG_MAX = 9223372036854775807


def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.

Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.

Raises:
ValueError: If the schemas are not compatible.
"""
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema

name_mapping = table_schema.name_mapping
Expand Down Expand Up @@ -1118,7 +1126,9 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema(self.schema(), other_schema=df.schema)
_check_schema_compatible(self.schema(), other_schema=df.schema)
# the two schemas are compatible so safe to cast
df = df.cast(self.schema().as_arrow())

with self.transaction() as txn:
with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
Expand Down Expand Up @@ -1156,7 +1166,9 @@ def overwrite(
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema(self.schema(), other_schema=df.schema)
_check_schema_compatible(self.schema(), other_schema=df.schema)
# the two schemas are compatible so safe to cast
df = df.cast(self.schema().as_arrow())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should _check_schema_compatible return a bool to indicate if the cast is needed? I'm not sure how costly the cast is. If we go from string to large_string then we might rewrite the Arrow buffers.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I like that idea.
_check_schema_compatible returns a boolean should_cast.

  • If schema is exactly the same, return False and skip cast
  • If schema is "compatible", return True and cast
  • If schema is not "compatible", throws an error

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was too complicated when _check_schema_compatible returned a boolean and threw an error.
I ended up doing an extra comparison as Arrow schemas outside and cast only if necessary


with self.transaction() as txn:
with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
Expand Down
33 changes: 33 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,39 @@ def test_create_table_with_pyarrow_schema(
catalog.drop_table(random_identifier)


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
# lazy_fixture('catalog_sqlite'),
],
)
def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None:
import pyarrow as pa

pyarrow_table = pa.Table.from_arrays(
[
pa.array([None, "A", "B", "C"]), # 'foo' column
pa.array([1, 2, 3, 4]), # 'bar' column
pa.array([True, None, False, True]), # 'baz' column
pa.array([None, "A", "B", "C"]), # 'large' column
],
schema=pa.schema([
pa.field('foo', pa.string(), nullable=True),
pa.field('bar', pa.int32(), nullable=False),
pa.field('baz', pa.bool_(), nullable=True),
pa.field('large', pa.large_string(), nullable=True),
]),
)
database_name, _table_name = random_identifier
catalog.create_namespace(database_name)
table = catalog.create_table(random_identifier, pyarrow_table.schema)
print(pyarrow_table.schema)
print(table.schema().as_struct())
print()
table.overwrite(pyarrow_table)


@pytest.mark.parametrize(
'catalog',
[
Expand Down
24 changes: 19 additions & 5 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
TableIdentifier,
UpdateSchema,
_apply_table_update,
_check_schema,
_check_schema_compatible,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
update_table_metadata,
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
Expand All @@ -1054,7 +1054,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
Expand All @@ -1074,7 +1074,7 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
Expand All @@ -1088,7 +1088,21 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_downcast(table_schema_simple: Schema) -> None:
# large_string type is compatible with string type
other_schema = pa.schema((
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

try:
_check_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_schema`")


def test_table_properties(example_table_metadata_v2: Dict[str, Any]) -> None:
Expand Down