Skip to content

Commit 784b434

Browse files
authored
Avoid exception when search result is empty (#1459)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent afea667 commit 784b434

4 files changed

Lines changed: 155 additions & 92 deletions

File tree

sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ public List<IDScore> getIDScore(int indexOfTarget) throws ParamException, Illega
237237
return idScores;
238238
}
239239

240+
/**
241+
* Gets how many nq are searched.
242+
* @return how many nq are searched
243+
*/
244+
public long getNumQueries() {
245+
return results.getNumQueries();
246+
}
247+
240248
@Getter
241249
private static final class Position {
242250
private final long offset;
@@ -250,11 +258,12 @@ public Position(long offset, long k) {
250258
private Position getOffsetByIndex(int indexOfTarget) {
251259
List<Long> kList = results.getTopksList();
252260

253-
// if the server didn't return separate topK, use same topK value
261+
// if the server didn't return separate topK, use same topK value "0"
262+
// will return an empty result for each nq instead of throwing an exception
254263
if (kList.isEmpty()) {
255264
kList = new ArrayList<>();
256265
for (long i = 0; i < results.getNumQueries(); ++i) {
257-
kList.add(results.getTopK());
266+
kList.add(0L);
258267
}
259268
}
260269

sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java

Lines changed: 79 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import io.milvus.common.utils.Float16Utils;
2828
import io.milvus.common.utils.GTsDict;
2929
import io.milvus.common.utils.JsonUtils;
30+
import io.milvus.exception.ParamException;
3031
import io.milvus.grpc.*;
3132
import io.milvus.orm.iterator.QueryIterator;
3233
import io.milvus.orm.iterator.SearchIterator;
@@ -63,6 +64,8 @@
6364
import java.util.*;
6465
import java.util.concurrent.ExecutionException;
6566
import java.util.concurrent.TimeUnit;
67+
import java.util.function.Function;
68+
6669

6770
@Testcontainers(disabledWithoutDocker = true)
6871
class MilvusClientDockerTest {
@@ -1345,18 +1348,6 @@ void testMultipleVectorFields() {
13451348
R<RpcStatus> createR = client.createCollection(createParam);
13461349
Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
13471350

1348-
// insert data to multiple vector fields
1349-
int rowCount = 10000;
1350-
List<InsertParam.Field> fields = generateColumnsData(schema, rowCount, 0);
1351-
1352-
InsertParam insertParam = InsertParam.newBuilder()
1353-
.withCollectionName(randomCollectionName)
1354-
.withFields(fields)
1355-
.build();
1356-
1357-
R<MutationResult> insertR = client.insert(insertParam);
1358-
Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
1359-
13601351
// create indexes on multiple vector fields
13611352
CreateIndexParam indexParam = CreateIndexParam.newBuilder()
13621353
.withCollectionName(randomCollectionName)
@@ -1397,53 +1388,86 @@ void testMultipleVectorFields() {
13971388
.build());
13981389
Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
13991390

1400-
// search on multiple vector fields
1401-
AnnSearchParam param1 = AnnSearchParam.newBuilder()
1402-
.withVectorFieldName(DataType.FloatVector.name())
1403-
.withFloatVectors(utils.generateFloatVectors(1))
1404-
.withMetricType(MetricType.COSINE)
1405-
.withParams("{\"nprobe\": 32}")
1406-
.withLimit(10L)
1407-
.build();
1408-
1409-
AnnSearchParam param2 = AnnSearchParam.newBuilder()
1410-
.withVectorFieldName(DataType.BinaryVector.name())
1411-
.withBinaryVectors(utils.generateBinaryVectors(1))
1412-
.withMetricType(MetricType.HAMMING)
1413-
.withParams("{}")
1414-
.withLimit(5L)
1415-
.build();
1416-
1417-
AnnSearchParam param3 = AnnSearchParam.newBuilder()
1418-
.withVectorFieldName(DataType.SparseFloatVector.name())
1419-
.withSparseFloatVectors(utils.generateSparseVectors(1))
1420-
.withMetricType(MetricType.IP)
1421-
.withParams("{\"drop_ratio_search\":0.2}")
1422-
.withLimit(7L)
1423-
.build();
1391+
// prepare sub requests
1392+
int nq = 5;
1393+
long topk = 10L;
1394+
Function<Integer, HybridSearchParam> genRequestFunc =
1395+
sparseCount -> {
1396+
AnnSearchParam param1 = AnnSearchParam.newBuilder()
1397+
.withVectorFieldName(DataType.FloatVector.name())
1398+
.withFloatVectors(utils.generateFloatVectors(nq))
1399+
.withMetricType(MetricType.COSINE)
1400+
.withParams("{\"nprobe\": 32}")
1401+
.withLimit(15L)
1402+
.build();
1403+
1404+
AnnSearchParam param2 = AnnSearchParam.newBuilder()
1405+
.withVectorFieldName(DataType.BinaryVector.name())
1406+
.withBinaryVectors(utils.generateBinaryVectors(nq))
1407+
.withMetricType(MetricType.HAMMING)
1408+
.withParams("{}")
1409+
.withLimit(5L)
1410+
.build();
1411+
1412+
List<SortedMap<Long, Float>> sparseVEctors = sparseCount > 0 ?
1413+
utils.generateSparseVectors(sparseCount) : new ArrayList<>();
1414+
AnnSearchParam param3 = AnnSearchParam.newBuilder()
1415+
.withVectorFieldName(DataType.SparseFloatVector.name())
1416+
.withSparseFloatVectors(sparseVEctors)
1417+
.withMetricType(MetricType.IP)
1418+
.withParams("{\"drop_ratio_search\":0.2}")
1419+
.withLimit(7L)
1420+
.build();
1421+
1422+
// search with an empty nq, return error
1423+
return HybridSearchParam.newBuilder()
1424+
.withCollectionName(randomCollectionName)
1425+
.addOutField(DataType.SparseFloatVector.name())
1426+
.addSearchRequest(param1)
1427+
.addSearchRequest(param2)
1428+
.addSearchRequest(param3)
1429+
.withLimit(topk)
1430+
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
1431+
.withRanker(WeightedRanker.newBuilder()
1432+
.withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f))
1433+
.build())
1434+
.withOutFields(Collections.singletonList("*"))
1435+
.build();
1436+
};
1437+
1438+
// search with an empty nq, return error
1439+
Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(0));
1440+
1441+
// unequal nq, return error
1442+
Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(1));
1443+
1444+
// search on empty collection, no result returned
1445+
R<SearchResults> searchR = client.hybridSearch(genRequestFunc.apply(nq));
1446+
Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
1447+
SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
1448+
for (int i = 0; i < results.getNumQueries(); ++i) {
1449+
List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
1450+
Assertions.assertTrue(scores.isEmpty());
1451+
}
14241452

1425-
HybridSearchParam searchParam = HybridSearchParam.newBuilder()
1453+
// insert data to multiple vector fields
1454+
int rowCount = 10000;
1455+
List<InsertParam.Field> fields = generateColumnsData(schema, rowCount, 0);
1456+
InsertParam insertParam = InsertParam.newBuilder()
14261457
.withCollectionName(randomCollectionName)
1427-
.addOutField(DataType.SparseFloatVector.name())
1428-
.addSearchRequest(param1)
1429-
.addSearchRequest(param2)
1430-
.addSearchRequest(param3)
1431-
.withLimit(3L)
1432-
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
1433-
.withRanker(WeightedRanker.newBuilder()
1434-
.withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f))
1435-
.build())
1436-
.withOutFields(Collections.singletonList("*"))
1458+
.withFields(fields)
14371459
.build();
1460+
R<MutationResult> insertR = client.insert(insertParam);
1461+
Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
14381462

1439-
R<SearchResults> searchR = client.hybridSearch(searchParam);
1463+
// search on multiple vector fields
1464+
searchR = client.hybridSearch(genRequestFunc.apply(nq));
14401465
Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
14411466

1442-
// print search result
1443-
SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
1467+
// check search result
1468+
results = new SearchResultsWrapper(searchR.getData().getResults());
14441469
List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
14451470
for (SearchResultsWrapper.IDScore score : scores) {
1446-
System.out.println(score);
14471471
Object id = score.get("id");
14481472
Assertions.assertInstanceOf(Long.class, id);
14491473
Object fv = score.get(DataType.FloatVector.name());
@@ -1457,6 +1481,10 @@ void testMultipleVectorFields() {
14571481
Object sv = score.get(DataType.SparseFloatVector.name());
14581482
Assertions.assertInstanceOf(SortedMap.class, sv);
14591483
}
1484+
for (int i = 0; i < results.getNumQueries(); ++i) {
1485+
scores = results.getIDScore(i);
1486+
Assertions.assertEquals(topk, scores.size());
1487+
}
14601488

14611489
// drop collection
14621490
DropCollectionParam dropParam = DropCollectionParam.newBuilder()

sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2969,6 +2969,8 @@ void testSearchResultsWrapper() {
29692969
String fieldName = "test";
29702970
SearchResultData results = SearchResultData.newBuilder()
29712971
.setTopK(topK)
2972+
.addTopks(topK)
2973+
.addTopks(topK) // numQueries=2, the topks list must have 2 elements
29722974
.setNumQueries(numQueries)
29732975
.setIds(IDs.newBuilder()
29742976
.setIntId(LongArray.newBuilder()
@@ -2996,6 +2998,8 @@ void testSearchResultsWrapper() {
29962998
// for string id
29972999
results = SearchResultData.newBuilder()
29983000
.setTopK(topK)
3001+
.addTopks(topK)
3002+
.addTopks(topK) // numQueries=2, the topks list must have 2 elements
29993003
.setNumQueries(numQueries)
30003004
.setIds(IDs.newBuilder()
30013005
.setStrId(StringArray.newBuilder()

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import io.milvus.orm.iterator.SearchIterator;
3434
import io.milvus.orm.iterator.SearchIteratorV2;
3535
import io.milvus.param.Constant;
36+
import io.milvus.param.dml.HybridSearchParam;
3637
import io.milvus.pool.MilvusClientV2Pool;
3738
import io.milvus.pool.PoolConfig;
3839
import io.milvus.response.QueryResultsWrapper;
@@ -873,6 +874,63 @@ void testHybridSearch() {
873874
Assertions.assertEquals(16, descResp.getFieldNames().size());
874875
Assertions.assertEquals(3, descResp.getVectorFieldNames().size());
875876

877+
// prepare sub requests
878+
int nq = 5;
879+
int topk = 10;
880+
Function<Integer, HybridSearchReq> genRequestFunc =
881+
sparseCount -> {
882+
List<BaseVector> floatVectors = new ArrayList<>();
883+
List<BaseVector> binaryVectors = new ArrayList<>();
884+
List<BaseVector> sparseVectors = new ArrayList<>();
885+
for (int i = 0; i < nq; i++) {
886+
floatVectors.add(new FloatVec(utils.generateFloatVector()));
887+
binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
888+
}
889+
for (int i = 0; i < sparseCount; i++) {
890+
sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
891+
}
892+
893+
List<AnnSearchReq> searchRequests = new ArrayList<>();
894+
searchRequests.add(AnnSearchReq.builder()
895+
.vectorFieldName("float_vector")
896+
.vectors(floatVectors)
897+
.params("{\"nprobe\": 10}")
898+
.limit(15)
899+
.build());
900+
searchRequests.add(AnnSearchReq.builder()
901+
.vectorFieldName("binary_vector")
902+
.vectors(binaryVectors)
903+
.limit(5)
904+
.build());
905+
searchRequests.add(AnnSearchReq.builder()
906+
.vectorFieldName("sparse_vector")
907+
.vectors(sparseVectors)
908+
.limit(7)
909+
.build());
910+
911+
return HybridSearchReq.builder()
912+
.collectionName(randomCollectionName)
913+
.searchRequests(searchRequests)
914+
.ranker(new RRFRanker(20))
915+
.limit(topk)
916+
.consistencyLevel(ConsistencyLevel.BOUNDED)
917+
.build();
918+
};
919+
920+
// search with an empty nq, return error
921+
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(0)));
922+
923+
// unequal nq, return error
924+
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(1)));
925+
926+
// search on empty collection, no result returned
927+
SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(nq));
928+
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
929+
Assertions.assertEquals(nq, searchResults.size());
930+
for (List<SearchResp.SearchResult> result : searchResults) {
931+
Assertions.assertTrue(result.isEmpty());
932+
}
933+
876934
// insert rows
877935
long count = 10000;
878936
List<JsonObject> data = generateRandomData(collectionSchema, count);
@@ -886,45 +944,9 @@ void testHybridSearch() {
886944
long rowCount = getRowCount(randomCollectionName);
887945
Assertions.assertEquals(count, rowCount);
888946

889-
// hybrid search in collection
890-
int nq = 5;
891-
int topk = 10;
892-
List<BaseVector> floatVectors = new ArrayList<>();
893-
List<BaseVector> binaryVectors = new ArrayList<>();
894-
List<BaseVector> sparseVectors = new ArrayList<>();
895-
for (int i = 0; i < nq; i++) {
896-
floatVectors.add(new FloatVec(utils.generateFloatVector()));
897-
binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
898-
sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
899-
}
900-
901-
List<AnnSearchReq> searchRequests = new ArrayList<>();
902-
searchRequests.add(AnnSearchReq.builder()
903-
.vectorFieldName("float_vector")
904-
.vectors(floatVectors)
905-
.params("{\"nprobe\": 10}")
906-
.limit(10)
907-
.build());
908-
searchRequests.add(AnnSearchReq.builder()
909-
.vectorFieldName("binary_vector")
910-
.vectors(binaryVectors)
911-
.limit(50)
912-
.build());
913-
searchRequests.add(AnnSearchReq.builder()
914-
.vectorFieldName("sparse_vector")
915-
.vectors(sparseVectors)
916-
.limit(100)
917-
.build());
918-
919-
HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
920-
.collectionName(randomCollectionName)
921-
.searchRequests(searchRequests)
922-
.ranker(new RRFRanker(20))
923-
.limit(topk)
924-
.consistencyLevel(ConsistencyLevel.BOUNDED)
925-
.build();
926-
SearchResp searchResp = client.hybridSearch(hybridSearchReq);
927-
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
947+
// search again, there are results
948+
searchResp = client.hybridSearch(genRequestFunc.apply(nq));
949+
searchResults = searchResp.getSearchResults();
928950
Assertions.assertEquals(nq, searchResults.size());
929951
for (int i = 0; i < nq; i++) {
930952
List<SearchResp.SearchResult> results = searchResults.get(i);

0 commit comments

Comments
 (0)