Skip to content

Commit 264510c

Browse files
committed
fix: cap bloom_filter_agg numItems/numBits and skip null inputs
The Spark 4.0 BloomFilterAggregateQuerySuite CI job aborted the executor with a multi-exabyte native allocation, and the Spark 3.4 CometExecRuleSuite job failed analysis. Three bloom-filter issues surfaced once this branch let bloom_filter_agg execute natively: - Spark's BloomFilterAggregate caps numItems/numBits at maxNumItems/ maxNumBits, but CometBloomFilterAggregate forwarded the raw literals. Comet's native aggregate stores them as i32, so an oversized Long (e.g. the Long.MaxValue cases in BloomFilterAggregateQuerySuite) wrapped to a negative size and triggered a 2^61-byte allocation. Apply the same cap in the serde so the native side receives Spark-equivalent values. - update_batch hit `unreachable!()` on a null input value. Spark's BloomFilterAggregate.update ignores nulls; skip them, and return an error rather than panicking on a genuinely unexpected type. - The new CometExecRuleSuite BloomFilter cases used an int column, which Spark 3.4's bloom_filter_agg rejects (it only accepts a long-typed first argument); cast to bigint. Adds a CometExec3_4PlusSuite regression test covering oversized numItems/numBits with a null-containing input.
1 parent d51c9d6 commit 264510c

4 files changed

Lines changed: 65 additions & 8 deletions

File tree

native/spark-expr/src/bloom_filter/bloom_filter_agg.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::bloom_filter::spark_bloom_filter::{SparkBloomFilter, SparkBloomFilter
2525
use arrow::array::ArrayRef;
2626
use arrow::array::BinaryArray;
2727
use datafusion::common::{downcast_value, ScalarValue};
28-
use datafusion::error::Result;
28+
use datafusion::error::{DataFusionError, Result};
2929
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
3030
use datafusion::logical_expr::{AggregateUDFImpl, Signature};
3131
use datafusion::physical_expr::expressions::Literal;
@@ -141,8 +141,16 @@ impl Accumulator for SparkBloomFilter {
141141
ScalarValue::Utf8(Some(value)) => {
142142
self.put_binary(value.as_bytes());
143143
}
144-
_ => {
145-
unreachable!()
144+
// Spark's BloomFilterAggregate.update ignores null inputs.
145+
ScalarValue::Int8(None)
146+
| ScalarValue::Int16(None)
147+
| ScalarValue::Int32(None)
148+
| ScalarValue::Int64(None)
149+
| ScalarValue::Utf8(None) => {}
150+
other => {
151+
return Err(DataFusionError::Internal(format!(
152+
"bloom_filter_agg received an unsupported input type: {other:?}"
153+
)));
146154
}
147155
}
148156
Ok(())

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import scala.jdk.CollectionConverters._
2323

24-
import org.apache.spark.sql.catalyst.expressions.Attribute
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
@@ -649,8 +649,20 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
649649
// We ignore mutableAggBufferOffset and inputAggBufferOffset because they are
650650
// implementation details for Spark's ObjectHashAggregate.
651651
val childExpr = exprToProto(bloomFilter.child, inputs, binding)
652-
val numItemsExpr = exprToProto(bloomFilter.estimatedNumItemsExpression, inputs, binding)
653-
val numBitsExpr = exprToProto(bloomFilter.numBitsExpression, inputs, binding)
652+
// Spark's BloomFilterAggregate caps numItems / numBits at the configured maxima
653+
// (its `estimatedNumItems` / `numBits` lazy vals). Comet's native aggregate stores
654+
// these as i32, so an uncapped Long literal (e.g. the Long.MaxValue cases in
655+
// BloomFilterAggregateQuerySuite) would wrap to a bogus negative size and abort the
656+
// executor with a multi-exabyte allocation. Apply the same cap here so the native
657+
// side always receives a sane, Spark-equivalent value.
658+
val numItems = math.min(
659+
bloomFilter.estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
660+
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
661+
val numBits = math.min(
662+
bloomFilter.numBitsExpression.eval().asInstanceOf[Number].longValue,
663+
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
664+
val numItemsExpr = exprToProto(Literal(numItems, LongType), inputs, binding)
665+
val numBitsExpr = exprToProto(Literal(numBits, LongType), inputs, binding)
654666
val dataType = serializeDataType(bloomFilter.dataType)
655667

656668
if (childExpr.isDefined &&

spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.scalatest.Tag
2929
import org.apache.spark.sql.CometTestBase
3030
import org.apache.spark.sql.catalyst.FunctionIdentifier
3131
import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, Expression, ExpressionInfo}
32+
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
3233
import org.apache.spark.sql.functions.{col, lit}
3334
import org.apache.spark.util.sketch.BloomFilter
3435

@@ -42,6 +43,7 @@ class CometExec3_4PlusSuite extends CometTestBase {
4243
import testImplicits._
4344

4445
val func_might_contain = new FunctionIdentifier("might_contain")
46+
val func_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
4547

4648
override def beforeAll(): Unit = {
4749
super.beforeAll()
@@ -51,12 +53,23 @@ class CometExec3_4PlusSuite extends CometTestBase {
5153
func_might_contain,
5254
new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
5355
(children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1)))
56+
// Register 'bloom_filter_agg' to builtin.
57+
spark.sessionState.functionRegistry.registerFunction(
58+
func_bloom_filter_agg,
59+
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
60+
(children: Seq[Expression]) =>
61+
children.size match {
62+
case 1 => new BloomFilterAggregate(children.head)
63+
case 2 => new BloomFilterAggregate(children.head, children(1))
64+
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
65+
})
5466
}
5567
}
5668

5769
override def afterAll(): Unit = {
5870
if (!isSpark42Plus) {
5971
spark.sessionState.functionRegistry.dropFunction(func_might_contain)
72+
spark.sessionState.functionRegistry.dropFunction(func_bloom_filter_agg)
6073
}
6174
super.afterAll()
6275
}
@@ -185,6 +198,24 @@ class CometExec3_4PlusSuite extends CometTestBase {
185198
}
186199
}
187200

201+
test("bloom_filter_agg caps oversized numItems / numBits like Spark") {
202+
assume(!isSpark42Plus, "https://github.com/apache/datafusion-comet/issues/4142")
203+
val table = "test"
204+
withTable(table) {
205+
sql(s"create table $table(col1 long) using parquet")
206+
sql(s"insert into $table values (1), (2), (3), (201), (null)")
207+
// numItems / numBits exceed the Int range. Spark's BloomFilterAggregate caps
208+
// them at maxNumItems / maxNumBits; Comet must apply the same cap, otherwise the
209+
// oversized values truncate to a negative i32 and abort the executor with a
210+
// multi-exabyte allocation.
211+
checkSparkAnswerAndOperator(s"""
212+
|SELECT bloom_filter_agg(col1,
213+
| cast(9223372036854775807 as long),
214+
| cast(9223372036854775807 as long)) FROM $table
215+
|""".stripMargin)
216+
}
217+
}
218+
188219
private def bloomFilterFromRandomInput(
189220
expectedItems: Long,
190221
expectedBits: Long): (Seq[Long], Array[Byte]) = {

spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ class CometExecRuleSuite extends CometTestBase {
247247
withTempView("test_data") {
248248
createTestDataFrame.createOrReplaceTempView("test_data")
249249

250-
val sparkPlan = createSparkPlan(spark, "SELECT bloom_filter_agg(id) FROM test_data")
250+
// Cast to bigint: Spark 3.4's bloom_filter_agg only accepts a long-typed first
251+
// argument; later versions widened it to any integral type.
252+
val sparkPlan =
253+
createSparkPlan(spark, "SELECT bloom_filter_agg(CAST(id AS BIGINT)) FROM test_data")
251254

252255
val originalObjectAggCount = countOperators(sparkPlan, classOf[ObjectHashAggregateExec])
253256
assert(originalObjectAggCount == 2)
@@ -283,7 +286,10 @@ class CometExecRuleSuite extends CometTestBase {
283286
withTempView("test_data") {
284287
createTestDataFrame.createOrReplaceTempView("test_data")
285288

286-
val sparkPlan = createSparkPlan(spark, "SELECT bloom_filter_agg(id) FROM test_data")
289+
// Cast to bigint: Spark 3.4's bloom_filter_agg only accepts a long-typed first
290+
// argument; later versions widened it to any integral type.
291+
val sparkPlan =
292+
createSparkPlan(spark, "SELECT bloom_filter_agg(CAST(id AS BIGINT)) FROM test_data")
287293

288294
val originalObjectAggCount = countOperators(sparkPlan, classOf[ObjectHashAggregateExec])
289295
assert(originalObjectAggCount == 2)

0 commit comments

Comments
 (0)