Skip to content

Commit d172d62

Browse files
authored
Align getLoadState/loadCollection/loadPartitions with PyMilvus (#1887)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 4a44668 commit d172d62

5 files changed

Lines changed: 80 additions & 61 deletions

File tree

sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import java.util.ArrayList;
4242
import java.util.Collections;
4343
import java.util.List;
44+
import java.util.concurrent.TimeUnit;
4445

4546
public class CollectionService extends BaseService {
4647
public IndexService indexService = new IndexService();
@@ -437,21 +438,28 @@ public Void renameCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin
437438
public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, LoadCollectionReq request) {
438439
String dbName = request.getDatabaseName();
439440
String collectionName = request.getCollectionName();
441+
boolean sync = Boolean.TRUE.equals(request.getSync());
442+
boolean refresh = Boolean.TRUE.equals(request.getRefresh());
443+
boolean skipLoadDynamicField = Boolean.TRUE.equals(request.getSkipLoadDynamicField());
440444
String title = String.format("Load collection: '%s' in database: '%s'", collectionName, dbName);
441445
LoadCollectionRequest.Builder builder = LoadCollectionRequest.newBuilder()
442446
.setCollectionName(collectionName)
443447
.setReplicaNumber(request.getNumReplicas())
444-
.setRefresh(request.getRefresh())
448+
.setRefresh(refresh)
445449
.addAllLoadFields(request.getLoadFields())
446-
.setSkipLoadDynamicField(request.getSkipLoadDynamicField())
450+
.setSkipLoadDynamicField(skipLoadDynamicField)
447451
.addAllResourceGroups(request.getResourceGroups());
448452
if (StringUtils.isNotEmpty(dbName)) {
449453
builder.setDbName(dbName);
450454
}
451-
Status status = blockingStub.loadCollection(builder.build());
455+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
456+
if (request.getTimeout() != null && request.getTimeout() > 0) {
457+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
458+
}
459+
Status status = tempBlockingStub.loadCollection(builder.build());
452460
rpcUtils.handleResponse(title, status);
453-
if (request.getSync()) {
454-
WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout());
461+
if (sync) {
462+
waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), refresh);
455463
}
456464

457465
return null;
@@ -460,17 +468,22 @@ public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
460468
public Void refreshLoad(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, RefreshLoadReq request) {
461469
String dbName = request.getDatabaseName();
462470
String collectionName = request.getCollectionName();
471+
boolean sync = Boolean.TRUE.equals(request.getSync());
463472
String title = String.format("Refresh load collection: '%s' in database: '%s'", collectionName, dbName);
464473
LoadCollectionRequest.Builder builder = LoadCollectionRequest.newBuilder()
465474
.setCollectionName(collectionName)
466475
.setRefresh(true);
467476
if (StringUtils.isNotEmpty(dbName)) {
468477
builder.setDbName(dbName);
469478
}
470-
Status status = blockingStub.loadCollection(builder.build());
479+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
480+
if (request.getTimeout() != null && request.getTimeout() > 0) {
481+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
482+
}
483+
Status status = tempBlockingStub.loadCollection(builder.build());
471484
rpcUtils.handleResponse(title, status);
472-
if (request.getSync()) {
473-
WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout());
485+
if (sync) {
486+
waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), true);
474487
}
475488

476489
return null;
@@ -500,9 +513,7 @@ public GetLoadStateResp getLoadStateV2(MilvusServiceGrpc.MilvusServiceBlockingSt
500513
GetLoadStateResp.GetLoadStateRespBuilder respBuilder = GetLoadStateResp.builder()
501514
.state(response.getState());
502515
if (response.getState() == LoadState.LoadStateLoading) {
503-
GetLoadingProgressResponse progressResponse = getLoadingProgressResponse(blockingStub, request);
504-
respBuilder.progress(progressResponse.getProgress())
505-
.refreshProgress(progressResponse.getRefreshProgress());
516+
respBuilder.progress(getLoadingProgress(blockingStub, request, false, null));
506517
}
507518

508519
return respBuilder.build();
@@ -535,8 +546,17 @@ private GetLoadStateResponse getLoadStateResponse(MilvusServiceGrpc.MilvusServic
535546
return response;
536547
}
537548

538-
private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
539-
GetLoadStateReq request) {
549+
private Long getLoadingProgress(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
550+
GetLoadStateReq request,
551+
boolean refreshLoad,
552+
Long timeoutMs) {
553+
GetLoadingProgressResponse response = getLoadingProgressInternal(blockingStub, request, timeoutMs);
554+
return refreshLoad ? response.getRefreshProgress() : response.getProgress();
555+
}
556+
557+
private GetLoadingProgressResponse getLoadingProgressInternal(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
558+
GetLoadStateReq request,
559+
Long timeoutMs) {
540560
String dbName = request.getDatabaseName();
541561
String collectionName = request.getCollectionName();
542562
String partitionName = request.getPartitionName();
@@ -548,7 +568,11 @@ private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.
548568
if (StringUtils.isNotEmpty(partitionName)) {
549569
builder.addPartitionNames(partitionName);
550570
}
551-
GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build());
571+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
572+
if (timeoutMs != null && timeoutMs > 0) {
573+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
574+
}
575+
GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build());
552576
String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName);
553577
rpcUtils.handleResponse(title, response.getStatus());
554578
return response;
@@ -690,31 +714,28 @@ public Void dropCollectionFunction(MilvusServiceGrpc.MilvusServiceBlockingStub b
690714
return null;
691715
}
692716

693-
private void WaitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName,
694-
String collectionName, long timeoutMs) {
695-
long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
717+
private void waitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName,
718+
String collectionName, Long timeoutMs, boolean refreshLoad) {
719+
long startTime = System.currentTimeMillis();
720+
GetLoadStateReq request = GetLoadStateReq.builder()
721+
.databaseName(databaseName)
722+
.collectionName(collectionName)
723+
.build();
696724

697725
while (true) {
698-
// Call the getLoadState method
699-
boolean isLoaded = getLoadState(blockingStub, GetLoadStateReq.builder()
700-
.databaseName(databaseName)
701-
.collectionName(collectionName)
702-
.build());
703-
if (isLoaded) {
726+
if (getLoadingProgress(blockingStub, request, refreshLoad, timeoutMs) >= 100L) {
704727
return;
705728
}
706729

707-
// Check if timeout is exceeded
708-
if (System.currentTimeMillis() - startTime > timeoutMs) {
730+
if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) {
709731
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection timeout");
710732
}
711-
// Wait for a certain period before checking again
712733
try {
713-
Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
734+
Thread.sleep(500);
714735
} catch (InterruptedException e) {
715736
Thread.currentThread().interrupt();
716737
logger.error("Thread was interrupted, Failed to complete operation");
717-
return; // or handle interruption appropriately
738+
return;
718739
}
719740
}
720741
}

sdk-core/src/main/java/io/milvus/v2/service/collection/response/GetLoadStateResp.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,10 @@
2424
public class GetLoadStateResp {
2525
private LoadState state;
2626
private Long progress;
27-
private Long refreshProgress;
2827

2928
private GetLoadStateResp(GetLoadStateRespBuilder builder) {
3029
this.state = builder.state;
3130
this.progress = builder.progress;
32-
this.refreshProgress = builder.refreshProgress;
3331
}
3432

3533
public LoadState getState() {
@@ -52,21 +50,12 @@ public void setProgress(Long progress) {
5250
this.progress = progress;
5351
}
5452

55-
public Long getRefreshProgress() {
56-
return refreshProgress;
57-
}
58-
59-
public void setRefreshProgress(Long refreshProgress) {
60-
this.refreshProgress = refreshProgress;
61-
}
62-
6353
@Override
6454
public String toString() {
6555
return "GetLoadStateResp{" +
6656
"state=" + state +
6757
", stateName='" + getStateName() + '\'' +
6858
", progress=" + progress +
69-
", refreshProgress=" + refreshProgress +
7059
'}';
7160
}
7261

@@ -77,7 +66,6 @@ public static GetLoadStateRespBuilder builder() {
7766
public static class GetLoadStateRespBuilder {
7867
private LoadState state;
7968
private Long progress;
80-
private Long refreshProgress;
8169

8270
private GetLoadStateRespBuilder() {
8371
}
@@ -92,11 +80,6 @@ public GetLoadStateRespBuilder progress(Long progress) {
9280
return this;
9381
}
9482

95-
public GetLoadStateRespBuilder refreshProgress(Long refreshProgress) {
96-
this.refreshProgress = refreshProgress;
97-
return this;
98-
}
99-
10083
public GetLoadStateResp build() {
10184
return new GetLoadStateResp(this);
10285
}

sdk-core/src/main/java/io/milvus/v2/service/partition/PartitionService.java

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.commons.lang3.StringUtils;
2929

3030
import java.util.List;
31+
import java.util.concurrent.TimeUnit;
3132

3233
public class PartitionService extends BaseService {
3334
public Void createPartition(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreatePartitionReq request) {
@@ -130,24 +131,32 @@ public Void loadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
130131
String dbName = request.getDatabaseName();
131132
String collectionName = request.getCollectionName();
132133
List<String> partitionNames = request.getPartitionNames();
134+
boolean sync = Boolean.TRUE.equals(request.getSync());
135+
boolean refresh = Boolean.TRUE.equals(request.getRefresh());
136+
boolean skipLoadDynamicField = Boolean.TRUE.equals(request.getSkipLoadDynamicField());
133137
String title = String.format("Load partitions: %s in collection: '%s' in database: '%s'",
134138
partitionNames, collectionName, dbName);
135139

136140
LoadPartitionsRequest.Builder builder = LoadPartitionsRequest.newBuilder()
137141
.setCollectionName(collectionName)
138142
.addAllPartitionNames(partitionNames)
139143
.setReplicaNumber(request.getNumReplicas())
140-
.setRefresh(request.getRefresh())
144+
.setRefresh(refresh)
141145
.addAllLoadFields(request.getLoadFields())
142-
.setSkipLoadDynamicField(request.getSkipLoadDynamicField())
146+
.setSkipLoadDynamicField(skipLoadDynamicField)
143147
.addAllResourceGroups(request.getResourceGroups());
144148
if (StringUtils.isNotEmpty(dbName)) {
145149
builder.setDbName(dbName);
146150
}
147-
Status status = blockingStub.loadPartitions(builder.build());
151+
152+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
153+
if (request.getTimeout() != null && request.getTimeout() > 0) {
154+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
155+
}
156+
Status status = tempBlockingStub.loadPartitions(builder.build());
148157
rpcUtils.handleResponse(title, status);
149-
if (request.getSync()) {
150-
WaitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout());
158+
if (sync) {
159+
waitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout(), refresh);
151160
}
152161

153162
return null;
@@ -172,9 +181,9 @@ public Void releasePartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blocki
172181
return null;
173182
}
174183

175-
private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName,
176-
String collectionName, List<String> partitions, long timeoutMs) {
177-
long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
184+
private void waitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName,
185+
String collectionName, List<String> partitions, Long timeoutMs, boolean refreshLoad) {
186+
long startTime = System.currentTimeMillis();
178187

179188
while (true) {
180189
GetLoadingProgressRequest.Builder builder = GetLoadingProgressRequest.newBuilder()
@@ -183,24 +192,27 @@ private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub b
183192
if (StringUtils.isNotEmpty(dbName)) {
184193
builder.setDbName(dbName);
185194
}
186-
GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build());
195+
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
196+
if (timeoutMs != null && timeoutMs > 0) {
197+
tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
198+
}
199+
GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build());
187200
String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName);
188201
rpcUtils.handleResponse(title, response.getStatus());
189-
if (response.getProgress() >= 100) {
202+
long progress = refreshLoad ? response.getRefreshProgress() : response.getProgress();
203+
if (progress >= 100L) {
190204
return;
191205
}
192206

193-
// Check if timeout is exceeded
194-
if (System.currentTimeMillis() - startTime > timeoutMs) {
207+
if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) {
195208
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load partitions timeout");
196209
}
197-
// Wait for a certain period before checking again
198210
try {
199-
Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
211+
Thread.sleep(500);
200212
} catch (InterruptedException e) {
201213
Thread.currentThread().interrupt();
202214
logger.error("Thread was interrupted, failed to complete operation");
203-
return; // or handle interruption appropriately
215+
return;
204216
}
205217
}
206218
}

sdk-core/src/test/java/io/milvus/v2/BaseTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@
3131
import org.mockito.quality.Strictness;
3232

3333
import java.util.Collections;
34+
import java.util.concurrent.TimeUnit;
3435

3536
import static org.mockito.ArgumentMatchers.any;
37+
import static org.mockito.ArgumentMatchers.anyLong;
38+
import static org.mockito.ArgumentMatchers.eq;
3639
import static org.mockito.Mockito.when;
3740

3841
@ExtendWith(MockitoExtension.class)
@@ -46,6 +49,7 @@ public class BaseTest {
4649
@BeforeEach
4750
public void setUp() {
4851
client_v2.setBlockingStub(blockingStub);
52+
when(blockingStub.withDeadlineAfter(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(blockingStub);
4953

5054
Status successStatus = Status.newBuilder().setCode(0).build();
5155
BoolResponse trueResponse = BoolResponse.newBuilder().setStatus(successStatus).setValue(Boolean.TRUE).build();

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2980,7 +2980,6 @@ void testOperationsAcrossDB() {
29802980
Assertions.assertEquals(LoadState.LoadStateLoaded, loadStateResp.getState());
29812981
Assertions.assertEquals(LoadState.LoadStateLoaded.name(), loadStateResp.getStateName());
29822982
Assertions.assertNull(loadStateResp.getProgress());
2983-
Assertions.assertNull(loadStateResp.getRefreshProgress());
29842983

29852984
// specify the temp database name to release partition
29862985
client.releasePartitions(ReleasePartitionsReq.builder()

0 commit comments

Comments
 (0)