Skip to content

Commit e1c6744

Browse files
committed
Add timezone/filterTemplate for QueryIterator/SearchIterator
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent b776006 commit e1c6744

10 files changed

Lines changed: 220 additions & 35 deletions

File tree

sdk-core/src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
import io.milvus.param.collection.FieldType;
2828
import io.milvus.param.dml.QueryIteratorParam;
2929
import io.milvus.param.dml.SearchIteratorParam;
30+
import io.milvus.v2.common.ConsistencyLevel;
3031
import io.milvus.v2.common.IndexParam;
3132
import io.milvus.v2.service.collection.request.CreateCollectionReq;
3233
import io.milvus.v2.service.vector.request.QueryIteratorReq;
3334
import io.milvus.v2.service.vector.request.SearchIteratorReq;
34-
import io.milvus.v2.service.vector.request.data.BaseVector;
35+
import io.milvus.v2.service.vector.request.data.*;
36+
import org.apache.commons.lang3.StringUtils;
3537

3638
import java.nio.ByteBuffer;
3739
import java.util.ArrayList;
@@ -59,6 +61,27 @@ public static QueryIteratorParam convertV2Req(QueryIteratorReq queryIteratorReq)
5961
return builder.build();
6062
}
6163

64+
public static QueryIteratorReq convertV1Param(QueryIteratorParam param) {
65+
ConsistencyLevel level = null;
66+
if (param.getConsistencyLevel() != null) {
67+
level = ConsistencyLevel.valueOf(param.getConsistencyLevel().name());
68+
}
69+
70+
return QueryIteratorReq.builder()
71+
.databaseName(param.getDatabaseName())
72+
.collectionName(param.getCollectionName())
73+
.partitionNames(param.getPartitionNames())
74+
.expr(param.getExpr())
75+
.outputFields(param.getOutFields())
76+
.offset(param.getOffset())
77+
.limit(param.getLimit())
78+
.ignoreGrowing(param.isIgnoreGrowing())
79+
.batchSize(param.getBatchSize())
80+
.reduceStopForBest(param.isReduceStopForBest())
81+
.consistencyLevel(level)
82+
.build();
83+
}
84+
6285
public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorReq) {
6386
MetricType metricType = MetricType.None;
6487
if (searchIteratorReq.getMetricType() != IndexParam.MetricType.INVALID) {
@@ -130,6 +153,67 @@ public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorR
130153
return builder.build();
131154
}
132155

156+
public static SearchIteratorReq convertV1Param(SearchIteratorParam param) {
157+
ConsistencyLevel level = null;
158+
if (param.getConsistencyLevel() != null) {
159+
level = ConsistencyLevel.valueOf(param.getConsistencyLevel().name());
160+
}
161+
162+
IndexParam.MetricType metricType = IndexParam.MetricType.INVALID;
163+
if (StringUtils.isEmpty(param.getMetricType())) {
164+
metricType = IndexParam.MetricType.valueOf(param.getMetricType());
165+
}
166+
167+
List<BaseVector> vectors = new ArrayList<>();
168+
switch (param.getPlType()) {
169+
case FloatVector: {
170+
List<List<Float>> data = (List<List<Float>>) param.getVectors();
171+
data.forEach(vector -> vectors.add(new FloatVec(vector)));
172+
break;
173+
}
174+
case BinaryVector: {
175+
List<ByteBuffer> data = (List<ByteBuffer>) param.getVectors();
176+
data.forEach(vector -> vectors.add(new BinaryVec(vector)));
177+
break;
178+
}
179+
case Float16Vector: {
180+
List<ByteBuffer> data = (List<ByteBuffer>) param.getVectors();
181+
data.forEach(vector -> vectors.add(new Float16Vec(vector)));
182+
break;
183+
}
184+
case BFloat16Vector: {
185+
List<ByteBuffer> data = (List<ByteBuffer>) param.getVectors();
186+
data.forEach(vector -> vectors.add(new BFloat16Vec(vector)));
187+
break;
188+
}
189+
case SparseFloatVector: {
190+
List<SortedMap<Long, Float>> data = (List<SortedMap<Long, Float>>) param.getVectors();
191+
data.forEach(vector -> vectors.add(new SparseFloatVec(vector)));
192+
break;
193+
}
194+
default:
195+
throw new ParamException("Unsupported vector type.");
196+
}
197+
198+
return SearchIteratorReq.builder()
199+
.databaseName(param.getDatabaseName())
200+
.collectionName(param.getCollectionName())
201+
.partitionNames(param.getPartitionNames())
202+
.vectorFieldName(param.getVectorFieldName())
203+
.vectors(vectors)
204+
.consistencyLevel(level)
205+
.metricType(metricType)
206+
.limit(param.getTopK())
207+
.expr(param.getExpr())
208+
.outputFields(param.getOutFields())
209+
.roundDecimal(param.getRoundDecimal())
210+
.params(param.getParams())
211+
.groupByFieldName(param.getGroupByFieldName())
212+
.ignoreGrowing(param.isIgnoreGrowing())
213+
.batchSize(param.getBatchSize())
214+
.build();
215+
}
216+
133217
public static FieldType convertV2Field(CreateCollectionReq.FieldSchema schema) {
134218
FieldType.Builder builder = FieldType.newBuilder()
135219
.withName(schema.getName())

sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121

2222
import io.milvus.grpc.*;
2323
import io.milvus.param.Constant;
24-
import io.milvus.param.ParamUtils;
2524
import io.milvus.param.collection.FieldType;
2625
import io.milvus.param.dml.QueryIteratorParam;
27-
import io.milvus.param.dml.QueryParam;
2826
import io.milvus.response.QueryResultsWrapper;
2927
import io.milvus.v2.service.collection.request.CreateCollectionReq;
3028
import io.milvus.v2.service.vector.request.QueryIteratorReq;
29+
import io.milvus.v2.service.vector.request.QueryReq;
3130
import io.milvus.v2.utils.RpcUtils;
31+
import io.milvus.v2.utils.VectorUtils;
3232
import org.apache.commons.lang3.StringUtils;
3333
import org.slf4j.Logger;
3434
import org.slf4j.LoggerFactory;
@@ -44,7 +44,7 @@ public class QueryIterator {
4444
private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
4545
private final FieldType primaryField;
4646

47-
private final QueryIteratorParam queryIteratorParam;
47+
private final QueryIteratorReq queryIteratorReq;
4848
private final int batchSize;
4949
private final long limit;
5050
private final String expr;
@@ -61,7 +61,7 @@ public QueryIterator(QueryIteratorParam queryIteratorParam,
6161
this.iteratorCache = new IteratorCache();
6262
this.blockingStub = blockingStub;
6363
this.primaryField = primaryField;
64-
this.queryIteratorParam = queryIteratorParam;
64+
this.queryIteratorReq = IteratorAdapterV2.convertV1Param(queryIteratorParam);
6565

6666
this.batchSize = (int) queryIteratorParam.getBatchSize();
6767
this.expr = queryIteratorParam.getExpr();
@@ -78,15 +78,14 @@ public QueryIterator(QueryIteratorReq queryIteratorReq,
7878
CreateCollectionReq.FieldSchema primaryField) {
7979
this.iteratorCache = new IteratorCache();
8080
this.blockingStub = blockingStub;
81-
IteratorAdapterV2 adapter = new IteratorAdapterV2();
82-
this.queryIteratorParam = adapter.convertV2Req(queryIteratorReq);
83-
this.primaryField = adapter.convertV2Field(primaryField);
81+
this.queryIteratorReq = queryIteratorReq;
82+
this.primaryField = IteratorAdapterV2.convertV2Field(primaryField);
8483

8584

86-
this.batchSize = (int) queryIteratorParam.getBatchSize();
87-
this.expr = queryIteratorParam.getExpr();
88-
this.limit = queryIteratorParam.getLimit();
89-
this.offset = queryIteratorParam.getOffset();
85+
this.batchSize = (int) queryIteratorReq.getBatchSize();
86+
this.expr = queryIteratorReq.getExpr();
87+
this.limit = queryIteratorReq.getLimit();
88+
this.offset = queryIteratorReq.getOffset();
9089
this.rpcUtils = new RpcUtils();
9190

9291
setupTsByRequest();
@@ -208,24 +207,27 @@ private boolean isResSufficient(List<QueryResultsWrapper.RowRecord> ret) {
208207
private QueryResults executeQuery(String expr, long offset, long limit, long ts, boolean isSeek) {
209208
// for seeking offset, no need to return output fields
210209
List<String> outputFields = new ArrayList<>();
211-
boolean reduceStopForBest = queryIteratorParam.isReduceStopForBest();
210+
boolean reduceStopForBest = queryIteratorReq.isReduceStopForBest();
212211
if (!isSeek) {
213-
outputFields = queryIteratorParam.getOutFields();
212+
outputFields = queryIteratorReq.getOutputFields();
214213
reduceStopForBest = false;
215214
}
216-
QueryParam queryParam = QueryParam.newBuilder()
217-
.withDatabaseName(queryIteratorParam.getDatabaseName())
218-
.withCollectionName(queryIteratorParam.getCollectionName())
219-
.withConsistencyLevel(queryIteratorParam.getConsistencyLevel())
220-
.withPartitionNames(queryIteratorParam.getPartitionNames())
221-
.withOutFields(outputFields)
222-
.withExpr(expr)
223-
.withOffset(offset)
224-
.withLimit(limit)
225-
.withIgnoreGrowing(queryIteratorParam.isIgnoreGrowing())
215+
QueryReq queryReq = QueryReq.builder()
216+
.databaseName(queryIteratorReq.getDatabaseName())
217+
.collectionName(queryIteratorReq.getCollectionName())
218+
.partitionNames(queryIteratorReq.getPartitionNames())
219+
.consistencyLevel(queryIteratorReq.getConsistencyLevel())
220+
.outputFields(outputFields)
221+
.filter(expr)
222+
.offset(offset)
223+
.limit(limit)
224+
.ignoreGrowing(queryIteratorReq.isIgnoreGrowing())
225+
.timezone(queryIteratorReq.getTimezone())
226+
.filterTemplateValues(queryIteratorReq.getFilterTemplateValues())
226227
.build();
227228

228-
QueryRequest queryRequest = ParamUtils.convertQueryParam(queryParam);
229+
VectorUtils vectorUtils = new VectorUtils();
230+
QueryRequest queryRequest = vectorUtils.ConvertToGrpcQueryRequest(queryReq);
229231
QueryRequest.Builder builder = queryRequest.toBuilder();
230232
// reduce stop for best
231233
builder.addQueryParams(KeyValuePair.newBuilder()
@@ -246,7 +248,7 @@ private QueryResults executeQuery(String expr, long offset, long limit, long ts,
246248
builder.setUseDefaultConsistency(true);
247249

248250
QueryResults response = rpcUtils.retry(() -> blockingStub.query(builder.build()));
249-
String title = String.format("QueryRequest collectionName:%s", queryIteratorParam.getCollectionName());
251+
String title = String.format("QueryRequest collectionName:%s", queryIteratorReq.getCollectionName());
250252
rpcUtils.handleResponse(title, response.getStatus());
251253
return response;
252254
}

sdk-core/src/main/java/io/milvus/orm/iterator/SearchIteratorV2.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import static io.milvus.param.Constant.UNLIMITED;
4343

4444
public class SearchIteratorV2 {
45-
private static final Logger logger = LoggerFactory.getLogger(SearchIterator.class);
45+
private static final Logger logger = LoggerFactory.getLogger(SearchIteratorV2.class);
4646
private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
4747

4848
private final SearchIteratorReqV2 searchIteratorReq;
@@ -86,12 +86,12 @@ private void checkParams() {
8686
int rows = searchIteratorReq.getVectors().size();
8787
if (rows > 1) {
8888
ExceptionUtils.throwUnExpectedException("SearchIterator does not support processing multiple vectors simultaneously");
89-
} else if (rows <= 0) {
89+
} else if (rows == 0) {
9090
ExceptionUtils.throwUnExpectedException("The vector data for search cannot be empty");
9191
}
9292

93-
if (searchIteratorReq.getTopK() != UNLIMITED) {
94-
this.leftResCnt = searchIteratorReq.getTopK();
93+
if (searchIteratorReq.getLimit() != UNLIMITED) {
94+
this.leftResCnt = (int) searchIteratorReq.getLimit();
9595
}
9696
}
9797

@@ -117,15 +117,17 @@ private SearchResults executeSearch(int limit) {
117117
.databaseName(searchIteratorReq.getDatabaseName())
118118
.annsField(searchIteratorReq.getVectorFieldName())
119119
.data(searchIteratorReq.getVectors())
120-
.topK(limit)
120+
.limit(limit)
121121
.filter(searchIteratorReq.getFilter())
122122
.consistencyLevel(searchIteratorReq.getConsistencyLevel())
123123
.outputFields(searchIteratorReq.getOutputFields())
124124
.roundDecimal(searchIteratorReq.getRoundDecimal())
125125
.searchParams(searchParams)
126126
.metricType(searchIteratorReq.getMetricType())
127+
.timezone(searchIteratorReq.getTimezone())
127128
.ignoreGrowing(searchIteratorReq.isIgnoreGrowing())
128129
.groupByFieldName(searchIteratorReq.getGroupByFieldName())
130+
.filterTemplateValues(searchIteratorReq.getFilterTemplateValues())
129131
.build();
130132
SearchRequest searchRequest = new VectorUtils().ConvertToGrpcSearchRequest(request);
131133
SearchResults response = rpcUtils.retry(() -> this.blockingStub.search(searchRequest));

sdk-core/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import io.milvus.v2.common.IndexParam;
2323
import io.milvus.v2.service.vector.request.data.BaseVector;
2424

25+
import java.util.HashMap;
2526
import java.util.List;
27+
import java.util.Map;
2628

2729
public class AnnSearchReq {
2830
private String vectorFieldName;
@@ -37,6 +39,16 @@ public class AnnSearchReq {
3739
private IndexParam.MetricType metricType;
3840
private String timezone;
3941

42+
// Expression template, to improve expression parsing performance in complicated list
43+
// Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]
44+
// The long list of city will increase the time cost to parse this expression.
45+
// So, we provide exprTemplateValues for this purpose, user can set filter like this:
46+
// filter = "pk > {age} and city in {city}"
47+
// filterTemplateValues = Map{"age": 3, "city": List<String>{"beijing", "shanghai", ......}}
48+
// Valid value of this map can be:
49+
// Boolean, Long, Double, String, List<Boolean>, List<Long>, List<Double>, List<String>
50+
private Map<String, Object> filterTemplateValues;
51+
4052
private AnnSearchReq(AnnSearchReqBuilder builder) {
4153
this.vectorFieldName = builder.vectorFieldName;
4254
this.topK = builder.topK;
@@ -47,6 +59,7 @@ private AnnSearchReq(AnnSearchReqBuilder builder) {
4759
this.params = builder.params;
4860
this.metricType = builder.metricType;
4961
this.timezone = builder.timezone;
62+
this.filterTemplateValues = builder.filterTemplateValues;
5063
}
5164

5265
public static AnnSearchReqBuilder builder() {
@@ -129,6 +142,10 @@ public String getTimezone() {
129142
return timezone;
130143
}
131144

145+
public Map<String, Object> getFilterTemplateValues() {
146+
return filterTemplateValues;
147+
}
148+
132149
@Override
133150
public String toString() {
134151
return "AnnSearchReq{" +
@@ -141,6 +158,7 @@ public String toString() {
141158
", params='" + params + '\'' +
142159
", metricType=" + metricType +
143160
", timezone='" + timezone + '\'' +
161+
// ", filterTemplateValues=" + filterTemplateValues +
144162
'}';
145163
}
146164

@@ -154,6 +172,7 @@ public static class AnnSearchReqBuilder {
154172
private String params;
155173
private IndexParam.MetricType metricType = null;
156174
private String timezone = "";
175+
private Map<String, Object> filterTemplateValues = new HashMap<>();
157176

158177
public AnnSearchReqBuilder vectorFieldName(String vectorFieldName) {
159178
this.vectorFieldName = vectorFieldName;
@@ -208,6 +227,11 @@ public AnnSearchReqBuilder timezone(String timezone) {
208227
return this;
209228
}
210229

230+
public AnnSearchReqBuilder filterTemplateValues(Map<String, Object> filterTemplateValues) {
231+
this.filterTemplateValues = filterTemplateValues;
232+
return this;
233+
}
234+
211235
public AnnSearchReq build() {
212236
return new AnnSearchReq(this);
213237
}

sdk-core/src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public String toString() {
109109
", partitionName='" + partitionName + '\'' +
110110
", filter='" + filter + '\'' +
111111
", ids=" + ids +
112-
", filterTemplateValues=" + filterTemplateValues +
112+
// ", filterTemplateValues=" + filterTemplateValues +
113113
'}';
114114
}
115115

0 commit comments

Comments
 (0)