Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;

public class CollectionService extends BaseService {
public IndexService indexService = new IndexService();
Expand Down Expand Up @@ -458,21 +459,28 @@ public Void renameCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin
public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, LoadCollectionReq request) {
String dbName = request.getDatabaseName();
String collectionName = request.getCollectionName();
boolean sync = Boolean.TRUE.equals(request.getSync());
boolean refresh = Boolean.TRUE.equals(request.getRefresh());
boolean skipLoadDynamicField = Boolean.TRUE.equals(request.getSkipLoadDynamicField());
String title = String.format("Load collection: '%s' in database: '%s'", collectionName, dbName);
LoadCollectionRequest.Builder builder = LoadCollectionRequest.newBuilder()
.setCollectionName(collectionName)
.setReplicaNumber(request.getNumReplicas())
.setRefresh(request.getRefresh())
.setRefresh(refresh)
.addAllLoadFields(request.getLoadFields())
.setSkipLoadDynamicField(request.getSkipLoadDynamicField())
.setSkipLoadDynamicField(skipLoadDynamicField)
.addAllResourceGroups(request.getResourceGroups());
if (StringUtils.isNotEmpty(dbName)) {
builder.setDbName(dbName);
}
Status status = blockingStub.loadCollection(builder.build());
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
if (request.getTimeout() != null && request.getTimeout() > 0) {
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
}
Status status = tempBlockingStub.loadCollection(builder.build());
rpcUtils.handleResponse(title, status);
if (request.getSync()) {
WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout());
if (sync) {
waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), refresh);
}

return null;
Expand All @@ -481,17 +489,22 @@ public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
public Void refreshLoad(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, RefreshLoadReq request) {
String dbName = request.getDatabaseName();
String collectionName = request.getCollectionName();
boolean sync = Boolean.TRUE.equals(request.getSync());
String title = String.format("Refresh load collection: '%s' in database: '%s'", collectionName, dbName);
LoadCollectionRequest.Builder builder = LoadCollectionRequest.newBuilder()
.setCollectionName(collectionName)
.setRefresh(true);
if (StringUtils.isNotEmpty(dbName)) {
builder.setDbName(dbName);
}
Status status = blockingStub.loadCollection(builder.build());
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
if (request.getTimeout() != null && request.getTimeout() > 0) {
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
}
Status status = tempBlockingStub.loadCollection(builder.build());
rpcUtils.handleResponse(title, status);
if (request.getSync()) {
WaitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout());
if (sync) {
waitForLoadCollection(blockingStub, dbName, collectionName, request.getTimeout(), true);
}

return null;
Expand Down Expand Up @@ -521,9 +534,7 @@ public GetLoadStateResp getLoadStateV2(MilvusServiceGrpc.MilvusServiceBlockingSt
GetLoadStateResp.GetLoadStateRespBuilder respBuilder = GetLoadStateResp.builder()
.state(response.getState());
if (response.getState() == LoadState.LoadStateLoading) {
GetLoadingProgressResponse progressResponse = getLoadingProgressResponse(blockingStub, request);
respBuilder.progress(progressResponse.getProgress())
.refreshProgress(progressResponse.getRefreshProgress());
respBuilder.progress(getLoadingProgress(blockingStub, request, false, null));
}

return respBuilder.build();
Expand Down Expand Up @@ -556,8 +567,17 @@ private GetLoadStateResponse getLoadStateResponse(MilvusServiceGrpc.MilvusServic
return response;
}

private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
GetLoadStateReq request) {
private Long getLoadingProgress(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
GetLoadStateReq request,
boolean refreshLoad,
Long timeoutMs) {
GetLoadingProgressResponse response = getLoadingProgressInternal(blockingStub, request, timeoutMs);
return refreshLoad ? response.getRefreshProgress() : response.getProgress();
}

private GetLoadingProgressResponse getLoadingProgressInternal(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
GetLoadStateReq request,
Long timeoutMs) {
String dbName = request.getDatabaseName();
String collectionName = request.getCollectionName();
String partitionName = request.getPartitionName();
Expand All @@ -569,7 +589,11 @@ private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.
if (StringUtils.isNotEmpty(partitionName)) {
builder.addPartitionNames(partitionName);
}
GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build());
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
if (timeoutMs != null && timeoutMs > 0) {
tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
}
GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build());
String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName);
rpcUtils.handleResponse(title, response.getStatus());
return response;
Expand Down Expand Up @@ -711,31 +735,28 @@ public Void dropCollectionFunction(MilvusServiceGrpc.MilvusServiceBlockingStub b
return null;
}

private void WaitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName,
String collectionName, long timeoutMs) {
long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
private void waitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String databaseName,
String collectionName, Long timeoutMs, boolean refreshLoad) {
long startTime = System.currentTimeMillis();
GetLoadStateReq request = GetLoadStateReq.builder()
.databaseName(databaseName)
.collectionName(collectionName)
.build();

while (true) {
// Call the getLoadState method
boolean isLoaded = getLoadState(blockingStub, GetLoadStateReq.builder()
.databaseName(databaseName)
.collectionName(collectionName)
.build());
if (isLoaded) {
if (getLoadingProgress(blockingStub, request, refreshLoad, timeoutMs) >= 100L) {
return;
}

// Check if timeout is exceeded
if (System.currentTimeMillis() - startTime > timeoutMs) {
if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) {
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection timeout");
}
// Wait for a certain period before checking again
try {
Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
Thread.sleep(500);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("Thread was interrupted, Failed to complete operation");
return; // or handle interruption appropriately
return;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@
public class GetLoadStateResp {
private LoadState state;
private Long progress;
private Long refreshProgress;

private GetLoadStateResp(GetLoadStateRespBuilder builder) {
this.state = builder.state;
this.progress = builder.progress;
this.refreshProgress = builder.refreshProgress;
}

public LoadState getState() {
Expand All @@ -52,21 +50,12 @@ public void setProgress(Long progress) {
this.progress = progress;
}

public Long getRefreshProgress() {
return refreshProgress;
}

public void setRefreshProgress(Long refreshProgress) {
this.refreshProgress = refreshProgress;
}

@Override
public String toString() {
return "GetLoadStateResp{" +
"state=" + state +
", stateName='" + getStateName() + '\'' +
", progress=" + progress +
", refreshProgress=" + refreshProgress +
'}';
}

Expand All @@ -77,7 +66,6 @@ public static GetLoadStateRespBuilder builder() {
public static class GetLoadStateRespBuilder {
private LoadState state;
private Long progress;
private Long refreshProgress;

private GetLoadStateRespBuilder() {
}
Expand All @@ -92,11 +80,6 @@ public GetLoadStateRespBuilder progress(Long progress) {
return this;
}

public GetLoadStateRespBuilder refreshProgress(Long refreshProgress) {
this.refreshProgress = refreshProgress;
return this;
}

public GetLoadStateResp build() {
return new GetLoadStateResp(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.commons.lang3.StringUtils;

import java.util.List;
import java.util.concurrent.TimeUnit;

public class PartitionService extends BaseService {
public Void createPartition(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreatePartitionReq request) {
Expand Down Expand Up @@ -130,24 +131,32 @@ public Void loadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
String dbName = request.getDatabaseName();
String collectionName = request.getCollectionName();
List<String> partitionNames = request.getPartitionNames();
boolean sync = Boolean.TRUE.equals(request.getSync());
boolean refresh = Boolean.TRUE.equals(request.getRefresh());
boolean skipLoadDynamicField = Boolean.TRUE.equals(request.getSkipLoadDynamicField());
String title = String.format("Load partitions: %s in collection: '%s' in database: '%s'",
partitionNames, collectionName, dbName);

LoadPartitionsRequest.Builder builder = LoadPartitionsRequest.newBuilder()
.setCollectionName(collectionName)
.addAllPartitionNames(partitionNames)
.setReplicaNumber(request.getNumReplicas())
.setRefresh(request.getRefresh())
.setRefresh(refresh)
.addAllLoadFields(request.getLoadFields())
.setSkipLoadDynamicField(request.getSkipLoadDynamicField())
.setSkipLoadDynamicField(skipLoadDynamicField)
.addAllResourceGroups(request.getResourceGroups());
if (StringUtils.isNotEmpty(dbName)) {
builder.setDbName(dbName);
}
Status status = blockingStub.loadPartitions(builder.build());

MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
if (request.getTimeout() != null && request.getTimeout() > 0) {
tempBlockingStub = tempBlockingStub.withDeadlineAfter(request.getTimeout(), TimeUnit.MILLISECONDS);
}
Status status = tempBlockingStub.loadPartitions(builder.build());
rpcUtils.handleResponse(title, status);
if (request.getSync()) {
WaitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout());
if (sync) {
waitForLoadPartitions(blockingStub, dbName, collectionName, partitionNames, request.getTimeout(), refresh);
}

return null;
Expand All @@ -172,9 +181,9 @@ public Void releasePartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blocki
return null;
}

private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName,
String collectionName, List<String> partitions, long timeoutMs) {
long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
private void waitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName,
String collectionName, List<String> partitions, Long timeoutMs, boolean refreshLoad) {
long startTime = System.currentTimeMillis();

while (true) {
GetLoadingProgressRequest.Builder builder = GetLoadingProgressRequest.newBuilder()
Expand All @@ -183,24 +192,27 @@ private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub b
if (StringUtils.isNotEmpty(dbName)) {
builder.setDbName(dbName);
}
GetLoadingProgressResponse response = blockingStub.getLoadingProgress(builder.build());
MilvusServiceGrpc.MilvusServiceBlockingStub tempBlockingStub = blockingStub;
if (timeoutMs != null && timeoutMs > 0) {
tempBlockingStub = tempBlockingStub.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
}
GetLoadingProgressResponse response = tempBlockingStub.getLoadingProgress(builder.build());
String title = String.format("Get loading progress of collection: '%s' in database: '%s'", collectionName, dbName);
rpcUtils.handleResponse(title, response.getStatus());
if (response.getProgress() >= 100) {
long progress = refreshLoad ? response.getRefreshProgress() : response.getProgress();
if (progress >= 100L) {
return;
}

// Check if timeout is exceeded
if (System.currentTimeMillis() - startTime > timeoutMs) {
if (timeoutMs != null && timeoutMs > 0 && System.currentTimeMillis() - startTime > timeoutMs) {
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load partitions timeout");
}
// Wait for a certain period before checking again
try {
Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
Thread.sleep(500);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("Thread was interrupted, failed to complete operation");
return; // or handle interruption appropriately
return;
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions sdk-core/src/test/java/io/milvus/v2/BaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
import org.mockito.quality.Strictness;

import java.util.Collections;
import java.util.concurrent.TimeUnit;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
Expand All @@ -46,6 +49,7 @@ public class BaseTest {
@BeforeEach
public void setUp() {
client_v2.setBlockingStub(blockingStub);
when(blockingStub.withDeadlineAfter(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(blockingStub);

Status successStatus = Status.newBuilder().setCode(0).build();
BoolResponse trueResponse = BoolResponse.newBuilder().setStatus(successStatus).setValue(Boolean.TRUE).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2977,7 +2977,6 @@ void testOperationsAcrossDB() {
Assertions.assertEquals(LoadState.LoadStateLoaded, loadStateResp.getState());
Assertions.assertEquals(LoadState.LoadStateLoaded.name(), loadStateResp.getStateName());
Assertions.assertNull(loadStateResp.getProgress());
Assertions.assertNull(loadStateResp.getRefreshProgress());

// specify the temp database name to release partition
client.releasePartitions(ReleasePartitionsReq.builder()
Expand Down
Loading