Skip to content

Commit 6459266

Browse files
authored
fix: rebalance deep AND/OR chains to avoid protobuf recursion limit (#4531)
1 parent ddd08ee commit 6459266

3 files changed

Lines changed: 127 additions & 7 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package org.apache.comet.serde
2121

2222
import java.util.concurrent.atomic.AtomicLong
2323

24+
import scala.collection.mutable.ArrayBuffer
2425
import scala.jdk.CollectionConverters._
2526

2627
import org.apache.spark.internal.Logging
@@ -832,6 +833,81 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
832833
}
833834
}
834835

836+
/**
837+
* Serialize an associative boolean chain (`And` / `Or`) as a BALANCED `BinaryExpr` tree of
838+
* depth `O(log n)` instead of the natural left-deep `O(n)`. A query with many ANDed/ORed
839+
* predicates otherwise builds a proto nested deeper than protobuf's default recursion limit
840+
* (100), which overflows when the serialized plan is re-parsed -- on the JVM
841+
* (`OperatorOuterClass.Operator.parseFrom`, e.g. `findShuffleScanIndices` / explain) and in the
842+
* Rust prost decoder. Comet evaluates `And`/`Or` vectorially (both sides always evaluated, no
843+
* row-level short-circuit), so rebalancing the associative chain is semantically identical --
844+
* it only changes the proto's shape.
845+
*
846+
* `operands` are the flattened leaves of the chain (see [[flattenAssociative]]); `wrap` tags
847+
* each combined `BinaryExpr` as `And` or `Or`.
848+
*/
849+
def createBalancedBinaryExpr(
850+
expr: Expression,
851+
operands: Seq[Expression],
852+
inputs: Seq[Attribute],
853+
binding: Boolean,
854+
wrap: (
855+
ExprOuterClass.Expr.Builder,
856+
ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder)
857+
: Option[ExprOuterClass.Expr] = {
858+
val protos = operands.map(exprToProtoInternal(_, inputs, binding))
859+
if (protos.exists(_.isEmpty)) {
860+
withFallbackReason(expr, operands: _*)
861+
None
862+
} else {
863+
val leaves = protos.map(_.get).toIndexedSeq
864+
def build(slice: IndexedSeq[ExprOuterClass.Expr]): ExprOuterClass.Expr = {
865+
if (slice.length == 1) slice.head
866+
else {
867+
val mid = slice.length / 2
868+
val inner = ExprOuterClass.BinaryExpr
869+
.newBuilder()
870+
.setLeft(build(slice.slice(0, mid)))
871+
.setRight(build(slice.slice(mid, slice.length)))
872+
.build()
873+
wrap(ExprOuterClass.Expr.newBuilder(), inner).build()
874+
}
875+
}
876+
Some(build(leaves))
877+
}
878+
}
879+
880+
/**
881+
* Flatten an associative binary chain into its leaf operands, in left-to-right order. `matches`
882+
* identifies the same operator (e.g. `case _: And => true`) and `children` extracts its two
883+
* operands. Used to rebalance deep `And`/`Or` chains before serialization (see
884+
* [[createBalancedBinaryExpr]]).
885+
*
886+
* Implemented with an explicit work stack and an accumulating buffer rather than recursion: the
887+
* chains that trigger this are left-deep and `O(n)` deep, so a recursive walk could itself
888+
* overflow the JVM stack, and `++`-accumulating the results would be `O(n^2)`.
889+
*/
890+
def flattenAssociative(
891+
expr: Expression,
892+
matches: Expression => Boolean,
893+
children: Expression => (Expression, Expression)): Seq[Expression] = {
894+
val operands = ArrayBuffer.empty[Expression]
895+
var stack: List[Expression] = expr :: Nil
896+
while (stack.nonEmpty) {
897+
val current = stack.head
898+
stack = stack.tail
899+
if (matches(current)) {
900+
val (l, r) = children(current)
901+
// Push right before left so the left subtree is popped (and emitted) first, preserving
902+
// the original left-to-right operand order.
903+
stack = l :: r :: stack
904+
} else {
905+
operands += current
906+
}
907+
}
908+
operands.toSeq
909+
}
910+
835911
def scalarFunctionExprToProtoWithReturnType(
836912
funcName: String,
837913
returnType: DataType,

spark/src/main/scala/org/apache/comet/serde/predicates.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,16 @@ object CometAnd extends CometExpressionSerde[And] {
6969
expr: And,
7070
inputs: Seq[Attribute],
7171
binding: Boolean): Option[ExprOuterClass.Expr] = {
72-
createBinaryExpr(
72+
// Rebalance the (associative) AND chain so deep `a AND b AND ...` predicates produce a
73+
// shallow proto instead of a left-deep one that overflows protobuf's recursion limit when
74+
// the plan is re-parsed (see createBalancedBinaryExpr).
75+
val operands = flattenAssociative(
7376
expr,
74-
expr.left,
75-
expr.right,
77+
{ case _: And => true; case _ => false },
78+
{ case a: And => (a.left, a.right) })
79+
createBalancedBinaryExpr(
80+
expr,
81+
operands,
7682
inputs,
7783
binding,
7884
(builder, binaryExpr) => builder.setAnd(binaryExpr))
@@ -84,10 +90,13 @@ object CometOr extends CometExpressionSerde[Or] {
8490
expr: Or,
8591
inputs: Seq[Attribute],
8692
binding: Boolean): Option[ExprOuterClass.Expr] = {
87-
createBinaryExpr(
93+
val operands = flattenAssociative(
8894
expr,
89-
expr.left,
90-
expr.right,
95+
{ case _: Or => true; case _ => false },
96+
{ case o: Or => (o.left, o.right) })
97+
createBalancedBinaryExpr(
98+
expr,
99+
operands,
91100
inputs,
92101
binding,
93102
(builder, binaryExpr) => builder.setOr(binaryExpr))

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import java.time.{Duration, Period}
2424
import scala.util.Random
2525

2626
import org.apache.hadoop.fs.Path
27-
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
27+
import org.apache.spark.sql.{Column, CometTestBase, DataFrame, Row}
2828
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, StructsToJson, TruncDate, TruncTimestamp}
2929
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
3030
import org.apache.spark.sql.comet.CometProjectExec
@@ -3127,4 +3127,39 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
31273127
}
31283128
}
31293129

3130+
test("deep AND/OR predicate chains do not overflow the protobuf recursion limit") {
3131+
// A left-deep chain of N associative boolean operands serializes to a proto nested N
3132+
// levels deep. With N > protobuf's default recursion limit (100), the message overflows
3133+
// when the serialized plan is re-parsed (JVM Operator.parseFrom and the Rust prost
3134+
// decoder), failing an otherwise-supported query. Comet evaluates AND/OR vectorially with
3135+
// no short-circuit, so the chain is fully associative and safe to rebalance.
3136+
val n = 200
3137+
// `_2` is nullable (every 7th row is null) so the rebalanced chain is exercised under SQL
3138+
// three-valued logic, not just true/false operands.
3139+
withParquetTable(
3140+
(0 until 100).map(i => (i, if (i % 7 == 0) None else Some(i.toLong))),
3141+
"tbl") {
3142+
// Build a chain that mixes the non-nullable `_1` with the nullable `_2` so null operands
3143+
// flow through the rebalanced tree.
3144+
def operand(i: Int): Column =
3145+
if (i % 2 == 0) col("_2") > lit(-i) else col("_1") > lit(-i)
3146+
3147+
// Project the chains as boolean columns rather than filtering: a top-level filter AND is
3148+
// split by Spark's splitConjunctivePredicates into many shallow pushed predicates, which
3149+
// would hide the deep-nesting. A projected expression survives intact. Distinct literals
3150+
// keep the optimizer from folding the chain; `>`/`<` (not `=`) keeps OptimizeIn from
3151+
// collapsing the OR chain into a single In.
3152+
val andChain = (1 to n).map(operand).reduce(_ && _)
3153+
checkSparkAnswerAndOperator(spark.table("tbl").select(andChain.as("a")))
3154+
3155+
val orChain = (1 to n).map(i => col("_1") < lit(i) || col("_2") < lit(i)).reduce(_ || _)
3156+
checkSparkAnswerAndOperator(spark.table("tbl").select(orChain.as("o")))
3157+
3158+
// A deep OR is a common real-world WHERE clause and, unlike a top-level AND, is NOT split
3159+
// by Spark -- it stays intact as a single deeply-nested predicate, so exercise that path
3160+
// directly.
3161+
checkSparkAnswerAndOperator(spark.table("tbl").where(orChain))
3162+
}
3163+
}
3164+
31303165
}

0 commit comments

Comments
 (0)