1313
1414use rand:: seq:: SliceRandom ;
1515use rand:: Rng ;
16- use std:: collections:: { BinaryHeap , HashMap , HashSet } ;
16+ use std:: collections:: BinaryHeap ;
1717
1818/// An edge in the MST, represented as (node_a, node_b, distance)
1919#[ derive( Debug , Clone ) ]
@@ -264,9 +264,33 @@ where
264264 return AnnGraph :: new ( vec ! [ vec![ ] ; n] , vec ! [ vec![ ] ; n] ) ;
265265 }
266266
267+ // Helper: check if sorted vec contains value
268+ fn sorted_contains ( v : & [ usize ] , x : usize ) -> bool {
269+ v. binary_search ( & x) . is_ok ( )
270+ }
271+
272+ // Helper: insert into sorted vec, returns true if inserted (was not present)
273+ fn sorted_insert ( v : & mut Vec < usize > , x : usize ) -> bool {
274+ match v. binary_search ( & x) {
275+ Ok ( _) => false ,
276+ Err ( pos) => {
277+ v. insert ( pos, x) ;
278+ true
279+ }
280+ }
281+ }
282+
283+ // Helper: remove from sorted vec
284+ fn sorted_remove ( v : & mut Vec < usize > , x : usize ) {
285+ if let Ok ( pos) = v. binary_search ( & x) {
286+ v. remove ( pos) ;
287+ }
288+ }
289+
267290 // Initialize with random neighbors using max-heap for each point
291+ // neighbor_lists[i] is kept sorted by index for O(log k) membership tests
268292 let mut heaps: Vec < BinaryHeap < NeighborEntry > > = Vec :: with_capacity ( n) ;
269- let mut neighbor_sets : Vec < HashSet < usize > > = vec ! [ HashSet :: with_capacity( k) ; n] ;
293+ let mut neighbor_lists : Vec < Vec < usize > > = vec ! [ Vec :: with_capacity( k) ; n] ;
270294
271295 for i in 0 ..n {
272296 let mut heap = BinaryHeap :: with_capacity ( k) ;
@@ -281,10 +305,10 @@ where
281305 // Map j to actual index, skipping i
282306 let actual_j = if j >= i { j + 1 } else { j } ;
283307
284- if !neighbor_sets [ i] . insert ( actual_j) {
308+ if !sorted_insert ( & mut neighbor_lists [ i] , actual_j) {
285309 // j was already selected, so add t instead
286310 let actual_t = if t >= i { t + 1 } else { t } ;
287- neighbor_sets [ i] . insert ( actual_t) ;
311+ sorted_insert ( & mut neighbor_lists [ i] , actual_t) ;
288312 let d = distance_fn ( & data[ i] , & data[ actual_t] ) ;
289313 heap. push ( NeighborEntry {
290314 index : actual_t,
@@ -302,52 +326,52 @@ where
302326 }
303327
304328 // Build reverse neighbor lists (who has me as a neighbor)
305- let build_reverse = |neighbor_sets : & [ HashSet < usize > ] | -> Vec < HashSet < usize > > {
306- let mut reverse: Vec < HashSet < usize > > = vec ! [ HashSet :: new( ) ; n] ;
307- for ( i, neighbors) in neighbor_sets. iter ( ) . enumerate ( ) {
329+ // Returns sorted vecs for each point
330+ let build_reverse = |neighbor_lists : & [ Vec < usize > ] | -> Vec < Vec < usize > > {
331+ let mut reverse: Vec < Vec < usize > > = vec ! [ Vec :: new( ) ; n] ;
332+ for ( i, neighbors) in neighbor_lists. iter ( ) . enumerate ( ) {
308333 for & j in neighbors {
309- reverse[ j] . insert ( i) ;
334+ reverse[ j] . push ( i) ;
310335 }
311336 }
337+ // Sort each reverse list (they're built in order of i, so already sorted)
312338 reverse
313339 } ;
314340
315341 // NN-Descent iterations
316342 for _ in 0 ..config. nn_descent_iterations {
317343 let mut updates = 0 ;
318- let reverse_neighbors = build_reverse ( & neighbor_sets ) ;
344+ let reverse_neighbors = build_reverse ( & neighbor_lists ) ;
319345
320346 // For each point, explore neighbors of neighbors
321347 for i in 0 ..n {
322348 // Collect candidates: neighbors and reverse neighbors
323349 let mut candidates: Vec < usize > = Vec :: new ( ) ;
324350
325351 // Sample from forward neighbors
326- let forward : Vec < usize > = neighbor_sets [ i] . iter ( ) . copied ( ) . collect ( ) ;
352+ let mut sampled_forward = neighbor_lists [ i] . clone ( ) ;
327353 let sample_size =
328- ( ( forward. len ( ) as f64 * config. nn_descent_sample_rate ) . ceil ( ) as usize ) . max ( 1 ) ;
329- let mut sampled_forward = forward. clone ( ) ;
354+ ( ( sampled_forward. len ( ) as f64 * config. nn_descent_sample_rate ) . ceil ( ) as usize ) . max ( 1 ) ;
330355 sampled_forward. shuffle ( rng) ;
331356 sampled_forward. truncate ( sample_size) ;
332357
333358 // Sample from reverse neighbors
334- let reverse : Vec < usize > = reverse_neighbors[ i] . iter ( ) . copied ( ) . collect ( ) ;
359+ let mut sampled_reverse = reverse_neighbors[ i] . clone ( ) ;
335360 let sample_size =
336- ( ( reverse. len ( ) as f64 * config. nn_descent_sample_rate ) . ceil ( ) as usize ) . max ( 1 ) ;
337- let mut sampled_reverse = reverse. clone ( ) ;
361+ ( ( sampled_reverse. len ( ) as f64 * config. nn_descent_sample_rate ) . ceil ( ) as usize ) . max ( 1 ) ;
338362 sampled_reverse. shuffle ( rng) ;
339363 sampled_reverse. truncate ( sample_size) ;
340364
341365 // Neighbors of neighbors
342366 for & neighbor in sampled_forward. iter ( ) . chain ( sampled_reverse. iter ( ) ) {
343- for & nn in & neighbor_sets [ neighbor] {
344- if nn != i && !neighbor_sets [ i] . contains ( & nn) {
367+ for & nn in & neighbor_lists [ neighbor] {
368+ if nn != i && !sorted_contains ( & neighbor_lists [ i] , nn) {
345369 candidates. push ( nn) ;
346370 }
347371 }
348372 // Also check reverse neighbors of neighbors
349373 for & rn in & reverse_neighbors[ neighbor] {
350- if rn != i && !neighbor_sets [ i] . contains ( & rn) {
374+ if rn != i && !sorted_contains ( & neighbor_lists [ i] , rn) {
351375 candidates. push ( rn) ;
352376 }
353377 }
@@ -366,13 +390,13 @@ where
366390 if d < worst. distance {
367391 // Remove worst and add new neighbor
368392 let removed = heaps[ i] . pop ( ) . unwrap ( ) ;
369- neighbor_sets [ i] . remove ( & removed. index ) ;
393+ sorted_remove ( & mut neighbor_lists [ i] , removed. index ) ;
370394
371395 heaps[ i] . push ( NeighborEntry {
372396 index : c,
373397 distance : d,
374398 } ) ;
375- neighbor_sets [ i] . insert ( c) ;
399+ sorted_insert ( & mut neighbor_lists [ i] , c) ;
376400 updates += 1 ;
377401 }
378402 }
@@ -403,18 +427,23 @@ where
403427}
404428
405429/// Find connected components in the ANN graph using DFS
406- /// Returns the undirected graph adjacency list and component assignments
407- fn find_components ( ann_graph : & AnnGraph ) -> ( Vec < HashSet < usize > > , Vec < Vec < usize > > ) {
430+ /// Returns the undirected graph adjacency list (sorted vecs) and component assignments
431+ fn find_components ( ann_graph : & AnnGraph ) -> ( Vec < Vec < usize > > , Vec < Vec < usize > > ) {
408432 let n = ann_graph. n ( ) ;
409433
410- // Build undirected graph from directed ANN graph
411- let mut graph: Vec < HashSet < usize > > = vec ! [ HashSet :: new( ) ; n] ;
434+ // Build undirected graph from directed ANN graph using sorted vecs
435+ let mut graph: Vec < Vec < usize > > = vec ! [ Vec :: new( ) ; n] ;
412436 for ( i, neighbors) in ann_graph. neighbors . iter ( ) . enumerate ( ) {
413437 for & j in neighbors {
414- graph[ i] . insert ( j) ;
415- graph[ j] . insert ( i) ;
438+ graph[ i] . push ( j) ;
439+ graph[ j] . push ( i) ;
416440 }
417441 }
442+ // Sort and deduplicate each adjacency list
443+ for adj in & mut graph {
444+ adj. sort_unstable ( ) ;
445+ adj. dedup ( ) ;
446+ }
418447
419448 // DFS to find components
420449 let mut visited = vec ! [ false ; n] ;
@@ -494,7 +523,7 @@ where
494523/// Refine inter-component edges (Algorithm 4 in the paper)
495524fn refine_edges < T , D > (
496525 data : & [ T ] ,
497- undirected_graph : & [ HashSet < usize > ] ,
526+ undirected_graph : & [ Vec < usize > ] ,
498527 components : & [ Vec < usize > ] ,
499528 edges : & [ Edge ] ,
500529 edge_components : & [ ( usize , usize ) ] ,
@@ -503,20 +532,16 @@ fn refine_edges<T, D>(
503532where
504533 D : Fn ( & T , & T ) -> f64 ,
505534{
506- // Build component membership lookup
507- let mut node_to_component: HashMap < usize , usize > = HashMap :: new ( ) ;
535+ let n = data. len ( ) ;
536+
537+ // Build component membership lookup (simple vec, O(1) lookup)
538+ let mut node_to_component: Vec < usize > = vec ! [ 0 ; n] ;
508539 for ( comp_idx, component) in components. iter ( ) . enumerate ( ) {
509540 for & node in component {
510- node_to_component. insert ( node, comp_idx) ;
541+ node_to_component[ node] = comp_idx;
511542 }
512543 }
513544
514- // Build component node sets for quick lookup
515- let component_sets: Vec < HashSet < usize > > = components
516- . iter ( )
517- . map ( |c| c. iter ( ) . copied ( ) . collect ( ) )
518- . collect ( ) ;
519-
520545 let mut refined_edges = Vec :: with_capacity ( edges. len ( ) ) ;
521546 let mut changes = 0 ;
522547
@@ -526,40 +551,25 @@ where
526551 let mut best_d = edge. distance ;
527552
528553 // Get neighbors of u that are in component ci
529- let neighbors_u: Vec < usize > = undirected_graph[ edge. u ]
530- . iter ( )
531- . filter ( |& & n| component_sets[ ci] . contains ( & n) )
532- . copied ( )
533- . collect ( ) ;
534-
535- // Try to find better u from neighbors
536- for u_prime in neighbors_u {
537- if u_prime == edge. v {
538- continue ;
539- }
540- let d_prime = distance_fn ( & data[ u_prime] , & data[ best_v] ) ;
541- if d_prime < best_d {
542- best_u = u_prime;
543- best_d = d_prime;
554+ // (undirected_graph is sorted, so we can iterate directly)
555+ for & u_prime in & undirected_graph[ edge. u ] {
556+ if node_to_component[ u_prime] == ci && u_prime != edge. v {
557+ let d_prime = distance_fn ( & data[ u_prime] , & data[ best_v] ) ;
558+ if d_prime < best_d {
559+ best_u = u_prime;
560+ best_d = d_prime;
561+ }
544562 }
545563 }
546564
547565 // Get neighbors of v that are in component cj
548- let neighbors_v: Vec < usize > = undirected_graph[ edge. v ]
549- . iter ( )
550- . filter ( |& & n| component_sets[ cj] . contains ( & n) )
551- . copied ( )
552- . collect ( ) ;
553-
554- // Try to find better v from neighbors (using updated best_u)
555- for v_prime in neighbors_v {
556- if v_prime == edge. u {
557- continue ;
558- }
559- let d_prime = distance_fn ( & data[ best_u] , & data[ v_prime] ) ;
560- if d_prime < best_d {
561- best_v = v_prime;
562- best_d = d_prime;
566+ for & v_prime in & undirected_graph[ edge. v ] {
567+ if node_to_component[ v_prime] == cj && v_prime != edge. u {
568+ let d_prime = distance_fn ( & data[ best_u] , & data[ v_prime] ) ;
569+ if d_prime < best_d {
570+ best_v = v_prime;
571+ best_d = d_prime;
572+ }
563573 }
564574 }
565575
0 commit comments