55
66use rand:: seq:: SliceRandom ;
77use rand:: Rng ;
8- use std:: collections:: BinaryHeap ;
8+ use std:: collections:: HashSet ;
99
1010use crate :: { AnnGraph , FamstConfig , Neighbor , NodeId } ;
1111
12- /// Check if sorted vec contains value
13- fn sorted_contains ( v : & [ NodeId ] , x : NodeId ) -> bool {
14- v. binary_search ( & x) . is_ok ( )
15- }
16-
17- /// Insert into sorted vec, returns true if inserted (was not present)
18- fn sorted_insert ( v : & mut Vec < NodeId > , x : NodeId ) -> bool {
19- match v. binary_search ( & x) {
20- Ok ( _) => false ,
21- Err ( pos) => {
22- v. insert ( pos, x) ;
23- true
12+ /// Build reverse neighbor lists (who has me as a neighbor)
13+ fn build_reverse ( graph : & AnnGraph ) -> Vec < Vec < NodeId > > {
14+ let n = graph. n ( ) ;
15+ let mut reverse: Vec < Vec < NodeId > > = vec ! [ Vec :: new( ) ; n] ;
16+ for i in 0 ..n {
17+ for neighbor in graph. neighbors ( i) {
18+ reverse[ neighbor. index as usize ] . push ( i as NodeId ) ;
2419 }
2520 }
21+ reverse
2622}
2723
28- /// Remove from sorted vec
29- fn sorted_remove ( v : & mut Vec < NodeId > , x : NodeId ) {
30- if let Ok ( pos) = v. binary_search ( & x) {
31- v. remove ( pos) ;
24+ /// Initialize ANN graph with random neighbors
25+ fn init_random_graph < T , D , R > ( data : & [ T ] , k : usize , distance_fn : & D , rng : & mut R ) -> AnnGraph
26+ where
27+ D : Fn ( & T , & T ) -> f32 ,
28+ R : Rng ,
29+ {
30+ let n = data. len ( ) ;
31+ let mut graph_data: Vec < Neighbor > = Vec :: with_capacity ( n * k) ;
32+
33+ for i in 0 ..n {
34+ let mut neighbors: Vec < Neighbor > = Vec :: with_capacity ( k) ;
35+ let mut seen: HashSet < NodeId > = HashSet :: with_capacity ( k) ;
36+
37+ // Sample k random neighbors using Floyd's algorithm - guaranteed O(k)
38+ let effective_n = n - 1 ; // exclude self
39+ let range_start = effective_n. saturating_sub ( k) ;
40+ for t in range_start..effective_n {
41+ let j = rng. gen_range ( 0 ..=t) ;
42+ // Map j to actual index, skipping i
43+ let actual_j = ( if j >= i { j + 1 } else { j } ) as NodeId ;
44+
45+ let selected = if seen. insert ( actual_j) {
46+ actual_j
47+ } else {
48+ // j was already selected, so add t instead
49+ let actual_t = ( if t >= i { t + 1 } else { t } ) as NodeId ;
50+ seen. insert ( actual_t) ;
51+ actual_t
52+ } ;
53+
54+ let d = distance_fn ( & data[ i] , & data[ selected as usize ] ) ;
55+ neighbors. push ( Neighbor {
56+ index : selected,
57+ distance : d,
58+ } ) ;
59+ }
60+
61+ // Sort by (distance, index) for total ordering
62+ neighbors. sort ( ) ;
63+ graph_data. extend ( neighbors) ;
3264 }
65+
66+ AnnGraph :: new ( n, k, graph_data)
3367}
3468
35- /// Build reverse neighbor lists (who has me as a neighbor)
36- /// Returns sorted vecs for each point
37- fn build_reverse ( neighbor_lists : & [ Vec < NodeId > ] , n : usize ) -> Vec < Vec < NodeId > > {
38- let mut reverse: Vec < Vec < NodeId > > = vec ! [ Vec :: new( ) ; n] ;
39- for ( i, neighbors) in neighbor_lists. iter ( ) . enumerate ( ) {
40- for & j in neighbors {
41- reverse[ j as usize ] . push ( i as NodeId ) ;
69+ /// Try to insert a new neighbor into a sorted neighbor slice.
70+ /// Returns true if the neighbor was inserted (better than the worst).
71+ /// Assumes neighbors are sorted by (distance, index) for total ordering.
72+ fn insert_neighbor ( neighbors : & mut [ Neighbor ] , new_index : NodeId , new_distance : f32 ) -> bool {
73+ let new_neighbor = Neighbor {
74+ index : new_index,
75+ distance : new_distance,
76+ } ;
77+
78+ // Binary search using total ordering - also serves as existence check
79+ match neighbors. binary_search ( & new_neighbor) {
80+ Ok ( _) => false , // Already exists
81+ Err ( insert_pos) => {
82+ // Check if better than worst (last element)
83+ if insert_pos >= neighbors. len ( ) {
84+ return false ;
85+ }
86+
87+ // Shift elements to make room (dropping the last/worst)
88+ for j in ( insert_pos + 1 ..neighbors. len ( ) ) . rev ( ) {
89+ neighbors[ j] = neighbors[ j - 1 ] ;
90+ }
91+
92+ neighbors[ insert_pos] = new_neighbor;
93+ true
4294 }
4395 }
44- // Each reverse list is built in order of i, so already sorted
45- reverse
4696}
4797
4898/// NN-Descent algorithm for approximate k-NN graph construction
@@ -66,56 +116,26 @@ where
66116 return AnnGraph :: new ( n, 0 , vec ! [ ] ) ;
67117 }
68118
69- // Initialize with random neighbors using max-heap for each point
70- // neighbor_lists[i] is kept sorted by index for O(log k) membership tests
71- let mut heaps: Vec < BinaryHeap < Neighbor > > = Vec :: with_capacity ( n) ;
72- let mut neighbor_lists: Vec < Vec < NodeId > > = vec ! [ Vec :: with_capacity( k) ; n] ;
73-
74- for i in 0 ..n {
75- let mut heap = BinaryHeap :: with_capacity ( k) ;
76-
77- // Sample k random neighbors using Floyd's algorithm - guaranteed O(k)
78- // https://fermatslibrary.com/s/a-sample-of-brilliance
79- // This selects k distinct elements from 0..n, excluding i
80- let effective_n = n - 1 ; // exclude self
81- let range_start = effective_n. saturating_sub ( k) ;
82- for t in range_start..effective_n {
83- let j = rng. gen_range ( 0 ..=t) ;
84- // Map j to actual index, skipping i
85- let actual_j = ( if j >= i { j + 1 } else { j } ) as NodeId ;
86-
87- if !sorted_insert ( & mut neighbor_lists[ i] , actual_j) {
88- // j was already selected, so add t instead
89- let actual_t = ( if t >= i { t + 1 } else { t } ) as NodeId ;
90- sorted_insert ( & mut neighbor_lists[ i] , actual_t) ;
91- let d = distance_fn ( & data[ i] , & data[ actual_t as usize ] ) ;
92- heap. push ( Neighbor {
93- index : actual_t,
94- distance : d,
95- } ) ;
96- } else {
97- let d = distance_fn ( & data[ i] , & data[ actual_j as usize ] ) ;
98- heap. push ( Neighbor {
99- index : actual_j,
100- distance : d,
101- } ) ;
102- }
103- }
104- heaps. push ( heap) ;
105- }
119+ // Initialize ANN graph with random neighbors
120+ let mut graph = init_random_graph ( data, k, distance_fn, rng) ;
106121
107122 // NN-Descent iterations
108123 for _ in 0 ..config. nn_descent_iterations {
109124 let mut updates = 0 ;
110- let reverse_neighbors = build_reverse ( & neighbor_lists , n ) ;
125+ let reverse_neighbors = build_reverse ( & graph ) ;
111126
112127 // For each point, explore neighbors of neighbors
113128 for i in 0 ..n {
129+ // Build set of current neighbors for O(1) lookup
130+ let current_neighbors: HashSet < NodeId > =
131+ graph. neighbors ( i) . iter ( ) . map ( |nb| nb. index ) . collect ( ) ;
132+
114133 // Collect candidates: neighbors and reverse neighbors
115134 let mut candidates: Vec < NodeId > = Vec :: new ( ) ;
116135
117136 // Sample from forward neighbors
118- let mut sampled_forward = neighbor_lists[ i] . clone ( ) ;
137+ let mut sampled_forward: Vec < NodeId > =
138+ graph. neighbors ( i) . iter ( ) . map ( |nb| nb. index ) . collect ( ) ;
119139 let sample_size =
120140 ( ( sampled_forward. len ( ) as f64 * config. nn_descent_sample_rate ) . ceil ( ) as usize )
121141 . max ( 1 ) ;
@@ -133,14 +153,14 @@ where
133153 // Neighbors of neighbors
134154 let i_id = i as NodeId ;
135155 for & neighbor in sampled_forward. iter ( ) . chain ( sampled_reverse. iter ( ) ) {
136- for & nn in & neighbor_lists [ neighbor as usize ] {
137- if nn != i_id && !sorted_contains ( & neighbor_lists [ i ] , nn ) {
138- candidates. push ( nn ) ;
156+ for nb in graph . neighbors ( neighbor as usize ) {
157+ if nb . index != i_id && !current_neighbors . contains ( & nb . index ) {
158+ candidates. push ( nb . index ) ;
139159 }
140160 }
141161 // Also check reverse neighbors of neighbors
142162 for & rn in & reverse_neighbors[ neighbor as usize ] {
143- if rn != i_id && !sorted_contains ( & neighbor_lists [ i ] , rn) {
163+ if rn != i_id && !current_neighbors . contains ( & rn) {
144164 candidates. push ( rn) ;
145165 }
146166 }
@@ -154,17 +174,8 @@ where
154174 for c in candidates {
155175 let d = distance_fn ( & data[ i] , & data[ c as usize ] ) ;
156176
157- // Check if this is better than the worst current neighbor
158- if let Some ( worst) = heaps[ i] . peek ( ) {
159- if d < worst. distance {
160- // Remove worst and add new neighbor
161- let removed = heaps[ i] . pop ( ) . unwrap ( ) ;
162- sorted_remove ( & mut neighbor_lists[ i] , removed. index ) ;
163-
164- heaps[ i] . push ( Neighbor { index : c, distance : d } ) ;
165- sorted_insert ( & mut neighbor_lists[ i] , c) ;
166- updates += 1 ;
167- }
177+ if insert_neighbor ( graph. neighbors_mut ( i) , c, d) {
178+ updates += 1 ;
168179 }
169180 }
170181 }
@@ -175,14 +186,5 @@ where
175186 }
176187 }
177188
178- // Convert heaps to flat neighbor array sorted by distance
179- let mut result_data = Vec :: with_capacity ( n * k) ;
180-
181- for heap in heaps {
182- let mut entries: Vec < Neighbor > = heap. into_vec ( ) ;
183- entries. sort_by ( |a, b| a. distance . partial_cmp ( & b. distance ) . unwrap ( ) ) ;
184- result_data. extend ( entries) ;
185- }
186-
187- AnnGraph :: new ( n, k, result_data)
189+ graph
188190}
0 commit comments