|
14 | 14 | # KIND, either express or implied. See the License for the |
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
| 17 | +import itertools |
17 | 18 | from pathlib import PosixPath |
18 | 19 |
|
19 | 20 | import pyarrow as pa |
|
25 | 26 | from pyiceberg.exceptions import NoSuchTableError |
26 | 27 | from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference |
27 | 28 | from pyiceberg.expressions.literals import LongLiteral |
| 29 | +from pyiceberg.expressions.visitors import expression_evaluator |
28 | 30 | from pyiceberg.io.pyarrow import schema_to_pyarrow |
29 | 31 | from pyiceberg.schema import Schema |
30 | 32 | from pyiceberg.table import Table, UpsertResult |
31 | 33 | from pyiceberg.table.snapshots import Operation |
32 | 34 | from pyiceberg.table.upsert_util import create_match_filter |
| 35 | +from pyiceberg.typedef import Record |
33 | 36 | from pyiceberg.types import IntegerType, NestedField, StringType, StructType |
34 | 37 | from tests.catalog.test_base import InMemoryCatalog |
35 | 38 |
|
@@ -443,6 +446,72 @@ def test_create_match_filter_single_condition() -> None: |
443 | 446 | assert expr == And(op1, op2) or expr == And(op2, op1) |
444 | 447 |
|
445 | 448 |
|
| 449 | +def _assert_match_filter_selects(data: list[dict[str, int]], join_cols: list[str], schema: Schema) -> None: |
| 450 | + """Assert the filter from ``create_match_filter`` matches exactly the unique source keys. |
| 451 | +
|
| 452 | + Rather than asserting a specific expression tree (which is implementation-specific), |
| 453 | + this binds the filter and evaluates it against the full cross-product of the values |
| 454 | + observed per column. The filter must accept exactly the unique keys present in |
| 455 | + ``data`` and reject every other combination, so any over- or under-matching |
| 456 | + (e.g. a cross-product regression) is caught. This holds for any correct |
| 457 | + implementation of ``create_match_filter``. |
| 458 | + """ |
| 459 | + arrow_schema = schema_to_pyarrow(schema) |
| 460 | + table = pa.Table.from_pylist(data, schema=arrow_schema) |
| 461 | + expr = create_match_filter(table, join_cols) |
| 462 | + |
| 463 | + field_names = [field.name for field in schema.fields] |
| 464 | + expected_keys = {tuple(row[name] for name in field_names) for row in data} |
| 465 | + domains = [sorted({row[name] for row in data}) for name in field_names] |
| 466 | + |
| 467 | + evaluate = expression_evaluator(schema, expr, case_sensitive=True) |
| 468 | + for candidate in itertools.product(*domains): |
| 469 | + key = dict(zip(field_names, candidate, strict=True)) |
| 470 | + should_match = candidate in expected_keys |
| 471 | + verb = "rejected matching" if should_match else "matched non-matching" |
| 472 | + assert evaluate(Record(*candidate)) is should_match, f"Filter {expr} {verb} key {key}" |
| 473 | + |
| 474 | + |
| 475 | +def test_create_match_filter_single_prefix_group() -> None: |
| 476 | + """ |
| 477 | + Test create_match_filter with multiple key columns whose rows all share a single prefix combination. |
| 478 | +
|
| 479 | + The filter must match the (one order_id, many order_line_id) keys and nothing else. |
| 480 | + """ |
| 481 | + schema = Schema( |
| 482 | + NestedField(1, "order_id", IntegerType(), required=True), |
| 483 | + NestedField(2, "order_line_id", IntegerType(), required=True), |
| 484 | + ) |
| 485 | + data = [ |
| 486 | + {"order_id": 101, "order_line_id": 1}, |
| 487 | + {"order_id": 101, "order_line_id": 2}, |
| 488 | + {"order_id": 101, "order_line_id": 3}, |
| 489 | + {"order_id": 101, "order_line_id": 3}, # duplicate |
| 490 | + ] |
| 491 | + _assert_match_filter_selects(data, ["order_id", "order_line_id"], schema) |
| 492 | + |
| 493 | + |
| 494 | +def test_create_match_filter_multiple_prefix_groups() -> None: |
| 495 | + """ |
| 496 | + Test create_match_filter with multiple key columns that yield several distinct prefix combinations. |
| 497 | +
|
| 498 | + The filter must match exactly the listed composite keys and must NOT match cross-product |
| 499 | + combinations that never appear together (e.g. order_id 101 with order_line_id 2). |
| 500 | + """ |
| 501 | + schema = Schema( |
| 502 | + NestedField(1, "order_id", IntegerType(), required=True), |
| 503 | + NestedField(2, "order_line_id", IntegerType(), required=True), |
| 504 | + ) |
| 505 | + data = [ |
| 506 | + {"order_id": 101, "order_line_id": 1}, |
| 507 | + {"order_id": 102, "order_line_id": 1}, |
| 508 | + {"order_id": 103, "order_line_id": 1}, |
| 509 | + {"order_id": 201, "order_line_id": 2}, |
| 510 | + {"order_id": 202, "order_line_id": 2}, |
| 511 | + ] |
| 512 | + _assert_match_filter_selects(data, ["order_id", "order_line_id"], schema) |
| 513 | + |
| 514 | + |
446 | 515 | def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: |
447 | 516 | identifier = "default.test_upsert_with_duplicate_rows_in_table" |
448 | 517 |
|
|
0 commit comments