1+ /*
2+ * The MIT License
3+ *
4+ * Copyright 2016 Thibault Debatty.
5+ *
6+ * Permission is hereby granted, free of charge, to any person obtaining a copy
7+ * of this software and associated documentation files (the "Software"), to deal
8+ * in the Software without restriction, including without limitation the rights
9+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+ * copies of the Software, and to permit persons to whom the Software is
11+ * furnished to do so, subject to the following conditions:
12+ *
13+ * The above copyright notice and this permission notice shall be included in
14+ * all copies or substantial portions of the Software.
15+ *
16+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+ * THE SOFTWARE.
23+ */
124package info .debatty .java .graphs .build ;
225
26+ import info .debatty .java .graphs .Edge ;
327import info .debatty .java .graphs .Graph ;
428import info .debatty .java .graphs .Neighbor ;
529import info .debatty .java .graphs .NeighborList ;
630import java .security .InvalidParameterException ;
731import java .util .ArrayList ;
832import java .util .HashMap ;
33+ import java .util .HashSet ;
934import java .util .List ;
1035import java .util .Map ;
1136import java .util .Random ;
37+ import java .util .Set ;
1238
1339/**
1440 * Implementation of NN-Descent k-nn graph building algorithm. Based on the
@@ -30,7 +56,12 @@ public class NNDescent<T> extends GraphBuilder<T> {
3056 protected int iterations = 0 ;
3157 protected int c ;
3258
33- protected static final String IS_PROCESSED = "NNDescent_IS_PROCESSED_KEY" ;
59+ /**
60+ * Contains the list of neighbors that have been processed. Has we use a
61+ * hashset, we have to use edges (which contain a reference to the source
62+ * node) instead of neighbors for the concrete implementation.
63+ */
64+ protected Set <Edge > processed ;
3465
3566 /**
3667 * Get the number of edges modified at the last iteration
@@ -105,6 +136,7 @@ public void setMaxIterations(int max_iterations) {
105136 protected Graph <T > _computeGraph (List <T > nodes ) {
106137
107138 iterations = 0 ;
139+ processed = new HashSet <Edge >(nodes .size () * k );
108140
109141 if (nodes .size () <= (k + 1 )) {
110142 return MakeFullyLinked (nodes );
@@ -135,8 +167,8 @@ protected Graph<T> _computeGraph(List<T> nodes) {
135167 // Mark sampled items in B[v] as false;
136168 for (int i = 0 ; i < nodes .size (); i ++) {
137169 T v = nodes .get (i );
138- old_lists .put (v , PickFalses (neighborlists .getNeighbors (v )));
139- new_lists .put (v , PickTruesAndMark (neighborlists .getNeighbors (v )));
170+ old_lists .put (v , PickFalses (v , neighborlists .getNeighbors (v )));
171+ new_lists .put (v , PickTruesAndMark (v , neighborlists .getNeighbors (v )));
140172
141173 }
142174
@@ -242,10 +274,11 @@ protected NeighborList RandomNeighborList(List<T> nodes, T for_node) {
242274 return nl ;
243275 }
244276
245- protected ArrayList <T > PickFalses (NeighborList neighborList ) {
277+ protected ArrayList <T > PickFalses (T node , NeighborList neighborList ) {
246278 ArrayList <T > falses = new ArrayList <T >();
247279 for (Neighbor <T > n : neighborList ) {
248- if (n .getAttribute (IS_PROCESSED ) != null ) { // !n.is_new
280+ Edge edge = new Edge (node , n );
281+ if (processed .contains (edge )) {
249282 falses .add (n .node );
250283 }
251284 }
@@ -259,11 +292,12 @@ protected ArrayList<T> PickFalses(NeighborList neighborList) {
259292 * @param neighborList
260293 * @return
261294 */
262- protected ArrayList <T > PickTruesAndMark (NeighborList neighborList ) {
295+ protected ArrayList <T > PickTruesAndMark (T node , NeighborList neighborList ) {
263296 ArrayList <T > r = new ArrayList <T >();
264297 for (Neighbor <T > n : neighborList ) {
265- if (n .getAttribute (IS_PROCESSED ) == null && Math .random () < rho ) { // n.is_new
266- n .setAttribute (IS_PROCESSED , true ); // n.is_new = false;
298+ Edge <T > edge = new Edge <T >(node , n );
299+ if (!processed .contains (edge ) && Math .random () < rho ) {
300+ processed .add (edge );
267301 r .add (n .node );
268302 }
269303 }
0 commit comments