@@ -50,10 +50,15 @@ public abstract class FloatVectorHnswTest extends DbBaseTest {
5050
5151 private final String namespaceName = "items" ;
5252 private Namespace <VectorItem > vectorNs ;
53+ private List <VectorItem > testItems ;
5354
5455 @ BeforeEach
5556 public void setUp () {
5657 vectorNs = db .openNamespace (namespaceName , NamespaceOptions .defaultOptions (), VectorItem .class );
58+ testItems = getTestVectorItems ();
59+ for (VectorItem item : testItems ) {
60+ db .insert (namespaceName , item );
61+ }
5762 }
5863
5964 @ Test
@@ -91,16 +96,11 @@ public void testInsertWithWrongVectorSize_isException() {
9196 }
9297
9398 @ Test
94- public void testSearchWithBaseParams_isOk () {
95- List <VectorItem > testItems = getTestVectorItems ();
96- for (VectorItem item : testItems ) {
97- db .insert (namespaceName , item );
98- }
99-
99+ public void testSearchWithBaseParamK_isOk () {
100100 List <VectorItem > list = db .query (namespaceName , VectorItem .class )
101101 .selectAllFields ()
102102 .whereKnn ("vector" , new float []{0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f },
103- KnnParams .base (2 ))
103+ KnnParams .k (2 ))
104104 .toList ();
105105
106106 assertThat (list .size (), is (2 ));
@@ -111,32 +111,103 @@ public void testSearchWithBaseParams_isOk() {
111111 }
112112
113113 @ Test
114- public void testSearchWithHnswParams_isOk () {
115- List <VectorItem > testItems = getTestVectorItems ();
116- for (VectorItem item : testItems ) {
117- db .insert (namespaceName , item );
118- }
114+ public void testSearchWithBaseParamRadius_isOk () {
115+ List <VectorItem > list = db .query (namespaceName , VectorItem .class )
116+ .selectAllFields ()
117+ .whereKnn ("vector" , new float []{0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f },
118+ KnnParams .radius (0.4f ))
119+ .toList ();
120+
121+ assertThat (list .size (), is (4 ));
122+ assertThat (list .get (0 ).getId (), is (1 ));
123+ assertThat (list .get (1 ).getId (), is (2 ));
124+ assertThat (list .get (2 ).getId (), is (0 ));
125+ assertThat (list .get (3 ).getId (), is (3 ));
126+ }
119127
128+ @ Test
129+ public void testSearchWithBaseParamsKAndRadius_isOk () {
130+ // by k (3 records) + by radius (4 records) = 3 records
131+ List <VectorItem > list = db .query (namespaceName , VectorItem .class )
132+ .selectAllFields ()
133+ .whereKnn ("vector" , new float []{0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f },
134+ KnnParams .base (3 , 0.4f ))
135+ .toList ();
136+
137+ assertThat (list .size (), is (3 ));
138+ assertThat (list .get (0 ).getId (), is (1 ));
139+ assertThat (list .get (1 ).getId (), is (2 ));
140+ assertThat (list .get (2 ).getId (), is (0 ));
141+
142+ // by k (5 records) + by radius (4 records) = 4 records
143+ list = db .query (namespaceName , VectorItem .class )
144+ .selectAllFields ()
145+ .whereKnn ("vector" , new float []{0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f , 0.13f },
146+ KnnParams .base (5 , 0.4f ))
147+ .toList ();
148+
149+ assertThat (list .size (), is (4 ));
150+ assertThat (list .get (0 ).getId (), is (1 ));
151+ assertThat (list .get (1 ).getId (), is (2 ));
152+ assertThat (list .get (2 ).getId (), is (0 ));
153+ assertThat (list .get (3 ).getId (), is (3 ));
154+ }
155+
156+ @ Test
157+ public void testSearchWithHnswParams_isOk () {
158+ // k - 2 records
120159 List <VectorItem > list = db .query (namespaceName , VectorItem .class )
121160 .selectAllFields ()
122161 .whereKnn ("vector" , new float []{0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f },
123- KnnParams .hnsw (2 , 2 ))
162+ KnnParams .hnsw (2 , 5 ))
124163 .toList ();
125164
126165 assertThat (list .size (), is (2 ));
127166 assertThat (list .get (0 ).getId (), is (2 ));
128167 assertThat (list .get (0 ).getVector (), is (testItems .get (2 ).getVector ()));
129168 assertThat (list .get (1 ).getId (), is (3 ));
130169 assertThat (list .get (1 ).getVector (), is (testItems .get (3 ).getVector ()));
170+
171+ // radius 0.4 - 4 records
172+ list = db .query (namespaceName , VectorItem .class )
173+ .selectAllFields ()
174+ .whereKnn ("vector" , new float []{0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f },
175+ KnnParams .hnsw (KnnParams .radius (0.4f ), 5 ))
176+ .toList ();
177+
178+ assertThat (list .size (), is (4 ));
179+
180+ // by k (3 records) + by radius (4 records) = 3 records
181+ list = db .query (namespaceName , VectorItem .class )
182+ .selectAllFields ()
183+ .whereKnn ("vector" , new float []{0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f },
184+ KnnParams .hnsw (KnnParams .base (3 , 0.4f ), 5 ))
185+ .toList ();
186+
187+ assertThat (list .size (), is (3 ));
188+
189+ // by k (5 records) + by radius (4 records) = 4 records
190+ list = db .query (namespaceName , VectorItem .class )
191+ .selectAllFields ()
192+ .whereKnn ("vector" , new float []{0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f , 0.23f },
193+ KnnParams .hnsw (KnnParams .base (5 , 0.4f ), 5 ))
194+ .toList ();
195+
196+ assertThat (list .size (), is (4 ));
131197 }
132198
133199 @ Test
134- public void testSearchWithIncorrectHnswParams_isException () {
135- List <VectorItem > testItems = getTestVectorItems ();
136- for (VectorItem item : testItems ) {
137- db .insert (namespaceName , item );
138- }
200+ public void testTrySearchWithHnswParamFromNullBaseParam_isException () {
201+ assertThrows (NullPointerException .class ,
202+ () -> db .query (namespaceName , VectorItem .class )
203+ .selectAllFields ()
204+ .whereKnn ("vector" , new float []{0.5f , 0.0f , 0.6f },
205+ KnnParams .hnsw (null , 5 ))
206+ .toList ());
207+ }
139208
209+ @ Test
210+ public void testSearchWithIncorrectHnswParams_isException () {
140211 assertThrows (IllegalArgumentException .class ,
141212 () -> db .query (namespaceName , VectorItem .class )
142213 .selectAllFields ()
@@ -156,11 +227,6 @@ public void testSearchWithIncorrectHnswParams_isException() {
156227
157228 @ Test
158229 public void testSearchWithNotHnswNorBaseParams_isException () {
159- List <VectorItem > testItems = getTestVectorItems ();
160- for (VectorItem item : testItems ) {
161- db .insert (namespaceName , item );
162- }
163-
164230 assertThrows (RuntimeException .class ,
165231 () -> db .query (namespaceName , VectorItem .class )
166232 .selectAllFields ()
0 commit comments