4242import java .util .ArrayList ;
4343import java .util .Collections ;
4444import java .util .List ;
45+ import java .util .concurrent .TimeUnit ;
4546
4647public class CollectionService extends BaseService {
4748 public IndexService indexService = new IndexService ();
@@ -458,21 +459,28 @@ public Void renameCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin
458459 public Void loadCollection (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , LoadCollectionReq request ) {
459460 String dbName = request .getDatabaseName ();
460461 String collectionName = request .getCollectionName ();
462+ boolean sync = Boolean .TRUE .equals (request .getSync ());
463+ boolean refresh = Boolean .TRUE .equals (request .getRefresh ());
464+ boolean skipLoadDynamicField = Boolean .TRUE .equals (request .getSkipLoadDynamicField ());
461465 String title = String .format ("Load collection: '%s' in database: '%s'" , collectionName , dbName );
462466 LoadCollectionRequest .Builder builder = LoadCollectionRequest .newBuilder ()
463467 .setCollectionName (collectionName )
464468 .setReplicaNumber (request .getNumReplicas ())
465- .setRefresh (request . getRefresh () )
469+ .setRefresh (refresh )
466470 .addAllLoadFields (request .getLoadFields ())
467- .setSkipLoadDynamicField (request . getSkipLoadDynamicField () )
471+ .setSkipLoadDynamicField (skipLoadDynamicField )
468472 .addAllResourceGroups (request .getResourceGroups ());
469473 if (StringUtils .isNotEmpty (dbName )) {
470474 builder .setDbName (dbName );
471475 }
472- Status status = blockingStub .loadCollection (builder .build ());
476+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
477+ if (request .getTimeout () != null && request .getTimeout () > 0 ) {
478+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (request .getTimeout (), TimeUnit .MILLISECONDS );
479+ }
480+ Status status = tempBlockingStub .loadCollection (builder .build ());
473481 rpcUtils .handleResponse (title , status );
474- if (request . getSync () ) {
475- WaitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout ());
482+ if (sync ) {
483+ waitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout (), refresh );
476484 }
477485
478486 return null ;
@@ -481,17 +489,22 @@ public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
481489 public Void refreshLoad (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , RefreshLoadReq request ) {
482490 String dbName = request .getDatabaseName ();
483491 String collectionName = request .getCollectionName ();
492+ boolean sync = Boolean .TRUE .equals (request .getSync ());
484493 String title = String .format ("Refresh load collection: '%s' in database: '%s'" , collectionName , dbName );
485494 LoadCollectionRequest .Builder builder = LoadCollectionRequest .newBuilder ()
486495 .setCollectionName (collectionName )
487496 .setRefresh (true );
488497 if (StringUtils .isNotEmpty (dbName )) {
489498 builder .setDbName (dbName );
490499 }
491- Status status = blockingStub .loadCollection (builder .build ());
500+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
501+ if (request .getTimeout () != null && request .getTimeout () > 0 ) {
502+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (request .getTimeout (), TimeUnit .MILLISECONDS );
503+ }
504+ Status status = tempBlockingStub .loadCollection (builder .build ());
492505 rpcUtils .handleResponse (title , status );
493- if (request . getSync () ) {
494- WaitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout ());
506+ if (sync ) {
507+ waitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout (), true );
495508 }
496509
497510 return null ;
@@ -521,9 +534,7 @@ public GetLoadStateResp getLoadStateV2(MilvusServiceGrpc.MilvusServiceBlockingSt
521534 GetLoadStateResp .GetLoadStateRespBuilder respBuilder = GetLoadStateResp .builder ()
522535 .state (response .getState ());
523536 if (response .getState () == LoadState .LoadStateLoading ) {
524- GetLoadingProgressResponse progressResponse = getLoadingProgressResponse (blockingStub , request );
525- respBuilder .progress (progressResponse .getProgress ())
526- .refreshProgress (progressResponse .getRefreshProgress ());
537+ respBuilder .progress (getLoadingProgress (blockingStub , request , false , null ));
527538 }
528539
529540 return respBuilder .build ();
@@ -556,8 +567,17 @@ private GetLoadStateResponse getLoadStateResponse(MilvusServiceGrpc.MilvusServic
556567 return response ;
557568 }
558569
559- private GetLoadingProgressResponse getLoadingProgressResponse (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
560- GetLoadStateReq request ) {
570+ private Long getLoadingProgress (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
571+ GetLoadStateReq request ,
572+ boolean refreshLoad ,
573+ Long timeoutMs ) {
574+ GetLoadingProgressResponse response = getLoadingProgressInternal (blockingStub , request , timeoutMs );
575+ return refreshLoad ? response .getRefreshProgress () : response .getProgress ();
576+ }
577+
578+ private GetLoadingProgressResponse getLoadingProgressInternal (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
579+ GetLoadStateReq request ,
580+ Long timeoutMs ) {
561581 String dbName = request .getDatabaseName ();
562582 String collectionName = request .getCollectionName ();
563583 String partitionName = request .getPartitionName ();
@@ -569,7 +589,11 @@ private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.
569589 if (StringUtils .isNotEmpty (partitionName )) {
570590 builder .addPartitionNames (partitionName );
571591 }
572- GetLoadingProgressResponse response = blockingStub .getLoadingProgress (builder .build ());
592+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
593+ if (timeoutMs != null && timeoutMs > 0 ) {
594+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (timeoutMs , TimeUnit .MILLISECONDS );
595+ }
596+ GetLoadingProgressResponse response = tempBlockingStub .getLoadingProgress (builder .build ());
573597 String title = String .format ("Get loading progress of collection: '%s' in database: '%s'" , collectionName , dbName );
574598 rpcUtils .handleResponse (title , response .getStatus ());
575599 return response ;
@@ -711,31 +735,28 @@ public Void dropCollectionFunction(MilvusServiceGrpc.MilvusServiceBlockingStub b
711735 return null ;
712736 }
713737
714- private void WaitForLoadCollection (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , String databaseName ,
715- String collectionName , long timeoutMs ) {
716- long startTime = System .currentTimeMillis (); // Capture start time/ Timeout in milliseconds (60 seconds)
738+ private void waitForLoadCollection (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , String databaseName ,
739+ String collectionName , Long timeoutMs , boolean refreshLoad ) {
740+ long startTime = System .currentTimeMillis ();
741+ GetLoadStateReq request = GetLoadStateReq .builder ()
742+ .databaseName (databaseName )
743+ .collectionName (collectionName )
744+ .build ();
717745
718746 while (true ) {
719- // Call the getLoadState method
720- boolean isLoaded = getLoadState (blockingStub , GetLoadStateReq .builder ()
721- .databaseName (databaseName )
722- .collectionName (collectionName )
723- .build ());
724- if (isLoaded ) {
747+ if (getLoadingProgress (blockingStub , request , refreshLoad , timeoutMs ) >= 100L ) {
725748 return ;
726749 }
727750
728- // Check if timeout is exceeded
729- if (System .currentTimeMillis () - startTime > timeoutMs ) {
751+ if (timeoutMs != null && timeoutMs > 0 && System .currentTimeMillis () - startTime > timeoutMs ) {
730752 throw new MilvusClientException (ErrorCode .SERVER_ERROR , "Load collection timeout" );
731753 }
732- // Wait for a certain period before checking again
733754 try {
734- Thread .sleep (500 ); // Sleep for 0.5 second. Adjust this value as needed.
755+ Thread .sleep (500 );
735756 } catch (InterruptedException e ) {
736757 Thread .currentThread ().interrupt ();
737758 logger .error ("Thread was interrupted, Failed to complete operation" );
738- return ; // or handle interruption appropriately
759+ return ;
739760 }
740761 }
741762 }
0 commit comments