Skip to content

Commit af26233

Browse files
committed
Align getLoadState/loadCollection/loadPartitions with PyMilvus
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent a666cd2 commit af26233

4 files changed

Lines changed: 61 additions & 54 deletions

File tree

sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import java.util.ArrayList;
4343
import java.util.Collections;
4444
import java.util.List;
45+
import java.util.concurrent.TimeUnit;
4546

4647
public class CollectionService extends BaseService {
4748
public IndexService indexService = new IndexService();
@@ -469,10 +470,14 @@ public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
469470
if (StringUtils.isNotEmpty(dbName)) {
470471
builder.setDbName(dbName);
471472
}
472-
Status status = blockingStub.loadCollection(builder.build());
473+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
474+
if (request.getTimeout() != null && request.getTimeout() > 0) {
475+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
476+
}
477+
Status status = tempBlockingStub.loadCollection(builder.build());
473478
rpcUtils.handleResponse(title, status);
474479
if (request.getSync()) {
475-
WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout());
480+
waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), request.getRefresh());
476481
}
477482

478483
return null;
@@ -488,10 +493,14 @@ public Void refreshLoad(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub
488493
if (StringUtils.isNotEmpty(dbName)) {
489494
builder.setDbName(dbName);
490495
}
491-
Status status = blockingStub.loadCollection(builder.build());
496+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
497+
if (request.getTimeout() != null && request.getTimeout() > 0) {
498+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
499+
}
500+
Status status = tempBlockingStub.loadCollection(builder.build());
492501
rpcUtils.handleResponse(title, status);
493502
if (request.getSync()) {
494-
WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout());
503+
waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), true);
495504
}
496505

497506
return null;
@@ -521,9 +530,7 @@ public GetLoadStateResp getLoadStateV2(MilvusServiceGrpc.MilvusServiceBlockingSt
521530
GetLoadStateResp.GetLoadStateRespBuilder respBuilder = GetLoadStateResp.builder()
522531
.state(response.getState());
523532
if (response.getState() == LoadState.LoadStateLoading) {
524-
GetLoadingProgressResponse progressResponse = getLoadingProgressResponse(blockingStub, request);
525-
respBuilder.progress(progressResponse.getProgress())
526-
.refreshProgress(progressResponse.getRefreshProgress());
533+
respBuilder.progress(getLoadingProgress(blockingStub, request, false, null));
527534
}
528535

529536
return respBuilder.build();
@@ -556,8 +563,17 @@ private GetLoadStateResponse getLoadStateResponse(MilvusServiceGrpc.MilvusServic
556563
return response;
557564
}
558565

559-
private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
560-
GetLoadStateReq request) {
566+
private Long getLoadingProgress(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
567+
GetLoadStateReq request,
568+
boolean refreshLoad,
569+
Long timeoutMs) {
570+
GetLoadingProgressResponse response = getLoadingProgressInternal(blockingStub, request, timeoutMs);
571+
return refreshLoad ? response.getRefreshProgress() : response.getProgress();
572+
}
573+
574+
private GetLoadingProgressResponse getLoadingProgressInternal(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
575+
GetLoadStateReq request,
576+
Long timeoutMs) {
561577
String dbName = request.getDatabaseName();
562578
String collectionName = request.getCollectionName();
563579
String partitionName = request.getPartitionName();
@@ -569,7 +585,11 @@ private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.
569585
if (StringUtils.isNotEmpty(partitionName)) {
570586
builder.addPartitionNames(partitionName);
571587
}
572-
GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build());
588+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
589+
if (timeoutMs != null && timeoutMs > 0) {
590+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
591+
}
592+
GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build());
573593
String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName);
574594
rpcUtils.handleResponse(title, response.getStatus());
575595
return response;
@@ -711,31 +731,28 @@ public Void dropCollectionFunction(MilvusServiceGrpc.MilvusServiceBlockingStub b
711731
return null;
712732
}
713733

714-
private void WaitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName,
715-
String collectionName, long timeoutMs) {
716-
long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
734+
private void waitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName,
735+
String collectionName, Long timeoutMs, boolean refreshLoad) {
736+
long startTime = System.currentTimeMillis();
737+
GetLoadStateReq request = GetLoadStateReq.builder()
738+
.databaseName(databaseName)
739+
.collectionName(collectionName)
740+
.build();
717741

718742
while (true) {
719-
// Call the getLoadState method
720-
boolean isLoaded = getLoadState(blockingStub, GetLoadStateReq.builder()
721-
.databaseName(databaseName)
722-
.collectionName(collectionName)
723-
.build());
724-
if (isLoaded) {
743+
if (getLoadingProgress(blockingStub, request, refreshLoad, timeoutMs) >= 100L) {
725744
return;
726745
}
727746

728-
// Check if timeout is exceeded
729-
if (System.currentTimeMillis() - startTime > timeoutMs) {
747+
if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) {
730748
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection timeout");
731749
}
732-
// Wait for a certain period before checking again
733750
try {
734-
Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
751+
Thread.sleep(500);
735752
} catch (InterruptedException e) {
736753
Thread.currentThread().interrupt();
737754
logger.error("Thread was interrupted, Failed to complete operation");
738-
return; // or handle interruption appropriately
755+
return;
739756
}
740757
}
741758
}

sdk-core/src/main/java/io/milvus/v2/service/collection/response/GetLoadStateResp.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,10 @@
2424
public class GetLoadStateResp {
2525
private LoadState state;
2626
private Long progress;
27-
private Long refreshProgress;
2827

2928
private GetLoadStateResp(GetLoadStateRespBuilder builder) {
3029
this.state = builder.state;
3130
this.progress = builder.progress;
32-
this.refreshProgress = builder.refreshProgress;
3331
}
3432

3533
public LoadState getState() {
@@ -52,21 +50,12 @@ public void setProgress(Long progress) {
5250
this.progress = progress;
5351
}
5452

55-
public Long getRefreshProgress() {
56-
return refreshProgress;
57-
}
58-
59-
public void setRefreshProgress(Long refreshProgress) {
60-
this.refreshProgress = refreshProgress;
61-
}
62-
6353
@Override
6454
public String toString() {
6555
return "GetLoadStateResp{" +
6656
"state=" + state +
6757
", stateName='" + getStateName() + '\'' +
6858
", progress=" + progress +
69-
", refreshProgress=" + refreshProgress +
7059
'}';
7160
}
7261

@@ -77,7 +66,6 @@ public static GetLoadStateRespBuilder builder() {
7766
public static class GetLoadStateRespBuilder {
7867
private LoadState state;
7968
private Long progress;
80-
private Long refreshProgress;
8169

8270
private GetLoadStateRespBuilder() {
8371
}
@@ -92,11 +80,6 @@ public GetLoadStateRespBuilder progress(Long progress) {
9280
return this;
9381
}
9482

95-
public GetLoadStateRespBuilder refreshProgress(Long refreshProgress) {
96-
this.refreshProgress = refreshProgress;
97-
return this;
98-
}
99-
10083
public GetLoadStateResp build() {
10184
return new GetLoadStateResp(this);
10285
}

sdk-core/src/main/java/io/milvus/v2/service/partition/PartitionService.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.commons.lang3.StringUtils;
2929

3030
import java.util.List;
31+
import java.util.concurrent.TimeUnit;
3132

3233
public class PartitionService extends BaseService {
3334
public Void createPartition(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreatePartitionReq request) {
@@ -144,10 +145,14 @@ public Void loadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
144145
if (StringUtils.isNotEmpty(dbName)) {
145146
builder.setDbName(dbName);
146147
}
147-
Status status = blockingStub.loadPartitions(builder.build());
148+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
149+
if (request.getTimeout() != null && request.getTimeout() > 0) {
150+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
151+
}
152+
Status status = tempBlockingStub.loadPartitions(builder.build());
148153
rpcUtils.handleResponse(title, status);
149154
if (request.getSync()) {
150-
WaitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout());
155+
waitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout(), request.getRefresh());
151156
}
152157

153158
return null;
@@ -172,9 +177,9 @@ public Void releasePartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blocki
172177
return null;
173178
}
174179

175-
private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName,
176-
String collectionName, List<String> partitions, long timeoutMs) {
177-
long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
180+
private void waitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName,
181+
String collectionName, List<String> partitions, Long timeoutMs, boolean refreshLoad) {
182+
long startTime = System.currentTimeMillis();
178183

179184
while (true) {
180185
GetLoadingProgressRequest.Builder builder = GetLoadingProgressRequest.newBuilder()
@@ -183,24 +188,27 @@ private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub b
183188
if (StringUtils.isNotEmpty(dbName)) {
184189
builder.setDbName(dbName);
185190
}
186-
GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build());
191+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
192+
if (timeoutMs != null && timeoutMs > 0) {
193+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
194+
}
195+
GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build());
187196
String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName);
188197
rpcUtils.handleResponse(title, response.getStatus());
189-
if (response.getProgress() >= 100) {
198+
long progress = refreshLoad ? response.getRefreshProgress() : response.getProgress();
199+
if (progress >= 100L) {
190200
return;
191201
}
192202

193-
// Check if timeout is exceeded
194-
if (System.currentTimeMillis() - startTime > timeoutMs) {
203+
if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) {
195204
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load partitions timeout");
196205
}
197-
// Wait for a certain period before checking again
198206
try {
199-
Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
207+
Thread.sleep(500);
200208
} catch (InterruptedException e) {
201209
Thread.currentThread().interrupt();
202210
logger.error("Thread was interrupted, failed to complete operation");
203-
return; // or handle interruption appropriately
211+
return;
204212
}
205213
}
206214
}

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2977,7 +2977,6 @@ void testOperationsAcrossDB() {
29772977
Assertions.assertEquals(LoadState.LoadStateLoaded, loadStateResp.getState());
29782978
Assertions.assertEquals(LoadState.LoadStateLoaded.name(), loadStateResp.getStateName());
29792979
Assertions.assertNull(loadStateResp.getProgress());
2980-
Assertions.assertNull(loadStateResp.getRefreshProgress());
29812980

29822981
// specify the temp database name to release partition
29832982
client.releasePartitions(ReleasePartitionsReq.builder()

0 commit comments

Comments
 (0)