Skip to content

Commit e83a839

Browse files
committed
[flink][spark] Fix partition pruning for non-string partition keys
Partition predicate pushdown stringified literals and partition values before evaluation, so range comparisons fell back to string lexicographic order. An INT partition column with values 2 and 10 under WHERE pt > 2 lex-compared "10" < "2" and incorrectly dropped partition 10. Add PartitionUtils.toPartitionRow and PartitionUtils.partitionRowType in fluss-common. Use them from SparkPartitionPredicate and FlinkSourceEnumerator; drop the stringify step in FlinkTableSource and delete StringifyPredicateVisitor. The stringifier was also hiding two latent gaps in LeafPredicate.get: BYTES had no case (UnsupportedOperationException) and TIMESTAMP_WITH_LOCAL_TIME_ZONE used getTimestampNtz instead of getTimestampLtz (ClassCastException). Both exercised by testStreamingReadAllPartitionTypePushDown; fix in the same file. Regression test for the partition pruning bug added with an INT partition column and a range predicate in SparkLogTableReadTest and FlinkTableSourceITCase. Closes #3292.
1 parent 170e95f commit e83a839

9 files changed

Lines changed: 98 additions & 139 deletions

File tree

fluss-common/src/main/java/org/apache/fluss/predicate/LeafPredicate.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ public static Object get(InternalRow internalRow, int pos, DataType fieldType) {
175175
return internalRow.getTimestampNtz(pos, timestampType.getPrecision());
176176
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
177177
LocalZonedTimestampType lzTs = (LocalZonedTimestampType) fieldType;
178-
return internalRow.getTimestampNtz(pos, lzTs.getPrecision());
178+
return internalRow.getTimestampLtz(pos, lzTs.getPrecision());
179179
case FLOAT:
180180
return internalRow.getFloat(pos);
181181
case DOUBLE:
@@ -188,6 +188,7 @@ public static Object get(InternalRow internalRow, int pos, DataType fieldType) {
188188
return internalRow.getDecimal(
189189
pos, decimalType.getPrecision(), decimalType.getScale());
190190
case BINARY:
191+
case BYTES:
191192
return internalRow.getBytes(pos);
192193
default:
193194
throw new UnsupportedOperationException("Unsupported type: " + fieldType);

fluss-common/src/main/java/org/apache/fluss/utils/PartitionUtils.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@
2222
import org.apache.fluss.exception.InvalidPartitionException;
2323
import org.apache.fluss.metadata.PartitionSpec;
2424
import org.apache.fluss.metadata.ResolvedPartitionSpec;
25+
import org.apache.fluss.metadata.TableInfo;
2526
import org.apache.fluss.metadata.TablePath;
2627
import org.apache.fluss.row.BinaryString;
28+
import org.apache.fluss.row.GenericRow;
2729
import org.apache.fluss.row.TimestampLtz;
2830
import org.apache.fluss.row.TimestampNtz;
2931
import org.apache.fluss.types.DataTypeRoot;
32+
import org.apache.fluss.types.RowType;
3033

3134
import java.time.Instant;
3235
import java.time.ZonedDateTime;
@@ -345,4 +348,28 @@ public static String convertValueOfType(Object value, DataTypeRoot type) {
345348
}
346349
return stringPartitionKey;
347350
}
351+
352+
/** Projects {@code tableInfo}'s row type down to its partition key columns, in key order. */
353+
public static RowType partitionRowType(TableInfo tableInfo) {
354+
RowType schema = tableInfo.getRowType();
355+
List<String> fieldNames = schema.getFieldNames();
356+
int[] indexes =
357+
tableInfo.getPartitionKeys().stream().mapToInt(fieldNames::indexOf).toArray();
358+
return schema.project(indexes);
359+
}
360+
361+
/**
362+
* Builds a row of typed partition values by parsing each string with {@link
363+
* #parseValueOfType(String, DataTypeRoot)} for the column at that ordinal in {@code
364+
* partitionRowType}.
365+
*/
366+
public static GenericRow toPartitionRow(
367+
List<String> partitionValues, RowType partitionRowType) {
368+
GenericRow row = new GenericRow(partitionValues.size());
369+
for (int i = 0; i < partitionValues.size(); i++) {
370+
DataTypeRoot type = partitionRowType.getTypeAt(i).getTypeRoot();
371+
row.setField(i, parseValueOfType(partitionValues.get(i), type));
372+
}
373+
return row;
374+
}
348375
}

fluss-flink/fluss-flink-common/src/main/java/org/apache/fluss/flink/source/FlinkTableSource.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
import static org.apache.fluss.flink.utils.PredicateConverter.convertToFlussPredicate;
105105
import static org.apache.fluss.flink.utils.PushdownUtils.ValueConversion.FLINK_INTERNAL_VALUE;
106106
import static org.apache.fluss.flink.utils.PushdownUtils.extractFieldEquals;
107-
import static org.apache.fluss.flink.utils.StringifyPredicateVisitor.stringifyPartitionPredicate;
108107
import static org.apache.fluss.utils.Preconditions.checkNotNull;
109108

110109
/** Flink table source to scan Fluss data. */
@@ -627,8 +626,7 @@ && hasPrimaryKey()
627626
} else {
628627
acceptedFilters.add(filter);
629628
}
630-
// Convert literals in the predicate to partition string
631-
converted.add(stringifyPartitionPredicate(p));
629+
converted.add(p);
632630
} else {
633631
remainingFilters.add(filter);
634632
}

fluss-flink/fluss-flink-common/src/main/java/org/apache/fluss/flink/source/enumerator/FlinkSourceEnumerator.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@
5050
import org.apache.fluss.metadata.TableInfo;
5151
import org.apache.fluss.metadata.TablePath;
5252
import org.apache.fluss.predicate.Predicate;
53-
import org.apache.fluss.row.BinaryString;
54-
import org.apache.fluss.row.GenericRow;
5553
import org.apache.fluss.row.InternalRow;
5654
import org.apache.fluss.shaded.guava32.com.google.common.collect.Lists;
55+
import org.apache.fluss.types.RowType;
5756
import org.apache.fluss.utils.ExceptionUtils;
57+
import org.apache.fluss.utils.PartitionUtils;
5858

5959
import org.apache.flink.annotation.Internal;
6060
import org.apache.flink.annotation.VisibleForTesting;
@@ -558,9 +558,13 @@ private List<PartitionInfo> applyPartitionFilter(List<PartitionInfo> partitionIn
558558
return partitionInfos;
559559
} else {
560560
int originalSize = partitionInfos.size();
561+
RowType partitionRowType = PartitionUtils.partitionRowType(tableInfo);
561562
List<PartitionInfo> filteredPartitionInfos =
562563
partitionInfos.stream()
563-
.filter(partition -> partitionFilters.test(toInternalRow(partition)))
564+
.filter(
565+
partition ->
566+
partitionFilters.test(
567+
toInternalRow(partition, partitionRowType)))
564568
.collect(Collectors.toList());
565569

566570
int filteredSize = filteredPartitionInfos.size();
@@ -583,14 +587,10 @@ private List<PartitionInfo> applyPartitionFilter(List<PartitionInfo> partitionIn
583587
}
584588
}
585589

586-
private static InternalRow toInternalRow(PartitionInfo partitionInfo) {
587-
List<String> partitionValues =
588-
partitionInfo.getResolvedPartitionSpec().getPartitionValues();
589-
GenericRow genericRow = new GenericRow(partitionValues.size());
590-
for (int i = 0; i < partitionValues.size(); i++) {
591-
genericRow.setField(i, BinaryString.fromString(partitionValues.get(i)));
592-
}
593-
return genericRow;
590+
private static InternalRow toInternalRow(
591+
PartitionInfo partitionInfo, RowType partitionRowType) {
592+
return PartitionUtils.toPartitionRow(
593+
partitionInfo.getResolvedPartitionSpec().getPartitionValues(), partitionRowType);
594594
}
595595

596596
/** Init the splits for Fluss. */

fluss-flink/fluss-flink-common/src/main/java/org/apache/fluss/flink/utils/StringifyPredicateVisitor.java

Lines changed: 0 additions & 72 deletions
This file was deleted.

fluss-flink/fluss-flink-common/src/test/java/org/apache/fluss/flink/source/FlinkTableSourceITCase.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,32 @@ private List<String> writeRowsToTwoPartition(TablePath tablePath, Collection<Str
17561756
return expectedRowValues;
17571757
}
17581758

1759+
@Test
1760+
void testReadPartitionPushDownWithIntPartitionRangePredicate() throws Exception {
1761+
tEnv.executeSql(
1762+
"create table int_partitioned_table"
1763+
+ " (a int not null, b varchar, pt int, primary key (a, pt) NOT ENFORCED) partitioned by (pt) ");
1764+
TablePath tablePath = TablePath.of(DEFAULT_DB, "int_partitioned_table");
1765+
tEnv.executeSql("alter table int_partitioned_table add partition (pt=2)");
1766+
tEnv.executeSql("alter table int_partitioned_table add partition (pt=10)");
1767+
1768+
List<InternalRow> rows = new ArrayList<>();
1769+
for (int i = 0; i < 3; i++) {
1770+
rows.add(row(i, "v" + i, 2));
1771+
rows.add(row(i, "v" + i, 10));
1772+
}
1773+
writeRows(conn, tablePath, rows, false);
1774+
FLUSS_CLUSTER_EXTENSION.triggerAndWaitSnapshot(tablePath);
1775+
1776+
List<String> expected = new ArrayList<>();
1777+
for (int i = 0; i < 3; i++) {
1778+
expected.add(String.format("+I[%d, v%d, 10]", i, i));
1779+
}
1780+
CloseableIterator<Row> rowIter =
1781+
tEnv.executeSql("select * from int_partitioned_table where pt > 2").collect();
1782+
assertResultsIgnoreOrder(rowIter, expected, true);
1783+
}
1784+
17591785
@Test
17601786
void testStreamingReadPartitionPushDownWithInExpr() throws Exception {
17611787

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ class FlussAppendBatch(
119119

120120
if (tableInfo.isPartitioned) {
121121
val matching =
122-
SparkPartitionPredicate.filterPartitions(partitionInfos.asScala.toSeq, partitionPredicate)
122+
SparkPartitionPredicate.filterPartitions(
123+
tableInfo,
124+
partitionInfos.asScala.toSeq,
125+
partitionPredicate)
123126
matching
124127
.map {
125128
partitionInfo =>
@@ -240,7 +243,10 @@ class FlussUpsertBatch(
240243

241244
if (tableInfo.isPartitioned) {
242245
val matching =
243-
SparkPartitionPredicate.filterPartitions(partitionInfos.asScala.toSeq, partitionPredicate)
246+
SparkPartitionPredicate.filterPartitions(
247+
tableInfo,
248+
partitionInfos.asScala.toSeq,
249+
partitionPredicate)
244250
matching.flatMap {
245251
partitionInfo =>
246252
val partitionName = partitionInfo.getPartitionName

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
package org.apache.fluss.spark.utils
1919

2020
import org.apache.fluss.metadata.{PartitionInfo, TableInfo}
21-
import org.apache.fluss.predicate.{CompoundPredicate, LeafPredicate, PartitionPredicateVisitor, Predicate => FlussPredicate, PredicateBuilder, PredicateVisitor}
22-
import org.apache.fluss.row.{BinaryString, GenericRow}
23-
import org.apache.fluss.types.{DataTypes, RowType}
21+
import org.apache.fluss.predicate.{PartitionPredicateVisitor, Predicate => FlussPredicate, PredicateBuilder}
2422
import org.apache.fluss.utils.PartitionUtils
2523

2624
import org.apache.spark.sql.connector.expressions.filter.Predicate
@@ -34,15 +32,14 @@ object SparkPartitionPredicate {
3432
val partitionKeys = tableInfo.getPartitionKeys
3533
if (partitionKeys.isEmpty) return None
3634

37-
val rowType = partitionRowType(tableInfo)
35+
val rowType = PartitionUtils.partitionRowType(tableInfo)
3836
val onlyPartitionKeys = new PartitionPredicateVisitor(partitionKeys)
3937

4038
val converted = predicates.flatMap {
4139
sparkPredicate =>
4240
SparkPredicateConverter
4341
.convert(rowType, sparkPredicate)
4442
.filter(_.visit(onlyPartitionKeys))
45-
.map(stringifyLiterals)
4643
}
4744

4845
converted match {
@@ -53,54 +50,17 @@ object SparkPartitionPredicate {
5350
}
5451

5552
def filterPartitions(
53+
tableInfo: TableInfo,
5654
partitionInfos: Seq[PartitionInfo],
5755
partitionPredicate: Option[FlussPredicate]): Seq[PartitionInfo] =
5856
partitionPredicate match {
5957
case None => partitionInfos
60-
case Some(predicate) => partitionInfos.filter(p => predicate.test(toPartitionRow(p)))
58+
case Some(predicate) =>
59+
val rowType = PartitionUtils.partitionRowType(tableInfo)
60+
partitionInfos.filter {
61+
p =>
62+
predicate.test(
63+
PartitionUtils.toPartitionRow(p.getResolvedPartitionSpec.getPartitionValues, rowType))
64+
}
6165
}
62-
63-
private def partitionRowType(tableInfo: TableInfo): RowType = {
64-
val schemaRowType = tableInfo.getRowType
65-
val fieldNames = schemaRowType.getFieldNames
66-
val partitionFieldIndexes = tableInfo.getPartitionKeys.asScala.map(fieldNames.indexOf).toArray
67-
schemaRowType.project(partitionFieldIndexes)
68-
}
69-
70-
private def toPartitionRow(partitionInfo: PartitionInfo): GenericRow = {
71-
val values = partitionInfo.getResolvedPartitionSpec.getPartitionValues
72-
val row = new GenericRow(values.size)
73-
var i = 0
74-
while (i < values.size) {
75-
row.setField(i, BinaryString.fromString(values.get(i)))
76-
i += 1
77-
}
78-
row
79-
}
80-
81-
// Partition values are stored as strings; literals must be coerced before evaluation.
82-
private val stringifier: PredicateVisitor[FlussPredicate] = new PredicateVisitor[FlussPredicate] {
83-
override def visit(leaf: LeafPredicate): FlussPredicate = {
84-
val converted: Seq[Object] = leaf.literals.asScala.toSeq.map {
85-
case null => null
86-
case literal =>
87-
BinaryString.fromString(
88-
PartitionUtils.convertValueOfType(literal, leaf.`type`.getTypeRoot))
89-
}
90-
new LeafPredicate(
91-
leaf.function,
92-
DataTypes.STRING,
93-
leaf.index,
94-
leaf.fieldName,
95-
converted.asJava)
96-
}
97-
98-
override def visit(compound: CompoundPredicate): FlussPredicate = {
99-
val children = compound.children.asScala.map(_.visit(this)).asJava
100-
new CompoundPredicate(compound.function, children)
101-
}
102-
}
103-
104-
private def stringifyLiterals(predicate: FlussPredicate): FlussPredicate =
105-
predicate.visit(stringifier)
10666
}

fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,19 @@ class SparkLogTableReadTest extends FlussSparkTestBase {
498498
}
499499
}
500500

501+
test("Spark Read: partition pushdown — INT partition with range predicate") {
502+
withTable("t") {
503+
sql(s"""
504+
|CREATE TABLE $DEFAULT_DATABASE.t (id BIGINT, pt INT)
505+
|PARTITIONED BY (pt)""".stripMargin)
506+
sql(s"INSERT INTO $DEFAULT_DATABASE.t VALUES (1, 2), (2, 10)")
507+
508+
val query = sql(s"SELECT * FROM $DEFAULT_DATABASE.t WHERE pt > 2 ORDER BY id")
509+
checkAnswer(query, Row(2L, 10) :: Nil)
510+
assert(partitionPredicate(query).isDefined)
511+
}
512+
}
513+
501514
test("Spark Read: scan description surfaces partition filter when pushed") {
502515
withPartitionedTable {
503516
val withPart = sql(s"SELECT * FROM $DEFAULT_DATABASE.t WHERE dt = '2026-01-01'")

0 commit comments

Comments
 (0)