Skip to content

Commit a9a74c4

Browse files
committed
[SPARK-56660][SQL] Decompose struct equality into field-level predicates for filter pushdown
1 parent af55029 commit a9a74c4

5 files changed

Lines changed: 512 additions & 0 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
141141
BooleanSimplification,
142142
SimplifyConditionals,
143143
PushFoldableIntoBranches,
144+
DecomposeStructComparison,
144145
SimplifyBinaryComparison,
145146
ReplaceNullWithFalseInPredicate,
146147
PruneFilters,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,46 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
564564
}
565565

566566

567+
/**
568+
* Decomposes struct-level equality comparisons into conjunctions of field-level equalities.
569+
* This enables filter pushdown for individual struct fields.
570+
* For example, `struct_col = struct(1, 'a')` becomes
571+
* `struct_col.field1 = 1 AND struct_col.field2 = 'a'`.
572+
*/
573+
object DecomposeStructComparison extends Rule[LogicalPlan] {
574+
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
575+
_.containsPattern(FILTER), ruleId) {
576+
case f @ Filter(condition, _) =>
577+
f.copy(condition = decomposeCondition(condition))
578+
}
579+
580+
private def decomposeCondition(expr: Expression): Expression = expr.transformWithPruning(
581+
_.containsPattern(BINARY_COMPARISON)) {
582+
case EqualTo(left, right) if canDecompose(left, right) =>
583+
decompose(left, right, EqualTo)
584+
case EqualNullSafe(left, right) if canDecompose(left, right) =>
585+
decompose(left, right, EqualNullSafe)
586+
}
587+
588+
private def canDecompose(left: Expression, right: Expression): Boolean = {
589+
(left.dataType, right.dataType) match {
590+
case (l: StructType, r: StructType) =>
591+
l.length > 0 && l.length == r.length && left.deterministic && right.deterministic
592+
case _ => false
593+
}
594+
}
595+
596+
private def decompose(
597+
left: Expression,
598+
right: Expression,
599+
cmp: (Expression, Expression) => Expression): Expression = {
600+
val fields = left.dataType.asInstanceOf[StructType].fields
601+
fields.indices.map { i =>
602+
cmp(GetStructField(left, i), GetStructField(right, i))
603+
}.reduceLeft(And)
604+
}
605+
}
606+
567607
/**
568608
* Simplifies binary comparisons with semantically-equal expressions:
569609
* 1) Replace '<=>' with 'true' literal.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ object RuleIdCollection {
182182
"org.apache.spark.sql.catalyst.optimizer.RewriteExceptAll" ::
183183
"org.apache.spark.sql.catalyst.optimizer.RewriteIntersectAll" ::
184184
"org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" ::
185+
"org.apache.spark.sql.catalyst.optimizer.DecomposeStructComparison" ::
185186
"org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" ::
186187
"org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::
187188
"org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" ::

0 commit comments

Comments
 (0)