Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.comet.serde

import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -832,6 +833,81 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
}
}

/**
* Serialize an associative boolean chain (`And` / `Or`) as a BALANCED `BinaryExpr` tree of
* depth `O(log n)` instead of the natural left-deep `O(n)`. A query with many ANDed/ORed
* predicates otherwise builds a proto nested deeper than protobuf's default recursion limit
* (100), which overflows when the serialized plan is re-parsed -- on the JVM
* (`OperatorOuterClass.Operator.parseFrom`, e.g. `findShuffleScanIndices` / explain) and in the
* Rust prost decoder. Comet evaluates `And`/`Or` vectorially (both sides always evaluated, no
* row-level short-circuit), so rebalancing the associative chain is semantically identical --
* it only changes the proto's shape.
*
* `operands` are the flattened leaves of the chain (see [[flattenAssociative]]); `wrap` tags
* each combined `BinaryExpr` as `And` or `Or`.
*/
def createBalancedBinaryExpr(
expr: Expression,
operands: Seq[Expression],
inputs: Seq[Attribute],
binding: Boolean,
wrap: (
ExprOuterClass.Expr.Builder,
ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder)
: Option[ExprOuterClass.Expr] = {
val protos = operands.map(exprToProtoInternal(_, inputs, binding))
if (protos.exists(_.isEmpty)) {
withFallbackReason(expr, operands: _*)
None
} else {
val leaves = protos.map(_.get).toIndexedSeq
def build(slice: IndexedSeq[ExprOuterClass.Expr]): ExprOuterClass.Expr = {
if (slice.length == 1) slice.head
else {
val mid = slice.length / 2
val inner = ExprOuterClass.BinaryExpr
.newBuilder()
.setLeft(build(slice.slice(0, mid)))
.setRight(build(slice.slice(mid, slice.length)))
.build()
wrap(ExprOuterClass.Expr.newBuilder(), inner).build()
}
}
Some(build(leaves))
}
}

/**
* Flatten an associative binary chain into its leaf operands, in left-to-right order. `matches`
* identifies the same operator (e.g. `case _: And => true`) and `children` extracts its two
* operands. Used to rebalance deep `And`/`Or` chains before serialization (see
* [[createBalancedBinaryExpr]]).
*
* Implemented with an explicit work stack and an accumulating buffer rather than recursion: the
* chains that trigger this are left-deep and `O(n)` deep, so a recursive walk could itself
* overflow the JVM stack, and `++`-accumulating the results would be `O(n^2)`.
*/
def flattenAssociative(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small thought: this recurses O(n) deep and the ++ accumulation is O(n^2). It is totally fine for the depths that hit this bug, but since the motivation is deep chains, would an explicit accumulator be worth it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Rewrote it with an explicit work stack and an accumulating buffer instead of recursion, so it's O(n) time and no longer recurses O(n) deep (which on these left-deep chains could itself overflow the stack). Left-to-right operand order is preserved.

expr: Expression,
matches: Expression => Boolean,
children: Expression => (Expression, Expression)): Seq[Expression] = {
val operands = ArrayBuffer.empty[Expression]
var stack: List[Expression] = expr :: Nil
while (stack.nonEmpty) {
val current = stack.head
stack = stack.tail
if (matches(current)) {
val (l, r) = children(current)
// Push right before left so the left subtree is popped (and emitted) first, preserving
// the original left-to-right operand order.
stack = l :: r :: stack
} else {
operands += current
}
}
operands.toSeq
}

def scalarFunctionExprToProtoWithReturnType(
funcName: String,
returnType: DataType,
Expand Down
21 changes: 15 additions & 6 deletions spark/src/main/scala/org/apache/comet/serde/predicates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ object CometAnd extends CometExpressionSerde[And] {
expr: And,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
createBinaryExpr(
// Rebalance the (associative) AND chain so deep `a AND b AND ...` predicates produce a
// shallow proto instead of a left-deep one that overflows protobuf's recursion limit when
// the plan is re-parsed (see createBalancedBinaryExpr).
val operands = flattenAssociative(
expr,
expr.left,
expr.right,
{ case _: And => true; case _ => false },
{ case a: And => (a.left, a.right) })
createBalancedBinaryExpr(
expr,
operands,
inputs,
binding,
(builder, binaryExpr) => builder.setAnd(binaryExpr))
Expand All @@ -84,10 +90,13 @@ object CometOr extends CometExpressionSerde[Or] {
expr: Or,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
createBinaryExpr(
val operands = flattenAssociative(
expr,
expr.left,
expr.right,
{ case _: Or => true; case _ => false },
{ case o: Or => (o.left, o.right) })
createBalancedBinaryExpr(
expr,
operands,
inputs,
binding,
(builder, binaryExpr) => builder.setOr(binaryExpr))
Expand Down
37 changes: 36 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.time.{Duration, Period}
import scala.util.Random

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.{Column, CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, StructsToJson, TruncDate, TruncTimestamp}
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
import org.apache.spark.sql.comet.CometProjectExec
Expand Down Expand Up @@ -3096,4 +3096,39 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("deep AND/OR predicate chains do not overflow the protobuf recursion limit") {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operands here are never null, so the associativity-with-nulls case is not exercised. Would it be worth adding a nullable predicate into one of the chains to lock that in? A deep OR in a WHERE clause might also be worth a case, since that is a common trigger and stays intact rather than being split.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added both: the chains now mix in a nullable column so the rebalanced tree is exercised under three-valued logic, and there's a deep OR in a WHERE clause, which, unlike a top-level AND, Spark doesn't split, so it stays deeply nested.

// A left-deep chain of N associative boolean operands serializes to a proto nested N
// levels deep. With N > protobuf's default recursion limit (100), the message overflows
// when the serialized plan is re-parsed (JVM Operator.parseFrom and the Rust prost
// decoder), failing an otherwise-supported query. Comet evaluates AND/OR vectorially with
// no short-circuit, so the chain is fully associative and safe to rebalance.
val n = 200
// `_2` is nullable (every 7th row is null) so the rebalanced chain is exercised under SQL
// three-valued logic, not just true/false operands.
withParquetTable(
(0 until 100).map(i => (i, if (i % 7 == 0) None else Some(i.toLong))),
"tbl") {
// Build a chain that mixes the non-nullable `_1` with the nullable `_2` so null operands
// flow through the rebalanced tree.
def operand(i: Int): Column =
if (i % 2 == 0) col("_2") > lit(-i) else col("_1") > lit(-i)

// Project the chains as boolean columns rather than filtering: a top-level filter AND is
// split by Spark's splitConjunctivePredicates into many shallow pushed predicates, which
// would hide the deep-nesting. A projected expression survives intact. Distinct literals
// keep the optimizer from folding the chain; `>`/`<` (not `=`) keeps OptimizeIn from
// collapsing the OR chain into a single In.
val andChain = (1 to n).map(operand).reduce(_ && _)
checkSparkAnswerAndOperator(spark.table("tbl").select(andChain.as("a")))

val orChain = (1 to n).map(i => col("_1") < lit(i) || col("_2") < lit(i)).reduce(_ || _)
checkSparkAnswerAndOperator(spark.table("tbl").select(orChain.as("o")))

// A deep OR is a common real-world WHERE clause and, unlike a top-level AND, is NOT split
// by Spark -- it stays intact as a single deeply-nested predicate, so exercise that path
// directly.
checkSparkAnswerAndOperator(spark.table("tbl").where(orChain))
}
}

}
Loading