Skip to content

Commit f8b46ae

Browse files
committed
GH-117 Add test cases on radius
1 parent b83ea91 commit f8b46ae

3 files changed

Lines changed: 275 additions & 70 deletions

File tree

src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java

Lines changed: 93 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,15 @@ public abstract class FloatVectorBfTest extends DbBaseTest {
5050

5151
private final String namespaceName = "items";
5252
private Namespace<VectorItem> vectorNs;
53+
private List<VectorItem> testItems;
5354

5455
@BeforeEach
5556
public void setUp() {
5657
vectorNs = db.openNamespace(namespaceName, NamespaceOptions.defaultOptions(), VectorItem.class);
58+
testItems = getTestVectorItems();
59+
for (VectorItem item : testItems) {
60+
db.insert(namespaceName, item);
61+
}
5762
}
5863

5964
@Test
@@ -91,16 +96,11 @@ public void testInsertWithWrongVectorSize_isException() {
9196
}
9297

9398
@Test
94-
public void testSearchWithBaseParams_isOk() {
95-
List<VectorItem> testItems = getTestVectorItems();
96-
for (VectorItem item : testItems) {
97-
db.insert(namespaceName, item);
98-
}
99-
99+
public void testSearchWithBaseParamK_isOk() {
100100
List<VectorItem> list = db.query(namespaceName, VectorItem.class)
101101
.selectAllFields()
102102
.whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
103-
KnnParams.base(2))
103+
KnnParams.k(2))
104104
.toList();
105105

106106
assertThat(list.size(), is(2));
@@ -111,12 +111,51 @@ public void testSearchWithBaseParams_isOk() {
111111
}
112112

113113
@Test
114-
public void testSearchWithVecBfParams_isOk() {
115-
List<VectorItem> testItems = getTestVectorItems();
116-
for (VectorItem item : testItems) {
117-
db.insert(namespaceName, item);
118-
}
114+
public void testSearchWithBaseParamRadius_isOk() {
115+
List<VectorItem> list = db.query(namespaceName, VectorItem.class)
116+
.selectAllFields()
117+
.whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
118+
KnnParams.radius(0.7f))
119+
.toList();
120+
121+
assertThat(list.size(), is(4));
122+
assertThat(list.get(0).getId(), is(18));
123+
assertThat(list.get(1).getId(), is(6));
124+
assertThat(list.get(2).getId(), is(7));
125+
assertThat(list.get(3).getId(), is(8));
126+
}
119127

128+
@Test
129+
public void testSearchWithBaseParamsKAndRadius_isOk() {
130+
List<VectorItem> list = db.query(namespaceName, VectorItem.class)
131+
.selectAllFields()
132+
.whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
133+
KnnParams.base(3, 0.7f))
134+
.toList();
135+
136+
// by k (3 records) + by radius (4 records) = 3 records
137+
assertThat(list.size(), is(3));
138+
assertThat(list.get(0).getId(), is(18));
139+
assertThat(list.get(1).getId(), is(6));
140+
assertThat(list.get(2).getId(), is(7));
141+
142+
list = db.query(namespaceName, VectorItem.class)
143+
.selectAllFields()
144+
.whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
145+
KnnParams.base(5, 0.7f))
146+
.toList();
147+
148+
// by k (5 records) + by radius (4 records) = 4 records
149+
assertThat(list.size(), is(4));
150+
assertThat(list.get(0).getId(), is(18));
151+
assertThat(list.get(1).getId(), is(6));
152+
assertThat(list.get(2).getId(), is(7));
153+
assertThat(list.get(3).getId(), is(8));
154+
}
155+
156+
@Test
157+
public void testSearchWithVecBfParams_isOk() {
158+
// only k - 3 records
120159
List<VectorItem> list = db.query(namespaceName, VectorItem.class)
121160
.selectAllFields()
122161
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
@@ -130,15 +169,42 @@ public void testSearchWithVecBfParams_isOk() {
130169
assertThat(list.get(1).getVector(), is(testItems.get(18).getVector()));
131170
assertThat(list.get(2).getId(), is(19));
132171
assertThat(list.get(2).getVector(), is(testItems.get(19).getVector()));
172+
173+
// only radius 0.7 - 5 records
174+
list = db.query(namespaceName, VectorItem.class)
175+
.selectAllFields()
176+
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
177+
KnnParams.bf(KnnParams.radius(0.7f)))
178+
.toList();
179+
180+
assertThat(list.size(), is(5));
181+
assertThat(list.get(0).getId(), is(8));
182+
assertThat(list.get(1).getId(), is(18));
183+
assertThat(list.get(2).getId(), is(19));
184+
assertThat(list.get(3).getId(), is(1));
185+
assertThat(list.get(4).getId(), is(2));
186+
187+
// by k (3 records) + by radius (5 records) = 3 records
188+
list = db.query(namespaceName, VectorItem.class)
189+
.selectAllFields()
190+
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
191+
KnnParams.bf(KnnParams.base(3, 0.7f)))
192+
.toList();
193+
194+
assertThat(list.size(), is(3));
195+
196+
// by k (6 records) + by radius (5 records) = 5 records
197+
list = db.query(namespaceName, VectorItem.class)
198+
.selectAllFields()
199+
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
200+
KnnParams.bf(KnnParams.base(6, 0.7f)))
201+
.toList();
202+
203+
assertThat(list.size(), is(5));
133204
}
134205

135206
@Test
136207
public void testSearchWithIncorrectVecBfParams_isException() {
137-
List<VectorItem> testItems = getTestVectorItems();
138-
for (VectorItem item : testItems) {
139-
db.insert(namespaceName, item);
140-
}
141-
142208
assertThrows(IllegalArgumentException.class,
143209
() -> db.query(namespaceName, VectorItem.class)
144210
.selectAllFields()
@@ -149,12 +215,17 @@ public void testSearchWithIncorrectVecBfParams_isException() {
149215
}
150216

151217
@Test
152-
public void testSearchWithNotVecBfNorBaseParams_isException() {
153-
List<VectorItem> testItems = getTestVectorItems();
154-
for (VectorItem item : testItems) {
155-
db.insert(namespaceName, item);
156-
}
218+
public void testTrySearchWithVecBfParamFromNullBaseParam_isException() {
219+
assertThrows(NullPointerException.class,
220+
() -> db.query(namespaceName, VectorItem.class)
221+
.selectAllFields()
222+
.whereKnn("vector", new float[]{0.5f, 0.0f, 0.6f},
223+
KnnParams.bf(null))
224+
.toList());
225+
}
157226

227+
@Test
228+
public void testSearchWithNotVecBfNorBaseParams_isException() {
158229
assertThrows(RuntimeException.class,
159230
() -> db.query(namespaceName, VectorItem.class)
160231
.selectAllFields()

src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java

Lines changed: 89 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,15 @@ public abstract class FloatVectorHnswTest extends DbBaseTest {
5050

5151
private final String namespaceName = "items";
5252
private Namespace<VectorItem> vectorNs;
53+
private List<VectorItem> testItems;
5354

5455
@BeforeEach
5556
public void setUp() {
5657
vectorNs = db.openNamespace(namespaceName, NamespaceOptions.defaultOptions(), VectorItem.class);
58+
testItems = getTestVectorItems();
59+
for (VectorItem item : testItems) {
60+
db.insert(namespaceName, item);
61+
}
5762
}
5863

5964
@Test
@@ -91,16 +96,11 @@ public void testInsertWithWrongVectorSize_isException() {
9196
}
9297

9398
@Test
94-
public void testSearchWithBaseParams_isOk() {
95-
List<VectorItem> testItems = getTestVectorItems();
96-
for (VectorItem item : testItems) {
97-
db.insert(namespaceName, item);
98-
}
99-
99+
public void testSearchWithBaseParamK_isOk() {
100100
List<VectorItem> list = db.query(namespaceName, VectorItem.class)
101101
.selectAllFields()
102102
.whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f},
103-
KnnParams.base(2))
103+
KnnParams.k(2))
104104
.toList();
105105

106106
assertThat(list.size(), is(2));
@@ -111,32 +111,103 @@ public void testSearchWithBaseParams_isOk() {
111111
}
112112

113113
@Test
114-
public void testSearchWithHnswParams_isOk() {
115-
List<VectorItem> testItems = getTestVectorItems();
116-
for (VectorItem item : testItems) {
117-
db.insert(namespaceName, item);
118-
}
114+
public void testSearchWithBaseParamRadius_isOk() {
115+
List<VectorItem> list = db.query(namespaceName, VectorItem.class)
116+
.selectAllFields()
117+
.whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f},
118+
KnnParams.radius(0.4f))
119+
.toList();
120+
121+
assertThat(list.size(), is(4));
122+
assertThat(list.get(0).getId(), is(1));
123+
assertThat(list.get(1).getId(), is(2));
124+
assertThat(list.get(2).getId(), is(0));
125+
assertThat(list.get(3).getId(), is(3));
126+
}
119127

128+
@Test
129+
public void testSearchWithBaseParamsKAndRadius_isOk() {
130+
// by k (3 records) + by radius (4 records) = 3 records
131+
List<VectorItem> list = db.query(namespaceName, VectorItem.class)
132+
.selectAllFields()
133+
.whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f},
134+
KnnParams.base(3, 0.4f))
135+
.toList();
136+
137+
assertThat(list.size(), is(3));
138+
assertThat(list.get(0).getId(), is(1));
139+
assertThat(list.get(1).getId(), is(2));
140+
assertThat(list.get(2).getId(), is(0));
141+
142+
// by k (5 records) + by radius (4 records) = 4 records
143+
list = db.query(namespaceName, VectorItem.class)
144+
.selectAllFields()
145+
.whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f},
146+
KnnParams.base(5, 0.4f))
147+
.toList();
148+
149+
assertThat(list.size(), is(4));
150+
assertThat(list.get(0).getId(), is(1));
151+
assertThat(list.get(1).getId(), is(2));
152+
assertThat(list.get(2).getId(), is(0));
153+
assertThat(list.get(3).getId(), is(3));
154+
}
155+
156+
@Test
157+
public void testSearchWithHnswParams_isOk() {
158+
// k - 2 records
120159
List<VectorItem> list = db.query(namespaceName, VectorItem.class)
121160
.selectAllFields()
122161
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f},
123-
KnnParams.hnsw(2, 2))
162+
KnnParams.hnsw(2, 5))
124163
.toList();
125164

126165
assertThat(list.size(), is(2));
127166
assertThat(list.get(0).getId(), is(2));
128167
assertThat(list.get(0).getVector(), is(testItems.get(2).getVector()));
129168
assertThat(list.get(1).getId(), is(3));
130169
assertThat(list.get(1).getVector(), is(testItems.get(3).getVector()));
170+
171+
// radius 0.4 - 4 records
172+
list = db.query(namespaceName, VectorItem.class)
173+
.selectAllFields()
174+
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f},
175+
KnnParams.hnsw(KnnParams.radius(0.4f), 5))
176+
.toList();
177+
178+
assertThat(list.size(), is(4));
179+
180+
// by k (3 records) + by radius (4 records) = 3 records
181+
list = db.query(namespaceName, VectorItem.class)
182+
.selectAllFields()
183+
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f},
184+
KnnParams.hnsw(KnnParams.base(3, 0.4f), 5))
185+
.toList();
186+
187+
assertThat(list.size(), is(3));
188+
189+
// by k (5 records) + by radius (4 records) = 4 records
190+
list = db.query(namespaceName, VectorItem.class)
191+
.selectAllFields()
192+
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f},
193+
KnnParams.hnsw(KnnParams.base(5, 0.4f), 5))
194+
.toList();
195+
196+
assertThat(list.size(), is(4));
131197
}
132198

133199
@Test
134-
public void testSearchWithIncorrectHnswParams_isException() {
135-
List<VectorItem> testItems = getTestVectorItems();
136-
for (VectorItem item : testItems) {
137-
db.insert(namespaceName, item);
138-
}
200+
public void testTrySearchWithHnswParamFromNullBaseParam_isException() {
201+
assertThrows(NullPointerException.class,
202+
() -> db.query(namespaceName, VectorItem.class)
203+
.selectAllFields()
204+
.whereKnn("vector", new float[]{0.5f, 0.0f, 0.6f},
205+
KnnParams.hnsw(null, 5))
206+
.toList());
207+
}
139208

209+
@Test
210+
public void testSearchWithIncorrectHnswParams_isException() {
140211
assertThrows(IllegalArgumentException.class,
141212
() -> db.query(namespaceName, VectorItem.class)
142213
.selectAllFields()
@@ -156,11 +227,6 @@ public void testSearchWithIncorrectHnswParams_isException() {
156227

157228
@Test
158229
public void testSearchWithNotHnswNorBaseParams_isException() {
159-
List<VectorItem> testItems = getTestVectorItems();
160-
for (VectorItem item : testItems) {
161-
db.insert(namespaceName, item);
162-
}
163-
164230
assertThrows(RuntimeException.class,
165231
() -> db.query(namespaceName, VectorItem.class)
166232
.selectAllFields()

0 commit comments

Comments
 (0)