Skip to content

Commit 0014f86

Browse files
committed
more simplifications
1 parent ea848b5 commit 0014f86

2 files changed

Lines changed: 108 additions & 95 deletions

File tree

crates/famst/src/lib.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub(crate) struct Neighbor {
4545

4646
impl PartialEq for Neighbor {
4747
fn eq(&self, other: &Self) -> bool {
48-
self.distance == other.distance
48+
self.distance == other.distance && self.index == other.index
4949
}
5050
}
5151

@@ -59,10 +59,11 @@ impl PartialOrd for Neighbor {
5959

6060
impl Ord for Neighbor {
6161
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
62-
// Max-heap: larger distances have higher priority
62+
// Total ordering by (distance, index)
6363
self.distance
6464
.partial_cmp(&other.distance)
6565
.unwrap_or(std::cmp::Ordering::Equal)
66+
.then_with(|| self.index.cmp(&other.index))
6667
}
6768
}
6869

@@ -78,20 +79,30 @@ pub(crate) struct AnnGraph {
7879
}
7980

8081
impl AnnGraph {
81-
fn new(n: usize, k: usize, data: Vec<Neighbor>) -> Self {
82+
pub(crate) fn new(n: usize, k: usize, data: Vec<Neighbor>) -> Self {
8283
assert_eq!(data.len(), n * k);
8384
AnnGraph { data, n, k }
8485
}
8586

86-
fn n(&self) -> usize {
87+
pub(crate) fn n(&self) -> usize {
8788
self.n
8889
}
8990

91+
pub(crate) fn k(&self) -> usize {
92+
self.k
93+
}
94+
9095
/// Get the neighbors of point i
91-
fn neighbors(&self, i: usize) -> &[Neighbor] {
96+
pub(crate) fn neighbors(&self, i: usize) -> &[Neighbor] {
9297
let start = i * self.k;
9398
&self.data[start..start + self.k]
9499
}
100+
101+
/// Get mutable access to neighbors of point i
102+
pub(crate) fn neighbors_mut(&mut self, i: usize) -> &mut [Neighbor] {
103+
let start = i * self.k;
104+
&mut self.data[start..start + self.k]
105+
}
95106
}
96107

97108
/// FAMST algorithm configuration

crates/famst/src/nn_descent.rs

Lines changed: 92 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,94 @@
55
66
use rand::seq::SliceRandom;
77
use rand::Rng;
8-
use std::collections::BinaryHeap;
8+
use std::collections::HashSet;
99

1010
use crate::{AnnGraph, FamstConfig, Neighbor, NodeId};
1111

12-
/// Check if sorted vec contains value
13-
fn sorted_contains(v: &[NodeId], x: NodeId) -> bool {
14-
v.binary_search(&x).is_ok()
15-
}
16-
17-
/// Insert into sorted vec, returns true if inserted (was not present)
18-
fn sorted_insert(v: &mut Vec<NodeId>, x: NodeId) -> bool {
19-
match v.binary_search(&x) {
20-
Ok(_) => false,
21-
Err(pos) => {
22-
v.insert(pos, x);
23-
true
12+
/// Build reverse neighbor lists (who has me as a neighbor)
13+
fn build_reverse(graph: &AnnGraph) -> Vec<Vec<NodeId>> {
14+
let n = graph.n();
15+
let mut reverse: Vec<Vec<NodeId>> = vec![Vec::new(); n];
16+
for i in 0..n {
17+
for neighbor in graph.neighbors(i) {
18+
reverse[neighbor.index as usize].push(i as NodeId);
2419
}
2520
}
21+
reverse
2622
}
2723

28-
/// Remove from sorted vec
29-
fn sorted_remove(v: &mut Vec<NodeId>, x: NodeId) {
30-
if let Ok(pos) = v.binary_search(&x) {
31-
v.remove(pos);
24+
/// Initialize ANN graph with random neighbors
25+
fn init_random_graph<T, D, R>(data: &[T], k: usize, distance_fn: &D, rng: &mut R) -> AnnGraph
26+
where
27+
D: Fn(&T, &T) -> f32,
28+
R: Rng,
29+
{
30+
let n = data.len();
31+
let mut graph_data: Vec<Neighbor> = Vec::with_capacity(n * k);
32+
33+
for i in 0..n {
34+
let mut neighbors: Vec<Neighbor> = Vec::with_capacity(k);
35+
let mut seen: HashSet<NodeId> = HashSet::with_capacity(k);
36+
37+
// Sample k random neighbors using Floyd's algorithm - guaranteed O(k)
38+
let effective_n = n - 1; // exclude self
39+
let range_start = effective_n.saturating_sub(k);
40+
for t in range_start..effective_n {
41+
let j = rng.gen_range(0..=t);
42+
// Map j to actual index, skipping i
43+
let actual_j = (if j >= i { j + 1 } else { j }) as NodeId;
44+
45+
let selected = if seen.insert(actual_j) {
46+
actual_j
47+
} else {
48+
// j was already selected, so add t instead
49+
let actual_t = (if t >= i { t + 1 } else { t }) as NodeId;
50+
seen.insert(actual_t);
51+
actual_t
52+
};
53+
54+
let d = distance_fn(&data[i], &data[selected as usize]);
55+
neighbors.push(Neighbor {
56+
index: selected,
57+
distance: d,
58+
});
59+
}
60+
61+
// Sort by (distance, index) for total ordering
62+
neighbors.sort();
63+
graph_data.extend(neighbors);
3264
}
65+
66+
AnnGraph::new(n, k, graph_data)
3367
}
3468

35-
/// Build reverse neighbor lists (who has me as a neighbor)
36-
/// Returns sorted vecs for each point
37-
fn build_reverse(neighbor_lists: &[Vec<NodeId>], n: usize) -> Vec<Vec<NodeId>> {
38-
let mut reverse: Vec<Vec<NodeId>> = vec![Vec::new(); n];
39-
for (i, neighbors) in neighbor_lists.iter().enumerate() {
40-
for &j in neighbors {
41-
reverse[j as usize].push(i as NodeId);
69+
/// Try to insert a new neighbor into a sorted neighbor slice.
70+
/// Returns true if the neighbor was inserted (better than the worst).
71+
/// Assumes neighbors are sorted by (distance, index) for total ordering.
72+
fn insert_neighbor(neighbors: &mut [Neighbor], new_index: NodeId, new_distance: f32) -> bool {
73+
let new_neighbor = Neighbor {
74+
index: new_index,
75+
distance: new_distance,
76+
};
77+
78+
// Binary search using total ordering - also serves as existence check
79+
match neighbors.binary_search(&new_neighbor) {
80+
Ok(_) => false, // Already exists
81+
Err(insert_pos) => {
82+
// Check if better than worst (last element)
83+
if insert_pos >= neighbors.len() {
84+
return false;
85+
}
86+
87+
// Shift elements to make room (dropping the last/worst)
88+
for j in (insert_pos + 1..neighbors.len()).rev() {
89+
neighbors[j] = neighbors[j - 1];
90+
}
91+
92+
neighbors[insert_pos] = new_neighbor;
93+
true
4294
}
4395
}
44-
// Each reverse list is built in order of i, so already sorted
45-
reverse
4696
}
4797

4898
/// NN-Descent algorithm for approximate k-NN graph construction
@@ -66,56 +116,26 @@ where
66116
return AnnGraph::new(n, 0, vec![]);
67117
}
68118

69-
// Initialize with random neighbors using max-heap for each point
70-
// neighbor_lists[i] is kept sorted by index for O(log k) membership tests
71-
let mut heaps: Vec<BinaryHeap<Neighbor>> = Vec::with_capacity(n);
72-
let mut neighbor_lists: Vec<Vec<NodeId>> = vec![Vec::with_capacity(k); n];
73-
74-
for i in 0..n {
75-
let mut heap = BinaryHeap::with_capacity(k);
76-
77-
// Sample k random neighbors using Floyd's algorithm - guaranteed O(k)
78-
// https://fermatslibrary.com/s/a-sample-of-brilliance
79-
// This selects k distinct elements from 0..n, excluding i
80-
let effective_n = n - 1; // exclude self
81-
let range_start = effective_n.saturating_sub(k);
82-
for t in range_start..effective_n {
83-
let j = rng.gen_range(0..=t);
84-
// Map j to actual index, skipping i
85-
let actual_j = (if j >= i { j + 1 } else { j }) as NodeId;
86-
87-
if !sorted_insert(&mut neighbor_lists[i], actual_j) {
88-
// j was already selected, so add t instead
89-
let actual_t = (if t >= i { t + 1 } else { t }) as NodeId;
90-
sorted_insert(&mut neighbor_lists[i], actual_t);
91-
let d = distance_fn(&data[i], &data[actual_t as usize]);
92-
heap.push(Neighbor {
93-
index: actual_t,
94-
distance: d,
95-
});
96-
} else {
97-
let d = distance_fn(&data[i], &data[actual_j as usize]);
98-
heap.push(Neighbor {
99-
index: actual_j,
100-
distance: d,
101-
});
102-
}
103-
}
104-
heaps.push(heap);
105-
}
119+
// Initialize ANN graph with random neighbors
120+
let mut graph = init_random_graph(data, k, distance_fn, rng);
106121

107122
// NN-Descent iterations
108123
for _ in 0..config.nn_descent_iterations {
109124
let mut updates = 0;
110-
let reverse_neighbors = build_reverse(&neighbor_lists, n);
125+
let reverse_neighbors = build_reverse(&graph);
111126

112127
// For each point, explore neighbors of neighbors
113128
for i in 0..n {
129+
// Build set of current neighbors for O(1) lookup
130+
let current_neighbors: HashSet<NodeId> =
131+
graph.neighbors(i).iter().map(|nb| nb.index).collect();
132+
114133
// Collect candidates: neighbors and reverse neighbors
115134
let mut candidates: Vec<NodeId> = Vec::new();
116135

117136
// Sample from forward neighbors
118-
let mut sampled_forward = neighbor_lists[i].clone();
137+
let mut sampled_forward: Vec<NodeId> =
138+
graph.neighbors(i).iter().map(|nb| nb.index).collect();
119139
let sample_size =
120140
((sampled_forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize)
121141
.max(1);
@@ -133,14 +153,14 @@ where
133153
// Neighbors of neighbors
134154
let i_id = i as NodeId;
135155
for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) {
136-
for &nn in &neighbor_lists[neighbor as usize] {
137-
if nn != i_id && !sorted_contains(&neighbor_lists[i], nn) {
138-
candidates.push(nn);
156+
for nb in graph.neighbors(neighbor as usize) {
157+
if nb.index != i_id && !current_neighbors.contains(&nb.index) {
158+
candidates.push(nb.index);
139159
}
140160
}
141161
// Also check reverse neighbors of neighbors
142162
for &rn in &reverse_neighbors[neighbor as usize] {
143-
if rn != i_id && !sorted_contains(&neighbor_lists[i], rn) {
163+
if rn != i_id && !current_neighbors.contains(&rn) {
144164
candidates.push(rn);
145165
}
146166
}
@@ -154,17 +174,8 @@ where
154174
for c in candidates {
155175
let d = distance_fn(&data[i], &data[c as usize]);
156176

157-
// Check if this is better than the worst current neighbor
158-
if let Some(worst) = heaps[i].peek() {
159-
if d < worst.distance {
160-
// Remove worst and add new neighbor
161-
let removed = heaps[i].pop().unwrap();
162-
sorted_remove(&mut neighbor_lists[i], removed.index);
163-
164-
heaps[i].push(Neighbor { index: c, distance: d });
165-
sorted_insert(&mut neighbor_lists[i], c);
166-
updates += 1;
167-
}
177+
if insert_neighbor(graph.neighbors_mut(i), c, d) {
178+
updates += 1;
168179
}
169180
}
170181
}
@@ -175,14 +186,5 @@ where
175186
}
176187
}
177188

178-
// Convert heaps to flat neighbor array sorted by distance
179-
let mut result_data = Vec::with_capacity(n * k);
180-
181-
for heap in heaps {
182-
let mut entries: Vec<Neighbor> = heap.into_vec();
183-
entries.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
184-
result_data.extend(entries);
185-
}
186-
187-
AnnGraph::new(n, k, result_data)
189+
graph
188190
}

0 commit comments

Comments
 (0)