Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,8 +1010,7 @@ def _task_to_table(

if len(arrow_table) < 1:
return None

return to_requested_schema(projected_schema, file_project_schema, arrow_table)
return to_requested_schema(table=arrow_table, from_schema=file_project_schema, to_schema=projected_schema)


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -1122,12 +1121,12 @@ def project_table(
return result


def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
def to_requested_schema(table: pa.Table, from_schema: Schema, to_schema: Schema) -> pa.Table:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the helper method

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can pull this refactor into a separate PR if it helps with review

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a public method, so we're breaking the API here. Not sure if a refactor justifies the the breaking change. Also, The file_Schema and requested_schema are more informative to me.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize it's a public API. I reverted the refactor

struct_array = visit_with_partner(to_schema, table, ArrowProjectionVisitor(from_schema), ArrowAccessor(from_schema))

arrays = []
fields = []
for pos, field in enumerate(requested_schema.fields):
for pos, field in enumerate(to_schema.fields):
array = struct_array.field(pos)
arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
Expand Down Expand Up @@ -1761,8 +1760,9 @@ def data_file_statistics_from_parquet_metadata(


def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
schema = table_metadata.schema()
arrow_file_schema = schema.as_arrow()
iceberg_table_schema = table_metadata.schema()
parquet_schema = sanitize_column_names(table_metadata.schema())
arrow_file_schema = parquet_schema.as_arrow()
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)

row_group_size = PropertyUtil.property_as_int(
Expand All @@ -1772,16 +1772,17 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
)

def write_parquet(task: WriteTask) -> DataFile:
arrow_table = pa.Table.from_batches(task.record_batches)
df = to_requested_schema(table=arrow_table, from_schema=iceberg_table_schema, to_schema=parquet_schema)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to the read side (#83 & #597), turns the Arrow table from unsanitized Iceberg table schema to sanitized parquet schema

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we batch the incoming dataframe first (in _dataframe_to_data_files) and then transform the scheme for each batch.

We can optimize by transforming first and then batching.

I want the schema transformation to happen as closely to the parquet writing as possible, so going with the first method for now

file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
writer.write(pa.Table.from_batches(task.record_batches), row_group_size=row_group_size)

writer.write(df, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(schema),
stats_columns=compute_statistics_plan(parquet_schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(parquet_schema),
)
data_file = DataFile(
content=DataFileContent.DATA,
Expand Down
42 changes: 42 additions & 0 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,48 @@ def get_current_snapshot_id(identifier: str) -> int:
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_python_writes_special_character_column_with_spark_reads(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
identifier = "default.python_writes_special_character_column_with_spark_reads"
column_name_with_special_character = "letter/abc"
TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN = {
column_name_with_special_character: ['a', None, 'z'],
'id': [1, 2, 3],
'name': ['AB', 'CD', 'EF'],
'address': [
{'street': '123', 'city': 'SFO', 'zip': 12345, column_name_with_special_character: 'a'},
{'street': '456', 'city': 'SW', 'zip': 67890, column_name_with_special_character: 'b'},
{'street': '789', 'city': 'Random', 'zip': 10112, column_name_with_special_character: 'c'},
],
}
pa_schema = pa.schema([
pa.field(column_name_with_special_character, pa.string()),
pa.field('id', pa.int32()),
pa.field('name', pa.string()),
pa.field(
'address',
pa.struct([
pa.field('street', pa.string()),
pa.field('city', pa.string()),
pa.field('zip', pa.int32()),
pa.field(column_name_with_special_character, pa.string()),
]),
),
])
arrow_table_with_special_character_column = pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema)
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema)

tbl.overwrite(arrow_table_with_special_character_column)
# PySpark toPandas() turns nested field into tuple by default, but returns the proper schema when Arrow is enabled
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add this to the spark fixture in conftest.py? Since the fixture's scope is "session", if we change the config here, all tests before this line will not have the configuration and all after this line will have this enabled. Moving it to the initialization part can ensure we have a consistent set of spark configs throughout the integration tests. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! i didn't know about the fixture scope behavior. Moved to conftest

spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas()
pyiceberg_df = tbl.scan().to_pandas()
assert spark_df.equals(pyiceberg_df)


@pytest.mark.integration
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.write_bin_pack_data_files"
Expand Down
7 changes: 4 additions & 3 deletions tests/integration/test_writes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
from typing import List, Optional
from typing import List, Optional, Union

import pyarrow as pa

Expand Down Expand Up @@ -65,6 +65,7 @@ def _create_table(
properties: Properties,
data: Optional[List[pa.Table]] = None,
partition_spec: Optional[PartitionSpec] = None,
schema: Union[Schema, "pa.Schema"] = TABLE_SCHEMA,
) -> Table:
try:
session_catalog.drop_table(identifier=identifier)
Expand All @@ -73,10 +74,10 @@ def _create_table(

if partition_spec:
tbl = session_catalog.create_table(
identifier=identifier, schema=TABLE_SCHEMA, properties=properties, partition_spec=partition_spec
identifier=identifier, schema=schema, properties=properties, partition_spec=partition_spec
)
else:
tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties=properties)
tbl = session_catalog.create_table(identifier=identifier, schema=schema, properties=properties)

if data:
for d in data:
Expand Down