Skip to content

Commit fbadc91

Browse files
authored
fix: support Spark 4.1 BloomFilter V2 format and bit-scattering (#4196)
1 parent 7d5884f commit fbadc91

10 files changed

Lines changed: 328 additions & 63 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ use datafusion::{
6969
use datafusion_comet_spark_expr::{
7070
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle,
7171
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc,
72-
SumInteger, ToCsv,
72+
SparkBloomFilterVersion, SumInteger, ToCsv,
7373
};
7474
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
7575
use iceberg::expr::Bind;
@@ -2363,10 +2363,17 @@ impl PhysicalPlanner {
23632363
let num_bits =
23642364
self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?;
23652365
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
2366+
let version = match expr.version() {
2367+
spark_expression::BloomFilterVersion::V2 => SparkBloomFilterVersion::V2,
2368+
// Default (Unspecified or V1) preserves the pre-Spark-4.1 format that
2369+
// Comet has always emitted, keeping older Spark versions byte-equivalent.
2370+
_ => SparkBloomFilterVersion::V1,
2371+
};
23662372
let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
23672373
Arc::clone(&num_items),
23682374
Arc::clone(&num_bits),
23692375
datatype,
2376+
version,
23702377
));
23712378
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
23722379
}

native/proto/src/proto/expr.proto

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,17 @@ message BloomFilterAgg {
248248
Expr numItems = 2;
249249
Expr numBits = 3;
250250
DataType datatype = 4;
251+
// Output serialization version. Spark 4.0 and earlier always wrote V1; Spark
252+
// 4.1+ defaults to V2 (different bit-scattering algorithm and a `seed` field
253+
// in the binary format). The JVM serde sets this to the matching version so
254+
// Comet's aggregate output is byte-equivalent with Spark's.
255+
BloomFilterVersion version = 5;
256+
}
257+
258+
enum BloomFilterVersion {
259+
BLOOM_FILTER_VERSION_UNSPECIFIED = 0;
260+
BLOOM_FILTER_VERSION_V1 = 1;
261+
BLOOM_FILTER_VERSION_V2 = 2;
251262
}
252263

253264
message CollectSet {

native/spark-expr/benches/bloom_filter_agg.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion::physical_expr::expressions::{Column, Literal};
3030
use datafusion::physical_expr::PhysicalExpr;
3131
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
3232
use datafusion::physical_plan::ExecutionPlan;
33-
use datafusion_comet_spark_expr::BloomFilterAgg;
33+
use datafusion_comet_spark_expr::{BloomFilterAgg, SparkBloomFilterVersion};
3434
use futures::StreamExt;
3535
use std::hint::black_box;
3636
use std::sync::Arc;
@@ -66,6 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) {
6666
Arc::clone(&num_items),
6767
Arc::clone(&num_bits),
6868
DataType::Binary,
69+
SparkBloomFilterVersion::V1,
6970
)));
7071
b.to_async(&rt).iter(|| {
7172
black_box(agg_test(

native/spark-expr/src/bloom_filter/bloom_filter_agg.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
2020
use std::{any::Any, sync::Arc};
2121

2222
use crate::bloom_filter::spark_bloom_filter;
23-
use crate::bloom_filter::spark_bloom_filter::SparkBloomFilter;
23+
use crate::bloom_filter::spark_bloom_filter::{SparkBloomFilter, SparkBloomFilterVersion};
2424

2525
use arrow::array::ArrayRef;
2626
use arrow::array::BinaryArray;
@@ -37,6 +37,10 @@ pub struct BloomFilterAgg {
3737
signature: Signature,
3838
num_items: i32,
3939
num_bits: i32,
40+
/// Output serialization version. Spark <= 4.0 only knows V1; Spark 4.1+'s
41+
/// `BloomFilter.create` defaults to V2, so the JVM serde sets this to V2 on
42+
/// 4.1+ to keep `bloom_filter_agg` byte-equivalent with Spark's aggregator.
43+
version: SparkBloomFilterVersion,
4044
}
4145

4246
#[inline]
@@ -54,6 +58,7 @@ impl BloomFilterAgg {
5458
num_items: Arc<dyn PhysicalExpr>,
5559
num_bits: Arc<dyn PhysicalExpr>,
5660
data_type: DataType,
61+
version: SparkBloomFilterVersion,
5762
) -> Self {
5863
assert!(matches!(data_type, DataType::Binary));
5964
Self {
@@ -70,6 +75,7 @@ impl BloomFilterAgg {
7075
),
7176
num_items: extract_i32_from_literal(num_items),
7277
num_bits: extract_i32_from_literal(num_bits),
78+
version,
7379
}
7480
}
7581
}
@@ -92,10 +98,13 @@ impl AggregateUDFImpl for BloomFilterAgg {
9298
}
9399

94100
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
95-
Ok(Box::new(SparkBloomFilter::from((
101+
Ok(Box::new(SparkBloomFilter::new(
102+
self.version,
96103
spark_bloom_filter::optimal_num_hash_functions(self.num_items, self.num_bits),
97104
self.num_bits,
98-
))))
105+
// Spark's BloomFilterAggregate always uses BloomFilterImplV2.DEFAULT_SEED (= 0).
106+
0,
107+
)))
99108
}
100109

101110
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {

native/spark-expr/src/bloom_filter/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod bit;
2020

2121
mod spark_bit_array;
2222
mod spark_bloom_filter;
23+
pub use spark_bloom_filter::SparkBloomFilterVersion;
2324

2425
pub mod bloom_filter_agg;
2526
pub use bloom_filter_might_contain::BloomFilterMightContain;

0 commit comments

Comments
 (0)