Skip to content

Commit d90d6f0

Browse files
committed
implement basic serialization and deserialization of diff counts as bytes
1 parent 6a72711 commit d90d6f0

2 files changed

Lines changed: 163 additions & 5 deletions

File tree

crates/geo_filters/src/diff_count.rs

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::borrow::Cow;
44
use std::cmp::Ordering;
55
use std::hash::BuildHasher as _;
66
use std::mem::{size_of, size_of_val};
7+
use std::ops::Deref;
78

89
use crate::config::{
910
count_ones_from_bitchunks, count_ones_from_msb_and_lsb, iter_bit_chunks, iter_ones,
@@ -77,7 +78,7 @@ impl<C: GeoConfig<Diff>> std::fmt::Debug for GeoDiffCount<'_, C> {
7778
}
7879
}
7980

80-
impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
81+
impl<'a, C: GeoConfig<Diff>> GeoDiffCount<'a, C> {
8182
pub fn new(config: C) -> Self {
8283
Self {
8384
config,
@@ -86,6 +87,49 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
8687
}
8788
}
8889

90+
pub fn from_bytes(c: C, buf: &'a [u8]) -> Self {
91+
if buf.is_empty() {
92+
return Self::new(c);
93+
}
94+
95+
// The number of most significant bits stores in the MSB sparse repr
96+
let msb_len = (buf.len() / size_of::<C::BucketType>()).min(c.max_msb_len());
97+
98+
let msb =
99+
unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const C::BucketType, msb_len) };
100+
101+
// The number of bytes representing the MSB - this is how many bytes we need to
102+
// skip over to reach the LSB
103+
let msb_bytes_len = msb_len * size_of::<C::BucketType>();
104+
105+
Self {
106+
config: c,
107+
msb: Cow::Borrowed(msb),
108+
lsb: BitVec::from_bytes(&buf[msb_bytes_len..]),
109+
}
110+
}
111+
112+
pub fn write<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
113+
if self.msb.is_empty() {
114+
return Ok(0);
115+
}
116+
117+
let msb_buckets = self.msb.deref();
118+
let msb_bytes = unsafe {
119+
std::slice::from_raw_parts(
120+
msb_buckets.as_ptr() as *const u8,
121+
msb_buckets.len() * size_of::<C::BucketType>(),
122+
)
123+
};
124+
writer.write_all(msb_bytes)?;
125+
126+
let mut bytes_written = msb_bytes.len();
127+
128+
bytes_written += self.lsb.write(writer)?;
129+
130+
Ok(bytes_written)
131+
}
132+
89133
/// `BitChunk`s can be processed much more efficiently than individual one bits!
90134
/// This function makes it possible to construct a GeoDiffCount instance directly from
91135
/// `BitChunk`s. It will extract the most significant bits first and then put the remainder
@@ -208,16 +252,23 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
208252
/// that makes the cost of the else case negligible.
209253
fn xor_bit(&mut self, bucket: C::BucketType) {
210254
if bucket.into_usize() < self.lsb.num_bits() {
255+
// The bit being toggled is within our LSB bit vector
256+
// so toggle it directly.
211257
self.lsb.toggle(bucket.into_usize());
212258
} else {
213259
let msb = self.msb.to_mut();
214260
match msb.binary_search_by(|k| bucket.cmp(k)) {
215261
Ok(idx) => {
262+
// The bit is already set in the MSB sparse bitset, remove it (XOR)
216263
msb.remove(idx);
264+
265+
// We have removed a value from our MSB, move a value in the
266+
// LSB into the MSB
217267
let (first, second) = {
218268
let mut lsb = iter_ones(self.lsb.bit_chunks().peekable());
219269
(lsb.next(), lsb.next())
220270
};
271+
221272
let new_smallest = if let Some(smallest) = first {
222273
msb.push(C::BucketType::from_usize(smallest));
223274
second.map(|_| smallest).unwrap_or(0)
@@ -229,15 +280,19 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
229280
Err(idx) => {
230281
msb.insert(idx, bucket);
231282
if msb.len() > self.config.max_msb_len() {
283+
// We have too many values in the MSB sparse index vector,
284+
// let's move the smalles MSB value into the LSB bit vector
232285
let smallest = msb
233286
.pop()
234287
.expect("we should have at least one element!")
235288
.into_usize();
236-
// ensure vector covers smallest
289+
237290
let new_smallest = msb
238291
.last()
239292
.expect("should have at least one element")
240293
.into_usize();
294+
295+
// ensure LSB bit vector has the space for `smallest`
241296
self.lsb.resize(new_smallest);
242297
self.lsb.toggle(smallest);
243298
}
@@ -360,7 +415,10 @@ impl<C: GeoConfig<Diff>> Count<Diff> for GeoDiffCount<'_, C> {
360415
#[cfg(test)]
361416
mod tests {
362417
use itertools::Itertools;
363-
use rand::{RngCore, SeedableRng};
418+
use rand::{
419+
seq::{IndexedRandom, IteratorRandom},
420+
RngCore, SeedableRng,
421+
};
364422

365423
use crate::{
366424
build_hasher::UnstableDefaultBuildHasher,
@@ -580,4 +638,44 @@ mod tests {
580638
iter_ones(self.bit_chunks().peekable()).map(C::BucketType::from_usize)
581639
}
582640
}
641+
642+
#[test]
643+
fn test_serialization_empty() {
644+
let before = GeoDiffCount7::default();
645+
646+
let mut writer = vec![];
647+
before.write(&mut writer).unwrap();
648+
649+
assert_eq!(writer.len(), 0);
650+
651+
let after = GeoDiffCount7::from_bytes(before.config.clone(), &writer);
652+
653+
assert_eq!(before, after);
654+
}
655+
656+
#[test]
657+
fn test_serialization_round_trip() {
658+
let mut rnd = rand::rngs::StdRng::from_os_rng();
659+
660+
// Run 100 simulations of random values being put into
661+
// a diff counter
662+
for _ in 0..100 {
663+
let mut before = GeoDiffCount7::default();
664+
665+
// Select a random number of items to insert
666+
let items = (1..1000).choose(&mut rnd).unwrap();
667+
668+
for _ in 0..items {
669+
before.push_hash(rnd.next_u64());
670+
}
671+
672+
let mut writer = vec![];
673+
674+
before.write(&mut writer).unwrap();
675+
676+
let after = GeoDiffCount7::from_bytes(before.config.clone(), &writer);
677+
678+
assert_eq!(before, after);
679+
}
680+
}
583681
}

crates/geo_filters/src/diff_count/bitvec.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ use std::borrow::Cow;
22
use std::cmp::Ordering;
33
use std::iter::Peekable;
44
use std::mem::{size_of, size_of_val};
5-
use std::ops::{Index, Range};
5+
use std::ops::{Deref as _, Index, Range};
66

7-
use crate::config::BitChunk;
87
use crate::config::IsBucketType;
98
use crate::config::BITS_PER_BLOCK;
9+
use crate::config::{BitChunk, BYTES_PER_BLOCK};
1010

1111
/// A bit vector where every bit occupies exactly one bit (in contrast to `Vec<bool>` where each
1212
/// bit consumes 1 byte). It only implements the minimum number of operations that we need for our
@@ -34,6 +34,62 @@ impl PartialOrd for BitVec<'_> {
3434
}
3535

3636
impl BitVec<'_> {
37+
pub fn from_bytes(mut buf: &[u8]) -> Self {
38+
if buf.is_empty() {
39+
return Self::default();
40+
}
41+
42+
// The first byte of the serialized BitVec is used to indicate how many
43+
// of the bits in the left-most byte are *unoccupied*.
44+
// See [`BitVec::write`] implementation for how this is done.
45+
assert!(
46+
buf[0] < 64,
47+
"Number of unoccupied bits should be <64, got {}",
48+
buf[0]
49+
);
50+
51+
let num_bits = (buf.len() - 1) * 8 - buf[0] as usize;
52+
buf = &buf[1..];
53+
54+
assert_eq!(
55+
buf.len() % BYTES_PER_BLOCK,
56+
0,
57+
"buffer should be a multiple of 8 bytes, got {}",
58+
buf.len()
59+
);
60+
61+
let blocks = unsafe {
62+
std::mem::transmute(std::slice::from_raw_parts(
63+
buf.as_ptr(),
64+
buf.len() / BYTES_PER_BLOCK,
65+
))
66+
};
67+
let blocks = Cow::Borrowed(blocks);
68+
69+
Self { num_bits, blocks }
70+
}
71+
72+
pub fn write<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
73+
if self.is_empty() {
74+
return Ok(0);
75+
}
76+
77+
// First serialize the number of unoccupied bits in the last block as one byte.
78+
let unoccupied_bits = 63 - ((self.num_bits - 1) % 64) as u8;
79+
80+
writer.write_all(&[unoccupied_bits])?;
81+
82+
let blocks = self.blocks.deref();
83+
84+
let block_bytes = unsafe {
85+
std::slice::from_raw_parts(blocks.as_ptr() as *const u8, blocks.len() * BYTES_PER_BLOCK)
86+
};
87+
88+
writer.write_all(block_bytes)?;
89+
90+
Ok(block_bytes.len() + 1)
91+
}
92+
3793
/// Takes an iterator of `BitChunk` items as input and returns the corresponding `BitVec`.
3894
/// The order of `BitChunk`s doesn't matter for this function and `BitChunk` may be hitting
3995
/// the same block. In this case, the function will simply xor them together.
@@ -81,6 +137,10 @@ impl BitVec<'_> {
81137
self.num_bits
82138
}
83139

140+
pub fn is_empty(&self) -> bool {
141+
self.num_bits() == 0
142+
}
143+
84144
/// Tests the bit specified by the provided zero-based bit position.
85145
pub fn test_bit(&self, index: usize) -> bool {
86146
assert!(index < self.num_bits);

0 commit comments

Comments
 (0)