Skip to content

Commit bec779e

Browse files
authored
Unify Function and Rerank (#1515)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent e7f829e commit bec779e

12 files changed

Lines changed: 210 additions & 53 deletions

File tree

docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ services:
3232

3333
standalone:
3434
container_name: milvus-javasdk-test-standalone
35-
image: milvusdb/milvus:v2.6.0-rc1
35+
image: milvusdb/milvus:v2.6.0
3636
command: ["milvus", "run", "standalone"]
3737
environment:
3838
ETCD_ENDPOINTS: etcd:2379
@@ -77,7 +77,7 @@ services:
7777

7878
standaloneslave:
7979
container_name: milvus-javasdk-test-slave-standalone
80-
image: milvusdb/milvus:v2.6.0-rc1
80+
image: milvusdb/milvus:v2.6.0
8181
command: ["milvus", "run", "standalone"]
8282
environment:
8383
ETCD_ENDPOINTS: etcdslave:2379

examples/src/main/java/io/milvus/v2/HybridSearchExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ private void hybridSearch() {
213213
HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
214214
.collectionName(COLLECTION_NAME)
215215
.searchRequests(searchRequests)
216-
.ranker(new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f)))
216+
.ranker(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build())
217217
.limit(5)
218218
.consistencyLevel(ConsistencyLevel.BOUNDED)
219219
.build();

sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
package io.milvus.v2.service.vector.request;
2121

2222
import io.milvus.v2.common.ConsistencyLevel;
23-
import io.milvus.v2.service.collection.request.LoadCollectionReq;
24-
import io.milvus.v2.service.vector.request.ranker.BaseRanker;
23+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
2524
import lombok.Builder;
2625
import lombok.Data;
2726
import lombok.experimental.SuperBuilder;
@@ -36,7 +35,7 @@ public class HybridSearchReq
3635
private String collectionName;
3736
private List<String> partitionNames;
3837
private List<AnnSearchReq> searchRequests;
39-
private BaseRanker ranker;
38+
private CreateCollectionReq.Function ranker;
4039
@Builder.Default
4140
@Deprecated
4241
private int topK = 0; // deprecated, replaced by "limit"

sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import io.milvus.v2.common.ConsistencyLevel;
2323
import io.milvus.v2.common.IndexParam;
24+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
2425
import io.milvus.v2.service.vector.request.data.BaseVector;
2526

2627
import lombok.Builder;
@@ -66,6 +67,7 @@ public class SearchReq {
6667
private String groupByFieldName;
6768
private Integer groupSize;
6869
private Boolean strictGroupSize;
70+
private CreateCollectionReq.Function ranker;
6971

7072
// Expression template, to improve expression parsing performance in complicated list
7173
// Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]

sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/BaseRanker.java

Lines changed: 0 additions & 26 deletions
This file was deleted.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package io.milvus.v2.service.vector.request.ranker;
2+
3+
import io.milvus.common.clientenum.FunctionType;
4+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
5+
import lombok.Builder;
6+
import lombok.experimental.SuperBuilder;
7+
8+
import java.util.Map;
9+
10+
/**
11+
* The Decay reranking strategy, which by adjusting search rankings based on numeric field values.
12+
* Read the doc for more info: https://milvus.io/docs/decay-ranker-overview.md
13+
*
14+
* You also can declare a decay ranker by Function
15+
* CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
16+
* .functionType(FunctionType.RERANK)
17+
* .name("time_decay")
18+
* .description("time decay")
19+
* .inputFieldNames(Collections.singletonList("timestamp"))
20+
* .param("reranker", "decay")
21+
* .param("function", "gauss")
22+
* .param("origin", "1000")
23+
* .param("scale", "10000")
24+
* .param("offset", "24")
25+
* .param("decay", "0.5")
26+
* .build();
27+
*/
28+
@SuperBuilder
29+
public class DecayRanker extends CreateCollectionReq.Function {
30+
@Builder.Default
31+
private String function = "gauss";
32+
private Number origin;
33+
private Number scale;
34+
35+
public FunctionType getFunctionType() {
36+
return FunctionType.RERANK;
37+
}
38+
39+
public Map<String, String> getParams() {
40+
// the parent params might contain "offset" and "decay"
41+
Map<String, String> props = super.getParams();
42+
props.put("reranker", "decay");
43+
props.put("function", function); // "gauss", "exp", or "linear"
44+
if (origin != null) {
45+
props.put("origin", origin.toString());
46+
}
47+
if (scale != null) {
48+
props.put("scale", scale.toString());
49+
}
50+
return props;
51+
}
52+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package io.milvus.v2.service.vector.request.ranker;
2+
3+
import com.google.gson.JsonArray;
4+
import io.milvus.common.clientenum.FunctionType;
5+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
6+
import lombok.Builder;
7+
import lombok.experimental.SuperBuilder;
8+
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.Map;
12+
13+
/**
14+
* The Model reranking strategy, which transforms Milvus search by integrating advanced language models
15+
* that understand semantic relationships between queries and documents.
16+
* Read the doc for more info: https://milvus.io/docs/model-ranker-overview.md
17+
*
18+
* You also can declare a model ranker by Function
19+
* CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
20+
* .functionType(FunctionType.RERANK)
21+
* .name("semantic_ranker")
22+
* .description("semantic ranker")
23+
* .inputFieldNames(Collections.singletonList("document"))
24+
* .param("reranker", "model")
25+
* .param("provider", "tei")
26+
* .param("queries", "[\"machine learning for time series\"]")
27+
* .param("endpoint", "http://model-service:8080")
28+
* .build();
29+
*/
30+
@SuperBuilder
31+
public class ModelRanker extends CreateCollectionReq.Function {
32+
@Builder.Default
33+
private String provider = "tei";
34+
@Builder.Default
35+
private List<String> queries = new ArrayList<>();
36+
private String endpoint;
37+
38+
public FunctionType getFunctionType() {
39+
return FunctionType.RERANK;
40+
}
41+
42+
public Map<String, String> getParams() {
43+
// the parent params might contain "offset" and "decay"
44+
Map<String, String> props = super.getParams();
45+
props.put("reranker", "model");
46+
props.put("provider", provider); // "tei" or "vllm"
47+
JsonArray json = new JsonArray();
48+
queries.forEach(json::add);
49+
props.put("queries", json.toString());
50+
if (endpoint != null) {
51+
props.put("endpoint", endpoint);
52+
}
53+
return props;
54+
}
55+
}

sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/RRFRanker.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,49 @@
2020
package io.milvus.v2.service.vector.request.ranker;
2121

2222
import com.google.gson.JsonObject;
23+
import io.milvus.common.clientenum.FunctionType;
24+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
25+
import lombok.Builder;
26+
import lombok.experimental.SuperBuilder;
2327

2428
import java.util.HashMap;
2529
import java.util.Map;
2630

2731
/**
2832
* The RRF reranking strategy, which merges results from multiple searches, favoring items that consistently appear.
33+
* Read the doc for more info: https://milvus.io/docs/rrf-ranker.md
34+
*
35+
* Note: In v2.6, the Function and Rerank have been unified to support more rerank types: decay and model ranker
36+
* https://milvus.io/docs/decay-ranker-overview.md
37+
* https://milvus.io/docs/model-ranker-overview.md
38+
* So we have to inherit the BaseRanker from Function, this change will lead to uncomfortable issues with
39+
* RRFRanker/WeightedRanker in some users client code. We will mention it in release note.
40+
* * In old client code, to declare a WeightedRanker:
41+
* * RRFRanker ranker = new RRFRanker(20)
42+
* * After this change, the client code should be changed accordingly:
43+
* * RRFRanker ranker = RRFRanker.builder().k(20).build()
44+
*
45+
* You also can declare a rrf ranker by Function
46+
* CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
47+
* .functionType(FunctionType.RERANK)
48+
* .param("strategy", "rrf")
49+
* .param("params", "{\"k\": 60}")
50+
* .build();
2951
*/
30-
public class RRFRanker extends BaseRanker {
52+
@SuperBuilder
53+
public class RRFRanker extends CreateCollectionReq.Function {
54+
@Builder.Default
3155
private int k = 60;
3256

33-
public RRFRanker(int k) {
34-
this.k = k;
57+
public FunctionType getFunctionType() {
58+
return FunctionType.RERANK;
3559
}
3660

37-
@Override
38-
public Map<String, String> getProperties() {
61+
public Map<String, String> getParams() {
3962
JsonObject params = new JsonObject();
4063
params.addProperty("k", this.k);
4164

42-
Map<String, String> props = new HashMap<>();
65+
Map<String, String> props = super.getParams();
4366
props.put("strategy", "rrf");
4467
props.put("params", params.toString());
4568
return props;

sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/WeightedRanker.java

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,53 @@
2020
package io.milvus.v2.service.vector.request.ranker;
2121

2222
import com.google.gson.JsonObject;
23+
import io.milvus.common.clientenum.FunctionType;
2324
import io.milvus.common.utils.JsonUtils;
25+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
26+
import lombok.Builder;
27+
import lombok.experimental.SuperBuilder;
2428

29+
import java.util.ArrayList;
2530
import java.util.HashMap;
2631
import java.util.List;
2732
import java.util.Map;
2833

2934
/**
3035
* The Average Weighted Scoring reranking strategy, which prioritizes vectors based on relevance,
3136
* averaging their significance.
37+
* Read the doc for more info: https://milvus.io/docs/weighted-ranker.md
38+
*
39+
* Note: In v2.6, the Function and Rerank have been unified to support more rerank types: decay and model ranker
40+
* https://milvus.io/docs/decay-ranker-overview.md
41+
* https://milvus.io/docs/model-ranker-overview.md
42+
* So we have to inherit the BaseRanker from Function, this change will lead to uncomfortable issues with
43+
* RRFRanker/WeightedRanker in some users client code. We will mention it in release note.
44+
* In old client code, to declare a WeightedRanker:
45+
* WeightedRanker ranker = new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f))
46+
* After this change, the client code should be changed accordingly:
47+
* WeightedRanker ranker = WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build()
48+
*
49+
* You also can declare a weighter ranker by Function
50+
* CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
51+
* .functionType(FunctionType.RERANK)
52+
* .param("strategy", "weighted")
53+
* .param("params", "{\"weights\": [0.4, 0.6]}")
54+
* .build();
3255
*/
33-
public class WeightedRanker extends BaseRanker {
34-
private List<Float> weights;
56+
@SuperBuilder
57+
public class WeightedRanker extends CreateCollectionReq.Function {
58+
@Builder.Default
59+
private List<Float> weights = new ArrayList<>();
3560

36-
public WeightedRanker(List<Float> weights) {
37-
this.weights = weights;
61+
public FunctionType getFunctionType() {
62+
return FunctionType.RERANK;
3863
}
3964

40-
@Override
41-
public Map<String, String> getProperties() {
65+
public Map<String, String> getParams() {
4266
JsonObject params = new JsonObject();
4367
params.add("weights", JsonUtils.toJsonTree(this.weights).getAsJsonArray());
4468

45-
Map<String, String> props = new HashMap<>();
69+
Map<String, String> props = super.getParams();
4670
props.put("strategy", "weighted");
4771
props.put("params", params.toString());
4872
return props;

sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
import io.milvus.param.ParamUtils;
3232
import io.milvus.v2.exception.ErrorCode;
3333
import io.milvus.v2.exception.MilvusClientException;
34+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
3435
import io.milvus.v2.service.vector.request.*;
35-
import io.milvus.v2.service.vector.request.ranker.BaseRanker;
3636
import io.milvus.v2.service.vector.request.data.*;
37+
import io.milvus.v2.service.vector.request.ranker.RRFRanker;
38+
import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
3739
import lombok.NonNull;
3840
import org.apache.commons.collections4.CollectionUtils;
3941
import org.apache.commons.lang3.StringUtils;
@@ -279,6 +281,12 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) {
279281
builder.setConsistencyLevelValue(request.getConsistencyLevel().getCode());
280282
}
281283

284+
// set ranker, support reranking search result from v2.6.1
285+
CreateCollectionReq.Function ranker = request.getRanker();
286+
if (ranker != null) {
287+
builder.setFunctionScore(convertFunctionScore(ranker));
288+
}
289+
282290
return builder.build();
283291
}
284292

@@ -473,16 +481,25 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ
473481
}
474482

475483
// set ranker
476-
BaseRanker ranker = request.getRanker();
477-
if (request.getRanker() == null) {
484+
CreateCollectionReq.Function ranker = request.getRanker();
485+
if (ranker == null) {
478486
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null.");
479487
}
480488

481-
// topK value is deprecated, always use "limit" to set the topK
482-
Map<String, String> props = ranker.getProperties();
489+
Map<String, String> props = new HashMap<>();
483490
props.put(Constant.LIMIT, String.valueOf(request.getLimit()));
484491
props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal()));
485492
props.put(Constant.OFFSET, String.valueOf(request.getOffset()));
493+
494+
if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
495+
// old logic for RRF/Weighted ranker
496+
Map<String, String> params = ranker.getParams();
497+
props.putAll(params);
498+
} else {
499+
// new logic for Decay/Model ranker
500+
builder.setFunctionScore(convertFunctionScore(ranker));
501+
}
502+
486503
List<KeyValuePair> propertiesList = ParamUtils.AssembleKvPair(props);
487504
if (CollectionUtils.isNotEmpty(propertiesList)) {
488505
propertiesList.forEach(builder::addRankParams);
@@ -528,6 +545,17 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ
528545
return builder.build();
529546
}
530547

548+
private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) {
549+
FunctionSchema schema = FunctionSchema.newBuilder()
550+
.setName(function.getName())
551+
.setDescription(function.getDescription())
552+
.setType(FunctionType.forNumber(function.getFunctionType().getCode()))
553+
.addAllInputFieldNames(function.getInputFieldNames())
554+
.addAllParams(ParamUtils.AssembleKvPair(function.getParams()))
555+
.build();
556+
return FunctionScore.newBuilder().addFunctions(schema).build();
557+
}
558+
531559
public String getExprById(String primaryFieldName, List<?> ids) {
532560
StringBuilder sb = new StringBuilder();
533561
sb.append(primaryFieldName).append(" in [");

0 commit comments

Comments
 (0)