Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f7fa33c
fix: allow safe mixed Spark/Comet partial/final aggregate execution
andygrove Apr 21, 2026
f2a8207
fix: address review feedback on mixed partial/final aggregate guard
andygrove Apr 21, 2026
9826403
fix: skip partial aggregate tag when partial itself cannot be converted
andygrove Apr 21, 2026
753a9a5
fix: narrow partial aggregate tag lookup and regenerate TPC-DS golden…
andygrove Apr 21, 2026
6ae483d
fix: reject grouping on nested map types in hash aggregate conversion
andygrove Apr 21, 2026
53405f6
fix: remove COUNT from mixed-safe aggregates to fix AQE/count-bug reg…
andygrove Apr 22, 2026
9e2c25a
spotless
andygrove Apr 22, 2026
f53e3c1
test: ignore SPARK-33853 explain codegen subquery test under Comet
andygrove Apr 23, 2026
3285485
Merge remote-tracking branch 'apache/main' into fix/safe-mixed-partia…
andygrove Apr 25, 2026
671afa6
Merge remote-tracking branch 'apache/main' into fix/safe-mixed-partia…
andygrove May 6, 2026
4322852
test: regenerate Spark 4.2 TPC-DS golden files after merge from main
andygrove May 6, 2026
12018c3
Merge remote-tracking branch 'apache/main' into fix/safe-mixed-partia…
andygrove May 20, 2026
43e0c0b
fix: address review feedback on safe mixed aggregate guard
andygrove May 20, 2026
4bbfe74
fix: drop unused StructType import and regenerate TPC-DS golden files
andygrove May 20, 2026
56e5da6
chore: revert .gitignore change
andygrove May 20, 2026
08b3924
test: ignore SPARK-33853 explain codegen test on Spark 4.1.1
andygrove May 21, 2026
64575f2
test: use descriptive reason for SPARK-33853 IgnoreComet tag
andygrove May 21, 2026
8db42b0
fix: emit Spark-compatible BloomFilter intermediate buffer
andygrove May 21, 2026
36cf0e8
feat: enable BloomFilter for mixed Spark/Comet partial/final aggregate
andygrove May 21, 2026
9bab432
Merge remote-tracking branch 'apache/main' into feat/bloom-filter-int…
andygrove May 21, 2026
2406272
refactor: move SparkBitArray test-only methods into the test module
andygrove May 21, 2026
16299be
refactor: address review feedback on SparkBloomFilter::merge_filter
andygrove May 21, 2026
d51c9d6
Merge remote-tracking branch 'apache/main' into feat/bloom-filter-int…
andygrove May 21, 2026
264510c
fix: cap bloom_filter_agg numItems/numBits and skip null inputs
andygrove May 22, 2026
dbca4b2
fix: return NULL from bloom_filter_agg on empty input [skip ci]
andygrove May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 35 additions & 34 deletions native/spark-expr/src/bloom_filter/spark_bit_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::ToByteSlice;
use std::iter::zip;

/// A simple bit array implementation that simulates the behavior of Spark's BitArray which is
/// used in the BloomFilter implementation. Some methods are not implemented as they are not
/// required for the current use case.
Expand Down Expand Up @@ -61,44 +58,13 @@ impl SparkBitArray {
self.word_size() as u64 * 64
}

pub fn byte_size(&self) -> usize {
self.word_size() * 8
}

pub fn word_size(&self) -> usize {
self.data.len()
}

#[allow(dead_code)] // this is only called from tests
pub fn cardinality(&self) -> usize {
self.bit_count
}

pub fn to_bytes(&self) -> Vec<u8> {
Vec::from(self.data.to_byte_slice())
}

pub fn data(&self) -> Vec<u64> {
self.data.clone()
}

// Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from an
// Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word vector.
pub fn merge_bits(&mut self, other: &[u8]) {
assert_eq!(self.byte_size(), other.len());
let mut bit_count: usize = 0;
// For each word, merge the bits into self, and accumulate a new bit_count.
for i in zip(
self.data.iter_mut(),
other
.chunks(8)
.map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
) {
*i.0 |= i.1;
bit_count += i.0.count_ones() as usize;
}
self.bit_count = bit_count;
}
}

pub fn num_words(num_bits: usize) -> usize {
Expand All @@ -108,6 +74,41 @@ pub fn num_words(num_bits: usize) -> usize {
#[cfg(test)]
mod test {
use super::*;
use arrow::datatypes::ToByteSlice;
use std::iter::zip;

impl SparkBitArray {
fn byte_size(&self) -> usize {
self.word_size() * 8
}

fn cardinality(&self) -> usize {
self.bit_count
}

fn to_bytes(&self) -> Vec<u8> {
Vec::from(self.data.to_byte_slice())
}

/// Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from
/// an Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word
/// vector.
fn merge_bits(&mut self, other: &[u8]) {
assert_eq!(self.byte_size(), other.len());
let mut bit_count: usize = 0;
// For each word, merge the bits into self, and accumulate a new bit_count.
for i in zip(
self.data.iter_mut(),
other
.chunks(8)
.map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
) {
*i.0 |= i.1;
bit_count += i.0.count_ones() as usize;
}
self.bit_count = bit_count;
}
}

#[test]
fn test_spark_bit_array() {
Expand Down
145 changes: 140 additions & 5 deletions native/spark-expr/src/bloom_filter/spark_bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,59 @@ impl SparkBloomFilter {
}

pub fn state_as_bytes(&self) -> Vec<u8> {
self.bits.to_bytes()
self.spark_serialization()
}

pub fn merge_filter(&mut self, other: &[u8]) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge_filter is pub fn ... -> () and panics via assert_eq! on every header mismatch. Its only caller is Accumulator::merge_batch in bloom_filter_agg.rs:176, which already returns Result. Threading these through as DataFusionError::Internal would let a corrupt or truncated intermediate buffer surface as a query failure rather than crashing the executor process.

let mut offset = 0;

let version_int = read_num_be_bytes!(i32, 4, other[offset..]);
offset += 4;
assert_eq!(
version_int,
self.version.to_int(),
"BloomFilter merge: version mismatch (got {}, expected {})",
version_int,
self.version.to_int(),
);

let num_hash = read_num_be_bytes!(i32, 4, other[offset..]) as u32;
offset += 4;
assert_eq!(
num_hash, self.num_hash_functions,
"BloomFilter merge: num_hash_functions mismatch (got {}, expected {})",
num_hash, self.num_hash_functions,
);

if let SparkBloomFilterVersion::V2 = self.version {
let seed = read_num_be_bytes!(i32, 4, other[offset..]);
offset += 4;
assert_eq!(
seed, self.seed,
"BloomFilter merge: seed mismatch (got {}, expected {})",
seed, self.seed,
);
}

let num_words = read_num_be_bytes!(i32, 4, other[offset..]) as usize;
offset += 4;
assert_eq!(
other.len(),
self.bits.byte_size(),
"Cannot merge SparkBloomFilters with different lengths."
num_words,
self.bits.word_size(),
"BloomFilter merge: num_words mismatch (got {}, expected {})",
num_words,
self.bits.word_size(),
);
self.bits.merge_bits(other);

let words = self.bits.data();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This clones self.bits.data() into words, builds a new Vec of the same size, and then constructs a fresh SparkBitArray whose new re-scans the words to recompute bit_count. Two allocations and two passes per merge. An in-place variant that ORs into self.bits.data directly and accumulates bit_count in the same loop would avoid both.

let mut merged = Vec::with_capacity(words.len());
for word in words.into_iter() {
let incoming = read_num_be_bytes!(i64, 8, other[offset..]) as u64;
offset += 8;
merged.push(word | incoming);
}
// `SparkBitArray::new` recomputes `bit_count` from the words.
self.bits = SparkBitArray::new(merged);
}
}

Expand Down Expand Up @@ -396,4 +439,96 @@ mod tests {
buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes
let _ = SparkBloomFilter::from(buf.as_slice());
}

/// Two V1 filters with identical parameters. Populate the first, serialize via
/// state_as_bytes, merge into the empty second, and verify the second contains
/// everything the first did. Exercises the aggregator state → merge_batch path.
#[test]
fn state_round_trip_v1_merge() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
for v in [1_i64, 7, 42, 99, -3, i64::MAX] {
a.put_long(v);
}

let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
b.merge_filter(&a.state_as_bytes());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice that these go through state_as_bytes then merge_filter. The pre-existing v1_round_trip and v2_round_trip tests above use SparkBloomFilter::from, which was always header-aware, so they would not have caught the bug this PR fixes. The new cases are exactly the right round-trip path


for v in [1_i64, 7, 42, 99, -3, i64::MAX] {
assert!(b.might_contain_long(v), "missing {v} after merge");
}
}

/// V2 default seed (0) round-trip through state_as_bytes → merge_filter.
#[test]
fn state_round_trip_v2_default_seed() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
for v in [11_i64, 222, 3333] {
a.put_long(v);
}

let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
b.merge_filter(&a.state_as_bytes());

for v in [11_i64, 222, 3333] {
assert!(b.might_contain_long(v));
}
}

/// V2 non-zero seed round-trip; verifies the seed field is parsed and that
/// both filters use the same seed-dependent hash scattering.
#[test]
fn state_round_trip_v2_nonzero_seed() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let seed = 0x5eed_5eed_u32 as i32;
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed);
a.put_long(123);

let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed);
b.merge_filter(&a.state_as_bytes());

assert!(b.might_contain_long(123));
}

#[test]
#[should_panic(expected = "version mismatch")]
fn merge_rejects_version_mismatch() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
b.merge_filter(&a.state_as_bytes());
}

#[test]
#[should_panic(expected = "num_hash_functions mismatch")]
fn merge_rejects_num_hash_mismatch() {
let num_bits = 1024;
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 5, num_bits, 0);
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 7, num_bits, 0);
b.merge_filter(&a.state_as_bytes());
}

#[test]
#[should_panic(expected = "seed mismatch")]
fn merge_rejects_seed_mismatch_v2() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 1);
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 2);
b.merge_filter(&a.state_as_bytes());
}

#[test]
#[should_panic(expected = "num_words mismatch")]
fn merge_rejects_num_words_mismatch() {
let num_hash = 5;
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 512, 0);
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 1024, 0);
b.merge_filter(&a.state_as_bytes());
}
}
2 changes: 2 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,8 @@ object CometCorr extends CometAggregateExpressionSerde[Corr] {

object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilterAggregate] {

override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bloomFilter: BloomFilterAggregate,
Expand Down
14 changes: 14 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,20 @@ object CometObjectHashAggregateExec
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)

override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The body here is identical to CometHashAggregateExec.getSupportLevel at operators.scala:1658-1670, including the conf names. That is fine for the test-knob purpose called out in the comment, but COMET_ENABLE_PARTIAL_HASH_AGGREGATE and COMET_ENABLE_FINAL_HASH_AGGREGATE now gate both HashAggregateExec and ObjectHashAggregateExec. As a follow-up, consider renaming to COMET_ENABLE_PARTIAL_AGGREGATE / COMET_ENABLE_FINAL_AGGREGATE so the conf names match the scope.

// Mirror the same test-knobs as CometHashAggregateExec so that mixed-execution
// unit tests can selectively disable partial or final ObjectHashAggregateExec conversion.
if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) {
return Unsupported(Some("Partial aggregates disabled via test config"))
}
if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(_.mode == Final)) {
return Unsupported(Some("Final aggregates disabled via test config"))
}
Compatible()
}

override def convert(
aggregate: ObjectHashAggregateExec,
builder: Operator.Builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ package org.apache.comet.rules
import scala.util.Random

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.QueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.CometSparkSessionExtensions.{isSpark40Plus, isSpark42Plus}
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}

/**
Expand Down Expand Up @@ -228,6 +231,77 @@ class CometExecRuleSuite extends CometTestBase {
}
}

test("CometExecRule should allow BloomFilter mixed Comet partial and Spark final") {
assume(!isSpark42Plus, "https://github.com/apache/datafusion-comet/issues/4142")
val funcId = new FunctionIdentifier("bloom_filter_agg")
spark.sessionState.functionRegistry.registerFunction(
funcId,
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
(children: Seq[Expression]) =>
children.size match {
case 1 => new BloomFilterAggregate(children.head)
case 2 => new BloomFilterAggregate(children.head, children(1))
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
})
try {
withTempView("test_data") {
createTestDataFrame.createOrReplaceTempView("test_data")

val sparkPlan = createSparkPlan(spark, "SELECT bloom_filter_agg(id) FROM test_data")

val originalObjectAggCount = countOperators(sparkPlan, classOf[ObjectHashAggregateExec])
assert(originalObjectAggCount == 2)

withSQLConf(
CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false",
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
val transformedPlan = applyCometExecRule(sparkPlan)

// BloomFilter is mixed-safe: partial converts to Comet, final stays Spark.
assert(countOperators(transformedPlan, classOf[ObjectHashAggregateExec]) == 1)
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1)
}
}
} finally {
spark.sessionState.functionRegistry.dropFunction(funcId)
}
}

test("CometExecRule should allow BloomFilter mixed Spark partial and Comet final") {
assume(!isSpark42Plus, "https://github.com/apache/datafusion-comet/issues/4142")
val funcId = new FunctionIdentifier("bloom_filter_agg")
spark.sessionState.functionRegistry.registerFunction(
funcId,
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
(children: Seq[Expression]) =>
children.size match {
case 1 => new BloomFilterAggregate(children.head)
case 2 => new BloomFilterAggregate(children.head, children(1))
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
})
try {
withTempView("test_data") {
createTestDataFrame.createOrReplaceTempView("test_data")

val sparkPlan = createSparkPlan(spark, "SELECT bloom_filter_agg(id) FROM test_data")

val originalObjectAggCount = countOperators(sparkPlan, classOf[ObjectHashAggregateExec])
assert(originalObjectAggCount == 2)

withSQLConf(
CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false",
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
val transformedPlan = applyCometExecRule(sparkPlan)

assert(countOperators(transformedPlan, classOf[ObjectHashAggregateExec]) == 1)
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1)
}
}
} finally {
spark.sessionState.functionRegistry.dropFunction(funcId)
}
}

test("CometExecRule should not convert hash aggregate when grouping key contains map type") {
// Spark 3.4/3.5 reject `array<map<...>>` as a grouping key in the analyzer (not orderable),
// so the plan never reaches CometExecRule on those versions. The guard we're exercising
Expand Down
Loading