Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 41 additions & 18 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
import logging
import os
import re
import warnings
from abc import ABC, abstractmethod
from concurrent.futures import Future
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache, singledispatch
from itertools import chain
from itertools import chain, count
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -713,28 +714,50 @@ def primitive(self, primitive: pa.DataType) -> Optional[T]:
"""Visit a primitive type."""


def _get_field_id(field: pa.Field) -> Optional[int]:
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field_id_str := field.metadata.get(pyarrow_field_id_key):
return int(field_id_str.decode())
return None
class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
counter: count[int]
missing_is_metadata: Optional[bool]

def __init__(self) -> None:
self.counter = count()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use count(1) here since we start from 1 inassign_fresh_schema_ids

class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
"""Traverses the schema and assigns monotonically increasing ids."""
reserved_ids: Dict[int, int]
def __init__(self, next_id_func: Optional[Callable[[], int]] = None) -> None:
self.reserved_ids = {}
counter = itertools.count(1)
self.next_id_func = next_id_func if next_id_func is not None else lambda: next(counter)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the limited context here, it will skip the fields if it doesn't have an ID:
image

Which is kind of awkward.

Copy link
Copy Markdown
Contributor Author

@Fokko Fokko Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of using assign_fresh_schema_ids, since that one is a pre-order, and the default is post-order. I've updated the code, let me know what you think! Appreciate the review!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I happen to have an iceberg table (migrated from delta lake) whose parquet files contain no field-id. With this change, I am now able to use pyiceberg to read its data. This is really great! Out of curiosity, are there any additional use-cases where this PR might be beneficial?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I happen to have an iceberg table (migrated from delta lake) whose parquet files contain no field-id. With this change, I am now able to use pyiceberg to read its data.

There is already a way to assign field IDs when they are not in a data file, using a name mapping. All reads that need to infer field IDs must use a name mapping rather than assigning IDs per data file.

self.missing_is_metadata = None

def _get_field_doc(field: pa.Field) -> Optional[str]:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if doc_str := field.metadata.get(pyarrow_doc_key):
return doc_str.decode()
return None
def _get_field_id(self, field: pa.Field) -> int:
field_id: Optional[int] = None

for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field.metadata and (field_id_str := field.metadata.get(pyarrow_field_id_key)):
field_id = int(field_id_str.decode())

if field_id is None:
if self.missing_is_metadata is None:
warnings.warn("Field-ids are missing, generating new IDs")

field_id = next(self.counter)
missing_is_metadata = True
else:
missing_is_metadata = False

if self.missing_is_metadata is not None and self.missing_is_metadata != missing_is_metadata:
raise ValueError("Parquet file contains partial field-ids")
else:
self.missing_is_metadata = missing_is_metadata

return field_id

def _get_field_doc(self, field: pa.Field) -> Optional[str]:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if field.metadata and (doc_str := field.metadata.get(pyarrow_doc_key)):
return doc_str.decode()
return None

class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
fields = []
for i, field in enumerate(arrow_fields):
field_id = _get_field_id(field)
field_doc = _get_field_doc(field)
field_id = self._get_field_id(field)
field_doc = self._get_field_doc(field)
field_type = field_results[i]
if field_type is not None and field_id is not None:
if field_type is not None:
fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc))
return fields

Expand All @@ -746,7 +769,7 @@ def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType

def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
element_field = list_type.value_field
element_id = _get_field_id(element_field)
element_id = self._get_field_id(element_field)
if element_result is not None and element_id is not None:
return ListType(element_id, element_result, element_required=not element_field.nullable)
return None
Expand All @@ -755,9 +778,9 @@ def map(
self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
) -> Optional[IcebergType]:
key_field = map_type.key_field
key_id = _get_field_id(key_field)
key_id = self._get_field_id(key_field)
value_field = map_type.item_field
value_id = _get_field_id(value_field)
value_id = self._get_field_id(value_field)
if key_result is not None and value_result is not None and key_id is not None and value_id is not None:
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
return None
Expand Down
30 changes: 30 additions & 0 deletions tests/io/test_pyarrow_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=protected-access,unused-argument,redefined-outer-name
import re
from unittest.mock import Mock, patch

import pyarrow as pa
import pytest
Expand Down Expand Up @@ -269,3 +270,32 @@ def test_round_schema_conversion_nested(table_schema_nested: Schema) -> None:
15: person: optional struct<16: name: optional string, 17: age: required int>
}"""
assert actual == expected


@patch("warnings.warn")
def test_schema_to_pyarrow_schema_missing_ids(warn: Mock) -> None:
schema = pa.schema([pa.field('some_int', pa.int32(), nullable=True), pa.field('some_string', pa.string(), nullable=False)])
actual = pyarrow_to_schema(schema)

expected = Schema(
NestedField(field_id=0, name="some_int", field_type=IntegerType(), required=False),
NestedField(field_id=1, name="some_string", field_type=StringType(), required=True),
)

assert actual == expected
assert warn.called


@patch("warnings.warn")
def test_schema_to_pyarrow_schema_missing_id(warn: Mock) -> None:
schema = pa.schema(
[
pa.field('some_int', pa.int32(), nullable=True),
pa.field('some_string', pa.string(), nullable=False, metadata={b"field_id": "22"}),
]
)

with pytest.raises(ValueError) as exc_info:
_ = pyarrow_to_schema(schema)
assert "Parquet file contains partial field-ids" in str(exc_info.value)
assert warn.called