Skip to content

Commit 3153b72

Browse files
committed
Align getLoadState/loadCollection/loadPartitions with PyMilvus
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent a666cd2 commit 3153b72

5 files changed

Lines changed: 77 additions & 61 deletions

File tree

sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import java.util.ArrayList;
4343
import java.util.Collections;
4444
import java.util.List;
45+
import java.util.concurrent.TimeUnit;
4546

4647
public class CollectionService extends BaseService {
4748
public IndexService indexService = new IndexService();
@@ -458,21 +459,28 @@ public Void renameCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin
458459
public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, LoadCollectionReq request) {
459460
String dbName = request.getDatabaseName();
460461
String collectionName = request.getCollectionName();
462+
boolean sync = Boolean.TRUE.equals(request.getSync());
463+
boolean refresh = Boolean.TRUE.equals(request.getRefresh());
464+
boolean skipLoadDynamicField = Boolean.TRUE.equals(request.getSkipLoadDynamicField());
461465
String title = String.format("Load collection: '%s' in database: '%s'", collectionName, dbName);
462466
LoadCollectionRequest.Builder builder = LoadCollectionRequest.newBuilder()
463467
.setCollectionName(collectionName)
464468
.setReplicaNumber(request.getNumReplicas())
465-
.setRefresh(request.getRefresh())
469+
.setRefresh(refresh)
466470
.addAllLoadFields(request.getLoadFields())
467-
.setSkipLoadDynamicField(request.getSkipLoadDynamicField())
471+
.setSkipLoadDynamicField(skipLoadDynamicField)
468472
.addAllResourceGroups(request.getResourceGroups());
469473
if (StringUtils.isNotEmpty(dbName)) {
470474
builder.setDbName(dbName);
471475
}
472-
Status status = blockingStub.loadCollection(builder.build());
476+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
477+
if (request.getTimeout() != null && request.getTimeout() > 0) {
478+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
479+
}
480+
Status status = tempBlockingStub.loadCollection(builder.build());
473481
rpcUtils.handleResponse(title, status);
474-
if (request.getSync()) {
475-
WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout());
482+
if (sync) {
483+
waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), refresh);
476484
}
477485

478486
return null;
@@ -481,17 +489,22 @@ public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
481489
public Void refreshLoad(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, RefreshLoadReq request) {
482490
String dbName = request.getDatabaseName();
483491
String collectionName = request.getCollectionName();
492+
boolean sync = Boolean.TRUE.equals(request.getSync());
484493
String title = String.format("Refresh load collection: '%s' in database: '%s'", collectionName, dbName);
485494
LoadCollectionRequest.Builder builder = LoadCollectionRequest.newBuilder()
486495
.setCollectionName(collectionName)
487496
.setRefresh(true);
488497
if (StringUtils.isNotEmpty(dbName)) {
489498
builder.setDbName(dbName);
490499
}
491-
Status status = blockingStub.loadCollection(builder.build());
500+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
501+
if (request.getTimeout() != null && request.getTimeout() > 0) {
502+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
503+
}
504+
Status status = tempBlockingStub.loadCollection(builder.build());
492505
rpcUtils.handleResponse(title, status);
493-
if (request.getSync()) {
494-
WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout());
506+
if (sync) {
507+
waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), true);
495508
}
496509

497510
return null;
@@ -521,9 +534,7 @@ public GetLoadStateResp getLoadStateV2(MilvusServiceGrpc.MilvusServiceBlockingSt
521534
GetLoadStateResp.GetLoadStateRespBuilder respBuilder = GetLoadStateResp.builder()
522535
.state(response.getState());
523536
if (response.getState() == LoadState.LoadStateLoading) {
524-
GetLoadingProgressResponse progressResponse = getLoadingProgressResponse(blockingStub, request);
525-
respBuilder.progress(progressResponse.getProgress())
526-
.refreshProgress(progressResponse.getRefreshProgress());
537+
respBuilder.progress(getLoadingProgress(blockingStub, request, false, null));
527538
}
528539

529540
return respBuilder.build();
@@ -556,8 +567,17 @@ private GetLoadStateResponse getLoadStateResponse(MilvusServiceGrpc.MilvusServic
556567
return response;
557568
}
558569

559-
private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
560-
GetLoadStateReq request) {
570+
private Long getLoadingProgress(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
571+
GetLoadStateReq request,
572+
boolean refreshLoad,
573+
Long timeoutMs) {
574+
GetLoadingProgressResponse response = getLoadingProgressInternal(blockingStub, request, timeoutMs);
575+
return refreshLoad ? response.getRefreshProgress() : response.getProgress();
576+
}
577+
578+
private GetLoadingProgressResponse getLoadingProgressInternal(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
579+
GetLoadStateReq request,
580+
Long timeoutMs) {
561581
String dbName = request.getDatabaseName();
562582
String collectionName = request.getCollectionName();
563583
String partitionName = request.getPartitionName();
@@ -569,7 +589,11 @@ private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.
569589
if (StringUtils.isNotEmpty(partitionName)) {
570590
builder.addPartitionNames(partitionName);
571591
}
572-
GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build());
592+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
593+
if (timeoutMs != null && timeoutMs > 0) {
594+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
595+
}
596+
GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build());
573597
String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName);
574598
rpcUtils.handleResponse(title, response.getStatus());
575599
return response;
@@ -711,31 +735,28 @@ public Void dropCollectionFunction(MilvusServiceGrpc.MilvusServiceBlockingStub b
711735
return null;
712736
}
713737

714-
private void WaitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName,
715-
String collectionName, long timeoutMs) {
716-
long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
738+
private void waitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName,
739+
String collectionName, Long timeoutMs, boolean refreshLoad) {
740+
long startTime = System.currentTimeMillis();
741+
GetLoadStateReq request = GetLoadStateReq.builder()
742+
.databaseName(databaseName)
743+
.collectionName(collectionName)
744+
.build();
717745

718746
while (true) {
719-
// Call the getLoadState method
720-
boolean isLoaded = getLoadState(blockingStub, GetLoadStateReq.builder()
721-
.databaseName(databaseName)
722-
.collectionName(collectionName)
723-
.build());
724-
if (isLoaded) {
747+
if (getLoadingProgress(blockingStub, request, refreshLoad, timeoutMs) >= 100L) {
725748
return;
726749
}
727750

728-
// Check if timeout is exceeded
729-
if (System.currentTimeMillis() - startTime > timeoutMs) {
751+
if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) {
730752
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection timeout");
731753
}
732-
// Wait for a certain period before checking again
733754
try {
734-
Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
755+
Thread.sleep(500);
735756
} catch (InterruptedException e) {
736757
Thread.currentThread().interrupt();
737758
logger.error("Thread was interrupted, Failed to complete operation");
738-
return; // or handle interruption appropriately
759+
return;
739760
}
740761
}
741762
}

sdk-core/src/main/java/io/milvus/v2/service/collection/response/GetLoadStateResp.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,10 @@
2424
public class GetLoadStateResp {
2525
private LoadState state;
2626
private Long progress;
27-
private Long refreshProgress;
2827

2928
private GetLoadStateResp(GetLoadStateRespBuilder builder) {
3029
this.state = builder.state;
3130
this.progress = builder.progress;
32-
this.refreshProgress = builder.refreshProgress;
3331
}
3432

3533
public LoadState getState() {
@@ -52,21 +50,12 @@ public void setProgress(Long progress) {
5250
this.progress = progress;
5351
}
5452

55-
public Long getRefreshProgress() {
56-
return refreshProgress;
57-
}
58-
59-
public void setRefreshProgress(Long refreshProgress) {
60-
this.refreshProgress = refreshProgress;
61-
}
62-
6353
@Override
6454
public String toString() {
6555
return "GetLoadStateResp{" +
6656
"state=" + state +
6757
", stateName='" + getStateName() + '\'' +
6858
", progress=" + progress +
69-
", refreshProgress=" + refreshProgress +
7059
'}';
7160
}
7261

@@ -77,7 +66,6 @@ public static GetLoadStateRespBuilder builder() {
7766
public static class GetLoadStateRespBuilder {
7867
private LoadState state;
7968
private Long progress;
80-
private Long refreshProgress;
8169

8270
private GetLoadStateRespBuilder() {
8371
}
@@ -92,11 +80,6 @@ public GetLoadStateRespBuilder progress(Long progress) {
9280
return this;
9381
}
9482

95-
public GetLoadStateRespBuilder refreshProgress(Long refreshProgress) {
96-
this.refreshProgress = refreshProgress;
97-
return this;
98-
}
99-
10083
public GetLoadStateResp build() {
10184
return new GetLoadStateResp(this);
10285
}

sdk-core/src/main/java/io/milvus/v2/service/partition/PartitionService.java

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.commons.lang3.StringUtils;
2929

3030
import java.util.List;
31+
import java.util.concurrent.TimeUnit;
3132

3233
public class PartitionService extends BaseService {
3334
public Void createPartition(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreatePartitionReq request) {
@@ -130,24 +131,32 @@ public Void loadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
130131
String dbName = request.getDatabaseName();
131132
String collectionName = request.getCollectionName();
132133
List<String> partitionNames = request.getPartitionNames();
134+
boolean sync = Boolean.TRUE.equals(request.getSync());
135+
boolean refresh = Boolean.TRUE.equals(request.getRefresh());
136+
boolean skipLoadDynamicField = Boolean.TRUE.equals(request.getSkipLoadDynamicField());
133137
String title = String.format("Load partitions: %s in collection: '%s' in database: '%s'",
134138
partitionNames, collectionName, dbName);
135139

136140
LoadPartitionsRequest.Builder builder = LoadPartitionsRequest.newBuilder()
137141
.setCollectionName(collectionName)
138142
.addAllPartitionNames(partitionNames)
139143
.setReplicaNumber(request.getNumReplicas())
140-
.setRefresh(request.getRefresh())
144+
.setRefresh(refresh)
141145
.addAllLoadFields(request.getLoadFields())
142-
.setSkipLoadDynamicField(request.getSkipLoadDynamicField())
146+
.setSkipLoadDynamicField(skipLoadDynamicField)
143147
.addAllResourceGroups(request.getResourceGroups());
144148
if (StringUtils.isNotEmpty(dbName)) {
145149
builder.setDbName(dbName);
146150
}
147-
Status status = blockingStub.loadPartitions(builder.build());
151+
152+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
153+
if (request.getTimeout() != null && request.getTimeout() > 0) {
154+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
155+
}
156+
Status status = tempBlockingStub.loadPartitions(builder.build());
148157
rpcUtils.handleResponse(title, status);
149-
if (request.getSync()) {
150-
WaitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout());
158+
if (sync) {
159+
waitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout(), refresh);
151160
}
152161

153162
return null;
@@ -172,9 +181,9 @@ public Void releasePartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blocki
172181
return null;
173182
}
174183

175-
private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName,
176-
String collectionName, List<String> partitions, long timeoutMs) {
177-
long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
184+
private void waitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName,
185+
String collectionName, List<String> partitions, Long timeoutMs, boolean refreshLoad) {
186+
long startTime = System.currentTimeMillis();
178187

179188
while (true) {
180189
GetLoadingProgressRequest.Builder builder = GetLoadingProgressRequest.newBuilder()
@@ -183,24 +192,27 @@ private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub b
183192
if (StringUtils.isNotEmpty(dbName)) {
184193
builder.setDbName(dbName);
185194
}
186-
GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build());
195+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
196+
if (timeoutMs != null && timeoutMs > 0) {
197+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
198+
}
199+
GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build());
187200
String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName);
188201
rpcUtils.handleResponse(title, response.getStatus());
189-
if (response.getProgress() >= 100) {
202+
long progress = refreshLoad ? response.getRefreshProgress() : response.getProgress();
203+
if (progress >= 100L) {
190204
return;
191205
}
192206

193-
// Check if timeout is exceeded
194-
if (System.currentTimeMillis() - startTime > timeoutMs) {
207+
if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) {
195208
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load partitions timeout");
196209
}
197-
// Wait for a certain period before checking again
198210
try {
199-
Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
211+
Thread.sleep(500);
200212
} catch (InterruptedException e) {
201213
Thread.currentThread().interrupt();
202214
logger.error("Thread was interrupted, failed to complete operation");
203-
return; // or handle interruption appropriately
215+
return;
204216
}
205217
}
206218
}

sdk-core/src/test/java/io/milvus/v2/BaseTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public class BaseTest {
4646
@BeforeEach
4747
public void setUp() {
4848
client_v2.setBlockingStub(blockingStub);
49+
when(blockingStub.withDeadlineAfter(any(Long.class), any())).thenReturn(blockingStub);
4950

5051
Status successStatus = Status.newBuilder().setCode(0).build();
5152
BoolResponse trueResponse = BoolResponse.newBuilder().setStatus(successStatus).setValue(Boolean.TRUE).build();

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2977,7 +2977,6 @@ void testOperationsAcrossDB() {
29772977
Assertions.assertEquals(LoadState.LoadStateLoaded, loadStateResp.getState());
29782978
Assertions.assertEquals(LoadState.LoadStateLoaded.name(), loadStateResp.getStateName());
29792979
Assertions.assertNull(loadStateResp.getProgress());
2980-
Assertions.assertNull(loadStateResp.getRefreshProgress());
29812980

29822981
// specify the temp database name to release partition
29832982
client.releasePartitions(ReleasePartitionsReq.builder()

0 commit comments

Comments
 (0)