diff --git a/docker-compose.yml b/docker-compose.yml index ab4ede9be..8ccda4dff 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -32,7 +32,7 @@ services: standalone: container_name: milvus-javasdk-test-standalone - image: milvusdb/milvus:v2.6.0-rc1 + image: milvusdb/milvus:v2.6.0 command: ["milvus", "run", "standalone"] environment: ETCD_ENDPOINTS: etcd:2379 @@ -77,7 +77,7 @@ services: standaloneslave: container_name: milvus-javasdk-test-slave-standalone - image: milvusdb/milvus:v2.6.0-rc1 + image: milvusdb/milvus:v2.6.0 command: ["milvus", "run", "standalone"] environment: ETCD_ENDPOINTS: etcdslave:2379 diff --git a/examples/src/main/java/io/milvus/v2/HybridSearchExample.java b/examples/src/main/java/io/milvus/v2/HybridSearchExample.java index 31f2f98dd..8c510ba90 100644 --- a/examples/src/main/java/io/milvus/v2/HybridSearchExample.java +++ b/examples/src/main/java/io/milvus/v2/HybridSearchExample.java @@ -213,7 +213,7 @@ private void hybridSearch() { HybridSearchReq hybridSearchReq = HybridSearchReq.builder() .collectionName(COLLECTION_NAME) .searchRequests(searchRequests) - .ranker(new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f))) + .ranker(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build()) .limit(5) .consistencyLevel(ConsistencyLevel.BOUNDED) .build(); diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java index 624f2cf8d..e651a6a10 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java @@ -20,8 +20,7 @@ package io.milvus.v2.service.vector.request; import io.milvus.v2.common.ConsistencyLevel; -import io.milvus.v2.service.collection.request.LoadCollectionReq; -import io.milvus.v2.service.vector.request.ranker.BaseRanker; +import io.milvus.v2.service.collection.request.CreateCollectionReq; import lombok.Builder; import lombok.Data; import lombok.experimental.SuperBuilder; @@ -36,7 +35,7 @@ public class HybridSearchReq private String collectionName; private List partitionNames; private List searchRequests; - private BaseRanker ranker; + private CreateCollectionReq.Function ranker; @Builder.Default @Deprecated private int topK = 0; // deprecated, replaced by "limit" diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java index 4c50ff568..f7dd43080 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java @@ -21,6 +21,7 @@ import io.milvus.v2.common.ConsistencyLevel; import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.CreateCollectionReq; import io.milvus.v2.service.vector.request.data.BaseVector; import lombok.Builder; @@ -66,6 +67,7 @@ public class SearchReq { private String groupByFieldName; private Integer groupSize; private Boolean strictGroupSize; + private CreateCollectionReq.Function ranker; // Expression template, to improve expression parsing performance in complicated list // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......] diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/BaseRanker.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/BaseRanker.java deleted file mode 100644 index dd6b907a9..000000000 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/BaseRanker.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package io.milvus.v2.service.vector.request.ranker; - -import java.util.Map; - -public abstract class BaseRanker { - public abstract Map getProperties(); -} diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/DecayRanker.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/DecayRanker.java new file mode 100644 index 000000000..fa2dc0cc3 --- /dev/null +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/DecayRanker.java @@ -0,0 +1,52 @@ +package io.milvus.v2.service.vector.request.ranker; + +import io.milvus.common.clientenum.FunctionType; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import lombok.Builder; +import lombok.experimental.SuperBuilder; + +import java.util.Map; + +/** + * The Decay reranking strategy, which by adjusting search rankings based on numeric field values. + * Read the doc for more info: https://milvus.io/docs/decay-ranker-overview.md + * + * You also can declare a decay ranker by Function + * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder() + * .functionType(FunctionType.RERANK) + * .name("time_decay") + * .description("time decay") + * .inputFieldNames(Collections.singletonList("timestamp")) + * .param("reranker", "decay") + * .param("function", "gauss") + * .param("origin", "1000") + * .param("scale", "10000") + * .param("offset", "24") + * .param("decay", "0.5") + * .build(); + */ +@SuperBuilder +public class DecayRanker extends CreateCollectionReq.Function { + @Builder.Default + private String function = "gauss"; + private Number origin; + private Number scale; + + public FunctionType getFunctionType() { + return FunctionType.RERANK; + } + + public Map getParams() { + // the parent params might contain "offset" and "decay" + Map props = super.getParams(); + props.put("reranker", "decay"); + props.put("function", function); // "gauss", "exp", or "linear" + if (origin != null) { + props.put("origin", origin.toString()); + } + if (scale != null) { + props.put("scale", scale.toString()); + } + return props; + } +} diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/ModelRanker.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/ModelRanker.java new file mode 100644 index 000000000..2084d19bc --- /dev/null +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/ModelRanker.java @@ -0,0 +1,55 @@ +package io.milvus.v2.service.vector.request.ranker; + +import com.google.gson.JsonArray; +import io.milvus.common.clientenum.FunctionType; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import lombok.Builder; +import lombok.experimental.SuperBuilder; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * The Model reranking strategy, which transforms Milvus search by integrating advanced language models + * that understand semantic relationships between queries and documents. + * Read the doc for more info: https://milvus.io/docs/model-ranker-overview.md + * + * You also can declare a model ranker by Function + * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder() + * .functionType(FunctionType.RERANK) + * .name("semantic_ranker") + * .description("semantic ranker") + * .inputFieldNames(Collections.singletonList("document")) + * .param("reranker", "model") + * .param("provider", "tei") + * .param("queries", "[\"machine learning for time series\"]") + * .param("endpoint", "http://model-service:8080") + * .build(); + */ +@SuperBuilder +public class ModelRanker extends CreateCollectionReq.Function { + @Builder.Default + private String provider = "tei"; + @Builder.Default + private List queries = new ArrayList<>(); + private String endpoint; + + public FunctionType getFunctionType() { + return FunctionType.RERANK; + } + + public Map getParams() { + // the parent params might contain "offset" and "decay" + Map props = super.getParams(); + props.put("reranker", "model"); + props.put("provider", provider); // "tei" or "vllm" + JsonArray json = new JsonArray(); + queries.forEach(json::add); + props.put("queries", json.toString()); + if (endpoint != null) { + props.put("endpoint", endpoint); + } + return props; + } +} diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/RRFRanker.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/RRFRanker.java index b35f7a834..68151b9af 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/RRFRanker.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/RRFRanker.java @@ -20,26 +20,49 @@ package io.milvus.v2.service.vector.request.ranker; import com.google.gson.JsonObject; +import io.milvus.common.clientenum.FunctionType; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import lombok.Builder; +import lombok.experimental.SuperBuilder; import java.util.HashMap; import java.util.Map; /** * The RRF reranking strategy, which merges results from multiple searches, favoring items that consistently appear. + * Read the doc for more info: https://milvus.io/docs/rrf-ranker.md + * + * Note: In v2.6, the Function and Rerank have been unified to support more rerank types: decay and model ranker + * https://milvus.io/docs/decay-ranker-overview.md + * https://milvus.io/docs/model-ranker-overview.md + * So we have to inherit the BaseRanker from Function, this change will lead to uncomfortable issues with + * RRFRanker/WeightedRanker in some users client code. We will mention it in release note. + * * In old client code, to declare a WeightedRanker: + * * RRFRanker ranker = new RRFRanker(20) + * * After this change, the client code should be changed accordingly: + * * RRFRanker ranker = RRFRanker.builder().k(20).build() + * + * You also can declare a rrf ranker by Function + * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder() + * .functionType(FunctionType.RERANK) + * .param("strategy", "rrf") + * .param("params", "{\"k\": 60}") + * .build(); */ -public class RRFRanker extends BaseRanker { +@SuperBuilder +public class RRFRanker extends CreateCollectionReq.Function { + @Builder.Default private int k = 60; - public RRFRanker(int k) { - this.k = k; + public FunctionType getFunctionType() { + return FunctionType.RERANK; } - @Override - public Map getProperties() { + public Map getParams() { JsonObject params = new JsonObject(); params.addProperty("k", this.k); - Map props = new HashMap<>(); + Map props = super.getParams(); props.put("strategy", "rrf"); props.put("params", params.toString()); return props; diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/WeightedRanker.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/WeightedRanker.java index 191c728f5..eedd5fcb9 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/WeightedRanker.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/WeightedRanker.java @@ -20,8 +20,13 @@ package io.milvus.v2.service.vector.request.ranker; import com.google.gson.JsonObject; +import io.milvus.common.clientenum.FunctionType; import io.milvus.common.utils.JsonUtils; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import lombok.Builder; +import lombok.experimental.SuperBuilder; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -29,20 +34,39 @@ /** * The Average Weighted Scoring reranking strategy, which prioritizes vectors based on relevance, * averaging their significance. + * Read the doc for more info: https://milvus.io/docs/weighted-ranker.md + * + * Note: In v2.6, the Function and Rerank have been unified to support more rerank types: decay and model ranker + * https://milvus.io/docs/decay-ranker-overview.md + * https://milvus.io/docs/model-ranker-overview.md + * So we have to inherit the BaseRanker from Function, this change will lead to uncomfortable issues with + * RRFRanker/WeightedRanker in some users client code. We will mention it in release note. + * In old client code, to declare a WeightedRanker: + * WeightedRanker ranker = new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f)) + * After this change, the client code should be changed accordingly: + * WeightedRanker ranker = WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build() + * + * You also can declare a weighter ranker by Function + * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder() + * .functionType(FunctionType.RERANK) + * .param("strategy", "weighted") + * .param("params", "{\"weights\": [0.4, 0.6]}") + * .build(); */ -public class WeightedRanker extends BaseRanker { - private List weights; +@SuperBuilder +public class WeightedRanker extends CreateCollectionReq.Function { + @Builder.Default + private List weights = new ArrayList<>(); - public WeightedRanker(List weights) { - this.weights = weights; + public FunctionType getFunctionType() { + return FunctionType.RERANK; } - @Override - public Map getProperties() { + public Map getParams() { JsonObject params = new JsonObject(); params.add("weights", JsonUtils.toJsonTree(this.weights).getAsJsonArray()); - Map props = new HashMap<>(); + Map props = super.getParams(); props.put("strategy", "weighted"); props.put("params", params.toString()); return props; diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java index 7bec544fa..7f6e8d37d 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java @@ -31,9 +31,11 @@ import io.milvus.param.ParamUtils; import io.milvus.v2.exception.ErrorCode; import io.milvus.v2.exception.MilvusClientException; +import io.milvus.v2.service.collection.request.CreateCollectionReq; import io.milvus.v2.service.vector.request.*; -import io.milvus.v2.service.vector.request.ranker.BaseRanker; import io.milvus.v2.service.vector.request.data.*; +import io.milvus.v2.service.vector.request.ranker.RRFRanker; +import io.milvus.v2.service.vector.request.ranker.WeightedRanker; import lombok.NonNull; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -279,6 +281,12 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) { builder.setConsistencyLevelValue(request.getConsistencyLevel().getCode()); } + // set ranker, support reranking search result from v2.6.1 + CreateCollectionReq.Function ranker = request.getRanker(); + if (ranker != null) { + builder.setFunctionScore(convertFunctionScore(ranker)); + } + return builder.build(); } @@ -473,16 +481,25 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ } // set ranker - BaseRanker ranker = request.getRanker(); - if (request.getRanker() == null) { + CreateCollectionReq.Function ranker = request.getRanker(); + if (ranker == null) { throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null."); } - // topK value is deprecated, always use "limit" to set the topK - Map props = ranker.getProperties(); + Map props = new HashMap<>(); props.put(Constant.LIMIT, String.valueOf(request.getLimit())); props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal())); props.put(Constant.OFFSET, String.valueOf(request.getOffset())); + + if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) { + // old logic for RRF/Weighted ranker + Map params = ranker.getParams(); + props.putAll(params); + } else { + // new logic for Decay/Model ranker + builder.setFunctionScore(convertFunctionScore(ranker)); + } + List propertiesList = ParamUtils.AssembleKvPair(props); if (CollectionUtils.isNotEmpty(propertiesList)) { propertiesList.forEach(builder::addRankParams); @@ -528,6 +545,17 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ return builder.build(); } + private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) { + FunctionSchema schema = FunctionSchema.newBuilder() + .setName(function.getName()) + .setDescription(function.getDescription()) + .setType(FunctionType.forNumber(function.getFunctionType().getCode())) + .addAllInputFieldNames(function.getInputFieldNames()) + .addAllParams(ParamUtils.AssembleKvPair(function.getParams())) + .build(); + return FunctionScore.newBuilder().addFunctions(schema).build(); + } + public String getExprById(String primaryFieldName, List ids) { StringBuilder sb = new StringBuilder(); sb.append(primaryFieldName).append(" in ["); diff --git a/sdk-core/src/test/java/io/milvus/TestUtils.java b/sdk-core/src/test/java/io/milvus/TestUtils.java index 401e278d9..b3d4de70f 100644 --- a/sdk-core/src/test/java/io/milvus/TestUtils.java +++ b/sdk-core/src/test/java/io/milvus/TestUtils.java @@ -11,7 +11,7 @@ public class TestUtils { private int dimension = 256; private static final Random RANDOM = new Random(); - public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.0-rc1"; + public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.0"; public TestUtils(int dimension) { this.dimension = dimension; diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 993a3f0cc..57007da5b 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -1047,7 +1047,7 @@ void testHybridSearch() { return HybridSearchReq.builder() .collectionName(randomCollectionName) .searchRequests(searchRequests) - .ranker(new RRFRanker(20)) + .ranker(RRFRanker.builder().k(20).build()) .limit(topk) .consistencyLevel(ConsistencyLevel.BOUNDED) .build(); @@ -2853,7 +2853,7 @@ void testConsistencyLevel() throws InterruptedException { .databaseName(dbName) .collectionName(randomCollectionName) .searchRequests(Collections.singletonList(subReq)) - .ranker(new RRFRanker(20)) + .ranker(RRFRanker.builder().k(20).build()) .limit(5) .build()); List> oneResult = searchResp.getSearchResults();