Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/src/main/java/io/milvus/v2/HybridSearchExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.DropCollectionReq;
import io.milvus.v2.service.vector.request.AnnSearchReq;
import io.milvus.v2.service.vector.request.FunctionScore;
import io.milvus.v2.service.vector.request.HybridSearchReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.QueryReq;
Expand Down Expand Up @@ -122,7 +123,6 @@ private void createCollection() {
.metricType(BINARY_VECTOR_METRIC)
.build());
Map<String,Object> fv16Params = new HashMap<>();
fv16Params.clear();
fv16Params.put("M",16);
fv16Params.put("efConstruction",64);
indexes.add(IndexParam.builder()
Expand Down Expand Up @@ -212,7 +212,9 @@ private void hybridSearch() {
HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
.collectionName(COLLECTION_NAME)
.searchRequests(searchRequests)
.ranker(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build())
.functionScore(FunctionScore.builder()
.addFunction(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build())
.build())
.limit(5)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ private boolean isResSufficient(List<QueryResultsWrapper.RowRecord> ret) {
private QueryResults executeQuery(String expr, long offset, long limit, long ts, boolean isSeek) {
// for seeking offset, no need to return output fields
List<String> outputFields = new ArrayList<>();
boolean reduceStopForBest = queryIteratorParam.isReduceStopForBest();
if (!isSeek) {
outputFields = queryIteratorParam.getOutFields();
reduceStopForBest = false;
}
QueryParam queryParam = QueryParam.newBuilder()
.withDatabaseName(queryIteratorParam.getDatabaseName())
Expand All @@ -230,7 +232,7 @@ private QueryResults executeQuery(String expr, long offset, long limit, long ts,
// reduce stop for best
builder.addQueryParams(KeyValuePair.newBuilder()
.setKey(Constant.REDUCE_STOP_FOR_BEST)
.setValue(String.valueOf(queryIteratorParam.isReduceStopForBest()))
.setValue(String.valueOf(reduceStopForBest))
.build());

// iterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ public static class FieldSchema {
@Data
@SuperBuilder
public static class Function {
private String name;
@Builder.Default
private String name = "";
@Builder.Default
private String description = "";
@Builder.Default
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package io.milvus.v2.service.vector.request;

import io.milvus.v2.service.collection.request.CreateCollectionReq;
import lombok.Builder;
import lombok.Data;
import lombok.experimental.SuperBuilder;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Data
@SuperBuilder
public class FunctionScore {
@Builder.Default
private List<CreateCollectionReq.Function> functions = new ArrayList<>();
@Builder.Default
private Map<String, String> params = new HashMap<>();

public static abstract class FunctionScoreBuilder<C extends FunctionScore, B extends FunctionScore.FunctionScoreBuilder<C, B>> {
public B addFunction(CreateCollectionReq.Function func) {
if(null == this.functions$value ){
this.functions$value = new ArrayList<>();
}
this.functions$value.add(func);
this.functions$set = true;
return self();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ public class HybridSearchReq
private String collectionName;
private List<String> partitionNames;
private List<AnnSearchReq> searchRequests;
private CreateCollectionReq.Function ranker;
@Builder.Default
@Deprecated
private int topK = 0; // deprecated, replaced by "limit"
Expand All @@ -51,6 +50,11 @@ public class HybridSearchReq
private String groupByFieldName;
private Integer groupSize;
private Boolean strictGroupSize;
@Deprecated
private CreateCollectionReq.Function ranker;
// milvus v2.6.1 supports multi-rankers. The "ranker" still works. It is recommended
// to use functionScore even you have only one ranker. Not allow to set both.
private FunctionScore functionScore;

public static abstract class HybridSearchReqBuilder<C extends HybridSearchReq, B extends HybridSearchReq.HybridSearchReqBuilder<C, B>> {
// topK is deprecated, topK and limit must be the same value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ public class SearchReq {
private String groupByFieldName;
private Integer groupSize;
private Boolean strictGroupSize;
@Deprecated
private CreateCollectionReq.Function ranker;
// milvus v2.6.1 supports multi-rankers. The "ranker" still works. It is recommended
// to use functionScore even you have only one ranker. Not allow to set both.
private FunctionScore functionScore;

// Expression template, to improve expression parsing performance in complicated list
// Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]
Expand Down
78 changes: 59 additions & 19 deletions sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package io.milvus.v2.utils;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.reflect.TypeToken;
import com.google.protobuf.ByteString;
import io.milvus.common.utils.GTsDict;
Expand All @@ -33,6 +35,7 @@
import io.milvus.v2.exception.MilvusClientException;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.*;
import io.milvus.v2.service.vector.request.FunctionScore;
import io.milvus.v2.service.vector.request.data.*;
import io.milvus.v2.service.vector.request.ranker.RRFRanker;
import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
Expand Down Expand Up @@ -283,8 +286,14 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) {

// set ranker, support reranking search result from v2.6.1
CreateCollectionReq.Function ranker = request.getRanker();
if (ranker != null) {
builder.setFunctionScore(convertFunctionScore(ranker));
io.milvus.v2.service.vector.request.FunctionScore functionScore = request.getFunctionScore();
if (ranker != null && functionScore != null) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Not allow to set both ranker and functionScore.");
}
if (functionScore != null) {
builder.setFunctionScore(convertFunctionScore(functionScore));
} else if (ranker != null) {
builder.setFunctionScore(convertOneFunction(ranker));
}

return builder.build();
Expand Down Expand Up @@ -480,24 +489,28 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ
builder.addRequests(searchRequest);
}

// set ranker
CreateCollectionReq.Function ranker = request.getRanker();
if (ranker == null) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null.");
}

Map<String, String> props = new HashMap<>();
props.put(Constant.LIMIT, String.valueOf(request.getLimit()));
props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal()));
props.put(Constant.OFFSET, String.valueOf(request.getOffset()));

if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
// old logic for RRF/Weighted ranker
Map<String, String> params = ranker.getParams();
props.putAll(params);
} else {
// new logic for Decay/Model ranker
builder.setFunctionScore(convertFunctionScore(ranker));
// set ranker
CreateCollectionReq.Function ranker = request.getRanker();
io.milvus.v2.service.vector.request.FunctionScore functionScore = request.getFunctionScore();
if (ranker != null && functionScore != null) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Not allow to set both ranker and functionScore.");
}
if (functionScore != null) {
builder.setFunctionScore(convertFunctionScore(functionScore));
} else if (ranker != null) {
if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
// old logic for RRF/Weighted ranker
Map<String, String> params = ranker.getParams();
props.putAll(params);
} else {
// new logic for Decay/Model ranker
builder.setFunctionScore(convertOneFunction(ranker));
}
}

List<KeyValuePair> propertiesList = ParamUtils.AssembleKvPair(props);
Expand Down Expand Up @@ -545,15 +558,42 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ
return builder.build();
}

private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) {
FunctionSchema schema = FunctionSchema.newBuilder()
private FunctionSchema convertFunctionSchema(CreateCollectionReq.Function function) {
Map<String, String> params = function.getParams();
// FunctionSchema type keyword is "reranker", old RRF/Weighted ranker type keyword is "strategy"
// FunctionSchema parameters are flat, old RRF/Weighted parameters are wrapped by "params"
if (function instanceof RRFRanker || function instanceof WeightedRanker) {
String name = (function instanceof RRFRanker) ? "rrf" : "weighted";
params.put("reranker", name);
JsonObject inner = JsonParser.parseString(params.get("params")).getAsJsonObject();
for (String key : inner.keySet()) {
params.put(key, inner.get(key).toString());
}
params.remove("strategy");
params.remove("params");
}
return FunctionSchema.newBuilder()
.setName(function.getName())
.setDescription(function.getDescription())
.setType(FunctionType.forNumber(function.getFunctionType().getCode()))
.addAllInputFieldNames(function.getInputFieldNames())
.addAllParams(ParamUtils.AssembleKvPair(function.getParams()))
.addAllParams(ParamUtils.AssembleKvPair(params))
.build();
return FunctionScore.newBuilder().addFunctions(schema).build();
}

private io.milvus.grpc.FunctionScore convertOneFunction(CreateCollectionReq.Function function) {
FunctionSchema schema = convertFunctionSchema(function);
return io.milvus.grpc.FunctionScore.newBuilder().addFunctions(schema).build();
}

private io.milvus.grpc.FunctionScore convertFunctionScore(FunctionScore functionScore) {
io.milvus.grpc.FunctionScore.Builder builder = io.milvus.grpc.FunctionScore.newBuilder();
for (CreateCollectionReq.Function function : functionScore.getFunctions()) {
FunctionSchema schema = convertFunctionSchema(function);
builder.addFunctions(schema);
}
builder.addAllParams(ParamUtils.AssembleKvPair(functionScore.getParams()));
return builder.build();
}

public String getExprById(String primaryFieldName, List<?> ids) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1074,15 +1074,16 @@ void testHybridSearch() {
// prepare sub requests
int nq = 5;
int topk = 10;
Function<Integer, HybridSearchReq> genRequestFunc =
sparseCount -> {
Function<Map<String, Object>, HybridSearchReq> genRequestFunc =
config -> {
List<BaseVector> floatVectors = new ArrayList<>();
List<BaseVector> binaryVectors = new ArrayList<>();
List<BaseVector> sparseVectors = new ArrayList<>();
for (int i = 0; i < nq; i++) {
floatVectors.add(new FloatVec(utils.generateFloatVector()));
binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
}
int sparseCount = (Integer)config.get("sparseCount");
for (int i = 0; i < sparseCount; i++) {
sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
}
Expand All @@ -1105,23 +1106,40 @@ void testHybridSearch() {
.limit(7)
.build());

return HybridSearchReq.builder()
.collectionName(randomCollectionName)
.searchRequests(searchRequests)
.ranker(RRFRanker.builder().k(20).build())
.limit(topk)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build();
CreateCollectionReq.Function ranker = WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build();
boolean useFunctionScore = (Boolean)config.get("useFunctionScore");
if (useFunctionScore) {
return HybridSearchReq.builder()
.collectionName(randomCollectionName)
.searchRequests(searchRequests)
.functionScore(FunctionScore.builder().addFunction(ranker).build())
.limit(topk)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build();
} else {
return HybridSearchReq.builder()
.collectionName(randomCollectionName)
.searchRequests(searchRequests)
.ranker(RRFRanker.builder().k(20).build())
.limit(topk)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build();
}
};

Map<String, Object> config = new HashMap<>();
config.put("sparseCount", 0);
config.put("useFunctionScore", false);
// search with an empty nq, return error
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(0)));
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(config)));

// unequal nq, return error
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(1)));
config.put("sparseCount", 1);
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(config)));

// search on empty collection, no result returned
SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(nq));
config.put("sparseCount", nq);
SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(config));
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
Assertions.assertEquals(nq, searchResults.size());
for (List<SearchResp.SearchResult> result : searchResults) {
Expand All @@ -1142,7 +1160,8 @@ void testHybridSearch() {
Assertions.assertEquals(count, rowCount);

// search again, there are results
searchResp = client.hybridSearch(genRequestFunc.apply(nq));
config.put("useFunctionScore", true);
searchResp = client.hybridSearch(genRequestFunc.apply(config));
searchResults = searchResp.getSearchResults();
Assertions.assertEquals(nq, searchResults.size());
for (int i = 0; i < nq; i++) {
Expand Down
Loading