diff --git a/examples/src/main/java/io/milvus/v2/IteratorExample.java b/examples/src/main/java/io/milvus/v2/IteratorExample.java index d7791313d..3dde2a88e 100644 --- a/examples/src/main/java/io/milvus/v2/IteratorExample.java +++ b/examples/src/main/java/io/milvus/v2/IteratorExample.java @@ -160,6 +160,44 @@ private static void queryIterator(String expr, int batchSize, int offset, int li System.out.printf("%d query results returned%n", counter); } + private static void queryIteratorWithTemplate(int batchSize) { + System.out.println("\n========== queryIterator() =========="); + List ids = new ArrayList<>(); + for (long i = 500L; i < 600L; i++) { + ids.add(i); + } + Map template = new HashMap<>(); + template.put("my_ids", ids); + + String filter = ID_FIELD + " in {my_ids}"; + QueryIterator queryIterator = client.queryIterator(QueryIteratorReq.builder() + .collectionName(COLLECTION_NAME) + .expr(filter) + .outputFields(Lists.newArrayList(ID_FIELD, AGE_FIELD)) + .batchSize(batchSize) + .filterTemplateValues(template) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build()); + + System.out.println("QueryIterator with filter template results:"); + int counter = 0; + while (true) { + List res = queryIterator.next(); + if (res.isEmpty()) { + System.out.println("query iteration finished, close"); + queryIterator.close(); + break; + } + + for (QueryResultsWrapper.RowRecord record : res) { + System.out.println(record); + counter++; + } + } + System.out.printf("%d query results returned%n", counter); + } + + // Search iterator V1 private static void searchIteratorV1(String expr, String params, int batchSize, int topK) { System.out.println("\n========== searchIteratorV1() =========="); @@ -235,9 +273,51 @@ private static void searchIteratorV2(String filter, Map params, System.out.printf("%d search results returned\n%n", counter); } + private static void searchIteratorV2WithTemplate(int batchSize) { + System.out.println("\n========== searchIteratorV2() =========="); + List ids = new ArrayList<>(); + for (long i = 500L; i < 600L; i++) { + ids.add(i); + } + Map template = new HashMap<>(); + template.put("my_ids", ids); + + String filter = ID_FIELD + " in {my_ids}"; + SearchIteratorV2 searchIterator = client.searchIteratorV2(SearchIteratorReqV2.builder() + .collectionName(COLLECTION_NAME) + .outputFields(Lists.newArrayList(AGE_FIELD)) + .batchSize(batchSize) + .vectorFieldName(VECTOR_FIELD) + .vectors(Collections.singletonList(new FloatVec(CommonUtils.generateFloatVector(VECTOR_DIM)))) + .filter(filter) + .filterTemplateValues(template) + .metricType(IndexParam.MetricType.L2) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build()); + + System.out.println("SearchIteratorV2 with filter template results:"); + int counter = 0; + while (true) { + List res = searchIterator.next(); + if (res.isEmpty()) { + System.out.println("Search iteration finished, close"); + searchIterator.close(); + break; + } + + for (SearchResp.SearchResult record : res) { + System.out.println(record); + counter++; + } + } + System.out.printf("%d search results returned\n%n", counter); + } + public static void main(String[] args) { buildCollection(); queryIterator("userID < 300", 50, 5, 400); + queryIteratorWithTemplate(80); + searchIteratorV1("userAge > 50 &&userAge < 100", "{\"range_filter\": 15.0, \"radius\": 20.0}", 100, 500); searchIteratorV1("", "", 10, 99); searchIteratorV2("userAge > 10 &&userAge < 20", null, 50, 120, null); @@ -245,6 +325,7 @@ public static void main(String[] args) { Map extraParams = new HashMap<>(); extraParams.put("radius", 15.0); searchIteratorV2("", extraParams, 50, 100, null); + searchIteratorV2WithTemplate(80); // use external function to filter the result Function, List> externalFilterFunc = (List src) -> { diff --git a/sdk-core/src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java b/sdk-core/src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java index 952e334a6..05fe7ae23 100644 --- a/sdk-core/src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java +++ b/sdk-core/src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java @@ -27,11 +27,13 @@ import io.milvus.param.collection.FieldType; import io.milvus.param.dml.QueryIteratorParam; import io.milvus.param.dml.SearchIteratorParam; +import io.milvus.v2.common.ConsistencyLevel; import io.milvus.v2.common.IndexParam; import io.milvus.v2.service.collection.request.CreateCollectionReq; import io.milvus.v2.service.vector.request.QueryIteratorReq; import io.milvus.v2.service.vector.request.SearchIteratorReq; -import io.milvus.v2.service.vector.request.data.BaseVector; +import io.milvus.v2.service.vector.request.data.*; +import org.apache.commons.lang3.StringUtils; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -59,6 +61,27 @@ public static QueryIteratorParam convertV2Req(QueryIteratorReq queryIteratorReq) return builder.build(); } + public static QueryIteratorReq convertV1Param(QueryIteratorParam param) { + ConsistencyLevel level = null; + if (param.getConsistencyLevel() != null) { + level = ConsistencyLevel.valueOf(param.getConsistencyLevel().name()); + } + + return QueryIteratorReq.builder() + .databaseName(param.getDatabaseName()) + .collectionName(param.getCollectionName()) + .partitionNames(param.getPartitionNames()) + .expr(param.getExpr()) + .outputFields(param.getOutFields()) + .offset(param.getOffset()) + .limit(param.getLimit()) + .ignoreGrowing(param.isIgnoreGrowing()) + .batchSize(param.getBatchSize()) + .reduceStopForBest(param.isReduceStopForBest()) + .consistencyLevel(level) + .build(); + } + public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorReq) { MetricType metricType = MetricType.None; if (searchIteratorReq.getMetricType() != IndexParam.MetricType.INVALID) { @@ -130,6 +153,67 @@ public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorR return builder.build(); } + public static SearchIteratorReq convertV1Param(SearchIteratorParam param) { + ConsistencyLevel level = null; + if (param.getConsistencyLevel() != null) { + level = ConsistencyLevel.valueOf(param.getConsistencyLevel().name()); + } + + IndexParam.MetricType metricType = IndexParam.MetricType.INVALID; + if (StringUtils.isEmpty(param.getMetricType())) { + metricType = IndexParam.MetricType.valueOf(param.getMetricType()); + } + + List vectors = new ArrayList<>(); + switch (param.getPlType()) { + case FloatVector: { + List> data = (List>) param.getVectors(); + data.forEach(vector -> vectors.add(new FloatVec(vector))); + break; + } + case BinaryVector: { + List data = (List) param.getVectors(); + data.forEach(vector -> vectors.add(new BinaryVec(vector))); + break; + } + case Float16Vector: { + List data = (List) param.getVectors(); + data.forEach(vector -> vectors.add(new Float16Vec(vector))); + break; + } + case BFloat16Vector: { + List data = (List) param.getVectors(); + data.forEach(vector -> vectors.add(new BFloat16Vec(vector))); + break; + } + case SparseFloatVector: { + List> data = (List>) param.getVectors(); + data.forEach(vector -> vectors.add(new SparseFloatVec(vector))); + break; + } + default: + throw new ParamException("Unsupported vector type."); + } + + return SearchIteratorReq.builder() + .databaseName(param.getDatabaseName()) + .collectionName(param.getCollectionName()) + .partitionNames(param.getPartitionNames()) + .vectorFieldName(param.getVectorFieldName()) + .vectors(vectors) + .consistencyLevel(level) + .metricType(metricType) + .limit(param.getTopK()) + .expr(param.getExpr()) + .outputFields(param.getOutFields()) + .roundDecimal(param.getRoundDecimal()) + .params(param.getParams()) + .groupByFieldName(param.getGroupByFieldName()) + .ignoreGrowing(param.isIgnoreGrowing()) + .batchSize(param.getBatchSize()) + .build(); + } + public static FieldType convertV2Field(CreateCollectionReq.FieldSchema schema) { FieldType.Builder builder = FieldType.newBuilder() .withName(schema.getName()) diff --git a/sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java b/sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java index a62fe7259..f95df2998 100644 --- a/sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java +++ b/sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java @@ -21,14 +21,14 @@ import io.milvus.grpc.*; import io.milvus.param.Constant; -import io.milvus.param.ParamUtils; import io.milvus.param.collection.FieldType; import io.milvus.param.dml.QueryIteratorParam; -import io.milvus.param.dml.QueryParam; import io.milvus.response.QueryResultsWrapper; import io.milvus.v2.service.collection.request.CreateCollectionReq; import io.milvus.v2.service.vector.request.QueryIteratorReq; +import io.milvus.v2.service.vector.request.QueryReq; import io.milvus.v2.utils.RpcUtils; +import io.milvus.v2.utils.VectorUtils; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,7 +44,7 @@ public class QueryIterator { private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub; private final FieldType primaryField; - private final QueryIteratorParam queryIteratorParam; + private final QueryIteratorReq queryIteratorReq; private final int batchSize; private final long limit; private final String expr; @@ -61,7 +61,7 @@ public QueryIterator(QueryIteratorParam queryIteratorParam, this.iteratorCache = new IteratorCache(); this.blockingStub = blockingStub; this.primaryField = primaryField; - this.queryIteratorParam = queryIteratorParam; + this.queryIteratorReq = IteratorAdapterV2.convertV1Param(queryIteratorParam); this.batchSize = (int) queryIteratorParam.getBatchSize(); this.expr = queryIteratorParam.getExpr(); @@ -78,15 +78,14 @@ public QueryIterator(QueryIteratorReq queryIteratorReq, CreateCollectionReq.FieldSchema primaryField) { this.iteratorCache = new IteratorCache(); this.blockingStub = blockingStub; - IteratorAdapterV2 adapter = new IteratorAdapterV2(); - this.queryIteratorParam = adapter.convertV2Req(queryIteratorReq); - this.primaryField = adapter.convertV2Field(primaryField); + this.queryIteratorReq = queryIteratorReq; + this.primaryField = IteratorAdapterV2.convertV2Field(primaryField); - this.batchSize = (int) queryIteratorParam.getBatchSize(); - this.expr = queryIteratorParam.getExpr(); - this.limit = queryIteratorParam.getLimit(); - this.offset = queryIteratorParam.getOffset(); + this.batchSize = (int) queryIteratorReq.getBatchSize(); + this.expr = queryIteratorReq.getExpr(); + this.limit = queryIteratorReq.getLimit(); + this.offset = queryIteratorReq.getOffset(); this.rpcUtils = new RpcUtils(); setupTsByRequest(); @@ -208,24 +207,27 @@ private boolean isResSufficient(List ret) { private QueryResults executeQuery(String expr, long offset, long limit, long ts, boolean isSeek) { // for seeking offset, no need to return output fields List outputFields = new ArrayList<>(); - boolean reduceStopForBest = queryIteratorParam.isReduceStopForBest(); + boolean reduceStopForBest = queryIteratorReq.isReduceStopForBest(); if (!isSeek) { - outputFields = queryIteratorParam.getOutFields(); + outputFields = queryIteratorReq.getOutputFields(); reduceStopForBest = false; } - QueryParam queryParam = QueryParam.newBuilder() - .withDatabaseName(queryIteratorParam.getDatabaseName()) - .withCollectionName(queryIteratorParam.getCollectionName()) - .withConsistencyLevel(queryIteratorParam.getConsistencyLevel()) - .withPartitionNames(queryIteratorParam.getPartitionNames()) - .withOutFields(outputFields) - .withExpr(expr) - .withOffset(offset) - .withLimit(limit) - .withIgnoreGrowing(queryIteratorParam.isIgnoreGrowing()) + QueryReq queryReq = QueryReq.builder() + .databaseName(queryIteratorReq.getDatabaseName()) + .collectionName(queryIteratorReq.getCollectionName()) + .partitionNames(queryIteratorReq.getPartitionNames()) + .consistencyLevel(queryIteratorReq.getConsistencyLevel()) + .outputFields(outputFields) + .filter(expr) + .offset(offset) + .limit(limit) + .ignoreGrowing(queryIteratorReq.isIgnoreGrowing()) + .timezone(queryIteratorReq.getTimezone()) + .filterTemplateValues(queryIteratorReq.getFilterTemplateValues()) .build(); - QueryRequest queryRequest = ParamUtils.convertQueryParam(queryParam); + VectorUtils vectorUtils = new VectorUtils(); + QueryRequest queryRequest = vectorUtils.ConvertToGrpcQueryRequest(queryReq); QueryRequest.Builder builder = queryRequest.toBuilder(); // reduce stop for best builder.addQueryParams(KeyValuePair.newBuilder() @@ -246,7 +248,7 @@ private QueryResults executeQuery(String expr, long offset, long limit, long ts, builder.setUseDefaultConsistency(true); QueryResults response = rpcUtils.retry(() -> blockingStub.query(builder.build())); - String title = String.format("QueryRequest collectionName:%s", queryIteratorParam.getCollectionName()); + String title = String.format("QueryRequest collectionName:%s", queryIteratorReq.getCollectionName()); rpcUtils.handleResponse(title, response.getStatus()); return response; } diff --git a/sdk-core/src/main/java/io/milvus/orm/iterator/SearchIteratorV2.java b/sdk-core/src/main/java/io/milvus/orm/iterator/SearchIteratorV2.java index 7af337d78..b925d82c4 100644 --- a/sdk-core/src/main/java/io/milvus/orm/iterator/SearchIteratorV2.java +++ b/sdk-core/src/main/java/io/milvus/orm/iterator/SearchIteratorV2.java @@ -42,7 +42,7 @@ import static io.milvus.param.Constant.UNLIMITED; public class SearchIteratorV2 { - private static final Logger logger = LoggerFactory.getLogger(SearchIterator.class); + private static final Logger logger = LoggerFactory.getLogger(SearchIteratorV2.class); private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub; private final SearchIteratorReqV2 searchIteratorReq; @@ -86,12 +86,12 @@ private void checkParams() { int rows = searchIteratorReq.getVectors().size(); if (rows > 1) { ExceptionUtils.throwUnExpectedException("SearchIterator does not support processing multiple vectors simultaneously"); - } else if (rows <= 0) { + } else if (rows == 0) { ExceptionUtils.throwUnExpectedException("The vector data for search cannot be empty"); } - if (searchIteratorReq.getTopK() != UNLIMITED) { - this.leftResCnt = searchIteratorReq.getTopK(); + if (searchIteratorReq.getLimit() != UNLIMITED) { + this.leftResCnt = (int) searchIteratorReq.getLimit(); } } @@ -117,15 +117,17 @@ private SearchResults executeSearch(int limit) { .databaseName(searchIteratorReq.getDatabaseName()) .annsField(searchIteratorReq.getVectorFieldName()) .data(searchIteratorReq.getVectors()) - .topK(limit) + .limit(limit) .filter(searchIteratorReq.getFilter()) .consistencyLevel(searchIteratorReq.getConsistencyLevel()) .outputFields(searchIteratorReq.getOutputFields()) .roundDecimal(searchIteratorReq.getRoundDecimal()) .searchParams(searchParams) .metricType(searchIteratorReq.getMetricType()) + .timezone(searchIteratorReq.getTimezone()) .ignoreGrowing(searchIteratorReq.isIgnoreGrowing()) .groupByFieldName(searchIteratorReq.getGroupByFieldName()) + .filterTemplateValues(searchIteratorReq.getFilterTemplateValues()) .build(); SearchRequest searchRequest = new VectorUtils().ConvertToGrpcSearchRequest(request); SearchResults response = rpcUtils.retry(() -> this.blockingStub.search(searchRequest)); diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java index 5abc123f0..b4f0eede4 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java @@ -22,7 +22,9 @@ import io.milvus.v2.common.IndexParam; import io.milvus.v2.service.vector.request.data.BaseVector; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class AnnSearchReq { private String vectorFieldName; @@ -37,6 +39,16 @@ public class AnnSearchReq { private IndexParam.MetricType metricType; private String timezone; + // Expression template, to improve expression parsing performance in complicated list + // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......] + // The long list of city will increase the time cost to parse this expression. + // So, we provide exprTemplateValues for this purpose, user can set filter like this: + // filter = "pk > {age} and city in {city}" + // filterTemplateValues = Map{"age": 3, "city": List{"beijing", "shanghai", ......}} + // Valid value of this map can be: + // Boolean, Long, Double, String, List, List, List, List + private Map filterTemplateValues; + private AnnSearchReq(AnnSearchReqBuilder builder) { this.vectorFieldName = builder.vectorFieldName; this.topK = builder.topK; @@ -47,6 +59,7 @@ private AnnSearchReq(AnnSearchReqBuilder builder) { this.params = builder.params; this.metricType = builder.metricType; this.timezone = builder.timezone; + this.filterTemplateValues = builder.filterTemplateValues; } public static AnnSearchReqBuilder builder() { @@ -129,6 +142,10 @@ public String getTimezone() { return timezone; } + public Map getFilterTemplateValues() { + return filterTemplateValues; + } + @Override public String toString() { return "AnnSearchReq{" + @@ -141,6 +158,7 @@ public String toString() { ", params='" + params + '\'' + ", metricType=" + metricType + ", timezone='" + timezone + '\'' + +// ", filterTemplateValues=" + filterTemplateValues + '}'; } @@ -154,6 +172,7 @@ public static class AnnSearchReqBuilder { private String params; private IndexParam.MetricType metricType = null; private String timezone = ""; + private Map filterTemplateValues = new HashMap<>(); public AnnSearchReqBuilder vectorFieldName(String vectorFieldName) { this.vectorFieldName = vectorFieldName; @@ -208,6 +227,11 @@ public AnnSearchReqBuilder timezone(String timezone) { return this; } + public AnnSearchReqBuilder filterTemplateValues(Map filterTemplateValues) { + this.filterTemplateValues = filterTemplateValues; + return this; + } + public AnnSearchReq build() { return new AnnSearchReq(this); } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java index 3a70876b2..479089423 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java @@ -109,7 +109,7 @@ public String toString() { ", partitionName='" + partitionName + '\'' + ", filter='" + filter + '\'' + ", ids=" + ids + - ", filterTemplateValues=" + filterTemplateValues + +// ", filterTemplateValues=" + filterTemplateValues + '}'; } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryIteratorReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryIteratorReq.java index fb57d356f..d9c583db2 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryIteratorReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryIteratorReq.java @@ -3,7 +3,9 @@ import com.google.common.collect.Lists; import io.milvus.v2.common.ConsistencyLevel; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class QueryIteratorReq { private String databaseName; @@ -15,9 +17,20 @@ public class QueryIteratorReq { private long offset; private long limit; private boolean ignoreGrowing; + private String timezone; private long batchSize; private boolean reduceStopForBest; + // Expression template, to improve expression parsing performance in complicated list + // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......] + // The long list of city will increase the time cost to parse this expression. + // So, we provide exprTemplateValues for this purpose, user can set filter like this: + // filter = "pk > {age} and city in {city}" + // filterTemplateValues = Map{"age": 3, "city": List{"beijing", "shanghai", ......}} + // Valid value of this map can be: + // Boolean, Long, Double, String, List, List, List, List + private Map filterTemplateValues; + private QueryIteratorReq(QueryIteratorReqBuilder builder) { this.databaseName = builder.databaseName; this.collectionName = builder.collectionName; @@ -28,8 +41,10 @@ private QueryIteratorReq(QueryIteratorReqBuilder builder) { this.offset = builder.offset; this.limit = builder.limit; this.ignoreGrowing = builder.ignoreGrowing; + this.timezone = builder.timezone; this.batchSize = builder.batchSize; this.reduceStopForBest = builder.reduceStopForBest; + this.filterTemplateValues = builder.filterTemplateValues; } public static QueryIteratorReqBuilder builder() { @@ -108,6 +123,10 @@ public void setIgnoreGrowing(boolean ignoreGrowing) { this.ignoreGrowing = ignoreGrowing; } + public String getTimezone() { + return timezone; + } + public long getBatchSize() { return batchSize; } @@ -124,6 +143,10 @@ public void setReduceStopForBest(boolean reduceStopForBest) { this.reduceStopForBest = reduceStopForBest; } + public Map getFilterTemplateValues() { + return filterTemplateValues; + } + @Override public String toString() { return "QueryIteratorReq{" + @@ -136,6 +159,7 @@ public String toString() { ", offset=" + offset + ", limit=" + limit + ", ignoreGrowing=" + ignoreGrowing + + ", timezone='" + timezone + '\'' + ", batchSize=" + batchSize + ", reduceStopForBest=" + reduceStopForBest + '}'; @@ -151,8 +175,10 @@ public static class QueryIteratorReqBuilder { private long offset = 0; private long limit = -1; private boolean ignoreGrowing = false; + private String timezone = ""; private long batchSize = 1000L; private boolean reduceStopForBest = false; + private Map filterTemplateValues = new HashMap<>(); public QueryIteratorReqBuilder databaseName(String databaseName) { this.databaseName = databaseName; @@ -199,6 +225,11 @@ public QueryIteratorReqBuilder ignoreGrowing(boolean ignoreGrowing) { return this; } + public QueryIteratorReqBuilder timezone(String timezone) { + this.timezone = timezone; + return this; + } + public QueryIteratorReqBuilder batchSize(long batchSize) { this.batchSize = batchSize; return this; @@ -209,6 +240,11 @@ public QueryIteratorReqBuilder reduceStopForBest(boolean reduceStopForBest) { return this; } + public QueryIteratorReqBuilder filterTemplateValues(Map filterTemplateValues) { + this.filterTemplateValues = filterTemplateValues; + return this; + } + public QueryIteratorReq build() { return new QueryIteratorReq(this); } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryReq.java index 547b16bcc..5c64ecf99 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryReq.java @@ -186,7 +186,7 @@ public String toString() { ", ignoreGrowing=" + ignoreGrowing + ", timezone='" + timezone + '\'' + ", queryParams=" + queryParams + - ", filterTemplateValues=" + filterTemplateValues + +// ", filterTemplateValues=" + filterTemplateValues + '}'; } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchIteratorReqV2.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchIteratorReqV2.java index c9aa4985e..126ecf790 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchIteratorReqV2.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchIteratorReqV2.java @@ -28,10 +28,21 @@ public class SearchIteratorReqV2 { private Map searchParams; private ConsistencyLevel consistencyLevel; private boolean ignoreGrowing; + private String timezone; private String groupByFieldName; private long batchSize; private Function, List> externalFilterFunc; + // Expression template, to improve expression parsing performance in complicated list + // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......] + // The long list of city will increase the time cost to parse this expression. + // So, we provide exprTemplateValues for this purpose, user can set filter like this: + // filter = "pk > {age} and city in {city}" + // filterTemplateValues = Map{"age": 3, "city": List{"beijing", "shanghai", ......}} + // Valid value of this map can be: + // Boolean, Long, Double, String, List, List, List, List + private Map filterTemplateValues; + private SearchIteratorReqV2(SearchIteratorReqV2Builder builder) { this.databaseName = builder.databaseName; this.collectionName = builder.collectionName; @@ -47,9 +58,11 @@ private SearchIteratorReqV2(SearchIteratorReqV2Builder builder) { this.searchParams = builder.searchParams; this.consistencyLevel = builder.consistencyLevel; this.ignoreGrowing = builder.ignoreGrowing; + this.timezone = builder.timezone; this.groupByFieldName = builder.groupByFieldName; this.batchSize = builder.batchSize; this.externalFilterFunc = builder.externalFilterFunc; + this.filterTemplateValues = builder.filterTemplateValues; } public static SearchIteratorReqV2Builder builder() { @@ -172,6 +185,10 @@ public void setIgnoreGrowing(boolean ignoreGrowing) { this.ignoreGrowing = ignoreGrowing; } + public String getTimezone() { + return timezone; + } + public String getGroupByFieldName() { return groupByFieldName; } @@ -196,6 +213,10 @@ public void setExternalFilterFunc(Function, List getFilterTemplateValues() { + return filterTemplateValues; + } + @Override public String toString() { return "SearchIteratorReqV2{" + @@ -213,6 +234,7 @@ public String toString() { ", searchParams=" + searchParams + ", consistencyLevel=" + consistencyLevel + ", ignoreGrowing=" + ignoreGrowing + + ", timezone='" + timezone + '\'' + ", groupByFieldName='" + groupByFieldName + '\'' + ", batchSize=" + batchSize + ", externalFilterFunc=" + externalFilterFunc + @@ -234,9 +256,11 @@ public static class SearchIteratorReqV2Builder { private Map searchParams = new HashMap<>(); private ConsistencyLevel consistencyLevel = null; private boolean ignoreGrowing = false; + private String timezone = ""; private String groupByFieldName = ""; private long batchSize = 1000L; private Function, List> externalFilterFunc = null; + private Map filterTemplateValues = new HashMap<>(); public SearchIteratorReqV2Builder databaseName(String databaseName) { this.databaseName = databaseName; @@ -312,6 +336,11 @@ public SearchIteratorReqV2Builder ignoreGrowing(boolean ignoreGrowing) { return this; } + public SearchIteratorReqV2Builder timezone(String timezone) { + this.timezone = timezone; + return this; + } + public SearchIteratorReqV2Builder groupByFieldName(String groupByFieldName) { this.groupByFieldName = groupByFieldName; return this; @@ -327,6 +356,11 @@ public SearchIteratorReqV2Builder externalFilterFunc(Function filterTemplateValues) { + this.filterTemplateValues = filterTemplateValues; + return this; + } + public SearchIteratorReqV2 build() { return new SearchIteratorReqV2(this); } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java index 514d10ced..e240f9b81 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java @@ -66,7 +66,6 @@ public class SearchReq { // filterTemplateValues = Map{"age": 3, "city": List{"beijing", "shanghai", ......}} // Valid value of this map can be: // Boolean, Long, Double, String, List, List, List, List - private Map filterTemplateValues; private SearchReq(SearchReqBuilder builder) { @@ -315,7 +314,7 @@ public String toString() { ", strictGroupSize=" + strictGroupSize + ", ranker=" + ranker + ", functionScore=" + functionScore + - ", filterTemplateValues=" + filterTemplateValues + +// ", filterTemplateValues=" + filterTemplateValues + '}'; } diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java index f0d71adae..361d4ec49 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java @@ -501,6 +501,10 @@ public static SearchRequest convertAnnSearchParam(AnnSearchReq annSearchReq, builder.setDslType(DslType.BoolExprV1); if (annSearchReq.getExpr() != null && !annSearchReq.getExpr().isEmpty()) { builder.setDsl(annSearchReq.getExpr()); + Map filterTemplateValues = annSearchReq.getFilterTemplateValues(); + filterTemplateValues.forEach((key, value) -> { + builder.putExprTemplateValues(key, deduceAndCreateTemplateValue(value)); + }); } if (consistencyLevel == null) {