From 8bee1a094de8c424fcf144841b2b4dbe36c6ef81 Mon Sep 17 00:00:00 2001 From: yhmo Date: Wed, 3 Jun 2026 15:01:06 +0800 Subject: [PATCH] Align getLoadState/loadCollection/loadPartitions with PyMilvus Signed-off-by: yhmo --- .../service/collection/CollectionService.java | 77 ++++++++++++------- .../collection/response/GetLoadStateResp.java | 17 ---- .../service/partition/PartitionService.java | 42 ++++++---- .../src/test/java/io/milvus/v2/BaseTest.java | 4 + .../v2/client/MilvusClientV2DockerTest.java | 1 - 5 files changed, 80 insertions(+), 61 deletions(-) diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java index 2acc74714..521b174ea 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java @@ -42,6 +42,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; public class CollectionService extends BaseService { public IndexService indexService = new IndexService(); @@ -458,21 +459,28 @@ public Void renameCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, LoadCollectionReq request) { String dbName = request.getDatabaseName(); String collectionName = request.getCollectionName(); + boolean sync = Boolean.TRUE.equals(request.getSync()); + boolean refresh = Boolean.TRUE.equals(request.getRefresh()); + boolean skipLoadDynamicField = Boolean.TRUE.equals(request.getSkipLoadDynamicField()); String title = String.format("Load collection: '%s' in database: '%s'", collectionName, dbName); LoadCollectionRequest.Builder builder = LoadCollectionRequest.newBuilder() .setCollectionName(collectionName) .setReplicaNumber(request.getNumReplicas()) - .setRefresh(request.getRefresh()) + .setRefresh(refresh) .addAllLoadFields(request.getLoadFields()) - .setSkipLoadDynamicField(request.getSkipLoadDynamicField()) + .setSkipLoadDynamicField(skipLoadDynamicField) .addAllResourceGroups(request.getResourceGroups()); if (StringUtils.isNotEmpty(dbName)) { builder.setDbName(dbName); } - Status status = blockingStub.loadCollection(builder.build()); + MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub; + if (request.getTimeout() != null && request.getTimeout() > 0) { + tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS); + } + Status status = tempBlockingStub.loadCollection(builder.build()); rpcUtils.handleResponse(title, status); - if (request.getSync()) { - WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout()); + if (sync) { + waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), refresh); } return null; @@ -481,6 +489,7 @@ public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS public Void refreshLoad(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, RefreshLoadReq request) { String dbName = request.getDatabaseName(); String collectionName = request.getCollectionName(); + boolean sync = Boolean.TRUE.equals(request.getSync()); String title = String.format("Refresh load collection: '%s' in database: '%s'", collectionName, dbName); LoadCollectionRequest.Builder builder = LoadCollectionRequest.newBuilder() .setCollectionName(collectionName) @@ -488,10 +497,14 @@ public Void refreshLoad(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub if (StringUtils.isNotEmpty(dbName)) { builder.setDbName(dbName); } - Status status = blockingStub.loadCollection(builder.build()); + MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub; + if (request.getTimeout() != null && request.getTimeout() > 0) { + tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS); + } + Status status = tempBlockingStub.loadCollection(builder.build()); rpcUtils.handleResponse(title, status); - if (request.getSync()) { - WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout()); + if (sync) { + waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), true); } return null; @@ -521,9 +534,7 @@ public GetLoadStateResp getLoadStateV2(MilvusServiceGrpc.MilvusServiceBlockingSt GetLoadStateResp.GetLoadStateRespBuilder respBuilder = GetLoadStateResp.builder() .state(response.getState()); if (response.getState() == LoadState.LoadStateLoading) { - GetLoadingProgressResponse progressResponse = getLoadingProgressResponse(blockingStub, request); - respBuilder.progress(progressResponse.getProgress()) - .refreshProgress(progressResponse.getRefreshProgress()); + respBuilder.progress(getLoadingProgress(blockingStub, request, false, null)); } return respBuilder.build(); @@ -556,8 +567,17 @@ private GetLoadStateResponse getLoadStateResponse(MilvusServiceGrpc.MilvusServic return response; } - private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, - GetLoadStateReq request) { + private Long getLoadingProgress(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, + GetLoadStateReq request, + boolean refreshLoad, + Long timeoutMs) { + GetLoadingProgressResponse response = getLoadingProgressInternal(blockingStub, request, timeoutMs); + return refreshLoad ? response.getRefreshProgress() : response.getProgress(); + } + + private GetLoadingProgressResponse getLoadingProgressInternal(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, + GetLoadStateReq request, + Long timeoutMs) { String dbName = request.getDatabaseName(); String collectionName = request.getCollectionName(); String partitionName = request.getPartitionName(); @@ -569,7 +589,11 @@ private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc. if (StringUtils.isNotEmpty(partitionName)) { builder.addPartitionNames(partitionName); } - GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build()); + MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub; + if (timeoutMs != null && timeoutMs > 0) { + tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS); + } + GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build()); String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName); rpcUtils.handleResponse(title, response.getStatus()); return response; @@ -711,31 +735,28 @@ public Void dropCollectionFunction(MilvusServiceGrpc.MilvusServiceBlockingStub b return null; } - private void WaitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName, - String collectionName, long timeoutMs) { - long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds) + private void waitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName, + String collectionName, Long timeoutMs, boolean refreshLoad) { + long startTime = System.currentTimeMillis(); + GetLoadStateReq request = GetLoadStateReq.builder() + .databaseName(databaseName) + .collectionName(collectionName) + .build(); while (true) { - // Call the getLoadState method - boolean isLoaded = getLoadState(blockingStub, GetLoadStateReq.builder() - .databaseName(databaseName) - .collectionName(collectionName) - .build()); - if (isLoaded) { + if (getLoadingProgress(blockingStub, request, refreshLoad, timeoutMs) >= 100L) { return; } - // Check if timeout is exceeded - if (System.currentTimeMillis() - startTime > timeoutMs) { + if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) { throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection timeout"); } - // Wait for a certain period before checking again try { - Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed. + Thread.sleep(500); } catch (InterruptedException e) { Thread.currentThread().interrupt(); logger.error("Thread was interrupted, Failed to complete operation"); - return; // or handle interruption appropriately + return; } } } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/response/GetLoadStateResp.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/response/GetLoadStateResp.java index 99ae98055..fb10ed4ab 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/collection/response/GetLoadStateResp.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/response/GetLoadStateResp.java @@ -24,12 +24,10 @@ public class GetLoadStateResp { private LoadState state; private Long progress; - private Long refreshProgress; private GetLoadStateResp(GetLoadStateRespBuilder builder) { this.state = builder.state; this.progress = builder.progress; - this.refreshProgress = builder.refreshProgress; } public LoadState getState() { @@ -52,21 +50,12 @@ public void setProgress(Long progress) { this.progress = progress; } - public Long getRefreshProgress() { - return refreshProgress; - } - - public void setRefreshProgress(Long refreshProgress) { - this.refreshProgress = refreshProgress; - } - @Override public String toString() { return "GetLoadStateResp{" + "state=" + state + ", stateName='" + getStateName() + '\'' + ", progress=" + progress + - ", refreshProgress=" + refreshProgress + '}'; } @@ -77,7 +66,6 @@ public static GetLoadStateRespBuilder builder() { public static class GetLoadStateRespBuilder { private LoadState state; private Long progress; - private Long refreshProgress; private GetLoadStateRespBuilder() { } @@ -92,11 +80,6 @@ public GetLoadStateRespBuilder progress(Long progress) { return this; } - public GetLoadStateRespBuilder refreshProgress(Long refreshProgress) { - this.refreshProgress = refreshProgress; - return this; - } - public GetLoadStateResp build() { return new GetLoadStateResp(this); } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/partition/PartitionService.java b/sdk-core/src/main/java/io/milvus/v2/service/partition/PartitionService.java index a9f005e49..0ba362138 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/partition/PartitionService.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/partition/PartitionService.java @@ -28,6 +28,7 @@ import org.apache.commons.lang3.StringUtils; import java.util.List; +import java.util.concurrent.TimeUnit; public class PartitionService extends BaseService { public Void createPartition(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreatePartitionReq request) { @@ -130,6 +131,9 @@ public Void loadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS String dbName = request.getDatabaseName(); String collectionName = request.getCollectionName(); List partitionNames = request.getPartitionNames(); + boolean sync = Boolean.TRUE.equals(request.getSync()); + boolean refresh = Boolean.TRUE.equals(request.getRefresh()); + boolean skipLoadDynamicField = Boolean.TRUE.equals(request.getSkipLoadDynamicField()); String title = String.format("Load partitions: %s in collection: '%s' in database: '%s'", partitionNames, collectionName, dbName); @@ -137,17 +141,22 @@ public Void loadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS .setCollectionName(collectionName) .addAllPartitionNames(partitionNames) .setReplicaNumber(request.getNumReplicas()) - .setRefresh(request.getRefresh()) + .setRefresh(refresh) .addAllLoadFields(request.getLoadFields()) - .setSkipLoadDynamicField(request.getSkipLoadDynamicField()) + .setSkipLoadDynamicField(skipLoadDynamicField) .addAllResourceGroups(request.getResourceGroups()); if (StringUtils.isNotEmpty(dbName)) { builder.setDbName(dbName); } - Status status = blockingStub.loadPartitions(builder.build()); + + MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub; + if (request.getTimeout() != null && request.getTimeout() > 0) { + tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS); + } + Status status = tempBlockingStub.loadPartitions(builder.build()); rpcUtils.handleResponse(title, status); - if (request.getSync()) { - WaitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout()); + if (sync) { + waitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout(), refresh); } return null; @@ -172,9 +181,9 @@ public Void releasePartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blocki return null; } - private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName, - String collectionName, List partitions, long timeoutMs) { - long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds) + private void waitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName, + String collectionName, List partitions, Long timeoutMs, boolean refreshLoad) { + long startTime = System.currentTimeMillis(); while (true) { GetLoadingProgressRequest.Builder builder = GetLoadingProgressRequest.newBuilder() @@ -183,24 +192,27 @@ private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub b if (StringUtils.isNotEmpty(dbName)) { builder.setDbName(dbName); } - GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build()); + MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub; + if (timeoutMs != null && timeoutMs > 0) { + tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS); + } + GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build()); String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName); rpcUtils.handleResponse(title, response.getStatus()); - if (response.getProgress() >= 100) { + long progress = refreshLoad ? response.getRefreshProgress() : response.getProgress(); + if (progress >= 100L) { return; } - // Check if timeout is exceeded - if (System.currentTimeMillis() - startTime > timeoutMs) { + if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) { throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load partitions timeout"); } - // Wait for a certain period before checking again try { - Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed. + Thread.sleep(500); } catch (InterruptedException e) { Thread.currentThread().interrupt(); logger.error("Thread was interrupted, failed to complete operation"); - return; // or handle interruption appropriately + return; } } } diff --git a/sdk-core/src/test/java/io/milvus/v2/BaseTest.java b/sdk-core/src/test/java/io/milvus/v2/BaseTest.java index 5db352d55..c584b285a 100644 --- a/sdk-core/src/test/java/io/milvus/v2/BaseTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/BaseTest.java @@ -31,8 +31,11 @@ import org.mockito.quality.Strictness; import java.util.Collections; +import java.util.concurrent.TimeUnit; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -46,6 +49,7 @@ public class BaseTest { @BeforeEach public void setUp() { client_v2.setBlockingStub(blockingStub); + when(blockingStub.withDeadlineAfter(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(blockingStub); Status successStatus = Status.newBuilder().setCode(0).build(); BoolResponse trueResponse = BoolResponse.newBuilder().setStatus(successStatus).setValue(Boolean.TRUE).build(); diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 41ade2d19..9fa99e70a 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -2977,7 +2977,6 @@ void testOperationsAcrossDB() { Assertions.assertEquals(LoadState.LoadStateLoaded, loadStateResp.getState()); Assertions.assertEquals(LoadState.LoadStateLoaded.name(), loadStateResp.getStateName()); Assertions.assertNull(loadStateResp.getProgress()); - Assertions.assertNull(loadStateResp.getRefreshProgress()); // specify the temp database name to release partition client.releasePartitions(ReleasePartitionsReq.builder()