Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ services:

standalone:
container_name: milvus-javasdk-test-standalone
image: milvusdb/milvus:v2.6.0-rc1
image: milvusdb/milvus:v2.6.0
command: ["milvus", "run", "standalone"]
environment:
ETCD_ENDPOINTS: etcd:2379
Expand Down Expand Up @@ -77,7 +77,7 @@ services:

standaloneslave:
container_name: milvus-javasdk-test-slave-standalone
image: milvusdb/milvus:v2.6.0-rc1
image: milvusdb/milvus:v2.6.0
command: ["milvus", "run", "standalone"]
environment:
ETCD_ENDPOINTS: etcdslave:2379
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ private void hybridSearch() {
HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
.collectionName(COLLECTION_NAME)
.searchRequests(searchRequests)
.ranker(new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f)))
.ranker(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build())
.limit(5)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
package io.milvus.v2.service.vector.request;

import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.service.collection.request.LoadCollectionReq;
import io.milvus.v2.service.vector.request.ranker.BaseRanker;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import lombok.Builder;
import lombok.Data;
import lombok.experimental.SuperBuilder;
Expand All @@ -36,7 +35,7 @@ public class HybridSearchReq
private String collectionName;
private List<String> partitionNames;
private List<AnnSearchReq> searchRequests;
private BaseRanker ranker;
private CreateCollectionReq.Function ranker;
@Builder.Default
@Deprecated
private int topK = 0; // deprecated, replaced by "limit"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.data.BaseVector;

import lombok.Builder;
Expand Down Expand Up @@ -66,6 +67,7 @@ public class SearchReq {
private String groupByFieldName;
private Integer groupSize;
private Boolean strictGroupSize;
private CreateCollectionReq.Function ranker;

// Expression template, to improve expression parsing performance in complicated list
// Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.milvus.v2.service.vector.request.ranker;

import io.milvus.common.clientenum.FunctionType;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import lombok.Builder;
import lombok.experimental.SuperBuilder;

import java.util.Map;

/**
* The Decay reranking strategy, which by adjusting search rankings based on numeric field values.
* Read the doc for more info: https://milvus.io/docs/decay-ranker-overview.md
*
* You also can declare a decay ranker by Function
* CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
* .functionType(FunctionType.RERANK)
* .name("time_decay")
* .description("time decay")
* .inputFieldNames(Collections.singletonList("timestamp"))
* .param("reranker", "decay")
* .param("function", "gauss")
* .param("origin", "1000")
* .param("scale", "10000")
* .param("offset", "24")
* .param("decay", "0.5")
* .build();
*/
@SuperBuilder
public class DecayRanker extends CreateCollectionReq.Function {
@Builder.Default
private String function = "gauss";
private Number origin;
private Number scale;

public FunctionType getFunctionType() {
return FunctionType.RERANK;
}

public Map<String, String> getParams() {
// the parent params might contain "offset" and "decay"
Map<String, String> props = super.getParams();
props.put("reranker", "decay");
props.put("function", function); // "gauss", "exp", or "linear"
if (origin != null) {
props.put("origin", origin.toString());
}
if (scale != null) {
props.put("scale", scale.toString());
}
return props;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.milvus.v2.service.vector.request.ranker;

import com.google.gson.JsonArray;
import io.milvus.common.clientenum.FunctionType;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import lombok.Builder;
import lombok.experimental.SuperBuilder;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* The Model reranking strategy, which transforms Milvus search by integrating advanced language models
* that understand semantic relationships between queries and documents.
* Read the doc for more info: https://milvus.io/docs/model-ranker-overview.md
*
* You also can declare a model ranker by Function
* CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
* .functionType(FunctionType.RERANK)
* .name("semantic_ranker")
* .description("semantic ranker")
* .inputFieldNames(Collections.singletonList("document"))
* .param("reranker", "model")
* .param("provider", "tei")
* .param("queries", "[\"machine learning for time series\"]")
* .param("endpoint", "http://model-service:8080")
* .build();
*/
@SuperBuilder
public class ModelRanker extends CreateCollectionReq.Function {
@Builder.Default
private String provider = "tei";
@Builder.Default
private List<String> queries = new ArrayList<>();
private String endpoint;

public FunctionType getFunctionType() {
return FunctionType.RERANK;
}

public Map<String, String> getParams() {
// the parent params might contain "offset" and "decay"
Map<String, String> props = super.getParams();
props.put("reranker", "model");
props.put("provider", provider); // "tei" or "vllm"
JsonArray json = new JsonArray();
queries.forEach(json::add);
props.put("queries", json.toString());
if (endpoint != null) {
props.put("endpoint", endpoint);
}
return props;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,49 @@
package io.milvus.v2.service.vector.request.ranker;

import com.google.gson.JsonObject;
import io.milvus.common.clientenum.FunctionType;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import lombok.Builder;
import lombok.experimental.SuperBuilder;

import java.util.HashMap;
import java.util.Map;

/**
* The RRF reranking strategy, which merges results from multiple searches, favoring items that consistently appear.
* Read the doc for more info: https://milvus.io/docs/rrf-ranker.md
*
* Note: In v2.6, the Function and Rerank have been unified to support more rerank types: decay and model ranker
* https://milvus.io/docs/decay-ranker-overview.md
* https://milvus.io/docs/model-ranker-overview.md
* So we have to inherit the BaseRanker from Function, this change will lead to uncomfortable issues with
* RRFRanker/WeightedRanker in some users client code. We will mention it in release note.
* * In old client code, to declare a WeightedRanker:
* * RRFRanker ranker = new RRFRanker(20)
* * After this change, the client code should be changed accordingly:
* * RRFRanker ranker = RRFRanker.builder().k(20).build()
*
* You also can declare a rrf ranker by Function
* CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
* .functionType(FunctionType.RERANK)
* .param("strategy", "rrf")
* .param("params", "{\"k\": 60}")
* .build();
*/
public class RRFRanker extends BaseRanker {
@SuperBuilder
public class RRFRanker extends CreateCollectionReq.Function {
@Builder.Default
private int k = 60;

public RRFRanker(int k) {
this.k = k;
public FunctionType getFunctionType() {
return FunctionType.RERANK;
}

@Override
public Map<String, String> getProperties() {
public Map<String, String> getParams() {
JsonObject params = new JsonObject();
params.addProperty("k", this.k);

Map<String, String> props = new HashMap<>();
Map<String, String> props = super.getParams();
props.put("strategy", "rrf");
props.put("params", params.toString());
return props;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,53 @@
package io.milvus.v2.service.vector.request.ranker;

import com.google.gson.JsonObject;
import io.milvus.common.clientenum.FunctionType;
import io.milvus.common.utils.JsonUtils;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import lombok.Builder;
import lombok.experimental.SuperBuilder;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* The Average Weighted Scoring reranking strategy, which prioritizes vectors based on relevance,
* averaging their significance.
* Read the doc for more info: https://milvus.io/docs/weighted-ranker.md
*
* Note: In v2.6, the Function and Rerank have been unified to support more rerank types: decay and model ranker
* https://milvus.io/docs/decay-ranker-overview.md
* https://milvus.io/docs/model-ranker-overview.md
* So we have to inherit the BaseRanker from Function, this change will lead to uncomfortable issues with
* RRFRanker/WeightedRanker in some users client code. We will mention it in release note.
* In old client code, to declare a WeightedRanker:
* WeightedRanker ranker = new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f))
* After this change, the client code should be changed accordingly:
* WeightedRanker ranker = WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build()
*
* You also can declare a weighter ranker by Function
* CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
* .functionType(FunctionType.RERANK)
* .param("strategy", "weighted")
* .param("params", "{\"weights\": [0.4, 0.6]}")
* .build();
*/
public class WeightedRanker extends BaseRanker {
private List<Float> weights;
@SuperBuilder
public class WeightedRanker extends CreateCollectionReq.Function {
@Builder.Default
private List<Float> weights = new ArrayList<>();

public WeightedRanker(List<Float> weights) {
this.weights = weights;
public FunctionType getFunctionType() {
return FunctionType.RERANK;
}

@Override
public Map<String, String> getProperties() {
public Map<String, String> getParams() {
JsonObject params = new JsonObject();
params.add("weights", JsonUtils.toJsonTree(this.weights).getAsJsonArray());

Map<String, String> props = new HashMap<>();
Map<String, String> props = super.getParams();
props.put("strategy", "weighted");
props.put("params", params.toString());
return props;
Expand Down
38 changes: 33 additions & 5 deletions sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
import io.milvus.param.ParamUtils;
import io.milvus.v2.exception.ErrorCode;
import io.milvus.v2.exception.MilvusClientException;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.*;
import io.milvus.v2.service.vector.request.ranker.BaseRanker;
import io.milvus.v2.service.vector.request.data.*;
import io.milvus.v2.service.vector.request.ranker.RRFRanker;
import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -279,6 +281,12 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) {
builder.setConsistencyLevelValue(request.getConsistencyLevel().getCode());
}

// set ranker, support reranking search result from v2.6.1
CreateCollectionReq.Function ranker = request.getRanker();
if (ranker != null) {
builder.setFunctionScore(convertFunctionScore(ranker));
}

return builder.build();
}

Expand Down Expand Up @@ -473,16 +481,25 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ
}

// set ranker
BaseRanker ranker = request.getRanker();
if (request.getRanker() == null) {
CreateCollectionReq.Function ranker = request.getRanker();
if (ranker == null) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null.");
}

// topK value is deprecated, always use "limit" to set the topK
Map<String, String> props = ranker.getProperties();
Map<String, String> props = new HashMap<>();
props.put(Constant.LIMIT, String.valueOf(request.getLimit()));
props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal()));
props.put(Constant.OFFSET, String.valueOf(request.getOffset()));

if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
// old logic for RRF/Weighted ranker
Map<String, String> params = ranker.getParams();
props.putAll(params);
} else {
// new logic for Decay/Model ranker
builder.setFunctionScore(convertFunctionScore(ranker));
}

List<KeyValuePair> propertiesList = ParamUtils.AssembleKvPair(props);
if (CollectionUtils.isNotEmpty(propertiesList)) {
propertiesList.forEach(builder::addRankParams);
Expand Down Expand Up @@ -528,6 +545,17 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ
return builder.build();
}

private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) {
FunctionSchema schema = FunctionSchema.newBuilder()
.setName(function.getName())
.setDescription(function.getDescription())
.setType(FunctionType.forNumber(function.getFunctionType().getCode()))
.addAllInputFieldNames(function.getInputFieldNames())
.addAllParams(ParamUtils.AssembleKvPair(function.getParams()))
.build();
return FunctionScore.newBuilder().addFunctions(schema).build();
}

public String getExprById(String primaryFieldName, List<?> ids) {
StringBuilder sb = new StringBuilder();
sb.append(primaryFieldName).append(" in [");
Expand Down
2 changes: 1 addition & 1 deletion sdk-core/src/test/java/io/milvus/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public class TestUtils {
private int dimension = 256;
private static final Random RANDOM = new Random();

public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.0-rc1";
public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.0";

public TestUtils(int dimension) {
this.dimension = dimension;
Expand Down
Loading
Loading