@@ -21,6 +21,7 @@ package org.apache.comet.serde
2121
2222import java .util .concurrent .atomic .AtomicLong
2323
24+ import scala .collection .mutable .ArrayBuffer
2425import scala .jdk .CollectionConverters ._
2526
2627import org .apache .spark .internal .Logging
@@ -832,6 +833,81 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
832833 }
833834 }
834835
836+ /**
837+ * Serialize an associative boolean chain (`And` / `Or`) as a BALANCED `BinaryExpr` tree of
838+ * depth `O(log n)` instead of the natural left-deep `O(n)`. A query with many ANDed/ORed
839+ * predicates otherwise builds a proto nested deeper than protobuf's default recursion limit
840+ * (100), which overflows when the serialized plan is re-parsed -- on the JVM
841+ * (`OperatorOuterClass.Operator.parseFrom`, e.g. `findShuffleScanIndices` / explain) and in the
842+ * Rust prost decoder. Comet evaluates `And`/`Or` vectorially (both sides always evaluated, no
843+ * row-level short-circuit), so rebalancing the associative chain is semantically identical --
844+ * it only changes the proto's shape.
845+ *
846+ * `operands` are the flattened leaves of the chain (see [[flattenAssociative ]]); `wrap` tags
847+ * each combined `BinaryExpr` as `And` or `Or`.
848+ */
849+ def createBalancedBinaryExpr (
850+ expr : Expression ,
851+ operands : Seq [Expression ],
852+ inputs : Seq [Attribute ],
853+ binding : Boolean ,
854+ wrap : (
855+ ExprOuterClass .Expr .Builder ,
856+ ExprOuterClass .BinaryExpr ) => ExprOuterClass .Expr .Builder )
857+ : Option [ExprOuterClass .Expr ] = {
858+ val protos = operands.map(exprToProtoInternal(_, inputs, binding))
859+ if (protos.exists(_.isEmpty)) {
860+ withFallbackReason(expr, operands : _* )
861+ None
862+ } else {
863+ val leaves = protos.map(_.get).toIndexedSeq
864+ def build (slice : IndexedSeq [ExprOuterClass .Expr ]): ExprOuterClass .Expr = {
865+ if (slice.length == 1 ) slice.head
866+ else {
867+ val mid = slice.length / 2
868+ val inner = ExprOuterClass .BinaryExpr
869+ .newBuilder()
870+ .setLeft(build(slice.slice(0 , mid)))
871+ .setRight(build(slice.slice(mid, slice.length)))
872+ .build()
873+ wrap(ExprOuterClass .Expr .newBuilder(), inner).build()
874+ }
875+ }
876+ Some (build(leaves))
877+ }
878+ }
879+
880+ /**
881+ * Flatten an associative binary chain into its leaf operands, in left-to-right order. `matches`
882+ * identifies the same operator (e.g. `case _: And => true`) and `children` extracts its two
883+ * operands. Used to rebalance deep `And`/`Or` chains before serialization (see
884+ * [[createBalancedBinaryExpr ]]).
885+ *
886+ * Implemented with an explicit work stack and an accumulating buffer rather than recursion: the
887+ * chains that trigger this are left-deep and `O(n)` deep, so a recursive walk could itself
888+ * overflow the JVM stack, and `++`-accumulating the results would be `O(n^2)`.
889+ */
890+ def flattenAssociative (
891+ expr : Expression ,
892+ matches : Expression => Boolean ,
893+ children : Expression => (Expression , Expression )): Seq [Expression ] = {
894+ val operands = ArrayBuffer .empty[Expression ]
895+ var stack : List [Expression ] = expr :: Nil
896+ while (stack.nonEmpty) {
897+ val current = stack.head
898+ stack = stack.tail
899+ if (matches(current)) {
900+ val (l, r) = children(current)
901+ // Push right before left so the left subtree is popped (and emitted) first, preserving
902+ // the original left-to-right operand order.
903+ stack = l :: r :: stack
904+ } else {
905+ operands += current
906+ }
907+ }
908+ operands.toSeq
909+ }
910+
835911 def scalarFunctionExprToProtoWithReturnType (
836912 funcName : String ,
837913 returnType : DataType ,
0 commit comments