-
Notifications
You must be signed in to change notification settings - Fork 324
fix: make BloomFilter intermediate buffer Spark-compatible #4390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
f7fa33c
f2a8207
9826403
753a9a5
6ae483d
53405f6
9e2c25a
f53e3c1
3285485
671afa6
4322852
12018c3
43e0c0b
4bbfe74
56e5da6
08b3924
64575f2
8db42b0
36cf0e8
9bab432
2406272
16299be
d51c9d6
264510c
dbca4b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -272,16 +272,59 @@ impl SparkBloomFilter { | |
| } | ||
|
|
||
| pub fn state_as_bytes(&self) -> Vec<u8> { | ||
| self.bits.to_bytes() | ||
| self.spark_serialization() | ||
| } | ||
|
|
||
| pub fn merge_filter(&mut self, other: &[u8]) { | ||
| let mut offset = 0; | ||
|
|
||
| let version_int = read_num_be_bytes!(i32, 4, other[offset..]); | ||
| offset += 4; | ||
| assert_eq!( | ||
| version_int, | ||
| self.version.to_int(), | ||
| "BloomFilter merge: version mismatch (got {}, expected {})", | ||
| version_int, | ||
| self.version.to_int(), | ||
| ); | ||
|
|
||
| let num_hash = read_num_be_bytes!(i32, 4, other[offset..]) as u32; | ||
| offset += 4; | ||
| assert_eq!( | ||
| num_hash, self.num_hash_functions, | ||
| "BloomFilter merge: num_hash_functions mismatch (got {}, expected {})", | ||
| num_hash, self.num_hash_functions, | ||
| ); | ||
|
|
||
| if let SparkBloomFilterVersion::V2 = self.version { | ||
| let seed = read_num_be_bytes!(i32, 4, other[offset..]); | ||
| offset += 4; | ||
| assert_eq!( | ||
| seed, self.seed, | ||
| "BloomFilter merge: seed mismatch (got {}, expected {})", | ||
| seed, self.seed, | ||
| ); | ||
| } | ||
|
|
||
| let num_words = read_num_be_bytes!(i32, 4, other[offset..]) as usize; | ||
| offset += 4; | ||
| assert_eq!( | ||
| other.len(), | ||
| self.bits.byte_size(), | ||
| "Cannot merge SparkBloomFilters with different lengths." | ||
| num_words, | ||
| self.bits.word_size(), | ||
| "BloomFilter merge: num_words mismatch (got {}, expected {})", | ||
| num_words, | ||
| self.bits.word_size(), | ||
| ); | ||
| self.bits.merge_bits(other); | ||
|
|
||
| let words = self.bits.data(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This clones |
||
| let mut merged = Vec::with_capacity(words.len()); | ||
| for word in words.into_iter() { | ||
| let incoming = read_num_be_bytes!(i64, 8, other[offset..]) as u64; | ||
| offset += 8; | ||
| merged.push(word | incoming); | ||
| } | ||
| // `SparkBitArray::new` recomputes `bit_count` from the words. | ||
| self.bits = SparkBitArray::new(merged); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -396,4 +439,96 @@ mod tests { | |
| buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes | ||
| let _ = SparkBloomFilter::from(buf.as_slice()); | ||
| } | ||
|
|
||
| /// Two V1 filters with identical parameters. Populate the first, serialize via | ||
| /// state_as_bytes, merge into the empty second, and verify the second contains | ||
| /// everything the first did. Exercises the aggregator state → merge_batch path. | ||
| #[test] | ||
| fn state_round_trip_v1_merge() { | ||
| let num_bits = 1024; | ||
| let num_hash = optimal_num_hash_functions(100, num_bits); | ||
| let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); | ||
| for v in [1_i64, 7, 42, 99, -3, i64::MAX] { | ||
| a.put_long(v); | ||
| } | ||
|
|
||
| let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); | ||
| b.merge_filter(&a.state_as_bytes()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice that these go through |
||
|
|
||
| for v in [1_i64, 7, 42, 99, -3, i64::MAX] { | ||
| assert!(b.might_contain_long(v), "missing {v} after merge"); | ||
| } | ||
| } | ||
|
|
||
| /// V2 default seed (0) round-trip through state_as_bytes → merge_filter. | ||
| #[test] | ||
| fn state_round_trip_v2_default_seed() { | ||
| let num_bits = 1024; | ||
| let num_hash = optimal_num_hash_functions(100, num_bits); | ||
| let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); | ||
| for v in [11_i64, 222, 3333] { | ||
| a.put_long(v); | ||
| } | ||
|
|
||
| let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); | ||
| b.merge_filter(&a.state_as_bytes()); | ||
|
|
||
| for v in [11_i64, 222, 3333] { | ||
| assert!(b.might_contain_long(v)); | ||
| } | ||
| } | ||
|
|
||
| /// V2 non-zero seed round-trip; verifies the seed field is parsed and that | ||
| /// both filters use the same seed-dependent hash scattering. | ||
| #[test] | ||
| fn state_round_trip_v2_nonzero_seed() { | ||
| let num_bits = 1024; | ||
| let num_hash = optimal_num_hash_functions(100, num_bits); | ||
| let seed = 0x5eed_5eed_u32 as i32; | ||
| let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed); | ||
| a.put_long(123); | ||
|
|
||
| let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed); | ||
| b.merge_filter(&a.state_as_bytes()); | ||
|
|
||
| assert!(b.might_contain_long(123)); | ||
| } | ||
|
|
||
| #[test] | ||
| #[should_panic(expected = "version mismatch")] | ||
| fn merge_rejects_version_mismatch() { | ||
| let num_bits = 1024; | ||
| let num_hash = optimal_num_hash_functions(100, num_bits); | ||
| let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); | ||
| let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); | ||
| b.merge_filter(&a.state_as_bytes()); | ||
| } | ||
|
|
||
| #[test] | ||
| #[should_panic(expected = "num_hash_functions mismatch")] | ||
| fn merge_rejects_num_hash_mismatch() { | ||
| let num_bits = 1024; | ||
| let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 5, num_bits, 0); | ||
| let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 7, num_bits, 0); | ||
| b.merge_filter(&a.state_as_bytes()); | ||
| } | ||
|
|
||
| #[test] | ||
| #[should_panic(expected = "seed mismatch")] | ||
| fn merge_rejects_seed_mismatch_v2() { | ||
| let num_bits = 1024; | ||
| let num_hash = optimal_num_hash_functions(100, num_bits); | ||
| let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 1); | ||
| let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 2); | ||
| b.merge_filter(&a.state_as_bytes()); | ||
| } | ||
|
|
||
| #[test] | ||
| #[should_panic(expected = "num_words mismatch")] | ||
| fn merge_rejects_num_words_mismatch() { | ||
| let num_hash = 5; | ||
| let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 512, 0); | ||
| let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 1024, 0); | ||
| b.merge_filter(&a.state_as_bytes()); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1697,6 +1697,20 @@ object CometObjectHashAggregateExec | |
| override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( | ||
| CometConf.COMET_EXEC_AGGREGATE_ENABLED) | ||
|
|
||
| override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The body here is identical to |
||
| // Mirror the same test-knobs as CometHashAggregateExec so that mixed-execution | ||
| // unit tests can selectively disable partial or final ObjectHashAggregateExec conversion. | ||
| if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) && | ||
| op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) { | ||
| return Unsupported(Some("Partial aggregates disabled via test config")) | ||
| } | ||
| if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) && | ||
| op.aggregateExpressions.exists(_.mode == Final)) { | ||
| return Unsupported(Some("Final aggregates disabled via test config")) | ||
| } | ||
| Compatible() | ||
| } | ||
|
|
||
| override def convert( | ||
| aggregate: ObjectHashAggregateExec, | ||
| builder: Operator.Builder, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge_filterispub fn ... -> ()and panics viaassert_eq!on every header mismatch. Its only caller isAccumulator::merge_batchinbloom_filter_agg.rs:176, which already returnsResult. Threading these through asDataFusionError::Internalwould let a corrupt or truncated intermediate buffer surface as a query failure rather than crashing the executor process.