4141import java .util .ArrayList ;
4242import java .util .Collections ;
4343import java .util .List ;
44+ import java .util .concurrent .TimeUnit ;
4445
4546public class CollectionService extends BaseService {
4647 public IndexService indexService = new IndexService ();
@@ -437,21 +438,28 @@ public Void renameCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin
437438 public Void loadCollection (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , LoadCollectionReq request ) {
438439 String dbName = request .getDatabaseName ();
439440 String collectionName = request .getCollectionName ();
441+ boolean sync = Boolean .TRUE .equals (request .getSync ());
442+ boolean refresh = Boolean .TRUE .equals (request .getRefresh ());
443+ boolean skipLoadDynamicField = Boolean .TRUE .equals (request .getSkipLoadDynamicField ());
440444 String title = String .format ("Load collection: '%s' in database: '%s'" , collectionName , dbName );
441445 LoadCollectionRequest .Builder builder = LoadCollectionRequest .newBuilder ()
442446 .setCollectionName (collectionName )
443447 .setReplicaNumber (request .getNumReplicas ())
444- .setRefresh (request . getRefresh () )
448+ .setRefresh (refresh )
445449 .addAllLoadFields (request .getLoadFields ())
446- .setSkipLoadDynamicField (request . getSkipLoadDynamicField () )
450+ .setSkipLoadDynamicField (skipLoadDynamicField )
447451 .addAllResourceGroups (request .getResourceGroups ());
448452 if (StringUtils .isNotEmpty (dbName )) {
449453 builder .setDbName (dbName );
450454 }
451- Status status = blockingStub .loadCollection (builder .build ());
455+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
456+ if (request .getTimeout () != null && request .getTimeout () > 0 ) {
457+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (request .getTimeout (), TimeUnit .MILLISECONDS );
458+ }
459+ Status status = tempBlockingStub .loadCollection (builder .build ());
452460 rpcUtils .handleResponse (title , status );
453- if (request . getSync () ) {
454- WaitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout ());
461+ if (sync ) {
462+ waitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout (), refresh );
455463 }
456464
457465 return null ;
@@ -460,17 +468,22 @@ public Void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingS
460468 public Void refreshLoad (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , RefreshLoadReq request ) {
461469 String dbName = request .getDatabaseName ();
462470 String collectionName = request .getCollectionName ();
471+ boolean sync = Boolean .TRUE .equals (request .getSync ());
463472 String title = String .format ("Refresh load collection: '%s' in database: '%s'" , collectionName , dbName );
464473 LoadCollectionRequest .Builder builder = LoadCollectionRequest .newBuilder ()
465474 .setCollectionName (collectionName )
466475 .setRefresh (true );
467476 if (StringUtils .isNotEmpty (dbName )) {
468477 builder .setDbName (dbName );
469478 }
470- Status status = blockingStub .loadCollection (builder .build ());
479+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
480+ if (request .getTimeout () != null && request .getTimeout () > 0 ) {
481+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (request .getTimeout (), TimeUnit .MILLISECONDS );
482+ }
483+ Status status = tempBlockingStub .loadCollection (builder .build ());
471484 rpcUtils .handleResponse (title , status );
472- if (request . getSync () ) {
473- WaitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout ());
485+ if (sync ) {
486+ waitForLoadCollection (blockingStub , dbName , collectionName , request .getTimeout (), true );
474487 }
475488
476489 return null ;
@@ -500,9 +513,7 @@ public GetLoadStateResp getLoadStateV2(MilvusServiceGrpc.MilvusServiceBlockingSt
500513 GetLoadStateResp .GetLoadStateRespBuilder respBuilder = GetLoadStateResp .builder ()
501514 .state (response .getState ());
502515 if (response .getState () == LoadState .LoadStateLoading ) {
503- GetLoadingProgressResponse progressResponse = getLoadingProgressResponse (blockingStub , request );
504- respBuilder .progress (progressResponse .getProgress ())
505- .refreshProgress (progressResponse .getRefreshProgress ());
516+ respBuilder .progress (getLoadingProgress (blockingStub , request , false , null ));
506517 }
507518
508519 return respBuilder .build ();
@@ -535,8 +546,17 @@ private GetLoadStateResponse getLoadStateResponse(MilvusServiceGrpc.MilvusServic
535546 return response ;
536547 }
537548
538- private GetLoadingProgressResponse getLoadingProgressResponse (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
539- GetLoadStateReq request ) {
549+ private Long getLoadingProgress (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
550+ GetLoadStateReq request ,
551+ boolean refreshLoad ,
552+ Long timeoutMs ) {
553+ GetLoadingProgressResponse response = getLoadingProgressInternal (blockingStub , request , timeoutMs );
554+ return refreshLoad ? response .getRefreshProgress () : response .getProgress ();
555+ }
556+
557+ private GetLoadingProgressResponse getLoadingProgressInternal (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub ,
558+ GetLoadStateReq request ,
559+ Long timeoutMs ) {
540560 String dbName = request .getDatabaseName ();
541561 String collectionName = request .getCollectionName ();
542562 String partitionName = request .getPartitionName ();
@@ -548,7 +568,11 @@ private GetLoadingProgressResponse getLoadingProgressResponse(MilvusServiceGrpc.
548568 if (StringUtils .isNotEmpty (partitionName )) {
549569 builder .addPartitionNames (partitionName );
550570 }
551- GetLoadingProgressResponse response = blockingStub .getLoadingProgress (builder .build ());
571+ MilvusServiceGrpc .MilvusServiceBlockingStub tempBlockingStub = blockingStub ;
572+ if (timeoutMs != null && timeoutMs > 0 ) {
573+ tempBlockingStub = tempBlockingStub .withDeadlineAfter (timeoutMs , TimeUnit .MILLISECONDS );
574+ }
575+ GetLoadingProgressResponse response = tempBlockingStub .getLoadingProgress (builder .build ());
552576 String title = String .format ("Get loading progress of collection: '%s' in database: '%s'" , collectionName , dbName );
553577 rpcUtils .handleResponse (title , response .getStatus ());
554578 return response ;
@@ -690,31 +714,28 @@ public Void dropCollectionFunction(MilvusServiceGrpc.MilvusServiceBlockingStub b
690714 return null ;
691715 }
692716
693- private void WaitForLoadCollection (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , String databaseName ,
694- String collectionName , long timeoutMs ) {
695- long startTime = System .currentTimeMillis (); // Capture start time/ Timeout in milliseconds (60 seconds)
717+ private void waitForLoadCollection (MilvusServiceGrpc .MilvusServiceBlockingStub blockingStub , String databaseName ,
718+ String collectionName , Long timeoutMs , boolean refreshLoad ) {
719+ long startTime = System .currentTimeMillis ();
720+ GetLoadStateReq request = GetLoadStateReq .builder ()
721+ .databaseName (databaseName )
722+ .collectionName (collectionName )
723+ .build ();
696724
697725 while (true ) {
698- // Call the getLoadState method
699- boolean isLoaded = getLoadState (blockingStub , GetLoadStateReq .builder ()
700- .databaseName (databaseName )
701- .collectionName (collectionName )
702- .build ());
703- if (isLoaded ) {
726+ if (getLoadingProgress (blockingStub , request , refreshLoad , timeoutMs ) >= 100L ) {
704727 return ;
705728 }
706729
707- // Check if timeout is exceeded
708- if (System .currentTimeMillis () - startTime > timeoutMs ) {
730+ if (timeoutMs != null && timeoutMs > 0 && System .currentTimeMillis () - startTime > timeoutMs ) {
709731 throw new MilvusClientException (ErrorCode .SERVER_ERROR , "Load collection timeout" );
710732 }
711- // Wait for a certain period before checking again
712733 try {
713- Thread .sleep (500 ); // Sleep for 0.5 second. Adjust this value as needed.
734+ Thread .sleep (500 );
714735 } catch (InterruptedException e ) {
715736 Thread .currentThread ().interrupt ();
716737 logger .error ("Thread was interrupted, Failed to complete operation" );
717- return ; // or handle interruption appropriately
738+ return ;
718739 }
719740 }
720741 }
0 commit comments