Skip to content

Commit 25d0853

Browse files
Fix in_col list-aggregate column collision and add edge-case tests
When folding a key column into an In(), the list aggregation column was named f"{in_col}_list", which silently clobbered a join column of the same name and fed a Python list into EqualTo (TypeError). Rename the aggregate to a collision-free sentinel by position instead. Also add coverage for create_match_filter edge cases raised in the #3509 review: single column, single-value collapse to EqualTo, empty input, three key columns, the column-name collision regression, and a large multi-column upsert (#3508) that must not overflow PyArrow's expression canonicalizer when a key column is low-cardinality. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 691d138 commit 25d0853

2 files changed

Lines changed: 111 additions & 3 deletions

File tree

pyiceberg/table/upsert_util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,13 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
5151
)
5252
prefix_cols = [c for c in join_cols if c != in_col]
5353

54-
grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")])
55-
in_values_col = f"{in_col}_list"
54+
# The group keys come first (in prefix_cols order) followed by the list aggregate.
55+
# Rename the aggregate to a sentinel so it cannot collide with a join column that
56+
# happens to be named f"{in_col}_list".
57+
in_values_col = "__in_values"
58+
while in_values_col in prefix_cols:
59+
in_values_col += "_"
60+
grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")]).rename_columns([*prefix_cols, in_values_col])
5661

5762
disjuncts: list[BooleanExpression] = []
5863
for row in grouped.to_pylist():

tests/table/test_upsert.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from pyiceberg.catalog import Catalog
2626
from pyiceberg.exceptions import NoSuchTableError
27-
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
27+
from pyiceberg.expressions import AlwaysFalse, AlwaysTrue, And, EqualTo, In, Reference
2828
from pyiceberg.expressions.literals import LongLiteral
2929
from pyiceberg.expressions.visitors import expression_evaluator
3030
from pyiceberg.io.pyarrow import schema_to_pyarrow
@@ -512,6 +512,109 @@ def test_create_match_filter_multiple_prefix_groups() -> None:
512512
_assert_match_filter_selects(data, ["order_id", "order_line_id"], schema)
513513

514514

515+
def test_create_match_filter_single_column() -> None:
516+
"""A single join column collapses to a single In() over the unique values."""
517+
schema = pa.schema([pa.field("order_id", pa.int32())])
518+
table = pa.Table.from_pylist([{"order_id": 1}, {"order_id": 2}, {"order_id": 2}], schema=schema)
519+
assert create_match_filter(table, ["order_id"]) == In("order_id", [1, 2])
520+
521+
522+
def test_create_match_filter_single_column_single_value() -> None:
523+
"""A single unique value collapses the In() down to an EqualTo()."""
524+
schema = pa.schema([pa.field("order_id", pa.int32())])
525+
table = pa.Table.from_pylist([{"order_id": 1}, {"order_id": 1}], schema=schema)
526+
assert create_match_filter(table, ["order_id"]) == EqualTo("order_id", 1)
527+
528+
529+
def test_create_match_filter_empty_input() -> None:
530+
"""An empty source matches nothing (AlwaysFalse), for both single and composite keys."""
531+
schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32())])
532+
empty = pa.Table.from_pylist([], schema=schema)
533+
assert create_match_filter(empty, ["order_id"]) == AlwaysFalse()
534+
assert create_match_filter(empty, ["order_id", "order_line_id"]) == AlwaysFalse()
535+
536+
537+
def test_create_match_filter_three_columns() -> None:
538+
"""
539+
Test create_match_filter with three key columns.
540+
541+
Exercises the multi-column prefix branch where the prefix predicate is an And of two
542+
EqualTo() conjuncts combined with an In() over the folded column.
543+
"""
544+
schema = Schema(
545+
NestedField(1, "a", IntegerType(), required=True),
546+
NestedField(2, "b", IntegerType(), required=True),
547+
NestedField(3, "c", IntegerType(), required=True),
548+
)
549+
data = [
550+
{"a": 1, "b": 1, "c": 1},
551+
{"a": 1, "b": 1, "c": 2},
552+
{"a": 1, "b": 1, "c": 3},
553+
{"a": 2, "b": 9, "c": 5},
554+
{"a": 2, "b": 9, "c": 6},
555+
]
556+
_assert_match_filter_selects(data, ["a", "b", "c"], schema)
557+
558+
559+
def test_create_match_filter_column_named_like_aggregate() -> None:
560+
"""
561+
Regression test for #3509 review feedback.
562+
563+
A join column named ``<in_col>_list`` must not collide with the internal list-aggregation
564+
column used to fold values into an In(). Before the fix this raised a TypeError.
565+
"""
566+
schema = Schema(
567+
NestedField(1, "a", IntegerType(), required=True),
568+
NestedField(2, "a_list", IntegerType(), required=True),
569+
)
570+
data = [
571+
{"a": 1, "a_list": 7},
572+
{"a": 2, "a_list": 7},
573+
{"a": 3, "a_list": 8},
574+
]
575+
_assert_match_filter_selects(data, ["a", "a_list"], schema)
576+
577+
578+
def test_upsert_large_composite_key_does_not_overflow(catalog: Catalog) -> None:
579+
"""
580+
Regression test for #3508: a large multi-column upsert must not overflow PyArrow's
581+
expression canonicalizer when at least one key column is low-cardinality (see #3509).
582+
"""
583+
identifier = "default.test_upsert_large_composite_key"
584+
_drop_table(catalog, identifier)
585+
586+
n = 20_000
587+
schema = pa.schema(
588+
[
589+
pa.field("order_id", pa.int64(), nullable=False),
590+
pa.field("region", pa.string(), nullable=False),
591+
pa.field("amount", pa.int64(), nullable=False),
592+
]
593+
)
594+
595+
def make(order_ids: range, amount: int) -> pa.Table:
596+
# region is intentionally low-cardinality (4 values) so the fix folds order_id into an In().
597+
return pa.Table.from_pylist(
598+
[{"order_id": oid, "region": "ABCD"[oid % 4], "amount": amount} for oid in order_ids],
599+
schema=schema,
600+
)
601+
602+
tbl = catalog.create_table(identifier, schema)
603+
tbl.append(make(range(1, n + 1), amount=1))
604+
605+
# Update the first half (amount changes) and insert a tenth of brand-new keys.
606+
source = pa.concat_tables(
607+
[
608+
make(range(1, n // 2 + 1), amount=2),
609+
make(range(n + 1, n + n // 10 + 1), amount=2),
610+
]
611+
)
612+
613+
res = tbl.upsert(source, join_cols=["order_id", "region"])
614+
assert res.rows_updated == n // 2
615+
assert res.rows_inserted == n // 10
616+
617+
515618
def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
516619
identifier = "default.test_upsert_with_duplicate_rows_in_table"
517620

0 commit comments

Comments
 (0)