diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 0bdc02a790..12e710d44e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -21,6 +21,7 @@ package org.apache.comet.serde import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import org.apache.spark.internal.Logging @@ -832,6 +833,81 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { } } + /** + * Serialize an associative boolean chain (`And` / `Or`) as a BALANCED `BinaryExpr` tree of + * depth `O(log n)` instead of the natural left-deep `O(n)`. A query with many ANDed/ORed + * predicates otherwise builds a proto nested deeper than protobuf's default recursion limit + * (100), which overflows when the serialized plan is re-parsed -- on the JVM + * (`OperatorOuterClass.Operator.parseFrom`, e.g. `findShuffleScanIndices` / explain) and in the + * Rust prost decoder. Comet evaluates `And`/`Or` vectorially (both sides always evaluated, no + * row-level short-circuit), so rebalancing the associative chain is semantically identical -- + * it only changes the proto's shape. + * + * `operands` are the flattened leaves of the chain (see [[flattenAssociative]]); `wrap` tags + * each combined `BinaryExpr` as `And` or `Or`. + */ + def createBalancedBinaryExpr( + expr: Expression, + operands: Seq[Expression], + inputs: Seq[Attribute], + binding: Boolean, + wrap: ( + ExprOuterClass.Expr.Builder, + ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val protos = operands.map(exprToProtoInternal(_, inputs, binding)) + if (protos.exists(_.isEmpty)) { + withFallbackReason(expr, operands: _*) + None + } else { + val leaves = protos.map(_.get).toIndexedSeq + def build(slice: IndexedSeq[ExprOuterClass.Expr]): ExprOuterClass.Expr = { + if (slice.length == 1) slice.head + else { + val mid = slice.length / 2 + val inner = ExprOuterClass.BinaryExpr + .newBuilder() + .setLeft(build(slice.slice(0, mid))) + .setRight(build(slice.slice(mid, slice.length))) + .build() + wrap(ExprOuterClass.Expr.newBuilder(), inner).build() + } + } + Some(build(leaves)) + } + } + + /** + * Flatten an associative binary chain into its leaf operands, in left-to-right order. `matches` + * identifies the same operator (e.g. `case _: And => true`) and `children` extracts its two + * operands. Used to rebalance deep `And`/`Or` chains before serialization (see + * [[createBalancedBinaryExpr]]). + * + * Implemented with an explicit work stack and an accumulating buffer rather than recursion: the + * chains that trigger this are left-deep and `O(n)` deep, so a recursive walk could itself + * overflow the JVM stack, and `++`-accumulating the results would be `O(n^2)`. + */ + def flattenAssociative( + expr: Expression, + matches: Expression => Boolean, + children: Expression => (Expression, Expression)): Seq[Expression] = { + val operands = ArrayBuffer.empty[Expression] + var stack: List[Expression] = expr :: Nil + while (stack.nonEmpty) { + val current = stack.head + stack = stack.tail + if (matches(current)) { + val (l, r) = children(current) + // Push right before left so the left subtree is popped (and emitted) first, preserving + // the original left-to-right operand order. + stack = l :: r :: stack + } else { + operands += current + } + } + operands.toSeq + } + def scalarFunctionExprToProtoWithReturnType( funcName: String, returnType: DataType, diff --git a/spark/src/main/scala/org/apache/comet/serde/predicates.scala b/spark/src/main/scala/org/apache/comet/serde/predicates.scala index 7abe40823e..63b64fbcf2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/predicates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/predicates.scala @@ -69,10 +69,16 @@ object CometAnd extends CometExpressionSerde[And] { expr: And, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - createBinaryExpr( + // Rebalance the (associative) AND chain so deep `a AND b AND ...` predicates produce a + // shallow proto instead of a left-deep one that overflows protobuf's recursion limit when + // the plan is re-parsed (see createBalancedBinaryExpr). + val operands = flattenAssociative( expr, - expr.left, - expr.right, + { case _: And => true; case _ => false }, + { case a: And => (a.left, a.right) }) + createBalancedBinaryExpr( + expr, + operands, inputs, binding, (builder, binaryExpr) => builder.setAnd(binaryExpr)) @@ -84,10 +90,13 @@ object CometOr extends CometExpressionSerde[Or] { expr: Or, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - createBinaryExpr( + val operands = flattenAssociative( expr, - expr.left, - expr.right, + { case _: Or => true; case _ => false }, + { case o: Or => (o.left, o.right) }) + createBalancedBinaryExpr( + expr, + operands, inputs, binding, (builder, binaryExpr) => builder.setOr(binaryExpr)) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 819b1ba051..8dc84dc9e7 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -24,7 +24,7 @@ import java.time.{Duration, Period} import scala.util.Random import org.apache.hadoop.fs.Path -import org.apache.spark.sql.{CometTestBase, DataFrame, Row} +import org.apache.spark.sql.{Column, CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, StructsToJson, TruncDate, TruncTimestamp} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps import org.apache.spark.sql.comet.CometProjectExec @@ -3096,4 +3096,39 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("deep AND/OR predicate chains do not overflow the protobuf recursion limit") { + // A left-deep chain of N associative boolean operands serializes to a proto nested N + // levels deep. With N > protobuf's default recursion limit (100), the message overflows + // when the serialized plan is re-parsed (JVM Operator.parseFrom and the Rust prost + // decoder), failing an otherwise-supported query. Comet evaluates AND/OR vectorially with + // no short-circuit, so the chain is fully associative and safe to rebalance. + val n = 200 + // `_2` is nullable (every 7th row is null) so the rebalanced chain is exercised under SQL + // three-valued logic, not just true/false operands. + withParquetTable( + (0 until 100).map(i => (i, if (i % 7 == 0) None else Some(i.toLong))), + "tbl") { + // Build a chain that mixes the non-nullable `_1` with the nullable `_2` so null operands + // flow through the rebalanced tree. + def operand(i: Int): Column = + if (i % 2 == 0) col("_2") > lit(-i) else col("_1") > lit(-i) + + // Project the chains as boolean columns rather than filtering: a top-level filter AND is + // split by Spark's splitConjunctivePredicates into many shallow pushed predicates, which + // would hide the deep-nesting. A projected expression survives intact. Distinct literals + // keep the optimizer from folding the chain; `>`/`<` (not `=`) keeps OptimizeIn from + // collapsing the OR chain into a single In. + val andChain = (1 to n).map(operand).reduce(_ && _) + checkSparkAnswerAndOperator(spark.table("tbl").select(andChain.as("a"))) + + val orChain = (1 to n).map(i => col("_1") < lit(i) || col("_2") < lit(i)).reduce(_ || _) + checkSparkAnswerAndOperator(spark.table("tbl").select(orChain.as("o"))) + + // A deep OR is a common real-world WHERE clause and, unlike a top-level AND, is NOT split + // by Spark -- it stays intact as a single deeply-nested predicate, so exercise that path + // directly. + checkSparkAnswerAndOperator(spark.table("tbl").where(orChain)) + } + } + }