Skip to content

Commit 540b308

Browse files
committed
sorting benchmark
1 parent f6f4c21 commit 540b308

3 files changed

Lines changed: 197 additions & 66 deletions

File tree

crates/hash-sorted-map/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ repository = "https://github.com/github/rust-gems"
88
license = "MIT"
99
keywords = ["hashmap", "sorted", "merge", "simd"]
1010
categories = ["algorithms", "data-structures"]
11+
12+
[dependencies]
13+
smallvec = "1"

crates/hash-sorted-map/benchmarks/performance.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::hash::BuildHasher;
2+
13
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
24
use hash_sorted_map::HashSortedMap;
35
use hash_sorted_map_benchmarks::{random_trigram_hashes, IdentityBuildHasher};
@@ -393,12 +395,46 @@ fn bench_iter(c: &mut Criterion) {
393395
group.finish();
394396
}
395397

398+
fn bench_sort(c: &mut Criterion) {
399+
let keys = random_trigram_hashes(100_000);
400+
let hasher = IdentityBuildHasher::default();
401+
let mut group = c.benchmark_group("sort_100000_trigrams");
402+
403+
group.bench_function("Vec::sort_unstable", |b| {
404+
b.iter(|| {
405+
let mut vec: Vec<_> = keys
406+
.iter()
407+
.enumerate()
408+
.map(|(i, &key)| (key, i))
409+
.collect();
410+
vec.sort_unstable_by_key(|&(key, _)| hasher.hash_one(key));
411+
vec
412+
});
413+
});
414+
415+
group.bench_function("HashSortedMap sort_by_hash", |b| {
416+
b.iter(|| {
417+
let mut map = HashSortedMap::with_capacity_and_hasher(
418+
keys.len(),
419+
IdentityBuildHasher::default(),
420+
);
421+
for (i, &key) in keys.iter().enumerate() {
422+
map.insert(key, i);
423+
}
424+
map.sort_by_hash()
425+
});
426+
});
427+
428+
group.finish();
429+
}
430+
396431
criterion_group!(
397432
benches,
398433
bench_insert,
399434
bench_reinsert,
400435
bench_grow,
401436
bench_count,
402-
bench_iter
437+
bench_iter,
438+
bench_sort
403439
);
404440
criterion_main!(benches);

crates/hash-sorted-map/src/hash_sorted_map.rs

Lines changed: 157 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::collections::hash_map::RandomState;
44
use std::hash::{BuildHasher, Hash};
55
use std::marker::PhantomData;
66

7+
use smallvec::SmallVec;
8+
79
use super::container::HashSortedContainer;
810
use super::group::Group;
911
use super::group_ops::{self, CTRL_EMPTY, GROUP_SIZE};
@@ -87,64 +89,79 @@ impl<K, V, S> HashSortedMap<K, V, S> {
8789
}
8890

8991
impl<K: Hash + Eq, V, S: BuildHasher> HashSortedMap<K, V, S> {
90-
/// Sort all entries within each primary group chain by their hash value.
92+
/// Sort all entries within each primary group chain by their hash value
93+
/// and return the underlying container.
9194
///
9295
/// After sorting, iteration visits entries in hash order within each
9396
/// primary group (and since primary groups are visited in group-index
9497
/// order, the overall iteration is in full hash order).
9598
///
96-
/// This is a one-time operation intended to be called before iteration
97-
/// or serialization. After sorting, lookups via `get()` won't work
98-
/// correctly because the preferred `slot_hint` position might now be empty
99-
/// breaking an invariant.
100-
pub fn sort_by_hash(&mut self) {
99+
/// Consumes the map because lookups via `get()` won't work correctly
100+
/// after sorting (the preferred `slot_hint` position might now be empty,
101+
/// breaking an invariant).
102+
///
103+
/// # Complexity
104+
///
105+
/// Each of `n` elements hashes uniformly into one of `m` primary groups,
106+
/// so chain lengths follow `X_i ~ Binomial(n, 1/m)` with `E[X_i] = n/m`.
107+
/// With a quadratic sort per chain the total expected cost is:
108+
///
109+
/// ```text
110+
/// Σ E[X_i²] = m · (Var[X_i] + E[X_i]²)
111+
/// = m · (n/m · (1 − 1/m) + n²/m²)
112+
/// = n · (1 − 1/m) + n²/m
113+
/// ```
114+
///
115+
/// Dividing by `n` gives the expected cost per element: `1 + n/m` (for
116+
/// `m ≫ 1`). Since `n/m` is the average chain length, bounded by
117+
/// `GROUP_SIZE / MAX_FILL ≈ 16`, the per-element cost stays constant.
118+
pub fn sort_by_hash(mut self) -> HashSortedContainer<K, V> {
101119
let num_primary = 1usize << self.container.n_bits;
102-
let mut buf: Vec<(u64, K, V)> = Vec::new();
120+
let mut chain: SmallVec<[u32; 4]> = SmallVec::new();
121+
let mut hashes: SmallVec<[u64; 16]> = SmallVec::new();
122+
103123
for primary_gi in 0..num_primary {
104-
buf.clear();
105-
// Extract all entries from this primary group's chain.
124+
chain.clear();
125+
hashes.clear();
126+
127+
// Collect group indices in this chain.
106128
let mut gi = primary_gi;
107129
loop {
108-
let group = &mut self.container.groups[gi];
109-
let mut full_mask = group_ops::match_full(&group.ctrl);
110-
while let Some(slot) = group_ops::next_match(&mut full_mask) {
111-
let key = unsafe { group.keys[slot].assume_init_read() };
112-
let value = unsafe { group.values[slot].assume_init_read() };
113-
let hash = self.hash_builder.hash_one(&key);
114-
buf.push((hash, key, value));
115-
group.ctrl[slot] = CTRL_EMPTY;
116-
}
117-
if group.overflow == NO_OVERFLOW {
130+
chain.push(gi as u32);
131+
let overflow = self.container.groups[gi].overflow;
132+
if overflow == NO_OVERFLOW {
118133
break;
119134
}
120-
gi = group.overflow as usize;
135+
gi = overflow as usize;
121136
}
122-
if buf.len() <= 1 {
123-
// 0 or 1 entry — write back to slot 0 if present (already extracted).
124-
if let Some((hash, key, value)) = buf.pop() {
125-
let group = &mut self.container.groups[primary_gi];
126-
group.ctrl[0] = tag(hash);
127-
group.keys[0] = MaybeUninit::new(key);
128-
group.values[0] = MaybeUninit::new(value);
137+
// All groups before the last are fully packed (overflow is only
138+
// allocated when the previous group is full). Compute hashes for
139+
// those directly.
140+
for &cgi in &chain[..chain.len() - 1] {
141+
let g = &self.container.groups[cgi as usize];
142+
for slot in 0..GROUP_SIZE {
143+
let hash = self
144+
.hash_builder
145+
.hash_one(unsafe { g.keys[slot].assume_init_ref() });
146+
hashes.push(hash);
129147
}
130-
continue;
131148
}
132-
buf.sort_unstable_by_key(|&(hash, _, _)| hash);
133-
// Write back in sorted order, filling slots linearly.
134-
let mut gi = primary_gi;
135-
let mut slot = 0;
136-
for (hash, key, value) in buf.drain(..) {
137-
if slot == GROUP_SIZE {
138-
slot = 0;
139-
gi = self.container.groups[gi].overflow as usize;
149+
// The last group may have gaps — compact it to the front.
150+
let last_gi = *chain.last().unwrap() as usize;
151+
compact_last_group(&mut self.container.groups[last_gi], &self.hash_builder, &mut hashes);
152+
let n = hashes.len();
153+
// Insertion sort by hash.
154+
for i in 1..n {
155+
let mut j = i;
156+
while j > 0 && hashes[j - 1] > hashes[j] {
157+
hashes.swap(j - 1, j);
158+
swap_chain_slots(&mut self.container.groups, &chain, j - 1, j);
159+
j -= 1;
140160
}
141-
let group = &mut self.container.groups[gi];
142-
group.ctrl[slot] = tag(hash);
143-
group.keys[slot] = MaybeUninit::new(key);
144-
group.values[slot] = MaybeUninit::new(value);
145-
slot += 1;
146161
}
162+
147163
}
164+
self.container
148165
}
149166

150167
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
@@ -421,6 +438,70 @@ impl<K: Hash + Eq, V, S: BuildHasher> HashSortedMap<K, V, S> {
421438
}
422439
}
423440

441+
// ── Chain-slot helpers for sort_by_hash ─────────────────────────────────
442+
443+
/// Map a flat position (0..chain.len()*GROUP_SIZE) to a (group_index, slot).
444+
#[inline]
445+
fn chain_slot(chain: &[u32], pos: usize) -> (usize, usize) {
446+
(chain[pos / GROUP_SIZE] as usize, pos % GROUP_SIZE)
447+
}
448+
449+
/// Compact the last group in a chain: move all occupied entries to slots
450+
/// 0..n and clear the rest. Computes hashes for each occupied entry and
451+
/// appends them to `hashes`.
452+
fn compact_last_group<K: Hash, V, S: BuildHasher>(
453+
group: &mut Group<K, V>,
454+
hash_builder: &S,
455+
hashes: &mut SmallVec<[u64; 16]>,
456+
) {
457+
let mut write = 0usize;
458+
let mut full_mask = group_ops::match_full(&group.ctrl);
459+
while let Some(read) = group_ops::next_match(&mut full_mask) {
460+
let hash = hash_builder.hash_one(unsafe { group.keys[read].assume_init_ref() });
461+
hashes.push(hash);
462+
if read != write {
463+
unsafe {
464+
group.keys[write] = std::ptr::read(&group.keys[read]);
465+
group.values[write] = std::ptr::read(&group.values[read]);
466+
}
467+
}
468+
write += 1;
469+
}
470+
// Fix ctrl bytes: only the top bit matters (full vs empty).
471+
for slot in 0..write {
472+
group.ctrl[slot] = 0x80;
473+
}
474+
for slot in write..GROUP_SIZE {
475+
group.ctrl[slot] = CTRL_EMPTY;
476+
}
477+
}
478+
479+
/// Swap the ctrl byte, key, and value between two flat positions in a chain.
480+
fn swap_chain_slots<K, V>(
481+
groups: &mut [Group<K, V>],
482+
chain: &[u32],
483+
a: usize,
484+
b: usize,
485+
) {
486+
let (gi_a, slot_a) = chain_slot(chain, a);
487+
let (gi_b, slot_b) = chain_slot(chain, b);
488+
if gi_a == gi_b {
489+
let g = &mut groups[gi_a];
490+
g.keys.swap(slot_a, slot_b);
491+
g.values.swap(slot_a, slot_b);
492+
} else {
493+
let (ga, gb) = if gi_a < gi_b {
494+
let (left, right) = groups.split_at_mut(gi_b);
495+
(&mut left[gi_a], &mut right[0])
496+
} else {
497+
let (left, right) = groups.split_at_mut(gi_a);
498+
(&mut right[0], &mut left[gi_b])
499+
};
500+
std::mem::swap(&mut ga.keys[slot_a], &mut gb.keys[slot_b]);
501+
std::mem::swap(&mut ga.values[slot_a], &mut gb.values[slot_b]);
502+
}
503+
}
504+
424505
// ────────────────────────────────────────────────────────────────────────
425506
// Entry API
426507
// ────────────────────────────────────────────────────────────────────────
@@ -842,18 +923,19 @@ mod tests {
842923

843924
#[test]
844925
fn sort_by_hash_empty() {
845-
let mut map: HashSortedMap<u32, u32> = HashSortedMap::new();
846-
map.sort_by_hash(); // should not panic
847-
assert_eq!(map.len(), 0);
926+
let map: HashSortedMap<u32, u32> = HashSortedMap::new();
927+
let container = map.sort_by_hash();
928+
assert_eq!(container.len, 0);
848929
}
849930

850931
#[test]
851932
fn sort_by_hash_single() {
852933
let mut map = HashSortedMap::new();
853934
map.insert(42u32, "hello");
854-
map.sort_by_hash();
855-
assert_eq!(map.get(&42), Some(&"hello"));
856-
assert_eq!(map.len(), 1);
935+
let container = map.sort_by_hash();
936+
assert_eq!(container.len, 1);
937+
let entries: Vec<_> = container.into_iter().collect();
938+
assert_eq!(entries, vec![(42, "hello")]);
857939
}
858940

859941
#[test]
@@ -862,10 +944,12 @@ mod tests {
862944
for i in 0..200u32 {
863945
map.insert(i, i * 10);
864946
}
865-
map.sort_by_hash();
866-
assert_eq!(map.len(), 200);
947+
let container = map.sort_by_hash();
948+
assert_eq!(container.len, 200);
949+
let mut entries: Vec<_> = container.into_iter().collect();
950+
entries.sort_by_key(|&(k, _)| k);
867951
for i in 0..200u32 {
868-
assert_eq!(map.get(&i), Some(&(i * 10)), "missing key {i}");
952+
assert_eq!(entries[i as usize], (i, i * 10), "missing key {i}");
869953
}
870954
}
871955

@@ -878,14 +962,14 @@ mod tests {
878962
for i in 0..500u32 {
879963
map.insert(i, i);
880964
}
881-
map.sort_by_hash();
965+
let container = map.sort_by_hash();
882966
// Iteration should now yield entries in hash order.
883967
let mut prev_hash = 0u64;
884968
let mut first = true;
885-
for (&k, _) in &map {
969+
for (&k, _) in &container {
886970
let h = hasher.hash_one(&k);
887971
if !first {
888-
assert!(h >= prev_hash, "hash order violated: {prev_hash:#x} > {h:#x}");
972+
assert!(h > prev_hash, "hash order violated: {prev_hash:#x} > {h:#x}");
889973
}
890974
prev_hash = h;
891975
first = false;
@@ -899,27 +983,35 @@ mod tests {
899983
for i in 0..50u32 {
900984
map.insert(i, i);
901985
}
902-
map.sort_by_hash();
903-
assert_eq!(map.len(), 50);
986+
let container = map.sort_by_hash();
987+
assert_eq!(container.len, 50);
988+
let mut entries: Vec<_> = container.into_iter().collect();
989+
entries.sort_by_key(|&(k, _)| k);
904990
for i in 0..50u32 {
905-
assert_eq!(map.get(&i), Some(&i), "missing key {i}");
991+
assert_eq!(entries[i as usize], (i, i), "missing key {i}");
906992
}
907993
}
908994

909995
#[test]
910996
fn sort_by_hash_with_strings() {
911-
let mut map = HashSortedMap::new();
997+
use std::collections::hash_map::RandomState;
998+
999+
let hasher = RandomState::new();
1000+
let mut map = HashSortedMap::with_hasher(hasher.clone());
9121001
for i in 0..100u32 {
9131002
map.insert(format!("key-{i}"), format!("val-{i}"));
9141003
}
915-
map.sort_by_hash();
916-
assert_eq!(map.len(), 100);
917-
for i in 0..100u32 {
918-
assert_eq!(
919-
map.get(&format!("key-{i}")),
920-
Some(&format!("val-{i}")),
921-
"missing key-{i}"
922-
);
1004+
let container = map.sort_by_hash();
1005+
assert_eq!(container.len, 100);
1006+
let mut prev_hash = 0u64;
1007+
let mut first = true;
1008+
for (k, _) in &container {
1009+
let h = hasher.hash_one(k);
1010+
if !first {
1011+
assert!(h > prev_hash, "hash order violated: {prev_hash:#x} > {h:#x}");
1012+
}
1013+
prev_hash = h;
1014+
first = false;
9231015
}
9241016
}
9251017
}

0 commit comments

Comments
 (0)