4242import java .util .ArrayList ;
4343import java .util .Collections ;
4444import java .util .List ;
45+ import java .util .concurrent .TimeUnit ;
4546
4647public class CollectionService extends BaseService {
4748 public IndexService indexService = new IndexService ();
@@ -469,10 +470,14 @@ public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
469470 if (StringUtils .isNotEmpty (dbName )) {
470471 builder .setDbName (dbName );
471472 }
472- Status status = blockingStub .loadCollection (builder .build ());
473+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
474+ if (request .getTimeout () != null && request .getTimeout () > 0 ) {
475+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (request .getTimeout (), TimeUnit .MILLISECONDS );
476+ }
477+ Status status = tempBlockingStub .loadCollection (builder .build ());
473478 rpcUtils .handleResponse (title , status );
474479 if (request .getSync ()) {
475- WaitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout ());
480+ waitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout (), request . getRefresh ());
476481 }
477482
478483 return null ;
@@ -488,10 +493,14 @@ public Void refreshLoad(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub
488493 if (StringUtils .isNotEmpty (dbName )) {
489494 builder .setDbName (dbName );
490495 }
491- Status status = blockingStub .loadCollection (builder .build ());
496+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
497+ if (request .getTimeout () != null && request .getTimeout () > 0 ) {
498+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (request .getTimeout (), TimeUnit .MILLISECONDS );
499+ }
500+ Status status = tempBlockingStub .loadCollection (builder .build ());
492501 rpcUtils .handleResponse (title , status );
493502 if (request .getSync ()) {
494- WaitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout ());
503+ waitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout (), true );
495504 }
496505
497506 return null ;
@@ -521,9 +530,7 @@ public GetLoadStateResp getLoadStateV2(MilvusServiceGrpc.MilvusServiceBlockingSt
521530 GetLoadStateResp .GetLoadStateRespBuilder respBuilder = GetLoadStateResp .builder ()
522531 .state (response .getState ());
523532 if (response .getState () == LoadState .LoadStateLoading ) {
524- GetLoadingProgressResponse progressResponse = getLoadingProgressResponse (blockingStub , request );
525- respBuilder .progress (progressResponse .getProgress ())
526- .refreshProgress (progressResponse .getRefreshProgress ());
533+ respBuilder .progress (getLoadingProgress (blockingStub , request , false , null ));
527534 }
528535
529536 return respBuilder .build ();
@@ -556,8 +563,17 @@ private GetLoadStateResponse getLoadStateResponse(MilvusServiceGrpc.MilvusServic
556563 return response ;
557564 }
558565
559- private GetLoadingProgressResponse getLoadingProgressResponse (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
560- GetLoadStateReq request ) {
566+ private Long getLoadingProgress (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
567+ GetLoadStateReq request ,
568+ boolean refreshLoad ,
569+ Long timeoutMs ) {
570+ GetLoadingProgressResponse response = getLoadingProgressInternal (blockingStub , request , timeoutMs );
571+ return refreshLoad ? response .getRefreshProgress () : response .getProgress ();
572+ }
573+
574+ private GetLoadingProgressResponse getLoadingProgressInternal (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
575+ GetLoadStateReq request ,
576+ Long timeoutMs ) {
561577 String dbName = request .getDatabaseName ();
562578 String collectionName = request .getCollectionName ();
563579 String partitionName = request .getPartitionName ();
@@ -569,7 +585,11 @@ private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.
569585 if (StringUtils .isNotEmpty (partitionName )) {
570586 builder .addPartitionNames (partitionName );
571587 }
572- GetLoadingProgressResponse response = blockingStub .getLoadingProgress (builder .build ());
588+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
589+ if (timeoutMs != null && timeoutMs > 0 ) {
590+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (timeoutMs , TimeUnit .MILLISECONDS );
591+ }
592+ GetLoadingProgressResponse response = tempBlockingStub .getLoadingProgress (builder .build ());
573593 String title = String .format ("Get loading progress of collection: '%s' in database: '%s'" , collectionName , dbName );
574594 rpcUtils .handleResponse (title , response .getStatus ());
575595 return response ;
@@ -711,31 +731,28 @@ public Void dropCollectionFunction(MilvusServiceGrpc.MilvusServiceBlockingStub b
711731 return null ;
712732 }
713733
714- private void WaitForLoadCollection (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , String databaseName ,
715- String collectionName , long timeoutMs ) {
716- long startTime = System .currentTimeMillis (); // Capture start time/ Timeout in milliseconds (60 seconds)
734+ private void waitForLoadCollection (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , String databaseName ,
735+ String collectionName , Long timeoutMs , boolean refreshLoad ) {
736+ long startTime = System .currentTimeMillis ();
737+ GetLoadStateReq request = GetLoadStateReq .builder ()
738+ .databaseName (databaseName )
739+ .collectionName (collectionName )
740+ .build ();
717741
718742 while (true ) {
719- // Call the getLoadState method
720- boolean isLoaded = getLoadState (blockingStub , GetLoadStateReq .builder ()
721- .databaseName (databaseName )
722- .collectionName (collectionName )
723- .build ());
724- if (isLoaded ) {
743+ if (getLoadingProgress (blockingStub , request , refreshLoad , timeoutMs ) >= 100L ) {
725744 return ;
726745 }
727746
728- // Check if timeout is exceeded
729- if (System .currentTimeMillis () - startTime > timeoutMs ) {
747+ if (timeoutMs != null && timeoutMs > 0 && System .currentTimeMillis () - startTime > timeoutMs ) {
730748 throw new MilvusClientException (ErrorCode .SERVER_ERROR , "Load collection timeout" );
731749 }
732- // Wait for a certain period before checking again
733750 try {
734- Thread .sleep (500 ); // Sleep for 0.5 second. Adjust this value as needed.
751+ Thread .sleep (500 );
735752 } catch (InterruptedException e ) {
736753 Thread .currentThread ().interrupt ();
737754 logger .error ("Thread was interrupted, Failed to complete operation" );
738- return ; // or handle interruption appropriately
755+ return ;
739756 }
740757 }
741758 }
0 commit comments