Skip to content

Commit cafbbb0

Browse files
committed
Avoid exception when search result is empty
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent e0f0b66 commit cafbbb0

4 files changed

Lines changed: 158 additions & 92 deletions

File tree

sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ public List<IDScore> getIDScore(int indexOfTarget) throws ParamException, Illega
237237
return idScores;
238238
}
239239

240+
/**
241+
* Gets how many nq are searched.
242+
* @return how many nq are searched
243+
*/
244+
public long getNumQueries() {
245+
return results.getNumQueries();
246+
}
247+
240248
@Getter
241249
private static final class Position {
242250
private final long offset;
@@ -250,11 +258,12 @@ public Position(long offset, long k) {
250258
private Position getOffsetByIndex(int indexOfTarget) {
251259
List<Long> kList = results.getTopksList();
252260

253-
// if the server didn't return separate topK, use same topK value
261+
// if the server didn't return separate topK, use same topK value "0"
262+
// will return an empty result for each nq instead of throwing an exception
254263
if (kList.isEmpty()) {
255264
kList = new ArrayList<>();
256265
for (long i = 0; i < results.getNumQueries(); ++i) {
257-
kList.add(results.getTopK());
266+
kList.add(0L);
258267
}
259268
}
260269

sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java

Lines changed: 82 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import io.milvus.common.utils.Float16Utils;
2828
import io.milvus.common.utils.GTsDict;
2929
import io.milvus.common.utils.JsonUtils;
30+
import io.milvus.exception.ParamException;
3031
import io.milvus.grpc.*;
3132
import io.milvus.orm.iterator.QueryIterator;
3233
import io.milvus.orm.iterator.SearchIterator;
@@ -49,6 +50,8 @@
4950
import io.milvus.pool.PoolConfig;
5051
import io.milvus.response.*;
5152

53+
import io.milvus.v2.exception.MilvusClientException;
54+
import io.milvus.v2.service.vector.response.SearchResp;
5255
import org.apache.commons.text.RandomStringGenerator;
5356

5457
import org.junit.jupiter.api.Assertions;
@@ -63,6 +66,9 @@
6366
import java.util.*;
6467
import java.util.concurrent.ExecutionException;
6568
import java.util.concurrent.TimeUnit;
69+
import java.util.function.Function;
70+
71+
import static org.junit.jupiter.api.Assertions.assertThrows;
6672

6773
@Testcontainers(disabledWithoutDocker = true)
6874
class MilvusClientDockerTest {
@@ -1345,18 +1351,6 @@ void testMultipleVectorFields() {
13451351
R<RpcStatus> createR = client.createCollection(createParam);
13461352
Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
13471353

1348-
// insert data to multiple vector fields
1349-
int rowCount = 10000;
1350-
List<InsertParam.Field> fields = generateColumnsData(schema, rowCount, 0);
1351-
1352-
InsertParam insertParam = InsertParam.newBuilder()
1353-
.withCollectionName(randomCollectionName)
1354-
.withFields(fields)
1355-
.build();
1356-
1357-
R<MutationResult> insertR = client.insert(insertParam);
1358-
Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
1359-
13601354
// create indexes on multiple vector fields
13611355
CreateIndexParam indexParam = CreateIndexParam.newBuilder()
13621356
.withCollectionName(randomCollectionName)
@@ -1397,53 +1391,86 @@ void testMultipleVectorFields() {
13971391
.build());
13981392
Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
13991393

1400-
// search on multiple vector fields
1401-
AnnSearchParam param1 = AnnSearchParam.newBuilder()
1402-
.withVectorFieldName(DataType.FloatVector.name())
1403-
.withFloatVectors(utils.generateFloatVectors(1))
1404-
.withMetricType(MetricType.COSINE)
1405-
.withParams("{\"nprobe\": 32}")
1406-
.withLimit(10L)
1407-
.build();
1408-
1409-
AnnSearchParam param2 = AnnSearchParam.newBuilder()
1410-
.withVectorFieldName(DataType.BinaryVector.name())
1411-
.withBinaryVectors(utils.generateBinaryVectors(1))
1412-
.withMetricType(MetricType.HAMMING)
1413-
.withParams("{}")
1414-
.withLimit(5L)
1415-
.build();
1416-
1417-
AnnSearchParam param3 = AnnSearchParam.newBuilder()
1418-
.withVectorFieldName(DataType.SparseFloatVector.name())
1419-
.withSparseFloatVectors(utils.generateSparseVectors(1))
1420-
.withMetricType(MetricType.IP)
1421-
.withParams("{\"drop_ratio_search\":0.2}")
1422-
.withLimit(7L)
1423-
.build();
1394+
// prepare sub requests
1395+
int nq = 5;
1396+
long topk = 10L;
1397+
Function<Integer, HybridSearchParam> genRequestFunc =
1398+
sparseCount -> {
1399+
AnnSearchParam param1 = AnnSearchParam.newBuilder()
1400+
.withVectorFieldName(DataType.FloatVector.name())
1401+
.withFloatVectors(utils.generateFloatVectors(nq))
1402+
.withMetricType(MetricType.COSINE)
1403+
.withParams("{\"nprobe\": 32}")
1404+
.withLimit(15L)
1405+
.build();
1406+
1407+
AnnSearchParam param2 = AnnSearchParam.newBuilder()
1408+
.withVectorFieldName(DataType.BinaryVector.name())
1409+
.withBinaryVectors(utils.generateBinaryVectors(nq))
1410+
.withMetricType(MetricType.HAMMING)
1411+
.withParams("{}")
1412+
.withLimit(5L)
1413+
.build();
1414+
1415+
List<SortedMap<Long, Float>> sparseVEctors = sparseCount > 0 ?
1416+
utils.generateSparseVectors(sparseCount) : new ArrayList<>();
1417+
AnnSearchParam param3 = AnnSearchParam.newBuilder()
1418+
.withVectorFieldName(DataType.SparseFloatVector.name())
1419+
.withSparseFloatVectors(sparseVEctors)
1420+
.withMetricType(MetricType.IP)
1421+
.withParams("{\"drop_ratio_search\":0.2}")
1422+
.withLimit(7L)
1423+
.build();
1424+
1425+
// search with an empty nq, return error
1426+
return HybridSearchParam.newBuilder()
1427+
.withCollectionName(randomCollectionName)
1428+
.addOutField(DataType.SparseFloatVector.name())
1429+
.addSearchRequest(param1)
1430+
.addSearchRequest(param2)
1431+
.addSearchRequest(param3)
1432+
.withLimit(topk)
1433+
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
1434+
.withRanker(WeightedRanker.newBuilder()
1435+
.withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f))
1436+
.build())
1437+
.withOutFields(Collections.singletonList("*"))
1438+
.build();
1439+
};
1440+
1441+
// search with an empty nq, return error
1442+
Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(0));
1443+
1444+
// unequal nq, return error
1445+
Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(1));
1446+
1447+
// search on empty collection, no result returned
1448+
R<SearchResults> searchR = client.hybridSearch(genRequestFunc.apply(nq));
1449+
Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
1450+
SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
1451+
for (int i = 0; i < results.getNumQueries(); ++i) {
1452+
List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
1453+
Assertions.assertTrue(scores.isEmpty());
1454+
}
14241455

1425-
HybridSearchParam searchParam = HybridSearchParam.newBuilder()
1456+
// insert data to multiple vector fields
1457+
int rowCount = 10000;
1458+
List<InsertParam.Field> fields = generateColumnsData(schema, rowCount, 0);
1459+
InsertParam insertParam = InsertParam.newBuilder()
14261460
.withCollectionName(randomCollectionName)
1427-
.addOutField(DataType.SparseFloatVector.name())
1428-
.addSearchRequest(param1)
1429-
.addSearchRequest(param2)
1430-
.addSearchRequest(param3)
1431-
.withLimit(3L)
1432-
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
1433-
.withRanker(WeightedRanker.newBuilder()
1434-
.withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f))
1435-
.build())
1436-
.withOutFields(Collections.singletonList("*"))
1461+
.withFields(fields)
14371462
.build();
1463+
R<MutationResult> insertR = client.insert(insertParam);
1464+
Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
14381465

1439-
R<SearchResults> searchR = client.hybridSearch(searchParam);
1466+
// search on multiple vector fields
1467+
searchR = client.hybridSearch(genRequestFunc.apply(nq));
14401468
Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
14411469

1442-
// print search result
1443-
SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
1470+
// check search result
1471+
results = new SearchResultsWrapper(searchR.getData().getResults());
14441472
List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
14451473
for (SearchResultsWrapper.IDScore score : scores) {
1446-
System.out.println(score);
14471474
Object id = score.get("id");
14481475
Assertions.assertInstanceOf(Long.class, id);
14491476
Object fv = score.get(DataType.FloatVector.name());
@@ -1457,6 +1484,10 @@ void testMultipleVectorFields() {
14571484
Object sv = score.get(DataType.SparseFloatVector.name());
14581485
Assertions.assertInstanceOf(SortedMap.class, sv);
14591486
}
1487+
for (int i = 0; i < results.getNumQueries(); ++i) {
1488+
scores = results.getIDScore(i);
1489+
Assertions.assertEquals(topk, scores.size());
1490+
}
14601491

14611492
// drop collection
14621493
DropCollectionParam dropParam = DropCollectionParam.newBuilder()

sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2969,6 +2969,8 @@ void testSearchResultsWrapper() {
29692969
String fieldName = "test";
29702970
SearchResultData results = SearchResultData.newBuilder()
29712971
.setTopK(topK)
2972+
.addTopks(topK)
2973+
.addTopks(topK) // numQueries=2, the topks list must have 2 elements
29722974
.setNumQueries(numQueries)
29732975
.setIds(IDs.newBuilder()
29742976
.setIntId(LongArray.newBuilder()
@@ -2996,6 +2998,8 @@ void testSearchResultsWrapper() {
29962998
// for string id
29972999
results = SearchResultData.newBuilder()
29983000
.setTopK(topK)
3001+
.addTopks(topK)
3002+
.addTopks(topK) // numQueries=2, the topks list must have 2 elements
29993003
.setNumQueries(numQueries)
30003004
.setIds(IDs.newBuilder()
30013005
.setStrId(StringArray.newBuilder()

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import io.milvus.orm.iterator.SearchIterator;
3434
import io.milvus.orm.iterator.SearchIteratorV2;
3535
import io.milvus.param.Constant;
36+
import io.milvus.param.dml.HybridSearchParam;
3637
import io.milvus.pool.MilvusClientV2Pool;
3738
import io.milvus.pool.PoolConfig;
3839
import io.milvus.response.QueryResultsWrapper;
@@ -1010,6 +1011,63 @@ void testHybridSearch() {
10101011
Assertions.assertEquals(16, descResp.getFieldNames().size());
10111012
Assertions.assertEquals(3, descResp.getVectorFieldNames().size());
10121013

1014+
// prepare sub requests
1015+
int nq = 5;
1016+
int topk = 10;
1017+
Function<Integer, HybridSearchReq> genRequestFunc =
1018+
sparseCount -> {
1019+
List<BaseVector> floatVectors = new ArrayList<>();
1020+
List<BaseVector> binaryVectors = new ArrayList<>();
1021+
List<BaseVector> sparseVectors = new ArrayList<>();
1022+
for (int i = 0; i < nq; i++) {
1023+
floatVectors.add(new FloatVec(utils.generateFloatVector()));
1024+
binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
1025+
}
1026+
for (int i = 0; i < sparseCount; i++) {
1027+
sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
1028+
}
1029+
1030+
List<AnnSearchReq> searchRequests = new ArrayList<>();
1031+
searchRequests.add(AnnSearchReq.builder()
1032+
.vectorFieldName("float_vector")
1033+
.vectors(floatVectors)
1034+
.params("{\"nprobe\": 10}")
1035+
.limit(15)
1036+
.build());
1037+
searchRequests.add(AnnSearchReq.builder()
1038+
.vectorFieldName("binary_vector")
1039+
.vectors(binaryVectors)
1040+
.limit(5)
1041+
.build());
1042+
searchRequests.add(AnnSearchReq.builder()
1043+
.vectorFieldName("sparse_vector")
1044+
.vectors(sparseVectors)
1045+
.limit(7)
1046+
.build());
1047+
1048+
return HybridSearchReq.builder()
1049+
.collectionName(randomCollectionName)
1050+
.searchRequests(searchRequests)
1051+
.ranker(new RRFRanker(20))
1052+
.limit(topk)
1053+
.consistencyLevel(ConsistencyLevel.BOUNDED)
1054+
.build();
1055+
};
1056+
1057+
// search with an empty nq, return error
1058+
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(0)));
1059+
1060+
// unequal nq, return error
1061+
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(1)));
1062+
1063+
// search on empty collection, no result returned
1064+
SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(nq));
1065+
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
1066+
Assertions.assertEquals(nq, searchResults.size());
1067+
for (List<SearchResp.SearchResult> result : searchResults) {
1068+
Assertions.assertTrue(result.isEmpty());
1069+
}
1070+
10131071
// insert rows
10141072
long count = 10000;
10151073
List<JsonObject> data = generateRandomData(collectionSchema, count);
@@ -1023,45 +1081,9 @@ void testHybridSearch() {
10231081
long rowCount = getRowCount(randomCollectionName);
10241082
Assertions.assertEquals(count, rowCount);
10251083

1026-
// hybrid search in collection
1027-
int nq = 5;
1028-
int topk = 10;
1029-
List<BaseVector> floatVectors = new ArrayList<>();
1030-
List<BaseVector> binaryVectors = new ArrayList<>();
1031-
List<BaseVector> sparseVectors = new ArrayList<>();
1032-
for (int i = 0; i < nq; i++) {
1033-
floatVectors.add(new FloatVec(utils.generateFloatVector()));
1034-
binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
1035-
sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
1036-
}
1037-
1038-
List<AnnSearchReq> searchRequests = new ArrayList<>();
1039-
searchRequests.add(AnnSearchReq.builder()
1040-
.vectorFieldName("float_vector")
1041-
.vectors(floatVectors)
1042-
.params("{\"nprobe\": 10}")
1043-
.limit(10)
1044-
.build());
1045-
searchRequests.add(AnnSearchReq.builder()
1046-
.vectorFieldName("binary_vector")
1047-
.vectors(binaryVectors)
1048-
.limit(50)
1049-
.build());
1050-
searchRequests.add(AnnSearchReq.builder()
1051-
.vectorFieldName("sparse_vector")
1052-
.vectors(sparseVectors)
1053-
.limit(100)
1054-
.build());
1055-
1056-
HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
1057-
.collectionName(randomCollectionName)
1058-
.searchRequests(searchRequests)
1059-
.ranker(new RRFRanker(20))
1060-
.limit(topk)
1061-
.consistencyLevel(ConsistencyLevel.BOUNDED)
1062-
.build();
1063-
SearchResp searchResp = client.hybridSearch(hybridSearchReq);
1064-
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
1084+
// search again, there are results
1085+
searchResp = client.hybridSearch(genRequestFunc.apply(nq));
1086+
searchResults = searchResp.getSearchResults();
10651087
Assertions.assertEquals(nq, searchResults.size());
10661088
for (int i = 0; i < nq; i++) {
10671089
List<SearchResp.SearchResult> results = searchResults.get(i);

0 commit comments

Comments
 (0)