Skip to content

Commit bc26791

Browse files
committed
SearchIteratorV2
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 67dca68 commit bc26791

7 files changed

Lines changed: 361 additions & 0 deletions

File tree

sdk-core/src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
120
package io.milvus.orm.iterator;
221

322
import io.milvus.common.clientenum.ConsistencyLevelEnum;

sdk-core/src/main/java/io/milvus/orm/iterator/IteratorCache.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
120
package io.milvus.orm.iterator;
221

322
import io.milvus.response.QueryResultsWrapper;

sdk-core/src/main/java/io/milvus/orm/iterator/SearchIterator.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
120
package io.milvus.orm.iterator;
221

322

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package io.milvus.orm.iterator;
21+
22+
23+
import com.google.gson.reflect.TypeToken;
24+
import io.milvus.common.utils.ExceptionUtils;
25+
import io.milvus.common.utils.JsonUtils;
26+
import io.milvus.grpc.*;
27+
import io.milvus.param.Constant;
28+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
29+
import io.milvus.v2.service.collection.response.DescribeCollectionResp;
30+
import io.milvus.v2.service.vector.request.SearchIteratorReqV2;
31+
import io.milvus.v2.service.vector.request.SearchReq;
32+
import io.milvus.v2.service.vector.response.SearchResp;
33+
import io.milvus.v2.utils.ConvertUtils;
34+
import io.milvus.v2.utils.RpcUtils;
35+
import io.milvus.v2.utils.VectorUtils;
36+
import org.apache.commons.lang3.StringUtils;
37+
import org.slf4j.Logger;
38+
import org.slf4j.LoggerFactory;
39+
40+
import java.util.*;
41+
import java.util.concurrent.Callable;
42+
import java.util.function.Function;
43+
44+
import static io.milvus.param.Constant.MAX_BATCH_SIZE;
45+
import static io.milvus.param.Constant.UNLIMITED;
46+
47+
public class SearchIteratorV2 {
48+
private static final Logger logger = LoggerFactory.getLogger(SearchIterator.class);
49+
private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
50+
51+
private final SearchIteratorReqV2 searchIteratorReq;
52+
private final int batchSize;
53+
54+
private Map<String, Object> searchParams;
55+
private final RpcUtils rpcUtils;
56+
57+
private Integer leftResCnt = null;
58+
private Long collectionID = null;
59+
private Function<List<SearchResp.SearchResult>, List<SearchResp.SearchResult>> externalFilterFunc = null;
60+
private List<SearchResp.SearchResult> cache = new ArrayList<>();
61+
62+
// to support V2
63+
public SearchIteratorV2(SearchIteratorReqV2 searchIteratorReq,
64+
MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub) {
65+
this.blockingStub = blockingStub;
66+
this.searchIteratorReq = searchIteratorReq;
67+
68+
this.batchSize = (int) searchIteratorReq.getBatchSize();
69+
this.externalFilterFunc = searchIteratorReq.getExternalFilterFunc();
70+
this.rpcUtils = new RpcUtils();
71+
72+
checkParams();
73+
setupCollectionID();
74+
probeForCompability();
75+
}
76+
77+
private void checkParams() {
78+
if (this.batchSize < 0) {
79+
ExceptionUtils.throwUnExpectedException("Batch size cannot be less than zero");
80+
} else if (this.batchSize > MAX_BATCH_SIZE) {
81+
ExceptionUtils.throwUnExpectedException(String.format("Batch size cannot be larger than %d", MAX_BATCH_SIZE));
82+
}
83+
84+
searchParams = new HashMap<>();
85+
if (null != searchIteratorReq.getParams() && !searchIteratorReq.getParams().isEmpty()) {
86+
searchParams = JsonUtils.fromJson(searchIteratorReq.getParams(), new TypeToken<Map<String, Object>>() {}.getType());
87+
}
88+
89+
if (searchParams.containsKey(Constant.OFFSET) && (int)searchParams.get(Constant.OFFSET) > 0) {
90+
ExceptionUtils.throwUnExpectedException("Offset is not supported for SearchIterator");
91+
}
92+
93+
int rows = searchIteratorReq.getVectors().size();
94+
if (rows > 1) {
95+
ExceptionUtils.throwUnExpectedException("SearchIterator does not support processing multiple vectors simultaneously");
96+
} else if (rows <= 0) {
97+
ExceptionUtils.throwUnExpectedException("The vector data for search cannot be empty");
98+
}
99+
100+
if (searchIteratorReq.getTopK() != UNLIMITED) {
101+
this.leftResCnt = searchIteratorReq.getTopK();
102+
}
103+
}
104+
105+
private void setupCollectionID() {
106+
DescribeCollectionRequest.Builder builder = DescribeCollectionRequest.newBuilder()
107+
.setCollectionName(searchIteratorReq.getCollectionName());
108+
if (StringUtils.isNotEmpty(searchIteratorReq.getDatabaseName())) {
109+
builder.setDbName(searchIteratorReq.getDatabaseName());
110+
}
111+
DescribeCollectionResponse response = rpcUtils.retry(()->this.blockingStub.describeCollection(builder.build()));
112+
String title = String.format("DescribeCollectionRequest collectionName:%s", searchIteratorReq.getCollectionName());
113+
rpcUtils.handleResponse(title, response.getStatus());
114+
115+
DescribeCollectionResp respR = new ConvertUtils().convertDescCollectionResp(response);
116+
this.collectionID = respR.getCollectionID();
117+
}
118+
119+
private SearchResults executeSearch(int limit) {
120+
searchParams.put("search_iter_batch_size", limit);
121+
SearchReq request = SearchReq.builder()
122+
.collectionName(searchIteratorReq.getCollectionName())
123+
.partitionNames(searchIteratorReq.getPartitionNames())
124+
.databaseName(searchIteratorReq.getDatabaseName())
125+
.annsField(searchIteratorReq.getVectorFieldName())
126+
.searchParams(searchIteratorReq.getSearchParams())
127+
.data(searchIteratorReq.getVectors())
128+
.topK(limit)
129+
.filter(searchIteratorReq.getFilter())
130+
.consistencyLevel(searchIteratorReq.getConsistencyLevel())
131+
.outputFields(searchIteratorReq.getOutputFields())
132+
.roundDecimal(searchIteratorReq.getRoundDecimal())
133+
.searchParams(searchParams)
134+
.build();
135+
SearchRequest searchRequest = new VectorUtils().ConvertToGrpcSearchRequest(request);
136+
SearchResults response = rpcUtils.retry(()->this.blockingStub.search(searchRequest));
137+
String title = String.format("SearchRequest collectionName:%s", searchIteratorReq.getCollectionName());
138+
rpcUtils.handleResponse(title, response.getStatus());
139+
140+
return response;
141+
}
142+
143+
private void probeForCompability() {
144+
searchParams.put("collection_id", this.collectionID);
145+
searchParams.put("iterator", true);
146+
searchParams.put("search_iter_v2", true);
147+
searchParams.put("guarantee_timestamp", 0);
148+
149+
SearchResultData resultData = executeSearch(1).getResults();
150+
checkTokenExists(resultData);
151+
}
152+
153+
private void checkTokenExists(SearchResultData resultData) {
154+
String token = resultData.getSearchIteratorV2Results().getToken();
155+
if (StringUtils.isEmpty(token)) {
156+
ExceptionUtils.throwUnExpectedException("The server does not support Search Iterator V2." +
157+
" The search_iterator (v1) is used instead.\n" +
158+
" Please upgrade your Milvus server version to 2.5.2 and later,\n" +
159+
" or use a pymilvus version before 2.5.3 (excluded) to avoid this issue.");
160+
}
161+
}
162+
163+
public List<SearchResp.SearchResult> next() {
164+
if (leftResCnt != null && leftResCnt <= 0) {
165+
return null;
166+
}
167+
168+
if (externalFilterFunc == null) {
169+
return wrapReturnRes(_next());
170+
}
171+
172+
int targetLen = batchSize;
173+
if (leftResCnt != null && leftResCnt < targetLen) {
174+
targetLen = leftResCnt;
175+
}
176+
177+
while (true) {
178+
List<SearchResp.SearchResult> hits = _next();
179+
if (hits == null || hits.isEmpty()) {
180+
break;
181+
}
182+
183+
if (externalFilterFunc != null) {
184+
hits = externalFilterFunc.apply(hits);
185+
}
186+
187+
cache.addAll(hits);
188+
if (cache.size() >= targetLen) {
189+
break;
190+
}
191+
}
192+
193+
// create a list with elements from 0 to targetLen, and remove the elements from cache
194+
List<SearchResp.SearchResult> subList = cache.subList(0, targetLen);
195+
List<SearchResp.SearchResult> ret = new ArrayList<>(subList);
196+
subList.clear();
197+
return wrapReturnRes(ret);
198+
}
199+
200+
private List<SearchResp.SearchResult> _next() {
201+
SearchResults response = executeSearch(batchSize);
202+
checkTokenExists(response.getResults());
203+
SearchIteratorV2Results iterInfo = response.getResults().getSearchIteratorV2Results();
204+
searchParams.put("search_iter_last_bound", iterInfo.getLastBound());
205+
206+
if (!searchParams.containsKey("search_iter_id")) {
207+
searchParams.put("search_iter_id", iterInfo.getToken());
208+
}
209+
210+
long ts = (long)searchParams.get("guarantee_timestamp");
211+
if (ts <= 0) {
212+
if (response.getSessionTs() > 0) {
213+
searchParams.put("guarantee_timestamp", response.getSessionTs());
214+
} else {
215+
logger.warn("Failed to set up mvccTs from milvus server, use client-side ts instead");
216+
217+
long clientTs = System.currentTimeMillis() + 1000L;
218+
clientTs = clientTs << 18;
219+
searchParams.put("guarantee_timestamp", clientTs);
220+
}
221+
}
222+
223+
List<List<SearchResp.SearchResult>> res = new ConvertUtils().getEntities(response);
224+
return res.get(0);
225+
}
226+
227+
private List<SearchResp.SearchResult> wrapReturnRes(List<SearchResp.SearchResult> res) {
228+
if (leftResCnt == null) {
229+
return res;
230+
}
231+
232+
int currentLen = res.size();
233+
if (currentLen > leftResCnt) {
234+
res = res.subList(0, leftResCnt);
235+
}
236+
leftResCnt -= currentLen;
237+
return res;
238+
}
239+
}

sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.milvus.grpc.*;
2424
import io.milvus.orm.iterator.QueryIterator;
2525
import io.milvus.orm.iterator.SearchIterator;
26+
import io.milvus.orm.iterator.SearchIteratorV2;
2627

2728
import io.milvus.v2.service.database.DatabaseService;
2829
import io.milvus.v2.service.database.request.*;
@@ -544,6 +545,16 @@ public SearchIterator searchIterator(SearchIteratorReq request) {
544545
return rpcUtils.retry(()->vectorService.searchIterator(this.getRpcStub(), request));
545546
}
546547

548+
/**
549+
* Get searchIteratorV2 based on a vector field. Use expression to do filtering before search.
550+
*
551+
* @param request {@link SearchIteratorReqV2}
552+
* @return {status:result code, data: SearchIteratorV2}
553+
*/
554+
public SearchIteratorV2 searchIteratorV2(SearchIteratorReqV2 request) {
555+
return rpcUtils.retry(()->vectorService.searchIteratorV2(this.getRpcStub(), request));
556+
}
557+
547558
/////////////////////////////////////////////////////////////////////////////////////////////
548559
// Partition Operations
549560
/////////////////////////////////////////////////////////////////////////////////////////////

sdk-core/src/main/java/io/milvus/v2/service/vector/VectorService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ public SearchIterator searchIterator(MilvusServiceGrpc.MilvusServiceBlockingStub
251251
return new SearchIterator(request, blockingStub, pkField);
252252
}
253253

254+
public SearchIteratorV2 searchIteratorV2(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
255+
SearchIteratorReqV2 request) {
256+
return new SearchIteratorV2(request, blockingStub);
257+
}
258+
254259
public DeleteResp delete(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, DeleteReq request) {
255260
String title = String.format("DeleteRequest collectionName:%s", request.getCollectionName());
256261

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package io.milvus.v2.service.vector.request;
2+
3+
import com.google.common.collect.Lists;
4+
import io.milvus.v2.common.ConsistencyLevel;
5+
import io.milvus.v2.common.IndexParam;
6+
import io.milvus.v2.service.vector.request.data.BaseVector;
7+
import io.milvus.v2.service.vector.response.SearchResp;
8+
import lombok.Builder;
9+
import lombok.Data;
10+
import lombok.experimental.SuperBuilder;
11+
12+
import java.util.HashMap;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.function.Function;
16+
17+
@Data
18+
@SuperBuilder
19+
public class SearchIteratorReqV2 {
20+
private String databaseName;
21+
private String collectionName;
22+
@Builder.Default
23+
private List<String> partitionNames = Lists.newArrayList();
24+
@Builder.Default
25+
private IndexParam.MetricType metricType = IndexParam.MetricType.INVALID;
26+
private String vectorFieldName;
27+
@Builder.Default
28+
private int topK = -1;
29+
@Builder.Default
30+
private String filter = "";
31+
@Builder.Default
32+
private List<String> outputFields = Lists.newArrayList();
33+
@Builder.Default
34+
private List<BaseVector> vectors = Lists.newArrayList();
35+
@Builder.Default
36+
private int roundDecimal = -1;
37+
@Builder.Default
38+
private Map<String, Object> searchParams = new HashMap<>();
39+
@Builder.Default
40+
private ConsistencyLevel consistencyLevel = null;
41+
@Builder.Default
42+
private boolean ignoreGrowing = false;
43+
@Builder.Default
44+
private String groupByFieldName = "";
45+
@Builder.Default
46+
private long batchSize = 1000L;
47+
@Builder.Default
48+
private Function<List<SearchResp.SearchResult>, List<SearchResp.SearchResult>> externalFilterFunc = null;
49+
}

0 commit comments

Comments
 (0)