diff --git a/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java b/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java index 388fc9657..3d15d125e 100644 --- a/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java +++ b/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java @@ -237,6 +237,14 @@ public List getIDScore(int indexOfTarget) throws ParamException, Illega return idScores; } + /** + * Gets how many nq are searched. + * @return how many nq are searched + */ + public long getNumQueries() { + return results.getNumQueries(); + } + @Getter private static final class Position { private final long offset; @@ -250,11 +258,12 @@ public Position(long offset, long k) { private Position getOffsetByIndex(int indexOfTarget) { List kList = results.getTopksList(); - // if the server didn't return separate topK, use same topK value + // if the server didn't return separate topK, use same topK value "0" + // will return an empty result for each nq instead of throwing an exception if (kList.isEmpty()) { kList = new ArrayList<>(); for (long i = 0; i < results.getNumQueries(); ++i) { - kList.add(results.getTopK()); + kList.add(0L); } } diff --git a/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java b/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java index 166491d11..2a4b16016 100644 --- a/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java +++ b/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java @@ -27,6 +27,7 @@ import io.milvus.common.utils.Float16Utils; import io.milvus.common.utils.GTsDict; import io.milvus.common.utils.JsonUtils; +import io.milvus.exception.ParamException; import io.milvus.grpc.*; import io.milvus.orm.iterator.QueryIterator; import io.milvus.orm.iterator.SearchIterator; @@ -63,6 +64,7 @@ import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.function.Function; @Testcontainers(disabledWithoutDocker = true) class MilvusClientDockerTest { @@ -1345,18 +1347,6 @@ void testMultipleVectorFields() { R createR = client.createCollection(createParam); Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue()); - // insert data to multiple vector fields - int rowCount = 10000; - List fields = generateColumnsData(schema, rowCount, 0); - - InsertParam insertParam = InsertParam.newBuilder() - .withCollectionName(randomCollectionName) - .withFields(fields) - .build(); - - R insertR = client.insert(insertParam); - Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue()); - // create indexes on multiple vector fields CreateIndexParam indexParam = CreateIndexParam.newBuilder() .withCollectionName(randomCollectionName) @@ -1397,53 +1387,86 @@ void testMultipleVectorFields() { .build()); Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue()); - // search on multiple vector fields - AnnSearchParam param1 = AnnSearchParam.newBuilder() - .withVectorFieldName(DataType.FloatVector.name()) - .withFloatVectors(utils.generateFloatVectors(1)) - .withMetricType(MetricType.COSINE) - .withParams("{\"nprobe\": 32}") - .withLimit(10L) - .build(); - - AnnSearchParam param2 = AnnSearchParam.newBuilder() - .withVectorFieldName(DataType.BinaryVector.name()) - .withBinaryVectors(utils.generateBinaryVectors(1)) - .withMetricType(MetricType.HAMMING) - .withParams("{}") - .withLimit(5L) - .build(); - - AnnSearchParam param3 = AnnSearchParam.newBuilder() - .withVectorFieldName(DataType.SparseFloatVector.name()) - .withSparseFloatVectors(utils.generateSparseVectors(1)) - .withMetricType(MetricType.IP) - .withParams("{\"drop_ratio_search\":0.2}") - .withLimit(7L) - .build(); + // prepare sub requests + int nq = 5; + long topk = 10L; + Function genRequestFunc = + sparseCount -> { + AnnSearchParam param1 = AnnSearchParam.newBuilder() + .withVectorFieldName(DataType.FloatVector.name()) + .withFloatVectors(utils.generateFloatVectors(nq)) + .withMetricType(MetricType.COSINE) + .withParams("{\"nprobe\": 32}") + .withLimit(15L) + .build(); + + AnnSearchParam param2 = AnnSearchParam.newBuilder() + .withVectorFieldName(DataType.BinaryVector.name()) + .withBinaryVectors(utils.generateBinaryVectors(nq)) + .withMetricType(MetricType.HAMMING) + .withParams("{}") + .withLimit(5L) + .build(); + + List> sparseVEctors = sparseCount > 0 ? + utils.generateSparseVectors(sparseCount) : new ArrayList<>(); + AnnSearchParam param3 = AnnSearchParam.newBuilder() + .withVectorFieldName(DataType.SparseFloatVector.name()) + .withSparseFloatVectors(sparseVEctors) + .withMetricType(MetricType.IP) + .withParams("{\"drop_ratio_search\":0.2}") + .withLimit(7L) + .build(); + + // search with an empty nq, return error + return HybridSearchParam.newBuilder() + .withCollectionName(randomCollectionName) + .addOutField(DataType.SparseFloatVector.name()) + .addSearchRequest(param1) + .addSearchRequest(param2) + .addSearchRequest(param3) + .withLimit(topk) + .withConsistencyLevel(ConsistencyLevelEnum.STRONG) + .withRanker(WeightedRanker.newBuilder() + .withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f)) + .build()) + .withOutFields(Collections.singletonList("*")) + .build(); + }; + + // search with an empty nq, return error + Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(0)); + + // unequal nq, return error + Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(1)); + + // search on empty collection, no result returned + R searchR = client.hybridSearch(genRequestFunc.apply(nq)); + Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue()); + SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults()); + for (int i = 0; i < results.getNumQueries(); ++i) { + List scores = results.getIDScore(0); + Assertions.assertTrue(scores.isEmpty()); + } - HybridSearchParam searchParam = HybridSearchParam.newBuilder() + // insert data to multiple vector fields + int rowCount = 10000; + List fields = generateColumnsData(schema, rowCount, 0); + InsertParam insertParam = InsertParam.newBuilder() .withCollectionName(randomCollectionName) - .addOutField(DataType.SparseFloatVector.name()) - .addSearchRequest(param1) - .addSearchRequest(param2) - .addSearchRequest(param3) - .withLimit(3L) - .withConsistencyLevel(ConsistencyLevelEnum.STRONG) - .withRanker(WeightedRanker.newBuilder() - .withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f)) - .build()) - .withOutFields(Collections.singletonList("*")) + .withFields(fields) .build(); + R insertR = client.insert(insertParam); + Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue()); - R searchR = client.hybridSearch(searchParam); + // search on multiple vector fields + searchR = client.hybridSearch(genRequestFunc.apply(nq)); Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue()); - // print search result - SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults()); + // check search result + results = new SearchResultsWrapper(searchR.getData().getResults()); List scores = results.getIDScore(0); for (SearchResultsWrapper.IDScore score : scores) { - System.out.println(score); Object id = score.get("id"); Assertions.assertInstanceOf(Long.class, id); Object fv = score.get(DataType.FloatVector.name()); @@ -1457,6 +1480,10 @@ void testMultipleVectorFields() { Object sv = score.get(DataType.SparseFloatVector.name()); Assertions.assertInstanceOf(SortedMap.class, sv); } + for (int i = 0; i < results.getNumQueries(); ++i) { + scores = results.getIDScore(i); + Assertions.assertEquals(topk, scores.size()); + } // drop collection DropCollectionParam dropParam = DropCollectionParam.newBuilder() diff --git a/sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java b/sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java index d36a01d9c..88b887055 100644 --- a/sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java +++ b/sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java @@ -2969,6 +2969,8 @@ void testSearchResultsWrapper() { String fieldName = "test"; SearchResultData results = SearchResultData.newBuilder() .setTopK(topK) + .addTopks(topK) + .addTopks(topK) // numQueries=2, the topks list must have 2 elements .setNumQueries(numQueries) .setIds(IDs.newBuilder() .setIntId(LongArray.newBuilder() @@ -2996,6 +2998,8 @@ void testSearchResultsWrapper() { // for string id results = SearchResultData.newBuilder() .setTopK(topK) + .addTopks(topK) + .addTopks(topK) // numQueries=2, the topks list must have 2 elements .setNumQueries(numQueries) .setIds(IDs.newBuilder() .setStrId(StringArray.newBuilder() diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 9eeb7fb0f..e760429ff 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -33,6 +33,7 @@ import io.milvus.orm.iterator.SearchIterator; import io.milvus.orm.iterator.SearchIteratorV2; import io.milvus.param.Constant; +import io.milvus.param.dml.HybridSearchParam; import io.milvus.pool.MilvusClientV2Pool; import io.milvus.pool.PoolConfig; import io.milvus.response.QueryResultsWrapper; @@ -1010,6 +1011,63 @@ void testHybridSearch() { Assertions.assertEquals(16, descResp.getFieldNames().size()); Assertions.assertEquals(3, descResp.getVectorFieldNames().size()); + // prepare sub requests + int nq = 5; + int topk = 10; + Function genRequestFunc = + sparseCount -> { + List floatVectors = new ArrayList<>(); + List binaryVectors = new ArrayList<>(); + List sparseVectors = new ArrayList<>(); + for (int i = 0; i < nq; i++) { + floatVectors.add(new FloatVec(utils.generateFloatVector())); + binaryVectors.add(new BinaryVec(utils.generateBinaryVector())); + } + for (int i = 0; i < sparseCount; i++) { + sparseVectors.add(new SparseFloatVec(utils.generateSparseVector())); + } + + List searchRequests = new ArrayList<>(); + searchRequests.add(AnnSearchReq.builder() + .vectorFieldName("float_vector") + .vectors(floatVectors) + .params("{\"nprobe\": 10}") + .limit(15) + .build()); + searchRequests.add(AnnSearchReq.builder() + .vectorFieldName("binary_vector") + .vectors(binaryVectors) + .limit(5) + .build()); + searchRequests.add(AnnSearchReq.builder() + .vectorFieldName("sparse_vector") + .vectors(sparseVectors) + .limit(7) + .build()); + + return HybridSearchReq.builder() + .collectionName(randomCollectionName) + .searchRequests(searchRequests) + .ranker(new RRFRanker(20)) + .limit(topk) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build(); + }; + + // search with an empty nq, return error + Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(0))); + + // unequal nq, return error + Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(1))); + + // search on empty collection, no result returned + SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(nq)); + List> searchResults = searchResp.getSearchResults(); + Assertions.assertEquals(nq, searchResults.size()); + for (List result : searchResults) { + Assertions.assertTrue(result.isEmpty()); + } + // insert rows long count = 10000; List data = generateRandomData(collectionSchema, count); @@ -1023,45 +1081,9 @@ void testHybridSearch() { long rowCount = getRowCount(randomCollectionName); Assertions.assertEquals(count, rowCount); - // hybrid search in collection - int nq = 5; - int topk = 10; - List floatVectors = new ArrayList<>(); - List binaryVectors = new ArrayList<>(); - List sparseVectors = new ArrayList<>(); - for (int i = 0; i < nq; i++) { - floatVectors.add(new FloatVec(utils.generateFloatVector())); - binaryVectors.add(new BinaryVec(utils.generateBinaryVector())); - sparseVectors.add(new SparseFloatVec(utils.generateSparseVector())); - } - - List searchRequests = new ArrayList<>(); - searchRequests.add(AnnSearchReq.builder() - .vectorFieldName("float_vector") - .vectors(floatVectors) - .params("{\"nprobe\": 10}") - .limit(10) - .build()); - searchRequests.add(AnnSearchReq.builder() - .vectorFieldName("binary_vector") - .vectors(binaryVectors) - .limit(50) - .build()); - searchRequests.add(AnnSearchReq.builder() - .vectorFieldName("sparse_vector") - .vectors(sparseVectors) - .limit(100) - .build()); - - HybridSearchReq hybridSearchReq = HybridSearchReq.builder() - .collectionName(randomCollectionName) - .searchRequests(searchRequests) - .ranker(new RRFRanker(20)) - .limit(topk) - .consistencyLevel(ConsistencyLevel.BOUNDED) - .build(); - SearchResp searchResp = client.hybridSearch(hybridSearchReq); - List> searchResults = searchResp.getSearchResults(); + // search again, there are results + searchResp = client.hybridSearch(genRequestFunc.apply(nq)); + searchResults = searchResp.getSearchResults(); Assertions.assertEquals(nq, searchResults.size()); for (int i = 0; i < nq; i++) { List results = searchResults.get(i);