diff --git a/examples/src/main/java/io/milvus/v2/HybridSearchExample.java b/examples/src/main/java/io/milvus/v2/HybridSearchExample.java index 139e745b5..9608fc8cc 100644 --- a/examples/src/main/java/io/milvus/v2/HybridSearchExample.java +++ b/examples/src/main/java/io/milvus/v2/HybridSearchExample.java @@ -31,6 +31,7 @@ import io.milvus.v2.service.collection.request.CreateCollectionReq; import io.milvus.v2.service.collection.request.DropCollectionReq; import io.milvus.v2.service.vector.request.AnnSearchReq; +import io.milvus.v2.service.vector.request.FunctionScore; import io.milvus.v2.service.vector.request.HybridSearchReq; import io.milvus.v2.service.vector.request.InsertReq; import io.milvus.v2.service.vector.request.QueryReq; @@ -122,7 +123,6 @@ private void createCollection() { .metricType(BINARY_VECTOR_METRIC) .build()); Map fv16Params = new HashMap<>(); - fv16Params.clear(); fv16Params.put("M",16); fv16Params.put("efConstruction",64); indexes.add(IndexParam.builder() @@ -212,7 +212,9 @@ private void hybridSearch() { HybridSearchReq hybridSearchReq = HybridSearchReq.builder() .collectionName(COLLECTION_NAME) .searchRequests(searchRequests) - .ranker(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build()) + .functionScore(FunctionScore.builder() + .addFunction(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build()) + .build()) .limit(5) .consistencyLevel(ConsistencyLevel.BOUNDED) .build(); diff --git a/sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java b/sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java index 8286529d6..3d0c0c69d 100644 --- a/sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java +++ b/sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java @@ -210,8 +210,10 @@ private boolean isResSufficient(List ret) { private QueryResults executeQuery(String expr, long offset, long limit, long ts, boolean isSeek) { // for seeking offset, no need to return output fields List outputFields = new ArrayList<>(); + boolean reduceStopForBest = queryIteratorParam.isReduceStopForBest(); if (!isSeek) { outputFields = queryIteratorParam.getOutFields(); + reduceStopForBest = false; } QueryParam queryParam = QueryParam.newBuilder() .withDatabaseName(queryIteratorParam.getDatabaseName()) @@ -230,7 +232,7 @@ private QueryResults executeQuery(String expr, long offset, long limit, long ts, // reduce stop for best builder.addQueryParams(KeyValuePair.newBuilder() .setKey(Constant.REDUCE_STOP_FOR_BEST) - .setValue(String.valueOf(queryIteratorParam.isReduceStopForBest())) + .setValue(String.valueOf(reduceStopForBest)) .build()); // iterator diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java index f0ef9ae1d..0e650d0c5 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java @@ -194,7 +194,8 @@ public static class FieldSchema { @Data @SuperBuilder public static class Function { - private String name; + @Builder.Default + private String name = ""; @Builder.Default private String description = ""; @Builder.Default diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/FunctionScore.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/FunctionScore.java new file mode 100644 index 000000000..28c987584 --- /dev/null +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/FunctionScore.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.v2.service.vector.request; + +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@Data +@SuperBuilder +public class FunctionScore { + @Builder.Default + private List functions = new ArrayList<>(); + @Builder.Default + private Map params = new HashMap<>(); + + public static abstract class FunctionScoreBuilder> { + public B addFunction(CreateCollectionReq.Function func) { + if(null == this.functions$value ){ + this.functions$value = new ArrayList<>(); + } + this.functions$value.add(func); + this.functions$set = true; + return self(); + } + } +} diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java index e651a6a10..52b6531ab 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java @@ -35,7 +35,6 @@ public class HybridSearchReq private String collectionName; private List partitionNames; private List searchRequests; - private CreateCollectionReq.Function ranker; @Builder.Default @Deprecated private int topK = 0; // deprecated, replaced by "limit" @@ -51,6 +50,11 @@ public class HybridSearchReq private String groupByFieldName; private Integer groupSize; private Boolean strictGroupSize; + @Deprecated + private CreateCollectionReq.Function ranker; + // milvus v2.6.1 supports multi-rankers. The "ranker" still works. It is recommended + // to use functionScore even you have only one ranker. Not allow to set both. + private FunctionScore functionScore; public static abstract class HybridSearchReqBuilder> { // topK is deprecated, topK and limit must be the same value diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java index f7dd43080..f70d677a9 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java @@ -67,7 +67,11 @@ public class SearchReq { private String groupByFieldName; private Integer groupSize; private Boolean strictGroupSize; + @Deprecated private CreateCollectionReq.Function ranker; + // milvus v2.6.1 supports multi-rankers. The "ranker" still works. It is recommended + // to use functionScore even you have only one ranker. Not allow to set both. + private FunctionScore functionScore; // Expression template, to improve expression parsing performance in complicated list // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......] diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java index 705b60104..5af8ff7b9 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java @@ -20,6 +20,8 @@ package io.milvus.v2.utils; import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import com.google.gson.reflect.TypeToken; import com.google.protobuf.ByteString; import io.milvus.common.utils.GTsDict; @@ -33,6 +35,7 @@ import io.milvus.v2.exception.MilvusClientException; import io.milvus.v2.service.collection.request.CreateCollectionReq; import io.milvus.v2.service.vector.request.*; +import io.milvus.v2.service.vector.request.FunctionScore; import io.milvus.v2.service.vector.request.data.*; import io.milvus.v2.service.vector.request.ranker.RRFRanker; import io.milvus.v2.service.vector.request.ranker.WeightedRanker; @@ -283,8 +286,14 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) { // set ranker, support reranking search result from v2.6.1 CreateCollectionReq.Function ranker = request.getRanker(); - if (ranker != null) { - builder.setFunctionScore(convertFunctionScore(ranker)); + io.milvus.v2.service.vector.request.FunctionScore functionScore = request.getFunctionScore(); + if (ranker != null && functionScore != null) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Not allow to set both ranker and functionScore."); + } + if (functionScore != null) { + builder.setFunctionScore(convertFunctionScore(functionScore)); + } else if (ranker != null) { + builder.setFunctionScore(convertOneFunction(ranker)); } return builder.build(); @@ -480,24 +489,28 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ builder.addRequests(searchRequest); } - // set ranker - CreateCollectionReq.Function ranker = request.getRanker(); - if (ranker == null) { - throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null."); - } - Map props = new HashMap<>(); props.put(Constant.LIMIT, String.valueOf(request.getLimit())); props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal())); props.put(Constant.OFFSET, String.valueOf(request.getOffset())); - if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) { - // old logic for RRF/Weighted ranker - Map params = ranker.getParams(); - props.putAll(params); - } else { - // new logic for Decay/Model ranker - builder.setFunctionScore(convertFunctionScore(ranker)); + // set ranker + CreateCollectionReq.Function ranker = request.getRanker(); + io.milvus.v2.service.vector.request.FunctionScore functionScore = request.getFunctionScore(); + if (ranker != null && functionScore != null) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Not allow to set both ranker and functionScore."); + } + if (functionScore != null) { + builder.setFunctionScore(convertFunctionScore(functionScore)); + } else if (ranker != null) { + if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) { + // old logic for RRF/Weighted ranker + Map params = ranker.getParams(); + props.putAll(params); + } else { + // new logic for Decay/Model ranker + builder.setFunctionScore(convertOneFunction(ranker)); + } } List propertiesList = ParamUtils.AssembleKvPair(props); @@ -545,15 +558,42 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ return builder.build(); } - private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) { - FunctionSchema schema = FunctionSchema.newBuilder() + private FunctionSchema convertFunctionSchema(CreateCollectionReq.Function function) { + Map params = function.getParams(); + // FunctionSchema type keyword is "reranker", old RRF/Weighted ranker type keyword is "strategy" + // FunctionSchema parameters are flat, old RRF/Weighted parameters are wrapped by "params" + if (function instanceof RRFRanker || function instanceof WeightedRanker) { + String name = (function instanceof RRFRanker) ? "rrf" : "weighted"; + params.put("reranker", name); + JsonObject inner = JsonParser.parseString(params.get("params")).getAsJsonObject(); + for (String key : inner.keySet()) { + params.put(key, inner.get(key).toString()); + } + params.remove("strategy"); + params.remove("params"); + } + return FunctionSchema.newBuilder() .setName(function.getName()) .setDescription(function.getDescription()) .setType(FunctionType.forNumber(function.getFunctionType().getCode())) .addAllInputFieldNames(function.getInputFieldNames()) - .addAllParams(ParamUtils.AssembleKvPair(function.getParams())) + .addAllParams(ParamUtils.AssembleKvPair(params)) .build(); - return FunctionScore.newBuilder().addFunctions(schema).build(); + } + + private io.milvus.grpc.FunctionScore convertOneFunction(CreateCollectionReq.Function function) { + FunctionSchema schema = convertFunctionSchema(function); + return io.milvus.grpc.FunctionScore.newBuilder().addFunctions(schema).build(); + } + + private io.milvus.grpc.FunctionScore convertFunctionScore(FunctionScore functionScore) { + io.milvus.grpc.FunctionScore.Builder builder = io.milvus.grpc.FunctionScore.newBuilder(); + for (CreateCollectionReq.Function function : functionScore.getFunctions()) { + FunctionSchema schema = convertFunctionSchema(function); + builder.addFunctions(schema); + } + builder.addAllParams(ParamUtils.AssembleKvPair(functionScore.getParams())); + return builder.build(); } public String getExprById(String primaryFieldName, List ids) { diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 0289436ec..179f2e7f0 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -1074,8 +1074,8 @@ void testHybridSearch() { // prepare sub requests int nq = 5; int topk = 10; - Function genRequestFunc = - sparseCount -> { + Function, HybridSearchReq> genRequestFunc = + config -> { List floatVectors = new ArrayList<>(); List binaryVectors = new ArrayList<>(); List sparseVectors = new ArrayList<>(); @@ -1083,6 +1083,7 @@ void testHybridSearch() { floatVectors.add(new FloatVec(utils.generateFloatVector())); binaryVectors.add(new BinaryVec(utils.generateBinaryVector())); } + int sparseCount = (Integer)config.get("sparseCount"); for (int i = 0; i < sparseCount; i++) { sparseVectors.add(new SparseFloatVec(utils.generateSparseVector())); } @@ -1105,23 +1106,40 @@ void testHybridSearch() { .limit(7) .build()); - return HybridSearchReq.builder() - .collectionName(randomCollectionName) - .searchRequests(searchRequests) - .ranker(RRFRanker.builder().k(20).build()) - .limit(topk) - .consistencyLevel(ConsistencyLevel.BOUNDED) - .build(); + CreateCollectionReq.Function ranker = WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build(); + boolean useFunctionScore = (Boolean)config.get("useFunctionScore"); + if (useFunctionScore) { + return HybridSearchReq.builder() + .collectionName(randomCollectionName) + .searchRequests(searchRequests) + .functionScore(FunctionScore.builder().addFunction(ranker).build()) + .limit(topk) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build(); + } else { + return HybridSearchReq.builder() + .collectionName(randomCollectionName) + .searchRequests(searchRequests) + .ranker(RRFRanker.builder().k(20).build()) + .limit(topk) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build(); + } }; + Map config = new HashMap<>(); + config.put("sparseCount", 0); + config.put("useFunctionScore", false); // search with an empty nq, return error - Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(0))); + Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(config))); // unequal nq, return error - Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(1))); + config.put("sparseCount", 1); + Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(config))); // search on empty collection, no result returned - SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(nq)); + config.put("sparseCount", nq); + SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(config)); List> searchResults = searchResp.getSearchResults(); Assertions.assertEquals(nq, searchResults.size()); for (List result : searchResults) { @@ -1142,7 +1160,8 @@ void testHybridSearch() { Assertions.assertEquals(count, rowCount); // search again, there are results - searchResp = client.hybridSearch(genRequestFunc.apply(nq)); + config.put("useFunctionScore", true); + searchResp = client.hybridSearch(genRequestFunc.apply(config)); searchResults = searchResp.getSearchResults(); Assertions.assertEquals(nq, searchResults.size()); for (int i = 0; i < nq; i++) {