Skip to content

Commit b3a257b

Browse files
committed
Replace HashSets with sorted Vecs for memory efficiency
This reduces memory usage by ~5x for large graphs: - neighbor_lists in nn_descent: HashSet -> sorted Vec - reverse_neighbors in nn_descent: HashSet -> Vec (already sorted by construction) - graph in find_components: HashSet -> sorted Vec - node_to_component in refine_edges: HashMap -> Vec (O(1) indexed lookup) - Removed component_sets HashSets entirely For n=1 billion, k=20, this saves ~2.5 TB of memory.
1 parent e37f359 commit b3a257b

1 file changed

Lines changed: 76 additions & 66 deletions

File tree

crates/famst/src/lib.rs

Lines changed: 76 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
1414
use rand::seq::SliceRandom;
1515
use rand::Rng;
16-
use std::collections::{BinaryHeap, HashMap, HashSet};
16+
use std::collections::BinaryHeap;
1717

1818
/// An edge in the MST, represented as (node_a, node_b, distance)
1919
#[derive(Debug, Clone)]
@@ -264,9 +264,33 @@ where
264264
return AnnGraph::new(vec![vec![]; n], vec![vec![]; n]);
265265
}
266266

267+
// Helper: check if sorted vec contains value
268+
fn sorted_contains(v: &[usize], x: usize) -> bool {
269+
v.binary_search(&x).is_ok()
270+
}
271+
272+
// Helper: insert into sorted vec, returns true if inserted (was not present)
273+
fn sorted_insert(v: &mut Vec<usize>, x: usize) -> bool {
274+
match v.binary_search(&x) {
275+
Ok(_) => false,
276+
Err(pos) => {
277+
v.insert(pos, x);
278+
true
279+
}
280+
}
281+
}
282+
283+
// Helper: remove from sorted vec
284+
fn sorted_remove(v: &mut Vec<usize>, x: usize) {
285+
if let Ok(pos) = v.binary_search(&x) {
286+
v.remove(pos);
287+
}
288+
}
289+
267290
// Initialize with random neighbors using max-heap for each point
291+
// neighbor_lists[i] is kept sorted by index for O(log k) membership tests
268292
let mut heaps: Vec<BinaryHeap<NeighborEntry>> = Vec::with_capacity(n);
269-
let mut neighbor_sets: Vec<HashSet<usize>> = vec![HashSet::with_capacity(k); n];
293+
let mut neighbor_lists: Vec<Vec<usize>> = vec![Vec::with_capacity(k); n];
270294

271295
for i in 0..n {
272296
let mut heap = BinaryHeap::with_capacity(k);
@@ -281,10 +305,10 @@ where
281305
// Map j to actual index, skipping i
282306
let actual_j = if j >= i { j + 1 } else { j };
283307

284-
if !neighbor_sets[i].insert(actual_j) {
308+
if !sorted_insert(&mut neighbor_lists[i], actual_j) {
285309
// j was already selected, so add t instead
286310
let actual_t = if t >= i { t + 1 } else { t };
287-
neighbor_sets[i].insert(actual_t);
311+
sorted_insert(&mut neighbor_lists[i], actual_t);
288312
let d = distance_fn(&data[i], &data[actual_t]);
289313
heap.push(NeighborEntry {
290314
index: actual_t,
@@ -302,52 +326,52 @@ where
302326
}
303327

304328
// Build reverse neighbor lists (who has me as a neighbor)
305-
let build_reverse = |neighbor_sets: &[HashSet<usize>]| -> Vec<HashSet<usize>> {
306-
let mut reverse: Vec<HashSet<usize>> = vec![HashSet::new(); n];
307-
for (i, neighbors) in neighbor_sets.iter().enumerate() {
329+
// Returns sorted vecs for each point
330+
let build_reverse = |neighbor_lists: &[Vec<usize>]| -> Vec<Vec<usize>> {
331+
let mut reverse: Vec<Vec<usize>> = vec![Vec::new(); n];
332+
for (i, neighbors) in neighbor_lists.iter().enumerate() {
308333
for &j in neighbors {
309-
reverse[j].insert(i);
334+
reverse[j].push(i);
310335
}
311336
}
337+
// Sort each reverse list (they're built in order of i, so already sorted)
312338
reverse
313339
};
314340

315341
// NN-Descent iterations
316342
for _ in 0..config.nn_descent_iterations {
317343
let mut updates = 0;
318-
let reverse_neighbors = build_reverse(&neighbor_sets);
344+
let reverse_neighbors = build_reverse(&neighbor_lists);
319345

320346
// For each point, explore neighbors of neighbors
321347
for i in 0..n {
322348
// Collect candidates: neighbors and reverse neighbors
323349
let mut candidates: Vec<usize> = Vec::new();
324350

325351
// Sample from forward neighbors
326-
let forward: Vec<usize> = neighbor_sets[i].iter().copied().collect();
352+
let mut sampled_forward = neighbor_lists[i].clone();
327353
let sample_size =
328-
((forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1);
329-
let mut sampled_forward = forward.clone();
354+
((sampled_forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1);
330355
sampled_forward.shuffle(rng);
331356
sampled_forward.truncate(sample_size);
332357

333358
// Sample from reverse neighbors
334-
let reverse: Vec<usize> = reverse_neighbors[i].iter().copied().collect();
359+
let mut sampled_reverse = reverse_neighbors[i].clone();
335360
let sample_size =
336-
((reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1);
337-
let mut sampled_reverse = reverse.clone();
361+
((sampled_reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1);
338362
sampled_reverse.shuffle(rng);
339363
sampled_reverse.truncate(sample_size);
340364

341365
// Neighbors of neighbors
342366
for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) {
343-
for &nn in &neighbor_sets[neighbor] {
344-
if nn != i && !neighbor_sets[i].contains(&nn) {
367+
for &nn in &neighbor_lists[neighbor] {
368+
if nn != i && !sorted_contains(&neighbor_lists[i], nn) {
345369
candidates.push(nn);
346370
}
347371
}
348372
// Also check reverse neighbors of neighbors
349373
for &rn in &reverse_neighbors[neighbor] {
350-
if rn != i && !neighbor_sets[i].contains(&rn) {
374+
if rn != i && !sorted_contains(&neighbor_lists[i], rn) {
351375
candidates.push(rn);
352376
}
353377
}
@@ -366,13 +390,13 @@ where
366390
if d < worst.distance {
367391
// Remove worst and add new neighbor
368392
let removed = heaps[i].pop().unwrap();
369-
neighbor_sets[i].remove(&removed.index);
393+
sorted_remove(&mut neighbor_lists[i], removed.index);
370394

371395
heaps[i].push(NeighborEntry {
372396
index: c,
373397
distance: d,
374398
});
375-
neighbor_sets[i].insert(c);
399+
sorted_insert(&mut neighbor_lists[i], c);
376400
updates += 1;
377401
}
378402
}
@@ -403,18 +427,23 @@ where
403427
}
404428

405429
/// Find connected components in the ANN graph using DFS
406-
/// Returns the undirected graph adjacency list and component assignments
407-
fn find_components(ann_graph: &AnnGraph) -> (Vec<HashSet<usize>>, Vec<Vec<usize>>) {
430+
/// Returns the undirected graph adjacency list (sorted vecs) and component assignments
431+
fn find_components(ann_graph: &AnnGraph) -> (Vec<Vec<usize>>, Vec<Vec<usize>>) {
408432
let n = ann_graph.n();
409433

410-
// Build undirected graph from directed ANN graph
411-
let mut graph: Vec<HashSet<usize>> = vec![HashSet::new(); n];
434+
// Build undirected graph from directed ANN graph using sorted vecs
435+
let mut graph: Vec<Vec<usize>> = vec![Vec::new(); n];
412436
for (i, neighbors) in ann_graph.neighbors.iter().enumerate() {
413437
for &j in neighbors {
414-
graph[i].insert(j);
415-
graph[j].insert(i);
438+
graph[i].push(j);
439+
graph[j].push(i);
416440
}
417441
}
442+
// Sort and deduplicate each adjacency list
443+
for adj in &mut graph {
444+
adj.sort_unstable();
445+
adj.dedup();
446+
}
418447

419448
// DFS to find components
420449
let mut visited = vec![false; n];
@@ -494,7 +523,7 @@ where
494523
/// Refine inter-component edges (Algorithm 4 in the paper)
495524
fn refine_edges<T, D>(
496525
data: &[T],
497-
undirected_graph: &[HashSet<usize>],
526+
undirected_graph: &[Vec<usize>],
498527
components: &[Vec<usize>],
499528
edges: &[Edge],
500529
edge_components: &[(usize, usize)],
@@ -503,20 +532,16 @@ fn refine_edges<T, D>(
503532
where
504533
D: Fn(&T, &T) -> f64,
505534
{
506-
// Build component membership lookup
507-
let mut node_to_component: HashMap<usize, usize> = HashMap::new();
535+
let n = data.len();
536+
537+
// Build component membership lookup (simple vec, O(1) lookup)
538+
let mut node_to_component: Vec<usize> = vec![0; n];
508539
for (comp_idx, component) in components.iter().enumerate() {
509540
for &node in component {
510-
node_to_component.insert(node, comp_idx);
541+
node_to_component[node] = comp_idx;
511542
}
512543
}
513544

514-
// Build component node sets for quick lookup
515-
let component_sets: Vec<HashSet<usize>> = components
516-
.iter()
517-
.map(|c| c.iter().copied().collect())
518-
.collect();
519-
520545
let mut refined_edges = Vec::with_capacity(edges.len());
521546
let mut changes = 0;
522547

@@ -526,40 +551,25 @@ where
526551
let mut best_d = edge.distance;
527552

528553
// Get neighbors of u that are in component ci
529-
let neighbors_u: Vec<usize> = undirected_graph[edge.u]
530-
.iter()
531-
.filter(|&&n| component_sets[ci].contains(&n))
532-
.copied()
533-
.collect();
534-
535-
// Try to find better u from neighbors
536-
for u_prime in neighbors_u {
537-
if u_prime == edge.v {
538-
continue;
539-
}
540-
let d_prime = distance_fn(&data[u_prime], &data[best_v]);
541-
if d_prime < best_d {
542-
best_u = u_prime;
543-
best_d = d_prime;
554+
// (undirected_graph is sorted, so we can iterate directly)
555+
for &u_prime in &undirected_graph[edge.u] {
556+
if node_to_component[u_prime] == ci && u_prime != edge.v {
557+
let d_prime = distance_fn(&data[u_prime], &data[best_v]);
558+
if d_prime < best_d {
559+
best_u = u_prime;
560+
best_d = d_prime;
561+
}
544562
}
545563
}
546564

547565
// Get neighbors of v that are in component cj
548-
let neighbors_v: Vec<usize> = undirected_graph[edge.v]
549-
.iter()
550-
.filter(|&&n| component_sets[cj].contains(&n))
551-
.copied()
552-
.collect();
553-
554-
// Try to find better v from neighbors (using updated best_u)
555-
for v_prime in neighbors_v {
556-
if v_prime == edge.u {
557-
continue;
558-
}
559-
let d_prime = distance_fn(&data[best_u], &data[v_prime]);
560-
if d_prime < best_d {
561-
best_v = v_prime;
562-
best_d = d_prime;
566+
for &v_prime in &undirected_graph[edge.v] {
567+
if node_to_component[v_prime] == cj && v_prime != edge.u {
568+
let d_prime = distance_fn(&data[best_u], &data[v_prime]);
569+
if d_prime < best_d {
570+
best_v = v_prime;
571+
best_d = d_prime;
572+
}
563573
}
564574
}
565575

0 commit comments

Comments
 (0)