Skip to content

Commit acc934f

Browse files
authored
Check the types when writing (#313)
1 parent 9e03949 commit acc934f

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

pyiceberg/table/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,14 @@ def append(self, df: pa.Table) -> None:
932932
Args:
933933
df: The Arrow dataframe that will be appended to overwrite the table
934934
"""
935+
try:
936+
import pyarrow as pa
937+
except ModuleNotFoundError as e:
938+
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
939+
940+
if not isinstance(df, pa.Table):
941+
raise ValueError(f"Expected PyArrow table, got: {df}")
942+
935943
if len(self.spec().fields) > 0:
936944
raise ValueError("Cannot write to partitioned tables")
937945

@@ -954,6 +962,14 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
954962
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
955963
or a boolean expression in case of a partial overwrite
956964
"""
965+
try:
966+
import pyarrow as pa
967+
except ModuleNotFoundError as e:
968+
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
969+
970+
if not isinstance(df, pa.Table):
971+
raise ValueError(f"Expected PyArrow table, got: {df}")
972+
957973
if overwrite_filter != AlwaysTrue():
958974
raise NotImplementedError("Cannot overwrite a subset of a table")
959975

tests/integration/test_writes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,21 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
391391
assert [row.added_data_files_count for row in rows] == [1, 1, 0, 1, 1]
392392
assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0]
393393
assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0]
394+
395+
396+
@pytest.mark.integration
397+
def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
398+
identifier = "default.arrow_data_files"
399+
400+
try:
401+
session_catalog.drop_table(identifier=identifier)
402+
except NoSuchTableError:
403+
pass
404+
405+
tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'})
406+
407+
with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"):
408+
tbl.overwrite("not a df")
409+
410+
with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"):
411+
tbl.append("not a df")

0 commit comments

Comments
 (0)