Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ public List<IDScore> getIDScore(int indexOfTarget) throws ParamException, Illega
return idScores;
}

/**
* Gets how many nq are searched.
* @return how many nq are searched
*/
public long getNumQueries() {
return results.getNumQueries();
}

@Getter
private static final class Position {
private final long offset;
Expand All @@ -250,11 +258,12 @@ public Position(long offset, long k) {
private Position getOffsetByIndex(int indexOfTarget) {
List<Long> kList = results.getTopksList();

// if the server didn't return separate topK, use same topK value
// if the server didn't return separate topK, use same topK value "0"
// will return an empty result for each nq instead of throwing an exception
if (kList.isEmpty()) {
kList = new ArrayList<>();
for (long i = 0; i < results.getNumQueries(); ++i) {
kList.add(results.getTopK());
kList.add(0L);
}
}

Expand Down
129 changes: 78 additions & 51 deletions sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.milvus.common.utils.Float16Utils;
import io.milvus.common.utils.GTsDict;
import io.milvus.common.utils.JsonUtils;
import io.milvus.exception.ParamException;
import io.milvus.grpc.*;
import io.milvus.orm.iterator.QueryIterator;
import io.milvus.orm.iterator.SearchIterator;
Expand Down Expand Up @@ -63,6 +64,7 @@
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

@Testcontainers(disabledWithoutDocker = true)
class MilvusClientDockerTest {
Expand Down Expand Up @@ -1345,18 +1347,6 @@ void testMultipleVectorFields() {
R<RpcStatus> createR = client.createCollection(createParam);
Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());

// insert data to multiple vector fields
int rowCount = 10000;
List<InsertParam.Field> fields = generateColumnsData(schema, rowCount, 0);

InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(randomCollectionName)
.withFields(fields)
.build();

R<MutationResult> insertR = client.insert(insertParam);
Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());

// create indexes on multiple vector fields
CreateIndexParam indexParam = CreateIndexParam.newBuilder()
.withCollectionName(randomCollectionName)
Expand Down Expand Up @@ -1397,53 +1387,86 @@ void testMultipleVectorFields() {
.build());
Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());

// search on multiple vector fields
AnnSearchParam param1 = AnnSearchParam.newBuilder()
.withVectorFieldName(DataType.FloatVector.name())
.withFloatVectors(utils.generateFloatVectors(1))
.withMetricType(MetricType.COSINE)
.withParams("{\"nprobe\": 32}")
.withLimit(10L)
.build();

AnnSearchParam param2 = AnnSearchParam.newBuilder()
.withVectorFieldName(DataType.BinaryVector.name())
.withBinaryVectors(utils.generateBinaryVectors(1))
.withMetricType(MetricType.HAMMING)
.withParams("{}")
.withLimit(5L)
.build();

AnnSearchParam param3 = AnnSearchParam.newBuilder()
.withVectorFieldName(DataType.SparseFloatVector.name())
.withSparseFloatVectors(utils.generateSparseVectors(1))
.withMetricType(MetricType.IP)
.withParams("{\"drop_ratio_search\":0.2}")
.withLimit(7L)
.build();
// prepare sub requests
int nq = 5;
long topk = 10L;
Function<Integer, HybridSearchParam> genRequestFunc =
sparseCount -> {
AnnSearchParam param1 = AnnSearchParam.newBuilder()
.withVectorFieldName(DataType.FloatVector.name())
.withFloatVectors(utils.generateFloatVectors(nq))
.withMetricType(MetricType.COSINE)
.withParams("{\"nprobe\": 32}")
.withLimit(15L)
.build();

AnnSearchParam param2 = AnnSearchParam.newBuilder()
.withVectorFieldName(DataType.BinaryVector.name())
.withBinaryVectors(utils.generateBinaryVectors(nq))
.withMetricType(MetricType.HAMMING)
.withParams("{}")
.withLimit(5L)
.build();

List<SortedMap<Long, Float>> sparseVEctors = sparseCount > 0 ?
utils.generateSparseVectors(sparseCount) : new ArrayList<>();
AnnSearchParam param3 = AnnSearchParam.newBuilder()
.withVectorFieldName(DataType.SparseFloatVector.name())
.withSparseFloatVectors(sparseVEctors)
.withMetricType(MetricType.IP)
.withParams("{\"drop_ratio_search\":0.2}")
.withLimit(7L)
.build();

// search with an empty nq, return error
return HybridSearchParam.newBuilder()
.withCollectionName(randomCollectionName)
.addOutField(DataType.SparseFloatVector.name())
.addSearchRequest(param1)
.addSearchRequest(param2)
.addSearchRequest(param3)
.withLimit(topk)
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
.withRanker(WeightedRanker.newBuilder()
.withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f))
.build())
.withOutFields(Collections.singletonList("*"))
.build();
};

// search with an empty nq, return error
Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(0));

// unequal nq, return error
Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(1));

// search on empty collection, no result returned
R<SearchResults> searchR = client.hybridSearch(genRequestFunc.apply(nq));
Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
for (int i = 0; i < results.getNumQueries(); ++i) {
List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
Assertions.assertTrue(scores.isEmpty());
}

HybridSearchParam searchParam = HybridSearchParam.newBuilder()
// insert data to multiple vector fields
int rowCount = 10000;
List<InsertParam.Field> fields = generateColumnsData(schema, rowCount, 0);
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(randomCollectionName)
.addOutField(DataType.SparseFloatVector.name())
.addSearchRequest(param1)
.addSearchRequest(param2)
.addSearchRequest(param3)
.withLimit(3L)
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
.withRanker(WeightedRanker.newBuilder()
.withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f))
.build())
.withOutFields(Collections.singletonList("*"))
.withFields(fields)
.build();
R<MutationResult> insertR = client.insert(insertParam);
Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());

R<SearchResults> searchR = client.hybridSearch(searchParam);
// search on multiple vector fields
searchR = client.hybridSearch(genRequestFunc.apply(nq));
Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());

// print search result
SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
// check search result
results = new SearchResultsWrapper(searchR.getData().getResults());
List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
for (SearchResultsWrapper.IDScore score : scores) {
System.out.println(score);
Object id = score.get("id");
Assertions.assertInstanceOf(Long.class, id);
Object fv = score.get(DataType.FloatVector.name());
Expand All @@ -1457,6 +1480,10 @@ void testMultipleVectorFields() {
Object sv = score.get(DataType.SparseFloatVector.name());
Assertions.assertInstanceOf(SortedMap.class, sv);
}
for (int i = 0; i < results.getNumQueries(); ++i) {
scores = results.getIDScore(i);
Assertions.assertEquals(topk, scores.size());
}

// drop collection
DropCollectionParam dropParam = DropCollectionParam.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2969,6 +2969,8 @@ void testSearchResultsWrapper() {
String fieldName = "test";
SearchResultData results = SearchResultData.newBuilder()
.setTopK(topK)
.addTopks(topK)
.addTopks(topK) // numQueries=2, the topks list must have 2 elements
.setNumQueries(numQueries)
.setIds(IDs.newBuilder()
.setIntId(LongArray.newBuilder()
Expand Down Expand Up @@ -2996,6 +2998,8 @@ void testSearchResultsWrapper() {
// for string id
results = SearchResultData.newBuilder()
.setTopK(topK)
.addTopks(topK)
.addTopks(topK) // numQueries=2, the topks list must have 2 elements
.setNumQueries(numQueries)
.setIds(IDs.newBuilder()
.setStrId(StringArray.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.milvus.orm.iterator.SearchIterator;
import io.milvus.orm.iterator.SearchIteratorV2;
import io.milvus.param.Constant;
import io.milvus.param.dml.HybridSearchParam;
import io.milvus.pool.MilvusClientV2Pool;
import io.milvus.pool.PoolConfig;
import io.milvus.response.QueryResultsWrapper;
Expand Down Expand Up @@ -1010,6 +1011,63 @@ void testHybridSearch() {
Assertions.assertEquals(16, descResp.getFieldNames().size());
Assertions.assertEquals(3, descResp.getVectorFieldNames().size());

// prepare sub requests
int nq = 5;
int topk = 10;
Function<Integer, HybridSearchReq> genRequestFunc =
sparseCount -> {
List<BaseVector> floatVectors = new ArrayList<>();
List<BaseVector> binaryVectors = new ArrayList<>();
List<BaseVector> sparseVectors = new ArrayList<>();
for (int i = 0; i < nq; i++) {
floatVectors.add(new FloatVec(utils.generateFloatVector()));
binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
}
for (int i = 0; i < sparseCount; i++) {
sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
}

List<AnnSearchReq> searchRequests = new ArrayList<>();
searchRequests.add(AnnSearchReq.builder()
.vectorFieldName("float_vector")
.vectors(floatVectors)
.params("{\"nprobe\": 10}")
.limit(15)
.build());
searchRequests.add(AnnSearchReq.builder()
.vectorFieldName("binary_vector")
.vectors(binaryVectors)
.limit(5)
.build());
searchRequests.add(AnnSearchReq.builder()
.vectorFieldName("sparse_vector")
.vectors(sparseVectors)
.limit(7)
.build());

return HybridSearchReq.builder()
.collectionName(randomCollectionName)
.searchRequests(searchRequests)
.ranker(new RRFRanker(20))
.limit(topk)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build();
};

// search with an empty nq, return error
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(0)));

// unequal nq, return error
Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(1)));

// search on empty collection, no result returned
SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(nq));
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
Assertions.assertEquals(nq, searchResults.size());
for (List<SearchResp.SearchResult> result : searchResults) {
Assertions.assertTrue(result.isEmpty());
}

// insert rows
long count = 10000;
List<JsonObject> data = generateRandomData(collectionSchema, count);
Expand All @@ -1023,45 +1081,9 @@ void testHybridSearch() {
long rowCount = getRowCount(randomCollectionName);
Assertions.assertEquals(count, rowCount);

// hybrid search in collection
int nq = 5;
int topk = 10;
List<BaseVector> floatVectors = new ArrayList<>();
List<BaseVector> binaryVectors = new ArrayList<>();
List<BaseVector> sparseVectors = new ArrayList<>();
for (int i = 0; i < nq; i++) {
floatVectors.add(new FloatVec(utils.generateFloatVector()));
binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
}

List<AnnSearchReq> searchRequests = new ArrayList<>();
searchRequests.add(AnnSearchReq.builder()
.vectorFieldName("float_vector")
.vectors(floatVectors)
.params("{\"nprobe\": 10}")
.limit(10)
.build());
searchRequests.add(AnnSearchReq.builder()
.vectorFieldName("binary_vector")
.vectors(binaryVectors)
.limit(50)
.build());
searchRequests.add(AnnSearchReq.builder()
.vectorFieldName("sparse_vector")
.vectors(sparseVectors)
.limit(100)
.build());

HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
.collectionName(randomCollectionName)
.searchRequests(searchRequests)
.ranker(new RRFRanker(20))
.limit(topk)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build();
SearchResp searchResp = client.hybridSearch(hybridSearchReq);
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
// search again, there are results
searchResp = client.hybridSearch(genRequestFunc.apply(nq));
searchResults = searchResp.getSearchResults();
Assertions.assertEquals(nq, searchResults.size());
for (int i = 0; i < nq; i++) {
List<SearchResp.SearchResult> results = searchResults.get(i);
Expand Down
Loading