Skip to content

Commit a873189

Browse files
authored
Add session() interface (#1879)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 216c84e commit a873189

17 files changed

Lines changed: 516 additions & 5 deletions

sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ private QueryResults executeQuery(String expr, long offset, long limit, long ts,
217217
QueryReq queryReq = QueryReq.builder()
218218
.databaseName(queryIteratorReq.getDatabaseName())
219219
.collectionName(queryIteratorReq.getCollectionName())
220+
.clusterId(queryIteratorReq.getClusterId())
220221
.partitionNames(queryIteratorReq.getPartitionNames())
221222
.consistencyLevel(queryIteratorReq.getConsistencyLevel())
222223
.outputFields(outputFields)

sdk-core/src/main/java/io/milvus/orm/iterator/SearchIterator.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ public class SearchIterator {
7676
private Float filteredDistance = null;
7777
private Map<String, Object> params;
7878
private final RpcUtils rpcUtils;
79+
private String clusterId = "";
7980
private long sessionTs = 0;
8081

8182
public SearchIterator(SearchIteratorParam searchIteratorParam,
@@ -113,6 +114,7 @@ public SearchIterator(SearchIteratorReq searchIteratorReq,
113114
this.expr = this.searchIteratorParam.getExpr();
114115
this.topK = this.searchIteratorParam.getTopK();
115116
this.rpcUtils = new RpcUtils();
117+
this.clusterId = searchIteratorReq.getClusterId();
116118

117119
initParams();
118120
checkForSpecialIndexParam();
@@ -292,6 +294,13 @@ private SearchResults executeSearch(Map<String, Object> params, String nextExpr,
292294
.setKey(Constant.ITERATOR_FIELD)
293295
.setValue(String.valueOf(Boolean.TRUE))
294296
.build());
297+
if (StringUtils.isNotEmpty(clusterId)) {
298+
builder.addSearchParams(
299+
KeyValuePair.newBuilder()
300+
.setKey(Constant.CLUSTER_ID)
301+
.setValue(clusterId)
302+
.build());
303+
}
295304

296305
// pass the session ts to search interface
297306
builder.setGuaranteeTimestamp(ts).build();

sdk-core/src/main/java/io/milvus/orm/iterator/SearchIteratorV2.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ private SearchResults executeSearch(int limit) {
115115
.collectionName(searchIteratorReq.getCollectionName())
116116
.partitionNames(searchIteratorReq.getPartitionNames())
117117
.databaseName(searchIteratorReq.getDatabaseName())
118+
.clusterId(searchIteratorReq.getClusterId())
118119
.annsField(searchIteratorReq.getVectorFieldName())
119120
.data(searchIteratorReq.getVectors())
120121
.limit(limit)

sdk-core/src/main/java/io/milvus/param/Constant.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public class Constant {
3434
public static final String TIMEZONE = "timezone";
3535
public static final String REDUCE_STOP_FOR_BEST = "reduce_stop_for_best";
3636
public static final String ITERATOR_FIELD = "iterator";
37+
public static final String CLUSTER_ID = "cluster_id";
3738
public static final String GROUP_BY_FIELD = "group_by_field";
3839
public static final String GROUP_SIZE = "group_size";
3940
public static final String STRICT_GROUP_SIZE = "strict_group_size";

sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,13 @@ public String currentUsedDatabase() {
284284
return dbName;
285285
}
286286

287+
public MilvusClientV2Session session(String clusterId) {
288+
if (StringUtils.isEmpty(clusterId)) {
289+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "clusterId cannot be null or empty");
290+
}
291+
return new MilvusClientV2Session(this, clusterId);
292+
}
293+
287294

288295
/////////////////////////////////////////////////////////////////////////////////////////////
289296
// Database Operations
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package io.milvus.v2.client;
21+
22+
import io.milvus.orm.iterator.QueryIterator;
23+
import io.milvus.orm.iterator.SearchIterator;
24+
import io.milvus.orm.iterator.SearchIteratorV2;
25+
import io.milvus.v2.exception.ErrorCode;
26+
import io.milvus.v2.exception.MilvusClientException;
27+
import io.milvus.v2.service.vector.request.*;
28+
import io.milvus.v2.service.vector.response.GetResp;
29+
import io.milvus.v2.service.vector.response.QueryResp;
30+
import io.milvus.v2.service.vector.response.SearchResp;
31+
import org.apache.commons.lang3.StringUtils;
32+
33+
public class MilvusClientV2Session {
34+
private final MilvusClientV2 parent;
35+
private final String clusterId;
36+
private boolean closed = false;
37+
38+
MilvusClientV2Session(MilvusClientV2 parent, String clusterId) {
39+
this.parent = parent;
40+
this.clusterId = clusterId;
41+
}
42+
43+
public SearchResp search(SearchReq request) {
44+
ensureOpen();
45+
return parent.search(copy(request));
46+
}
47+
48+
public SearchResp hybridSearch(HybridSearchReq request) {
49+
ensureOpen();
50+
return parent.hybridSearch(copy(request));
51+
}
52+
53+
public QueryResp query(QueryReq request) {
54+
ensureOpen();
55+
return parent.query(copy(request));
56+
}
57+
58+
public QueryIterator queryIterator(QueryIteratorReq request) {
59+
ensureOpen();
60+
return parent.queryIterator(copy(request));
61+
}
62+
63+
public SearchIterator searchIterator(SearchIteratorReq request) {
64+
ensureOpen();
65+
return parent.searchIterator(copy(request));
66+
}
67+
68+
public SearchIteratorV2 searchIteratorV2(SearchIteratorReqV2 request) {
69+
ensureOpen();
70+
return parent.searchIteratorV2(copy(request));
71+
}
72+
73+
public GetResp get(GetReq request) {
74+
ensureOpen();
75+
return parent.get(copy(request));
76+
}
77+
78+
public void close() {
79+
closed = true;
80+
}
81+
82+
private void ensureOpen() {
83+
if (closed) {
84+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "MilvusClient session is closed");
85+
}
86+
}
87+
88+
private void checkClusterId(String requestClusterId) {
89+
if (StringUtils.isNotEmpty(requestClusterId) && !clusterId.equals(requestClusterId)) {
90+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "clusterId conflicts with session clusterId");
91+
}
92+
}
93+
94+
private SearchReq copy(SearchReq request) {
95+
checkClusterId(request.getClusterId());
96+
return SearchReq.builder()
97+
.databaseName(request.getDatabaseName())
98+
.collectionName(request.getCollectionName())
99+
.clusterId(clusterId)
100+
.partitionNames(request.getPartitionNames())
101+
.annsField(request.getAnnsField())
102+
.metricType(request.getMetricType())
103+
.filter(request.getFilter())
104+
.outputFields(request.getOutputFields())
105+
.data(request.getData())
106+
.ids(request.getIds())
107+
.offset(request.getOffset())
108+
.limit(request.getLimit())
109+
.roundDecimal(request.getRoundDecimal())
110+
.searchParams(request.getSearchParams())
111+
.guaranteeTimestamp(request.getGuaranteeTimestamp())
112+
.gracefulTime(request.getGracefulTime())
113+
.consistencyLevel(request.getConsistencyLevel())
114+
.ignoreGrowing(request.isIgnoreGrowing())
115+
.timezone(request.getTimezone())
116+
.groupByFieldName(request.getGroupByFieldName())
117+
.groupSize(request.getGroupSize())
118+
.strictGroupSize(request.getStrictGroupSize())
119+
.ranker(request.getRanker())
120+
.functionScore(request.getFunctionScore())
121+
.filterTemplateValues(request.getFilterTemplateValues())
122+
.highlighter(request.getHighlighter())
123+
.build();
124+
}
125+
126+
private HybridSearchReq copy(HybridSearchReq request) {
127+
checkClusterId(request.getClusterId());
128+
return HybridSearchReq.builder()
129+
.databaseName(request.getDatabaseName())
130+
.collectionName(request.getCollectionName())
131+
.clusterId(clusterId)
132+
.partitionNames(request.getPartitionNames())
133+
.searchRequests(request.getSearchRequests())
134+
.ranker(request.getRanker())
135+
.functionScore(request.getFunctionScore())
136+
.limit(request.getLimit())
137+
.outFields(request.getOutFields())
138+
.offset(request.getOffset())
139+
.roundDecimal(request.getRoundDecimal())
140+
.consistencyLevel(request.getConsistencyLevel())
141+
.groupByFieldName(request.getGroupByFieldName())
142+
.groupSize(request.getGroupSize())
143+
.strictGroupSize(request.getStrictGroupSize())
144+
.build();
145+
}
146+
147+
private QueryReq copy(QueryReq request) {
148+
checkClusterId(request.getClusterId());
149+
return QueryReq.builder()
150+
.databaseName(request.getDatabaseName())
151+
.collectionName(request.getCollectionName())
152+
.clusterId(clusterId)
153+
.partitionNames(request.getPartitionNames())
154+
.outputFields(request.getOutputFields())
155+
.ids(request.getIds())
156+
.filter(request.getFilter())
157+
.consistencyLevel(request.getConsistencyLevel())
158+
.offset(request.getOffset())
159+
.limit(request.getLimit())
160+
.ignoreGrowing(request.isIgnoreGrowing())
161+
.timezone(request.getTimezone())
162+
.queryParams(request.getQueryParams())
163+
.filterTemplateValues(request.getFilterTemplateValues())
164+
.build();
165+
}
166+
167+
private QueryIteratorReq copy(QueryIteratorReq request) {
168+
checkClusterId(request.getClusterId());
169+
return QueryIteratorReq.builder()
170+
.databaseName(request.getDatabaseName())
171+
.collectionName(request.getCollectionName())
172+
.clusterId(clusterId)
173+
.partitionNames(request.getPartitionNames())
174+
.outputFields(request.getOutputFields())
175+
.expr(request.getExpr())
176+
.consistencyLevel(request.getConsistencyLevel())
177+
.offset(request.getOffset())
178+
.limit(request.getLimit())
179+
.ignoreGrowing(request.isIgnoreGrowing())
180+
.timezone(request.getTimezone())
181+
.batchSize(request.getBatchSize())
182+
.reduceStopForBest(request.isReduceStopForBest())
183+
.filterTemplateValues(request.getFilterTemplateValues())
184+
.build();
185+
}
186+
187+
private SearchIteratorReq copy(SearchIteratorReq request) {
188+
checkClusterId(request.getClusterId());
189+
return SearchIteratorReq.builder()
190+
.databaseName(request.getDatabaseName())
191+
.collectionName(request.getCollectionName())
192+
.clusterId(clusterId)
193+
.partitionNames(request.getPartitionNames())
194+
.metricType(request.getMetricType())
195+
.vectorFieldName(request.getVectorFieldName())
196+
.limit(request.getLimit())
197+
.expr(request.getExpr())
198+
.outputFields(request.getOutputFields())
199+
.vectors(request.getVectors())
200+
.roundDecimal(request.getRoundDecimal())
201+
.params(request.getParams())
202+
.consistencyLevel(request.getConsistencyLevel())
203+
.ignoreGrowing(request.isIgnoreGrowing())
204+
.groupByFieldName(request.getGroupByFieldName())
205+
.batchSize(request.getBatchSize())
206+
.build();
207+
}
208+
209+
private SearchIteratorReqV2 copy(SearchIteratorReqV2 request) {
210+
checkClusterId(request.getClusterId());
211+
return SearchIteratorReqV2.builder()
212+
.databaseName(request.getDatabaseName())
213+
.collectionName(request.getCollectionName())
214+
.clusterId(clusterId)
215+
.partitionNames(request.getPartitionNames())
216+
.metricType(request.getMetricType())
217+
.vectorFieldName(request.getVectorFieldName())
218+
.limit(request.getLimit())
219+
.filter(request.getFilter())
220+
.outputFields(request.getOutputFields())
221+
.vectors(request.getVectors())
222+
.roundDecimal(request.getRoundDecimal())
223+
.searchParams(request.getSearchParams())
224+
.consistencyLevel(request.getConsistencyLevel())
225+
.ignoreGrowing(request.isIgnoreGrowing())
226+
.timezone(request.getTimezone())
227+
.groupByFieldName(request.getGroupByFieldName())
228+
.batchSize(request.getBatchSize())
229+
.externalFilterFunc(request.getExternalFilterFunc())
230+
.filterTemplateValues(request.getFilterTemplateValues())
231+
.build();
232+
}
233+
234+
private GetReq copy(GetReq request) {
235+
checkClusterId(request.getClusterId());
236+
return GetReq.builder()
237+
.databaseName(request.getDatabaseName())
238+
.collectionName(request.getCollectionName())
239+
.clusterId(clusterId)
240+
.partitionName(request.getPartitionName())
241+
.ids(request.getIds())
242+
.outputFields(request.getOutputFields())
243+
.build();
244+
}
245+
}

sdk-core/src/main/java/io/milvus/v2/service/vector/VectorService.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.slf4j.LoggerFactory;
4242

4343
import java.util.ArrayList;
44+
import java.util.Collections;
4445
import java.util.List;
4546
import java.util.concurrent.ConcurrentHashMap;
4647

@@ -350,11 +351,15 @@ public GetResp get(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, Get
350351
String collectionName = request.getCollectionName();
351352
String title = String.format("Get entities of collection: '%s' in database: '%s'", collectionName, dbName);
352353
logger.debug(title);
353-
QueryReq queryReq = QueryReq.builder()
354+
QueryReq.QueryReqBuilder queryReqBuilder = QueryReq.builder()
354355
.databaseName(dbName)
355356
.collectionName(collectionName)
356-
.ids(request.getIds())
357-
.build();
357+
.clusterId(request.getClusterId())
358+
.ids(request.getIds());
359+
if (StringUtils.isNotEmpty(request.getPartitionName())) {
360+
queryReqBuilder.partitionNames(Collections.singletonList(request.getPartitionName()));
361+
}
362+
QueryReq queryReq = queryReqBuilder.build();
358363
if (request.getOutputFields() != null) {
359364
queryReq.setOutputFields(request.getOutputFields());
360365
}

sdk-core/src/main/java/io/milvus/v2/service/vector/request/GetReq.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
public class GetReq {
2525
private String databaseName;
2626
private String collectionName;
27+
private String clusterId;
2728
private String partitionName = "";
2829
private List<Object> ids;
2930
private List<String> outputFields;
3031

3132
private GetReq(GetReqBuilder builder) {
3233
this.databaseName = builder.databaseName;
3334
this.collectionName = builder.collectionName;
35+
this.clusterId = builder.clusterId;
3436
this.partitionName = builder.partitionName;
3537
this.ids = builder.ids;
3638
this.outputFields = builder.outputFields;
@@ -56,6 +58,14 @@ public void setCollectionName(String collectionName) {
5658
this.collectionName = collectionName;
5759
}
5860

61+
public String getClusterId() {
62+
return clusterId;
63+
}
64+
65+
public void setClusterId(String clusterId) {
66+
this.clusterId = clusterId;
67+
}
68+
5969
public String getPartitionName() {
6070
return partitionName;
6171
}
@@ -85,6 +95,7 @@ public String toString() {
8595
return "GetReq{" +
8696
"databaseName='" + databaseName + '\'' +
8797
", collectionName='" + collectionName + '\'' +
98+
", clusterId='" + clusterId + '\'' +
8899
", partitionName='" + partitionName + '\'' +
89100
", ids=" + ids +
90101
", outputFields=" + outputFields +
@@ -94,6 +105,7 @@ public String toString() {
94105
public static class GetReqBuilder {
95106
private String databaseName;
96107
private String collectionName;
108+
private String clusterId;
97109
private String partitionName = "";
98110
private List<Object> ids;
99111
private List<String> outputFields;
@@ -108,6 +120,11 @@ public GetReqBuilder collectionName(String collectionName) {
108120
return this;
109121
}
110122

123+
public GetReqBuilder clusterId(String clusterId) {
124+
this.clusterId = clusterId;
125+
return this;
126+
}
127+
111128
public GetReqBuilder partitionName(String partitionName) {
112129
this.partitionName = partitionName;
113130
return this;

0 commit comments

Comments
 (0)