Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
pre_order_visit,
promote,
prune_columns,
sanitize_column_names,
visit,
visit_with_partner,
)
Expand Down Expand Up @@ -1016,7 +1017,6 @@ def _task_to_table(

if len(arrow_table) < 1:
return None

return to_requested_schema(projected_schema, file_project_schema, arrow_table)


Expand Down Expand Up @@ -1769,8 +1769,9 @@ def data_file_statistics_from_parquet_metadata(


def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
schema = table_metadata.schema()
arrow_file_schema = schema.as_arrow()
iceberg_table_schema = table_metadata.schema()
parquet_schema = sanitize_column_names(iceberg_table_schema)
arrow_file_schema = parquet_schema.as_arrow()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I realize we have many names, but that might be confusing. Parquet-schema is appropriate today since we only support parquet, but we might also support ORC and Avro later.

Suggested change
parquet_schema = sanitize_column_names(iceberg_table_schema)
arrow_file_schema = parquet_schema.as_arrow()
arrow_file_schema = sanitize_column_names(iceberg_table_schema).as_arrow()

parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)

row_group_size = PropertyUtil.property_as_int(
Expand All @@ -1780,16 +1781,17 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
)

def write_parquet(task: WriteTask) -> DataFile:
arrow_table = pa.Table.from_batches(task.record_batches)
df = to_requested_schema(requested_schema=parquet_schema, file_schema=iceberg_table_schema, table=arrow_table)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if from_arrays in the ArrowProjectionVisitor is no-op?

The I'm asking is that we're introducing quite a bit of logic here, and I think the rewrites are only applicable for Avro: https://avro.apache.org/docs/1.8.1/spec.html#names

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick check, depending on how long the from_arrays take, it doesn't seem to copy anything:

python3.9
Python 3.9.18 (main, Aug 24 2023, 18:16:58) 
[Clang 15.0.0 (clang-1500.1.0.2.5)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import pyarrow as pa
>>> numbers = pa.array(range(100000000))
>>> pa.Table.from_arrays([numbers], names=['abc'])
pyarrow.Table
abc: int64
----
abc: [[0,1,2,3,4,...,99999995,99999996,99999997,99999998,99999999]]

file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
writer.write(pa.Table.from_batches(task.record_batches), row_group_size=row_group_size)

writer.write(df, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(schema),
stats_columns=compute_statistics_plan(parquet_schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(parquet_schema),
)
data_file = DataFile(
content=DataFileContent.DATA,
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,6 +2060,7 @@ def spark() -> "SparkSession":
.config("spark.sql.catalog.hive.warehouse", "s3://warehouse/hive/")
.config("spark.sql.catalog.hive.s3.endpoint", "http://localhost:9000")
.config("spark.sql.catalog.hive.s3.path-style-access", "true")
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
.getOrCreate()
)

Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> Non
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == 'data_file':
right = right.asDict(recursive=True)
for df_column in left.keys():
if df_column == 'partition':
# Spark leaves out the partition if the table is unpartitioned
Expand All @@ -185,8 +184,6 @@ def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> Non

assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
elif column == 'readable_metrics':
right = right.asDict(recursive=True)

assert list(left.keys()) == [
'bool',
'string',
Expand Down
20 changes: 19 additions & 1 deletion tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,27 @@ def test_python_writes_special_character_column_with_spark_reads(
column_name_with_special_character = "letter/abc"
TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN = {
column_name_with_special_character: ['a', None, 'z'],
'id': [1, 2, 3],
'name': ['AB', 'CD', 'EF'],
'address': [
{'street': '123', 'city': 'SFO', 'zip': 12345, column_name_with_special_character: 'a'},
{'street': '456', 'city': 'SW', 'zip': 67890, column_name_with_special_character: 'b'},
{'street': '789', 'city': 'Random', 'zip': 10112, column_name_with_special_character: 'c'},
],
}
pa_schema = pa.schema([
(column_name_with_special_character, pa.string()),
pa.field(column_name_with_special_character, pa.string()),
pa.field('id', pa.int32()),
pa.field('name', pa.string()),
pa.field(
'address',
pa.struct([
pa.field('street', pa.string()),
pa.field('city', pa.string()),
pa.field('zip', pa.int32()),
pa.field(column_name_with_special_character, pa.string()),
]),
),
])
arrow_table_with_special_character_column = pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema)
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema)
Expand Down