Skip to content

Commit 3dcc344

Browse files
committed
cast to pyarrow schema
1 parent 6989b92 commit 3dcc344

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

pyiceberg/table/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,8 @@ def overwrite(
11471147
except ModuleNotFoundError as e:
11481148
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
11491149

1150+
from pyiceberg.io.pyarrow import schema_to_pyarrow
1151+
11501152
if not isinstance(df, pa.Table):
11511153
raise ValueError(f"Expected PyArrow table, got: {df}")
11521154

@@ -1157,6 +1159,9 @@ def overwrite(
11571159
raise ValueError("Cannot write to partitioned tables")
11581160

11591161
_check_schema(self.schema(), other_schema=df.schema)
1162+
# safe to cast
1163+
pyarrow_schema = schema_to_pyarrow(self.schema())
1164+
df = df.cast(pyarrow_schema)
11601165

11611166
with self.transaction() as txn:
11621167
with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:

tests/catalog/test_sql.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,39 @@ def test_create_table_with_pyarrow_schema(
193193
catalog.drop_table(random_identifier)
194194

195195

196+
@pytest.mark.parametrize(
197+
'catalog',
198+
[
199+
lazy_fixture('catalog_memory'),
200+
# lazy_fixture('catalog_sqlite'),
201+
],
202+
)
203+
def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None:
204+
import pyarrow as pa
205+
206+
pyarrow_table = pa.Table.from_arrays(
207+
[
208+
pa.array([None, "A", "B", "C"]), # 'foo' column
209+
pa.array([1, 2, 3, 4]), # 'bar' column
210+
pa.array([True, None, False, True]), # 'baz' column
211+
pa.array([None, "A", "B", "C"]), # 'large' column
212+
],
213+
schema=pa.schema([
214+
pa.field('foo', pa.string(), nullable=True),
215+
pa.field('bar', pa.int32(), nullable=False),
216+
pa.field('baz', pa.bool_(), nullable=True),
217+
pa.field('large', pa.large_string(), nullable=True),
218+
]),
219+
)
220+
database_name, _table_name = random_identifier
221+
catalog.create_namespace(database_name)
222+
table = catalog.create_table(random_identifier, pyarrow_table.schema)
223+
print(pyarrow_table.schema)
224+
print(table.schema().as_struct())
225+
print()
226+
table.overwrite(pyarrow_table)
227+
228+
196229
@pytest.mark.parametrize(
197230
'catalog',
198231
[

0 commit comments

Comments
 (0)