Skip to content

Commit 53f680c

Browse files
Gayathri Srividya RajavarapuGayathri Srividya Rajavarapu
authored andcommitted
fix: handle upsert after schema evolution
1 parent 6da06ad commit 53f680c

2 files changed

Lines changed: 59 additions & 8 deletions

File tree

pyiceberg/table/upsert_util.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,20 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
8585
f"DataFrames, and cannot be used as column names"
8686
) from None
8787

88-
# Step 1: Prepare source index with join keys and a marker index
89-
# Cast to target table schema, so we can do the join
90-
# See: https://github.com/apache/arrow/issues/37542
88+
# Step 1: Prepare source index with join keys and a marker index.
89+
# Cast only join columns to target join-column schema so schema evolution
90+
# (for example, newly added non-key columns) doesn't break the join setup.
9191
source_index = (
92-
source_table.cast(target_table.schema)
93-
.select(join_cols_set)
92+
source_table.select(join_cols)
93+
.cast(pa.schema([target_table.schema.field(col) for col in join_cols]))
9494
.append_column(SOURCE_INDEX_COLUMN_NAME, pa.array(range(len(source_table))))
9595
)
9696

9797
# Step 2: Prepare target index with join keys and a marker
98-
target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))
98+
target_index = target_table.select(join_cols).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))
9999

100100
# Step 3: Perform an inner join to find which rows from source exist in target
101-
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
101+
matching_indices = source_index.join(target_index, keys=join_cols, join_type="inner")
102102

103103
# Step 4: Compare all rows using Python
104104
to_update_indices = []
@@ -112,7 +112,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
112112

113113
for key in non_key_cols:
114114
source_val = source_row.column(key)[0].as_py()
115-
target_val = target_row.column(key)[0].as_py()
115+
target_val = target_row.column(key)[0].as_py() if key in target_table.column_names else None
116116
if source_val != target_val:
117117
to_update_indices.append(source_idx)
118118
break

tests/table/test_upsert.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,57 @@ def test_upsert_with_nulls(catalog: Catalog) -> None:
714714
)
715715

716716

717+
def test_upsert_after_schema_add_column(catalog: Catalog) -> None:
718+
identifier = "default.test_upsert_after_schema_add_column"
719+
_drop_table(catalog, identifier)
720+
721+
schema = Schema(
722+
NestedField(1, "id", IntegerType(), required=True),
723+
NestedField(2, "name", StringType(), required=True),
724+
identifier_field_ids=[1],
725+
)
726+
727+
tbl = catalog.create_table(identifier, schema=schema)
728+
729+
initial = pa.Table.from_pylist(
730+
[{"id": 1, "name": "Alice"}],
731+
schema=pa.schema(
732+
[
733+
pa.field("id", pa.int32(), nullable=False),
734+
pa.field("name", pa.string(), nullable=False),
735+
]
736+
),
737+
)
738+
tbl.append(initial)
739+
740+
with tbl.update_schema() as update_schema:
741+
update_schema.add_column("country", StringType())
742+
tbl = tbl.refresh()
743+
744+
source = pa.Table.from_pylist(
745+
[
746+
{"id": 1, "name": "Alice", "country": "NL"},
747+
{"id": 2, "name": "Bob", "country": "US"},
748+
],
749+
schema=pa.schema(
750+
[
751+
pa.field("id", pa.int32(), nullable=False),
752+
pa.field("name", pa.string(), nullable=False),
753+
pa.field("country", pa.string(), nullable=True),
754+
]
755+
),
756+
)
757+
758+
upd = tbl.upsert(source, ["id"])
759+
760+
assert upd.rows_updated == 1
761+
assert upd.rows_inserted == 1
762+
assert sorted(tbl.scan().to_arrow().to_pylist(), key=lambda row: row["id"]) == [
763+
{"id": 1, "name": "Alice", "country": "NL"},
764+
{"id": 2, "name": "Bob", "country": "US"},
765+
]
766+
767+
717768
def test_transaction(catalog: Catalog) -> None:
718769
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is
719770
rolled back."""

0 commit comments

Comments
 (0)