Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/src/main/java/io/milvus/v2/IteratorExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,44 @@ private static void queryIterator(String expr, int batchSize, int offset, int li
System.out.printf("%d query results returned%n", counter);
}

private static void queryIteratorWithTemplate(int batchSize) {
System.out.println("\n========== queryIterator() ==========");
List<Long> ids = new ArrayList<>();
for (long i = 500L; i < 600L; i++) {
ids.add(i);
}
Map<String, Object> template = new HashMap<>();
template.put("my_ids", ids);

String filter = ID_FIELD + " in {my_ids}";
QueryIterator queryIterator = client.queryIterator(QueryIteratorReq.builder()
.collectionName(COLLECTION_NAME)
.expr(filter)
.outputFields(Lists.newArrayList(ID_FIELD, AGE_FIELD))
.batchSize(batchSize)
.filterTemplateValues(template)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build());

System.out.println("QueryIterator with filter template results:");
int counter = 0;
while (true) {
List<QueryResultsWrapper.RowRecord> res = queryIterator.next();
if (res.isEmpty()) {
System.out.println("query iteration finished, close");
queryIterator.close();
break;
}

for (QueryResultsWrapper.RowRecord record : res) {
System.out.println(record);
counter++;
}
}
System.out.printf("%d query results returned%n", counter);
}


// Search iterator V1
private static void searchIteratorV1(String expr, String params, int batchSize, int topK) {
System.out.println("\n========== searchIteratorV1() ==========");
Expand Down Expand Up @@ -235,16 +273,59 @@ private static void searchIteratorV2(String filter, Map<String, Object> params,
System.out.printf("%d search results returned\n%n", counter);
}

private static void searchIteratorV2WithTemplate(int batchSize) {
System.out.println("\n========== searchIteratorV2() ==========");
List<Long> ids = new ArrayList<>();
for (long i = 500L; i < 600L; i++) {
ids.add(i);
}
Map<String, Object> template = new HashMap<>();
template.put("my_ids", ids);

String filter = ID_FIELD + " in {my_ids}";
SearchIteratorV2 searchIterator = client.searchIteratorV2(SearchIteratorReqV2.builder()
.collectionName(COLLECTION_NAME)
.outputFields(Lists.newArrayList(AGE_FIELD))
.batchSize(batchSize)
.vectorFieldName(VECTOR_FIELD)
.vectors(Collections.singletonList(new FloatVec(CommonUtils.generateFloatVector(VECTOR_DIM))))
.filter(filter)
.filterTemplateValues(template)
.metricType(IndexParam.MetricType.L2)
.consistencyLevel(ConsistencyLevel.BOUNDED)
.build());

System.out.println("SearchIteratorV2 with filter template results:");
int counter = 0;
while (true) {
List<SearchResp.SearchResult> res = searchIterator.next();
if (res.isEmpty()) {
System.out.println("Search iteration finished, close");
searchIterator.close();
break;
}

for (SearchResp.SearchResult record : res) {
System.out.println(record);
counter++;
}
}
System.out.printf("%d search results returned\n%n", counter);
}

public static void main(String[] args) {
buildCollection();
queryIterator("userID < 300", 50, 5, 400);
queryIteratorWithTemplate(80);

searchIteratorV1("userAge > 50 &&userAge < 100", "{\"range_filter\": 15.0, \"radius\": 20.0}", 100, 500);
searchIteratorV1("", "", 10, 99);
searchIteratorV2("userAge > 10 &&userAge < 20", null, 50, 120, null);

Map<String, Object> extraParams = new HashMap<>();
extraParams.put("radius", 15.0);
searchIteratorV2("", extraParams, 50, 100, null);
searchIteratorV2WithTemplate(80);

// use external function to filter the result
Function<List<SearchResp.SearchResult>, List<SearchResp.SearchResult>> externalFilterFunc = (List<SearchResp.SearchResult> src) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import io.milvus.param.collection.FieldType;
import io.milvus.param.dml.QueryIteratorParam;
import io.milvus.param.dml.SearchIteratorParam;
import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.QueryIteratorReq;
import io.milvus.v2.service.vector.request.SearchIteratorReq;
import io.milvus.v2.service.vector.request.data.BaseVector;
import io.milvus.v2.service.vector.request.data.*;
import org.apache.commons.lang3.StringUtils;

import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand Down Expand Up @@ -59,6 +61,27 @@ public static QueryIteratorParam convertV2Req(QueryIteratorReq queryIteratorReq)
return builder.build();
}

public static QueryIteratorReq convertV1Param(QueryIteratorParam param) {
ConsistencyLevel level = null;
if (param.getConsistencyLevel() != null) {
level = ConsistencyLevel.valueOf(param.getConsistencyLevel().name());
}

return QueryIteratorReq.builder()
.databaseName(param.getDatabaseName())
.collectionName(param.getCollectionName())
.partitionNames(param.getPartitionNames())
.expr(param.getExpr())
.outputFields(param.getOutFields())
.offset(param.getOffset())
.limit(param.getLimit())
.ignoreGrowing(param.isIgnoreGrowing())
.batchSize(param.getBatchSize())
.reduceStopForBest(param.isReduceStopForBest())
.consistencyLevel(level)
.build();
}

public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorReq) {
MetricType metricType = MetricType.None;
if (searchIteratorReq.getMetricType() != IndexParam.MetricType.INVALID) {
Expand Down Expand Up @@ -130,6 +153,67 @@ public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorR
return builder.build();
}

public static SearchIteratorReq convertV1Param(SearchIteratorParam param) {
ConsistencyLevel level = null;
if (param.getConsistencyLevel() != null) {
level = ConsistencyLevel.valueOf(param.getConsistencyLevel().name());
}

IndexParam.MetricType metricType = IndexParam.MetricType.INVALID;
if (StringUtils.isEmpty(param.getMetricType())) {
metricType = IndexParam.MetricType.valueOf(param.getMetricType());
}

List<BaseVector> vectors = new ArrayList<>();
switch (param.getPlType()) {
case FloatVector: {
List<List<Float>> data = (List<List<Float>>) param.getVectors();
data.forEach(vector -> vectors.add(new FloatVec(vector)));
break;
}
case BinaryVector: {
List<ByteBuffer> data = (List<ByteBuffer>) param.getVectors();
data.forEach(vector -> vectors.add(new BinaryVec(vector)));
break;
}
case Float16Vector: {
List<ByteBuffer> data = (List<ByteBuffer>) param.getVectors();
data.forEach(vector -> vectors.add(new Float16Vec(vector)));
break;
}
case BFloat16Vector: {
List<ByteBuffer> data = (List<ByteBuffer>) param.getVectors();
data.forEach(vector -> vectors.add(new BFloat16Vec(vector)));
break;
}
case SparseFloatVector: {
List<SortedMap<Long, Float>> data = (List<SortedMap<Long, Float>>) param.getVectors();
data.forEach(vector -> vectors.add(new SparseFloatVec(vector)));
break;
}
default:
throw new ParamException("Unsupported vector type.");
}

return SearchIteratorReq.builder()
.databaseName(param.getDatabaseName())
.collectionName(param.getCollectionName())
.partitionNames(param.getPartitionNames())
.vectorFieldName(param.getVectorFieldName())
.vectors(vectors)
.consistencyLevel(level)
.metricType(metricType)
.limit(param.getTopK())
.expr(param.getExpr())
.outputFields(param.getOutFields())
.roundDecimal(param.getRoundDecimal())
.params(param.getParams())
.groupByFieldName(param.getGroupByFieldName())
.ignoreGrowing(param.isIgnoreGrowing())
.batchSize(param.getBatchSize())
.build();
}

public static FieldType convertV2Field(CreateCollectionReq.FieldSchema schema) {
FieldType.Builder builder = FieldType.newBuilder()
.withName(schema.getName())
Expand Down
52 changes: 27 additions & 25 deletions sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@

import io.milvus.grpc.*;
import io.milvus.param.Constant;
import io.milvus.param.ParamUtils;
import io.milvus.param.collection.FieldType;
import io.milvus.param.dml.QueryIteratorParam;
import io.milvus.param.dml.QueryParam;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.QueryIteratorReq;
import io.milvus.v2.service.vector.request.QueryReq;
import io.milvus.v2.utils.RpcUtils;
import io.milvus.v2.utils.VectorUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -44,7 +44,7 @@ public class QueryIterator {
private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
private final FieldType primaryField;

private final QueryIteratorParam queryIteratorParam;
private final QueryIteratorReq queryIteratorReq;
private final int batchSize;
private final long limit;
private final String expr;
Expand All @@ -61,7 +61,7 @@ public QueryIterator(QueryIteratorParam queryIteratorParam,
this.iteratorCache = new IteratorCache();
this.blockingStub = blockingStub;
this.primaryField = primaryField;
this.queryIteratorParam = queryIteratorParam;
this.queryIteratorReq = IteratorAdapterV2.convertV1Param(queryIteratorParam);

this.batchSize = (int) queryIteratorParam.getBatchSize();
this.expr = queryIteratorParam.getExpr();
Expand All @@ -78,15 +78,14 @@ public QueryIterator(QueryIteratorReq queryIteratorReq,
CreateCollectionReq.FieldSchema primaryField) {
this.iteratorCache = new IteratorCache();
this.blockingStub = blockingStub;
IteratorAdapterV2 adapter = new IteratorAdapterV2();
this.queryIteratorParam = adapter.convertV2Req(queryIteratorReq);
this.primaryField = adapter.convertV2Field(primaryField);
this.queryIteratorReq = queryIteratorReq;
this.primaryField = IteratorAdapterV2.convertV2Field(primaryField);


this.batchSize = (int) queryIteratorParam.getBatchSize();
this.expr = queryIteratorParam.getExpr();
this.limit = queryIteratorParam.getLimit();
this.offset = queryIteratorParam.getOffset();
this.batchSize = (int) queryIteratorReq.getBatchSize();
this.expr = queryIteratorReq.getExpr();
this.limit = queryIteratorReq.getLimit();
this.offset = queryIteratorReq.getOffset();
this.rpcUtils = new RpcUtils();

setupTsByRequest();
Expand Down Expand Up @@ -208,24 +207,27 @@ private boolean isResSufficient(List<QueryResultsWrapper.RowRecord> ret) {
private QueryResults executeQuery(String expr, long offset, long limit, long ts, boolean isSeek) {
// for seeking offset, no need to return output fields
List<String> outputFields = new ArrayList<>();
boolean reduceStopForBest = queryIteratorParam.isReduceStopForBest();
boolean reduceStopForBest = queryIteratorReq.isReduceStopForBest();
if (!isSeek) {
outputFields = queryIteratorParam.getOutFields();
outputFields = queryIteratorReq.getOutputFields();
reduceStopForBest = false;
}
QueryParam queryParam = QueryParam.newBuilder()
.withDatabaseName(queryIteratorParam.getDatabaseName())
.withCollectionName(queryIteratorParam.getCollectionName())
.withConsistencyLevel(queryIteratorParam.getConsistencyLevel())
.withPartitionNames(queryIteratorParam.getPartitionNames())
.withOutFields(outputFields)
.withExpr(expr)
.withOffset(offset)
.withLimit(limit)
.withIgnoreGrowing(queryIteratorParam.isIgnoreGrowing())
QueryReq queryReq = QueryReq.builder()
.databaseName(queryIteratorReq.getDatabaseName())
.collectionName(queryIteratorReq.getCollectionName())
.partitionNames(queryIteratorReq.getPartitionNames())
.consistencyLevel(queryIteratorReq.getConsistencyLevel())
.outputFields(outputFields)
.filter(expr)
.offset(offset)
.limit(limit)
.ignoreGrowing(queryIteratorReq.isIgnoreGrowing())
.timezone(queryIteratorReq.getTimezone())
.filterTemplateValues(queryIteratorReq.getFilterTemplateValues())
.build();

QueryRequest queryRequest = ParamUtils.convertQueryParam(queryParam);
VectorUtils vectorUtils = new VectorUtils();
QueryRequest queryRequest = vectorUtils.ConvertToGrpcQueryRequest(queryReq);
QueryRequest.Builder builder = queryRequest.toBuilder();
// reduce stop for best
builder.addQueryParams(KeyValuePair.newBuilder()
Expand All @@ -246,7 +248,7 @@ private QueryResults executeQuery(String expr, long offset, long limit, long ts,
builder.setUseDefaultConsistency(true);

QueryResults response = rpcUtils.retry(() -> blockingStub.query(builder.build()));
String title = String.format("QueryRequest collectionName:%s", queryIteratorParam.getCollectionName());
String title = String.format("QueryRequest collectionName:%s", queryIteratorReq.getCollectionName());
rpcUtils.handleResponse(title, response.getStatus());
return response;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import static io.milvus.param.Constant.UNLIMITED;

public class SearchIteratorV2 {
private static final Logger logger = LoggerFactory.getLogger(SearchIterator.class);
private static final Logger logger = LoggerFactory.getLogger(SearchIteratorV2.class);
private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;

private final SearchIteratorReqV2 searchIteratorReq;
Expand Down Expand Up @@ -86,12 +86,12 @@ private void checkParams() {
int rows = searchIteratorReq.getVectors().size();
if (rows > 1) {
ExceptionUtils.throwUnExpectedException("SearchIterator does not support processing multiple vectors simultaneously");
} else if (rows <= 0) {
} else if (rows == 0) {
ExceptionUtils.throwUnExpectedException("The vector data for search cannot be empty");
}

if (searchIteratorReq.getTopK() != UNLIMITED) {
this.leftResCnt = searchIteratorReq.getTopK();
if (searchIteratorReq.getLimit() != UNLIMITED) {
this.leftResCnt = (int) searchIteratorReq.getLimit();
}
}

Expand All @@ -117,15 +117,17 @@ private SearchResults executeSearch(int limit) {
.databaseName(searchIteratorReq.getDatabaseName())
.annsField(searchIteratorReq.getVectorFieldName())
.data(searchIteratorReq.getVectors())
.topK(limit)
.limit(limit)
.filter(searchIteratorReq.getFilter())
.consistencyLevel(searchIteratorReq.getConsistencyLevel())
.outputFields(searchIteratorReq.getOutputFields())
.roundDecimal(searchIteratorReq.getRoundDecimal())
.searchParams(searchParams)
.metricType(searchIteratorReq.getMetricType())
.timezone(searchIteratorReq.getTimezone())
.ignoreGrowing(searchIteratorReq.isIgnoreGrowing())
.groupByFieldName(searchIteratorReq.getGroupByFieldName())
.filterTemplateValues(searchIteratorReq.getFilterTemplateValues())
.build();
SearchRequest searchRequest = new VectorUtils().ConvertToGrpcSearchRequest(request);
SearchResults response = rpcUtils.retry(() -> this.blockingStub.search(searchRequest));
Expand Down
Loading
Loading